Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
225 changes: 92 additions & 133 deletions crates/openshell-server/src/cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ use tracing::{info, warn};
use tracing_subscriber::EnvFilter;

use crate::certgen;
use crate::compute::{DockerComputeConfig, VmComputeConfig};
use crate::compute::driver_config::GuestTlsPaths;
use crate::config_file::{self, ConfigFile, GatewayFileSection};
use crate::defaults::{self, LocalTlsPaths};
use crate::{run_server, tracing_bus::TracingLogBus};
use crate::{ServerStartupConfig, run_server, tracing_bus::TracingLogBus};

/// `OpenShell` gateway process - gRPC and HTTP server with protocol multiplexing.
///
Expand Down Expand Up @@ -222,33 +222,29 @@ pub async fn run_cli() -> Result<()> {
}
}

async fn run_from_args(mut args: RunArgs, matches: ArgMatches) -> Result<()> {
fn prepare_server_config(args: &mut RunArgs, matches: &ArgMatches) -> Result<ServerStartupConfig> {
// Load TOML when explicitly requested, or from the default XDG location
// when that file exists. Missing default config is not an error: runtime
// defaults and OPENSHELL_* env vars are enough for package-managed starts.
let config_path = resolve_config_path(&args)?;
let config_path = resolve_config_path(args)?;
let file: Option<ConfigFile> = if let Some(path) = config_path {
Some(config_file::load(&path).map_err(|e| miette::miette!("{e}"))?)
} else {
None
};
if let Some(file) = file.as_ref() {
merge_file_into_args(&mut args, &file.openshell.gateway, &matches);
merge_file_into_args(args, &file.openshell.gateway, matches);
}

let local_tls = apply_runtime_defaults(&mut args)?;
let local_tls = apply_runtime_defaults(args)?;
let guest_tls = local_tls.as_ref().map(GuestTlsPaths::from);
let local_jwt = defaults::complete_local_jwt_config()?;

let tracing_log_bus = TracingLogBus::new();
tracing_log_bus.install_subscriber(
EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new(&args.log_level)),
);

let bind = SocketAddr::new(args.bind_address, args.port);

let has_client_ca = args.tls_client_ca.is_some();
let has_oidc = args.oidc_issuer.is_some();
let mtls_auth_enabled = resolve_mtls_auth_enabled(&args, &matches, file.as_ref());
let mtls_auth_enabled = resolve_mtls_auth_enabled(args, matches, file.as_ref());

if args.disable_tls && has_client_ca {
return Err(miette::miette!(
Expand All @@ -267,7 +263,7 @@ async fn run_from_args(mut args: RunArgs, matches: ArgMatches) -> Result<()> {
}
if mtls_auth_enabled
&& matches!(
effective_single_driver(&args),
effective_single_driver(args),
Some(ComputeDriverKind::Kubernetes)
)
{
Expand Down Expand Up @@ -318,14 +314,14 @@ async fn run_from_args(mut args: RunArgs, matches: ArgMatches) -> Result<()> {
let health_bind = resolve_aux_listener(
args.bind_address,
args.health_port,
&matches,
matches,
"health_port",
|| file_gateway.and_then(|g| g.health_bind_address),
);
let metrics_bind = resolve_aux_listener(
args.bind_address,
args.metrics_port,
&matches,
matches,
"metrics_port",
|| file_gateway.and_then(|g| g.metrics_bind_address),
);
Expand Down Expand Up @@ -404,15 +400,31 @@ async fn run_from_args(mut args: RunArgs, matches: ArgMatches) -> Result<()> {
config.gateway_jwt = Some(jwt);
}

let vm_config = build_vm_config(
file.as_ref(),
local_tls.as_ref(),
args.disable_tls,
args.port,
)?;
let docker_config = build_docker_config(file.as_ref(), local_tls.as_ref())?;
Ok(ServerStartupConfig {
config,
config_file: file,
guest_tls,
})
}

async fn run_from_args(mut args: RunArgs, matches: ArgMatches) -> Result<()> {
let prepared = prepare_server_config(&mut args, &matches)?;

let tracing_log_bus = TracingLogBus::new();
tracing_log_bus.install_subscriber(
EnvFilter::try_from_default_env()
.unwrap_or_else(|_| EnvFilter::new(&prepared.config.log_level)),
);

if args.disable_tls {
let has_client_ca = prepared
.config
.tls
.as_ref()
.and_then(|tls| tls.client_ca_path.as_ref())
.is_some();
let has_oidc = prepared.config.oidc.is_some();

if prepared.config.tls.is_none() {
warn!("TLS disabled — listening on plaintext HTTP");
} else {
info!("TLS enabled — listening on encrypted HTTPS");
Expand All @@ -421,40 +433,34 @@ async fn run_from_args(mut args: RunArgs, matches: ArgMatches) -> Result<()> {
if has_client_ca {
info!("TLS client certificate verification enabled");
}
if config.mtls_auth.enabled {
if prepared.config.mtls_auth.enabled {
info!("mTLS user authentication enabled");
}
if has_oidc {
info!("OIDC authentication enabled");
}
if config.auth.allow_unauthenticated_users {
if prepared.config.auth.allow_unauthenticated_users {
warn!(
"Unauthenticated user access enabled — only use this for trusted local development or a fully trusted fronting proxy"
);
}

if !config.auth.allow_unauthenticated_users
&& !config.mtls_auth.enabled
if !prepared.config.auth.allow_unauthenticated_users
&& !prepared.config.mtls_auth.enabled
&& !has_oidc
&& config.gateway_jwt.is_none()
&& prepared.config.gateway_jwt.is_none()
{
warn!(
"Neither mTLS user auth nor OIDC nor sandbox JWT auth is configured — \
the gateway has no authentication mechanism"
);
}

info!(bind = %config.bind_address, "Starting OpenShell server");
info!(bind = %prepared.config.bind_address, "Starting OpenShell server");

Box::pin(run_server(
config,
vm_config,
docker_config,
file,
tracing_log_bus,
))
.await
.into_diagnostic()
Box::pin(run_server(prepared, tracing_log_bus))
.await
.into_diagnostic()
}

fn parse_compute_driver(value: &str) -> std::result::Result<ComputeDriverKind, String> {
Expand Down Expand Up @@ -691,87 +697,6 @@ fn resolve_mtls_auth_enabled(
is_singleplayer_driver(args)
}

/// Build [`VmComputeConfig`] from the `[openshell.drivers.vm]` table
/// inherited from `[openshell.gateway]`.
fn build_vm_config(
file: Option<&ConfigFile>,
local_tls: Option<&LocalTlsPaths>,
disable_tls: bool,
gateway_port: u16,
) -> Result<VmComputeConfig> {
let mut cfg = if let Some(file) = file {
let merged = config_file::driver_table(
ComputeDriverKind::Vm,
&file.openshell.gateway,
file.openshell.drivers.get("vm"),
);
merged
.try_into::<VmComputeConfig>()
.map_err(|e| miette::miette!("invalid [openshell.drivers.vm] table: {e}"))?
} else {
VmComputeConfig::default()
};

if cfg.state_dir.as_os_str().is_empty() {
cfg.state_dir = VmComputeConfig::default_state_dir();
}
if cfg.grpc_endpoint.trim().is_empty() && (disable_tls || local_tls.is_some()) {
let scheme = if disable_tls { "http" } else { "https" };
cfg.grpc_endpoint = format!("{scheme}://127.0.0.1:{gateway_port}");
}
apply_guest_tls_defaults(
&mut cfg.guest_tls_ca,
&mut cfg.guest_tls_cert,
&mut cfg.guest_tls_key,
local_tls,
);
Ok(cfg)
}

/// Build [`DockerComputeConfig`] using the same inheritance pattern as
/// [`build_vm_config`].
fn build_docker_config(
file: Option<&ConfigFile>,
local_tls: Option<&LocalTlsPaths>,
) -> Result<DockerComputeConfig> {
let mut cfg = if let Some(file) = file {
let merged = config_file::driver_table(
ComputeDriverKind::Docker,
&file.openshell.gateway,
file.openshell.drivers.get("docker"),
);
merged
.try_into::<DockerComputeConfig>()
.map_err(|e| miette::miette!("invalid [openshell.drivers.docker] table: {e}"))?
} else {
DockerComputeConfig::default()
};
apply_guest_tls_defaults(
&mut cfg.guest_tls_ca,
&mut cfg.guest_tls_cert,
&mut cfg.guest_tls_key,
local_tls,
);
Ok(cfg)
}

fn apply_guest_tls_defaults(
ca: &mut Option<PathBuf>,
cert: &mut Option<PathBuf>,
key: &mut Option<PathBuf>,
local_tls: Option<&LocalTlsPaths>,
) {
if ca.is_none()
&& cert.is_none()
&& key.is_none()
&& let Some(paths) = local_tls
{
*ca = Some(paths.ca.clone());
*cert = Some(paths.client_cert.clone());
*key = Some(paths.client_key.clone());
}
}

#[cfg(test)]
mod tests {
use super::{Cli, command};
Expand Down Expand Up @@ -1613,6 +1538,54 @@ enable_loopback_service_http = false
);
}

#[test]
fn server_config_preparation_ignores_unselected_driver_tables() {
let _lock = ENV_LOCK
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
let state = tempfile::tempdir().unwrap();
let local_tls = tempfile::tempdir().unwrap();
let _g1 = EnvVarGuard::set("XDG_STATE_HOME", state.path().to_str().unwrap());
let _g2 = EnvVarGuard::set(
"OPENSHELL_LOCAL_TLS_DIR",
local_tls.path().to_str().unwrap(),
);
let config_path = state.path().join("gateway.toml");
std::fs::write(
&config_path,
r#"
[openshell.drivers.docker]
unknown_docker_key = true

[openshell.drivers.vm]
mem_mib = "not-a-number"
"#,
)
.unwrap();

let (mut args, matches) = parse_with_args(&[
"openshell-gateway",
"--config",
config_path.to_str().unwrap(),
"--db-url",
"sqlite::memory:",
"--drivers",
"podman",
"--disable-tls",
]);

let prepared =
super::prepare_server_config(&mut args, &matches).expect("server config is prepared");

assert_eq!(
prepared.config.compute_drivers,
vec![super::ComputeDriverKind::Podman]
);
let file = prepared.config_file.expect("config file is preserved");
assert!(file.openshell.drivers.contains_key("docker"));
assert!(file.openshell.drivers.contains_key("vm"));
}

#[test]
fn driver_inherits_shared_image_from_gateway_section() {
// [openshell.gateway].default_image inherits into the K8s driver
Expand Down Expand Up @@ -1659,18 +1632,4 @@ default_image = "k8s-specific:1.0"
.expect("deserializes");
assert_eq!(parsed.default_image, "k8s-specific:1.0");
}

#[test]
fn docker_config_reads_bind_mount_opt_in_from_driver_table() {
let file = config_file_from_toml(
r"
[openshell.drivers.docker]
enable_bind_mounts = true
",
);

let cfg = super::build_docker_config(Some(&file), None).expect("docker config");

assert!(cfg.enable_bind_mounts);
}
}
Loading
Loading