diff --git a/.agents/skills/debug-openshell-cluster/SKILL.md b/.agents/skills/debug-openshell-cluster/SKILL.md index 68ecc7749..7ef6f16cb 100644 --- a/.agents/skills/debug-openshell-cluster/SKILL.md +++ b/.agents/skills/debug-openshell-cluster/SKILL.md @@ -143,6 +143,15 @@ release. Look for failed installs, unexpected values, missing namespace, wrong image tag, TLS settings that do not match the registered endpoint, and scheduling failures. +When no external credential driver is enabled, the Helm chart uses the +gateway's default encrypted database credential storage. The chart creates a +retained Kubernetes Secret for the shared KEK, injects it into gateway pods, and +stores encrypted credential envelopes in the OpenShell database. For +`workload.kind=deployment` or multi-replica gateways, confirm +`server.externalDbSecret` points at a shared database. A render/install error +mentioning `server.credentialDrivers` means the values selected multiple +external credential backends. + For HA or PostgreSQL-backed installs, also check the external database Secret referenced by `server.externalDbSecret` and the PostgreSQL workload if the test or operator deployed one in-cluster: diff --git a/.github/workflows/branch-e2e.yml b/.github/workflows/branch-e2e.yml index 23d7d1cf6..6b12d9a09 100644 --- a/.github/workflows/branch-e2e.yml +++ b/.github/workflows/branch-e2e.yml @@ -24,6 +24,7 @@ jobs: run_core_e2e: ${{ steps.labels.outputs.run_core_e2e }} run_gpu_e2e: ${{ steps.labels.outputs.run_gpu_e2e }} run_kubernetes_ha_e2e: ${{ steps.labels.outputs.run_kubernetes_ha_e2e }} + run_kubernetes_credential_drivers_e2e: ${{ steps.labels.outputs.run_kubernetes_credential_drivers_e2e }} run_any_e2e: ${{ steps.labels.outputs.run_any_e2e }} steps: - uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0 @@ -41,12 +42,14 @@ jobs: run_core_e2e=true run_gpu_e2e=true run_kubernetes_ha_e2e=true + run_kubernetes_credential_drivers_e2e=true else run_core_e2e="$(jq -r 'index("test:e2e") != null' <<< "$LABELS_JSON")" run_gpu_e2e="$(jq -r 'index("test:e2e-gpu") != null' <<< "$LABELS_JSON")" run_kubernetes_ha_e2e="$(jq -r 'index("test:e2e-kubernetes") != null' <<< "$LABELS_JSON")" + run_kubernetes_credential_drivers_e2e="$(jq -r 'index("test:e2e-kubernetes") != null' <<< "$LABELS_JSON")" fi - if [ "$run_core_e2e" = "true" ] || [ "$run_gpu_e2e" = "true" ] || [ "$run_kubernetes_ha_e2e" = "true" ]; then + if [ "$run_core_e2e" = "true" ] || [ "$run_gpu_e2e" = "true" ] || [ "$run_kubernetes_ha_e2e" = "true" ] || [ "$run_kubernetes_credential_drivers_e2e" = "true" ]; then run_any_e2e=true else run_any_e2e=false @@ -55,12 +58,13 @@ jobs: echo "run_core_e2e=$run_core_e2e" echo "run_gpu_e2e=$run_gpu_e2e" echo "run_kubernetes_ha_e2e=$run_kubernetes_ha_e2e" + echo "run_kubernetes_credential_drivers_e2e=$run_kubernetes_credential_drivers_e2e" echo "run_any_e2e=$run_any_e2e" } >> "$GITHUB_OUTPUT" build-gateway: needs: [pr_metadata] - if: needs.pr_metadata.outputs.should_run == 'true' && (needs.pr_metadata.outputs.run_core_e2e == 'true' || needs.pr_metadata.outputs.run_kubernetes_ha_e2e == 'true') + if: needs.pr_metadata.outputs.should_run == 'true' && (needs.pr_metadata.outputs.run_core_e2e == 'true' || needs.pr_metadata.outputs.run_kubernetes_ha_e2e == 'true' || needs.pr_metadata.outputs.run_kubernetes_credential_drivers_e2e == 'true') permissions: contents: read packages: write @@ -135,6 +139,18 @@ jobs: extra-helm-values: deploy/helm/openshell/ci/values-high-availability.yaml external-postgres-secret: openshell-ha-pg + kubernetes-credential-drivers-e2e: + needs: [pr_metadata, build-gateway, build-supervisor] + if: needs.pr_metadata.outputs.should_run == 'true' && needs.pr_metadata.outputs.run_kubernetes_credential_drivers_e2e == 'true' + permissions: + contents: read + packages: read + uses: ./.github/workflows/e2e-kubernetes-test.yml + with: + image-tag: ${{ github.sha }} + job-name: Kubernetes Credential Drivers E2E + e2e-task: e2e:kubernetes:credential-drivers + core-e2e-result: name: Core E2E result needs: [pr_metadata, build-gateway, build-supervisor, e2e, kubernetes-e2e] @@ -215,3 +231,30 @@ jobs: fi done exit "$failed" + + kubernetes-credential-drivers-e2e-result: + name: Kubernetes Credential Drivers E2E result + needs: [pr_metadata, build-gateway, build-supervisor, kubernetes-credential-drivers-e2e] + if: always() && needs.pr_metadata.outputs.should_run == 'true' && needs.pr_metadata.outputs.run_kubernetes_credential_drivers_e2e == 'true' + runs-on: ubuntu-latest + steps: + - name: Verify Kubernetes credential drivers E2E jobs + env: + BUILD_GATEWAY_RESULT: ${{ needs.build-gateway.result }} + BUILD_SUPERVISOR_RESULT: ${{ needs.build-supervisor.result }} + KUBERNETES_CREDENTIAL_DRIVERS_E2E_RESULT: ${{ needs.kubernetes-credential-drivers-e2e.result }} + run: | + set -euo pipefail + failed=0 + for item in \ + "build-gateway:$BUILD_GATEWAY_RESULT" \ + "build-supervisor:$BUILD_SUPERVISOR_RESULT" \ + "kubernetes-credential-drivers-e2e:$KUBERNETES_CREDENTIAL_DRIVERS_E2E_RESULT"; do + name="${item%%:*}" + result="${item#*:}" + if [ "$result" != "success" ]; then + echo "::error::$name concluded $result" + failed=1 + fi + done + exit "$failed" diff --git a/.github/workflows/e2e-kubernetes-test.yml b/.github/workflows/e2e-kubernetes-test.yml index 5ff375922..9dc526758 100644 --- a/.github/workflows/e2e-kubernetes-test.yml +++ b/.github/workflows/e2e-kubernetes-test.yml @@ -32,6 +32,11 @@ on: required: false type: string default: "" + e2e-task: + description: "mise task to run for the Kubernetes e2e job" + required: false + type: string + default: "e2e:kubernetes" mise-version: description: "mise version to install on the bare Kubernetes e2e runner" required: false @@ -112,11 +117,12 @@ jobs: kind load image-archive "$archive" --name "$KIND_CLUSTER_NAME" done - - name: Run Kubernetes E2E (Rust smoke) + - name: Run Kubernetes E2E env: OPENSHELL_E2E_KUBE_CONTEXT: kind-${{ env.KIND_CLUSTER_NAME }} OPENSHELL_E2E_KUBE_EXTRA_VALUES: ${{ inputs.extra-helm-values }} OPENSHELL_E2E_KUBE_EXTERNAL_POSTGRES_SECRET: ${{ inputs.external-postgres-secret }} IMAGE_TAG: ${{ inputs.image-tag }} OPENSHELL_REGISTRY: ghcr.io/nvidia/openshell - run: mise run --no-deps --skip-deps e2e:kubernetes + E2E_TASK: ${{ inputs.e2e-task }} + run: mise run --no-deps --skip-deps "$E2E_TASK" diff --git a/.github/workflows/e2e-label-help.yml b/.github/workflows/e2e-label-help.yml index 1190bcd3d..cbcbcb76c 100644 --- a/.github/workflows/e2e-label-help.yml +++ b/.github/workflows/e2e-label-help.yml @@ -51,7 +51,7 @@ jobs: status_summary="The matching required CI gate status on this PR will flip green automatically once the run finishes." ;; test:e2e-kubernetes) - suite_summary="Kubernetes HA E2E" + suite_summary="Kubernetes HA and credential-driver E2E" build_summary="gateway and supervisor images" status_summary="This is an optional proof-of-life suite; failures are visible in the workflow run but do not publish a required CI gate status." ;; diff --git a/AGENTS.md b/AGENTS.md index 53b4e8049..ace2f5d1b 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -39,6 +39,8 @@ These pipelines connect skills into end-to-end workflows. Individual skill files | `crates/openshell-core/` | Shared core | Common types, configuration, error handling | | `crates/openshell-providers/` | Provider management | Credential provider backends | | `crates/openshell-tui/` | Terminal UI | Ratatui-based dashboard for monitoring | +| `crates/openshell-driver-kubernetes-secrets/` | Kubernetes Secrets credential driver | In-process `CredentialDriver` backend for OpenShell-managed K8s Secret storage | +| `crates/openshell-driver-vault/` | Vault credential driver | In-process `CredentialDriver` backend for Vault-compatible KV storage | | `crates/openshell-driver-kubernetes/` | Kubernetes compute driver | In-process `ComputeDriver` backend for K8s sandbox pods | | `crates/openshell-driver-docker/` | Docker compute driver | In-process `ComputeDriver` backend for local Docker sandbox containers | | `crates/openshell-driver-vm/` | VM compute driver | Standalone libkrun-backed `ComputeDriver` subprocess (embeds its own rootfs + runtime) | diff --git a/CI.md b/CI.md index d04668aaf..5123831bc 100644 --- a/CI.md +++ b/CI.md @@ -15,10 +15,11 @@ Three opt-in labels enable the long-running E2E suites: - `test:e2e` runs the standard E2E suite in `Branch E2E Checks` - `test:e2e-gpu` runs GPU E2E in `Branch E2E Checks` - `test:e2e-kubernetes` runs Kubernetes E2E with the HA Helm overlay - (`replicaCount: 2` and bundled PostgreSQL) in `Branch E2E Checks` + (`replicaCount: 2` and bundled PostgreSQL) and the credential-driver suite + (Kubernetes Secrets plus Vault) in `Branch E2E Checks` When multiple labels are present, `Branch E2E Checks` builds the shared gateway and supervisor images once and fans out all enabled suites in parallel. -The `OpenShell / E2E` and `OpenShell / GPU E2E` required statuses are evaluated from separate suite result jobs inside that workflow. `test:e2e-kubernetes` is optional while HA behavior is under active iteration: failures are visible in the workflow run but do not publish a required CI gate status. +The `OpenShell / E2E` and `OpenShell / GPU E2E` required statuses are evaluated from separate suite result jobs inside that workflow. `test:e2e-kubernetes` is optional while Kubernetes HA and credential-driver behavior are under active iteration: failures are visible in the workflow run but do not publish a required CI gate status. The GitHub ruleset should require the `OpenShell / ...` statuses published by `Required CI Gates`, not the push-triggered workflow jobs directly. @@ -110,7 +111,7 @@ The bot's full administrator documentation is internal to NVIDIA. The only comma | File | Role | |---|---| | `.github/workflows/branch-checks.yml` | Required non-E2E PR checks. Triggers on `push: pull-request/[0-9]+`. | -| `.github/workflows/branch-e2e.yml` | Opt-in standard, GPU, and Kubernetes HA E2E. Triggers on `push: pull-request/[0-9]+` and runs jobs selected by `test:e2e`, `test:e2e-gpu`, or `test:e2e-kubernetes`. | +| `.github/workflows/branch-e2e.yml` | Opt-in standard, GPU, Kubernetes HA, and Kubernetes credential-driver E2E. Triggers on `push: pull-request/[0-9]+` and runs jobs selected by `test:e2e`, `test:e2e-gpu`, or `test:e2e-kubernetes`. | | `.github/workflows/helm-lint.yml` | Helm chart validation. Triggers on `push: pull-request/[0-9]+` and skips lint jobs unless Helm inputs changed. | | `.github/actions/pr-gate/action.yml` | Composite action that resolves PR metadata and verifies the required label is set. | | `.github/actions/pr-merge-base/action.yml` | Composite action that resolves and fetches the merge-base commit for `pull-request/` push workflows. | diff --git a/Cargo.lock b/Cargo.lock index f693acd66..860e464a9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3471,6 +3471,23 @@ dependencies = [ "url", ] +[[package]] +name = "openshell-driver-db-credstore" +version = "0.0.0" +dependencies = [ + "async-trait", + "base64 0.22.1", + "openshell-core", + "ring", + "serde", + "serde_json", + "sha2 0.10.9", + "tempfile", + "tokio", + "toml", + "tonic", +] + [[package]] name = "openshell-driver-docker" version = "0.0.0" @@ -3516,6 +3533,25 @@ dependencies = [ "tracing-subscriber", ] +[[package]] +name = "openshell-driver-kubernetes-secrets" +version = "0.0.0" +dependencies = [ + "clap", + "futures", + "k8s-openapi", + "kube", + "miette", + "openshell-core", + "serde", + "sha2 0.10.9", + "tokio", + "toml", + "tonic", + "tracing", + "tracing-subscriber", +] + [[package]] name = "openshell-driver-podman" version = "0.0.0" @@ -3541,6 +3577,27 @@ dependencies = [ "tracing-subscriber", ] +[[package]] +name = "openshell-driver-vault" +version = "0.0.0" +dependencies = [ + "clap", + "futures", + "miette", + "openshell-core", + "reqwest 0.12.28", + "serde", + "serde_json", + "sha2 0.10.9", + "tempfile", + "tokio", + "toml", + "tonic", + "tracing", + "tracing-subscriber", + "wiremock", +] + [[package]] name = "openshell-driver-vm" version = "0.0.0" @@ -3673,6 +3730,7 @@ dependencies = [ "arc-swap", "async-trait", "axum", + "base64 0.22.1", "bytes", "clap", "futures", @@ -3696,9 +3754,12 @@ dependencies = [ "notify", "openshell-bootstrap", "openshell-core", + "openshell-driver-db-credstore", "openshell-driver-docker", "openshell-driver-kubernetes", + "openshell-driver-kubernetes-secrets", "openshell-driver-podman", + "openshell-driver-vault", "openshell-ocsf", "openshell-policy", "openshell-prover", @@ -3712,6 +3773,7 @@ dependencies = [ "rand 0.9.4", "rcgen", "reqwest 0.12.28", + "ring", "russh", "rustix 1.1.4", "rustls", diff --git a/Cargo.toml b/Cargo.toml index 86025646a..57dd9b542 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -88,6 +88,7 @@ sha2 = "0.10" rand = "0.9" jsonwebtoken = "9" getrandom = "0.3" +ring = "0.17" spiffe = { version = "0.15", default-features = false, features = ["workload-api-jwt", "tracing"] } # Filesystem embedding diff --git a/TESTING.md b/TESTING.md index d32dc385c..f23127e3c 100644 --- a/TESTING.md +++ b/TESTING.md @@ -150,6 +150,7 @@ Suites: - Common suite (`--features e2e`) - driver-neutral CLI behavior, sandbox lifecycle, sync, port forwarding, policy, and provider tests. - Docker suite (`--features e2e-docker`) - common suite plus Docker-only coverage such as Dockerfile image builds, Docker preflight checks, and managed Docker gateway resume. - Docker GPU suite (`--features e2e-docker-gpu`) - Docker suite plus GPU sandbox smoke coverage. +- Kubernetes credential-driver suite (`--features e2e-kubernetes-credential-drivers`) - targeted Kubernetes Secrets and Vault provider credential storage coverage. GPU device-selection tests compare OpenShell sandboxes against a plain Docker or Podman container that requests `--device nvidia.com/gpu=all`. The probe image @@ -173,6 +174,14 @@ Run the Podman-backed Rust CLI e2e suite: mise run e2e:podman ``` +Run the targeted Kubernetes credential-driver e2e suite. This deploys an +OpenBao fixture for the Vault-compatible driver path and validates Kubernetes +Secrets and Vault storage backends one at a time: + +```shell +mise run e2e:kubernetes:credential-drivers +``` + Run a single test directly with cargo: ```shell @@ -203,3 +212,4 @@ The harness (`e2e/rust/src/harness/`) provides: | `OPENSHELL_GATEWAY` | Override active gateway name for E2E tests | | `OPENSHELL_GATEWAY_ENDPOINT` | Run E2E tests against an existing plaintext HTTP gateway endpoint | | `OPENSHELL_E2E_DRIVER` | Driver name exported by the e2e gateway wrapper (`docker`, `podman`, or `vm`) | +| `OPENSHELL_E2E_CREDENTIAL_DRIVERS` | Enables the Kubernetes credential-driver fixture path in `e2e/with-kube-gateway.sh` | diff --git a/architecture/gateway.md b/architecture/gateway.md index 979422d7e..9ec55e009 100644 --- a/architecture/gateway.md +++ b/architecture/gateway.md @@ -159,9 +159,14 @@ default WAL journal mode), which mirror the same sensitive contents. Persisted state includes sandboxes, providers, provider credential refresh state, SSH sessions, policy revisions, settings, inference configuration, and deployment records. Provider refresh material is stored as a separate object -scoped to the provider instance through `objects.scope`; the provider record -keeps only the current injectable credential values and optional per-credential -expiry timestamps. +scoped to the provider instance through `objects.scope`. Provider records keep +inline credential values only for legacy records created before credential +driver storage. New provider writes keep driver-owned credential handles and +optional per-credential expiry timestamps. When no external credential driver +is configured, gateways use server-owned encrypted database credential storage +for defense in depth. Multi-replica deployments can use that default with a +shared database and shared key-encryption key, or opt into an external backend such as Vault +or Kubernetes Secrets. ### Optimistic Concurrency (CAS) diff --git a/crates/openshell-cli/src/run.rs b/crates/openshell-cli/src/run.rs index 1c3fd8a82..5e519acc5 100644 --- a/crates/openshell-cli/src/run.rs +++ b/crates/openshell-cli/src/run.rs @@ -3539,7 +3539,7 @@ fn format_provider_attachment_table(providers: &[Provider], color: bool) -> Stri for provider in providers { let provider_name = provider.object_name(); let provider_type = &provider.r#type; - let credential_keys = provider.credentials.len(); + let credential_keys = provider_credential_keys(provider).len(); let config_keys = provider.config.len(); let _ = writeln!( output, @@ -3814,6 +3814,7 @@ async fn auto_create_provider( credentials: discovered.credentials.clone(), config: discovered.config.clone(), credential_expires_at_ms: HashMap::new(), + credential_handles: HashMap::new(), }), }; @@ -3856,6 +3857,7 @@ async fn auto_create_provider( credentials: discovered.credentials.clone(), config: discovered.config.clone(), credential_expires_at_ms: HashMap::new(), + credential_handles: HashMap::new(), }), }; @@ -4493,7 +4495,7 @@ fn missing_credentials_error(provider_type: &str) -> miette::Report { "no credentials resolved for provider type '{provider_type}'. \ Set GOOGLE_VERTEX_AI_TOKEN, VERTEX_AI_TOKEN, \ GOOGLE_VERTEX_AI_SERVICE_ACCOUNT_TOKEN, or VERTEX_AI_SERVICE_ACCOUNT_TOKEN; \ - or use --from-gcloud-adc / --from-existing with those env vars set." + or use --from-gcloud-adc or --from-existing with those env vars set." ); } @@ -4507,8 +4509,8 @@ fn missing_credentials_error(provider_type: &str) -> miette::Report { miette::miette!( "no credentials resolved for provider type '{provider_type}'. \ - Use --credential KEY[=VALUE], --runtime-credentials for runtime-resolved profile credentials, \ - or --from-existing with the appropriate env vars set." + Use --credential KEY[=VALUE], --runtime-credentials for runtime-resolved profile credentials, or --from-existing \ + with the appropriate env vars set." ) } @@ -4551,7 +4553,7 @@ pub async fn provider_create_with_options( ) -> Result<()> { if from_gcloud_adc && (from_existing || !credentials.is_empty() || runtime_credentials) { return Err(miette::miette!( - "--from-gcloud-adc cannot be combined with --from-existing or --credential; it also cannot be combined with --runtime-credentials" + "--from-gcloud-adc cannot be combined with --from-existing, --credential, or --runtime-credentials" )); } if from_existing && (!credentials.is_empty() || runtime_credentials) { @@ -4695,6 +4697,7 @@ pub async fn provider_create_with_options( credentials: credential_map, config: config_map, credential_expires_at_ms: HashMap::new(), + credential_handles: HashMap::new(), }), }) .await @@ -4780,7 +4783,7 @@ pub async fn provider_get(server: &str, name: &str, tls: &TlsOptions) -> Result< .provider .ok_or_else(|| miette::miette!("provider missing from response"))?; - let credential_keys = provider.credentials.keys().cloned().collect::>(); + let credential_keys = provider_credential_keys(&provider); let config_keys = provider.config.keys().cloned().collect::>(); println!("{}", "Provider:".cyan().bold()); @@ -4827,7 +4830,7 @@ fn provider_to_json(provider: &Provider) -> serde_json::Value { obj.insert("type".to_string(), serde_json::json!(provider.r#type)); // Credential keys (NEVER values - security) - let credential_keys: Vec = provider.credentials.keys().cloned().collect(); + let credential_keys = provider_credential_keys(provider); obj.insert( "credential_keys".to_string(), serde_json::json!(credential_keys), @@ -4869,6 +4872,18 @@ fn provider_to_json(provider: &Provider) -> serde_json::Value { serde_json::Value::Object(obj) } +fn provider_credential_keys(provider: &Provider) -> Vec { + let mut keys: Vec = provider + .credentials + .keys() + .chain(provider.credential_handles.keys()) + .cloned() + .collect(); + keys.sort(); + keys.dedup(); + keys +} + pub async fn provider_list( server: &str, limit: u32, @@ -5626,6 +5641,7 @@ pub async fn provider_update( credentials: credential_map, config: config_map, credential_expires_at_ms: HashMap::new(), + credential_handles: HashMap::new(), }), credential_expires_at_ms, }) @@ -7963,6 +7979,7 @@ mod tests { )) .collect(), credential_expires_at_ms: std::collections::HashMap::new(), + credential_handles: std::collections::HashMap::new(), }], false, ); @@ -9410,6 +9427,7 @@ mod tests { credentials: std::collections::HashMap::new(), config: std::collections::HashMap::new(), credential_expires_at_ms: std::collections::HashMap::new(), + credential_handles: std::collections::HashMap::new(), }; let json = super::provider_to_json(&provider); @@ -9431,6 +9449,7 @@ mod tests { credentials, config: std::collections::HashMap::new(), credential_expires_at_ms: std::collections::HashMap::new(), + credential_handles: std::collections::HashMap::new(), }; let json = super::provider_to_json(&provider); @@ -9468,6 +9487,7 @@ mod tests { credentials: std::collections::HashMap::new(), config, credential_expires_at_ms: std::collections::HashMap::new(), + credential_handles: std::collections::HashMap::new(), }; let json = super::provider_to_json(&provider); @@ -9498,6 +9518,7 @@ mod tests { credentials: std::collections::HashMap::new(), config: std::collections::HashMap::new(), // Empty config credential_expires_at_ms: std::collections::HashMap::new(), + credential_handles: std::collections::HashMap::new(), }; let json = super::provider_to_json(&provider); @@ -9527,6 +9548,7 @@ mod tests { credentials: std::collections::HashMap::new(), config: std::collections::HashMap::new(), credential_expires_at_ms: std::collections::HashMap::new(), + credential_handles: std::collections::HashMap::new(), }; let json = super::provider_to_json(&provider); @@ -9552,6 +9574,7 @@ mod tests { credentials: std::collections::HashMap::new(), config: std::collections::HashMap::new(), credential_expires_at_ms: std::collections::HashMap::new(), + credential_handles: std::collections::HashMap::new(), }; let json = super::provider_to_json(&provider); @@ -9581,6 +9604,7 @@ mod tests { credentials: std::collections::HashMap::new(), config: std::collections::HashMap::new(), credential_expires_at_ms, + credential_handles: std::collections::HashMap::new(), }; let json = super::provider_to_json(&provider); @@ -9606,6 +9630,7 @@ mod tests { credentials: std::collections::HashMap::new(), config: std::collections::HashMap::new(), credential_expires_at_ms: std::collections::HashMap::new(), + credential_handles: std::collections::HashMap::new(), }; let json = super::provider_to_json(&provider); diff --git a/crates/openshell-cli/tests/ensure_providers_integration.rs b/crates/openshell-cli/tests/ensure_providers_integration.rs index 7bf8612b4..4e3a63936 100644 --- a/crates/openshell-cli/tests/ensure_providers_integration.rs +++ b/crates/openshell-cli/tests/ensure_providers_integration.rs @@ -68,6 +68,7 @@ impl TestOpenShell { credentials: HashMap::new(), config: HashMap::new(), credential_expires_at_ms: HashMap::new(), + credential_handles: HashMap::new(), }, ); } @@ -357,6 +358,11 @@ impl OpenShell for TestOpenShell { existing.credential_expires_at_ms, provider.credential_expires_at_ms, ), + credential_handles: if provider.credential_handles.is_empty() { + existing.credential_handles + } else { + provider.credential_handles + }, }; let updated_name = updated.object_name().to_string(); providers.insert(updated_name, updated.clone()); diff --git a/crates/openshell-cli/tests/provider_commands_integration.rs b/crates/openshell-cli/tests/provider_commands_integration.rs index 5a6e53eb1..bc54a84ef 100644 --- a/crates/openshell-cli/tests/provider_commands_integration.rs +++ b/crates/openshell-cli/tests/provider_commands_integration.rs @@ -340,7 +340,7 @@ impl OpenShell for TestOpenShell { .into_inner() .provider .ok_or_else(|| Status::invalid_argument("provider is required"))?; - if provider.credentials.is_empty() { + if provider.credentials.is_empty() && provider.credential_handles.is_empty() { let bootstrap_allowed = if let Some(profile) = openshell_providers::get_default_profile(&provider.r#type) { profile.allows_empty_provider_credentials() @@ -610,6 +610,11 @@ impl OpenShell for TestOpenShell { existing.credential_expires_at_ms, provider.credential_expires_at_ms, ), + credential_handles: if provider.credential_handles.is_empty() { + existing.credential_handles + } else { + provider.credential_handles + }, }; let updated_name = updated.object_name().to_string(); providers.insert(updated_name, updated.clone()); @@ -1755,6 +1760,7 @@ async fn provider_update_from_existing_uses_profile_discovery_when_v2_enabled() credentials: HashMap::new(), config: HashMap::new(), credential_expires_at_ms: HashMap::new(), + credential_handles: HashMap::new(), }, ); let _env = EnvVarGuard::set(&[("CUSTOM_UPDATE_DISCOVERY_API_KEY", "updated-profile-secret")]); @@ -2005,6 +2011,41 @@ async fn provider_create_supports_generic_type_and_env_lookup_credentials() { ); } +#[tokio::test] +async fn provider_create_sends_inline_credentials() { + let ts = run_server().await; + + run::provider_create_with_options( + &ts.endpoint, + "openai-inline", + "openai", + false, + &["OPENAI_API_KEY=sk-test".to_string()], + false, + false, + &[], + &ts.tls, + ) + .await + .expect("provider create with inline credential"); + + let stored = ts.state.providers.lock().await; + assert_eq!( + stored + .get("openai-inline") + .and_then(|provider| provider.credentials.get("OPENAI_API_KEY")) + .map(String::as_str), + Some("sk-test") + ); + assert!( + stored + .get("openai-inline") + .expect("provider") + .credential_handles + .is_empty() + ); +} + #[tokio::test] async fn provider_create_rejects_combined_from_existing_and_credentials() { let ts = run_server().await; @@ -2048,7 +2089,7 @@ async fn provider_create_rejects_combined_from_gcloud_adc_and_from_existing() { assert!( err.to_string() - .contains("--from-gcloud-adc cannot be combined with --from-existing or --credential"), + .contains("--from-gcloud-adc cannot be combined with --from-existing, --credential"), "unexpected error: {err}" ); assert!(ts.state.providers.lock().await.is_empty()); @@ -2073,7 +2114,7 @@ async fn provider_create_rejects_combined_from_gcloud_adc_and_credentials() { assert!( err.to_string() - .contains("--from-gcloud-adc cannot be combined with --from-existing or --credential"), + .contains("--from-gcloud-adc cannot be combined with --from-existing, --credential"), "unexpected error: {err}" ); assert!(ts.state.providers.lock().await.is_empty()); diff --git a/crates/openshell-core/src/config.rs b/crates/openshell-core/src/config.rs index eaaf1e4a0..1eeea4bf6 100644 --- a/crates/openshell-core/src/config.rs +++ b/crates/openshell-core/src/config.rs @@ -360,6 +360,13 @@ pub struct Config { /// configured driver. pub compute_drivers: Vec, + /// Credential drivers enabled for provider credential storage. + pub credential_drivers: Vec, + + /// Optional credential-driver default retained for compatibility. When + /// set, it must match the single enabled credential driver. + pub default_credential_driver: Option, + /// TTL for SSH session tokens, in seconds. 0 disables expiry. pub ssh_session_ttl_secs: u64, @@ -559,6 +566,8 @@ impl Config { gateway_jwt: None, database_url: String::new(), compute_drivers: vec![], + credential_drivers: Vec::new(), + default_credential_driver: None, ssh_session_ttl_secs: default_ssh_session_ttl_secs(), grpc_rate_limit_requests: None, grpc_rate_limit_window_secs: None, @@ -622,6 +631,24 @@ impl Config { self } + /// Create a new configuration with the configured credential drivers. + #[must_use] + pub fn with_credential_drivers(mut self, drivers: I) -> Self + where + I: IntoIterator, + S: Into, + { + self.credential_drivers = drivers.into_iter().map(Into::into).collect(); + self + } + + /// Create a new configuration with the default credential driver. + #[must_use] + pub fn with_default_credential_driver(mut self, driver: Option>) -> Self { + self.default_credential_driver = driver.map(Into::into); + self + } + /// Create a new configuration with the SSH session TTL. #[must_use] pub const fn with_ssh_session_ttl_secs(mut self, secs: u64) -> Self { @@ -821,6 +848,29 @@ mod tests { assert!(!cfg.auth.allow_unauthenticated_users); } + #[test] + fn config_defaults_to_internal_credential_storage() { + let cfg = Config::new(None); + assert!(cfg.credential_drivers.is_empty()); + assert!(cfg.default_credential_driver.is_none()); + } + + #[test] + fn config_accepts_credential_driver_settings() { + let cfg = Config::new(None) + .with_credential_drivers(["kubernetes-secrets", "vault"]) + .with_default_credential_driver(Some("kubernetes-secrets")); + + assert_eq!( + cfg.credential_drivers, + vec!["kubernetes-secrets".to_string(), "vault".to_string()] + ); + assert_eq!( + cfg.default_credential_driver.as_deref(), + Some("kubernetes-secrets") + ); + } + #[test] fn gateway_jwt_ttl_defaults_to_non_expiring() { let cfg: GatewayJwtConfig = serde_json::from_value(serde_json::json!({ diff --git a/crates/openshell-core/src/proto/mod.rs b/crates/openshell-core/src/proto/mod.rs index 08b062d2e..a1801c617 100644 --- a/crates/openshell-core/src/proto/mod.rs +++ b/crates/openshell-core/src/proto/mod.rs @@ -55,6 +55,19 @@ pub mod compute { } } +#[allow( + clippy::all, + clippy::pedantic, + clippy::nursery, + unused_qualifications, + rust_2018_idioms +)] +pub mod credentials { + pub mod v1 { + include!(concat!(env!("OUT_DIR"), "/openshell.credentials.v1.rs")); + } +} + #[allow( clippy::all, clippy::pedantic, diff --git a/crates/openshell-driver-db-credstore/Cargo.toml b/crates/openshell-driver-db-credstore/Cargo.toml new file mode 100644 index 000000000..87b24733e --- /dev/null +++ b/crates/openshell-driver-db-credstore/Cargo.toml @@ -0,0 +1,30 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +[package] +name = "openshell-driver-db-credstore" +description = "Encrypted database credential storage driver for OpenShell" +version.workspace = true +edition.workspace = true +rust-version.workspace = true +license.workspace = true +repository.workspace = true + +[dependencies] +openshell-core = { path = "../openshell-core", default-features = false } + +async-trait = "0.1" +base64 = { workspace = true } +ring = { workspace = true } +serde = { workspace = true } +serde_json = { workspace = true } +sha2 = { workspace = true } +toml = { workspace = true } +tonic = { workspace = true } + +[dev-dependencies] +tempfile = "3" +tokio = { workspace = true } + +[lints] +workspace = true diff --git a/crates/openshell-driver-db-credstore/src/lib.rs b/crates/openshell-driver-db-credstore/src/lib.rs new file mode 100644 index 000000000..86701f7b6 --- /dev/null +++ b/crates/openshell-driver-db-credstore/src/lib.rs @@ -0,0 +1,1203 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Encrypted database-backed credential storage driver. +//! +//! The driver persists encrypted credential envelopes through a caller-provided +//! object store. `openshell-server` supplies the object-store adapter for the +//! gateway database, while this crate owns the credential driver behavior and +//! envelope cryptography. + +use std::collections::HashMap; +use std::fs::{self, OpenOptions}; +use std::io::Write; +#[cfg(unix)] +use std::os::unix::fs::OpenOptionsExt; +use std::path::{Path, PathBuf}; +use std::sync::Arc; + +use async_trait::async_trait; +use base64::{ + Engine as _, + engine::general_purpose::{STANDARD as BASE64, STANDARD_NO_PAD as BASE64_NO_PAD}, +}; +use openshell_core::proto::CredentialHandle; +use openshell_core::proto::credentials::v1::{ + DeleteCredentialRequest, ResolveCredentialRequest, ResolvedCredential, StoreCredentialRequest, +}; +use openshell_core::{Error, Result as CoreResult}; +use ring::aead::{AES_256_GCM, Aad, LessSafeKey, Nonce, UnboundKey}; +use ring::rand::{SecureRandom, SystemRandom}; +use serde::{Deserialize, Serialize}; +use sha2::{Digest, Sha256}; +use tonic::Status; + +const HANDLE_VERSION: &str = "v1"; +const ENVELOPE_VERSION: u32 = 1; +const KEY_LEN: usize = 32; +const NONCE_LEN: usize = 12; +const HANDLE_ID_LEN: usize = 64; +const ALGORITHM: &str = "AES-256-GCM"; +const DEFAULT_KEY_ENCRYPTION_KEY_FILE: &str = "key-encryption-key.bin"; + +pub const DRIVER_NAME: &str = "openshell-driver-db-credstore"; +pub const OBJECT_TYPE: &str = "credential.gateway-encrypted"; + +#[derive(Debug, Clone)] +pub struct DbCredstoreCredentialDriver { + store: Arc, + crypto: EncryptedGatewayCredentialStoreCrypto, +} + +#[async_trait] +pub trait DbCredstoreObjectStore: std::fmt::Debug + Send + Sync { + async fn get_credential_object( + &self, + object_type: &str, + id: &str, + operation: &'static str, + ) -> Result, Status>; + + async fn put_credential_object( + &self, + write: CredentialObjectWrite, + operation: &'static str, + ) -> Result<(), Status>; + + async fn delete_credential_object( + &self, + object_type: &str, + id: &str, + expected_resource_version: u64, + operation: &'static str, + ) -> Result<(), Status>; +} + +#[derive(Debug, Clone)] +pub struct StoredCredentialObject { + pub object_type: String, + pub id: String, + pub payload: Vec, + pub resource_version: u64, +} + +#[derive(Debug, Clone)] +pub struct CredentialObjectWrite { + pub object_type: String, + pub id: String, + pub name: String, + pub payload: Vec, + pub labels: Option, + pub condition: DbCredstoreWriteCondition, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum DbCredstoreWriteCondition { + MustCreate, + MatchResourceVersion(u64), +} + +#[derive(Clone)] +pub struct EncryptedGatewayCredentialStoreCrypto { + state: EncryptedGatewayCredentialState, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +struct EncryptedGatewayCredentialSettings { + key_encryption_key_path: Option, + key_encryption_key_env: Option, +} + +#[derive(Clone)] +struct EncryptedGatewayCredentialState { + settings: EncryptedGatewayCredentialSettings, + key_encryption_key: [u8; KEY_LEN], + key_encryption_key_id: String, +} + +#[derive(Debug, Clone, Default, Deserialize)] +#[serde(default, deny_unknown_fields)] +struct EncryptedGatewayCredentialConfig { + key_encryption_key_path: Option, + key_encryption_key_env: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct EncryptedCredentialEnvelope { + version: u32, + id: String, + provider_name: String, + credential_key: String, + algorithm: String, + key_encryption_key_id: String, + wrapped_dek: EncryptedBytes, + value: EncryptedBytes, +} + +#[derive(Debug, Serialize, Deserialize)] +struct EncryptedBytes { + nonce: String, + ciphertext: String, +} + +impl DbCredstoreCredentialDriver { + pub const NAME: &'static str = DRIVER_NAME; + pub const OBJECT_TYPE: &'static str = OBJECT_TYPE; + + pub fn from_config( + store: Arc, + config: &toml::Table, + ) -> CoreResult { + Ok(Self { + store, + crypto: EncryptedGatewayCredentialStoreCrypto::from_config(config)?, + }) + } + + pub async fn store_credential( + &self, + request: StoreCredentialRequest, + ) -> Result { + let credential_key = EncryptedGatewayCredentialStoreCrypto::validate_credential_key( + &request.credential_key, + )? + .to_string(); + let provider_name = + EncryptedGatewayCredentialStoreCrypto::validate_provider_name(&request.provider_name)? + .to_string(); + + if let Some(existing_handle) = request.existing_handle.as_ref() { + let id = EncryptedGatewayCredentialStoreCrypto::id_from_handle(existing_handle)?; + let existing = self + .store + .get_credential_object(OBJECT_TYPE, &id, "load existing credential") + .await?; + if let Some(record) = existing { + let envelope = deserialize_credential_envelope(&record)?; + EncryptedGatewayCredentialStoreCrypto::ensure_envelope_owner( + &envelope, + &id, + &provider_name, + &credential_key, + )?; + self.write_envelope( + &id, + &provider_name, + &credential_key, + &request.value, + DbCredstoreWriteCondition::MatchResourceVersion(record.resource_version), + ) + .await?; + } else { + self.write_envelope( + &id, + &provider_name, + &credential_key, + &request.value, + DbCredstoreWriteCondition::MustCreate, + ) + .await?; + } + return self.crypto.credential_handle(&id); + } + + for _ in 0..16 { + let id = EncryptedGatewayCredentialStoreCrypto::new_handle_id()?; + match self + .write_envelope( + &id, + &provider_name, + &credential_key, + &request.value, + DbCredstoreWriteCondition::MustCreate, + ) + .await + { + Ok(()) => return self.crypto.credential_handle(&id), + Err(err) if err.code() == tonic::Code::AlreadyExists => {} + Err(err) => return Err(err), + } + } + + Err(Status::unavailable( + "failed to allocate unused default credential handle", + )) + } + + pub async fn delete_credential(&self, request: DeleteCredentialRequest) -> Result<(), Status> { + let handle = + EncryptedGatewayCredentialStoreCrypto::handle_from_request("delete", request.handle)?; + let id = EncryptedGatewayCredentialStoreCrypto::id_from_handle(&handle)?; + let record = self + .store + .get_credential_object(OBJECT_TYPE, &id, "load credential for deletion") + .await?; + let Some(record) = record else { + return Ok(()); + }; + + let envelope = deserialize_credential_envelope(&record)?; + EncryptedGatewayCredentialStoreCrypto::ensure_envelope_owner( + &envelope, + &id, + EncryptedGatewayCredentialStoreCrypto::validate_provider_name(&request.provider_name)?, + EncryptedGatewayCredentialStoreCrypto::validate_credential_key( + &request.credential_key, + )?, + )?; + + self.store + .delete_credential_object( + OBJECT_TYPE, + &id, + record.resource_version, + "delete credential", + ) + .await + } + + pub async fn resolve_credentials( + &self, + requests: Vec, + ) -> Result, Status> { + let mut responses = Vec::with_capacity(requests.len()); + for request in requests { + let handle = EncryptedGatewayCredentialStoreCrypto::handle_from_request( + &request.request_id, + request.handle, + )?; + let id = EncryptedGatewayCredentialStoreCrypto::id_from_handle(&handle)?; + let record = self + .store + .get_credential_object(OBJECT_TYPE, &id, "load credential") + .await? + .ok_or_else(|| { + Status::not_found(format!("default credential '{id}' was not found")) + })?; + let envelope = deserialize_credential_envelope(&record)?; + EncryptedGatewayCredentialStoreCrypto::ensure_envelope_owner( + &envelope, + &id, + EncryptedGatewayCredentialStoreCrypto::validate_provider_name( + &request.provider_name, + )?, + EncryptedGatewayCredentialStoreCrypto::validate_credential_key( + &request.credential_key, + )?, + )?; + let value = self.crypto.decrypt_envelope(&envelope)?; + responses.push(ResolvedCredential { + request_id: request.request_id, + value, + expires_at_ms: 0, + }); + } + + Ok(responses) + } + + async fn write_envelope( + &self, + id: &str, + provider_name: &str, + credential_key: &str, + value: &str, + condition: DbCredstoreWriteCondition, + ) -> Result<(), Status> { + let envelope = self + .crypto + .encrypt_envelope(id, provider_name, credential_key, value)?; + let payload = EncryptedGatewayCredentialStoreCrypto::serialize_envelope(&envelope)?; + let labels = credential_labels(provider_name, credential_key)?; + + self.store + .put_credential_object( + CredentialObjectWrite { + object_type: OBJECT_TYPE.to_string(), + id: id.to_string(), + name: id.to_string(), + payload, + labels: Some(labels), + condition, + }, + "persist credential", + ) + .await + } +} + +impl EncryptedGatewayCredentialStoreCrypto { + pub fn from_config(config: &toml::Table) -> CoreResult { + let settings = EncryptedGatewayCredentialSettings::from_table(config)?; + Ok(Self { + state: EncryptedGatewayCredentialState::from_settings(settings)?, + }) + } + + pub fn new_handle_id() -> Result { + new_handle_id() + } + + pub fn credential_handle(&self, id: &str) -> Result { + validate_handle_id(id)?; + Ok(credential_handle(&self.state, id)) + } + + pub fn handle_from_request( + request_id: &str, + handle: Option, + ) -> Result { + let handle = handle.ok_or_else(|| { + Status::invalid_argument(format!( + "default credential storage request '{request_id}' is missing handle" + )) + })?; + validate_handle_owner(&handle)?; + Ok(handle) + } + + pub fn id_from_handle(handle: &CredentialHandle) -> Result { + validate_handle_owner(handle)?; + let id = handle + .handle + .strip_prefix(&format!("{HANDLE_VERSION}:")) + .ok_or_else(|| { + Status::invalid_argument("default credential storage handle is malformed") + })?; + validate_handle_id(id)?; + Ok(id.to_string()) + } + + pub fn encrypt_envelope( + &self, + id: &str, + provider_name: &str, + credential_key: &str, + value: &str, + ) -> Result { + encrypt_envelope(&self.state, id, provider_name, credential_key, value) + } + + pub fn decrypt_envelope( + &self, + envelope: &EncryptedCredentialEnvelope, + ) -> Result { + decrypt_envelope(&self.state, envelope) + } + + pub fn ensure_envelope_owner( + envelope: &EncryptedCredentialEnvelope, + id: &str, + provider_name: &str, + credential_key: &str, + ) -> Result<(), Status> { + ensure_envelope_owner(envelope, id, provider_name, credential_key) + } + + pub fn validate_provider_name(value: &str) -> Result<&str, Status> { + validate_provider_name(value) + } + + pub fn validate_credential_key(value: &str) -> Result<&str, Status> { + validate_credential_key(value) + } + + pub fn serialize_envelope(envelope: &EncryptedCredentialEnvelope) -> Result, Status> { + serialize_envelope(envelope) + } + + pub fn deserialize_envelope( + bytes: &[u8], + description: impl std::fmt::Display, + ) -> Result { + serde_json::from_slice(bytes).map_err(|err| { + Status::data_loss(format!( + "default credential storage object '{description}' has invalid envelope JSON: {err}" + )) + }) + } +} + +impl std::fmt::Debug for EncryptedGatewayCredentialStoreCrypto { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("EncryptedGatewayCredentialStoreCrypto") + .field("settings", &self.state.settings) + .field("key_encryption_key_id", &self.state.key_encryption_key_id) + .finish_non_exhaustive() + } +} + +impl EncryptedGatewayCredentialSettings { + fn from_table(config: &toml::Table) -> CoreResult { + let config: EncryptedGatewayCredentialConfig = toml::Value::Table(config.clone()) + .try_into() + .map_err(|err| { + Error::config(format!( + "invalid [openshell.gateway.credential_storage]: {err}" + )) + })?; + + if config.key_encryption_key_path.is_some() && config.key_encryption_key_env.is_some() { + return Err(Error::config( + "[openshell.gateway.credential_storage] set only one of key_encryption_key_path or key_encryption_key_env", + )); + } + + let key_encryption_key_path = match config.key_encryption_key_path { + Some(path) => Some(validate_path("key_encryption_key_path", path)?), + None if config.key_encryption_key_env.is_some() => None, + None => Some(default_key_encryption_key_path()?), + }; + let key_encryption_key_env = config + .key_encryption_key_env + .map(|name| validate_env_name("key_encryption_key_env", &name)) + .transpose()?; + + Ok(Self { + key_encryption_key_path, + key_encryption_key_env, + }) + } +} + +impl EncryptedGatewayCredentialState { + fn from_settings(settings: EncryptedGatewayCredentialSettings) -> CoreResult { + let key_encryption_key = load_key_encryption_key(&settings)?; + let key_encryption_key_id = key_id(&key_encryption_key); + Ok(Self { + settings, + key_encryption_key, + key_encryption_key_id, + }) + } +} + +fn default_key_encryption_key_path() -> CoreResult { + let state_dir = openshell_core::paths::openshell_state_dir().map_err(|err| { + Error::config(format!( + "failed to resolve default credential storage key-encryption key path: {err}" + )) + })?; + Ok(state_dir + .join("gateway") + .join("credentials") + .join(DEFAULT_KEY_ENCRYPTION_KEY_FILE)) +} + +fn validate_path(field_name: &str, path: PathBuf) -> CoreResult { + if path.as_os_str().is_empty() { + return Err(Error::config(format!( + "[openshell.gateway.credential_storage] {field_name} must not be empty" + ))); + } + if !path.is_absolute() { + return Err(Error::config(format!( + "[openshell.gateway.credential_storage] {field_name} must be absolute" + ))); + } + Ok(path) +} + +fn validate_env_name(field_name: &str, value: &str) -> CoreResult { + let trimmed = value.trim(); + if trimmed.is_empty() || trimmed.len() != value.len() { + return Err(Error::config(format!( + "[openshell.gateway.credential_storage] {field_name} must not be empty or contain surrounding whitespace" + ))); + } + if !trimmed + .bytes() + .all(|byte| byte.is_ascii_alphanumeric() || byte == b'_') + { + return Err(Error::config(format!( + "[openshell.gateway.credential_storage] {field_name} must name an environment variable using only letters, digits, and underscores" + ))); + } + Ok(trimmed.to_string()) +} + +fn load_key_encryption_key( + settings: &EncryptedGatewayCredentialSettings, +) -> CoreResult<[u8; KEY_LEN]> { + if let Some(env_name) = &settings.key_encryption_key_env { + let value = std::env::var(env_name).map_err(|_| { + Error::config(format!( + "[openshell.gateway.credential_storage] environment variable '{env_name}' is not set" + )) + })?; + return decode_key_encryption_key_base64(&value).map_err(Error::config); + } + let path = settings + .key_encryption_key_path + .as_ref() + .expect("settings always has key_encryption_key_path unless key_encryption_key_env is set"); + load_or_create_file_key_encryption_key(path) +} + +fn decode_key_encryption_key_base64(value: &str) -> Result<[u8; KEY_LEN], String> { + let trimmed = value.trim(); + let bytes = BASE64 + .decode(trimmed) + .or_else(|_| BASE64_NO_PAD.decode(trimmed)) + .map_err(|err| { + format!("key_encryption_key_env value must be base64-encoded 32-byte key: {err}") + })?; + fixed_bytes::(&bytes) + .map_err(|()| "key_encryption_key_env value must decode to exactly 32 bytes".to_string()) +} + +fn load_or_create_file_key_encryption_key(path: &Path) -> CoreResult<[u8; KEY_LEN]> { + match fs::read(path) { + Ok(bytes) => { + openshell_core::paths::set_file_owner_only(path).map_err(|err| { + Error::config(format!( + "failed to restrict default credential storage key-encryption key '{}': {err}", + path.display() + )) + })?; + return fixed_bytes::(&bytes).map_err(|()| { + Error::config(format!( + "[openshell.gateway.credential_storage] key_encryption_key_path '{}' must contain exactly 32 bytes", + path.display() + )) + }); + } + Err(err) if err.kind() == std::io::ErrorKind::NotFound => {} + Err(err) => { + return Err(Error::config(format!( + "failed to read default credential storage key-encryption key '{}': {err}", + path.display() + ))); + } + } + + openshell_core::paths::ensure_parent_dir_restricted(path).map_err(|err| { + Error::config(format!( + "failed to prepare default credential storage key-encryption key directory '{}': {err}", + path.display() + )) + })?; + let key_encryption_key = random_bytes_core::()?; + let mut options = OpenOptions::new(); + options.write(true).create_new(true); + #[cfg(unix)] + options.mode(0o600); + match options.open(path) { + Ok(mut file) => { + if let Err(err) = file.write_all(&key_encryption_key) { + let _ = fs::remove_file(path); + return Err(Error::config(format!( + "failed to write default credential storage key-encryption key '{}': {err}", + path.display() + ))); + } + openshell_core::paths::set_file_owner_only(path).map_err(|err| { + Error::config(format!( + "failed to restrict default credential storage key-encryption key '{}': {err}", + path.display() + )) + })?; + Ok(key_encryption_key) + } + Err(err) if err.kind() == std::io::ErrorKind::AlreadyExists => { + load_or_create_file_key_encryption_key(path) + } + Err(err) => Err(Error::config(format!( + "failed to create default credential storage key-encryption key '{}': {err}", + path.display() + ))), + } +} + +fn new_handle_id() -> Result { + Ok(hex_encode(&random_bytes_status::()?)) +} + +fn credential_handle(state: &EncryptedGatewayCredentialState, id: &str) -> CredentialHandle { + CredentialHandle { + driver: DRIVER_NAME.to_string(), + handle: format!("{HANDLE_VERSION}:{id}"), + metadata: [ + ("algorithm".to_string(), ALGORITHM.to_string()), + ( + "key_encryption_key_id".to_string(), + state.key_encryption_key_id.clone(), + ), + ] + .into_iter() + .collect(), + } +} + +fn validate_handle_owner(handle: &CredentialHandle) -> Result<(), Status> { + if handle.driver.trim() == DRIVER_NAME { + return Ok(()); + } + Err(Status::invalid_argument(format!( + "default credential storage cannot use handle owned by '{}'", + handle.driver + ))) +} + +fn encrypt_envelope( + state: &EncryptedGatewayCredentialState, + id: &str, + provider_name: &str, + credential_key: &str, + value: &str, +) -> Result { + let dek = random_bytes_status::()?; + let wrapped_dek = encrypt_bytes( + &state.key_encryption_key, + &dek_aad(id, provider_name, credential_key), + &dek, + )?; + let encrypted_value = encrypt_bytes( + &dek, + &value_aad(id, provider_name, credential_key), + value.as_bytes(), + )?; + + Ok(EncryptedCredentialEnvelope { + version: ENVELOPE_VERSION, + id: id.to_string(), + provider_name: provider_name.to_string(), + credential_key: credential_key.to_string(), + algorithm: ALGORITHM.to_string(), + key_encryption_key_id: state.key_encryption_key_id.clone(), + wrapped_dek, + value: encrypted_value, + }) +} + +fn decrypt_envelope( + state: &EncryptedGatewayCredentialState, + envelope: &EncryptedCredentialEnvelope, +) -> Result { + validate_envelope_metadata(envelope)?; + if envelope.key_encryption_key_id != state.key_encryption_key_id { + return Err(Status::failed_precondition( + "default credential storage object was encrypted with a different key-encryption key", + )); + } + let dek = decrypt_bytes( + &state.key_encryption_key, + &dek_aad( + &envelope.id, + &envelope.provider_name, + &envelope.credential_key, + ), + &envelope.wrapped_dek, + )?; + let dek = fixed_bytes::(&dek) + .map_err(|()| Status::data_loss("default credential storage DEK has invalid length"))?; + let plaintext = decrypt_bytes( + &dek, + &value_aad( + &envelope.id, + &envelope.provider_name, + &envelope.credential_key, + ), + &envelope.value, + )?; + String::from_utf8(plaintext) + .map_err(|_| Status::data_loss("default credential storage value is not valid UTF-8")) +} + +fn encrypt_bytes( + key_bytes: &[u8; KEY_LEN], + aad: &[u8], + plaintext: &[u8], +) -> Result { + let nonce = random_bytes_status::()?; + let key = aead_key(key_bytes)?; + let mut in_out = plaintext.to_vec(); + key.seal_in_place_append_tag( + Nonce::assume_unique_for_key(nonce), + Aad::from(aad), + &mut in_out, + ) + .map_err(|_| Status::internal("failed to encrypt default credential storage value"))?; + Ok(EncryptedBytes { + nonce: BASE64.encode(nonce), + ciphertext: BASE64.encode(in_out), + }) +} + +fn decrypt_bytes( + key_bytes: &[u8; KEY_LEN], + aad: &[u8], + encrypted: &EncryptedBytes, +) -> Result, Status> { + let nonce = decode_b64_array::("nonce", &encrypted.nonce)?; + let mut in_out = decode_b64_vec("ciphertext", &encrypted.ciphertext)?; + let key = aead_key(key_bytes)?; + let plaintext = key + .open_in_place( + Nonce::assume_unique_for_key(nonce), + Aad::from(aad), + &mut in_out, + ) + .map_err(|_| Status::data_loss("failed to decrypt default credential storage value"))?; + Ok(plaintext.to_vec()) +} + +fn aead_key(key_bytes: &[u8; KEY_LEN]) -> Result { + let unbound = UnboundKey::new(&AES_256_GCM, key_bytes).map_err(|_| { + Status::internal("failed to initialize default credential storage AEAD key") + })?; + Ok(LessSafeKey::new(unbound)) +} + +fn dek_aad(id: &str, provider_name: &str, credential_key: &str) -> Vec { + format!("openshell:gateway-credential-storage:v1:dek:{id}:{provider_name}:{credential_key}") + .into_bytes() +} + +fn value_aad(id: &str, provider_name: &str, credential_key: &str) -> Vec { + format!("openshell:gateway-credential-storage:v1:value:{id}:{provider_name}:{credential_key}") + .into_bytes() +} + +fn serialize_envelope(envelope: &EncryptedCredentialEnvelope) -> Result, Status> { + serde_json::to_vec(envelope).map_err(|err| { + Status::internal(format!( + "failed to serialize default credential storage envelope: {err}" + )) + }) +} + +fn deserialize_credential_envelope( + record: &StoredCredentialObject, +) -> Result { + EncryptedGatewayCredentialStoreCrypto::deserialize_envelope( + &record.payload, + format!("{}/{}", record.object_type, record.id), + ) +} + +fn credential_labels(provider_name: &str, credential_key: &str) -> Result { + serde_json::to_string(&HashMap::from([ + ("provider_name", provider_name), + ("credential_key", credential_key), + ])) + .map_err(|err| { + Status::internal(format!( + "failed to serialize default credential labels: {err}" + )) + }) +} + +fn validate_envelope_metadata(envelope: &EncryptedCredentialEnvelope) -> Result<(), Status> { + if envelope.version != ENVELOPE_VERSION { + return Err(Status::data_loss(format!( + "default credential storage envelope version {} is unsupported", + envelope.version + ))); + } + validate_handle_id(&envelope.id)?; + if envelope.algorithm != ALGORITHM { + return Err(Status::data_loss(format!( + "default credential storage algorithm '{}' is unsupported", + envelope.algorithm + ))); + } + validate_provider_name(&envelope.provider_name)?; + validate_credential_key(&envelope.credential_key)?; + Ok(()) +} + +fn ensure_envelope_owner( + envelope: &EncryptedCredentialEnvelope, + id: &str, + provider_name: &str, + credential_key: &str, +) -> Result<(), Status> { + validate_envelope_metadata(envelope)?; + if envelope.id == id + && envelope.provider_name == provider_name + && envelope.credential_key == credential_key + { + return Ok(()); + } + Err(Status::failed_precondition( + "default credential storage handle is not managed for this provider credential", + )) +} + +fn validate_handle_id(id: &str) -> Result<(), Status> { + if id.len() == HANDLE_ID_LEN + && id + .bytes() + .all(|byte| byte.is_ascii_digit() || matches!(byte, b'a'..=b'f')) + { + return Ok(()); + } + Err(Status::invalid_argument( + "default credential storage handle id is invalid", + )) +} + +fn validate_provider_name(value: &str) -> Result<&str, Status> { + validate_request_component("provider_name", value) +} + +fn validate_credential_key(value: &str) -> Result<&str, Status> { + validate_request_component("credential_key", value) +} + +fn validate_request_component<'a>(field_name: &str, value: &'a str) -> Result<&'a str, Status> { + let trimmed = value.trim(); + if trimmed.is_empty() { + return Err(Status::invalid_argument(format!( + "default credential storage request {field_name} is required" + ))); + } + if trimmed.len() != value.len() { + return Err(Status::invalid_argument(format!( + "default credential storage request {field_name} must not contain leading or trailing whitespace" + ))); + } + Ok(trimmed) +} + +fn decode_b64_array(field_name: &str, value: &str) -> Result<[u8; N], Status> { + let bytes = decode_b64_vec(field_name, value)?; + fixed_bytes::(&bytes).map_err(|()| { + Status::data_loss(format!( + "default credential storage envelope {field_name} has invalid length" + )) + }) +} + +fn decode_b64_vec(field_name: &str, value: &str) -> Result, Status> { + BASE64.decode(value).map_err(|err| { + Status::data_loss(format!( + "default credential storage envelope {field_name} is invalid base64: {err}" + )) + }) +} + +fn fixed_bytes(bytes: &[u8]) -> Result<[u8; N], ()> { + bytes.try_into().map_err(|_| ()) +} + +fn random_bytes_core() -> CoreResult<[u8; N]> { + let mut bytes = [0_u8; N]; + SystemRandom::new() + .fill(&mut bytes) + .map_err(|_| Error::config("failed to generate default credential storage key material"))?; + Ok(bytes) +} + +fn random_bytes_status() -> Result<[u8; N], Status> { + let mut bytes = [0_u8; N]; + SystemRandom::new().fill(&mut bytes).map_err(|_| { + Status::internal("failed to generate default credential storage randomness") + })?; + Ok(bytes) +} + +fn key_id(key: &[u8; KEY_LEN]) -> String { + let digest = Sha256::digest(key); + format!("sha256:{}", hex_encode(&digest)) +} + +fn hex_encode(bytes: &[u8]) -> String { + const HEX: &[u8; 16] = b"0123456789abcdef"; + let mut out = String::with_capacity(bytes.len() * 2); + for byte in bytes { + out.push(HEX[(byte >> 4) as usize] as char); + out.push(HEX[(byte & 0x0f) as usize] as char); + } + out +} + +#[cfg(test)] +mod tests { + use super::*; + use std::sync::{Arc, Mutex}; + use tonic::Code; + + #[derive(Debug, Default)] + struct MemoryObjectStore { + objects: Mutex>, + } + + #[async_trait] + impl DbCredstoreObjectStore for MemoryObjectStore { + async fn get_credential_object( + &self, + _object_type: &str, + id: &str, + _operation: &'static str, + ) -> Result, Status> { + Ok(self.objects.lock().unwrap().get(id).cloned()) + } + + async fn put_credential_object( + &self, + write: CredentialObjectWrite, + _operation: &'static str, + ) -> Result<(), Status> { + let mut objects = self.objects.lock().unwrap(); + match write.condition { + DbCredstoreWriteCondition::MustCreate if objects.contains_key(&write.id) => { + return Err(Status::already_exists("object already exists")); + } + DbCredstoreWriteCondition::MatchResourceVersion(expected) => { + let Some(current) = objects.get(&write.id) else { + return Err(Status::not_found("object not found")); + }; + if current.resource_version != expected { + return Err(Status::aborted("resource version conflict")); + } + } + DbCredstoreWriteCondition::MustCreate => {} + } + + let resource_version = objects + .get(&write.id) + .map_or(1, |current| current.resource_version + 1); + objects.insert( + write.id.clone(), + StoredCredentialObject { + object_type: write.object_type, + id: write.id, + payload: write.payload, + resource_version, + }, + ); + Ok(()) + } + + async fn delete_credential_object( + &self, + _object_type: &str, + id: &str, + expected_resource_version: u64, + _operation: &'static str, + ) -> Result<(), Status> { + let mut objects = self.objects.lock().unwrap(); + let Some(current) = objects.get(id) else { + return Ok(()); + }; + if current.resource_version != expected_resource_version { + return Err(Status::aborted("resource version conflict")); + } + objects.remove(id); + Ok(()) + } + } + + fn crypto_for_key_encryption_key_path(path: &Path) -> EncryptedGatewayCredentialStoreCrypto { + let mut config = toml::Table::new(); + config.insert( + "key_encryption_key_path".to_string(), + toml::Value::String(path.to_string_lossy().to_string()), + ); + EncryptedGatewayCredentialStoreCrypto::from_config(&config).unwrap() + } + + fn driver_config_for_key_encryption_key_path(path: &Path) -> toml::Table { + let mut config = toml::Table::new(); + config.insert( + "key_encryption_key_path".to_string(), + toml::Value::String(path.to_string_lossy().to_string()), + ); + config + } + + fn request( + provider_name: &str, + credential_key: &str, + value: &str, + existing_handle: Option, + ) -> StoreCredentialRequest { + StoreCredentialRequest { + provider_name: provider_name.to_string(), + credential_key: credential_key.to_string(), + value: value.to_string(), + existing_handle, + } + } + + fn resolve_request( + request_id: &str, + provider_name: &str, + credential_key: &str, + handle: CredentialHandle, + ) -> ResolveCredentialRequest { + ResolveCredentialRequest { + request_id: request_id.to_string(), + provider_name: provider_name.to_string(), + credential_key: credential_key.to_string(), + handle: Some(handle), + } + } + + #[tokio::test] + async fn driver_stores_resolves_updates_and_deletes_encrypted_objects() { + let tmp = tempfile::tempdir().unwrap(); + let config = driver_config_for_key_encryption_key_path( + &tmp.path().join(DEFAULT_KEY_ENCRYPTION_KEY_FILE), + ); + let store = Arc::new(MemoryObjectStore::default()); + let object_store: Arc = store.clone(); + let driver = DbCredstoreCredentialDriver::from_config(object_store, &config).unwrap(); + + let first = driver + .store_credential(request( + "openai-local", + "OPENAI_API_KEY", + "sk-original", + None, + )) + .await + .unwrap(); + assert_eq!(first.driver, DbCredstoreCredentialDriver::NAME); + let handle_id = first.handle.strip_prefix("v1:").unwrap(); + let payload = store + .objects + .lock() + .unwrap() + .get(handle_id) + .unwrap() + .payload + .clone(); + assert!(!String::from_utf8_lossy(&payload).contains("sk-original")); + + let resolved = driver + .resolve_credentials(vec![resolve_request( + "credential-0", + "openai-local", + "OPENAI_API_KEY", + first.clone(), + )]) + .await + .unwrap(); + assert_eq!(resolved[0].value, "sk-original"); + + let updated = driver + .store_credential(request( + "openai-local", + "OPENAI_API_KEY", + "sk-updated", + Some(first.clone()), + )) + .await + .unwrap(); + assert_eq!(updated.handle, first.handle); + + let resolved = driver + .resolve_credentials(vec![resolve_request( + "credential-0", + "openai-local", + "OPENAI_API_KEY", + updated.clone(), + )]) + .await + .unwrap(); + assert_eq!(resolved[0].value, "sk-updated"); + + driver + .delete_credential(DeleteCredentialRequest { + provider_name: "openai-local".to_string(), + credential_key: "OPENAI_API_KEY".to_string(), + handle: Some(updated.clone()), + }) + .await + .unwrap(); + + let err = driver + .resolve_credentials(vec![resolve_request( + "credential-0", + "openai-local", + "OPENAI_API_KEY", + updated, + )]) + .await + .unwrap_err(); + assert_eq!(err.code(), Code::NotFound); + } + + #[test] + fn encrypts_decrypts_and_serializes_envelope() { + let tmp = tempfile::tempdir().unwrap(); + let crypto = + crypto_for_key_encryption_key_path(&tmp.path().join(DEFAULT_KEY_ENCRYPTION_KEY_FILE)); + let id = EncryptedGatewayCredentialStoreCrypto::new_handle_id().unwrap(); + let envelope = crypto + .encrypt_envelope(&id, "openai-local", "OPENAI_API_KEY", "sk-original") + .unwrap(); + let serialized = + EncryptedGatewayCredentialStoreCrypto::serialize_envelope(&envelope).unwrap(); + assert!(!String::from_utf8_lossy(&serialized).contains("sk-original")); + + let envelope = + EncryptedGatewayCredentialStoreCrypto::deserialize_envelope(&serialized, "test") + .unwrap(); + assert_eq!(crypto.decrypt_envelope(&envelope).unwrap(), "sk-original"); + } + + #[test] + fn file_key_encryption_key_is_reused_across_instances() { + let tmp = tempfile::tempdir().unwrap(); + let key_encryption_key_path = tmp.path().join(DEFAULT_KEY_ENCRYPTION_KEY_FILE); + let crypto = crypto_for_key_encryption_key_path(&key_encryption_key_path); + let id = EncryptedGatewayCredentialStoreCrypto::new_handle_id().unwrap(); + let envelope = crypto + .encrypt_envelope(&id, "openai-local", "OPENAI_API_KEY", "sk-persisted") + .unwrap(); + + let restarted = crypto_for_key_encryption_key_path(&key_encryption_key_path); + assert_eq!( + restarted.decrypt_envelope(&envelope).unwrap(), + "sk-persisted" + ); + } + + #[test] + fn rejects_handle_for_different_provider() { + let tmp = tempfile::tempdir().unwrap(); + let crypto = + crypto_for_key_encryption_key_path(&tmp.path().join(DEFAULT_KEY_ENCRYPTION_KEY_FILE)); + let id = EncryptedGatewayCredentialStoreCrypto::new_handle_id().unwrap(); + let envelope = crypto + .encrypt_envelope(&id, "openai-local", "OPENAI_API_KEY", "sk-original") + .unwrap(); + + let err = EncryptedGatewayCredentialStoreCrypto::ensure_envelope_owner( + &envelope, + &id, + "other-provider", + "OPENAI_API_KEY", + ) + .unwrap_err(); + assert_eq!(err.code(), Code::FailedPrecondition); + } + + #[test] + fn env_key_encryption_key_must_decode_to_32_bytes() { + let err = decode_key_encryption_key_base64(&BASE64.encode([1_u8; 31])).unwrap_err(); + assert!(err.contains("32 bytes")); + assert!(decode_key_encryption_key_base64(&BASE64.encode([1_u8; KEY_LEN])).is_ok()); + assert!(decode_key_encryption_key_base64(&BASE64_NO_PAD.encode([1_u8; KEY_LEN])).is_ok()); + } + + #[cfg(unix)] + #[test] + fn generated_key_encryption_key_file_is_owner_only() { + use std::os::unix::fs::PermissionsExt; + + let tmp = tempfile::tempdir().unwrap(); + let key_encryption_key_path = tmp.path().join(DEFAULT_KEY_ENCRYPTION_KEY_FILE); + let _crypto = crypto_for_key_encryption_key_path(&key_encryption_key_path); + let key_encryption_key_mode = fs::metadata(key_encryption_key_path) + .unwrap() + .permissions() + .mode() + & 0o777; + assert_eq!(key_encryption_key_mode, 0o600); + } +} diff --git a/crates/openshell-driver-kubernetes-secrets/Cargo.toml b/crates/openshell-driver-kubernetes-secrets/Cargo.toml new file mode 100644 index 000000000..3655013ff --- /dev/null +++ b/crates/openshell-driver-kubernetes-secrets/Cargo.toml @@ -0,0 +1,34 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +[package] +name = "openshell-driver-kubernetes-secrets" +description = "Kubernetes Secrets credential driver for OpenShell" +version.workspace = true +edition.workspace = true +rust-version.workspace = true +license.workspace = true +repository.workspace = true + +[[bin]] +name = "openshell-driver-kubernetes-secrets" +path = "src/main.rs" + +[dependencies] +openshell-core = { path = "../openshell-core", default-features = false } + +clap = { workspace = true } +futures = { workspace = true } +k8s-openapi = { workspace = true } +kube = { workspace = true } +miette = { workspace = true } +serde = { workspace = true } +sha2 = { workspace = true } +tokio = { workspace = true } +toml = { workspace = true } +tonic = { workspace = true } +tracing = { workspace = true } +tracing-subscriber = { workspace = true } + +[lints] +workspace = true diff --git a/crates/openshell-driver-kubernetes-secrets/src/lib.rs b/crates/openshell-driver-kubernetes-secrets/src/lib.rs new file mode 100644 index 000000000..3104fc608 --- /dev/null +++ b/crates/openshell-driver-kubernetes-secrets/src/lib.rs @@ -0,0 +1,793 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Credential driver backed by Kubernetes Secret objects. + +use std::collections::BTreeMap; + +use k8s_openapi::api::core::v1::Secret; +use k8s_openapi::apimachinery::pkg::apis::meta::v1::ObjectMeta; +use kube::api::{DeleteParams, Patch, PatchParams, PostParams}; +use kube::{Api, Client}; +use openshell_core::VERSION; +use openshell_core::proto::CredentialHandle; +use openshell_core::proto::credentials::v1::{ + DeleteCredentialRequest, DeleteCredentialResponse, GetCredentialDriverCapabilitiesRequest, + GetCredentialDriverCapabilitiesResponse, ListCredentialsRequest, ListCredentialsResponse, + ResolveCredentialRequest, ResolveCredentialsRequest, ResolveCredentialsResponse, + ResolvedCredential, StoreCredentialRequest, StoreCredentialResponse, + credential_driver_server::CredentialDriver, +}; +use openshell_core::{Error, Result as CoreResult}; +use sha2::{Digest, Sha256}; +use tonic::{Request, Response, Status}; + +const SERVICE_ACCOUNT_NAMESPACE_PATH: &str = + "/var/run/secrets/kubernetes.io/serviceaccount/namespace"; +const HANDLE_VERSION: &str = "v1"; +const MANAGED_BY_LABEL: &str = "app.kubernetes.io/managed-by"; +const MANAGED_BY_VALUE: &str = "openshell"; +const OWNER_ANNOTATION: &str = "openshell.nvidia.com/provider-credential-id"; + +pub struct KubernetesSecretsCredentialDriver { + client: Client, + settings: KubernetesSecretsDriverSettings, +} + +#[derive(Debug, Clone)] +pub struct CredentialDriverService { + driver: KubernetesSecretsCredentialDriver, +} + +impl CredentialDriverService { + #[must_use] + pub fn new(driver: KubernetesSecretsCredentialDriver) -> Self { + Self { driver } + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +struct KubernetesSecretsDriverSettings { + namespace: String, + allow_reference_namespace: bool, +} + +#[derive(Debug, Clone, Default, serde::Deserialize)] +#[serde(default, deny_unknown_fields)] +struct KubernetesSecretsDriverConfig { + namespace: Option, + allow_reference_namespace: bool, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +struct KubernetesSecretReference { + namespace: String, + secret_name: String, + key: String, +} + +impl KubernetesSecretsCredentialDriver { + pub const NAME: &'static str = "kubernetes-secrets"; + + pub async fn from_config(config: &toml::Table) -> CoreResult { + let settings = KubernetesSecretsDriverSettings::from_table(config)?; + let client = Client::try_default().await.map_err(|err| { + Error::config(format!( + "failed to configure kubernetes-secrets credential driver: {err}" + )) + })?; + Ok(Self { client, settings }) + } + + fn handle_from_request( + request_id: &str, + handle: Option, + ) -> Result { + handle.ok_or_else(|| { + Status::invalid_argument(format!( + "kubernetes-secrets credential request '{request_id}' is missing handle" + )) + }) + } + + fn resolve_handle( + handle: &CredentialHandle, + credential_key: &str, + ) -> Result { + let parts = handle.handle.split(':').collect::>(); + if parts.len() != 3 || parts[0] != HANDLE_VERSION { + return Err(Status::invalid_argument( + "kubernetes-secrets credential handle is malformed", + )); + } + let namespace = required_handle_component("namespace", parts[1])?; + if !is_dns_label(namespace) { + return Err(Status::invalid_argument( + "kubernetes-secrets credential handle namespace is invalid", + )); + } + let secret_name = required_handle_component("secret", parts[2])?; + if !is_dns_subdomain(secret_name) { + return Err(Status::invalid_argument( + "kubernetes-secrets credential handle Secret name is invalid", + )); + } + let key = required_handle_component("credential_key", credential_key)?; + if !is_secret_data_key(key) { + return Err(Status::invalid_argument( + "kubernetes-secrets credential key must be a valid Kubernetes Secret data key", + )); + } + + Ok(KubernetesSecretReference { + namespace: namespace.to_string(), + secret_name: secret_name.to_string(), + key: key.to_string(), + }) + } + + pub async fn store_credential( + &self, + request: StoreCredentialRequest, + ) -> Result { + let owner_id = credential_owner_id(&request.provider_name, &request.credential_key); + let reference = if let Some(existing_handle) = request.existing_handle.as_ref() { + let reference = Self::resolve_handle(existing_handle, &request.credential_key)?; + validate_expected_secret_name( + &request.provider_name, + &request.credential_key, + &reference.secret_name, + )?; + reference + } else { + KubernetesSecretReference { + namespace: self.settings.namespace.clone(), + secret_name: managed_secret_name(&request.provider_name, &request.credential_key), + key: required_handle_component("credential_key", &request.credential_key)? + .to_string(), + } + }; + if !is_secret_data_key(&reference.key) { + return Err(Status::invalid_argument( + "kubernetes-secrets credential key must be a valid Kubernetes Secret data key", + )); + } + if request.existing_handle.is_some() { + self.overwrite_secret_value(&reference, &owner_id, &request.value) + .await?; + } else { + self.create_secret_value(&reference, &owner_id, &request.value) + .await?; + } + Ok(CredentialHandle { + driver: Self::NAME.to_string(), + handle: format!( + "{HANDLE_VERSION}:{}:{}", + reference.namespace, reference.secret_name + ), + metadata: std::collections::HashMap::new(), + }) + } + + pub async fn delete_credential(&self, request: DeleteCredentialRequest) -> Result<(), Status> { + let handle = Self::handle_from_request("delete", request.handle)?; + let reference = Self::resolve_handle(&handle, &request.credential_key)?; + validate_expected_secret_name( + &request.provider_name, + &request.credential_key, + &reference.secret_name, + )?; + let owner_id = credential_owner_id(&request.provider_name, &request.credential_key); + let api: Api = Api::namespaced(self.client.clone(), &reference.namespace); + let secret = match api.get(&reference.secret_name).await { + Ok(secret) => secret, + Err(kube::Error::Api(api_err)) if api_err.code == 404 => return Ok(()), + Err(err) => { + return Err(kube_error_to_status( + &reference.namespace, + &reference.secret_name, + err, + )); + } + }; + ensure_secret_is_managed_for(&secret, &reference, &owner_id)?; + match api + .delete(&reference.secret_name, &DeleteParams::default()) + .await + { + Ok(_) => Ok(()), + Err(kube::Error::Api(api_err)) if api_err.code == 404 => Ok(()), + Err(kube::Error::Api(api_err)) if api_err.code == 403 => { + Err(Status::permission_denied(format!( + "gateway is not allowed to delete Kubernetes Secret '{}' in namespace '{}'", + reference.secret_name, reference.namespace + ))) + } + Err(err) => Err(Status::unavailable(format!( + "failed to delete Kubernetes Secret '{}' in namespace '{}': {err}", + reference.secret_name, reference.namespace + ))), + } + } + + pub async fn resolve_credentials( + &self, + requests: Vec, + ) -> Result, Status> { + let mut responses = Vec::with_capacity(requests.len()); + for request in requests { + let handle = Self::handle_from_request(&request.request_id, request.handle)?; + let reference = Self::resolve_handle(&handle, &request.credential_key)?; + validate_expected_secret_name( + &request.provider_name, + &request.credential_key, + &reference.secret_name, + )?; + let owner_id = credential_owner_id(&request.provider_name, &request.credential_key); + let value = self.resolve_secret_value(&reference, &owner_id).await?; + responses.push(ResolvedCredential { + request_id: request.request_id, + value, + expires_at_ms: 0, + }); + } + Ok(responses) + } + + async fn create_secret_value( + &self, + reference: &KubernetesSecretReference, + owner_id: &str, + value: &str, + ) -> Result<(), Status> { + let api: Api = Api::namespaced(self.client.clone(), &reference.namespace); + let secret = managed_secret(&reference.secret_name, &reference.key, owner_id, value); + match api.create(&PostParams::default(), &secret).await { + Ok(_) => Ok(()), + Err(kube::Error::Api(api_err)) if api_err.code == 409 => { + Err(Status::already_exists(format!( + "Kubernetes Secret '{}' in namespace '{}' already exists; refusing to overwrite a Secret not created for this provider credential", + reference.secret_name, reference.namespace + ))) + } + Err(err) => Err(kube_write_error_to_status( + &reference.namespace, + &reference.secret_name, + err, + )), + } + } + + async fn overwrite_secret_value( + &self, + reference: &KubernetesSecretReference, + owner_id: &str, + value: &str, + ) -> Result<(), Status> { + let api: Api = Api::namespaced(self.client.clone(), &reference.namespace); + let secret = match api.get(&reference.secret_name).await { + Ok(secret) => secret, + Err(kube::Error::Api(api_err)) if api_err.code == 404 => { + return self.create_secret_value(reference, owner_id, value).await; + } + Err(err) => { + return Err(kube_error_to_status( + &reference.namespace, + &reference.secret_name, + err, + )); + } + }; + ensure_secret_is_managed_for(&secret, reference, owner_id)?; + + let patch = managed_secret(&reference.secret_name, &reference.key, owner_id, value); + api.patch( + &reference.secret_name, + &PatchParams::default(), + &Patch::Merge(&patch), + ) + .await + .map(|_| ()) + .map_err(|err| { + kube_write_error_to_status(&reference.namespace, &reference.secret_name, err) + }) + } + + async fn resolve_secret_value( + &self, + reference: &KubernetesSecretReference, + owner_id: &str, + ) -> Result { + let api: Api = Api::namespaced(self.client.clone(), &reference.namespace); + let secret = api.get(&reference.secret_name).await.map_err(|err| { + kube_error_to_status(&reference.namespace, &reference.secret_name, err) + })?; + ensure_secret_is_managed_for(&secret, reference, owner_id)?; + let data = secret.data.ok_or_else(|| { + Status::not_found(format!( + "Kubernetes Secret '{}' in namespace '{}' has no data", + reference.secret_name, reference.namespace + )) + })?; + let value = data.get(&reference.key).ok_or_else(|| { + Status::not_found(format!( + "Kubernetes Secret '{}' in namespace '{}' does not contain key '{}'", + reference.secret_name, reference.namespace, reference.key + )) + })?; + String::from_utf8(value.0.clone()).map_err(|_| { + Status::invalid_argument(format!( + "Kubernetes Secret '{}' in namespace '{}' key '{}' is not valid UTF-8", + reference.secret_name, reference.namespace, reference.key + )) + }) + } +} + +impl std::fmt::Debug for KubernetesSecretsCredentialDriver { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("KubernetesSecretsCredentialDriver") + .field("settings", &self.settings) + .finish_non_exhaustive() + } +} + +impl Clone for KubernetesSecretsCredentialDriver { + fn clone(&self) -> Self { + Self { + client: self.client.clone(), + settings: self.settings.clone(), + } + } +} + +#[tonic::async_trait] +impl CredentialDriver for CredentialDriverService { + async fn get_capabilities( + &self, + _request: Request, + ) -> Result, Status> { + Ok(Response::new(GetCredentialDriverCapabilitiesResponse { + driver_name: KubernetesSecretsCredentialDriver::NAME.to_string(), + driver_version: VERSION.to_string(), + backend_kind: KubernetesSecretsCredentialDriver::NAME.to_string(), + supports_list: false, + supports_expires_at: false, + })) + } + + async fn store_credential( + &self, + request: Request, + ) -> Result, Status> { + let handle = self.driver.store_credential(request.into_inner()).await?; + Ok(Response::new(StoreCredentialResponse { + handle: Some(handle), + })) + } + + async fn delete_credential( + &self, + request: Request, + ) -> Result, Status> { + self.driver.delete_credential(request.into_inner()).await?; + Ok(Response::new(DeleteCredentialResponse {})) + } + + async fn resolve_credentials( + &self, + request: Request, + ) -> Result, Status> { + let credentials = self + .driver + .resolve_credentials(request.into_inner().credentials) + .await?; + Ok(Response::new(ResolveCredentialsResponse { credentials })) + } + + async fn list_credentials( + &self, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented( + "kubernetes-secrets credential driver does not support listing credentials", + )) + } +} + +impl KubernetesSecretsDriverSettings { + fn from_table(config: &toml::Table) -> CoreResult { + let config: KubernetesSecretsDriverConfig = toml::Value::Table(config.clone()) + .try_into() + .map_err(|err| { + Error::config(format!( + "invalid [openshell.credential_drivers.kubernetes-secrets]: {err}" + )) + })?; + let namespace = match config.namespace { + Some(namespace) => { + let namespace = trimmed_config_string("namespace", &namespace)?; + if !is_dns_label(namespace) { + return Err(Error::config( + "[openshell.credential_drivers.kubernetes-secrets] namespace must be a Kubernetes namespace name", + )); + } + namespace.to_string() + } + None => default_namespace(), + }; + + Ok(Self { + namespace, + allow_reference_namespace: config.allow_reference_namespace, + }) + } +} + +fn kube_error_to_status(namespace: &str, secret_name: &str, err: kube::Error) -> Status { + match err { + kube::Error::Api(api_err) if api_err.code == 404 => Status::not_found(format!( + "Kubernetes Secret '{secret_name}' in namespace '{namespace}' was not found" + )), + kube::Error::Api(api_err) if api_err.code == 403 => Status::permission_denied(format!( + "gateway is not allowed to read Kubernetes Secret '{secret_name}' in namespace '{namespace}'" + )), + other => Status::unavailable(format!( + "failed to read Kubernetes Secret '{secret_name}' in namespace '{namespace}': {other}" + )), + } +} + +fn default_namespace() -> String { + std::fs::read_to_string(SERVICE_ACCOUNT_NAMESPACE_PATH) + .ok() + .map(|namespace| namespace.trim().to_string()) + .filter(|namespace| !namespace.is_empty() && is_dns_label(namespace)) + .unwrap_or_else(|| "default".to_string()) +} + +fn kube_write_error_to_status(namespace: &str, secret_name: &str, err: kube::Error) -> Status { + match err { + kube::Error::Api(api_err) if api_err.code == 403 => Status::permission_denied(format!( + "gateway is not allowed to write Kubernetes Secret '{secret_name}' in namespace '{namespace}'" + )), + other => Status::unavailable(format!( + "failed to write Kubernetes Secret '{secret_name}' in namespace '{namespace}': {other}" + )), + } +} + +fn managed_secret(secret_name: &str, key: &str, owner_id: &str, value: &str) -> Secret { + let labels = BTreeMap::from([(MANAGED_BY_LABEL.to_string(), MANAGED_BY_VALUE.to_string())]); + let annotations = BTreeMap::from([(OWNER_ANNOTATION.to_string(), owner_id.to_string())]); + Secret { + metadata: ObjectMeta { + name: Some(secret_name.to_string()), + labels: Some(labels), + annotations: Some(annotations), + ..Default::default() + }, + string_data: Some(BTreeMap::from([(key.to_string(), value.to_string())])), + type_: Some("Opaque".to_string()), + ..Default::default() + } +} + +fn credential_owner_id(provider_name: &str, credential_key: &str) -> String { + let mut hasher = Sha256::new(); + hasher.update(provider_name.as_bytes()); + hasher.update([0]); + hasher.update(credential_key.as_bytes()); + let digest = hasher.finalize(); + format!("{digest:x}") +} + +fn managed_secret_name(provider_name: &str, credential_key: &str) -> String { + let hex = credential_owner_id(provider_name, credential_key); + format!("openshell-cred-{}", &hex[..40]) +} + +fn validate_expected_secret_name( + provider_name: &str, + credential_key: &str, + secret_name: &str, +) -> Result<(), Status> { + let expected = managed_secret_name(provider_name, credential_key); + if secret_name != expected { + return Err(Status::invalid_argument(format!( + "kubernetes-secrets credential handle Secret name '{secret_name}' does not match the managed Secret for provider credential '{credential_key}'" + ))); + } + Ok(()) +} + +fn ensure_secret_is_managed_for( + secret: &Secret, + reference: &KubernetesSecretReference, + owner_id: &str, +) -> Result<(), Status> { + let managed_by = secret + .metadata + .labels + .as_ref() + .and_then(|labels| labels.get(MANAGED_BY_LABEL)) + .map(String::as_str); + let owner = secret + .metadata + .annotations + .as_ref() + .and_then(|annotations| annotations.get(OWNER_ANNOTATION)) + .map(String::as_str); + if managed_by == Some(MANAGED_BY_VALUE) && owner == Some(owner_id) { + return Ok(()); + } + Err(Status::failed_precondition(format!( + "Kubernetes Secret '{}' in namespace '{}' is not managed by OpenShell for this provider credential", + reference.secret_name, reference.namespace + ))) +} + +fn trimmed_config_string<'a>(field_name: &str, value: &'a str) -> CoreResult<&'a str> { + let trimmed = value.trim(); + if trimmed.is_empty() { + return Err(Error::config(format!( + "[openshell.credential_drivers.kubernetes-secrets] {field_name} must not be empty" + ))); + } + if trimmed.len() != value.len() { + return Err(Error::config(format!( + "[openshell.credential_drivers.kubernetes-secrets] {field_name} must not contain leading or trailing whitespace" + ))); + } + Ok(trimmed) +} + +fn required_handle_component<'a>(field_name: &str, value: &'a str) -> Result<&'a str, Status> { + let trimmed = value.trim(); + if trimmed.is_empty() { + return Err(Status::invalid_argument(format!( + "kubernetes-secrets credential handle {field_name} is required" + ))); + } + if trimmed.len() != value.len() { + return Err(Status::invalid_argument(format!( + "kubernetes-secrets credential handle {field_name} must not contain leading or trailing whitespace" + ))); + } + Ok(trimmed) +} + +fn is_dns_subdomain(value: &str) -> bool { + !value.is_empty() + && value.len() <= 253 + && value.split('.').all(is_dns_label) + && !value.contains("..") +} + +fn is_dns_label(value: &str) -> bool { + !value.is_empty() + && value.len() <= 63 + && value + .bytes() + .all(|byte| byte.is_ascii_lowercase() || byte.is_ascii_digit() || byte == b'-') + && value + .as_bytes() + .first() + .is_some_and(u8::is_ascii_alphanumeric) + && value + .as_bytes() + .last() + .is_some_and(u8::is_ascii_alphanumeric) +} + +fn is_secret_data_key(value: &str) -> bool { + !value.is_empty() + && value.len() <= 253 + && value + .bytes() + .all(|byte| byte.is_ascii_alphanumeric() || matches!(byte, b'-' | b'_' | b'.')) +} + +#[cfg(test)] +mod tests { + use super::*; + use tonic::Code; + + fn handle(value: &str) -> CredentialHandle { + CredentialHandle { + driver: "kubernetes-secrets".to_string(), + handle: value.to_string(), + metadata: std::collections::HashMap::new(), + } + } + + #[test] + fn settings_parse_configured_namespace() { + let settings = KubernetesSecretsDriverSettings::from_table(&toml::toml! { + namespace = "openshell" + allow_reference_namespace = true + }) + .unwrap(); + + assert_eq!(settings.namespace, "openshell"); + assert!(settings.allow_reference_namespace); + } + + #[test] + fn settings_reject_unknown_fields() { + let err = KubernetesSecretsDriverSettings::from_table(&toml::toml! { + namespace = "openshell" + unknown = "value" + }) + .unwrap_err(); + + assert!(err.to_string().contains("unknown field")); + } + + #[test] + fn settings_reject_invalid_namespace() { + let err = KubernetesSecretsDriverSettings::from_table(&toml::toml! { + namespace = "OpenShell" + }) + .unwrap_err(); + + assert!(err.to_string().contains("namespace")); + } + + #[test] + fn handle_resolves_secret_reference() { + let reference = KubernetesSecretsCredentialDriver::resolve_handle( + &handle("v1:openshell:provider-secret"), + "API_KEY", + ) + .unwrap(); + + assert_eq!(reference.namespace, "openshell"); + assert_eq!(reference.secret_name, "provider-secret"); + assert_eq!(reference.key, "API_KEY"); + } + + #[test] + fn handle_rejects_malformed_value() { + let err = KubernetesSecretsCredentialDriver::resolve_handle( + &handle("provider-secret"), + "API_KEY", + ) + .unwrap_err(); + + assert_eq!(err.code(), Code::InvalidArgument); + assert!(err.message().contains("malformed")); + } + + #[test] + fn handle_rejects_invalid_namespace() { + let err = KubernetesSecretsCredentialDriver::resolve_handle( + &handle("v1:OpenShell:provider-secret"), + "API_KEY", + ) + .unwrap_err(); + + assert_eq!(err.code(), Code::InvalidArgument); + assert!(err.message().contains("namespace")); + } + + #[test] + fn handle_rejects_invalid_secret_name() { + let err = KubernetesSecretsCredentialDriver::resolve_handle( + &handle("v1:openshell:ProviderSecret"), + "API_KEY", + ) + .unwrap_err(); + + assert_eq!(err.code(), Code::InvalidArgument); + assert!(err.message().contains("Secret name")); + } + + #[test] + fn handle_rejects_invalid_credential_key() { + let err = KubernetesSecretsCredentialDriver::resolve_handle( + &handle("v1:openshell:provider-secret"), + "api/key", + ) + .unwrap_err(); + + assert_eq!(err.code(), Code::InvalidArgument); + assert!(err.message().contains("data key")); + } + + #[test] + fn managed_secret_names_are_stable_dns_subdomains() { + let name = managed_secret_name("openai-prod", "OPENAI_API_KEY"); + + assert!(name.starts_with("openshell-cred-")); + assert!(is_dns_subdomain(&name)); + assert_eq!(name, managed_secret_name("openai-prod", "OPENAI_API_KEY")); + } + + #[test] + fn managed_secret_carries_owner_metadata() { + let owner_id = credential_owner_id("openai-prod", "OPENAI_API_KEY"); + let secret = managed_secret("provider-secret", "OPENAI_API_KEY", &owner_id, "sk-test"); + + assert_eq!( + secret + .metadata + .labels + .as_ref() + .and_then(|labels| labels.get(MANAGED_BY_LABEL)) + .map(String::as_str), + Some(MANAGED_BY_VALUE) + ); + assert_eq!( + secret + .metadata + .annotations + .as_ref() + .and_then(|annotations| annotations.get(OWNER_ANNOTATION)) + .map(String::as_str), + Some(owner_id.as_str()) + ); + } + + #[test] + fn expected_secret_name_rejects_arbitrary_handle_names() { + let err = + validate_expected_secret_name("openai-prod", "OPENAI_API_KEY", "preexisting-secret") + .unwrap_err(); + + assert_eq!(err.code(), Code::InvalidArgument); + assert!(err.message().contains("does not match")); + } + + #[test] + fn ownership_check_accepts_matching_managed_secret() { + let owner_id = credential_owner_id("openai-prod", "OPENAI_API_KEY"); + let secret_name = managed_secret_name("openai-prod", "OPENAI_API_KEY"); + let reference = KubernetesSecretReference { + namespace: "openshell".to_string(), + secret_name: secret_name.clone(), + key: "OPENAI_API_KEY".to_string(), + }; + let secret = managed_secret(&secret_name, "OPENAI_API_KEY", &owner_id, "sk-test"); + + ensure_secret_is_managed_for(&secret, &reference, &owner_id).unwrap(); + } + + #[test] + fn ownership_check_rejects_unmanaged_secret() { + let owner_id = credential_owner_id("openai-prod", "OPENAI_API_KEY"); + let reference = KubernetesSecretReference { + namespace: "openshell".to_string(), + secret_name: "provider-secret".to_string(), + key: "OPENAI_API_KEY".to_string(), + }; + let secret = Secret { + metadata: ObjectMeta { + name: Some("provider-secret".to_string()), + ..Default::default() + }, + ..Default::default() + }; + + let err = ensure_secret_is_managed_for(&secret, &reference, &owner_id).unwrap_err(); + assert_eq!(err.code(), Code::FailedPrecondition); + assert!(err.message().contains("is not managed by OpenShell")); + } + + #[test] + fn ownership_check_rejects_different_provider_credential() { + let owner_id = credential_owner_id("openai-prod", "OPENAI_API_KEY"); + let other_owner_id = credential_owner_id("other-provider", "OPENAI_API_KEY"); + let secret_name = managed_secret_name("openai-prod", "OPENAI_API_KEY"); + let reference = KubernetesSecretReference { + namespace: "openshell".to_string(), + secret_name: secret_name.clone(), + key: "OPENAI_API_KEY".to_string(), + }; + let secret = managed_secret(&secret_name, "OPENAI_API_KEY", &other_owner_id, "sk-test"); + + let err = ensure_secret_is_managed_for(&secret, &reference, &owner_id).unwrap_err(); + assert_eq!(err.code(), Code::FailedPrecondition); + assert!(err.message().contains("is not managed by OpenShell")); + } +} diff --git a/crates/openshell-driver-kubernetes-secrets/src/main.rs b/crates/openshell-driver-kubernetes-secrets/src/main.rs new file mode 100644 index 000000000..bb3cdacbd --- /dev/null +++ b/crates/openshell-driver-kubernetes-secrets/src/main.rs @@ -0,0 +1,145 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +use std::io; +use std::os::unix::fs::{FileTypeExt, PermissionsExt}; +use std::path::{Path, PathBuf}; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use clap::Parser; +use futures::Stream; +use miette::{IntoDiagnostic, Result, miette}; +use openshell_core::VERSION; +use openshell_core::proto::credentials::v1::credential_driver_server::CredentialDriverServer; +use openshell_driver_kubernetes_secrets::{ + CredentialDriverService, KubernetesSecretsCredentialDriver, +}; +use tokio::net::{UnixListener, UnixStream}; +use tracing::info; +use tracing_subscriber::EnvFilter; + +#[derive(Parser, Debug)] +#[command(name = "openshell-driver-kubernetes-secrets")] +#[command(version = VERSION)] +struct Args { + #[arg(long, env = "OPENSHELL_CREDENTIAL_DRIVER_SOCKET")] + bind_socket: PathBuf, + + #[arg(long, env = "OPENSHELL_LOG_LEVEL", default_value = "info")] + log_level: String, + + #[arg(long, env = "OPENSHELL_KUBERNETES_SECRETS_NAMESPACE")] + namespace: Option, + + #[arg( + long, + env = "OPENSHELL_KUBERNETES_SECRETS_ALLOW_REFERENCE_NAMESPACE", + default_value_t = false + )] + allow_reference_namespace: bool, +} + +#[tokio::main] +async fn main() -> Result<()> { + let args = Args::parse(); + tracing_subscriber::fmt() + .with_env_filter( + EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new(&args.log_level)), + ) + .init(); + + let driver = KubernetesSecretsCredentialDriver::from_config(&driver_config(&args)) + .await + .into_diagnostic()?; + + prepare_socket(&args.bind_socket)?; + let listener = UnixListener::bind(&args.bind_socket).into_diagnostic()?; + restrict_socket_permissions(&args.bind_socket)?; + + info!( + socket = %args.bind_socket.display(), + "Starting Kubernetes Secrets credential driver" + ); + let result = tonic::transport::Server::builder() + .add_service(CredentialDriverServer::new(CredentialDriverService::new( + driver, + ))) + .serve_with_incoming(UnixIncoming::new(listener)) + .await + .into_diagnostic(); + let _ = std::fs::remove_file(&args.bind_socket); + result +} + +fn driver_config(args: &Args) -> toml::Table { + let mut config = toml::Table::new(); + if let Some(namespace) = args.namespace.as_ref() { + config.insert( + "namespace".to_string(), + toml::Value::String(namespace.clone()), + ); + } + if args.allow_reference_namespace { + config.insert( + "allow_reference_namespace".to_string(), + toml::Value::Boolean(true), + ); + } + config +} + +fn prepare_socket(socket_path: &Path) -> Result<()> { + let parent = socket_path.parent().ok_or_else(|| { + miette!( + "credential driver socket path '{}' has no parent directory", + socket_path.display() + ) + })?; + std::fs::create_dir_all(parent).into_diagnostic()?; + + match std::fs::symlink_metadata(socket_path) { + Ok(metadata) if metadata.file_type().is_socket() => { + std::fs::remove_file(socket_path).into_diagnostic()?; + } + Ok(_) => { + return Err(miette!( + "credential driver socket path '{}' exists but is not a Unix socket", + socket_path.display() + )); + } + Err(err) if err.kind() == io::ErrorKind::NotFound => {} + Err(err) => return Err(err).into_diagnostic(), + } + Ok(()) +} + +fn restrict_socket_permissions(socket_path: &Path) -> Result<()> { + let mut permissions = std::fs::metadata(socket_path) + .into_diagnostic()? + .permissions(); + permissions.set_mode(0o600); + std::fs::set_permissions(socket_path, permissions).into_diagnostic() +} + +struct UnixIncoming { + listener: UnixListener, +} + +impl UnixIncoming { + fn new(listener: UnixListener) -> Self { + Self { listener } + } +} + +impl Stream for UnixIncoming { + type Item = io::Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.get_mut().listener.poll_accept(cx) { + Poll::Ready(Ok((stream, _addr))) => Poll::Ready(Some(Ok(stream))), + Poll::Ready(Err(err)) => Poll::Ready(Some(Err(err))), + Poll::Pending => Poll::Pending, + } + } +} diff --git a/crates/openshell-driver-vault/Cargo.toml b/crates/openshell-driver-vault/Cargo.toml new file mode 100644 index 000000000..2d3878ed6 --- /dev/null +++ b/crates/openshell-driver-vault/Cargo.toml @@ -0,0 +1,38 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +[package] +name = "openshell-driver-vault" +description = "Vault credential driver for OpenShell" +version.workspace = true +edition.workspace = true +rust-version.workspace = true +license.workspace = true +repository.workspace = true + +[[bin]] +name = "openshell-driver-vault" +path = "src/main.rs" + +[dependencies] +openshell-core = { path = "../openshell-core", default-features = false } + +clap = { workspace = true } +futures = { workspace = true } +miette = { workspace = true } +reqwest = { workspace = true } +serde = { workspace = true } +serde_json = { workspace = true } +sha2 = { workspace = true } +tokio = { workspace = true } +toml = { workspace = true } +tonic = { workspace = true } +tracing = { workspace = true } +tracing-subscriber = { workspace = true } + +[dev-dependencies] +tempfile = "3" +wiremock = "0.6" + +[lints] +workspace = true diff --git a/crates/openshell-driver-vault/src/lib.rs b/crates/openshell-driver-vault/src/lib.rs new file mode 100644 index 000000000..5ba60d50c --- /dev/null +++ b/crates/openshell-driver-vault/src/lib.rs @@ -0,0 +1,1220 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Credential driver backed by a Vault-compatible HTTP API. + +use std::path::{Path, PathBuf}; +use std::time::Duration; + +use openshell_core::VERSION; +use openshell_core::proto::CredentialHandle; +use openshell_core::proto::credentials::v1::{ + DeleteCredentialRequest, DeleteCredentialResponse, GetCredentialDriverCapabilitiesRequest, + GetCredentialDriverCapabilitiesResponse, ListCredentialsRequest, ListCredentialsResponse, + ResolveCredentialRequest, ResolveCredentialsRequest, ResolveCredentialsResponse, + ResolvedCredential, StoreCredentialRequest, StoreCredentialResponse, + credential_driver_server::CredentialDriver, +}; +use openshell_core::{Error, Result as CoreResult}; +use reqwest::{StatusCode, Url}; +use serde::{Deserialize, Serialize}; +use sha2::{Digest, Sha256}; +use tonic::{Request, Response, Status}; + +const DEFAULT_MOUNT: &str = "secret"; +const DEFAULT_AUTH_METHOD: &str = "kubernetes"; +const DEFAULT_KUBERNETES_AUTH_MOUNT: &str = "kubernetes"; +const DEFAULT_SERVICE_ACCOUNT_TOKEN_PATH: &str = + "/var/run/secrets/kubernetes.io/serviceaccount/token"; +const DEFAULT_TIMEOUT_SECS: u64 = 10; +const HANDLE_VERSION: &str = "v1"; +const STORED_VALUE_KEY: &str = "value"; + +pub struct VaultCredentialDriver { + client: reqwest::Client, + settings: VaultDriverSettings, +} + +#[derive(Debug, Clone)] +pub struct CredentialDriverService { + driver: VaultCredentialDriver, +} + +impl CredentialDriverService { + #[must_use] + pub fn new(driver: VaultCredentialDriver) -> Self { + Self { driver } + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +struct VaultDriverSettings { + address: Url, + mount: String, + kv_version: KvVersion, + auth: VaultAuthSettings, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +enum VaultAuthSettings { + Kubernetes { + role: String, + auth_mount: String, + service_account_token_path: PathBuf, + }, + TokenFile { + token_path: PathBuf, + }, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum KvVersion { + V1, + V2, +} + +#[derive(Debug, Clone, Default, Deserialize)] +#[serde(default, deny_unknown_fields)] +struct VaultDriverConfig { + address: Option, + mount: Option, + kv_version: Option, + auth_method: Option, + role: Option, + kubernetes_auth_mount: Option, + service_account_token_path: Option, + token_path: Option, + timeout_secs: Option, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +struct VaultSecretReference { + api_path: String, + key: String, + kv_version: KvVersion, +} + +#[derive(Debug, Serialize)] +struct KubernetesLoginRequest<'a> { + role: &'a str, + jwt: &'a str, +} + +#[derive(Debug, Deserialize)] +struct KubernetesLoginResponse { + auth: Option, +} + +#[derive(Debug, Deserialize)] +struct KubernetesLoginAuth { + client_token: String, +} + +impl VaultCredentialDriver { + pub const NAME: &'static str = "vault"; + + pub fn from_config(config: &toml::Table) -> CoreResult { + let settings = VaultDriverSettings::from_table(config)?; + let timeout_secs = timeout_secs(config)?; + let client = reqwest::Client::builder() + .timeout(Duration::from_secs(timeout_secs)) + .build() + .map_err(|err| { + Error::config(format!( + "failed to configure vault credential driver: {err}" + )) + })?; + Ok(Self { client, settings }) + } + + pub async fn store_credential( + &self, + request: StoreCredentialRequest, + ) -> Result { + let logical_path = if let Some(existing_handle) = request.existing_handle.as_ref() { + Self::logical_path_from_handle(existing_handle)? + } else { + managed_secret_path(&request.provider_name, &request.credential_key) + }; + validate_secret_path(&logical_path).map_err(Status::invalid_argument)?; + validate_managed_secret_path( + &request.provider_name, + &request.credential_key, + &logical_path, + )?; + let token = self.auth_token().await?; + let reference = VaultSecretReference { + api_path: api_path_for_reference( + &self.settings.mount, + self.settings.kv_version, + &logical_path, + ), + key: STORED_VALUE_KEY.to_string(), + kv_version: self.settings.kv_version, + }; + self.store_secret_value(&reference, &request.value, &token) + .await?; + Ok(CredentialHandle { + driver: Self::NAME.to_string(), + handle: format!("{HANDLE_VERSION}:{logical_path}"), + metadata: std::collections::HashMap::new(), + }) + } + + pub async fn delete_credential(&self, request: DeleteCredentialRequest) -> Result<(), Status> { + let handle = Self::handle_from_request("delete", request.handle)?; + let logical_path = Self::logical_path_from_handle(&handle)?; + validate_managed_secret_path( + &request.provider_name, + &request.credential_key, + &logical_path, + )?; + let token = self.auth_token().await?; + let api_path = delete_api_path_for_reference( + &self.settings.mount, + self.settings.kv_version, + &logical_path, + ); + self.delete_secret_value(&api_path, &token).await + } + + pub async fn resolve_credentials( + &self, + requests: Vec, + ) -> Result, Status> { + let mut responses = Vec::with_capacity(requests.len()); + let mut resolved_requests = Vec::with_capacity(requests.len()); + for request in requests { + let handle = Self::handle_from_request(&request.request_id, request.handle)?; + let logical_path = Self::logical_path_from_handle(&handle)?; + validate_managed_secret_path( + &request.provider_name, + &request.credential_key, + &logical_path, + )?; + let reference = VaultSecretReference { + api_path: api_path_for_reference( + &self.settings.mount, + self.settings.kv_version, + &logical_path, + ), + key: STORED_VALUE_KEY.to_string(), + kv_version: self.settings.kv_version, + }; + resolved_requests.push((request.request_id, reference)); + } + + let token = self.auth_token().await?; + for (request_id, reference) in resolved_requests { + let value = self.resolve_secret_value(&reference, &token).await?; + responses.push(ResolvedCredential { + request_id, + value, + expires_at_ms: 0, + }); + } + Ok(responses) + } + + fn handle_from_request( + request_id: &str, + handle: Option, + ) -> Result { + handle.ok_or_else(|| { + Status::invalid_argument(format!( + "vault credential request '{request_id}' is missing handle" + )) + }) + } + + fn logical_path_from_handle(handle: &CredentialHandle) -> Result { + let logical_path = handle + .handle + .strip_prefix(&format!("{HANDLE_VERSION}:")) + .ok_or_else(|| Status::invalid_argument("vault credential handle is malformed"))?; + validate_secret_path(logical_path).map_err(Status::invalid_argument)?; + Ok(logical_path.to_string()) + } + + async fn auth_token(&self) -> Result { + match &self.settings.auth { + VaultAuthSettings::TokenFile { token_path } => { + read_secret_file(token_path, "Vault token file").await + } + VaultAuthSettings::Kubernetes { + role, + auth_mount, + service_account_token_path, + } => { + let jwt = read_secret_file( + service_account_token_path, + "Kubernetes service account token", + ) + .await?; + self.login_kubernetes(role, auth_mount, &jwt).await + } + } + } + + async fn login_kubernetes( + &self, + role: &str, + auth_mount: &str, + jwt: &str, + ) -> Result { + let path = format!("auth/{auth_mount}/login"); + let url = self.url_for_path(&path)?; + let response = self + .client + .post(url) + .json(&KubernetesLoginRequest { role, jwt }) + .send() + .await + .map_err(|err| { + Status::unavailable(format!("Vault Kubernetes auth request failed: {err}")) + })?; + let status = response.status(); + if !status.is_success() { + return Err(vault_auth_status(status)); + } + + let body = response + .json::() + .await + .map_err(|_| { + Status::failed_precondition("Vault Kubernetes auth returned invalid JSON") + })?; + let token = body + .auth + .map(|auth| auth.client_token) + .unwrap_or_default() + .trim() + .to_string(); + if token.is_empty() { + return Err(Status::failed_precondition( + "Vault Kubernetes auth returned an empty client token", + )); + } + Ok(token) + } + + async fn resolve_secret_value( + &self, + reference: &VaultSecretReference, + token: &str, + ) -> Result { + let url = self.url_for_path(&reference.api_path)?; + let response = self + .client + .get(url) + .header("X-Vault-Token", token) + .send() + .await + .map_err(|err| { + Status::unavailable(format!( + "Vault secret read failed for path '{}': {err}", + reference.api_path + )) + })?; + let status = response.status(); + if !status.is_success() { + return Err(vault_secret_status(status, &reference.api_path)); + } + + let body = response.json::().await.map_err(|_| { + Status::failed_precondition(format!( + "Vault secret path '{}' returned invalid JSON", + reference.api_path + )) + })?; + extract_secret_value(&body, reference) + } + + async fn store_secret_value( + &self, + reference: &VaultSecretReference, + value: &str, + token: &str, + ) -> Result<(), Status> { + let url = self.url_for_path(&reference.api_path)?; + let body = match reference.kv_version { + KvVersion::V1 => serde_json::json!({ &reference.key: value }), + KvVersion::V2 => serde_json::json!({ "data": { &reference.key: value } }), + }; + let response = self + .client + .post(url) + .header("X-Vault-Token", token) + .json(&body) + .send() + .await + .map_err(|err| { + Status::unavailable(format!( + "Vault secret write failed for path '{}': {err}", + reference.api_path + )) + })?; + let status = response.status(); + if status.is_success() { + Ok(()) + } else { + Err(vault_secret_status(status, &reference.api_path)) + } + } + + async fn delete_secret_value(&self, api_path: &str, token: &str) -> Result<(), Status> { + let url = self.url_for_path(api_path)?; + let response = self + .client + .delete(url) + .header("X-Vault-Token", token) + .send() + .await + .map_err(|err| { + Status::unavailable(format!( + "Vault secret delete failed for path '{api_path}': {err}" + )) + })?; + let status = response.status(); + if status.is_success() || status == StatusCode::NOT_FOUND { + Ok(()) + } else { + Err(vault_secret_status(status, api_path)) + } + } + + fn url_for_path(&self, path: &str) -> Result { + self.settings + .address + .join(&format!("v1/{path}")) + .map_err(|err| Status::internal(format!("failed to build Vault URL: {err}"))) + } +} + +impl std::fmt::Debug for VaultCredentialDriver { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("VaultCredentialDriver") + .field("settings", &self.settings) + .finish_non_exhaustive() + } +} + +impl Clone for VaultCredentialDriver { + fn clone(&self) -> Self { + Self { + client: self.client.clone(), + settings: self.settings.clone(), + } + } +} + +#[tonic::async_trait] +impl CredentialDriver for CredentialDriverService { + async fn get_capabilities( + &self, + _request: Request, + ) -> Result, Status> { + Ok(Response::new(GetCredentialDriverCapabilitiesResponse { + driver_name: VaultCredentialDriver::NAME.to_string(), + driver_version: VERSION.to_string(), + backend_kind: VaultCredentialDriver::NAME.to_string(), + supports_list: false, + supports_expires_at: false, + })) + } + + async fn store_credential( + &self, + request: Request, + ) -> Result, Status> { + let handle = self.driver.store_credential(request.into_inner()).await?; + Ok(Response::new(StoreCredentialResponse { + handle: Some(handle), + })) + } + + async fn delete_credential( + &self, + request: Request, + ) -> Result, Status> { + self.driver.delete_credential(request.into_inner()).await?; + Ok(Response::new(DeleteCredentialResponse {})) + } + + async fn resolve_credentials( + &self, + request: Request, + ) -> Result, Status> { + let credentials = self + .driver + .resolve_credentials(request.into_inner().credentials) + .await?; + Ok(Response::new(ResolveCredentialsResponse { credentials })) + } + + async fn list_credentials( + &self, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented( + "vault credential driver does not support listing credentials", + )) + } +} + +impl VaultDriverSettings { + fn from_table(config: &toml::Table) -> CoreResult { + let config: VaultDriverConfig = + toml::Value::Table(config.clone()) + .try_into() + .map_err(|err| { + Error::config(format!( + "invalid [openshell.credential_drivers.vault]: {err}" + )) + })?; + let address = config + .address + .as_deref() + .ok_or_else(|| { + Error::config("[openshell.credential_drivers.vault] address is required") + }) + .and_then(vault_address)?; + let mount = config + .mount + .as_deref() + .map_or_else(|| Ok(DEFAULT_MOUNT.to_string()), mount_config)?; + let kv_version = config + .kv_version + .as_deref() + .map_or_else(|| Ok(KvVersion::V2), KvVersion::parse_config)?; + let auth_method = config + .auth_method + .as_deref() + .unwrap_or(DEFAULT_AUTH_METHOD) + .trim(); + let auth = match auth_method { + "kubernetes" => { + if config.token_path.is_some() { + return Err(Error::config( + "[openshell.credential_drivers.vault] token_path requires auth_method = 'token_file'", + )); + } + let role = config.role.as_deref().ok_or_else(|| { + Error::config( + "[openshell.credential_drivers.vault] role is required for auth_method = 'kubernetes'", + ) + })?; + let role = trimmed_config_string("role", role)?.to_string(); + let auth_mount = config.kubernetes_auth_mount.as_deref().map_or_else( + || Ok(DEFAULT_KUBERNETES_AUTH_MOUNT.to_string()), + |mount| path_config("kubernetes_auth_mount", mount), + )?; + let service_account_token_path = config + .service_account_token_path + .unwrap_or_else(|| PathBuf::from(DEFAULT_SERVICE_ACCOUNT_TOKEN_PATH)); + VaultAuthSettings::Kubernetes { + role, + auth_mount, + service_account_token_path, + } + } + "token_file" => { + if config.role.is_some() + || config.kubernetes_auth_mount.is_some() + || config.service_account_token_path.is_some() + { + return Err(Error::config( + "[openshell.credential_drivers.vault] Kubernetes auth fields require auth_method = 'kubernetes'", + )); + } + let token_path = config.token_path.ok_or_else(|| { + Error::config( + "[openshell.credential_drivers.vault] token_path is required for auth_method = 'token_file'", + ) + })?; + VaultAuthSettings::TokenFile { token_path } + } + other => { + return Err(Error::config(format!( + "[openshell.credential_drivers.vault] auth_method must be 'kubernetes' or 'token_file', got '{other}'" + ))); + } + }; + + Ok(Self { + address, + mount, + kv_version, + auth, + }) + } +} + +impl KvVersion { + fn parse_config(value: &str) -> CoreResult { + match trimmed_config_string("kv_version", value)? { + "1" => Ok(Self::V1), + "2" => Ok(Self::V2), + other => Err(Error::config(format!( + "[openshell.credential_drivers.vault] kv_version must be '1' or '2', got '{other}'" + ))), + } + } +} + +fn vault_address(value: &str) -> CoreResult { + let value = trimmed_config_string("address", value)?; + let mut url = Url::parse(value).map_err(|_| { + Error::config("[openshell.credential_drivers.vault] address must be an absolute URL") + })?; + if !matches!(url.scheme(), "http" | "https") { + return Err(Error::config( + "[openshell.credential_drivers.vault] address must use http or https", + )); + } + if !url.username().is_empty() || url.password().is_some() { + return Err(Error::config( + "[openshell.credential_drivers.vault] address must not include credentials", + )); + } + if url.query().is_some() || url.fragment().is_some() { + return Err(Error::config( + "[openshell.credential_drivers.vault] address must not include query or fragment", + )); + } + if !url.path().ends_with('/') { + let path = format!("{}/", url.path().trim_end_matches('/')); + url.set_path(&path); + } + Ok(url) +} + +fn timeout_secs(table: &toml::Table) -> CoreResult { + let Some(value) = table.get("timeout_secs") else { + return Ok(DEFAULT_TIMEOUT_SECS); + }; + let timeout = value.as_integer().ok_or_else(|| { + Error::config( + "[openshell.credential_drivers.vault] timeout_secs must be a positive integer", + ) + })?; + if timeout <= 0 { + return Err(Error::config( + "[openshell.credential_drivers.vault] timeout_secs must be a positive integer", + )); + } + u64::try_from(timeout).map_err(|_| { + Error::config("[openshell.credential_drivers.vault] timeout_secs is too large") + }) +} + +fn mount_config(value: &str) -> CoreResult { + path_config("mount", value) +} + +fn path_config(field_name: &str, value: &str) -> CoreResult { + let value = trimmed_config_string(field_name, value)?; + validate_secret_path(value).map_err(|message| { + Error::config(format!( + "[openshell.credential_drivers.vault] {field_name} {message}" + )) + })?; + Ok(value.to_string()) +} + +fn trimmed_config_string<'a>(field_name: &str, value: &'a str) -> CoreResult<&'a str> { + let trimmed = value.trim(); + if trimmed.is_empty() { + return Err(Error::config(format!( + "[openshell.credential_drivers.vault] {field_name} must not be empty" + ))); + } + if trimmed.len() != value.len() { + return Err(Error::config(format!( + "[openshell.credential_drivers.vault] {field_name} must not contain leading or trailing whitespace" + ))); + } + Ok(trimmed) +} + +fn validate_secret_path(value: &str) -> Result<(), &'static str> { + if value.is_empty() { + return Err("must not be empty"); + } + if value.len() > 1024 { + return Err("must be 1024 bytes or fewer"); + } + if value.starts_with('/') || value.ends_with('/') { + return Err("must be a relative path without leading or trailing slash"); + } + if value.contains("//") { + return Err("must not contain empty path segments"); + } + for segment in value.split('/') { + if matches!(segment, "." | "..") { + return Err("must not contain '.' or '..' path segments"); + } + if !segment + .bytes() + .all(|byte| byte.is_ascii_alphanumeric() || matches!(byte, b'-' | b'_' | b'.')) + { + return Err("may only contain ASCII letters, digits, '-', '_', '.', and '/'"); + } + } + Ok(()) +} + +fn api_path_for_reference(mount: &str, kv_version: KvVersion, target: &str) -> String { + match kv_version { + KvVersion::V1 => { + if target == mount || target.starts_with(&format!("{mount}/")) { + target.to_string() + } else { + format!("{mount}/{target}") + } + } + KvVersion::V2 => { + let data_prefix = format!("{mount}/data/"); + if target.starts_with(&data_prefix) { + target.to_string() + } else { + let logical_path = target.strip_prefix(&format!("{mount}/")).unwrap_or(target); + format!("{mount}/data/{logical_path}") + } + } + } +} + +fn delete_api_path_for_reference(mount: &str, kv_version: KvVersion, target: &str) -> String { + match kv_version { + KvVersion::V1 => api_path_for_reference(mount, kv_version, target), + KvVersion::V2 => { + let metadata_prefix = format!("{mount}/metadata/"); + if target.starts_with(&metadata_prefix) { + target.to_string() + } else { + let logical_path = target.strip_prefix(&format!("{mount}/")).unwrap_or(target); + format!("{mount}/metadata/{logical_path}") + } + } + } +} + +fn managed_secret_path(provider_name: &str, credential_key: &str) -> String { + let mut hasher = Sha256::new(); + hasher.update(provider_name.as_bytes()); + hasher.update([0]); + hasher.update(credential_key.as_bytes()); + let digest = hasher.finalize(); + let hex = format!("{digest:x}"); + format!("openshell/provider-credentials/{}", &hex[..40]) +} + +fn validate_managed_secret_path( + provider_name: &str, + credential_key: &str, + logical_path: &str, +) -> Result<(), Status> { + let expected = managed_secret_path(provider_name, credential_key); + if logical_path == expected { + return Ok(()); + } + Err(Status::invalid_argument(format!( + "vault credential handle path does not match the managed path for provider credential '{credential_key}'" + ))) +} + +async fn read_secret_file(path: &Path, description: &str) -> Result { + let contents = tokio::fs::read_to_string(path).await.map_err(|err| { + Status::unauthenticated(format!( + "failed to read {description} '{}': {err}", + path.display() + )) + })?; + let value = contents.trim().to_string(); + if value.is_empty() { + return Err(Status::unauthenticated(format!( + "{description} '{}' is empty", + path.display() + ))); + } + Ok(value) +} + +fn vault_auth_status(status: StatusCode) -> Status { + match status { + StatusCode::UNAUTHORIZED => { + Status::unauthenticated("Vault Kubernetes auth rejected the service account token") + } + StatusCode::FORBIDDEN => { + Status::permission_denied("Vault Kubernetes auth denied the configured role") + } + other => Status::unavailable(format!("Vault Kubernetes auth returned HTTP {other}")), + } +} + +fn vault_secret_status(status: StatusCode, path: &str) -> Status { + match status { + StatusCode::UNAUTHORIZED => { + Status::unauthenticated("Vault rejected the credential driver token") + } + StatusCode::FORBIDDEN => Status::permission_denied(format!( + "Vault token is not allowed to read secret path '{path}'" + )), + StatusCode::NOT_FOUND => { + Status::not_found(format!("Vault secret path '{path}' was not found")) + } + other => Status::unavailable(format!("Vault secret path '{path}' returned HTTP {other}")), + } +} + +fn extract_secret_value( + body: &serde_json::Value, + reference: &VaultSecretReference, +) -> Result { + let data = body + .get("data") + .and_then(serde_json::Value::as_object) + .ok_or_else(|| { + Status::failed_precondition(format!( + "Vault secret path '{}' response is missing data", + reference.api_path + )) + })?; + let fields = match reference.kv_version { + KvVersion::V1 => data, + KvVersion::V2 => data + .get("data") + .and_then(serde_json::Value::as_object) + .ok_or_else(|| { + Status::failed_precondition(format!( + "Vault KV v2 secret path '{}' response is missing data.data", + reference.api_path + )) + })?, + }; + let value = fields.get(&reference.key).ok_or_else(|| { + Status::not_found(format!( + "Vault secret path '{}' does not contain key '{}'", + reference.api_path, reference.key + )) + })?; + value.as_str().map(str::to_string).ok_or_else(|| { + Status::failed_precondition(format!( + "Vault secret path '{}' key '{}' is not a string", + reference.api_path, reference.key + )) + }) +} + +#[cfg(test)] +mod tests { + use openshell_core::proto::CredentialHandle; + use tonic::Code; + use wiremock::matchers::{body_string_contains, header, method, path}; + use wiremock::{Mock, MockServer, ResponseTemplate}; + + use super::*; + + fn handle(value: &str) -> CredentialHandle { + CredentialHandle { + driver: "vault".to_string(), + handle: value.to_string(), + metadata: std::collections::HashMap::new(), + } + } + + fn table(values: &[(&str, toml::Value)]) -> toml::Table { + values + .iter() + .map(|(key, value)| ((*key).to_string(), value.clone())) + .collect() + } + + fn token_file(token: &str) -> tempfile::NamedTempFile { + let file = tempfile::NamedTempFile::new().unwrap(); + std::fs::write(file.path(), token).unwrap(); + file + } + + #[test] + fn settings_parse_kubernetes_auth() { + let settings = VaultDriverSettings::from_table(&table(&[ + ( + "address", + toml::Value::String("http://vault:8200".to_string()), + ), + ("mount", toml::Value::String("team-secret".to_string())), + ("kv_version", toml::Value::String("1".to_string())), + ("auth_method", toml::Value::String("kubernetes".to_string())), + ("role", toml::Value::String("openshell-gateway".to_string())), + ])) + .unwrap(); + + assert_eq!(settings.mount, "team-secret"); + assert_eq!(settings.kv_version, KvVersion::V1); + assert!(matches!( + settings.auth, + VaultAuthSettings::Kubernetes { .. } + )); + } + + #[test] + fn settings_parse_token_file_auth() { + let settings = VaultDriverSettings::from_table(&table(&[ + ( + "address", + toml::Value::String("http://vault:8200".to_string()), + ), + ("auth_method", toml::Value::String("token_file".to_string())), + ( + "token_path", + toml::Value::String("/run/secrets/vault-token".to_string()), + ), + ])) + .unwrap(); + + assert!(matches!(settings.auth, VaultAuthSettings::TokenFile { .. })); + assert_eq!(settings.kv_version, KvVersion::V2); + } + + #[test] + fn settings_reject_unknown_fields() { + let err = VaultDriverSettings::from_table(&table(&[ + ( + "address", + toml::Value::String("http://vault:8200".to_string()), + ), + ("auth_method", toml::Value::String("token_file".to_string())), + ( + "token_path", + toml::Value::String("/run/secrets/vault-token".to_string()), + ), + ("token", toml::Value::String("literal-secret".to_string())), + ])) + .unwrap_err(); + + assert!(err.to_string().contains("unknown field")); + } + + #[test] + fn settings_reject_token_file_without_token_path() { + let err = VaultDriverSettings::from_table(&table(&[ + ( + "address", + toml::Value::String("http://vault:8200".to_string()), + ), + ("auth_method", toml::Value::String("token_file".to_string())), + ])) + .unwrap_err(); + + assert!(err.to_string().contains("token_path is required")); + } + + #[test] + fn api_path_builds_kv2_api_path_from_logical_path() { + assert_eq!( + api_path_for_reference( + "secret", + KvVersion::V2, + "openshell/provider-credentials/abc" + ), + "secret/data/openshell/provider-credentials/abc" + ); + } + + #[test] + fn delete_api_path_builds_kv2_metadata_path_from_logical_path() { + assert_eq!( + delete_api_path_for_reference( + "secret", + KvVersion::V2, + "openshell/provider-credentials/abc" + ), + "secret/metadata/openshell/provider-credentials/abc" + ); + } + + #[test] + fn handle_rejects_malformed_value() { + let err = VaultCredentialDriver::logical_path_from_handle(&handle("providers/nvidia")) + .unwrap_err(); + + assert_eq!(err.code(), Code::InvalidArgument); + assert!(err.message().contains("malformed")); + } + + #[test] + fn handle_rejects_invalid_path() { + let err = + VaultCredentialDriver::logical_path_from_handle(&handle("v1:../providers/nvidia")) + .unwrap_err(); + + assert_eq!(err.code(), Code::InvalidArgument); + assert!(err.message().contains("path segments")); + } + + #[test] + fn handle_rejects_unexpected_managed_path() { + let err = validate_managed_secret_path( + "nvidia-prod", + "NVIDIA_API_KEY", + "openshell/provider-credentials/other", + ) + .unwrap_err(); + + assert_eq!(err.code(), Code::InvalidArgument); + assert!(err.message().contains("managed path")); + } + + #[tokio::test] + async fn store_and_resolve_token_file_kv2_secret() { + let mock_server = MockServer::start().await; + let logical_path = managed_secret_path("nvidia-prod", "NVIDIA_API_KEY"); + let api_path = format!("/v1/secret/data/{logical_path}"); + Mock::given(method("POST")) + .and(path(api_path.as_str())) + .and(header("x-vault-token", "dev-token")) + .and(body_string_contains("nvapi-test")) + .respond_with(ResponseTemplate::new(200)) + .mount(&mock_server) + .await; + Mock::given(method("GET")) + .and(path(api_path.as_str())) + .and(header("x-vault-token", "dev-token")) + .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({ + "data": { + "data": { + "value": "nvapi-test" + }, + "metadata": { + "version": 1 + } + } + }))) + .mount(&mock_server) + .await; + let token_file = token_file("dev-token\n"); + let driver = VaultCredentialDriver::from_config(&table(&[ + ("address", toml::Value::String(mock_server.uri())), + ("auth_method", toml::Value::String("token_file".to_string())), + ( + "token_path", + toml::Value::String(token_file.path().display().to_string()), + ), + ])) + .unwrap(); + + let stored = driver + .store_credential(StoreCredentialRequest { + provider_name: "nvidia-prod".to_string(), + credential_key: "NVIDIA_API_KEY".to_string(), + value: "nvapi-test".to_string(), + existing_handle: None, + }) + .await + .unwrap(); + assert_eq!(stored.handle, format!("v1:{logical_path}")); + + let resolved = driver + .resolve_credentials(vec![ResolveCredentialRequest { + request_id: "credential-0".to_string(), + provider_name: "nvidia-prod".to_string(), + credential_key: "NVIDIA_API_KEY".to_string(), + handle: Some(stored), + }]) + .await + .unwrap(); + + assert_eq!(resolved[0].value, "nvapi-test"); + } + + #[tokio::test] + async fn store_with_existing_handle_reuses_logical_path() { + let mock_server = MockServer::start().await; + let logical_path = managed_secret_path("nvidia-prod", "NVIDIA_API_KEY"); + Mock::given(method("POST")) + .and(path(format!("/v1/secret/data/{logical_path}"))) + .and(header("x-vault-token", "dev-token")) + .and(body_string_contains("updated-secret")) + .respond_with(ResponseTemplate::new(200)) + .mount(&mock_server) + .await; + let token_file = token_file("dev-token\n"); + let driver = VaultCredentialDriver::from_config(&table(&[ + ("address", toml::Value::String(mock_server.uri())), + ("auth_method", toml::Value::String("token_file".to_string())), + ( + "token_path", + toml::Value::String(token_file.path().display().to_string()), + ), + ])) + .unwrap(); + + let stored = driver + .store_credential(StoreCredentialRequest { + provider_name: "nvidia-prod".to_string(), + credential_key: "NVIDIA_API_KEY".to_string(), + value: "updated-secret".to_string(), + existing_handle: Some(handle(&format!("v1:{logical_path}"))), + }) + .await + .unwrap(); + + assert_eq!(stored.handle, format!("v1:{logical_path}")); + } + + #[tokio::test] + async fn delete_token_file_kv2_secret() { + let mock_server = MockServer::start().await; + let logical_path = managed_secret_path("nvidia-prod", "NVIDIA_API_KEY"); + Mock::given(method("DELETE")) + .and(path(format!("/v1/secret/metadata/{logical_path}"))) + .and(header("x-vault-token", "dev-token")) + .respond_with(ResponseTemplate::new(204)) + .mount(&mock_server) + .await; + let token_file = token_file("dev-token\n"); + let driver = VaultCredentialDriver::from_config(&table(&[ + ("address", toml::Value::String(mock_server.uri())), + ("auth_method", toml::Value::String("token_file".to_string())), + ( + "token_path", + toml::Value::String(token_file.path().display().to_string()), + ), + ])) + .unwrap(); + + driver + .delete_credential(DeleteCredentialRequest { + provider_name: "nvidia-prod".to_string(), + credential_key: "NVIDIA_API_KEY".to_string(), + handle: Some(handle(&format!("v1:{logical_path}"))), + }) + .await + .unwrap(); + } + + #[tokio::test] + async fn resolve_kubernetes_auth_kv2_secret() { + let mock_server = MockServer::start().await; + let logical_path = managed_secret_path("github-prod", "GITHUB_TOKEN"); + Mock::given(method("POST")) + .and(path("/v1/auth/kubernetes/login")) + .and(body_string_contains("openshell-gateway")) + .and(body_string_contains("jwt-test")) + .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({ + "auth": { + "client_token": "bao-token" + } + }))) + .mount(&mock_server) + .await; + Mock::given(method("GET")) + .and(path(format!("/v1/secret/data/{logical_path}"))) + .and(header("x-vault-token", "bao-token")) + .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({ + "data": { + "data": { + "value": "ghp-test" + } + } + }))) + .mount(&mock_server) + .await; + let jwt_file = token_file("jwt-test\n"); + let driver = VaultCredentialDriver::from_config(&table(&[ + ("address", toml::Value::String(mock_server.uri())), + ("auth_method", toml::Value::String("kubernetes".to_string())), + ("role", toml::Value::String("openshell-gateway".to_string())), + ( + "service_account_token_path", + toml::Value::String(jwt_file.path().display().to_string()), + ), + ])) + .unwrap(); + + let resolved = driver + .resolve_credentials(vec![ResolveCredentialRequest { + request_id: "credential-0".to_string(), + provider_name: "github-prod".to_string(), + credential_key: "GITHUB_TOKEN".to_string(), + handle: Some(handle(&format!("v1:{logical_path}"))), + }]) + .await + .unwrap(); + + assert_eq!(resolved[0].value, "ghp-test"); + } + + #[tokio::test] + async fn resolve_maps_missing_key() { + let mock_server = MockServer::start().await; + let logical_path = managed_secret_path("nvidia-prod", "NVIDIA_API_KEY"); + Mock::given(method("GET")) + .and(path(format!("/v1/secret/data/{logical_path}"))) + .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({ + "data": { + "data": {} + } + }))) + .mount(&mock_server) + .await; + let token_file = token_file("dev-token\n"); + let driver = VaultCredentialDriver::from_config(&table(&[ + ("address", toml::Value::String(mock_server.uri())), + ("auth_method", toml::Value::String("token_file".to_string())), + ( + "token_path", + toml::Value::String(token_file.path().display().to_string()), + ), + ])) + .unwrap(); + + let err = driver + .resolve_credentials(vec![ResolveCredentialRequest { + request_id: "credential-0".to_string(), + provider_name: "nvidia-prod".to_string(), + credential_key: "NVIDIA_API_KEY".to_string(), + handle: Some(handle(&format!("v1:{logical_path}"))), + }]) + .await + .unwrap_err(); + + assert_eq!(err.code(), Code::NotFound); + assert!(err.message().contains("does not contain key")); + } + + #[tokio::test] + async fn resolve_maps_permission_denied() { + let mock_server = MockServer::start().await; + let logical_path = managed_secret_path("nvidia-prod", "NVIDIA_API_KEY"); + Mock::given(method("GET")) + .and(path(format!("/v1/secret/data/{logical_path}"))) + .respond_with(ResponseTemplate::new(403)) + .mount(&mock_server) + .await; + let token_file = token_file("dev-token\n"); + let driver = VaultCredentialDriver::from_config(&table(&[ + ("address", toml::Value::String(mock_server.uri())), + ("auth_method", toml::Value::String("token_file".to_string())), + ( + "token_path", + toml::Value::String(token_file.path().display().to_string()), + ), + ])) + .unwrap(); + + let err = driver + .resolve_credentials(vec![ResolveCredentialRequest { + request_id: "credential-0".to_string(), + provider_name: "nvidia-prod".to_string(), + credential_key: "NVIDIA_API_KEY".to_string(), + handle: Some(handle(&format!("v1:{logical_path}"))), + }]) + .await + .unwrap_err(); + + assert_eq!(err.code(), Code::PermissionDenied); + } +} diff --git a/crates/openshell-driver-vault/src/main.rs b/crates/openshell-driver-vault/src/main.rs new file mode 100644 index 000000000..1b8265ef8 --- /dev/null +++ b/crates/openshell-driver-vault/src/main.rs @@ -0,0 +1,180 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +use std::io; +use std::os::unix::fs::{FileTypeExt, PermissionsExt}; +use std::path::{Path, PathBuf}; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use clap::Parser; +use futures::Stream; +use miette::{IntoDiagnostic, Result, miette}; +use openshell_core::VERSION; +use openshell_core::proto::credentials::v1::credential_driver_server::CredentialDriverServer; +use openshell_driver_vault::{CredentialDriverService, VaultCredentialDriver}; +use tokio::net::{UnixListener, UnixStream}; +use tracing::info; +use tracing_subscriber::EnvFilter; + +#[derive(Parser, Debug)] +#[command(name = "openshell-driver-vault")] +#[command(version = VERSION)] +struct Args { + #[arg(long, env = "OPENSHELL_CREDENTIAL_DRIVER_SOCKET")] + bind_socket: PathBuf, + + #[arg(long, env = "OPENSHELL_LOG_LEVEL", default_value = "info")] + log_level: String, + + #[arg(long, env = "OPENSHELL_VAULT_ADDRESS")] + address: Option, + + #[arg(long, env = "OPENSHELL_VAULT_MOUNT")] + mount: Option, + + #[arg(long, env = "OPENSHELL_VAULT_KV_VERSION")] + kv_version: Option, + + #[arg(long, env = "OPENSHELL_VAULT_AUTH_METHOD")] + auth_method: Option, + + #[arg(long, env = "OPENSHELL_VAULT_ROLE")] + role: Option, + + #[arg(long, env = "OPENSHELL_VAULT_KUBERNETES_AUTH_MOUNT")] + kubernetes_auth_mount: Option, + + #[arg(long, env = "OPENSHELL_VAULT_SERVICE_ACCOUNT_TOKEN_PATH")] + service_account_token_path: Option, + + #[arg(long, env = "OPENSHELL_VAULT_TOKEN_PATH")] + token_path: Option, + + #[arg(long, env = "OPENSHELL_VAULT_TIMEOUT_SECS")] + timeout_secs: Option, +} + +#[tokio::main] +async fn main() -> Result<()> { + let args = Args::parse(); + tracing_subscriber::fmt() + .with_env_filter( + EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new(&args.log_level)), + ) + .init(); + + let driver = VaultCredentialDriver::from_config(&driver_config(&args)).into_diagnostic()?; + + prepare_socket(&args.bind_socket)?; + let listener = UnixListener::bind(&args.bind_socket).into_diagnostic()?; + restrict_socket_permissions(&args.bind_socket)?; + + info!(socket = %args.bind_socket.display(), "Starting Vault credential driver"); + let result = tonic::transport::Server::builder() + .add_service(CredentialDriverServer::new(CredentialDriverService::new( + driver, + ))) + .serve_with_incoming(UnixIncoming::new(listener)) + .await + .into_diagnostic(); + let _ = std::fs::remove_file(&args.bind_socket); + result +} + +fn driver_config(args: &Args) -> toml::Table { + let mut config = toml::Table::new(); + insert_string(&mut config, "address", args.address.as_ref()); + insert_string(&mut config, "mount", args.mount.as_ref()); + insert_string(&mut config, "kv_version", args.kv_version.as_ref()); + insert_string(&mut config, "auth_method", args.auth_method.as_ref()); + insert_string(&mut config, "role", args.role.as_ref()); + insert_string( + &mut config, + "kubernetes_auth_mount", + args.kubernetes_auth_mount.as_ref(), + ); + insert_path( + &mut config, + "service_account_token_path", + args.service_account_token_path.as_ref(), + ); + insert_path(&mut config, "token_path", args.token_path.as_ref()); + if let Some(timeout_secs) = args.timeout_secs { + config.insert( + "timeout_secs".to_string(), + toml::Value::Integer(i64::try_from(timeout_secs).unwrap_or(i64::MAX)), + ); + } + config +} + +fn insert_string(config: &mut toml::Table, key: &str, value: Option<&String>) { + if let Some(value) = value { + config.insert(key.to_string(), toml::Value::String(value.clone())); + } +} + +fn insert_path(config: &mut toml::Table, key: &str, value: Option<&PathBuf>) { + if let Some(value) = value { + config.insert( + key.to_string(), + toml::Value::String(value.display().to_string()), + ); + } +} + +fn prepare_socket(socket_path: &Path) -> Result<()> { + let parent = socket_path.parent().ok_or_else(|| { + miette!( + "credential driver socket path '{}' has no parent directory", + socket_path.display() + ) + })?; + std::fs::create_dir_all(parent).into_diagnostic()?; + + match std::fs::symlink_metadata(socket_path) { + Ok(metadata) if metadata.file_type().is_socket() => { + std::fs::remove_file(socket_path).into_diagnostic()?; + } + Ok(_) => { + return Err(miette!( + "credential driver socket path '{}' exists but is not a Unix socket", + socket_path.display() + )); + } + Err(err) if err.kind() == io::ErrorKind::NotFound => {} + Err(err) => return Err(err).into_diagnostic(), + } + Ok(()) +} + +fn restrict_socket_permissions(socket_path: &Path) -> Result<()> { + let mut permissions = std::fs::metadata(socket_path) + .into_diagnostic()? + .permissions(); + permissions.set_mode(0o600); + std::fs::set_permissions(socket_path, permissions).into_diagnostic() +} + +struct UnixIncoming { + listener: UnixListener, +} + +impl UnixIncoming { + fn new(listener: UnixListener) -> Self { + Self { listener } + } +} + +impl Stream for UnixIncoming { + type Item = io::Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.get_mut().listener.poll_accept(cx) { + Poll::Ready(Ok((stream, _addr))) => Poll::Ready(Some(Ok(stream))), + Poll::Ready(Err(err)) => Poll::Ready(Some(Err(err))), + Poll::Pending => Poll::Pending, + } + } +} diff --git a/crates/openshell-server/Cargo.toml b/crates/openshell-server/Cargo.toml index 39a26b14e..c4d33f879 100644 --- a/crates/openshell-server/Cargo.toml +++ b/crates/openshell-server/Cargo.toml @@ -17,8 +17,11 @@ path = "src/main.rs" [dependencies] openshell-bootstrap = { path = "../openshell-bootstrap" } openshell-core = { path = "../openshell-core", default-features = false } +openshell-driver-db-credstore = { path = "../openshell-driver-db-credstore" } openshell-driver-docker = { path = "../openshell-driver-docker" } openshell-driver-kubernetes = { path = "../openshell-driver-kubernetes" } +openshell-driver-kubernetes-secrets = { path = "../openshell-driver-kubernetes-secrets" } +openshell-driver-vault = { path = "../openshell-driver-vault" } openshell-driver-podman = { path = "../openshell-driver-podman" } openshell-ocsf = { path = "../openshell-ocsf" } openshell-policy = { path = "../openshell-policy" } @@ -71,9 +74,11 @@ metrics = { workspace = true } metrics-exporter-prometheus = { workspace = true } # Utilities +base64 = { workspace = true } futures = { workspace = true } bytes = { workspace = true } pin-project-lite = { workspace = true } +ring = { workspace = true } serde = { workspace = true } serde_json = { workspace = true } toml = { workspace = true } diff --git a/crates/openshell-server/src/cli.rs b/crates/openshell-server/src/cli.rs index ce7734262..e494f255f 100644 --- a/crates/openshell-server/src/cli.rs +++ b/crates/openshell-server/src/cli.rs @@ -367,6 +367,15 @@ async fn run_from_args(mut args: RunArgs, matches: ArgMatches) -> Result<()> { ) .with_server_sans(args.server_sans.clone()) .with_loopback_service_http(args.enable_loopback_service_http); + + if let Some(gateway_file) = file.as_ref().map(|f| &f.openshell.gateway) { + if let Some(drivers) = &gateway_file.credential_drivers { + config = config.with_credential_drivers(drivers.clone()); + } + if let Some(default_driver) = &gateway_file.default_credential_driver { + config = config.with_default_credential_driver(Some(default_driver.clone())); + } + } validate_grpc_rate_limit_args( args.grpc_rate_limit_requests, args.grpc_rate_limit_window_seconds, diff --git a/crates/openshell-server/src/config_file.rs b/crates/openshell-server/src/config_file.rs index 39cf02bba..e3e090cd1 100644 --- a/crates/openshell-server/src/config_file.rs +++ b/crates/openshell-server/src/config_file.rs @@ -61,6 +61,11 @@ pub struct OpenShellRoot { /// independently of this crate. #[serde(default)] pub drivers: BTreeMap, + + /// `[openshell.credential_drivers.]` tables — passed verbatim to + /// credential driver implementations after gateway-level selection. + #[serde(default)] + pub credential_drivers: BTreeMap, } /// `[openshell.gateway]` section. @@ -88,6 +93,12 @@ pub struct GatewayFileSection { // ── Drivers ────────────────────────────────────────────────────────── #[serde(default)] pub compute_drivers: Option>, + #[serde(default)] + pub credential_drivers: Option>, + #[serde(default)] + pub default_credential_driver: Option, + #[serde(default)] + pub credential_storage: Option, // ── Sandbox / SSH ──────────────────────────────────────────────────── #[serde(default)] @@ -186,6 +197,11 @@ pub enum ConfigFileError { env: &'static str, cli: &'static str, }, + #[error("invalid gateway config field `{field}`: {message}")] + InvalidValue { + field: &'static str, + message: &'static str, + }, } /// Load and validate a TOML config file. @@ -218,6 +234,18 @@ pub fn load(path: &Path) -> Result { cli: "--db-url", }); } + if file + .openshell + .gateway + .credential_drivers + .as_ref() + .is_some_and(Vec::is_empty) + { + return Err(ConfigFileError::InvalidValue { + field: "openshell.gateway.credential_drivers", + message: "omit the field to use default encrypted gateway credential storage, or specify exactly one external credential driver", + }); + } Ok(file) } @@ -352,6 +380,7 @@ bind_address = "0.0.0.0:8080" health_bind_address = "0.0.0.0:8081" log_level = "info" compute_drivers = ["kubernetes"] +credential_drivers = ["kubernetes-secrets"] sandbox_namespace = "agents" grpc_rate_limit_requests = 120 grpc_rate_limit_window_seconds = 60 @@ -372,6 +401,9 @@ audience = "openshell-cli" [openshell.drivers.kubernetes] namespace = "agents" grpc_endpoint = "https://openshell-gateway.agents.svc:8080" + +[openshell.credential_drivers.kubernetes-secrets] +namespace = "agents" "#; let tmp = write_tmp(toml); let file = load(tmp.path()).expect("valid file parses"); @@ -385,7 +417,32 @@ grpc_endpoint = "https://openshell-gateway.agents.svc:8080" assert_eq!(gw.grpc_rate_limit_window_seconds, Some(60)); assert!(gw.tls.is_some()); assert!(gw.oidc.is_some()); + assert_eq!( + gw.credential_drivers.as_deref(), + Some(&["kubernetes-secrets".to_string()][..]) + ); + assert!(gw.default_credential_driver.is_none()); assert!(file.openshell.drivers.contains_key("kubernetes")); + assert!( + file.openshell + .credential_drivers + .contains_key("kubernetes-secrets") + ); + } + + #[test] + fn rejects_explicit_empty_credential_drivers() { + let tmp = write_tmp( + r" +[openshell.gateway] +credential_drivers = [] +", + ); + + let err = load(tmp.path()).unwrap_err(); + + assert!(err.to_string().contains("credential_drivers")); + assert!(err.to_string().contains("omit the field")); } #[test] diff --git a/crates/openshell-server/src/credentials.rs b/crates/openshell-server/src/credentials.rs new file mode 100644 index 000000000..ccb3f5e19 --- /dev/null +++ b/crates/openshell-server/src/credentials.rs @@ -0,0 +1,2033 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Gateway credential-driver runtime scaffolding. +//! +//! This module owns gateway-level credential-driver selection and resolution +//! dispatch. Concrete production backends and remote UDS transport plug in here +//! in later implementation slices. + +use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet}; +use std::path::{Path, PathBuf}; +#[cfg(unix)] +use std::process::Stdio; +use std::sync::Arc; +#[cfg(unix)] +use std::time::{Duration, Instant}; +#[cfg(unix)] +use std::{ + io::ErrorKind, + os::unix::fs::{FileTypeExt, MetadataExt}, +}; + +use async_trait::async_trait; +#[cfg(unix)] +use hyper_util::rt::TokioIo; +use openshell_core::proto::credentials::v1::{ + DeleteCredentialRequest, GetCredentialDriverCapabilitiesRequest, ResolveCredentialRequest, + ResolveCredentialsRequest, ResolvedCredential, StoreCredentialRequest, + credential_driver_client::CredentialDriverClient, +}; +use openshell_core::proto::{CredentialHandle, Provider}; +use openshell_core::{Config, Error, Result as CoreResult}; +use openshell_driver_db_credstore::{ + CredentialObjectWrite, DbCredstoreCredentialDriver, DbCredstoreObjectStore, + DbCredstoreWriteCondition, StoredCredentialObject, +}; +use openshell_driver_kubernetes_secrets::KubernetesSecretsCredentialDriver; +use openshell_driver_vault::VaultCredentialDriver; +#[cfg(unix)] +use tokio::net::UnixStream; +#[cfg(unix)] +use tokio::process::Command; +#[cfg(unix)] +use tonic::transport::{Channel, Endpoint}; +use tonic::{Request, Status}; +#[cfg(unix)] +use tower::service_fn; + +use crate::persistence::{PersistenceError, Store, WriteCondition}; + +const DEFAULT_CREDENTIAL_DRIVER_STARTUP_TIMEOUT_SECS: u64 = 10; +const COMMON_CREDENTIAL_DRIVER_FIELDS: &[&str] = &[ + "transport", + "socket_path", + "command", + "args", + "startup_timeout_secs", +]; +#[cfg(unix)] +const CREDENTIAL_DRIVER_CONNECT_INTERVAL: Duration = Duration::from_millis(100); + +#[async_trait] +pub trait CredentialDriver: std::fmt::Debug + Send + Sync { + async fn store_credential( + &self, + request: StoreCredentialRequest, + ) -> Result; + + async fn delete_credential(&self, request: DeleteCredentialRequest) -> Result<(), Status>; + + async fn resolve_credentials( + &self, + requests: Vec, + ) -> Result, Status>; +} + +#[derive(Debug, Clone, Default, PartialEq, Eq)] +pub struct ResolvedProviderCredentials { + pub values: HashMap, + pub expires_at_ms: HashMap, +} + +#[derive(Debug, Clone)] +pub struct CredentialRuntime { + registry: CredentialDriverRegistry, + drivers: BTreeMap>, + _driver_processes: Vec>, +} + +impl CredentialRuntime { + pub fn from_config(config: &Config) -> CoreResult { + Self::from_config_with_optional_store(config, None) + } + + pub fn from_config_with_store(config: &Config, store: Arc) -> CoreResult { + Self::from_config_with_optional_store(config, Some(store)) + } + + fn from_config_with_optional_store( + config: &Config, + store: Option>, + ) -> CoreResult { + let registry = CredentialDriverRegistry::from_config(config)?; + let mut drivers = BTreeMap::new(); + connect_default_credential_store( + &mut drivers, + store.clone(), + &toml::Table::new(), + registry.requires_default_store(), + )?; + + for driver_name in registry.enabled_driver_names() { + if let Some(driver) = build_sync_builtin_driver(driver_name, store.clone()) { + drivers.insert(driver_name.clone(), driver); + } else if BuiltinCredentialDriverKind::from_name(driver_name).is_none() { + return Err(unknown_credential_driver_error(driver_name)); + } + } + + Ok(Self { + registry, + drivers, + _driver_processes: Vec::new(), + }) + } + + pub async fn from_config_file( + config: &Config, + config_file: Option<&crate::config_file::ConfigFile>, + ) -> CoreResult { + Self::from_config_file_with_optional_store(config, config_file, None).await + } + + pub async fn from_config_file_with_store( + config: &Config, + config_file: Option<&crate::config_file::ConfigFile>, + store: Arc, + ) -> CoreResult { + Self::from_config_file_with_optional_store(config, config_file, Some(store)).await + } + + async fn from_config_file_with_optional_store( + config: &Config, + config_file: Option<&crate::config_file::ConfigFile>, + store: Option>, + ) -> CoreResult { + let registry = CredentialDriverRegistry::from_config(config)?; + let mut drivers = BTreeMap::new(); + let mut driver_processes = Vec::new(); + let empty_config = toml::Table::new(); + let default_store_config = config_file + .and_then(|file| file.openshell.gateway.credential_storage.as_ref()) + .unwrap_or(&empty_config); + connect_default_credential_store( + &mut drivers, + store.clone(), + default_store_config, + registry.requires_default_store(), + )?; + + for driver_name in registry.enabled_driver_names() { + let driver_config = config_file + .and_then(|file| file.openshell.credential_drivers.get(driver_name)) + .map(|value| parse_driver_table(driver_name, value)) + .transpose()?; + + if let Some(driver_config) = driver_config { + let built = + build_configured_driver(driver_name, driver_config, store.clone()).await?; + drivers.insert(driver_name.clone(), built.driver); + if let Some(process) = built.process { + driver_processes.push(process); + } + } else { + let driver = build_default_in_tree_driver(driver_name, store.clone()).await?; + drivers.insert(driver_name.clone(), driver); + } + } + + Ok(Self { + registry, + drivers, + _driver_processes: driver_processes, + }) + } + + pub fn validate_provider_handles(&self, provider: &Provider) -> Result<(), Status> { + self.registry.validate_provider_handles(provider) + } + + pub fn stores_provider_credentials(&self) -> bool { + let driver_name = self.registry.storage_owner_name(); + self.drivers.contains_key(&driver_name) + } + + pub fn storage_owns_handle(&self, handle: &CredentialHandle) -> bool { + normalize_driver_name(&handle.driver) == self.registry.storage_owner_name() + } + + pub async fn store_provider_credentials( + &self, + provider_name: &str, + credentials: &HashMap, + existing_handles: &HashMap, + ) -> Result, Status> { + if credentials.is_empty() { + return Ok(HashMap::new()); + } + let driver_name = self.registry.storage_owner_name(); + let driver = self.connected_driver(&driver_name)?; + let mut handles = HashMap::with_capacity(credentials.len()); + + for (credential_key, value) in credentials { + let existing_handle = existing_handles + .get(credential_key) + .filter(|handle| normalize_driver_name(&handle.driver) == driver_name) + .cloned(); + let replaced_handle = existing_handles + .get(credential_key) + .filter(|handle| normalize_driver_name(&handle.driver) != driver_name) + .cloned(); + let mut handle = driver + .store_credential(StoreCredentialRequest { + provider_name: provider_name.to_string(), + credential_key: credential_key.clone(), + value: value.clone(), + existing_handle, + }) + .await?; + handle.driver.clone_from(&driver_name); + if handle.handle.trim().is_empty() { + return Err(Status::internal(format!( + "credential driver '{driver_name}' returned an empty handle for provider credential '{credential_key}'" + ))); + } + if let Some(replaced_handle) = replaced_handle { + self.delete_provider_credential_handle( + provider_name, + credential_key, + replaced_handle, + ) + .await?; + } + handles.insert(credential_key.clone(), handle); + } + + Ok(handles) + } + + pub async fn delete_provider_credential_handles( + &self, + provider_name: &str, + handles: &HashMap, + ) -> Result<(), Status> { + for (credential_key, handle) in handles { + self.delete_provider_credential_handle(provider_name, credential_key, handle.clone()) + .await?; + } + Ok(()) + } + + async fn delete_provider_credential_handle( + &self, + provider_name: &str, + credential_key: &str, + handle: CredentialHandle, + ) -> Result<(), Status> { + let driver_name = self.registry.driver_for_handle(credential_key, &handle)?; + let driver = self.connected_driver(&driver_name)?; + driver + .delete_credential(DeleteCredentialRequest { + provider_name: provider_name.to_string(), + credential_key: credential_key.to_string(), + handle: Some(handle), + }) + .await + } + + pub async fn resolve_provider_handles( + &self, + provider: &Provider, + now_ms: i64, + ) -> Result { + self.registry.validate_provider_handles(provider)?; + if provider.credential_handles.is_empty() { + return Ok(ResolvedProviderCredentials::default()); + } + + let provider_name = provider + .metadata + .as_ref() + .map(|metadata| metadata.name.clone()) + .unwrap_or_default(); + let mut request_keys = HashMap::new(); + let mut requests_by_driver: BTreeMap> = + BTreeMap::new(); + + for (credential_key, handle) in &provider.credential_handles { + let driver_name = self.registry.driver_for_handle(credential_key, handle)?; + let request_id = format!("credential-{}", request_keys.len()); + request_keys.insert(request_id.clone(), credential_key.clone()); + + let mut selected_handle = handle.clone(); + selected_handle.driver.clone_from(&driver_name); + requests_by_driver + .entry(driver_name) + .or_default() + .push(ResolveCredentialRequest { + request_id, + provider_name: provider_name.clone(), + credential_key: credential_key.clone(), + handle: Some(selected_handle), + }); + } + + let mut resolved = ResolvedProviderCredentials::default(); + let mut seen_responses = HashSet::new(); + + for (driver_name, requests) in requests_by_driver { + let expected_request_ids: HashSet<_> = requests + .iter() + .map(|request| request.request_id.clone()) + .collect(); + let driver = self.connected_driver(&driver_name)?; + + let responses = driver.resolve_credentials(requests).await?; + for response in responses { + if response.request_id.is_empty() { + return Err(Status::internal(format!( + "credential driver '{driver_name}' returned a response without request_id" + ))); + } + if !expected_request_ids.contains(&response.request_id) { + return Err(Status::internal(format!( + "credential driver '{driver_name}' returned unknown request_id '{}'", + response.request_id + ))); + } + if !seen_responses.insert(response.request_id.clone()) { + return Err(Status::internal(format!( + "credential driver '{driver_name}' returned duplicate request_id '{}'", + response.request_id + ))); + } + + let credential_key = request_keys + .get(&response.request_id) + .expect("validated response request_id") + .clone(); + if response.expires_at_ms > 0 && response.expires_at_ms <= now_ms { + return Err(Status::failed_precondition(format!( + "credential driver '{driver_name}' returned expired credential for provider credential '{credential_key}'" + ))); + } + if response.expires_at_ms > 0 { + resolved + .expires_at_ms + .insert(credential_key.clone(), response.expires_at_ms); + } + resolved.values.insert(credential_key, response.value); + } + + for request_id in expected_request_ids { + if !seen_responses.contains(&request_id) { + return Err(Status::internal(format!( + "credential driver '{driver_name}' did not return a response for request_id '{request_id}'" + ))); + } + } + } + + Ok(resolved) + } + + fn connected_driver(&self, driver_name: &str) -> Result<&Arc, Status> { + self.drivers.get(driver_name).ok_or_else(|| { + Status::failed_precondition(format!( + "credential driver '{driver_name}' is enabled but not connected" + )) + }) + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum BuiltinCredentialDriverKind { + KubernetesSecrets, + Vault, + #[cfg(any(test, feature = "test-support"))] + TestStatic, +} + +impl BuiltinCredentialDriverKind { + fn from_name(name: &str) -> Option { + match name { + KubernetesSecretsCredentialDriver::NAME => Some(Self::KubernetesSecrets), + VaultCredentialDriver::NAME => Some(Self::Vault), + #[cfg(any(test, feature = "test-support"))] + TestStaticCredentialDriver::NAME => Some(Self::TestStatic), + _ => None, + } + } +} + +#[cfg(any(test, feature = "test-support"))] +fn builtin_credential_driver_names() -> &'static [&'static str] { + &[ + KubernetesSecretsCredentialDriver::NAME, + VaultCredentialDriver::NAME, + TestStaticCredentialDriver::NAME, + ] +} + +#[cfg(not(any(test, feature = "test-support")))] +fn builtin_credential_driver_names() -> &'static [&'static str] { + &[ + KubernetesSecretsCredentialDriver::NAME, + VaultCredentialDriver::NAME, + ] +} + +fn unknown_credential_driver_error(driver_name: &str) -> Error { + Error::config(format!( + "credential driver '{driver_name}' is not a built-in credential driver and has no [openshell.credential_drivers.{driver_name}] table; configure an external driver with transport = 'uds' or choose one of: {}", + builtin_credential_driver_names().join(", ") + )) +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct CredentialDriverRegistry { + enabled: BTreeSet, + default_driver: Option, +} + +impl CredentialDriverRegistry { + pub fn from_config(config: &Config) -> CoreResult { + let mut enabled = BTreeSet::new(); + for driver in &config.credential_drivers { + let driver = normalize_driver_name(driver); + if driver.is_empty() { + return Err(Error::config( + "credential_drivers entries must be non-empty strings", + )); + } + enabled.insert(driver); + } + + let default_driver = config + .default_credential_driver + .as_deref() + .map(normalize_driver_name) + .filter(|driver| !driver.is_empty()); + + if default_driver.is_some() && enabled.is_empty() { + return Err(Error::config( + "default_credential_driver requires credential_drivers to name an external credential driver", + )); + } + + if let Some(default_driver) = default_driver.as_deref() + && !enabled.contains(default_driver) + { + return Err(Error::config(format!( + "default_credential_driver '{default_driver}' is not listed in credential_drivers" + ))); + } + + if enabled.len() > 1 { + return Err(Error::config( + "credential_drivers supports at most one enabled credential driver", + )); + } + + Ok(Self { + enabled, + default_driver, + }) + } + + pub fn storage_owner_name(&self) -> String { + if self.enabled.is_empty() { + return DbCredstoreCredentialDriver::NAME.to_string(); + } + if let Some(default_driver) = self.default_driver.clone() { + return default_driver; + } + self.enabled + .iter() + .next() + .expect("enabled is non-empty") + .clone() + } + + fn requires_default_store(&self) -> bool { + self.enabled.is_empty() + } + + pub fn validate_provider_handles(&self, provider: &Provider) -> Result<(), Status> { + if provider.credential_handles.is_empty() { + return Ok(()); + } + for (credential_key, handle) in &provider.credential_handles { + self.driver_for_handle(credential_key, handle)?; + } + + Ok(()) + } + + fn enabled_driver_names(&self) -> impl Iterator { + self.enabled.iter() + } + + fn driver_for_handle( + &self, + credential_key: &str, + handle: &CredentialHandle, + ) -> Result { + let driver = normalize_driver_name(&handle.driver); + if driver.is_empty() { + return Err(Status::invalid_argument(format!( + "provider credential_handles['{credential_key}'] is missing driver" + ))); + } + if handle.handle.trim().is_empty() { + return Err(Status::invalid_argument(format!( + "provider credential_handles['{credential_key}'] is missing handle" + ))); + } + + if driver == DbCredstoreCredentialDriver::NAME { + return Ok(driver); + } + + if !self.enabled.contains(&driver) { + return Err(Status::invalid_argument(format!( + "provider credential_handles['{credential_key}'] references credential driver '{driver}' that is not enabled" + ))); + } + + Ok(driver) + } +} + +fn normalize_driver_name(driver: &str) -> String { + driver.trim().to_string() +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum CredentialDriverTransport { + InTree, + Uds, +} + +#[derive(Debug, Clone, PartialEq)] +struct ConfiguredCredentialDriver { + transport: CredentialDriverTransport, + socket_path: Option, + command: Option, + args: Vec, + startup_timeout_secs: u64, + backend_config: toml::Table, +} + +fn parse_driver_table( + driver_name: &str, + value: &toml::Value, +) -> CoreResult { + let table = value.as_table().ok_or_else(|| { + Error::config(format!( + "[openshell.credential_drivers.{driver_name}] must be a TOML table" + )) + })?; + + let transport = table + .get("transport") + .map(|value| string_field(driver_name, "transport", value)) + .transpose()? + .unwrap_or_else(|| "in_tree".to_string()); + let transport = match transport.as_str() { + "in_tree" => CredentialDriverTransport::InTree, + "uds" => CredentialDriverTransport::Uds, + other => { + return Err(Error::config(format!( + "[openshell.credential_drivers.{driver_name}] transport must be 'in_tree' or 'uds', got '{other}'" + ))); + } + }; + + let socket_path = table + .get("socket_path") + .map(|value| string_field(driver_name, "socket_path", value)) + .transpose()? + .map(PathBuf::from); + let command = table + .get("command") + .map(|value| string_field(driver_name, "command", value)) + .transpose()? + .map(PathBuf::from); + let args = table + .get("args") + .map(|value| string_array_field(driver_name, "args", value)) + .transpose()? + .unwrap_or_default(); + let startup_timeout_secs = table + .get("startup_timeout_secs") + .map(|value| positive_integer_field(driver_name, "startup_timeout_secs", value)) + .transpose()? + .unwrap_or(DEFAULT_CREDENTIAL_DRIVER_STARTUP_TIMEOUT_SECS); + + if transport == CredentialDriverTransport::Uds { + let socket_path = socket_path.as_ref().ok_or_else(|| { + Error::config(format!( + "[openshell.credential_drivers.{driver_name}] socket_path is required when transport = 'uds'" + )) + })?; + if !socket_path.is_absolute() { + return Err(Error::config(format!( + "[openshell.credential_drivers.{driver_name}] socket_path must be absolute" + ))); + } + if let Some(command) = command.as_ref() + && !command.is_absolute() + { + return Err(Error::config(format!( + "[openshell.credential_drivers.{driver_name}] command must be absolute" + ))); + } + if command.is_none() && !args.is_empty() { + return Err(Error::config(format!( + "[openshell.credential_drivers.{driver_name}] args requires command" + ))); + } + if command.is_none() && table.contains_key("startup_timeout_secs") { + return Err(Error::config(format!( + "[openshell.credential_drivers.{driver_name}] startup_timeout_secs requires command" + ))); + } + } else if command.is_some() || !args.is_empty() || table.contains_key("startup_timeout_secs") { + return Err(Error::config(format!( + "[openshell.credential_drivers.{driver_name}] command, args, and startup_timeout_secs require transport = 'uds'" + ))); + } + + Ok(ConfiguredCredentialDriver { + transport, + socket_path, + command, + args, + startup_timeout_secs, + backend_config: backend_config_table(table), + }) +} + +fn backend_config_table(table: &toml::Table) -> toml::Table { + let mut backend_config = table.clone(); + for field in COMMON_CREDENTIAL_DRIVER_FIELDS { + backend_config.remove(*field); + } + backend_config +} + +fn string_field( + driver_name: &str, + field_name: &'static str, + value: &toml::Value, +) -> CoreResult { + let value = value.as_str().ok_or_else(|| { + Error::config(format!( + "[openshell.credential_drivers.{driver_name}] {field_name} must be a string" + )) + })?; + let value = value.trim(); + if value.is_empty() { + return Err(Error::config(format!( + "[openshell.credential_drivers.{driver_name}] {field_name} must not be empty" + ))); + } + Ok(value.to_string()) +} + +fn string_array_field( + driver_name: &str, + field_name: &'static str, + value: &toml::Value, +) -> CoreResult> { + let values = value.as_array().ok_or_else(|| { + Error::config(format!( + "[openshell.credential_drivers.{driver_name}] {field_name} must be an array of strings" + )) + })?; + + values + .iter() + .map(|value| string_field(driver_name, field_name, value)) + .collect() +} + +fn positive_integer_field( + driver_name: &str, + field_name: &'static str, + value: &toml::Value, +) -> CoreResult { + let value = value.as_integer().ok_or_else(|| { + Error::config(format!( + "[openshell.credential_drivers.{driver_name}] {field_name} must be a positive integer" + )) + })?; + if value <= 0 { + return Err(Error::config(format!( + "[openshell.credential_drivers.{driver_name}] {field_name} must be a positive integer" + ))); + } + u64::try_from(value).map_err(|_| { + Error::config(format!( + "[openshell.credential_drivers.{driver_name}] {field_name} is too large" + )) + }) +} + +fn connect_default_credential_store( + drivers: &mut BTreeMap>, + store: Option>, + config: &toml::Table, + required: bool, +) -> CoreResult<()> { + if !required && config.is_empty() { + return Ok(()); + } + + let Some(store) = store else { + if required { + return Err(Error::config( + "default encrypted credential storage requires the gateway object store", + )); + } + return Ok(()); + }; + + let object_store: Arc = + Arc::new(ServerDbCredstoreObjectStore::new(store)); + let storage: Arc = Arc::new(DbCredstoreCredentialDriver::from_config( + object_store, + config, + )?); + drivers.insert(DbCredstoreCredentialDriver::NAME.to_string(), storage); + Ok(()) +} + +#[derive(Debug)] +struct BuiltCredentialDriver { + driver: Arc, + process: Option>, +} + +async fn build_configured_driver( + driver_name: &str, + config: ConfiguredCredentialDriver, + store: Option>, +) -> CoreResult { + match config.transport { + CredentialDriverTransport::InTree => { + let driver = build_in_tree_driver(driver_name, Some(&config.backend_config), store) + .await? + .ok_or_else(|| { + Error::config(format!( + "credential driver '{driver_name}' is configured with transport = 'in_tree', but no in-tree implementation is available" + )) + })?; + Ok(BuiltCredentialDriver { + driver, + process: None, + }) + } + CredentialDriverTransport::Uds => { + let socket_path = config + .socket_path + .clone() + .expect("UDS transport requires socket_path during parsing"); + connect_uds_driver(driver_name, config, &socket_path).await + } + } +} + +async fn build_default_in_tree_driver( + driver_name: &str, + store: Option>, +) -> CoreResult> { + build_in_tree_driver(driver_name, None, store) + .await? + .ok_or_else(|| unknown_credential_driver_error(driver_name)) +} + +async fn build_in_tree_driver( + name: &str, + backend_config: Option<&toml::Table>, + _store: Option>, +) -> CoreResult>> { + let Some(kind) = BuiltinCredentialDriverKind::from_name(name) else { + return Ok(None); + }; + + let empty_config = toml::Table::new(); + let backend_config = backend_config.unwrap_or(&empty_config); + let driver: Arc = match kind { + BuiltinCredentialDriverKind::KubernetesSecrets => { + Arc::new(KubernetesSecretsCredentialDriver::from_config(backend_config).await?) + } + BuiltinCredentialDriverKind::Vault => { + Arc::new(VaultCredentialDriver::from_config(backend_config)?) + } + #[cfg(any(test, feature = "test-support"))] + BuiltinCredentialDriverKind::TestStatic => Arc::new(TestStaticCredentialDriver::new()), + }; + Ok(Some(driver)) +} + +fn build_sync_builtin_driver( + name: &str, + _store: Option>, +) -> Option> { + #[cfg(any(test, feature = "test-support"))] + if BuiltinCredentialDriverKind::from_name(name) == Some(BuiltinCredentialDriverKind::TestStatic) + { + let driver: Arc = Arc::new(TestStaticCredentialDriver::new()); + return Some(driver); + } + + let _ = name; + None +} + +#[derive(Debug, Clone)] +struct ServerDbCredstoreObjectStore { + store: Arc, +} + +impl ServerDbCredstoreObjectStore { + fn new(store: Arc) -> Self { + Self { store } + } +} + +#[async_trait] +impl DbCredstoreObjectStore for ServerDbCredstoreObjectStore { + async fn get_credential_object( + &self, + object_type: &str, + id: &str, + operation: &'static str, + ) -> Result, Status> { + self.store + .get(object_type, id) + .await + .map(|record| { + record.map(|record| StoredCredentialObject { + object_type: record.object_type, + id: record.id, + payload: record.payload, + resource_version: record.resource_version, + }) + }) + .map_err(|err| default_credential_store_persistence_error_to_status(err, operation)) + } + + async fn put_credential_object( + &self, + write: CredentialObjectWrite, + operation: &'static str, + ) -> Result<(), Status> { + let condition = match write.condition { + DbCredstoreWriteCondition::MustCreate => WriteCondition::MustCreate, + DbCredstoreWriteCondition::MatchResourceVersion(resource_version) => { + WriteCondition::MatchResourceVersion(resource_version) + } + }; + self.store + .put_if( + &write.object_type, + &write.id, + &write.name, + &write.payload, + write.labels.as_deref(), + condition, + ) + .await + .map(|_| ()) + .map_err(|err| default_credential_store_persistence_error_to_status(err, operation)) + } + + async fn delete_credential_object( + &self, + object_type: &str, + id: &str, + expected_resource_version: u64, + operation: &'static str, + ) -> Result<(), Status> { + self.store + .delete_if(object_type, id, expected_resource_version) + .await + .map(|_| ()) + .map_err(|err| default_credential_store_persistence_error_to_status(err, operation)) + } +} + +fn default_credential_store_persistence_error_to_status( + err: PersistenceError, + operation: &str, +) -> Status { + match err { + PersistenceError::UniqueViolation { .. } => { + Status::already_exists(format!("default credential already exists: {err}")) + } + PersistenceError::Conflict { + current_resource_version, + } => Status::aborted(format!( + "default credential was modified concurrently during {operation} (current resource_version: {})", + current_resource_version.unwrap_or(0) + )), + PersistenceError::Decode(err) => Status::data_loss(format!( + "default credential decode failed during {operation}: {err}" + )), + PersistenceError::Encode(err) => Status::internal(format!( + "default credential encode failed during {operation}: {err}" + )), + other => Status::unavailable(format!("default credential {operation} failed: {other}")), + } +} + +#[async_trait] +impl CredentialDriver for KubernetesSecretsCredentialDriver { + async fn store_credential( + &self, + request: StoreCredentialRequest, + ) -> Result { + Self::store_credential(self, request).await + } + + async fn delete_credential(&self, request: DeleteCredentialRequest) -> Result<(), Status> { + Self::delete_credential(self, request).await + } + + async fn resolve_credentials( + &self, + requests: Vec, + ) -> Result, Status> { + Self::resolve_credentials(self, requests).await + } +} + +#[async_trait] +impl CredentialDriver for DbCredstoreCredentialDriver { + async fn store_credential( + &self, + request: StoreCredentialRequest, + ) -> Result { + Self::store_credential(self, request).await + } + + async fn delete_credential(&self, request: DeleteCredentialRequest) -> Result<(), Status> { + Self::delete_credential(self, request).await + } + + async fn resolve_credentials( + &self, + requests: Vec, + ) -> Result, Status> { + Self::resolve_credentials(self, requests).await + } +} + +#[async_trait] +impl CredentialDriver for VaultCredentialDriver { + async fn store_credential( + &self, + request: StoreCredentialRequest, + ) -> Result { + Self::store_credential(self, request).await + } + + async fn delete_credential(&self, request: DeleteCredentialRequest) -> Result<(), Status> { + Self::delete_credential(self, request).await + } + + async fn resolve_credentials( + &self, + requests: Vec, + ) -> Result, Status> { + Self::resolve_credentials(self, requests).await + } +} + +#[derive(Debug, Clone)] +#[cfg(unix)] +struct RemoteCredentialDriver { + channel: Channel, +} + +#[cfg(unix)] +impl RemoteCredentialDriver { + fn new(channel: Channel) -> Self { + Self { channel } + } + + fn client(&self) -> CredentialDriverClient { + CredentialDriverClient::new(self.channel.clone()) + } +} + +#[cfg(unix)] +#[async_trait] +impl CredentialDriver for RemoteCredentialDriver { + async fn store_credential( + &self, + request: StoreCredentialRequest, + ) -> Result { + let mut client = self.client(); + let response = client.store_credential(Request::new(request)).await?; + response + .into_inner() + .handle + .ok_or_else(|| Status::internal("credential driver returned no stored handle")) + } + + async fn delete_credential(&self, request: DeleteCredentialRequest) -> Result<(), Status> { + let mut client = self.client(); + client.delete_credential(Request::new(request)).await?; + Ok(()) + } + + async fn resolve_credentials( + &self, + requests: Vec, + ) -> Result, Status> { + let mut client = self.client(); + let response = client + .resolve_credentials(Request::new(ResolveCredentialsRequest { + credentials: requests, + })) + .await?; + Ok(response.into_inner().credentials) + } +} + +#[derive(Debug)] +struct ManagedCredentialDriverProcess { + child: std::sync::Mutex>, + socket_path: PathBuf, +} + +#[cfg(unix)] +impl ManagedCredentialDriverProcess { + fn new(child: tokio::process::Child, socket_path: PathBuf) -> Self { + Self { + child: std::sync::Mutex::new(Some(child)), + socket_path, + } + } +} + +impl Drop for ManagedCredentialDriverProcess { + fn drop(&mut self) { + if let Ok(mut child) = self.child.lock() { + let _ = child.take(); + } + let _ = std::fs::remove_file(&self.socket_path); + } +} + +#[cfg(unix)] +async fn connect_uds_driver( + driver_name: &str, + config: ConfiguredCredentialDriver, + socket_path: &Path, +) -> CoreResult { + if config.command.is_some() { + spawn_uds_driver(driver_name, config, socket_path).await + } else { + let channel = connect_ready_credential_driver(driver_name, socket_path).await?; + Ok(BuiltCredentialDriver { + driver: Arc::new(RemoteCredentialDriver::new(channel)), + process: None, + }) + } +} + +#[cfg(not(unix))] +async fn connect_uds_driver( + driver_name: &str, + _config: ConfiguredCredentialDriver, + _socket_path: &Path, +) -> CoreResult { + Err(Error::config(format!( + "credential driver '{driver_name}' uses transport = 'uds', but this platform does not support Unix domain sockets" + ))) +} + +#[cfg(unix)] +async fn spawn_uds_driver( + driver_name: &str, + config: ConfiguredCredentialDriver, + socket_path: &Path, +) -> CoreResult { + let command_path = config + .command + .expect("UDS command exists when spawning credential driver"); + let parent = socket_path.parent().ok_or_else(|| { + Error::execution(format!( + "credential driver '{driver_name}' socket path '{}' has no parent directory", + socket_path.display() + )) + })?; + std::fs::create_dir_all(parent).map_err(|err| { + Error::execution(format!( + "failed to create credential driver '{driver_name}' socket dir '{}': {err}", + parent.display() + )) + })?; + remove_stale_launched_driver_socket(driver_name, socket_path)?; + + let mut command = Command::new(&command_path); + command.kill_on_drop(true); + command.stdin(Stdio::null()); + command.stdout(Stdio::inherit()); + command.stderr(Stdio::inherit()); + command.args(&config.args); + command.arg("--bind-socket").arg(socket_path); + + let mut child = command.spawn().map_err(|err| { + Error::execution(format!( + "failed to launch credential driver '{driver_name}' '{}': {err}", + command_path.display() + )) + })?; + let channel = wait_for_launched_credential_driver( + driver_name, + socket_path, + &mut child, + Duration::from_secs(config.startup_timeout_secs), + ) + .await?; + let process = Arc::new(ManagedCredentialDriverProcess::new( + child, + socket_path.to_path_buf(), + )); + Ok(BuiltCredentialDriver { + driver: Arc::new(RemoteCredentialDriver::new(channel)), + process: Some(process), + }) +} + +#[cfg(unix)] +fn remove_stale_launched_driver_socket(driver_name: &str, socket_path: &Path) -> CoreResult<()> { + let metadata = match std::fs::symlink_metadata(socket_path) { + Ok(metadata) => metadata, + Err(err) if err.kind() == ErrorKind::NotFound => return Ok(()), + Err(err) => { + return Err(Error::execution(format!( + "failed to stat credential driver '{driver_name}' socket '{}': {err}", + socket_path.display() + ))); + } + }; + let file_type = metadata.file_type(); + if file_type.is_symlink() { + return Err(Error::execution(format!( + "credential driver '{driver_name}' socket '{}' is a symlink; refusing to remove it", + socket_path.display() + ))); + } + if !file_type.is_socket() { + return Err(Error::execution(format!( + "credential driver '{driver_name}' socket path '{}' exists but is not a Unix socket", + socket_path.display() + ))); + } + let expected_uid = rustix::process::geteuid().as_raw(); + if metadata.uid() != expected_uid { + return Err(Error::execution(format!( + "credential driver '{driver_name}' socket '{}' is owned by uid {} but current euid is {}", + socket_path.display(), + metadata.uid(), + expected_uid + ))); + } + std::fs::remove_file(socket_path).map_err(|err| { + Error::execution(format!( + "failed to remove stale credential driver '{driver_name}' socket '{}': {err}", + socket_path.display() + )) + }) +} + +#[cfg(unix)] +async fn wait_for_launched_credential_driver( + driver_name: &str, + socket_path: &Path, + child: &mut tokio::process::Child, + timeout: Duration, +) -> CoreResult { + let deadline = Instant::now() + timeout; + let mut last_error: Option; + + loop { + let try_wait_result = child.try_wait().map_err(|err| { + Error::execution(format!( + "failed to poll credential driver '{driver_name}' process: {err}" + )) + })?; + if let Some(status) = try_wait_result { + return Err(Error::execution(format!( + "credential driver '{driver_name}' exited before becoming ready with status {status}" + ))); + } + + match connect_ready_credential_driver(driver_name, socket_path).await { + Ok(channel) => return Ok(channel), + Err(err) => last_error = Some(err.to_string()), + } + + if Instant::now() >= deadline { + return Err(Error::execution(format!( + "timed out waiting for credential driver '{driver_name}' socket '{}': {}", + socket_path.display(), + last_error.unwrap_or_else(|| "unknown error".to_string()) + ))); + } + + tokio::time::sleep(CREDENTIAL_DRIVER_CONNECT_INTERVAL).await; + } +} + +#[cfg(unix)] +async fn connect_ready_credential_driver( + driver_name: &str, + socket_path: &Path, +) -> CoreResult { + let channel = connect_credential_driver_socket(driver_name, socket_path).await?; + let mut client = CredentialDriverClient::new(channel.clone()); + client + .get_capabilities(Request::new(GetCredentialDriverCapabilitiesRequest {})) + .await + .map_err(|status| { + Error::config(format!( + "credential driver '{driver_name}' GetCapabilities failed: {status}" + )) + })?; + Ok(channel) +} + +#[cfg(unix)] +async fn connect_credential_driver_socket( + driver_name: &str, + socket_path: &Path, +) -> CoreResult { + let socket_path = socket_path.to_path_buf(); + let display_path = socket_path.clone(); + Endpoint::from_static("http://[::]:50051") + .connect_with_connector(service_fn(move |_: tonic::transport::Uri| { + let socket_path = socket_path.clone(); + async move { UnixStream::connect(socket_path).await.map(TokioIo::new) } + })) + .await + .map_err(|err| { + Error::transport(format!( + "failed to connect to credential driver '{driver_name}' socket '{}': {err}", + display_path.display() + )) + }) +} + +#[cfg(any(test, feature = "test-support"))] +#[derive(Debug)] +struct TestStaticCredentialDriver { + values: std::sync::Mutex>, +} + +#[cfg(any(test, feature = "test-support"))] +impl TestStaticCredentialDriver { + const NAME: &'static str = "test-static"; + + fn new() -> Self { + Self { + values: std::sync::Mutex::new(HashMap::new()), + } + } + + fn handle_from_request( + request_id: &str, + handle: Option, + ) -> Result { + handle.ok_or_else(|| { + Status::invalid_argument(format!( + "test-static credential request '{request_id}' is missing handle" + )) + }) + } +} + +#[cfg(any(test, feature = "test-support"))] +#[async_trait] +impl CredentialDriver for TestStaticCredentialDriver { + async fn store_credential( + &self, + request: StoreCredentialRequest, + ) -> Result { + let handle = request + .existing_handle + .map(|handle| handle.handle) + .filter(|handle| !handle.trim().is_empty()) + .unwrap_or_else(|| format!("{}:{}", request.provider_name, request.credential_key)); + self.values + .lock() + .map_err(|_| Status::internal("test-static credential store lock poisoned"))? + .insert(handle.clone(), request.value); + Ok(CredentialHandle { + driver: Self::NAME.to_string(), + handle, + metadata: HashMap::new(), + }) + } + + async fn delete_credential(&self, request: DeleteCredentialRequest) -> Result<(), Status> { + let handle = Self::handle_from_request("delete", request.handle)?; + self.values + .lock() + .map_err(|_| Status::internal("test-static credential store lock poisoned"))? + .remove(&handle.handle); + Ok(()) + } + + async fn resolve_credentials( + &self, + requests: Vec, + ) -> Result, Status> { + let mut responses = Vec::with_capacity(requests.len()); + for request in requests { + let handle = Self::handle_from_request(&request.request_id, request.handle)?; + let value = self + .values + .lock() + .map_err(|_| Status::internal("test-static credential store lock poisoned"))? + .get(&handle.handle) + .cloned() + .ok_or_else(|| Status::not_found("test-static credential handle not found"))?; + responses.push(ResolvedCredential { + request_id: request.request_id, + value, + expires_at_ms: 0, + }); + } + + Ok(responses) + } +} + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + + use openshell_core::proto::{CredentialHandle, Provider}; + use tonic::Code; + + use super::*; + + fn provider_with_handle(driver: &str, handle: &str) -> Provider { + Provider { + metadata: Some(openshell_core::proto::ObjectMeta { + name: "openai-local".to_string(), + ..Default::default() + }), + credential_handles: HashMap::from([( + "OPENAI_API_KEY".to_string(), + CredentialHandle { + driver: driver.to_string(), + handle: handle.to_string(), + metadata: HashMap::new(), + }, + )]), + ..Default::default() + } + } + + fn config_file(toml: &str) -> crate::config_file::ConfigFile { + toml::from_str(toml).expect("config file TOML") + } + + fn driver_table(toml: &str) -> toml::Value { + toml::from_str(toml).expect("driver table TOML") + } + + #[test] + fn builtin_credential_driver_kind_resolves_known_names() { + assert_eq!( + BuiltinCredentialDriverKind::from_name("kubernetes-secrets"), + Some(BuiltinCredentialDriverKind::KubernetesSecrets) + ); + assert_eq!( + BuiltinCredentialDriverKind::from_name("openshell-gateway"), + None + ); + assert_eq!( + BuiltinCredentialDriverKind::from_name("vault"), + Some(BuiltinCredentialDriverKind::Vault) + ); + assert_eq!( + BuiltinCredentialDriverKind::from_name("enterprise-secrets"), + None + ); + } + + #[test] + fn registry_defaults_to_internal_credential_storage() { + let registry = CredentialDriverRegistry::from_config(&Config::new(None)).unwrap(); + + assert_eq!( + registry.storage_owner_name().as_str(), + DbCredstoreCredentialDriver::NAME + ); + } + + #[test] + fn registry_allows_legacy_inline_credentials_with_default_driver() { + let registry = CredentialDriverRegistry::from_config(&Config::new(None)).unwrap(); + + registry + .validate_provider_handles(&Provider::default()) + .expect("legacy inline provider should not require credential handles"); + } + + #[test] + fn registry_rejects_default_driver_without_external_driver() { + let config = Config::new(None).with_default_credential_driver(Some("vault")); + + let err = CredentialDriverRegistry::from_config(&config).unwrap_err(); + + assert!(err.to_string().contains("default_credential_driver")); + assert!(err.to_string().contains("requires credential_drivers")); + } + + #[test] + fn registry_rejects_empty_handle_driver() { + let config = Config::new(None).with_credential_drivers(["test-static"]); + let registry = CredentialDriverRegistry::from_config(&config).unwrap(); + + let err = registry + .validate_provider_handles(&provider_with_handle("", "openai/API_KEY")) + .unwrap_err(); + + assert_eq!(err.code(), Code::InvalidArgument); + assert!(err.message().contains("missing driver")); + } + + #[test] + fn registry_rejects_empty_handle_value() { + let config = Config::new(None).with_credential_drivers(["test-static"]); + let registry = CredentialDriverRegistry::from_config(&config).unwrap(); + + let err = registry + .validate_provider_handles(&provider_with_handle("test-static", "")) + .unwrap_err(); + + assert_eq!(err.code(), Code::InvalidArgument); + assert!(err.message().contains("missing handle")); + } + + #[test] + fn registry_rejects_unknown_handle_driver() { + let config = Config::new(None).with_credential_drivers(["test-static"]); + let registry = CredentialDriverRegistry::from_config(&config).unwrap(); + + let err = registry + .validate_provider_handles(&provider_with_handle("vault", "openai/API_KEY")) + .unwrap_err(); + + assert_eq!(err.code(), Code::InvalidArgument); + assert!(err.message().contains("not enabled")); + } + + #[test] + fn registry_rejects_default_driver_not_enabled() { + let config = Config::new(None) + .with_credential_drivers(["test-static"]) + .with_default_credential_driver(Some("vault")); + + let err = CredentialDriverRegistry::from_config(&config).unwrap_err(); + + assert!(err.to_string().contains("default_credential_driver")); + assert!(err.to_string().contains("not listed")); + } + + #[test] + fn registry_rejects_multiple_enabled_drivers() { + let config = Config::new(None).with_credential_drivers(["test-static", "vault"]); + + let err = CredentialDriverRegistry::from_config(&config).unwrap_err(); + + assert!(err.to_string().contains("at most one")); + } + + #[tokio::test] + async fn runtime_stores_and_resolves_test_static_handles() { + let config = Config::new(None) + .with_credential_drivers(["test-static"]) + .with_default_credential_driver(Some("test-static")); + let runtime = CredentialRuntime::from_config(&config).unwrap(); + let stored = runtime + .store_provider_credentials( + "openai-local", + &HashMap::from([("OPENAI_API_KEY".to_string(), "sk-test".to_string())]), + &HashMap::new(), + ) + .await + .unwrap(); + let mut provider = Provider { + metadata: Some(openshell_core::proto::ObjectMeta { + name: "openai-local".to_string(), + ..Default::default() + }), + ..Default::default() + }; + provider.credential_handles = stored; + + let resolved = runtime + .resolve_provider_handles(&provider, 1_000) + .await + .unwrap(); + + assert_eq!( + resolved.values.get("OPENAI_API_KEY").map(String::as_str), + Some("sk-test") + ); + } + + #[tokio::test] + async fn runtime_overwrites_existing_test_static_handle() { + let config = Config::new(None).with_credential_drivers(["test-static"]); + let runtime = CredentialRuntime::from_config(&config).unwrap(); + let first = runtime + .store_provider_credentials( + "openai-local", + &HashMap::from([("OPENAI_API_KEY".to_string(), "sk-first".to_string())]), + &HashMap::new(), + ) + .await + .unwrap(); + let second = runtime + .store_provider_credentials( + "openai-local", + &HashMap::from([("OPENAI_API_KEY".to_string(), "sk-second".to_string())]), + &first, + ) + .await + .unwrap(); + assert_eq!( + first.get("OPENAI_API_KEY").unwrap().handle, + second.get("OPENAI_API_KEY").unwrap().handle + ); + + let provider = Provider { + metadata: Some(openshell_core::proto::ObjectMeta { + name: "openai-local".to_string(), + ..Default::default() + }), + credential_handles: second, + ..Default::default() + }; + + let resolved = runtime + .resolve_provider_handles(&provider, 1_000) + .await + .unwrap(); + + assert_eq!( + resolved.values.get("OPENAI_API_KEY").map(String::as_str), + Some("sk-second") + ); + } + + #[tokio::test] + async fn runtime_deletes_stored_test_static_handle() { + let config = Config::new(None).with_credential_drivers(["test-static"]); + let runtime = CredentialRuntime::from_config(&config).unwrap(); + let stored = runtime + .store_provider_credentials( + "openai-local", + &HashMap::from([("OPENAI_API_KEY".to_string(), "sk-test".to_string())]), + &HashMap::new(), + ) + .await + .unwrap(); + + runtime + .delete_provider_credential_handles("openai-local", &stored) + .await + .unwrap(); + + let provider = Provider { + metadata: Some(openshell_core::proto::ObjectMeta { + name: "openai-local".to_string(), + ..Default::default() + }), + credential_handles: stored, + ..Default::default() + }; + let err = runtime + .resolve_provider_handles(&provider, 1_000) + .await + .unwrap_err(); + + assert_eq!(err.code(), Code::NotFound); + } + + #[tokio::test] + async fn runtime_uses_configured_in_tree_driver_table() { + let config = Config::new(None).with_credential_drivers(["test-static"]); + let file = config_file( + r#" +[openshell.credential_drivers.test-static] +transport = "in_tree" +backend_specific = "ignored-by-gateway" +"#, + ); + let runtime = CredentialRuntime::from_config_file(&config, Some(&file)) + .await + .unwrap(); + + let stored = runtime + .store_provider_credentials( + "openai-local", + &HashMap::from([("OPENAI_API_KEY".to_string(), "sk-test".to_string())]), + &HashMap::new(), + ) + .await + .unwrap(); + + assert_eq!( + stored + .get("OPENAI_API_KEY") + .map(|handle| handle.driver.as_str()), + Some("test-static") + ); + } + + #[tokio::test] + async fn runtime_uses_configured_vault_in_tree_driver_table() { + let token_file = tempfile::NamedTempFile::new().unwrap(); + std::fs::write(token_file.path(), "dev-token").unwrap(); + let config = Config::new(None).with_credential_drivers(["vault"]); + let file = config_file(&format!( + r#" +[openshell.credential_drivers.vault] +transport = "in_tree" +address = "http://127.0.0.1:8200" +auth_method = "token_file" +token_path = "{}" +"#, + token_file.path().display() + )); + let runtime = CredentialRuntime::from_config_file(&config, Some(&file)) + .await + .unwrap(); + + assert!(runtime.stores_provider_credentials()); + } + + #[tokio::test] + async fn runtime_uses_configured_default_credential_storage() { + let storage = tempfile::tempdir().unwrap(); + let key_encryption_key_path = storage.path().join("key-encryption-key.bin"); + let store = Arc::new(crate::persistence::test_store().await); + let config = Config::new(None); + let file = config_file(&format!( + r#" +[openshell.gateway.credential_storage] +key_encryption_key_path = "{}" +"#, + key_encryption_key_path.display() + )); + let runtime = CredentialRuntime::from_config_file_with_store( + &config, + Some(&file), + Arc::clone(&store), + ) + .await + .unwrap(); + + let stored = runtime + .store_provider_credentials( + "openai-local", + &HashMap::from([("OPENAI_API_KEY".to_string(), "sk-test".to_string())]), + &HashMap::new(), + ) + .await + .unwrap(); + + let handle_id = stored + .get("OPENAI_API_KEY") + .and_then(|handle| handle.handle.strip_prefix("v1:")) + .expect("stored db credstore handle id"); + let credential_record = store + .get(DbCredstoreCredentialDriver::OBJECT_TYPE, handle_id) + .await + .unwrap() + .expect("encrypted credential object"); + assert!( + !String::from_utf8_lossy(&credential_record.payload).contains("sk-test"), + "credential object payload must not contain plaintext credentials" + ); + + let provider = Provider { + metadata: Some(openshell_core::proto::ObjectMeta { + name: "openai-local".to_string(), + ..Default::default() + }), + credential_handles: stored, + ..Default::default() + }; + + let resolved = runtime + .resolve_provider_handles(&provider, 1_000) + .await + .unwrap(); + + assert_eq!( + resolved.values.get("OPENAI_API_KEY").map(String::as_str), + Some("sk-test") + ); + } + + #[tokio::test] + async fn runtime_rejects_in_tree_table_without_builtin_driver() { + let config = Config::new(None).with_credential_drivers(["enterprise-secrets"]); + let file = config_file( + r#" +[openshell.credential_drivers.enterprise-secrets] +transport = "in_tree" +"#, + ); + + let err = CredentialRuntime::from_config_file(&config, Some(&file)) + .await + .unwrap_err(); + + assert!(err.to_string().contains("no in-tree implementation")); + } + + #[tokio::test] + async fn runtime_rejects_unknown_driver_without_driver_table() { + let config = Config::new(None).with_credential_drivers(["enterprise-secrets"]); + + let err = CredentialRuntime::from_config_file(&config, None) + .await + .unwrap_err(); + + assert!(err.to_string().contains("not a built-in credential driver")); + assert!(err.to_string().contains("transport = 'uds'")); + } + + #[test] + fn runtime_from_config_rejects_unknown_driver_name() { + let config = Config::new(None).with_credential_drivers(["enterprise-secrets"]); + + let err = CredentialRuntime::from_config(&config).unwrap_err(); + + assert!(err.to_string().contains("not a built-in credential driver")); + } + + #[tokio::test] + async fn runtime_rejects_uds_table_without_socket_path() { + let config = Config::new(None).with_credential_drivers(["vault"]); + let file = config_file( + r#" +[openshell.credential_drivers.vault] +transport = "uds" +"#, + ); + + let err = CredentialRuntime::from_config_file(&config, Some(&file)) + .await + .unwrap_err(); + + assert!(err.to_string().contains("socket_path is required")); + } + + #[tokio::test] + async fn runtime_rejects_relative_uds_socket_path() { + let config = Config::new(None).with_credential_drivers(["vault"]); + let file = config_file( + r#" +[openshell.credential_drivers.vault] +transport = "uds" +socket_path = "vault.sock" +"#, + ); + + let err = CredentialRuntime::from_config_file(&config, Some(&file)) + .await + .unwrap_err(); + + assert!(err.to_string().contains("socket_path must be absolute")); + } + + #[tokio::test] + async fn runtime_rejects_unknown_transport() { + let config = Config::new(None).with_credential_drivers(["vault"]); + let file = config_file( + r#" +[openshell.credential_drivers.vault] +transport = "tcp" +"#, + ); + + let err = CredentialRuntime::from_config_file(&config, Some(&file)) + .await + .unwrap_err(); + + assert!(err.to_string().contains("transport must be")); + } + + #[test] + fn parse_uds_driver_launch_settings() { + let parsed = parse_driver_table( + "enterprise-secrets", + &driver_table( + r#" +transport = "uds" +socket_path = "/tmp/openshell-enterprise-secrets.sock" +command = "/usr/local/libexec/openshell-credential-driver-enterprise-secrets" +args = ["--profile", "dev"] +startup_timeout_secs = 3 +"#, + ), + ) + .unwrap(); + + assert_eq!(parsed.transport, CredentialDriverTransport::Uds); + assert_eq!( + parsed.socket_path.as_deref(), + Some(Path::new("/tmp/openshell-enterprise-secrets.sock")) + ); + assert_eq!( + parsed.command.as_deref(), + Some(Path::new( + "/usr/local/libexec/openshell-credential-driver-enterprise-secrets" + )) + ); + assert_eq!(parsed.args, ["--profile", "dev"]); + assert_eq!(parsed.startup_timeout_secs, 3); + } + + #[test] + fn parse_uds_driver_defaults_to_connect_only() { + let parsed = parse_driver_table( + "enterprise-secrets", + &driver_table( + r#" +transport = "uds" +socket_path = "/tmp/openshell-enterprise-secrets.sock" +"#, + ), + ) + .unwrap(); + + assert_eq!(parsed.transport, CredentialDriverTransport::Uds); + assert!(parsed.command.is_none()); + assert!(parsed.args.is_empty()); + assert_eq!( + parsed.startup_timeout_secs, + DEFAULT_CREDENTIAL_DRIVER_STARTUP_TIMEOUT_SECS + ); + } + + #[test] + fn parse_driver_table_preserves_backend_config_without_transport_fields() { + let parsed = parse_driver_table( + "kubernetes-secrets", + &driver_table( + r#" +transport = "in_tree" +namespace = "openshell" +allow_reference_namespace = true +"#, + ), + ) + .unwrap(); + + assert_eq!( + parsed + .backend_config + .get("namespace") + .and_then(toml::Value::as_str), + Some("openshell") + ); + assert_eq!( + parsed + .backend_config + .get("allow_reference_namespace") + .and_then(toml::Value::as_bool), + Some(true) + ); + assert!(!parsed.backend_config.contains_key("transport")); + } + + #[test] + fn parse_uds_driver_rejects_relative_command() { + let err = parse_driver_table( + "enterprise-secrets", + &driver_table( + r#" +transport = "uds" +socket_path = "/tmp/openshell-enterprise-secrets.sock" +command = "openshell-credential-driver-enterprise-secrets" +"#, + ), + ) + .unwrap_err(); + + assert!(err.to_string().contains("command must be absolute")); + } + + #[test] + fn parse_uds_driver_rejects_args_without_command() { + let err = parse_driver_table( + "enterprise-secrets", + &driver_table( + r#" +transport = "uds" +socket_path = "/tmp/openshell-enterprise-secrets.sock" +args = ["--profile", "dev"] +"#, + ), + ) + .unwrap_err(); + + assert!(err.to_string().contains("args requires command")); + } + + #[test] + fn parse_uds_driver_rejects_timeout_without_command() { + let err = parse_driver_table( + "enterprise-secrets", + &driver_table( + r#" +transport = "uds" +socket_path = "/tmp/openshell-enterprise-secrets.sock" +startup_timeout_secs = 3 +"#, + ), + ) + .unwrap_err(); + + assert!( + err.to_string() + .contains("startup_timeout_secs requires command") + ); + } + + #[test] + fn parse_in_tree_driver_rejects_launch_settings() { + let err = parse_driver_table( + "test-static", + &driver_table( + r#" +transport = "in_tree" +command = "/usr/local/libexec/openshell-credential-driver-test" +"#, + ), + ) + .unwrap_err(); + + assert!( + err.to_string() + .contains("command, args, and startup_timeout_secs require transport = 'uds'") + ); + } + + #[cfg(unix)] + #[test] + fn remove_stale_launched_driver_socket_removes_socket() { + use std::os::unix::net::UnixListener as StdUnixListener; + + let dir = tempfile::tempdir().unwrap(); + let socket_path = dir.path().join("driver.sock"); + let listener = StdUnixListener::bind(&socket_path).unwrap(); + + remove_stale_launched_driver_socket("enterprise-secrets", &socket_path).unwrap(); + + drop(listener); + assert!(!socket_path.exists()); + } + + #[cfg(unix)] + #[test] + fn remove_stale_launched_driver_socket_rejects_regular_file() { + let dir = tempfile::tempdir().unwrap(); + let socket_path = dir.path().join("driver.sock"); + std::fs::write(&socket_path, "not a socket").unwrap(); + + let err = + remove_stale_launched_driver_socket("enterprise-secrets", &socket_path).unwrap_err(); + + assert!(err.to_string().contains("not a Unix socket")); + assert!(socket_path.exists()); + } + + #[cfg(unix)] + #[test] + fn remove_stale_launched_driver_socket_rejects_symlink() { + let dir = tempfile::tempdir().unwrap(); + let target = dir.path().join("target.sock"); + let socket_path = dir.path().join("driver.sock"); + std::os::unix::fs::symlink(&target, &socket_path).unwrap(); + + let err = + remove_stale_launched_driver_socket("enterprise-secrets", &socket_path).unwrap_err(); + + assert!(err.to_string().contains("is a symlink")); + assert!(std::fs::symlink_metadata(&socket_path).is_ok()); + } + + #[tokio::test] + async fn runtime_rejects_unconnected_enabled_driver_on_resolution() { + let config = Config::new(None).with_credential_drivers(["vault"]); + let runtime = CredentialRuntime::from_config(&config).unwrap(); + + let err = runtime + .resolve_provider_handles(&provider_with_handle("vault", "v1:providers/openai"), 1_000) + .await + .unwrap_err(); + + assert_eq!(err.code(), Code::FailedPrecondition); + assert!(err.message().contains("not connected")); + } +} diff --git a/crates/openshell-server/src/grpc/auth_rpc.rs b/crates/openshell-server/src/grpc/auth_rpc.rs index 88c771bed..bf10d3961 100644 --- a/crates/openshell-server/src/grpc/auth_rpc.rs +++ b/crates/openshell-server/src/grpc/auth_rpc.rs @@ -169,7 +169,9 @@ mod tests { ); let compute = new_test_runtime(store.clone()).await; let mut state = ServerState::new( - Config::new(None).with_database_url("sqlite::memory:?cache=shared"), + Config::new(None) + .with_database_url("sqlite::memory:?cache=shared") + .with_credential_drivers(["test-static"]), store, compute, SandboxIndex::new(), @@ -345,7 +347,9 @@ mod tests { ); let compute = new_test_runtime(store.clone()).await; let state = Arc::new(ServerState::new( - Config::new(None).with_database_url("sqlite::memory:?cache=shared"), + Config::new(None) + .with_database_url("sqlite::memory:?cache=shared") + .with_credential_drivers(["test-static"]), store, compute, SandboxIndex::new(), diff --git a/crates/openshell-server/src/grpc/mod.rs b/crates/openshell-server/src/grpc/mod.rs index fe2eb331c..8f9488402 100644 --- a/crates/openshell-server/src/grpc/mod.rs +++ b/crates/openshell-server/src/grpc/mod.rs @@ -698,7 +698,9 @@ pub mod test_support { ); let compute = new_test_runtime(store.clone()).await; Arc::new(ServerState::new( - Config::new(None).with_database_url("sqlite::memory:?cache=shared"), + Config::new(None) + .with_database_url("sqlite::memory:?cache=shared") + .with_credential_drivers(["test-static"]), store, compute, SandboxIndex::new(), diff --git a/crates/openshell-server/src/grpc/policy.rs b/crates/openshell-server/src/grpc/policy.rs index 770bb71cc..ddb1165e1 100644 --- a/crates/openshell-server/src/grpc/policy.rs +++ b/crates/openshell-server/src/grpc/policy.rs @@ -1410,9 +1410,12 @@ pub(super) async fn handle_get_sandbox_provider_environment( let provider_names = spec.providers; let provider_env_revision = compute_provider_env_revision(state.store.as_ref(), &provider_names).await?; - let provider_environment = - super::provider::resolve_provider_environment(state.store.as_ref(), &provider_names) - .await?; + let provider_environment = super::provider::resolve_provider_environment_with_credentials( + state.store.as_ref(), + &provider_names, + &state.credentials, + ) + .await?; info!( sandbox_id = %sandbox_id, @@ -4334,6 +4337,7 @@ mod tests { .collect(), config: HashMap::new(), credential_expires_at_ms: HashMap::new(), + credential_handles: HashMap::new(), } } diff --git a/crates/openshell-server/src/grpc/provider.rs b/crates/openshell-server/src/grpc/provider.rs index d5a5f5c90..5535cd658 100644 --- a/crates/openshell-server/src/grpc/provider.rs +++ b/crates/openshell-server/src/grpc/provider.rs @@ -8,8 +8,9 @@ use crate::persistence::{ ObjectId, ObjectLabels, ObjectName, ObjectType, Store, WriteCondition, generate_name, }; + use openshell_core::proto::{ - Provider, ProviderCredentialTokenGrantAudienceOverride, ProviderProfile, + CredentialHandle, Provider, ProviderCredentialTokenGrantAudienceOverride, ProviderProfile, ProviderProfileCredential, Sandbox, }; use openshell_core::telemetry::{ @@ -36,6 +37,13 @@ fn redact_provider_credentials(mut provider: Provider) -> Provider { for value in provider.credentials.values_mut() { *value = "REDACTED".to_string(); } + for key in provider.credential_handles.keys() { + provider + .credentials + .entry(key.clone()) + .or_insert_with(|| "REDACTED".to_string()); + } + provider.credential_handles.clear(); provider } @@ -63,9 +71,18 @@ impl ProviderEnvironment { } } +#[cfg(test)] pub(super) async fn create_provider_record( + store: &Store, + provider: Provider, +) -> Result { + create_provider_record_validating(store, provider, None).await +} + +async fn create_provider_record_validating( store: &Store, mut provider: Provider, + credentials: Option<&crate::credentials::CredentialRuntime>, ) -> Result { use crate::persistence::{ObjectName, current_time_ms}; @@ -97,6 +114,11 @@ pub(super) async fn create_provider_record( if provider.r#type.trim().is_empty() { return Err(Status::invalid_argument("provider.type is required")); } + if !provider.credential_handles.is_empty() { + return Err(Status::invalid_argument( + "provider.credential_handles is internal gateway state and cannot be supplied", + )); + } if provider.credentials.is_empty() && !provider_type_allows_empty_credentials(store, &provider.r#type).await? { @@ -115,8 +137,18 @@ pub(super) async fn create_provider_record( metadata.id.clone_from(&provider_id); } + let credentials_to_store = provider.credentials.clone(); + store_provider_credentials_if_configured( + credentials, + &mut provider, + &credentials_to_store, + &std::collections::HashMap::new(), + ) + .await?; + validate_provider_fields(&provider)?; + // Create with MustCreate condition to prevent duplicate creation race - let result = store + let write_result = store .put_if( Provider::object_type(), &provider_id, @@ -125,17 +157,30 @@ pub(super) async fn create_provider_record( None, WriteCondition::MustCreate, ) - .await - .map_err(|e| { + .await; + + let result = match write_result { + Ok(result) => result, + Err(e) => { + if !provider.credential_handles.is_empty() + && let Some(credentials) = credentials + { + let _ = credentials + .delete_provider_credential_handles( + provider.object_name(), + &provider.credential_handles, + ) + .await; + } if matches!( e, crate::persistence::PersistenceError::UniqueViolation { .. } ) { - Status::already_exists("provider already exists") - } else { - Status::internal(format!("persist provider failed: {e}")) + return Err(Status::already_exists("provider already exists")); } - })?; + return Err(Status::internal(format!("persist provider failed: {e}"))); + } + }; if let Some(metadata) = provider.metadata.as_mut() { metadata.resource_version = result.resource_version; @@ -176,6 +221,14 @@ pub(super) async fn list_provider_records( pub(super) async fn update_provider_record( store: &Store, provider: Provider, +) -> Result { + update_provider_record_validating(store, provider, None).await +} + +async fn update_provider_record_validating( + store: &Store, + provider: Provider, + credentials: Option<&crate::credentials::CredentialRuntime>, ) -> Result { use crate::persistence::{ObjectId, ObjectName}; @@ -204,6 +257,11 @@ pub(super) async fn update_provider_record( "provider type cannot be changed; delete and recreate the provider", )); } + if !provider.credential_handles.is_empty() { + return Err(Status::invalid_argument( + "provider.credential_handles is internal gateway state and cannot be supplied", + )); + } let current_version = existing.metadata.as_ref().map_or(0, |m| m.resource_version); @@ -215,6 +273,14 @@ pub(super) async fn update_provider_record( // Apply merge to create candidate let mut candidate = existing.clone(); + let existing_handles = existing.credential_handles.clone(); + let removed_credential_handles = credential_handles_removed_by_update(&existing, &provider); + let updated_credential_values = provider + .credentials + .iter() + .filter(|(_, value)| !value.is_empty()) + .map(|(key, value)| (key.clone(), value.clone())) + .collect::>(); candidate.credentials = merge_map(candidate.credentials, provider.credentials); candidate.config = merge_map(candidate.config, provider.config); candidate.credential_expires_at_ms = merge_i64_map( @@ -229,8 +295,42 @@ pub(super) async fn update_provider_record( // strand legacy records whose stored type predates current limits. See // #1347. super::validation::validate_object_metadata(candidate.metadata.as_ref(), "provider")?; - validate_provider_mutable_fields(&candidate)?; - validate_provider_update_against_attached_sandboxes(store, &candidate).await?; + let credential_update = prepare_provider_credential_update( + credentials, + candidate.object_name(), + &updated_credential_values, + &existing_handles, + ) + .await?; + if credentials.is_some_and(crate::credentials::CredentialRuntime::stores_provider_credentials) { + for key in updated_credential_values.keys() { + candidate.credentials.remove(key); + } + } + for key in removed_credential_handles.keys() { + candidate.credential_handles.remove(key); + } + candidate + .credential_handles + .extend(credential_update.pre_stored_handles.clone()); + if let Err(err) = validate_provider_mutable_fields(&candidate) { + cleanup_pre_stored_provider_credentials( + credentials, + candidate.object_name(), + &credential_update.pre_stored_handles, + ) + .await; + return Err(err); + } + if let Err(err) = validate_provider_update_against_attached_sandboxes(store, &candidate).await { + cleanup_pre_stored_provider_credentials( + credentials, + candidate.object_name(), + &credential_update.pre_stored_handles, + ) + .await; + return Err(err); + } // Serialize labels for storage let labels_map = candidate.object_labels(); @@ -240,14 +340,22 @@ pub(super) async fn update_provider_record( { None } else { - Some( - serde_json::to_string(&labels_map) - .map_err(|e| Status::internal(format!("serialize labels failed: {e}")))?, - ) + match serde_json::to_string(&labels_map) { + Ok(labels_json) => Some(labels_json), + Err(e) => { + cleanup_pre_stored_provider_credentials( + credentials, + candidate.object_name(), + &credential_update.pre_stored_handles, + ) + .await; + return Err(Status::internal(format!("serialize labels failed: {e}"))); + } + } }; // Write validated candidate with CAS condition - let result = store + let write_result = store .put_if( Provider::object_type(), candidate.object_id(), @@ -256,22 +364,29 @@ pub(super) async fn update_provider_record( labels_json.as_deref(), WriteCondition::MatchResourceVersion(cas_version), ) - .await - .map_err(|e| { - if matches!(e, crate::persistence::PersistenceError::Conflict { .. }) { - Status::aborted(format!( - "provider was modified concurrently (current resource_version: {})", - match e { - crate::persistence::PersistenceError::Conflict { - current_resource_version, - } => current_resource_version.unwrap_or(0), - _ => 0, - } - )) - } else { - Status::internal(format!("update provider failed: {e}")) - } - })?; + .await; + + let result = match write_result { + Ok(result) => result, + Err(e) => { + cleanup_pre_stored_provider_credentials( + credentials, + candidate.object_name(), + &credential_update.pre_stored_handles, + ) + .await; + return Err(provider_update_persistence_error_to_status(e)); + } + }; + + finish_provider_credential_update( + credentials, + candidate.object_name(), + credential_update, + &removed_credential_handles, + &existing_handles, + ) + .await?; // Update resource_version from successful write if let Some(metadata) = candidate.metadata.as_mut() { @@ -281,6 +396,7 @@ pub(super) async fn update_provider_record( Ok(redact_provider_credentials(candidate)) } +#[cfg(test)] pub(super) async fn delete_provider_record(store: &Store, name: &str) -> Result { if name.is_empty() { return Err(Status::invalid_argument("name is required")); @@ -311,6 +427,44 @@ pub(super) async fn delete_provider_record(store: &Store, name: &str) -> Result< .map_err(|e| Status::internal(format!("delete provider failed: {e}"))) } +pub(super) async fn delete_provider_record_with_credentials( + store: &Store, + credentials: &crate::credentials::CredentialRuntime, + name: &str, +) -> Result { + if name.is_empty() { + return Err(Status::invalid_argument("name is required")); + } + + let Some(provider) = store + .get_message_by_name::(name) + .await + .map_err(|e| Status::internal(format!("fetch provider failed: {e}")))? + else { + return Ok(false); + }; + + let blocking_sandboxes = sandboxes_using_provider(store, name).await?; + if !blocking_sandboxes.is_empty() { + return Err(Status::failed_precondition(format!( + "provider '{name}' is attached to sandbox(es): {}", + blocking_sandboxes.join(", ") + ))); + } + + credentials + .delete_provider_credential_handles(provider.object_name(), &provider.credential_handles) + .await?; + + crate::provider_refresh::delete_refresh_states_for_provider(store, provider.object_id()) + .await?; + + store + .delete_by_name(Provider::object_type(), name) + .await + .map_err(|e| Status::internal(format!("delete provider failed: {e}"))) +} + /// Iterate over every `Sandbox` in the store and collect items produced by /// `f`. `f` receives each decoded sandbox; returning `Some(T)` includes the /// value in the output, `None` skips it. @@ -421,6 +575,177 @@ fn merge_i64_map( existing } +fn credential_handles_removed_by_update( + existing: &Provider, + incoming: &Provider, +) -> std::collections::HashMap { + incoming + .credentials + .iter() + .filter(|(_, value)| value.is_empty()) + .filter_map(|(key, _)| { + existing + .credential_handles + .get(key) + .cloned() + .map(|handle| (key.clone(), handle)) + }) + .collect() +} + +#[derive(Debug, Clone, Default)] +struct ProviderCredentialUpdate { + pre_stored_handles: std::collections::HashMap, + deferred_store_values: std::collections::HashMap, + replaced_handles: std::collections::HashMap, +} + +async fn prepare_provider_credential_update( + credentials: Option<&crate::credentials::CredentialRuntime>, + provider_name: &str, + updated_values: &std::collections::HashMap, + existing_handles: &std::collections::HashMap, +) -> Result { + let Some(credentials) = credentials else { + return Ok(ProviderCredentialUpdate::default()); + }; + if !credentials.stores_provider_credentials() || updated_values.is_empty() { + return Ok(ProviderCredentialUpdate::default()); + } + + let mut update = ProviderCredentialUpdate::default(); + let mut values_requiring_new_handles = std::collections::HashMap::new(); + for (credential_key, value) in updated_values { + match existing_handles.get(credential_key) { + Some(existing_handle) if credentials.storage_owns_handle(existing_handle) => { + update + .deferred_store_values + .insert(credential_key.clone(), value.clone()); + } + Some(replaced_handle) => { + values_requiring_new_handles.insert(credential_key.clone(), value.clone()); + update + .replaced_handles + .insert(credential_key.clone(), replaced_handle.clone()); + } + None => { + values_requiring_new_handles.insert(credential_key.clone(), value.clone()); + } + } + } + + if !values_requiring_new_handles.is_empty() { + update.pre_stored_handles = credentials + .store_provider_credentials( + provider_name, + &values_requiring_new_handles, + &std::collections::HashMap::new(), + ) + .await?; + } + + Ok(update) +} + +async fn finish_provider_credential_update( + credentials: Option<&crate::credentials::CredentialRuntime>, + provider_name: &str, + update: ProviderCredentialUpdate, + removed_handles: &std::collections::HashMap, + existing_handles: &std::collections::HashMap, +) -> Result<(), Status> { + let Some(credentials) = credentials else { + return Ok(()); + }; + if !credentials.stores_provider_credentials() { + return Ok(()); + } + + if !update.deferred_store_values.is_empty() { + credentials + .store_provider_credentials( + provider_name, + &update.deferred_store_values, + existing_handles, + ) + .await?; + } + + let mut handles_to_delete = removed_handles.clone(); + handles_to_delete.extend(update.replaced_handles); + if !handles_to_delete.is_empty() { + credentials + .delete_provider_credential_handles(provider_name, &handles_to_delete) + .await?; + } + + Ok(()) +} + +async fn cleanup_pre_stored_provider_credentials( + credentials: Option<&crate::credentials::CredentialRuntime>, + provider_name: &str, + handles: &std::collections::HashMap, +) { + if handles.is_empty() { + return; + } + let Some(credentials) = credentials else { + return; + }; + if let Err(err) = credentials + .delete_provider_credential_handles(provider_name, handles) + .await + { + warn!( + provider_name = %provider_name, + error = %err, + "failed to clean up staged provider credentials after provider update failure" + ); + } +} + +fn provider_update_persistence_error_to_status( + err: crate::persistence::PersistenceError, +) -> Status { + if let crate::persistence::PersistenceError::Conflict { + current_resource_version, + } = err + { + Status::aborted(format!( + "provider was modified concurrently (current resource_version: {})", + current_resource_version.unwrap_or(0) + )) + } else { + Status::internal(format!("update provider failed: {err}")) + } +} + +async fn store_provider_credentials_if_configured( + credentials: Option<&crate::credentials::CredentialRuntime>, + provider: &mut Provider, + values_to_store: &std::collections::HashMap, + existing_handles: &std::collections::HashMap, +) -> Result<(), Status> { + let Some(credentials) = credentials else { + return Ok(()); + }; + if !credentials.stores_provider_credentials() || values_to_store.is_empty() { + return Ok(()); + } + + let provider_name = provider.object_name().to_string(); + let stored_handles = credentials + .store_provider_credentials(&provider_name, values_to_store, existing_handles) + .await?; + + for key in stored_handles.keys() { + provider.credentials.remove(key); + } + provider.credential_handles.extend(stored_handles); + Ok(()) +} + // --------------------------------------------------------------------------- // Provider environment resolution // --------------------------------------------------------------------------- @@ -431,9 +756,22 @@ fn merge_i64_map( /// collects credential key-value pairs. Returns a map of environment variables /// to inject into the sandbox. Credential keys must be unique across attached /// providers so one provider cannot silently overwrite another provider's token. +#[cfg(test)] pub(super) async fn resolve_provider_environment( store: &Store, provider_names: &[String], +) -> Result { + let credentials = crate::credentials::CredentialRuntime::from_config( + &openshell_core::Config::new(None).with_credential_drivers(["test-static"]), + ) + .map_err(|err| Status::internal(format!("initialize credential runtime failed: {err}")))?; + resolve_provider_environment_with_credentials(store, provider_names, &credentials).await +} + +pub(super) async fn resolve_provider_environment_with_credentials( + store: &Store, + provider_names: &[String], + credentials: &crate::credentials::CredentialRuntime, ) -> Result { if provider_names.is_empty() { return Ok(ProviderEnvironment::default()); @@ -489,6 +827,37 @@ pub(super) async fn resolve_provider_environment( } } + let resolved_refs = credentials + .resolve_provider_handles(&provider, now_ms) + .await?; + for (key, value) in resolved_refs.values { + if is_non_injectable_provider_credential(&provider, &key) { + warn!( + provider_name = %name, + key = %key, + "skipping non-injectable provider credential handle" + ); + continue; + } + if is_valid_env_key(&key) { + if let Some(expires_at_ms) = resolved_refs + .expires_at_ms + .get(&key) + .copied() + .filter(|expires_at_ms| *expires_at_ms > 0) + { + expires.entry(key.clone()).or_insert(expires_at_ms); + } + env.entry(key).or_insert(value); + } else { + warn!( + provider_name = %name, + key = %key, + "skipping credential handle with invalid env var key" + ); + } + } + registry.inject_env(&provider, &mut env); } @@ -1080,19 +1449,31 @@ async fn active_provider_environment_keys( } fn active_provider_credential_keys(provider: &Provider, now_ms: i64) -> Vec { - provider + let mut keys: Vec = provider .credentials .keys() .filter(|key| !is_non_injectable_provider_credential(provider, key)) .filter(|key| is_valid_env_key(key)) - .filter(|key| { - provider - .credential_expires_at_ms - .get(*key) - .is_none_or(|expires_at_ms| *expires_at_ms <= 0 || *expires_at_ms > now_ms) - }) + .filter(|key| provider_credential_not_expired(provider, key, now_ms)) .cloned() - .collect() + .collect(); + keys.extend( + provider + .credential_handles + .keys() + .filter(|key| !is_non_injectable_provider_credential(provider, key)) + .filter(|key| is_valid_env_key(key)) + .filter(|key| provider_credential_not_expired(provider, key, now_ms)) + .cloned(), + ); + keys +} + +fn provider_credential_not_expired(provider: &Provider, key: &str, now_ms: i64) -> bool { + provider + .credential_expires_at_ms + .get(key) + .is_none_or(|expires_at_ms| *expires_at_ms <= 0 || *expires_at_ms > now_ms) } fn is_non_injectable_provider_credential(provider: &Provider, key: &str) -> bool { @@ -1161,7 +1542,9 @@ pub(super) async fn handle_create_provider( return Err(Status::invalid_argument("provider is required")); }; let provider_type = provider.r#type.clone(); - let result = create_provider_record(state.store.as_ref(), provider).await; + let result = + create_provider_record_validating(state.store.as_ref(), provider, Some(&state.credentials)) + .await; match result { Ok(provider) => { emit_provider_lifecycle( @@ -1904,7 +2287,9 @@ pub(super) async fn handle_update_provider( provider .credential_expires_at_ms .extend(req.credential_expires_at_ms); - let result = update_provider_record(state.store.as_ref(), provider).await; + let result = + update_provider_record_validating(state.store.as_ref(), provider, Some(&state.credentials)) + .await; match result { Ok(provider) => { emit_provider_lifecycle( @@ -2154,6 +2539,7 @@ pub(super) async fn handle_configure_provider_refresh( credential_key.to_string(), expires_at_ms, )]), + credential_handles: std::collections::HashMap::new(), }; update_provider_record(state.store.as_ref(), updated).await?; } @@ -2180,6 +2566,7 @@ pub(super) async fn handle_rotate_provider_credential( } let refresh_state = crate::provider_refresh::refresh_provider_credential( state.store.as_ref(), + Some(&state.credentials), provider_name, credential_key, ) @@ -2249,6 +2636,7 @@ pub(super) async fn handle_delete_provider_refresh( credential_key.to_string(), 0, )]), + credential_handles: std::collections::HashMap::new(), }; update_provider_record(state.store.as_ref(), updated).await?; } @@ -2264,7 +2652,9 @@ pub(super) async fn handle_delete_provider( ) -> Result, Status> { let name = request.into_inner().name; let provider_profile = provider_profile_for_name(state.store.as_ref(), &name).await; - let result = delete_provider_record(state.store.as_ref(), &name).await; + let result = + delete_provider_record_with_credentials(state.store.as_ref(), &state.credentials, &name) + .await; match result { Ok(deleted) => { let outcome = TelemetryOutcome::from_success(deleted); @@ -2531,6 +2921,7 @@ mod tests { credentials: HashMap::new(), config: HashMap::new(), credential_expires_at_ms: HashMap::new(), + credential_handles: HashMap::new(), }, ) .await @@ -2995,6 +3386,58 @@ mod tests { .into_iter() .collect(), credential_expires_at_ms: HashMap::new(), + credential_handles: HashMap::new(), + } + } + + fn provider_with_credential_handle( + name: &str, + provider_type: &str, + credential_key: &str, + ) -> Provider { + Provider { + metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { + id: String::new(), + name: name.to_string(), + created_at_ms: 0, + labels: HashMap::new(), + resource_version: 0, + }), + r#type: provider_type.to_string(), + credentials: HashMap::new(), + config: HashMap::new(), + credential_expires_at_ms: HashMap::new(), + credential_handles: std::iter::once(( + credential_key.to_string(), + CredentialHandle { + driver: "test-static".to_string(), + handle: format!("{name}:{credential_key}"), + metadata: HashMap::new(), + }, + )) + .collect(), + } + } + + fn provider_with_credential_value( + name: &str, + provider_type: &str, + credential_key: &str, + value: &str, + ) -> Provider { + Provider { + metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { + id: String::new(), + name: name.to_string(), + created_at_ms: 0, + labels: HashMap::new(), + resource_version: 0, + }), + r#type: provider_type.to_string(), + credentials: std::iter::once((credential_key.to_string(), value.to_string())).collect(), + config: HashMap::new(), + credential_expires_at_ms: HashMap::new(), + credential_handles: HashMap::new(), } } @@ -3633,6 +4076,7 @@ mod tests { .collect(), config: HashMap::new(), credential_expires_at_ms: HashMap::new(), + credential_handles: HashMap::new(), }, ) .await @@ -3745,6 +4189,7 @@ mod tests { .collect(), config: HashMap::new(), credential_expires_at_ms: HashMap::new(), + credential_handles: HashMap::new(), }, ) .await @@ -3808,6 +4253,7 @@ mod tests { .collect(), config: HashMap::new(), credential_expires_at_ms: HashMap::new(), + credential_handles: HashMap::new(), }, ) .await @@ -3850,6 +4296,7 @@ mod tests { "MS_GRAPH_ACCESS_TOKEN".to_string(), manual_expires_at_ms, )]), + credential_handles: HashMap::new(), }, ) .await @@ -3903,6 +4350,7 @@ mod tests { .collect(), config: HashMap::new(), credential_expires_at_ms: HashMap::new(), + credential_handles: HashMap::new(), }, ) .await @@ -3922,6 +4370,7 @@ mod tests { .collect(), config: HashMap::new(), credential_expires_at_ms: HashMap::new(), + credential_handles: HashMap::new(), }, ) .await @@ -3991,6 +4440,7 @@ mod tests { credentials: HashMap::new(), config: HashMap::new(), credential_expires_at_ms: HashMap::new(), + credential_handles: HashMap::new(), }, ) .await @@ -4080,6 +4530,7 @@ mod tests { .collect(), config: HashMap::new(), credential_expires_at_ms: HashMap::new(), + credential_handles: HashMap::new(), }, ) .await @@ -4147,6 +4598,7 @@ mod tests { .collect(), config: HashMap::new(), credential_expires_at_ms: HashMap::new(), + credential_handles: HashMap::new(), }, ) .await @@ -4299,6 +4751,7 @@ mod tests { config: std::iter::once(("endpoint".to_string(), "https://gitlab.com".to_string())) .collect(), credential_expires_at_ms: HashMap::new(), + credential_handles: HashMap::new(), }, ) .await @@ -4350,23 +4803,352 @@ mod tests { } #[tokio::test] - async fn delete_provider_removes_scoped_refresh_states() { + async fn create_provider_record_stores_credentials_with_runtime() { let store = test_store().await; + let config = openshell_core::Config::new(None).with_credential_drivers(["test-static"]); + let credentials = crate::credentials::CredentialRuntime::from_config(&config).unwrap(); - let provider = create_provider_record( + let persisted = create_provider_record_validating( &store, - Provider { - credential_expires_at_ms: HashMap::from([("API_TOKEN".to_string(), 123_456)]), - ..provider_with_values("gitlab-local", "gitlab") - }, + provider_with_credential_value("openai-local", "openai", "OPENAI_API_KEY", "sk-test"), + Some(&credentials), ) .await .unwrap(); - let refresh_state = crate::provider_refresh::new_refresh_state( - &provider, - "API_TOKEN", - crate::provider_refresh::NewRefreshStateConfig { - strategy: ProviderCredentialRefreshStrategy::External, + + assert_eq!(persisted.object_name(), "openai-local"); + assert_eq!( + persisted + .credentials + .get("OPENAI_API_KEY") + .map(String::as_str), + Some("REDACTED") + ); + assert!(persisted.credential_handles.is_empty()); + + let stored: Provider = store + .get_message_by_name("openai-local") + .await + .unwrap() + .unwrap(); + assert!(stored.credentials.is_empty()); + assert_eq!( + stored + .credential_handles + .get("OPENAI_API_KEY") + .map(|handle| handle.driver.as_str()), + Some("test-static") + ); + } + + #[tokio::test] + async fn update_provider_record_overwrites_credentials_with_runtime() { + let store = test_store().await; + let config = openshell_core::Config::new(None).with_credential_drivers(["test-static"]); + let credentials = crate::credentials::CredentialRuntime::from_config(&config).unwrap(); + + create_provider_record_validating( + &store, + provider_with_credential_value("openai-local", "openai", "OPENAI_API_KEY", "sk-first"), + Some(&credentials), + ) + .await + .unwrap(); + let stored_first: Provider = store + .get_message_by_name("openai-local") + .await + .unwrap() + .unwrap(); + let first_handle = stored_first + .credential_handles + .get("OPENAI_API_KEY") + .expect("stored handle") + .handle + .clone(); + + let updated = update_provider_record_validating( + &store, + provider_with_credential_value("openai-local", "openai", "OPENAI_API_KEY", "sk-second"), + Some(&credentials), + ) + .await + .unwrap(); + assert_eq!( + updated + .credentials + .get("OPENAI_API_KEY") + .map(String::as_str), + Some("REDACTED") + ); + assert!(updated.credential_handles.is_empty()); + + let stored_second: Provider = store + .get_message_by_name("openai-local") + .await + .unwrap() + .unwrap(); + assert!(stored_second.credentials.is_empty()); + assert_eq!( + stored_second + .credential_handles + .get("OPENAI_API_KEY") + .map(|handle| handle.handle.as_str()), + Some(first_handle.as_str()) + ); + + let result = resolve_provider_environment_with_credentials( + &store, + &["openai-local".to_string()], + &credentials, + ) + .await + .unwrap(); + assert_eq!(result.get("OPENAI_API_KEY"), Some(&"sk-second".to_string())); + } + + #[tokio::test] + async fn update_provider_record_with_runtime_preserves_legacy_inline_credentials_on_noop() { + let store = test_store().await; + let config = openshell_core::Config::new(None).with_credential_drivers(["test-static"]); + let credentials = crate::credentials::CredentialRuntime::from_config(&config).unwrap(); + + create_provider_record(&store, provider_with_values("legacy-provider", "openai")) + .await + .unwrap(); + + let updated = update_provider_record_validating( + &store, + Provider { + metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { + id: String::new(), + name: "legacy-provider".to_string(), + created_at_ms: 0, + labels: HashMap::new(), + resource_version: 0, + }), + r#type: String::new(), + credentials: HashMap::new(), + config: std::iter::once(( + "endpoint".to_string(), + "https://updated.example.com".to_string(), + )) + .collect(), + credential_expires_at_ms: HashMap::new(), + credential_handles: HashMap::new(), + }, + Some(&credentials), + ) + .await + .unwrap(); + assert_eq!(updated.credentials.len(), 2); + assert!(updated.credential_handles.is_empty()); + + let stored: Provider = store + .get_message_by_name("legacy-provider") + .await + .unwrap() + .unwrap(); + assert_eq!( + stored.credentials.get("API_TOKEN").map(String::as_str), + Some("token-123") + ); + assert_eq!( + stored.credentials.get("SECONDARY").map(String::as_str), + Some("secondary-token") + ); + assert!(stored.credential_handles.is_empty()); + } + + #[tokio::test] + async fn update_provider_record_with_runtime_stores_only_updated_legacy_inline_credentials() { + let store = test_store().await; + let config = openshell_core::Config::new(None).with_credential_drivers(["test-static"]); + let credentials = crate::credentials::CredentialRuntime::from_config(&config).unwrap(); + + create_provider_record(&store, provider_with_values("legacy-provider", "openai")) + .await + .unwrap(); + + update_provider_record_validating( + &store, + Provider { + metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { + id: String::new(), + name: "legacy-provider".to_string(), + created_at_ms: 0, + labels: HashMap::new(), + resource_version: 0, + }), + r#type: String::new(), + credentials: std::iter::once(( + "API_TOKEN".to_string(), + "rotated-token".to_string(), + )) + .collect(), + config: HashMap::new(), + credential_expires_at_ms: HashMap::new(), + credential_handles: HashMap::new(), + }, + Some(&credentials), + ) + .await + .unwrap(); + + let stored: Provider = store + .get_message_by_name("legacy-provider") + .await + .unwrap() + .unwrap(); + assert!(!stored.credentials.contains_key("API_TOKEN")); + assert_eq!( + stored.credentials.get("SECONDARY").map(String::as_str), + Some("secondary-token") + ); + assert_eq!( + stored + .credential_handles + .get("API_TOKEN") + .map(|handle| handle.driver.as_str()), + Some("test-static") + ); + + let result = resolve_provider_environment_with_credentials( + &store, + &["legacy-provider".to_string()], + &credentials, + ) + .await + .unwrap(); + assert_eq!(result.get("API_TOKEN"), Some(&"rotated-token".to_string())); + assert_eq!( + result.get("SECONDARY"), + Some(&"secondary-token".to_string()) + ); + } + + #[tokio::test] + async fn handle_create_provider_rejects_user_supplied_credential_handles() { + let state = test_server_state().await; + + let err = handle_create_provider( + &state, + Request::new(CreateProviderRequest { + provider: Some(provider_with_credential_handle( + "openai-ref", + "openai", + "OPENAI_API_KEY", + )), + }), + ) + .await + .unwrap_err(); + + assert_eq!(err.code(), Code::InvalidArgument); + assert!(err.message().contains("internal gateway state")); + } + + #[tokio::test] + async fn handle_create_provider_stores_inline_credentials_with_enabled_driver() { + let mut state = test_server_state().await; + let config = state + .config + .clone() + .with_credential_drivers(["test-static"]); + let credentials = crate::credentials::CredentialRuntime::from_config(&config).unwrap(); + let state_mut = Arc::get_mut(&mut state).unwrap(); + state_mut.config = config; + state_mut.credentials = credentials; + + let response = handle_create_provider( + &state, + Request::new(CreateProviderRequest { + provider: Some(provider_with_credential_value( + "openai-local", + "openai", + "OPENAI_API_KEY", + "sk-test", + )), + }), + ) + .await + .unwrap() + .into_inner(); + + let provider = response.provider.expect("provider"); + assert_eq!( + provider + .credentials + .get("OPENAI_API_KEY") + .map(String::as_str), + Some("REDACTED") + ); + assert!(provider.credential_handles.is_empty()); + + let stored: Provider = state + .store + .get_message_by_name("openai-local") + .await + .unwrap() + .unwrap(); + assert!(stored.credentials.is_empty()); + assert!(stored.credential_handles.contains_key("OPENAI_API_KEY")); + + let result = resolve_provider_environment_with_credentials( + state.store.as_ref(), + &["openai-local".to_string()], + &state.credentials, + ) + .await + .unwrap(); + assert_eq!(result.get("OPENAI_API_KEY"), Some(&"sk-test".to_string())); + } + + #[tokio::test] + async fn handle_update_provider_rejects_user_supplied_credential_handles() { + let state = test_server_state().await; + create_provider_record( + state.store.as_ref(), + provider_with_values("openai-local", "openai"), + ) + .await + .unwrap(); + + let err = handle_update_provider( + &state, + Request::new(UpdateProviderRequest { + provider: Some(provider_with_credential_handle( + "openai-local", + "openai", + "OPENAI_API_KEY", + )), + credential_expires_at_ms: HashMap::new(), + }), + ) + .await + .unwrap_err(); + + assert_eq!(err.code(), Code::InvalidArgument); + assert!(err.message().contains("internal gateway state")); + } + + #[tokio::test] + async fn delete_provider_removes_scoped_refresh_states() { + let store = test_store().await; + + let provider = create_provider_record( + &store, + Provider { + credential_expires_at_ms: HashMap::from([("API_TOKEN".to_string(), 123_456)]), + ..provider_with_values("gitlab-local", "gitlab") + }, + ) + .await + .unwrap(); + let refresh_state = crate::provider_refresh::new_refresh_state( + &provider, + "API_TOKEN", + crate::provider_refresh::NewRefreshStateConfig { + strategy: ProviderCredentialRefreshStrategy::External, material: HashMap::from([( "endpoint".to_string(), "https://refresh.example.com".to_string(), @@ -4464,6 +5246,7 @@ mod tests { .collect(), config: HashMap::new(), credential_expires_at_ms: HashMap::new(), + credential_handles: HashMap::new(), }, ) .await @@ -4493,6 +5276,7 @@ mod tests { .collect(), config: HashMap::new(), credential_expires_at_ms: HashMap::new(), + credential_handles: HashMap::new(), }, ) .await @@ -4523,6 +5307,7 @@ mod tests { credentials: HashMap::new(), config: HashMap::new(), credential_expires_at_ms: HashMap::new(), + credential_handles: HashMap::new(), }, ) .await @@ -4543,6 +5328,7 @@ mod tests { credentials: HashMap::new(), config: HashMap::new(), credential_expires_at_ms: HashMap::new(), + credential_handles: HashMap::new(), }, ) .await @@ -4617,6 +5403,7 @@ mod tests { credentials: HashMap::new(), config: HashMap::new(), credential_expires_at_ms: HashMap::new(), + credential_handles: HashMap::new(), }, ) .await @@ -4653,6 +5440,7 @@ mod tests { credentials: HashMap::new(), config: HashMap::new(), credential_expires_at_ms: HashMap::new(), + credential_handles: HashMap::new(), }, ) .await @@ -4689,6 +5477,7 @@ mod tests { credentials: HashMap::new(), config: HashMap::new(), credential_expires_at_ms: HashMap::new(), + credential_handles: HashMap::new(), }, ) .await @@ -4709,6 +5498,7 @@ mod tests { credentials: HashMap::new(), config: HashMap::new(), credential_expires_at_ms: HashMap::new(), + credential_handles: HashMap::new(), }, ) .await @@ -4735,6 +5525,7 @@ mod tests { credentials: HashMap::new(), config: HashMap::new(), credential_expires_at_ms: HashMap::new(), + credential_handles: HashMap::new(), }, ) .await @@ -4763,6 +5554,7 @@ mod tests { credentials: HashMap::new(), config: HashMap::new(), credential_expires_at_ms: HashMap::new(), + credential_handles: HashMap::new(), }, ) .await @@ -4810,6 +5602,7 @@ mod tests { credentials: std::iter::once(("SECONDARY".to_string(), String::new())).collect(), config: std::iter::once(("region".to_string(), String::new())).collect(), credential_expires_at_ms: HashMap::new(), + credential_handles: HashMap::new(), }, ) .await @@ -4861,6 +5654,7 @@ mod tests { credentials: HashMap::new(), config: HashMap::new(), credential_expires_at_ms: HashMap::new(), + credential_handles: HashMap::new(), }, ) .await @@ -4890,6 +5684,7 @@ mod tests { credentials: HashMap::new(), config: HashMap::new(), credential_expires_at_ms: HashMap::new(), + credential_handles: HashMap::new(), }, ) .await @@ -4921,6 +5716,7 @@ mod tests { credentials: std::iter::once((oversized_key, "value".to_string())).collect(), config: HashMap::new(), credential_expires_at_ms: HashMap::new(), + credential_handles: HashMap::new(), }, ) .await @@ -4949,6 +5745,7 @@ mod tests { credentials: std::iter::once(("API_TOKEN".to_string(), "old".to_string())).collect(), config: HashMap::new(), credential_expires_at_ms: HashMap::new(), + credential_handles: HashMap::new(), }; store.put_message(&legacy).await.unwrap(); @@ -4967,6 +5764,7 @@ mod tests { .collect(), config: HashMap::new(), credential_expires_at_ms: HashMap::new(), + credential_handles: HashMap::new(), }, ) .await @@ -5006,6 +5804,7 @@ mod tests { )) .collect(), credential_expires_at_ms: HashMap::new(), + credential_handles: HashMap::new(), }; create_provider_record(&store, provider).await.unwrap(); @@ -5035,6 +5834,57 @@ mod tests { assert!(result.dynamic_credentials.is_empty()); } + #[tokio::test] + async fn resolve_provider_env_rejects_unresolvable_credential_handle() { + let store = test_store().await; + let config = openshell_core::Config::new(None).with_credential_drivers(["test-static"]); + let credentials = crate::credentials::CredentialRuntime::from_config(&config).unwrap(); + create_provider_record_validating( + &store, + provider_with_credential_value("openai-local", "openai", "OPENAI_API_KEY", "sk-test"), + Some(&credentials), + ) + .await + .unwrap(); + let other_credentials = + crate::credentials::CredentialRuntime::from_config(&config).unwrap(); + + let err = resolve_provider_environment_with_credentials( + &store, + &["openai-local".to_string()], + &other_credentials, + ) + .await + .unwrap_err(); + + assert_eq!(err.code(), Code::NotFound); + assert!(err.message().contains("credential handle")); + } + + #[tokio::test] + async fn resolve_provider_env_resolves_credential_handles_with_runtime() { + let store = test_store().await; + let config = openshell_core::Config::new(None).with_credential_drivers(["test-static"]); + let credentials = crate::credentials::CredentialRuntime::from_config(&config).unwrap(); + create_provider_record_validating( + &store, + provider_with_credential_value("openai-local", "openai", "OPENAI_API_KEY", "sk-test"), + Some(&credentials), + ) + .await + .unwrap(); + + let result = resolve_provider_environment_with_credentials( + &store, + &["openai-local".to_string()], + &credentials, + ) + .await + .unwrap(); + + assert_eq!(result.get("OPENAI_API_KEY"), Some(&"sk-test".to_string())); + } + #[tokio::test] async fn resolve_provider_env_skips_expired_credentials_and_returns_expiry_metadata() { let store = test_store().await; @@ -5061,6 +5911,7 @@ mod tests { ] .into_iter() .collect(), + credential_handles: HashMap::new(), }; create_provider_record(&store, provider).await.unwrap(); @@ -5106,6 +5957,7 @@ mod tests { .collect(), config: HashMap::new(), credential_expires_at_ms: HashMap::new(), + credential_handles: HashMap::new(), }; create_provider_record(&store, provider).await.unwrap(); @@ -5138,6 +5990,7 @@ mod tests { .collect(), config: HashMap::new(), credential_expires_at_ms: HashMap::new(), + credential_handles: HashMap::new(), }, ) .await @@ -5157,6 +6010,7 @@ mod tests { .collect(), config: HashMap::new(), credential_expires_at_ms: HashMap::new(), + credential_handles: HashMap::new(), }, ) .await @@ -5190,6 +6044,7 @@ mod tests { .collect(), config: HashMap::new(), credential_expires_at_ms: HashMap::new(), + credential_handles: HashMap::new(), }, ) .await @@ -5212,6 +6067,7 @@ mod tests { .collect(), config: HashMap::new(), credential_expires_at_ms: HashMap::new(), + credential_handles: HashMap::new(), }, ) .await @@ -5229,6 +6085,52 @@ mod tests { assert!(err.message().contains("provider-b")); } + #[tokio::test] + async fn validate_provider_environment_keys_unique_includes_credential_handles() { + let store = test_store().await; + let config = openshell_core::Config::new(None).with_credential_drivers(["test-static"]); + let credentials = crate::credentials::CredentialRuntime::from_config(&config).unwrap(); + create_provider_record( + &store, + Provider { + metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { + id: String::new(), + name: "provider-a".to_string(), + created_at_ms: 0, + labels: HashMap::new(), + resource_version: 0, + }), + r#type: "claude".to_string(), + credentials: std::iter::once(("SHARED_KEY".to_string(), "first-value".to_string())) + .collect(), + config: HashMap::new(), + credential_expires_at_ms: HashMap::new(), + credential_handles: HashMap::new(), + }, + ) + .await + .unwrap(); + create_provider_record_validating( + &store, + provider_with_credential_value("provider-b", "gitlab", "SHARED_KEY", "second-value"), + Some(&credentials), + ) + .await + .unwrap(); + + let err = validate_provider_environment_keys_unique( + &store, + &["provider-a".to_string(), "provider-b".to_string()], + ) + .await + .unwrap_err(); + + assert_eq!(err.code(), Code::FailedPrecondition); + assert!(err.message().contains("SHARED_KEY")); + assert!(err.message().contains("provider-a")); + assert!(err.message().contains("provider-b")); + } + #[tokio::test] async fn resolve_provider_env_injects_vertex_agent_config() { let store = test_store().await; @@ -5258,6 +6160,7 @@ mod tests { .into_iter() .collect(), credential_expires_at_ms: HashMap::new(), + credential_handles: HashMap::new(), }, ) .await @@ -5331,6 +6234,7 @@ mod tests { .collect(), config: HashMap::new(), credential_expires_at_ms: HashMap::new(), + credential_handles: HashMap::new(), }, ) .await @@ -5368,6 +6272,7 @@ mod tests { .collect(), config: HashMap::new(), credential_expires_at_ms: HashMap::new(), + credential_handles: HashMap::new(), }, ) .await @@ -5426,6 +6331,7 @@ mod tests { .into_iter() .collect(), credential_expires_at_ms: HashMap::new(), + credential_handles: HashMap::new(), }, ) .await @@ -5460,6 +6366,7 @@ mod tests { .collect(), config: HashMap::new(), credential_expires_at_ms: HashMap::new(), + credential_handles: HashMap::new(), }, ) .await @@ -5501,6 +6408,7 @@ mod tests { .collect(), config: HashMap::new(), credential_expires_at_ms: HashMap::new(), + credential_handles: HashMap::new(), }, ) .await @@ -5523,6 +6431,7 @@ mod tests { .collect(), config: HashMap::new(), credential_expires_at_ms: HashMap::new(), + credential_handles: HashMap::new(), }, ) .await @@ -5561,6 +6470,7 @@ mod tests { .collect(), config: HashMap::new(), credential_expires_at_ms: HashMap::new(), + credential_handles: HashMap::new(), }, ) .await @@ -5595,6 +6505,7 @@ mod tests { .collect(), config: HashMap::new(), credential_expires_at_ms: HashMap::new(), + credential_handles: HashMap::new(), }, ) .await @@ -5695,6 +6606,7 @@ mod tests { credentials: HashMap::new(), config: HashMap::new(), credential_expires_at_ms: HashMap::new(), + credential_handles: HashMap::new(), }; // Attempt to update with an oversized credential key (exceeds MAX_MAP_KEY_LEN) @@ -5832,6 +6744,7 @@ mod tests { // Prepare an update with the correct resource_version let mut updated_provider = current.clone(); + updated_provider.credential_handles.clear(); updated_provider .credentials .insert("NEW_KEY".to_string(), "new-value".to_string()); @@ -5900,6 +6813,7 @@ mod tests { // Prepare an update with a stale resource_version let mut stale_provider = current.clone(); + stale_provider.credential_handles.clear(); stale_provider .credentials .insert("NEW_KEY".to_string(), "new-value".to_string()); @@ -5936,6 +6850,70 @@ mod tests { current_version ); assert!(!unchanged.credentials.contains_key("NEW_KEY")); + assert!(!unchanged.credential_handles.contains_key("NEW_KEY")); + } + + #[tokio::test] + async fn update_provider_stale_version_does_not_overwrite_stored_credential() { + let mut state = test_server_state().await; + let config = state + .config + .clone() + .with_credential_drivers(["test-static"]); + let credentials = crate::credentials::CredentialRuntime::from_config(&config).unwrap(); + let state_mut = Arc::get_mut(&mut state).unwrap(); + state_mut.config = config; + state_mut.credentials = credentials; + + handle_create_provider( + &state, + Request::new(CreateProviderRequest { + provider: Some(provider_with_credential_value( + "openai-local", + "openai", + "OPENAI_API_KEY", + "sk-first", + )), + }), + ) + .await + .unwrap(); + + let current = state + .store + .get_message_by_name::("openai-local") + .await + .unwrap() + .unwrap(); + let mut stale_provider = current.clone(); + stale_provider.credential_handles.clear(); + stale_provider + .credentials + .insert("OPENAI_API_KEY".to_string(), "sk-stale".to_string()); + stale_provider.metadata.as_mut().unwrap().resource_version = 99; + + let err = handle_update_provider( + &state, + Request::new(UpdateProviderRequest { + provider: Some(stale_provider), + credential_expires_at_ms: HashMap::new(), + }), + ) + .await + .unwrap_err(); + assert_eq!(err.code(), Code::Aborted); + + let resolved = resolve_provider_environment_with_credentials( + state.store.as_ref(), + &["openai-local".to_string()], + &state.credentials, + ) + .await + .unwrap(); + assert_eq!( + resolved.get("OPENAI_API_KEY"), + Some(&"sk-first".to_string()) + ); } #[tokio::test] @@ -5970,6 +6948,7 @@ mod tests { for i in 0..3 { let state_clone = Arc::clone(&state); let mut updated = initial.clone(); + updated.credential_handles.clear(); updated .credentials .insert(format!("KEY_{i}"), format!("value-{i}")); @@ -6024,7 +7003,11 @@ mod tests { // Exactly one of KEY_0, KEY_1, or KEY_2 should be present let new_keys_count = (0..3) - .filter(|i| final_provider.credentials.contains_key(&format!("KEY_{i}"))) + .filter(|i| { + final_provider + .credential_handles + .contains_key(&format!("KEY_{i}")) + }) .count(); assert_eq!(new_keys_count, 1); } @@ -6042,6 +7025,7 @@ mod tests { credentials: HashMap::new(), config, credential_expires_at_ms: HashMap::new(), + credential_handles: HashMap::new(), } } @@ -6144,6 +7128,7 @@ mod tests { credentials: HashMap::new(), config: HashMap::from([("project_id".to_string(), "should-be-ignored".to_string())]), credential_expires_at_ms: HashMap::new(), + credential_handles: HashMap::new(), }; let mut env = HashMap::new(); openshell_providers::ProviderRegistry::new().inject_env(&provider, &mut env); diff --git a/crates/openshell-server/src/grpc/sandbox.rs b/crates/openshell-server/src/grpc/sandbox.rs index 28377394f..cbe4e0e5e 100644 --- a/crates/openshell-server/src/grpc/sandbox.rs +++ b/crates/openshell-server/src/grpc/sandbox.rs @@ -2204,6 +2204,7 @@ mod tests { .collect(), config: HashMap::new(), credential_expires_at_ms: HashMap::new(), + credential_handles: HashMap::new(), } } diff --git a/crates/openshell-server/src/grpc/validation.rs b/crates/openshell-server/src/grpc/validation.rs index b3680c6e7..b38ee3859 100644 --- a/crates/openshell-server/src/grpc/validation.rs +++ b/crates/openshell-server/src/grpc/validation.rs @@ -9,7 +9,8 @@ #![allow(clippy::result_large_err)] // Validation returns Result<_, Status> use openshell_core::proto::{ - ExecSandboxRequest, Provider, SandboxPolicy as ProtoSandboxPolicy, SandboxTemplate, + CredentialHandle, ExecSandboxRequest, Provider, SandboxPolicy as ProtoSandboxPolicy, + SandboxTemplate, }; use prost::Message; use tonic::Status; @@ -344,6 +345,8 @@ pub(super) fn validate_provider_mutable_fields(provider: &Provider) -> Result<() MAX_MAP_VALUE_LEN, "provider.credentials", )?; + validate_provider_credential_handles(&provider.credential_handles)?; + validate_provider_credential_sources(provider)?; validate_string_map( &provider.config, MAX_PROVIDER_CONFIG_ENTRIES, @@ -373,6 +376,99 @@ pub(super) fn validate_provider_mutable_fields(provider: &Provider) -> Result<() Ok(()) } +fn validate_provider_credential_sources(provider: &Provider) -> Result<(), Status> { + let total_credentials = provider.credentials.len() + provider.credential_handles.len(); + if total_credentials > MAX_PROVIDER_CREDENTIALS_ENTRIES { + return Err(Status::invalid_argument(format!( + "provider credential sources exceed maximum entries ({total_credentials} > {MAX_PROVIDER_CREDENTIALS_ENTRIES})" + ))); + } + + for key in provider.credential_handles.keys() { + if provider.credentials.contains_key(key) { + return Err(Status::invalid_argument(format!( + "provider credential key '{key}' cannot be present in both provider.credentials and provider.credential_handles" + ))); + } + } + Ok(()) +} + +fn validate_provider_credential_handles( + credential_handles: &std::collections::HashMap, +) -> Result<(), Status> { + if credential_handles.len() > MAX_PROVIDER_CREDENTIALS_ENTRIES { + return Err(Status::invalid_argument(format!( + "provider.credential_handles exceeds maximum entries ({} > {MAX_PROVIDER_CREDENTIALS_ENTRIES})", + credential_handles.len() + ))); + } + + for (credential_key, handle) in credential_handles { + if credential_key.len() > MAX_MAP_KEY_LEN { + return Err(Status::invalid_argument(format!( + "provider.credential_handles key exceeds maximum length ({} > {MAX_MAP_KEY_LEN})", + credential_key.len() + ))); + } + if !super::provider::is_valid_env_key(credential_key) { + return Err(Status::invalid_argument(format!( + "provider.credential_handles keys must match ^[A-Za-z_][A-Za-z0-9_]*$; got '{credential_key}'" + ))); + } + validate_credential_handle( + handle, + &format!("provider.credential_handles['{credential_key}']"), + )?; + } + + Ok(()) +} + +fn validate_credential_handle(handle: &CredentialHandle, field_name: &str) -> Result<(), Status> { + validate_required_credential_handle_string(&handle.driver, field_name, "driver")?; + validate_required_credential_handle_string(&handle.handle, field_name, "handle")?; + validate_string_map( + &handle.metadata, + MAX_PROVIDER_CONFIG_ENTRIES, + MAX_MAP_KEY_LEN, + MAX_MAP_VALUE_LEN, + &format!("{field_name}.metadata"), + )?; + for (key, value) in &handle.metadata { + reject_control_chars(key, &format!("{field_name}.metadata key"))?; + reject_control_chars(value, &format!("{field_name}.metadata value for '{key}'"))?; + } + Ok(()) +} + +fn validate_required_credential_handle_string( + value: &str, + field_name: &str, + component: &str, +) -> Result<(), Status> { + if value.trim().is_empty() { + return Err(Status::invalid_argument(format!( + "{field_name}.{component} is required" + ))); + } + validate_optional_credential_handle_string(value, field_name, component) +} + +fn validate_optional_credential_handle_string( + value: &str, + field_name: &str, + component: &str, +) -> Result<(), Status> { + if value.len() > MAX_MAP_VALUE_LEN { + return Err(Status::invalid_argument(format!( + "{field_name}.{component} exceeds maximum length ({} > {MAX_MAP_VALUE_LEN})", + value.len() + ))); + } + reject_control_chars(value, &format!("{field_name}.{component}")) +} + // --------------------------------------------------------------------------- // Label selector validation // --------------------------------------------------------------------------- @@ -1118,6 +1214,18 @@ mod tests { std::iter::once(("KEY".to_string(), "val".to_string())).collect() } + fn one_credential_handle() -> HashMap { + std::iter::once(( + "API_KEY".to_string(), + CredentialHandle { + driver: "kubernetes-secrets".to_string(), + handle: "v1:openshell:provider-secret".to_string(), + metadata: HashMap::new(), + }, + )) + .collect() + } + fn make_test_provider( name: &str, provider_type: &str, @@ -1136,6 +1244,7 @@ mod tests { credentials, config, credential_expires_at_ms: HashMap::new(), + credential_handles: HashMap::new(), } } @@ -1150,6 +1259,72 @@ mod tests { assert!(validate_provider_fields(&provider).is_ok()); } + #[test] + fn validate_provider_fields_accepts_credential_handles() { + let mut provider = + make_test_provider("my-provider", "claude", HashMap::new(), HashMap::new()); + provider.credential_handles = one_credential_handle(); + + assert!(validate_provider_fields(&provider).is_ok()); + } + + #[test] + fn validate_provider_fields_rejects_duplicate_inline_and_referenced_key() { + let mut provider = make_test_provider( + "my-provider", + "claude", + std::iter::once(("API_KEY".to_string(), "inline".to_string())).collect(), + HashMap::new(), + ); + provider.credential_handles = one_credential_handle(); + + let err = validate_provider_fields(&provider).unwrap_err(); + assert_eq!(err.code(), Code::InvalidArgument); + assert!(err.message().contains("provider.credentials")); + assert!(err.message().contains("provider.credential_handles")); + } + + #[test] + fn validate_provider_fields_rejects_too_many_combined_credential_sources() { + let refs: HashMap = (0..MAX_PROVIDER_CREDENTIALS_ENTRIES) + .map(|i| { + ( + format!("REF_{i}"), + CredentialHandle { + driver: "test".to_string(), + handle: format!("handle-{i}"), + metadata: HashMap::new(), + }, + ) + }) + .collect(); + let mut provider = make_test_provider("ok", "claude", one_credential(), HashMap::new()); + provider.credential_handles = refs; + + let err = validate_provider_fields(&provider).unwrap_err(); + assert_eq!(err.code(), Code::InvalidArgument); + assert!(err.message().contains("credential sources")); + } + + #[test] + fn validate_provider_fields_rejects_credential_handle_missing_handle() { + let mut provider = + make_test_provider("my-provider", "claude", HashMap::new(), HashMap::new()); + provider.credential_handles = std::iter::once(( + "API_KEY".to_string(), + CredentialHandle { + driver: "test".to_string(), + handle: String::new(), + metadata: HashMap::new(), + }, + )) + .collect(); + + let err = validate_provider_fields(&provider).unwrap_err(); + assert_eq!(err.code(), Code::InvalidArgument); + assert!(err.message().contains("handle is required")); + } + #[test] fn validate_provider_fields_rejects_over_limit_name() { let provider = make_test_provider( diff --git a/crates/openshell-server/src/inference.rs b/crates/openshell-server/src/inference.rs index 43416c35d..61f7944cd 100644 --- a/crates/openshell-server/src/inference.rs +++ b/crates/openshell-server/src/inference.rs @@ -73,9 +73,12 @@ impl Inference for InferenceService { .extensions() .get::(), )?; - resolve_inference_bundle(self.state.store.as_ref()) - .await - .map(Response::new) + resolve_inference_bundle_with_credentials( + self.state.store.as_ref(), + Some(&self.state.credentials), + ) + .await + .map(Response::new) } #[rpc_auth(auth = "bearer", scope = "inference:write", role = "admin")] @@ -86,8 +89,9 @@ impl Inference for InferenceService { let req = request.into_inner(); let route_name = effective_route_name(&req.route_name)?; let verify = !req.no_verify; - let route = upsert_cluster_inference_route( + let route = upsert_cluster_inference_route_with_credentials( self.state.store.as_ref(), + Some(&self.state.credentials), route_name, &req.provider_name, &req.model_id, @@ -153,6 +157,7 @@ impl Inference for InferenceService { } } +#[cfg(test)] async fn upsert_cluster_inference_route( store: &Store, route_name: &str, @@ -160,6 +165,27 @@ async fn upsert_cluster_inference_route( model_id: &str, timeout_secs: u64, verify: bool, +) -> Result { + upsert_cluster_inference_route_with_credentials( + store, + None, + route_name, + provider_name, + model_id, + timeout_secs, + verify, + ) + .await +} + +async fn upsert_cluster_inference_route_with_credentials( + store: &Store, + credentials: Option<&crate::credentials::CredentialRuntime>, + route_name: &str, + provider_name: &str, + model_id: &str, + timeout_secs: u64, + verify: bool, ) -> Result { if provider_name.trim().is_empty() { return Err(Status::invalid_argument("provider_name is required")); @@ -175,6 +201,7 @@ async fn upsert_cluster_inference_route( .ok_or_else(|| { Status::failed_precondition(format!("provider '{provider_name}' not found")) })?; + let provider = resolve_provider_credentials(provider, credentials).await?; let resolved = resolve_provider_route(&provider, model_id)?; let validation = if verify { @@ -888,12 +915,26 @@ fn authorize_inference_bundle( } /// Resolve the inference bundle (all managed routes + revision hash). +#[cfg(test)] async fn resolve_inference_bundle(store: &Store) -> Result { + resolve_inference_bundle_with_credentials(store, None).await +} + +async fn resolve_inference_bundle_with_credentials( + store: &Store, + credentials: Option<&crate::credentials::CredentialRuntime>, +) -> Result { let mut routes = Vec::new(); - if let Some(r) = resolve_route_by_name(store, CLUSTER_INFERENCE_ROUTE_NAME).await? { + if let Some(r) = + resolve_route_by_name_with_credentials(store, credentials, CLUSTER_INFERENCE_ROUTE_NAME) + .await? + { routes.push(r); } - if let Some(r) = resolve_route_by_name(store, SANDBOX_SYSTEM_ROUTE_NAME).await? { + if let Some(r) = + resolve_route_by_name_with_credentials(store, credentials, SANDBOX_SYSTEM_ROUTE_NAME) + .await? + { routes.push(r); } @@ -930,9 +971,18 @@ async fn resolve_inference_bundle(store: &Store) -> Result Result, Status> { + resolve_route_by_name_with_credentials(store, None, route_name).await +} + +async fn resolve_route_by_name_with_credentials( + store: &Store, + credentials: Option<&crate::credentials::CredentialRuntime>, + route_name: &str, ) -> Result, Status> { let route = store .get_message_by_name::(route_name) @@ -969,6 +1019,7 @@ async fn resolve_route_by_name( config.provider_name )) })?; + let provider = resolve_provider_credentials(provider, credentials).await?; let resolved = resolve_provider_route(&provider, &config.model_id)?; @@ -985,6 +1036,30 @@ async fn resolve_route_by_name( })) } +async fn resolve_provider_credentials( + mut provider: Provider, + credentials: Option<&crate::credentials::CredentialRuntime>, +) -> Result { + if provider.credential_handles.is_empty() { + return Ok(provider); + } + + let credentials = credentials.ok_or_else(|| { + Status::failed_precondition(format!( + "provider '{}' stores credentials as handles, but credential storage is unavailable", + provider.object_name() + )) + })?; + let resolved = credentials + .resolve_provider_handles(&provider, current_time_ms()) + .await?; + provider.credentials.extend(resolved.values); + provider + .credential_expires_at_ms + .extend(resolved.expires_at_ms); + Ok(provider) +} + #[cfg(test)] mod tests { use super::*; @@ -1053,6 +1128,7 @@ mod tests { credentials: std::iter::once((key_name.to_string(), key_value.to_string())).collect(), config: std::collections::HashMap::new(), credential_expires_at_ms: std::collections::HashMap::new(), + credential_handles: std::collections::HashMap::new(), } } @@ -1161,6 +1237,7 @@ mod tests { )) .collect(), credential_expires_at_ms: std::collections::HashMap::new(), + credential_handles: std::collections::HashMap::new(), }; store .put_message(&provider) @@ -1231,6 +1308,7 @@ mod tests { // Intentionally no BEDROCK_BASE_URL. config: std::collections::HashMap::new(), credential_expires_at_ms: std::collections::HashMap::new(), + credential_handles: std::collections::HashMap::new(), }; store .put_message(&provider) @@ -1278,6 +1356,7 @@ mod tests { )) .collect(), credential_expires_at_ms: std::collections::HashMap::new(), + credential_handles: std::collections::HashMap::new(), }; store .put_message(&provider) @@ -1504,6 +1583,7 @@ mod tests { )) .collect(), credential_expires_at_ms: std::collections::HashMap::new(), + credential_handles: std::collections::HashMap::new(), }; store .put_message(&provider) @@ -1550,6 +1630,74 @@ mod tests { ); } + #[tokio::test] + async fn managed_route_resolves_default_credential_handles() { + let store = test_store().await; + let credentials = crate::credentials::CredentialRuntime::from_config_with_store( + &openshell_core::Config::new(None), + Arc::new(store.clone()), + ) + .expect("credential runtime should connect to default encrypted store"); + let handles = credentials + .store_provider_credentials( + "openai-dev", + &std::collections::HashMap::from([( + "OPENAI_API_KEY".to_string(), + "sk-encrypted".to_string(), + )]), + &std::collections::HashMap::new(), + ) + .await + .expect("credential should be stored"); + + let provider = Provider { + metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { + id: "provider-1".to_string(), + name: "openai-dev".to_string(), + created_at_ms: 1_000_000, + labels: std::collections::HashMap::new(), + resource_version: 0, + }), + r#type: "openai".to_string(), + credentials: std::collections::HashMap::new(), + config: std::iter::once(( + "OPENAI_BASE_URL".to_string(), + "https://station.example.com/v1".to_string(), + )) + .collect(), + credential_expires_at_ms: std::collections::HashMap::new(), + credential_handles: handles, + }; + store + .put_message(&provider) + .await + .expect("provider should persist"); + + upsert_cluster_inference_route_with_credentials( + &store, + Some(&credentials), + CLUSTER_INFERENCE_ROUTE_NAME, + "openai-dev", + "test/model", + 0, + false, + ) + .await + .expect("route should be created from handle-backed provider"); + + let managed = resolve_route_by_name_with_credentials( + &store, + Some(&credentials), + CLUSTER_INFERENCE_ROUTE_NAME, + ) + .await + .expect("route should resolve") + .expect("managed route should exist"); + + assert_eq!(managed.base_url, "https://station.example.com/v1"); + assert_eq!(managed.api_key, "sk-encrypted"); + } + #[tokio::test] async fn resolve_managed_route_reflects_provider_key_rotation() { let store = test_store().await; @@ -1579,6 +1727,7 @@ mod tests { .collect(), config: provider.config.clone(), credential_expires_at_ms: provider.credential_expires_at_ms.clone(), + credential_handles: std::collections::HashMap::new(), }; store .put_message(&rotated_provider) @@ -1645,6 +1794,7 @@ mod tests { .into_iter() .collect(), credential_expires_at_ms: std::collections::HashMap::new(), + credential_handles: std::collections::HashMap::new(), }; store .put_message(&provider) @@ -1974,6 +2124,7 @@ mod tests { .collect(), config, credential_expires_at_ms: std::collections::HashMap::new(), + credential_handles: std::collections::HashMap::new(), } } diff --git a/crates/openshell-server/src/lib.rs b/crates/openshell-server/src/lib.rs index dda8708e0..e3b43f68c 100644 --- a/crates/openshell-server/src/lib.rs +++ b/crates/openshell-server/src/lib.rs @@ -24,6 +24,7 @@ pub mod certgen; pub mod cli; mod compute; pub mod config_file; +mod credentials; mod defaults; mod grpc; mod http; @@ -85,6 +86,9 @@ pub struct ServerState { /// Compute orchestration over the configured driver. pub compute: ComputeRuntime, + /// Credential-driver selection and resolution runtime. + pub credentials: credentials::CredentialRuntime, + /// In-memory sandbox correlation index. pub sandbox_index: SandboxIndex, @@ -168,12 +172,43 @@ impl ServerState { tracing_log_bus: TracingLogBus, supervisor_sessions: Arc, oidc_cache: Option>, + ) -> Self { + let credentials = + credentials::CredentialRuntime::from_config_with_store(&config, Arc::clone(&store)) + .expect("server config should be validated before ServerState::new"); + Self::new_with_credentials( + config, + store, + compute, + sandbox_index, + sandbox_watch_bus, + tracing_log_bus, + supervisor_sessions, + oidc_cache, + credentials, + ) + } + + /// Create new server state with an already-initialized credential runtime. + #[must_use] + #[allow(clippy::too_many_arguments)] + pub fn new_with_credentials( + config: Config, + store: Arc, + compute: ComputeRuntime, + sandbox_index: SandboxIndex, + sandbox_watch_bus: SandboxWatchBus, + tracing_log_bus: TracingLogBus, + supervisor_sessions: Arc, + oidc_cache: Option>, + credentials: credentials::CredentialRuntime, ) -> Self { let grpc_rate_limiter = multiplex::GrpcRateLimiter::from_config(&config); Self { config, store, compute, + credentials, sandbox_index, sandbox_watch_bus, tracing_log_bus, @@ -211,6 +246,12 @@ pub async fn run_server( } let store = Arc::new(Store::connect(database_url).await?); + let credentials = credentials::CredentialRuntime::from_config_file_with_store( + &config, + config_file.as_ref(), + Arc::clone(&store), + ) + .await?; let oidc_cache = if let Some(ref oidc) = config.oidc { // Validate RBAC configuration before starting. @@ -245,7 +286,7 @@ pub async fn run_server( supervisor_sessions.clone(), ) .await?; - let mut state = ServerState::new( + let mut state = ServerState::new_with_credentials( config.clone(), store.clone(), compute, @@ -254,6 +295,7 @@ pub async fn run_server( tracing_log_bus, supervisor_sessions, oidc_cache, + credentials, ); // Load the gateway-minted sandbox JWT signing key when configured. @@ -969,7 +1011,8 @@ mod tests { .with_database_url("sqlite::memory:?cache=shared") .with_bind_address(bind_addr) .with_server_sans(["*.dev.openshell.localhost"]) - .with_loopback_service_http(enable_loopback_service_http), + .with_loopback_service_http(enable_loopback_service_http) + .with_credential_drivers(["test-static"]), store, compute, crate::sandbox_index::SandboxIndex::new(), diff --git a/crates/openshell-server/src/provider_refresh.rs b/crates/openshell-server/src/provider_refresh.rs index b0b9a927c..bc72aedfc 100644 --- a/crates/openshell-server/src/provider_refresh.rs +++ b/crates/openshell-server/src/provider_refresh.rs @@ -10,6 +10,7 @@ use openshell_core::proto::{ Provider, ProviderCredentialRefreshStatus, ProviderCredentialRefreshStrategy, StoredProviderCredentialRefreshState, }; +use openshell_core::{ObjectId, ObjectName}; use prost::Message; use serde::{Deserialize, Serialize}; use std::collections::HashMap; @@ -219,8 +220,6 @@ pub fn new_refresh_state( }) } -use openshell_core::{ObjectId, ObjectName}; - #[derive(Debug)] struct MintedCredential { access_token: String, @@ -294,6 +293,7 @@ pub fn is_gateway_mintable_strategy(strategy: ProviderCredentialRefreshStrategy) pub async fn refresh_provider_credential( store: &Store, + credentials: Option<&crate::credentials::CredentialRuntime>, provider_name: &str, credential_key: &str, ) -> Result { @@ -321,7 +321,8 @@ pub async fn refresh_provider_credential( Ok(minted) => { let now_ms = current_time_ms(); if let Err(err) = - apply_minted_credential(store, &provider, credential_key, &minted).await + apply_minted_credential(store, credentials, &provider, credential_key, &minted) + .await { state.status = "error".to_string(); state.last_error = err.message().to_string(); @@ -399,14 +400,36 @@ pub async fn refresh_provider_credential( async fn apply_minted_credential( store: &Store, + credentials: Option<&crate::credentials::CredentialRuntime>, provider: &Provider, credential_key: &str, minted: &MintedCredential, ) -> Result<(), Status> { let mut updated = provider.clone(); - updated - .credentials - .insert(credential_key.to_string(), minted.access_token.clone()); + let stored_handle = if let Some(credentials) = credentials + && credentials.stores_provider_credentials() + { + let stored = credentials + .store_provider_credentials( + provider.object_name(), + &HashMap::from([(credential_key.to_string(), minted.access_token.clone())]), + &provider.credential_handles, + ) + .await?; + let handle = stored.get(credential_key).cloned().ok_or_else(|| { + Status::internal("credential driver did not return refreshed credential handle") + })?; + updated.credentials.remove(credential_key); + updated + .credential_handles + .insert(credential_key.to_string(), handle.clone()); + Some(handle) + } else { + updated + .credentials + .insert(credential_key.to_string(), minted.access_token.clone()); + None + }; if minted.expires_at_ms > 0 { updated .credential_expires_at_ms @@ -418,9 +441,16 @@ async fn apply_minted_credential( .await?; store .update_message_cas::(provider.object_id(), 0, |current| { - current - .credentials - .insert(credential_key.to_string(), minted.access_token.clone()); + if let Some(handle) = stored_handle.clone() { + current.credentials.remove(credential_key); + current + .credential_handles + .insert(credential_key.to_string(), handle); + } else { + current + .credentials + .insert(credential_key.to_string(), minted.access_token.clone()); + } if minted.expires_at_ms > 0 { current .credential_expires_at_ms @@ -690,14 +720,19 @@ pub fn spawn_refresh_worker(state: std::sync::Arc, interval: ticker.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay); loop { ticker.tick().await; - if let Err(err) = run_refresh_worker_tick(state.store.as_ref()).await { + if let Err(err) = + run_refresh_worker_tick(state.store.as_ref(), Some(&state.credentials)).await + { warn!(error = %err, "provider credential refresh worker tick failed"); } } }); } -async fn run_refresh_worker_tick(store: &Store) -> Result<(), Status> { +async fn run_refresh_worker_tick( + store: &Store, + credentials: Option<&crate::credentials::CredentialRuntime>, +) -> Result<(), Status> { let now_ms = current_time_ms(); let states = list_all_refresh_states(store).await?; let watched_count = states.len(); @@ -752,8 +787,13 @@ async fn run_refresh_worker_tick(store: &Store) -> Result<(), Status> { status = %state.status, "refreshing provider credential" ); - if let Err(err) = - refresh_provider_credential(store, &state.provider_name, &state.credential_key).await + if let Err(err) = refresh_provider_credential( + store, + credentials, + &state.provider_name, + &state.credential_key, + ) + .await { warn!( provider = %state.provider_name, @@ -776,7 +816,9 @@ mod tests { refresh_provider_credential, refresh_state_name, refresh_strategy_name, run_refresh_worker_tick, seconds_until_ms, }; - use crate::persistence::test_store; + use crate::credentials::CredentialRuntime; + use crate::persistence::{current_time_ms, test_store}; + use openshell_core::Config; use openshell_core::ObjectId; use openshell_core::proto::datamodel::v1::ObjectMeta; use openshell_core::proto::{ @@ -849,7 +891,7 @@ mod tests { let store = test_store().await; let provider = provider("my-graph", "outlook"); store.put_message(&provider).await.unwrap(); - let before_refresh_ms = crate::persistence::current_time_ms(); + let before_refresh_ms = current_time_ms(); let state = new_refresh_state( &provider, "MS_GRAPH_ACCESS_TOKEN", @@ -870,9 +912,10 @@ mod tests { .unwrap(); put_refresh_state(&store, &state).await.unwrap(); - let refreshed = refresh_provider_credential(&store, "my-graph", "MS_GRAPH_ACCESS_TOKEN") - .await - .unwrap(); + let refreshed = + refresh_provider_credential(&store, None, "my-graph", "MS_GRAPH_ACCESS_TOKEN") + .await + .unwrap(); assert_eq!(refreshed.status, "refreshed"); assert!(refreshed.expires_at_ms > 0); assert!(refreshed.next_refresh_at_ms > 0); @@ -894,6 +937,79 @@ mod tests { ); } + #[tokio::test] + async fn oauth2_client_credentials_refresh_stores_access_token_with_credential_runtime() { + let mock_server = MockServer::start().await; + Mock::given(method("POST")) + .and(path("/token")) + .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({ + "access_token": "stored-graph-token", + "expires_in": 3600, + "token_type": "Bearer" + }))) + .mount(&mock_server) + .await; + + let store = test_store().await; + let provider = provider("my-stored-graph", "outlook"); + store.put_message(&provider).await.unwrap(); + let state = new_refresh_state( + &provider, + "MS_GRAPH_ACCESS_TOKEN", + NewRefreshStateConfig { + strategy: ProviderCredentialRefreshStrategy::Oauth2ClientCredentials, + material: HashMap::from([ + ("client_id".to_string(), "client-id".to_string()), + ("client_secret".to_string(), "client-secret".to_string()), + ]), + secret_material_keys: vec!["client_secret".to_string()], + expires_at_ms: 0, + token_url: format!("{}/token", mock_server.uri()), + scopes: Vec::new(), + refresh_before_seconds: 30, + max_lifetime_seconds: 60, + }, + ) + .unwrap(); + put_refresh_state(&store, &state).await.unwrap(); + let config = Config::new(None).with_credential_drivers(["test-static"]); + let credentials = CredentialRuntime::from_config(&config).unwrap(); + + let refreshed = refresh_provider_credential( + &store, + Some(&credentials), + "my-stored-graph", + "MS_GRAPH_ACCESS_TOKEN", + ) + .await + .unwrap(); + + let stored = store + .get_message_by_name::("my-stored-graph") + .await + .unwrap() + .unwrap(); + assert!(!stored.credentials.contains_key("MS_GRAPH_ACCESS_TOKEN")); + let handle = stored + .credential_handles + .get("MS_GRAPH_ACCESS_TOKEN") + .unwrap(); + assert_eq!(handle.driver, "test-static"); + assert_eq!( + stored.credential_expires_at_ms.get("MS_GRAPH_ACCESS_TOKEN"), + Some(&refreshed.expires_at_ms) + ); + + let resolved = credentials + .resolve_provider_handles(&stored, current_time_ms()) + .await + .unwrap(); + assert_eq!( + resolved.values.get("MS_GRAPH_ACCESS_TOKEN"), + Some(&"stored-graph-token".to_string()) + ); + } + #[tokio::test] async fn refresh_rejects_minted_credential_key_collision_for_attached_sandbox() { let mock_server = MockServer::start().await; @@ -953,9 +1069,10 @@ mod tests { .unwrap(); put_refresh_state(&store, &state).await.unwrap(); - let err = refresh_provider_credential(&store, "refreshing-graph", "MS_GRAPH_ACCESS_TOKEN") - .await - .unwrap_err(); + let err = + refresh_provider_credential(&store, None, "refreshing-graph", "MS_GRAPH_ACCESS_TOKEN") + .await + .unwrap_err(); assert_eq!(err.code(), tonic::Code::FailedPrecondition); assert!(err.message().contains("MS_GRAPH_ACCESS_TOKEN")); @@ -1021,10 +1138,14 @@ mod tests { .unwrap(); put_refresh_state(&store, &state).await.unwrap(); - let refreshed = - refresh_provider_credential(&store, "my-delegated-graph", "MS_GRAPH_ACCESS_TOKEN") - .await - .unwrap(); + let refreshed = refresh_provider_credential( + &store, + None, + "my-delegated-graph", + "MS_GRAPH_ACCESS_TOKEN", + ) + .await + .unwrap(); assert_eq!(refreshed.status, "refreshed"); assert!(refreshed.expires_at_ms > 0); @@ -1104,7 +1225,7 @@ mod tests { put_refresh_state(&store, &state).await.unwrap(); let refreshed = - refresh_provider_credential(&store, "my-drive", "GOOGLE_DRIVE_ACCESS_TOKEN") + refresh_provider_credential(&store, None, "my-drive", "GOOGLE_DRIVE_ACCESS_TOKEN") .await .unwrap(); assert_eq!(refreshed.status, "refreshed"); @@ -1143,7 +1264,7 @@ mod tests { .unwrap(); put_refresh_state(&store, &state).await.unwrap(); - run_refresh_worker_tick(&store).await.unwrap(); + run_refresh_worker_tick(&store, None).await.unwrap(); let stored_state = get_refresh_state(&store, provider.object_id(), "MS_GRAPH_ACCESS_TOKEN") .await @@ -1177,6 +1298,7 @@ mod tests { credentials: HashMap::new(), config: HashMap::new(), credential_expires_at_ms: HashMap::new(), + credential_handles: HashMap::new(), } } diff --git a/crates/openshell-tui/src/lib.rs b/crates/openshell-tui/src/lib.rs index 7992666d3..665d3aaf7 100644 --- a/crates/openshell-tui/src/lib.rs +++ b/crates/openshell-tui/src/lib.rs @@ -1620,6 +1620,7 @@ fn spawn_create_provider(app: &App, tx: mpsc::UnboundedSender) { credentials: credentials.clone(), config: HashMap::default(), credential_expires_at_ms: HashMap::default(), + credential_handles: HashMap::default(), }), }; @@ -1712,6 +1713,7 @@ fn spawn_update_provider(app: &App, tx: mpsc::UnboundedSender) { credentials, config: HashMap::default(), credential_expires_at_ms: HashMap::default(), + credential_handles: HashMap::default(), }), credential_expires_at_ms: HashMap::default(), }; diff --git a/deploy/helm/openshell/README.md b/deploy/helm/openshell/README.md index e6d539592..66775bd10 100644 --- a/deploy/helm/openshell/README.md +++ b/deploy/helm/openshell/README.md @@ -99,6 +99,18 @@ gateways. `workload.kind=statefulset` is still available for single-replica SQLite installs and for operators who explicitly need StatefulSet identity or storage semantics. +### Credential storage + +By default, the chart uses the gateway's encrypted database credential storage. +The gateway writes encrypted provider credential envelopes to the OpenShell +database. The chart creates a retained Kubernetes Secret with the shared +key-encryption key and injects that key into every gateway pod, so the same +default works for single-replica and external database-backed HA deployments. + +Use `kubernetes-secrets` or `vault` instead when credentials should live in a +cluster or external secret backend. Enabling one external credential driver +disables the default credential-storage key-encryption key Secret and env injection. + #### OpenShift Append these flags to any of the PostgreSQL commands above for OpenShift: @@ -192,6 +204,20 @@ add `ci/values-spire.yaml` to the OpenShell release values files. | securityContext.runAsUser | int | `1000` | UID assigned to the gateway container. | | server.appArmorProfile | string | `"Unconfined"` | Kubernetes AppArmor profile requested for sandbox agent containers. Default Unconfined avoids runtime/default AppArmor blocking the supervisor's network namespace mount setup on AppArmor-enabled nodes. Set to "" to omit the field, "RuntimeDefault" to force the runtime default profile, or "Localhost/profile-name" for an operator-managed localhost profile. | | server.auth.allowUnauthenticatedUsers | bool | `false` | UNSAFE: accept unauthenticated CLI/user requests as a local developer principal. Intended only for trusted local Skaffold/k3d development or a fully trusted fronting proxy. Leave false for shared or production clusters. | +| server.credentialDrivers.kubernetesSecrets.allowReferenceNamespace | bool | `false` | Deprecated compatibility field. Credential storage no longer supports user-authored namespace references. | +| server.credentialDrivers.kubernetesSecrets.enabled | bool | `false` | Enable the in-tree Kubernetes Secret credential driver. | +| server.credentialDrivers.kubernetesSecrets.namespace | string | `""` | Namespace where OpenShell-managed provider Secret objects are stored. Empty = Helm release namespace. | +| server.credentialDrivers.kubernetesSecrets.rbac.create | bool | `true` | Create a Role/RoleBinding granting the gateway ServiceAccount read/write access to managed provider Secrets. | +| server.credentialDrivers.vault.address | string | `""` | Vault service base URL, for example http://vault.vault.svc.cluster.local:8200. | +| server.credentialDrivers.vault.authMethod | string | `"kubernetes"` | Authentication method. Use "kubernetes" in-cluster or "token_file" for local/dev validation. | +| server.credentialDrivers.vault.enabled | bool | `false` | Enable the in-tree Vault credential driver. | +| server.credentialDrivers.vault.kubernetesAuthMount | string | `"kubernetes"` | Vault Kubernetes auth mount. | +| server.credentialDrivers.vault.kvVersion | string | `"2"` | Default KV engine version. Use "1" or "2". | +| server.credentialDrivers.vault.mount | string | `"secret"` | Default KV mount name. | +| server.credentialDrivers.vault.role | string | `""` | Vault Kubernetes auth role when authMethod is kubernetes. | +| server.credentialDrivers.vault.serviceAccountTokenPath | string | `"/var/run/secrets/kubernetes.io/serviceaccount/token"` | ServiceAccount token path used for Kubernetes auth. | +| server.credentialDrivers.vault.timeoutSecs | string | `""` | HTTP request timeout in seconds. Empty = driver default. | +| server.credentialDrivers.vault.tokenPath | string | `""` | Mounted token file path when authMethod is token_file. | | server.dbUrl | string | `"sqlite:/var/openshell/openshell.db"` | Gateway database URL (used for the default SQLite backend). | | server.defaultRuntimeClassName | string | `""` | Default Kubernetes runtimeClassName for sandbox pods. Applied when a CreateSandbox request does not specify one. Empty (default) = omit the field, using the cluster's default RuntimeClass. Set to a RuntimeClass name (e.g. "kata-containers", "nvidia") to apply it to all sandboxes that don't explicitly override it. | | server.disableTls | bool | `false` | Disable TLS entirely - the server listens on plaintext HTTP. Set to true when a reverse proxy / tunnel terminates TLS at the edge. | diff --git a/deploy/helm/openshell/README.md.gotmpl b/deploy/helm/openshell/README.md.gotmpl index e246ca67b..705f6f7e9 100644 --- a/deploy/helm/openshell/README.md.gotmpl +++ b/deploy/helm/openshell/README.md.gotmpl @@ -99,6 +99,18 @@ gateways. `workload.kind=statefulset` is still available for single-replica SQLite installs and for operators who explicitly need StatefulSet identity or storage semantics. +### Credential storage + +By default, the chart uses the gateway's encrypted database credential storage. +The gateway writes encrypted provider credential envelopes to the OpenShell +database. The chart creates a retained Kubernetes Secret with the shared +key-encryption key and injects that key into every gateway pod, so the same +default works for single-replica and external database-backed HA deployments. + +Use `kubernetes-secrets` or `vault` instead when credentials should live in a +cluster or external secret backend. Enabling one external credential driver +disables the default credential-storage key-encryption key Secret and env injection. + #### OpenShift Append these flags to any of the PostgreSQL commands above for OpenShift: diff --git a/deploy/helm/openshell/ci/values-credential-driver-kubernetes-secrets.yaml b/deploy/helm/openshell/ci/values-credential-driver-kubernetes-secrets.yaml new file mode 100644 index 000000000..096ce46e2 --- /dev/null +++ b/deploy/helm/openshell/ci/values-credential-driver-kubernetes-secrets.yaml @@ -0,0 +1,13 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +# Local Kubernetes Secrets credential-driver validation overlay. +# +# Use with: +# skaffold run -p credential-driver-kubernetes-secrets +# +server: + credentialDrivers: + kubernetesSecrets: + enabled: true + namespace: openshell diff --git a/deploy/helm/openshell/ci/values-credential-driver-vault.yaml b/deploy/helm/openshell/ci/values-credential-driver-vault.yaml new file mode 100644 index 000000000..6cae3fdb5 --- /dev/null +++ b/deploy/helm/openshell/ci/values-credential-driver-vault.yaml @@ -0,0 +1,19 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +# Local Vault credential-driver validation overlay. +# +# Use with: +# skaffold run -p credential-driver-vault +# +# The profile assumes another process has already deployed a Vault-compatible +# backend. Local e2e validation deploys OpenBao in the `openbao` namespace with +# a Kubernetes auth role named `openshell-gateway` bound to the OpenShell +# gateway ServiceAccount in the `openshell` namespace. + +server: + credentialDrivers: + vault: + enabled: true + address: http://openbao.openbao.svc.cluster.local:8200 + role: openshell-gateway diff --git a/deploy/helm/openshell/skaffold.yaml b/deploy/helm/openshell/skaffold.yaml index 37a21fbac..8b444b0ba 100644 --- a/deploy/helm/openshell/skaffold.yaml +++ b/deploy/helm/openshell/skaffold.yaml @@ -127,3 +127,14 @@ deploy: image.tag: '{{.IMAGE_TAG_openshell_gateway}}' supervisor.image.repository: '{{.IMAGE_REPO_openshell_supervisor}}' supervisor.image.tag: '{{.IMAGE_TAG_openshell_supervisor}}' +profiles: + - name: credential-driver-kubernetes-secrets + patches: + - op: add + path: /deploy/helm/releases/0/valuesFiles/- + value: ci/values-credential-driver-kubernetes-secrets.yaml + - name: credential-driver-vault + patches: + - op: add + path: /deploy/helm/releases/0/valuesFiles/- + value: ci/values-credential-driver-vault.yaml diff --git a/deploy/helm/openshell/templates/_gateway-workload.tpl b/deploy/helm/openshell/templates/_gateway-workload.tpl index 5931047e5..1af71bfb0 100644 --- a/deploy/helm/openshell/templates/_gateway-workload.tpl +++ b/deploy/helm/openshell/templates/_gateway-workload.tpl @@ -50,6 +50,13 @@ spec: - {{ .Values.server.dbUrl | quote }} {{- end }} env: + {{- if not (or .Values.server.credentialDrivers.kubernetesSecrets.enabled .Values.server.credentialDrivers.vault.enabled) }} + - name: {{ include "openshell.credentialStorageKeyEncryptionKeyEnvName" . }} + valueFrom: + secretKeyRef: + name: {{ include "openshell.credentialStorageKeyEncryptionKeySecretName" . }} + key: {{ include "openshell.credentialStorageKeyEncryptionKeySecretKey" . }} + {{- end }} {{- if .Values.server.externalDbSecret }} - name: OPENSHELL_DB_URL valueFrom: @@ -57,10 +64,9 @@ spec: name: {{ .Values.server.externalDbSecret }} key: uri {{- end }} - # All gateway settings live in the ConfigMap-backed TOML file - # mounted at /etc/openshell/gateway.toml. The only env var below - # is a process-level setting consumed by libraries outside - # gateway code (currently just SSL_CERT_FILE for OIDC issuer TLS). + # Most gateway settings live in the ConfigMap-backed TOML file + # mounted at /etc/openshell/gateway.toml. Secret-bearing settings use + # env vars that the TOML references by name. {{- if and .Values.server.oidc.issuer .Values.server.oidc.caConfigMapName }} # OIDC issuer custom-CA: rustls/reqwest read SSL_CERT_FILE for # outbound TLS verification. This is a process-level env var diff --git a/deploy/helm/openshell/templates/_helpers.tpl b/deploy/helm/openshell/templates/_helpers.tpl index 30c027576..199463c15 100644 --- a/deploy/helm/openshell/templates/_helpers.tpl +++ b/deploy/helm/openshell/templates/_helpers.tpl @@ -102,6 +102,34 @@ Namespace where sandbox pods are created. An explicit {{- .Values.server.sandboxNamespace | default .Release.Namespace -}} {{- end }} +{{/* +Namespace where Kubernetes Secret-backed provider credentials live. +*/}} +{{- define "openshell.credentialKubernetesSecretsNamespace" -}} +{{- .Values.server.credentialDrivers.kubernetesSecrets.namespace | default .Release.Namespace -}} +{{- end }} + +{{/* +Name of the Secret holding the default credential storage key-encryption key. +*/}} +{{- define "openshell.credentialStorageKeyEncryptionKeySecretName" -}} +{{- printf "%s-credential-storage-key-encryption-key" (include "openshell.fullname" .) | trunc 63 | trimSuffix "-" -}} +{{- end }} + +{{/* +Key inside the default credential storage key-encryption key Secret. +*/}} +{{- define "openshell.credentialStorageKeyEncryptionKeySecretKey" -}} +key-encryption-key +{{- end }} + +{{/* +Gateway environment variable used to pass the default credential storage key-encryption key. +*/}} +{{- define "openshell.credentialStorageKeyEncryptionKeyEnvName" -}} +OPENSHELL_GATEWAY_CREDENTIAL_KEY_ENCRYPTION_KEY +{{- end }} + {{/* Name of the Secret holding gateway-minted sandbox JWT signing material. */}} @@ -178,4 +206,14 @@ Validate chart values that Helm would otherwise accept silently. {{- if and (eq $workloadKind "statefulset") (gt $replicaCount 1) (not (get $workload "allowMultiReplicaStatefulSet" | default false)) -}} {{- fail "replicaCount > 1 with workload.kind=statefulset requires workload.allowMultiReplicaStatefulSet=true; use workload.kind=deployment for external database-backed multi-replica gateways." -}} {{- end -}} +{{- $credentialDrivers := list -}} +{{- if .Values.server.credentialDrivers.kubernetesSecrets.enabled -}} +{{- $credentialDrivers = append $credentialDrivers "kubernetes-secrets" -}} +{{- end -}} +{{- if .Values.server.credentialDrivers.vault.enabled -}} +{{- $credentialDrivers = append $credentialDrivers "vault" -}} +{{- end -}} +{{- if gt (len $credentialDrivers) 1 -}} +{{- fail "only one external server.credentialDrivers backend can be enabled at a time." -}} +{{- end -}} {{- end }} diff --git a/deploy/helm/openshell/templates/credential-secrets-role.yaml b/deploy/helm/openshell/templates/credential-secrets-role.yaml new file mode 100644 index 000000000..a986bb70d --- /dev/null +++ b/deploy/helm/openshell/templates/credential-secrets-role.yaml @@ -0,0 +1,22 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +{{- if and .Values.server.credentialDrivers.kubernetesSecrets.enabled .Values.server.credentialDrivers.kubernetesSecrets.rbac.create }} +apiVersion: rbac.authorization.k8s.io/v1 +kind: Role +metadata: + name: {{ include "openshell.fullname" . }}-credential-secrets + namespace: {{ include "openshell.credentialKubernetesSecretsNamespace" . }} + labels: + {{- include "openshell.labels" . | nindent 4 }} +rules: + - apiGroups: + - "" + resources: + - secrets + verbs: + - get + - create + - patch + - delete +{{- end }} diff --git a/deploy/helm/openshell/templates/credential-secrets-rolebinding.yaml b/deploy/helm/openshell/templates/credential-secrets-rolebinding.yaml new file mode 100644 index 000000000..4274fa6e1 --- /dev/null +++ b/deploy/helm/openshell/templates/credential-secrets-rolebinding.yaml @@ -0,0 +1,20 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +{{- if and .Values.server.credentialDrivers.kubernetesSecrets.enabled .Values.server.credentialDrivers.kubernetesSecrets.rbac.create }} +apiVersion: rbac.authorization.k8s.io/v1 +kind: RoleBinding +metadata: + name: {{ include "openshell.fullname" . }}-credential-secrets + namespace: {{ include "openshell.credentialKubernetesSecretsNamespace" . }} + labels: + {{- include "openshell.labels" . | nindent 4 }} +roleRef: + apiGroup: rbac.authorization.k8s.io + kind: Role + name: {{ include "openshell.fullname" . }}-credential-secrets +subjects: + - kind: ServiceAccount + name: {{ include "openshell.serviceAccountName" . }} + namespace: {{ .Release.Namespace }} +{{- end }} diff --git a/deploy/helm/openshell/templates/credential-storage-key-encryption-key-secret.yaml b/deploy/helm/openshell/templates/credential-storage-key-encryption-key-secret.yaml new file mode 100644 index 000000000..54d844afa --- /dev/null +++ b/deploy/helm/openshell/templates/credential-storage-key-encryption-key-secret.yaml @@ -0,0 +1,27 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +{{- if not (or .Values.server.credentialDrivers.kubernetesSecrets.enabled .Values.server.credentialDrivers.vault.enabled) }} +{{- $secretName := include "openshell.credentialStorageKeyEncryptionKeySecretName" . -}} +{{- $secretKey := include "openshell.credentialStorageKeyEncryptionKeySecretKey" . -}} +{{- $existing := lookup "v1" "Secret" .Release.Namespace $secretName -}} +{{- $encodedKeyEncryptionKey := randBytes 32 | b64enc -}} +{{- if $existing -}} +{{- $existingData := get $existing "data" | default dict -}} +{{- if not (hasKey $existingData $secretKey) -}} +{{- fail (printf "existing credential storage key-encryption key Secret %s/%s is missing key %s" .Release.Namespace $secretName $secretKey) -}} +{{- end -}} +{{- $encodedKeyEncryptionKey = index $existingData $secretKey -}} +{{- end }} +apiVersion: v1 +kind: Secret +metadata: + name: {{ $secretName }} + namespace: {{ .Release.Namespace }} + labels: + {{- include "openshell.labels" . | nindent 4 }} + annotations: + helm.sh/resource-policy: keep +type: Opaque +data: + {{ $secretKey }}: {{ $encodedKeyEncryptionKey | quote }} +{{- end }} diff --git a/deploy/helm/openshell/templates/gateway-config.yaml b/deploy/helm/openshell/templates/gateway-config.yaml index 7037be88f..2e2a6cba5 100644 --- a/deploy/helm/openshell/templates/gateway-config.yaml +++ b/deploy/helm/openshell/templates/gateway-config.yaml @@ -12,6 +12,13 @@ One value is intentionally NOT rendered here: when server.externalDbSecret is set, otherwise --db-url arg for SQLite */}} +{{- $credentialDrivers := list -}} +{{- if .Values.server.credentialDrivers.kubernetesSecrets.enabled -}} +{{- $credentialDrivers = append $credentialDrivers "kubernetes-secrets" -}} +{{- end -}} +{{- if .Values.server.credentialDrivers.vault.enabled -}} +{{- $credentialDrivers = append $credentialDrivers "vault" -}} +{{- end -}} apiVersion: v1 kind: ConfigMap metadata: @@ -32,6 +39,9 @@ data: metrics_bind_address = "0.0.0.0:{{ .Values.service.metricsPort }}" {{- end }} log_level = {{ .Values.server.logLevel | quote }} + {{- if $credentialDrivers }} + credential_drivers = [{{- range $i, $driver := $credentialDrivers }}{{ if $i }}, {{ end }}{{ $driver | quote }}{{- end }}] + {{- end }} sandbox_namespace = {{ include "openshell.sandboxNamespace" . | quote }} default_image = {{ .Values.server.sandboxImage | quote }} supervisor_image = {{ include "openshell.supervisorImage" . | quote }} @@ -141,3 +151,40 @@ data: {{- if .Values.supervisor.image.pullPolicy }} supervisor_image_pull_policy = {{ .Values.supervisor.image.pullPolicy | quote }} {{- end }} + + {{- if not $credentialDrivers }} + + [openshell.gateway.credential_storage] + key_encryption_key_env = {{ include "openshell.credentialStorageKeyEncryptionKeyEnvName" . | quote }} + {{- end }} + + {{- if .Values.server.credentialDrivers.kubernetesSecrets.enabled }} + + [openshell.credential_drivers.kubernetes-secrets] + namespace = {{ include "openshell.credentialKubernetesSecretsNamespace" . | quote }} + allow_reference_namespace = {{ .Values.server.credentialDrivers.kubernetesSecrets.allowReferenceNamespace }} + {{- end }} + + {{- if .Values.server.credentialDrivers.vault.enabled }} + + [openshell.credential_drivers.vault] + address = {{ .Values.server.credentialDrivers.vault.address | quote }} + mount = {{ .Values.server.credentialDrivers.vault.mount | quote }} + kv_version = {{ .Values.server.credentialDrivers.vault.kvVersion | quote }} + auth_method = {{ .Values.server.credentialDrivers.vault.authMethod | quote }} + {{- if .Values.server.credentialDrivers.vault.role }} + role = {{ .Values.server.credentialDrivers.vault.role | quote }} + {{- end }} + {{- if .Values.server.credentialDrivers.vault.kubernetesAuthMount }} + kubernetes_auth_mount = {{ .Values.server.credentialDrivers.vault.kubernetesAuthMount | quote }} + {{- end }} + {{- if .Values.server.credentialDrivers.vault.serviceAccountTokenPath }} + service_account_token_path = {{ .Values.server.credentialDrivers.vault.serviceAccountTokenPath | quote }} + {{- end }} + {{- if .Values.server.credentialDrivers.vault.tokenPath }} + token_path = {{ .Values.server.credentialDrivers.vault.tokenPath | quote }} + {{- end }} + {{- if .Values.server.credentialDrivers.vault.timeoutSecs }} + timeout_secs = {{ .Values.server.credentialDrivers.vault.timeoutSecs }} + {{- end }} + {{- end }} diff --git a/deploy/helm/openshell/tests/credential_drivers_test.yaml b/deploy/helm/openshell/tests/credential_drivers_test.yaml new file mode 100644 index 000000000..76d8e081a --- /dev/null +++ b/deploy/helm/openshell/tests/credential_drivers_test.yaml @@ -0,0 +1,172 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +suite: credential drivers +templates: + - templates/gateway-config.yaml + - templates/credential-storage-key-encryption-key-secret.yaml + - templates/statefulset.yaml + - templates/deployment.yaml + - templates/credential-secrets-role.yaml + - templates/credential-secrets-rolebinding.yaml +release: + name: openshell + namespace: my-namespace + +tests: + - it: renders default encrypted credential storage by default + template: templates/gateway-config.yaml + asserts: + - notMatchRegex: + path: data["gateway.toml"] + pattern: 'credential_drivers\s*=' + - matchRegex: + path: data["gateway.toml"] + pattern: '(?ms)\[openshell\.gateway\.credential_storage\].*?key_encryption_key_env\s*=\s*"OPENSHELL_GATEWAY_CREDENTIAL_KEY_ENCRYPTION_KEY"' + - notMatchRegex: + path: data["gateway.toml"] + pattern: 'key_encryption_key_path\s*=' + + - it: creates a retained default credential storage key-encryption key Secret by default + template: templates/credential-storage-key-encryption-key-secret.yaml + asserts: + - equal: + path: kind + value: Secret + - matchRegex: + path: metadata.name + pattern: 'credential-storage-key-encryption-key$' + - equal: + path: metadata.annotations["helm.sh/resource-policy"] + value: keep + - matchRegex: + path: data["key-encryption-key"] + pattern: '.+' + + - it: injects the default credential storage key-encryption key Secret into the gateway pod by default + template: templates/statefulset.yaml + asserts: + - equal: + path: spec.template.spec.containers[0].env[0].name + value: OPENSHELL_GATEWAY_CREDENTIAL_KEY_ENCRYPTION_KEY + - matchRegex: + path: spec.template.spec.containers[0].env[0].valueFrom.secretKeyRef.name + pattern: 'credential-storage-key-encryption-key$' + - equal: + path: spec.template.spec.containers[0].env[0].valueFrom.secretKeyRef.key + value: key-encryption-key + + - it: renders Kubernetes Secrets credential driver config + template: templates/gateway-config.yaml + set: + server.credentialDrivers.kubernetesSecrets.enabled: true + server.credentialDrivers.kubernetesSecrets.namespace: provider-secrets + asserts: + - matchRegex: + path: data["gateway.toml"] + pattern: 'credential_drivers\s*=\s*\["kubernetes-secrets"\]' + - matchRegex: + path: data["gateway.toml"] + pattern: '(?ms)\[openshell\.credential_drivers\.kubernetes-secrets\].*?namespace\s*=\s*"provider-secrets".*?allow_reference_namespace\s*=\s*false' + - notMatchRegex: + path: data["gateway.toml"] + pattern: 'transport\s*=\s*"in_tree"' + - notMatchRegex: + path: data["gateway.toml"] + pattern: '\[openshell\.gateway\.credential_storage\]' + + - it: renders Vault credential driver config + template: templates/gateway-config.yaml + set: + server.credentialDrivers.vault.enabled: true + server.credentialDrivers.vault.address: http://vault.vault.svc.cluster.local:8200 + server.credentialDrivers.vault.role: openshell-gateway + asserts: + - matchRegex: + path: data["gateway.toml"] + pattern: 'credential_drivers\s*=\s*\["vault"\]' + - matchRegex: + path: data["gateway.toml"] + pattern: '(?ms)\[openshell\.credential_drivers\.vault\].*?address\s*=\s*"http://vault\.vault\.svc\.cluster\.local:8200".*?auth_method\s*=\s*"kubernetes".*?role\s*=\s*"openshell-gateway"' + - notMatchRegex: + path: data["gateway.toml"] + pattern: 'transport\s*=\s*"in_tree"' + + - it: rejects multiple enabled credential drivers + template: templates/statefulset.yaml + set: + server.credentialDrivers.kubernetesSecrets.enabled: true + server.credentialDrivers.vault.enabled: true + server.credentialDrivers.vault.address: http://vault.vault.svc.cluster.local:8200 + server.credentialDrivers.vault.role: openshell-gateway + asserts: + - failedTemplate: + errorPattern: "only one external server.credentialDrivers backend can be enabled at a time" + + - it: creates namespaced Kubernetes Secret manager RBAC + template: templates/credential-secrets-role.yaml + set: + server.credentialDrivers.kubernetesSecrets.enabled: true + server.credentialDrivers.kubernetesSecrets.namespace: provider-secrets + asserts: + - equal: + path: metadata.namespace + value: provider-secrets + - equal: + path: rules[0].resources[0] + value: secrets + - equal: + path: rules[0].verbs[0] + value: get + - contains: + path: rules[0].verbs + content: create + - contains: + path: rules[0].verbs + content: patch + - contains: + path: rules[0].verbs + content: delete + + - it: binds Kubernetes Secret manager RBAC to the gateway ServiceAccount + template: templates/credential-secrets-rolebinding.yaml + set: + server.credentialDrivers.kubernetesSecrets.enabled: true + server.credentialDrivers.kubernetesSecrets.namespace: provider-secrets + asserts: + - equal: + path: metadata.namespace + value: provider-secrets + - equal: + path: subjects[0].name + value: openshell + - equal: + path: subjects[0].namespace + value: my-namespace + + - it: allows default credential storage on a Deployment with an external database + template: templates/deployment.yaml + set: + workload.kind: deployment + server.externalDbSecret: openshell-pg + asserts: + - equal: + path: kind + value: Deployment + - equal: + path: spec.template.spec.containers[0].env[0].name + value: OPENSHELL_GATEWAY_CREDENTIAL_KEY_ENCRYPTION_KEY + + - it: allows default credential storage with multiple replicas and an external database + template: templates/statefulset.yaml + set: + replicaCount: 2 + server.externalDbSecret: openshell-pg + workload.allowMultiReplicaStatefulSet: true + asserts: + - equal: + path: spec.replicas + value: 2 + - equal: + path: spec.template.spec.containers[0].env[0].name + value: OPENSHELL_GATEWAY_CREDENTIAL_KEY_ENCRYPTION_KEY diff --git a/deploy/helm/openshell/tests/gateway_config_test.yaml b/deploy/helm/openshell/tests/gateway_config_test.yaml index c2708a20f..c97633005 100644 --- a/deploy/helm/openshell/tests/gateway_config_test.yaml +++ b/deploy/helm/openshell/tests/gateway_config_test.yaml @@ -274,6 +274,7 @@ tests: workload.kind: deployment replicaCount: 2 server.externalDbSecret: my-pg-secret + server.credentialDrivers.kubernetesSecrets.enabled: true asserts: - equal: path: kind @@ -321,6 +322,7 @@ tests: replicaCount: 2 server.externalDbSecret: my-pg-secret workload.allowMultiReplicaStatefulSet: true + server.credentialDrivers.kubernetesSecrets.enabled: true asserts: - equal: path: kind diff --git a/deploy/helm/openshell/values.yaml b/deploy/helm/openshell/values.yaml index d7ff8b257..d6b8819fb 100644 --- a/deploy/helm/openshell/values.yaml +++ b/deploy/helm/openshell/values.yaml @@ -218,6 +218,42 @@ server: # -- gRPC rate-limit window length in seconds. Must be positive (alongside # requests) to enable rate limiting; 0 (default) disables it. windowSeconds: 0 + # Provider credential drivers store provider credential secret material in an + # external or native backend. When no driver is enabled, the gateway uses its + # default encrypted database credential storage with a retained Kubernetes + # Secret for the shared key-encryption key. + credentialDrivers: + kubernetesSecrets: + # -- Enable the in-tree Kubernetes Secret credential driver. + enabled: false + # -- Namespace where OpenShell-managed provider Secret objects are stored. Empty = Helm release namespace. + namespace: "" + # -- Deprecated compatibility field. Credential storage no longer supports user-authored namespace references. + allowReferenceNamespace: false + rbac: + # -- Create a Role/RoleBinding granting the gateway ServiceAccount read/write access to managed provider Secrets. + create: true + vault: + # -- Enable the in-tree Vault credential driver. + enabled: false + # -- Vault service base URL, for example http://vault.vault.svc.cluster.local:8200. + address: "" + # -- Default KV mount name. + mount: secret + # -- Default KV engine version. Use "1" or "2". + kvVersion: "2" + # -- Authentication method. Use "kubernetes" in-cluster or "token_file" for local/dev validation. + authMethod: kubernetes + # -- Vault Kubernetes auth role when authMethod is kubernetes. + role: "" + # -- Vault Kubernetes auth mount. + kubernetesAuthMount: kubernetes + # -- ServiceAccount token path used for Kubernetes auth. + serviceAccountTokenPath: /var/run/secrets/kubernetes.io/serviceaccount/token + # -- Mounted token file path when authMethod is token_file. + tokenPath: "" + # -- HTTP request timeout in seconds. Empty = driver default. + timeoutSecs: "" auth: # -- UNSAFE: accept unauthenticated CLI/user requests as a local developer # principal. Intended only for trusted local Skaffold/k3d development or a diff --git a/docs/reference/gateway-config.mdx b/docs/reference/gateway-config.mdx index ff4542136..3a3b8d102 100644 --- a/docs/reference/gateway-config.mdx +++ b/docs/reference/gateway-config.mdx @@ -31,7 +31,7 @@ Package-managed gateways do not require a TOML file. Create one at the package's ## Layout -The file is rooted at `[openshell]`. Gateway-wide settings live under `[openshell.gateway]`. Each compute driver owns its own `[openshell.drivers.]` table. Shared keys set at gateway scope are inherited into driver tables when not overridden. +The file is rooted at `[openshell]`. Gateway-wide settings live under `[openshell.gateway]`. Each compute driver owns its own `[openshell.drivers.]` table. Credential drivers own `[openshell.credential_drivers.]` tables. Shared compute-driver keys set at gateway scope are inherited into compute driver tables when not overridden. ```toml [openshell] @@ -48,6 +48,9 @@ version = 1 [openshell.drivers.kubernetes] # ... driver-specific settings ... + +[openshell.credential_drivers.kubernetes-secrets] +# ... credential-driver-specific settings ... ``` ## Full Example @@ -72,6 +75,10 @@ log_level = "info" # VM is never auto-detected and requires an explicit entry here. compute_drivers = ["kubernetes"] +# Optional external provider credential storage backend. Omit this key to use +# the gateway's default encrypted database credential storage. +credential_drivers = ["kubernetes-secrets"] + sandbox_namespace = "openshell" ssh_session_ttl_secs = 3600 @@ -132,6 +139,10 @@ roles_claim = "realm_access.roles" admin_role = "openshell-admin" user_role = "openshell-user" scopes_claim = "" + +[openshell.credential_drivers.kubernetes-secrets] +namespace = "openshell" +allow_reference_namespace = false ``` Local Docker, Podman, and VM gateways can also set `[openshell.gateway.mtls_auth] enabled = true` to authenticate CLI callers from verified client certificates. Kubernetes deployments must leave this unset and use OIDC or a trusted access proxy; the Helm chart does not render this table. @@ -142,6 +153,83 @@ Local Docker, Podman, and VM gateways can also set `[openshell.gateway.mtls_auth `image_pull_policy` is intentionally not a shared gateway key. Kubernetes and Docker use `Always`, `IfNotPresent`, or `Never`. Podman uses `always`, `missing`, `never`, or `newer`. Set it inside the relevant driver table. +## Credential Drivers + +Set `credential_drivers` only when the gateway should store provider credentials in an external credential backend. OpenShell supports at most one enabled credential driver at a time. When `credential_drivers` is omitted, the gateway uses its default encrypted database credential storage. `credential_drivers = []` is invalid in the TOML file; omit the field for the default encrypted store, or select a backend such as `kubernetes-secrets` or `vault`. + +Credential driver tables are backend-owned and live under `[openshell.credential_drivers.]`. Built-in drivers default to in-tree transport, so they do not need a `transport` field. Use `transport = "uds"` with an absolute `socket_path` only for a remote gRPC driver over a Unix domain socket. + +```toml +[openshell.gateway.credential_storage] +key_encryption_key_path = "/var/lib/openshell/credentials/key-encryption-key.bin" +``` + +For Kubernetes Secrets: + +```toml +[openshell.gateway] +credential_drivers = ["kubernetes-secrets"] + +[openshell.credential_drivers.kubernetes-secrets] +namespace = "openshell" +``` + +For Vault instead: + +```toml +[openshell.gateway] +credential_drivers = ["vault"] + +[openshell.credential_drivers.vault] +address = "http://vault.vault.svc.cluster.local:8200" +mount = "secret" +kv_version = "2" +auth_method = "kubernetes" +role = "openshell-gateway" +service_account_token_path = "/var/run/secrets/kubernetes.io/serviceaccount/token" +``` + +For the default encrypted database store, OpenShell stores provider credentials as JSON envelopes encrypted with AES-256-GCM in the gateway database. Each credential gets a random data-encryption key; the gateway wraps that key with a local key-encryption key. By default, the key-encryption key is created at `$XDG_STATE_HOME/openshell/gateway/credentials/key-encryption-key.bin` with owner-only permissions. Use `[openshell.gateway.credential_storage] key_encryption_key_env` instead of `key_encryption_key_path` to load a base64-encoded 32-byte key-encryption key from an environment variable. Back up the database and key-encryption key together; losing either makes stored credentials unrecoverable. In Kubernetes, the Helm chart creates a retained Secret containing the shared key-encryption key, injects it as `OPENSHELL_GATEWAY_CREDENTIAL_KEY_ENCRYPTION_KEY`, and renders `key_encryption_key_env` for the gateway when no external credential driver is enabled. Multi-replica deployments need every replica to use the same database and key-encryption key; the chart default handles the key-encryption key side. + +For `kubernetes-secrets`, `namespace` sets where OpenShell-managed provider Secret objects are stored. When omitted, the driver uses the in-cluster ServiceAccount namespace when available, otherwise `default`. + +For `vault`, `address` points at the Vault service, `mount` and `kv_version` describe the KV engine where OpenShell-managed provider secrets are stored, and `auth_method = "kubernetes"` logs in with the gateway Pod's ServiceAccount token. For local or development validation, use `auth_method = "token_file"` with `token_path = "/path/to/token"`. Do not put literal Vault tokens in TOML. + +Provider records that already contain inline database credentials remain readable for upgrade compatibility. New provider create/update requests still submit credential values through the normal API, but the gateway stores those values through the active credential storage path and persists only handles. + +For remote credential drivers, set `transport = "uds"` with `socket_path`. Omit `command`, `args`, and `startup_timeout_secs` when another service manager prestarts the driver socket. Keep backend tokens out of TOML; point the driver at mounted token files or native identity mechanisms instead. + +The built-in `kubernetes-secrets` and `vault` drivers can also run out of +process over UDS. Set `command` to the standalone driver binary and pass +driver-specific settings through `args`; the gateway appends `--bind-socket +` when it launches the process. + +```toml +[openshell.gateway] +credential_drivers = ["kubernetes-secrets"] + +[openshell.credential_drivers.kubernetes-secrets] +transport = "uds" +socket_path = "/run/openshell/credential-drivers/kubernetes-secrets.sock" +command = "/usr/libexec/openshell/openshell-driver-kubernetes-secrets" +args = ["--namespace", "openshell"] +``` + +```toml +[openshell.gateway] +credential_drivers = ["vault"] + +[openshell.credential_drivers.vault] +transport = "uds" +socket_path = "/run/openshell/credential-drivers/vault.sock" +command = "/usr/libexec/openshell/openshell-driver-vault" +args = [ + "--address", "http://vault.vault.svc.cluster.local:8200", + "--auth-method", "kubernetes", + "--role", "openshell-gateway", +] +``` + ## Driver References Each example is a complete TOML file for one compute driver. The examples repeat `[openshell]` and `[openshell.gateway]` so they stay copyable, and the driver tables list the accepted driver-specific keys. Driver-specific values override inherited gateway defaults. The gateway rejects unknown driver fields after inheritance is merged. diff --git a/docs/sandboxes/providers-v2.mdx b/docs/sandboxes/providers-v2.mdx index 896ac641e..fef4f7ed4 100644 --- a/docs/sandboxes/providers-v2.mdx +++ b/docs/sandboxes/providers-v2.mdx @@ -55,6 +55,7 @@ Providers v2 currently includes these user-facing features: - `openshell provider list-profiles` with table, YAML, and JSON output. - `openshell provider profile export`, `import`, `update`, `lint`, and `delete` for custom profiles. - Provider instances created from built-in or imported profile IDs with `openshell provider create --type `. +- Provider instances whose submitted credentials can be stored by a configured gateway credential driver. - Profile-backed credential discovery for explicit `openshell provider create --from-existing` and `openshell provider update --from-existing` flows. The built-in `google-vertex-ai` profile also supplements discovery with Vertex config env vars such as `VERTEX_AI_PROJECT_ID` and `VERTEX_AI_REGION`. - Just-in-time effective policy composition from sandbox policy plus attached provider profiles. - Runtime sandbox provider lifecycle commands under `openshell sandbox provider list|attach|detach`. @@ -381,6 +382,30 @@ openshell provider create \ --credential CUSTOM_API_TOKEN ``` +Create a provider whose credential is stored by a configured gateway credential +driver: + +```shell +openshell provider create \ + --name openai-stored \ + --type openai \ + --credential OPENAI_API_KEY +``` + +The create/update API stores submitted provider credentials through the +gateway's active credential storage path and persists only internal credential +handles. By default, the gateway stores AES-256-GCM encrypted credential +envelopes in the gateway database outside the provider record. The Helm chart +creates a retained Kubernetes Secret for the default storage key-encryption key +and injects it into every gateway pod when no external credential driver is enabled. +`credential_drivers = []` is invalid. Multi-replica Kubernetes gateways can use +a shared database with the default encrypted store, or choose a shared backend +such as `kubernetes-secrets` or `vault`. + +Provider records that already contain inline database credentials remain +readable for upgrade compatibility. New provider create/update requests store +credential values through the active credential driver and persist only handles. + Provider profiles whose required credentials are fully runtime-resolvable through `token_grant` or gateway-managed refresh can be created without `--credential`. Inspect the provider: diff --git a/e2e/rust/Cargo.toml b/e2e/rust/Cargo.toml index 083c622df..4479e7f95 100644 --- a/e2e/rust/Cargo.toml +++ b/e2e/rust/Cargo.toml @@ -28,6 +28,7 @@ e2e-docker = ["e2e", "e2e-host-gateway", "e2e-local-container-driver"] e2e-gpu = ["e2e"] e2e-docker-gpu = ["e2e-docker", "e2e-gpu"] e2e-kubernetes = ["e2e"] +e2e-kubernetes-credential-drivers = ["e2e-kubernetes"] e2e-podman = ["e2e", "e2e-host-gateway", "e2e-local-container-driver"] e2e-podman-gpu = ["e2e-podman", "e2e-gpu"] e2e-vm = ["e2e"] @@ -72,6 +73,11 @@ name = "readyz_health" path = "tests/readyz_health.rs" required-features = ["e2e-kubernetes"] +[[test]] +name = "credential_drivers" +path = "tests/credential_drivers.rs" +required-features = ["e2e-kubernetes-credential-drivers"] + [[test]] name = "websocket_conformance" path = "tests/websocket_conformance.rs" diff --git a/e2e/rust/e2e-kubernetes.sh b/e2e/rust/e2e-kubernetes.sh index 0644a0618..fc15667d6 100755 --- a/e2e/rust/e2e-kubernetes.sh +++ b/e2e/rust/e2e-kubernetes.sh @@ -28,6 +28,22 @@ if [ -n "${OPENSHELL_E2E_KUBE_TEST:-}" ]; then test_filter+=(--test "${OPENSHELL_E2E_KUBE_TEST}") fi +run_suite() { + "${ROOT}/e2e/with-kube-gateway.sh" \ + cargo test --manifest-path "${ROOT}/e2e/rust/Cargo.toml" \ + --features "${E2E_FEATURES}" \ + --no-fail-fast \ + ${test_filter[@]+"${test_filter[@]}"} \ + -- --nocapture +} + +if [ "${OPENSHELL_E2E_CREDENTIAL_DRIVERS:-0}" = "1" ] \ + && [ -z "${OPENSHELL_E2E_CREDENTIAL_DRIVER:-}" ]; then + OPENSHELL_E2E_CREDENTIAL_DRIVER=kubernetes-secrets run_suite + OPENSHELL_E2E_CREDENTIAL_DRIVER=vault run_suite + exit 0 +fi + exec "${ROOT}/e2e/with-kube-gateway.sh" \ cargo test --manifest-path "${ROOT}/e2e/rust/Cargo.toml" \ --features "${E2E_FEATURES}" \ diff --git a/e2e/rust/tests/credential_drivers.rs b/e2e/rust/tests/credential_drivers.rs new file mode 100644 index 000000000..e4f8fc735 --- /dev/null +++ b/e2e/rust/tests/credential_drivers.rs @@ -0,0 +1,431 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +#![cfg(feature = "e2e-kubernetes-credential-drivers")] + +use std::process::Stdio; +use std::time::{SystemTime, UNIX_EPOCH}; + +use base64::{Engine as _, engine::general_purpose::STANDARD as BASE64_STANDARD}; +use openshell_e2e::harness::binary::openshell_cmd; +use openshell_e2e::harness::cli::run_cli; +use openshell_e2e::harness::output::strip_ansi; +use openshell_e2e::harness::sandbox::SandboxGuard; +use sha2::{Digest, Sha256}; +use tokio::io::AsyncWriteExt; + +const CREDENTIAL_KEY: &str = "OPENAI_API_KEY"; +const VAULT_POLICY: &str = r#"path "secret/data/openshell/provider-credentials/*" { + capabilities = ["create", "read", "update", "delete"] +} + +path "secret/metadata/openshell/provider-credentials/*" { + capabilities = ["read", "delete", "list"] +} +"#; + +fn unique_suffix() -> String { + let millis = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .as_millis(); + format!("{}-{millis}", std::process::id()) +} + +fn namespace() -> String { + std::env::var("OPENSHELL_E2E_SANDBOX_NAMESPACE").unwrap_or_else(|_| "openshell".to_string()) +} + +fn credential_driver() -> String { + std::env::var("OPENSHELL_E2E_CREDENTIAL_DRIVER") + .unwrap_or_else(|_| "kubernetes-secrets".to_string()) +} + +fn vault_namespace() -> String { + std::env::var("OPENSHELL_E2E_VAULT_NAMESPACE").unwrap_or_else(|_| "vault".to_string()) +} + +fn vault_pod() -> String { + std::env::var("OPENSHELL_E2E_VAULT_POD").unwrap_or_else(|_| "vault-0".to_string()) +} + +fn vault_token() -> String { + std::env::var("OPENSHELL_E2E_VAULT_TOKEN").unwrap_or_else(|_| "root".to_string()) +} + +fn managed_kubernetes_secret_name(provider_name: &str) -> String { + let mut hasher = Sha256::new(); + hasher.update(provider_name.as_bytes()); + hasher.update([0]); + hasher.update(CREDENTIAL_KEY.as_bytes()); + let digest = hasher.finalize(); + let hex = format!("{digest:x}"); + format!("openshell-cred-{}", &hex[..40]) +} + +fn managed_vault_path(provider_name: &str) -> String { + let mut hasher = Sha256::new(); + hasher.update(provider_name.as_bytes()); + hasher.update([0]); + hasher.update(CREDENTIAL_KEY.as_bytes()); + let digest = hasher.finalize(); + let hex = format!("{digest:x}"); + format!("openshell/provider-credentials/{}", &hex[..40]) +} + +fn contains_placeholder_for_env_key(output: &str, key: &str) -> bool { + let legacy = format!("openshell:resolve:env:{key}"); + let revision_prefix = "openshell:resolve:env:v"; + let revision_suffix = format!("_{key}"); + output.split_whitespace().any(|token| { + token == legacy || (token.starts_with(revision_prefix) && token.ends_with(&revision_suffix)) + }) +} + +fn kubectl_command() -> tokio::process::Command { + let mut cmd = tokio::process::Command::new("kubectl"); + if let Ok(context) = std::env::var("OPENSHELL_E2E_KUBE_CONTEXT_ACTIVE") + && !context.trim().is_empty() + { + cmd.arg("--context").arg(context); + } + cmd +} + +async fn kubectl(args: &[&str]) -> Result { + let output = kubectl_command() + .args(args) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .output() + .await + .map_err(|err| format!("failed to spawn kubectl {args:?}: {err}"))?; + + let stdout = String::from_utf8_lossy(&output.stdout).to_string(); + let stderr = String::from_utf8_lossy(&output.stderr).to_string(); + let combined = format!("{stdout}{stderr}"); + if !output.status.success() { + return Err(format!( + "kubectl {args:?} failed (exit {:?}):\n{combined}", + output.status.code() + )); + } + Ok(combined) +} + +async fn bao(args: &[&str]) -> Result { + let namespace = vault_namespace(); + let pod = vault_pod(); + let token = vault_token(); + let token_env = format!("BAO_TOKEN={token}"); + let mut command = kubectl_command(); + command.args(["-n", &namespace, "exec", &pod, "--", "env", &token_env, "bao"]); + command.args(args); + let output = command + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .output() + .await + .map_err(|err| format!("failed to spawn bao {args:?}: {err}"))?; + + let stdout = String::from_utf8_lossy(&output.stdout).to_string(); + let stderr = String::from_utf8_lossy(&output.stderr).to_string(); + let combined = format!("{stdout}{stderr}"); + if !output.status.success() { + return Err(format!( + "bao {args:?} failed (exit {:?}):\n{combined}", + output.status.code() + )); + } + Ok(combined) +} + +async fn bao_with_stdin(args: &[&str], stdin: &str) -> Result { + let namespace = vault_namespace(); + let pod = vault_pod(); + let token = vault_token(); + let token_env = format!("BAO_TOKEN={token}"); + let mut command = kubectl_command(); + command.args([ + "-n", &namespace, "exec", "-i", &pod, "--", "env", &token_env, "bao", + ]); + command.args(args); + command.stdin(Stdio::piped()); + command.stdout(Stdio::piped()); + command.stderr(Stdio::piped()); + + let mut child = command + .spawn() + .map_err(|err| format!("failed to spawn bao {args:?}: {err}"))?; + let mut child_stdin = child + .stdin + .take() + .ok_or_else(|| "failed to open bao stdin".to_string())?; + child_stdin + .write_all(stdin.as_bytes()) + .await + .map_err(|err| format!("failed to write bao stdin: {err}"))?; + drop(child_stdin); + + let output = child + .wait_with_output() + .await + .map_err(|err| format!("failed to wait for bao {args:?}: {err}"))?; + let stdout = String::from_utf8_lossy(&output.stdout).to_string(); + let stderr = String::from_utf8_lossy(&output.stderr).to_string(); + let combined = format!("{stdout}{stderr}"); + if !output.status.success() { + return Err(format!( + "bao {args:?} failed (exit {:?}):\n{combined}", + output.status.code() + )); + } + Ok(combined) +} + +async fn delete_provider(name: &str) { + let mut cmd = openshell_cmd(); + cmd.arg("provider") + .arg("delete") + .arg(name) + .stdout(Stdio::null()) + .stderr(Stdio::null()); + let _ = cmd.status().await; +} + +async fn create_provider(name: &str, secret_value: &str) -> Result { + let credential = format!("{CREDENTIAL_KEY}={secret_value}"); + let (output, code) = run_cli(&[ + "provider", + "create", + "--name", + name, + "--type", + "openai", + "--credential", + &credential, + ]) + .await; + let clean = strip_ansi(&output); + if code != 0 { + return Err(format!("provider create {name} failed (exit {code}):\n{clean}")); + } + Ok(clean) +} + +async fn assert_provider_get_does_not_expose_secret( + provider_name: &str, + secret_value: &str, +) -> Result<(), String> { + let (output, code) = run_cli(&["provider", "get", provider_name]).await; + let clean = strip_ansi(&output); + if code != 0 { + return Err(format!( + "provider get {provider_name} failed (exit {code}):\n{clean}" + )); + } + if clean.contains(secret_value) { + return Err(format!( + "provider get {provider_name} exposed credential material:\n{clean}" + )); + } + Ok(()) +} + +async fn assert_provider_placeholder_available_in_sandbox( + provider_name: &str, + sandbox_name: &str, + secret_value: &str, +) -> Result<(), String> { + let guard = SandboxGuard::create(&[ + "--name", + sandbox_name, + "--provider", + provider_name, + "--no-keep", + "--no-auto-providers", + "--no-tty", + "--", + "bash", + "-lc", + r#"printf '%s\n' "$OPENAI_API_KEY""#, + ]) + .await?; + let clean = strip_ansi(&guard.create_output); + if !contains_placeholder_for_env_key(&clean, CREDENTIAL_KEY) { + return Err(format!( + "sandbox {sandbox_name} did not receive provider credential placeholder:\n{clean}" + )); + } + if clean.contains(secret_value) { + return Err(format!( + "sandbox {sandbox_name} output exposed credential material:\n{clean}" + )); + } + Ok(()) +} + +async fn configure_vault_storage() -> Result<(), String> { + let _ = bao(&["secrets", "enable", "-path=secret", "kv-v2"]).await; + let _ = bao(&["auth", "enable", "kubernetes"]).await; + bao(&[ + "write", + "auth/kubernetes/config", + "kubernetes_host=https://kubernetes.default.svc", + "kubernetes_ca_cert=@/var/run/secrets/kubernetes.io/serviceaccount/ca.crt", + ]) + .await?; + bao_with_stdin( + &["policy", "write", "openshell-provider-storage", "-"], + VAULT_POLICY, + ) + .await?; + bao(&[ + "write", + "auth/kubernetes/role/openshell-gateway", + "bound_service_account_names=openshell", + &format!("bound_service_account_namespaces={}", namespace()), + "policies=openshell-provider-storage", + "ttl=1h", + ]) + .await?; + Ok(()) +} + +async fn assert_kubernetes_secret_stored( + provider_name: &str, + secret_value: &str, +) -> Result<(), String> { + let namespace = namespace(); + let secret_name = managed_kubernetes_secret_name(provider_name); + let encoded = kubectl(&[ + "-n", + &namespace, + "get", + "secret", + &secret_name, + "-o", + &format!("jsonpath={{.data.{CREDENTIAL_KEY}}}"), + ]) + .await?; + let decoded = BASE64_STANDARD + .decode(encoded.trim()) + .map_err(|err| format!("failed to decode Kubernetes Secret value: {err}"))?; + let decoded = String::from_utf8(decoded) + .map_err(|err| format!("Kubernetes Secret value was not UTF-8: {err}"))?; + if decoded != secret_value { + return Err("Kubernetes Secret stored an unexpected credential value".to_string()); + } + Ok(()) +} + +async fn assert_kubernetes_secret_deleted(provider_name: &str) -> Result<(), String> { + let namespace = namespace(); + let secret_name = managed_kubernetes_secret_name(provider_name); + match kubectl(&["-n", &namespace, "get", "secret", &secret_name]).await { + Ok(output) => Err(format!( + "Kubernetes Secret '{secret_name}' still exists after provider deletion:\n{output}" + )), + Err(_) => Ok(()), + } +} + +async fn assert_vault_secret_stored( + provider_name: &str, + secret_value: &str, +) -> Result<(), String> { + let logical_path = managed_vault_path(provider_name); + let output = bao(&[ + "kv", + "get", + "-field=value", + &format!("secret/{logical_path}"), + ]) + .await?; + if output.trim() != secret_value { + return Err("Vault stored an unexpected credential value".to_string()); + } + Ok(()) +} + +async fn assert_vault_secret_deleted(provider_name: &str) -> Result<(), String> { + let logical_path = managed_vault_path(provider_name); + match bao(&[ + "kv", + "get", + "-field=value", + &format!("secret/{logical_path}"), + ]) + .await + { + Ok(output) => Err(format!( + "Vault secret '{logical_path}' still exists after provider deletion:\n{output}" + )), + Err(_) => Ok(()), + } +} + +async fn assert_backend_stored( + driver: &str, + provider_name: &str, + secret_value: &str, +) -> Result<(), String> { + match driver { + "kubernetes-secrets" => assert_kubernetes_secret_stored(provider_name, secret_value).await, + "vault" => assert_vault_secret_stored(provider_name, secret_value).await, + other => Err(format!("unsupported credential driver '{other}'")), + } +} + +async fn assert_backend_deleted(driver: &str, provider_name: &str) -> Result<(), String> { + match driver { + "kubernetes-secrets" => assert_kubernetes_secret_deleted(provider_name).await, + "vault" => assert_vault_secret_deleted(provider_name).await, + other => Err(format!("unsupported credential driver '{other}'")), + } +} + +#[tokio::test] +async fn provider_credentials_are_stored_in_configured_backend() { + assert!( + matches!( + std::env::var("OPENSHELL_E2E_CREDENTIAL_DRIVERS").as_deref(), + Ok("1") + ), + "run with `mise run e2e:kubernetes:credential-drivers` so the Kubernetes wrapper enables a credential storage driver" + ); + + let driver = credential_driver(); + let suffix = unique_suffix(); + let driver_slug = driver.replace('-', ""); + let provider_name = format!("cred-storage-{driver_slug}-{suffix}"); + let sandbox_name = format!("cred-storage-sandbox-{driver_slug}-{suffix}"); + let secret_value = format!("example-e2e-{driver_slug}-{suffix}"); + + delete_provider(&provider_name).await; + if driver == "vault" { + configure_vault_storage() + .await + .expect("configure Vault storage fixture"); + } + + let result: Result<(), String> = async { + create_provider(&provider_name, &secret_value).await?; + assert_provider_get_does_not_expose_secret(&provider_name, &secret_value).await?; + assert_backend_stored(&driver, &provider_name, &secret_value).await?; + assert_provider_placeholder_available_in_sandbox( + &provider_name, + &sandbox_name, + &secret_value, + ) + .await?; + Ok(()) + } + .await; + + delete_provider(&provider_name).await; + assert_backend_deleted(&driver, &provider_name) + .await + .expect("credential backend object should be deleted with provider"); + result.expect("credential storage e2e failed"); +} diff --git a/e2e/with-kube-gateway.sh b/e2e/with-kube-gateway.sh index bea1c01d3..547280226 100755 --- a/e2e/with-kube-gateway.sh +++ b/e2e/with-kube-gateway.sh @@ -39,6 +39,13 @@ # PostgreSQL Deployment and a matching Secret with a `uri` key before # installing OpenShell. This is used by HA CI so the gateway can run multiple # replicas without requiring the OpenShell chart to own a database. +# +# Credential-driver fixture: +# Set OPENSHELL_E2E_CREDENTIAL_DRIVERS=1 to enable one credential storage +# backend. Set OPENSHELL_E2E_CREDENTIAL_DRIVER to `kubernetes-secrets` or +# `vault`; the Rust `credential_drivers` e2e test validates the active +# backend. Vault mode installs a dev OpenBao fixture because it exposes the +# Vault-compatible API used by the driver. set -euo pipefail @@ -79,6 +86,11 @@ EXTERNAL_PG_FIXTURE_SERVICE="openshell-e2e-postgres" EXTERNAL_PG_FIXTURE_USER="openshell" EXTERNAL_PG_FIXTURE_PASSWORD="openshell-e2e-postgres" EXTERNAL_PG_FIXTURE_DATABASE="openshell" +VAULT_FIXTURE_DEPLOYED=0 +VAULT_NAMESPACE="${OPENSHELL_E2E_VAULT_NAMESPACE:-openbao}" +VAULT_RELEASE_NAME="${OPENSHELL_E2E_VAULT_RELEASE_NAME:-openbao}" +VAULT_CHART_VERSION="${OPENSHELL_E2E_OPENBAO_CHART_VERSION:-0.28.3}" +VAULT_DEV_ROOT_TOKEN="${OPENSHELL_E2E_VAULT_DEV_ROOT_TOKEN:-root}" # Isolate CLI/SDK gateway metadata from the developer's real config. export XDG_CONFIG_HOME="${WORKDIR}/config" @@ -129,6 +141,47 @@ cleanup_postgres_fixture() { EXTERNAL_PG_FIXTURE_SECRET="" } +deploy_vault_fixture() { + echo "Deploying OpenBao fixture for Vault credential-driver validation..." + + helmctl repo add openbao https://openbao.github.io/openbao-helm \ + >/dev/null 2>&1 || true + helmctl repo update openbao >/dev/null + helmctl upgrade --install "${VAULT_RELEASE_NAME}" openbao/openbao \ + --namespace "${VAULT_NAMESPACE}" --create-namespace \ + --version "${VAULT_CHART_VERSION}" \ + --set "server.dev.enabled=true" \ + --set "server.dev.devRootToken=${VAULT_DEV_ROOT_TOKEN}" \ + --set "injector.enabled=false" \ + --wait --timeout 5m + VAULT_FIXTURE_DEPLOYED=1 + + kctl -n "${VAULT_NAMESPACE}" wait \ + --for=condition=Ready pod \ + -l "app.kubernetes.io/name=openbao,component=server" \ + --timeout=300s + + export OPENSHELL_E2E_VAULT_NAMESPACE="${VAULT_NAMESPACE}" + export OPENSHELL_E2E_VAULT_POD="${VAULT_RELEASE_NAME}-0" + export OPENSHELL_E2E_VAULT_TOKEN="${VAULT_DEV_ROOT_TOKEN}" +} + +cleanup_vault_fixture() { + [ -n "${KUBE_CONTEXT}" ] || return 0 + [ -n "${VAULT_NAMESPACE}" ] || return 0 + + if command -v helm >/dev/null 2>&1; then + helmctl uninstall "${VAULT_RELEASE_NAME}" \ + --namespace "${VAULT_NAMESPACE}" --wait --timeout 60s \ + >/dev/null 2>&1 || true + fi + if command -v kubectl >/dev/null 2>&1; then + kctl delete namespace "${VAULT_NAMESPACE}" --wait=true --timeout=60s \ + --ignore-not-found >/dev/null 2>&1 || true + fi + VAULT_FIXTURE_DEPLOYED=0 +} + cleanup() { local exit_code=$? @@ -172,6 +225,10 @@ cleanup() { cleanup_postgres_fixture "${EXTERNAL_PG_FIXTURE_SECRET}" fi + if [ "${VAULT_FIXTURE_DEPLOYED}" = "1" ]; then + cleanup_vault_fixture + fi + if [ "${HELM_INSTALLED}" = "1" ] && [ -n "${KUBE_CONTEXT}" ] && [ -n "${NAMESPACE}" ]; then if command -v helm >/dev/null 2>&1; then helmctl uninstall "${RELEASE_NAME}" --namespace "${NAMESPACE}" --wait \ @@ -341,6 +398,12 @@ run_scenario() { export OPENSHELL_GATEWAY="${GATEWAY_NAME}" export OPENSHELL_E2E_DRIVER="kubernetes" + # Kubernetes e2e runs against k3d/kind-style Docker-backed clusters. Host + # fixture containers must use the same Docker host so published ports and + # cluster host-gateway aliases line up even on machines where Podman is also + # installed. + export CONTAINER_ENGINE="${CONTAINER_ENGINE:-docker}" + export OPENSHELL_E2E_KUBE_CONTEXT_ACTIVE="${KUBE_CONTEXT}" export OPENSHELL_E2E_SANDBOX_NAMESPACE="${NAMESPACE}" export OPENSHELL_PROVISION_TIMEOUT="${OPENSHELL_PROVISION_TIMEOUT:-300}" @@ -536,12 +599,33 @@ kctl apply -f "${_agent_sandbox_base}/manifest.yaml" kctl wait --for=condition=Established crd/sandboxes.agents.x-k8s.io --timeout=120s kctl -n agent-sandbox-system rollout status deployment/agent-sandbox-controller --timeout=300s +ACTIVE_CREDENTIAL_DRIVER="${OPENSHELL_E2E_CREDENTIAL_DRIVER:-kubernetes-secrets}" +if [ "${OPENSHELL_E2E_CREDENTIAL_DRIVERS:-0}" = "1" ] \ + && [ "${ACTIVE_CREDENTIAL_DRIVER}" = "vault" ]; then + deploy_vault_fixture +fi + helm_extra_args=() if [ -n "${HOST_GATEWAY_IP}" ]; then helm_extra_args+=(--set "server.hostGatewayIP=${HOST_GATEWAY_IP}") fi helm_values_args=(--values "${ROOT}/deploy/helm/openshell/ci/values-skaffold.yaml") +if [ "${OPENSHELL_E2E_CREDENTIAL_DRIVERS:-0}" = "1" ]; then + case "${ACTIVE_CREDENTIAL_DRIVER}" in + kubernetes-secrets) + helm_values_args+=(--values "${ROOT}/deploy/helm/openshell/ci/values-credential-driver-kubernetes-secrets.yaml") + ;; + vault) + helm_values_args+=(--values "${ROOT}/deploy/helm/openshell/ci/values-credential-driver-vault.yaml") + ;; + *) + echo "ERROR: OPENSHELL_E2E_CREDENTIAL_DRIVER must be kubernetes-secrets or vault, got '${ACTIVE_CREDENTIAL_DRIVER}'" >&2 + exit 2 + ;; + esac + export OPENSHELL_E2E_CREDENTIAL_DRIVER="${ACTIVE_CREDENTIAL_DRIVER}" +fi if [ -n "${OPENSHELL_E2E_KUBE_EXTRA_VALUES:-}" ]; then IFS=':' read -r -a extra_values_files <<< "${OPENSHELL_E2E_KUBE_EXTRA_VALUES}" for values_file in "${extra_values_files[@]}"; do @@ -666,6 +750,12 @@ else export OPENSHELL_GATEWAY="${GATEWAY_NAME}" export OPENSHELL_E2E_DRIVER="kubernetes" + # Kubernetes e2e runs against k3d/kind-style Docker-backed clusters. Host + # fixture containers must use the same Docker host so published ports and + # cluster host-gateway aliases line up even on machines where Podman is also + # installed. + export CONTAINER_ENGINE="${CONTAINER_ENGINE:-docker}" + export OPENSHELL_E2E_KUBE_CONTEXT_ACTIVE="${KUBE_CONTEXT}" export OPENSHELL_E2E_SANDBOX_NAMESPACE="${NAMESPACE}" export OPENSHELL_PROVISION_TIMEOUT="${OPENSHELL_PROVISION_TIMEOUT:-300}" diff --git a/proto/credential_driver.proto b/proto/credential_driver.proto new file mode 100644 index 000000000..4432aa658 --- /dev/null +++ b/proto/credential_driver.proto @@ -0,0 +1,117 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +syntax = "proto3"; + +package openshell.credentials.v1; + +import "datamodel.proto"; + +// Internal credential-driver contract used by the gateway. +// +// The gateway owns provider semantics and sandbox delivery. Credential drivers +// own backend-specific storage, deletion, authentication, and lookup for +// gateway-managed credential handles. +service CredentialDriver { + // Report driver identity and feature support. + rpc GetCapabilities(GetCredentialDriverCapabilitiesRequest) + returns (GetCredentialDriverCapabilitiesResponse); + + // Store or overwrite one provider credential and return an opaque handle. + rpc StoreCredential(StoreCredentialRequest) returns (StoreCredentialResponse); + + // Delete one provider credential handle. + rpc DeleteCredential(DeleteCredentialRequest) returns (DeleteCredentialResponse); + + // Resolve a batch of credential handles into string secret values. + rpc ResolveCredentials(ResolveCredentialsRequest) + returns (ResolveCredentialsResponse); + + // Optionally list discoverable credentials. Drivers may return UNIMPLEMENTED. + rpc ListCredentials(ListCredentialsRequest) returns (ListCredentialsResponse); +} + +message GetCredentialDriverCapabilitiesRequest {} + +message GetCredentialDriverCapabilitiesResponse { + // Human-readable driver name. + string driver_name = 1; + // Driver implementation version string. + string driver_version = 2; + // Backend kind, such as "kubernetes-secrets" or "vault". + string backend_kind = 3; + // True when ListCredentials is supported. + bool supports_list = 4; + // True when ResolveCredentials may return expires_at_ms values. + bool supports_expires_at = 5; +} + +message StoreCredentialRequest { + // Provider instance name supplied for audit and backend policy decisions. + string provider_name = 1; + // Provider credential key that will receive the resolved value at runtime. + string credential_key = 2; + // Secret value to store. Drivers must never log this field. + string value = 3; + // Existing handle to overwrite, if any. + openshell.datamodel.v1.CredentialHandle existing_handle = 4; +} + +message StoreCredentialResponse { + // Opaque handle for later resolution/deletion. + openshell.datamodel.v1.CredentialHandle handle = 1; +} + +message DeleteCredentialRequest { + // Provider instance name supplied for audit and backend policy decisions. + string provider_name = 1; + // Provider credential key that owns the handle. + string credential_key = 2; + // Opaque handle to delete. + openshell.datamodel.v1.CredentialHandle handle = 3; +} + +message DeleteCredentialResponse {} + +message ResolveCredentialsRequest { + repeated ResolveCredentialRequest credentials = 1; +} + +message ResolveCredentialRequest { + // Gateway-chosen opaque ID used to correlate batch responses. + string request_id = 1; + // Provider instance name supplied for audit and backend policy decisions. + string provider_name = 2; + // Provider credential key that will receive the resolved value. + string credential_key = 3; + // Opaque handle to resolve. + openshell.datamodel.v1.CredentialHandle handle = 4; +} + +message ResolveCredentialsResponse { + repeated ResolvedCredential credentials = 1; +} + +message ResolvedCredential { + // Echoes ResolveCredentialRequest.request_id. + string request_id = 1; + // Secret string value. Drivers must never log this field. + string value = 2; + // Expiration timestamp in milliseconds since Unix epoch, or zero when absent. + int64 expires_at_ms = 3; +} + +message ListCredentialsRequest {} + +message ListCredentialsResponse { + repeated ListedCredential credentials = 1; +} + +message ListedCredential { + // Opaque handle identifier or driver-owned display name. + string handle = 1; + // Available credential keys under the backend object. + repeated string keys = 2; + // Driver-owned non-secret metadata. + map metadata = 3; +} diff --git a/proto/datamodel.proto b/proto/datamodel.proto index f92d7b7a3..0e52f4726 100644 --- a/proto/datamodel.proto +++ b/proto/datamodel.proto @@ -28,17 +28,31 @@ message ObjectMeta { uint64 resource_version = 5; } +// Opaque handle for a provider credential stored by gateway credential storage. +// Handles are created by OpenShell and must not be authored by users. +message CredentialHandle { + // Internal storage owner or credential driver that owns this handle. + string driver = 1; + // Owner-owned opaque handle string. + string handle = 2; + // Owner-owned non-secret metadata. + map metadata = 3; +} + // Provider model stored by OpenShell. message Provider { // Kubernetes-style metadata (id, name, labels, timestamps, resource version). ObjectMeta metadata = 1; // Canonical provider type slug (for example: "claude", "gitlab"). string type = 2; - // Secret values used for authentication. + // Inline secret values used for authentication. map credentials = 3; // Non-secret provider configuration. map config = 4; // Expiration timestamps for credential values, keyed by credential/env var // name. A zero or missing value means the credential does not expire. map credential_expires_at_ms = 5; + // Opaque handles for secret values stored through gateway credential storage. + // This map is internal gateway state and is not accepted as user-authored input. + map credential_handles = 6; } diff --git a/tasks/scripts/setup-zig-cc-wrapper.sh b/tasks/scripts/setup-zig-cc-wrapper.sh index ea78f5d28..205c34ca8 100755 --- a/tasks/scripts/setup-zig-cc-wrapper.sh +++ b/tasks/scripts/setup-zig-cc-wrapper.sh @@ -66,6 +66,16 @@ EOF chmod +x "$wrapper_dir/$tool" done +for tool in ar ranlib; do + cat >"$wrapper_dir/$tool" <"$toolchain_file" <