diff --git a/crates/defguard_common/src/config.rs b/crates/defguard_common/src/config.rs index 63f85cce5..586d5098b 100644 --- a/crates/defguard_common/src/config.rs +++ b/crates/defguard_common/src/config.rs @@ -195,12 +195,12 @@ pub struct DefGuardConfig { /// Maximum number of requests per second per client IP before rate limiting kicks in. /// Set to 0 to disable rate limiting. - #[arg(long, env = "DEFGUARD_RATELIMIT_PERSECOND", default_value_t = 10)] + #[arg(long, env = "DEFGUARD_RATELIMIT_PERSECOND", default_value_t = 100)] pub rate_limit_per_second: u64, /// Maximum burst size for the rate limiter (token bucket capacity per client IP). /// Set to 0 to disable rate limiting. - #[arg(long, env = "DEFGUARD_RATELIMIT_BURST", default_value_t = 100)] + #[arg(long, env = "DEFGUARD_RATELIMIT_BURST", default_value_t = 1000)] pub rate_limit_burst: u32, } @@ -302,8 +302,8 @@ impl DefGuardConfig { grpc_bind_address: None, adopt_gateway: None, adopt_edge: None, - rate_limit_per_second: 10, - rate_limit_burst: 100, + rate_limit_per_second: 0, + rate_limit_burst: 0, }; config diff --git a/crates/defguard_core/src/lib.rs b/crates/defguard_core/src/lib.rs index de625351d..49f52dc10 100644 --- a/crates/defguard_core/src/lib.rs +++ b/crates/defguard_core/src/lib.rs @@ -261,6 +261,7 @@ pub fn build_webapp( incompatible_components: Arc>, proxy_control_tx: tokio::sync::mpsc::Sender, tls_active: Arc, + server_config: &DefGuardConfig, ) -> Router { let webapp: Router = Router::new() .route("/", get(index)) @@ -270,7 +271,8 @@ pub fn build_webapp( .route("/svg/{*path}", get(svg)) .fallback_service(get(handle_404)); - let webapp = webapp.nest( + // Collect all API routes into a single router so the rate-limiter can be scoped to API routes only + let api_router: Router = Router::new().nest( "/api/v1", Router::new() .route("/health", get(health_check)) @@ -430,7 +432,7 @@ pub fn build_webapp( ); // Enterprise features - let webapp = webapp.nest( + let api_router = api_router.nest( "/api/v1/openid", Router::new() .route( @@ -448,7 +450,7 @@ pub fn build_webapp( .route("/auth_info", get(get_auth_info)), ); - let webapp = webapp.nest( + let api_router = api_router.nest( "/api/v1", Router::new() .route("/enterprise_info", get(check_enterprise_info)) @@ -456,7 +458,7 @@ pub fn build_webapp( ); // activity log stream - let webapp = webapp.nest( + let api_router = api_router.nest( "/api/v1/activity_log_stream", Router::new() .route( @@ -469,7 +471,7 @@ pub fn build_webapp( ), ); - let webapp = webapp + let api_router = api_router .nest( "/api/v1/oauth", Router::new() @@ -491,7 +493,7 @@ pub fn build_webapp( get(openid_configuration), ); - let webapp = webapp.nest( + let api_router = api_router.nest( "/api/v1/acl", Router::new() .route("/rule", get(list_acl_rules).post(create_acl_rule)) @@ -526,7 +528,7 @@ pub fn build_webapp( .route("/destination/apply", put(apply_acl_destinations)), ); - let webapp = webapp.nest( + let api_router = api_router.nest( "/api/v1", Router::new() // FIXME: Conflict; change /device/{device_id} to /device/{username}. @@ -623,7 +625,7 @@ pub fn build_webapp( .route("/license/check", post(license_check)), ); - let webapp = webapp.nest( + let api_router = api_router.nest( "/api/v1/worker", Router::new() .route("/job", post(create_job)) @@ -633,6 +635,49 @@ pub fn build_webapp( .layer(Extension(worker_state)), ); + // Setup rate limiter + debug!( + "Configuring rate limiter, per_second: {}, burst: {}", + server_config.rate_limit_per_second, server_config.rate_limit_burst + ); + let governor_config = GovernorConfigBuilder::default() + .key_extractor(SmartIpKeyExtractor) + .per_second(server_config.rate_limit_per_second) + .burst_size(server_config.rate_limit_burst) + .finish(); + let governor_config = if let Some(conf) = governor_config { + let governor_limiter = conf.limiter().clone(); + spawn(async move { + loop { + sleep(RATE_LIMITER_CLEANUP_PERIOD).await; + debug!( + "Cleaning-up rate limiter storage, current size: {}", + governor_limiter.len() + ); + governor_limiter.retain_recent(); + } + }); + info!( + "Rate limiter configured: {} req/s per IP, burst {}", + server_config.rate_limit_per_second, server_config.rate_limit_burst + ); + Some(Arc::new(conf)) + } else { + info!("Rate limiting disabled (per_second or burst is 0)"); + None + }; + + // Apply rate-limiter to API routes only, leaving static asset routes unaffected. + // Use Arc::clone so the same underlying limiter is shared with the SSE routes below. + let api_router = if let Some(ref conf) = governor_config { + api_router.layer(GovernorLayer::new(Arc::clone(conf))) + } else { + api_router + }; + + // Merge rate-limited API routes into the static-assets webapp. + let webapp = webapp.merge(api_router); + // SSE routes are long-lived connections; they must not be wrapped by the // request timeout. They are merged in after TimeoutLayer is applied to the // main router so that they bypass the timeout while still receiving all @@ -647,6 +692,11 @@ pub fn build_webapp( get(setup_gateway_tls_stream), ), ); + let sse_routes = if let Some(conf) = governor_config { + sse_routes.layer(GovernorLayer::new(conf)) + } else { + sse_routes + }; let app_state = AppState::new( pool.clone(), @@ -735,7 +785,9 @@ pub async fn run_web_server( let tls_active = Arc::new(AtomicBool::new(false)); - let mut webapp = build_webapp( + let server_config = server_config(); + + let webapp = build_webapp( webhook_tx, webhook_rx, wireguard_tx, @@ -748,43 +800,11 @@ pub async fn run_web_server( incompatible_components, proxy_control_tx, Arc::clone(&tls_active), + server_config, ); info!("Started web services"); - let server_config = server_config(); - - // Setup rate limiter. Both fields default to non-zero so limiting is on by default; - // operators can set either env var to 0 to disable. - debug!( - "Configuring rate limiter, per_second: {}, burst: {}", - server_config.rate_limit_per_second, server_config.rate_limit_burst - ); - let governor_conf = GovernorConfigBuilder::default() - .key_extractor(SmartIpKeyExtractor) - .per_second(server_config.rate_limit_per_second) - .burst_size(server_config.rate_limit_burst) - .finish(); - if let Some(conf) = governor_conf { - let governor_limiter = conf.limiter().clone(); - spawn(async move { - loop { - sleep(RATE_LIMITER_CLEANUP_PERIOD).await; - debug!( - "Cleaning-up rate limiter storage, current size: {}", - governor_limiter.len() - ); - governor_limiter.retain_recent(); - } - }); - info!( - "Rate limiter configured: {} req/s per IP, burst {}", - server_config.rate_limit_per_second, server_config.rate_limit_burst - ); - webapp = webapp.layer(GovernorLayer::new(conf)); - } else { - info!("Rate limiting disabled (per_second or burst is 0)"); - } - webapp = apply_security_layers(webapp, Arc::clone(&tls_active)); + let webapp = apply_security_layers(webapp, Arc::clone(&tls_active)); let addr = SocketAddr::new( server_config diff --git a/crates/defguard_core/tests/integration/api/common/mod.rs b/crates/defguard_core/tests/integration/api/common/mod.rs index f628aafe6..aef405fbe 100644 --- a/crates/defguard_core/tests/integration/api/common/mod.rs +++ b/crates/defguard_core/tests/integration/api/common/mod.rs @@ -153,6 +153,7 @@ pub(crate) async fn make_base_client( Arc::default(), proxy_control_tx, Arc::clone(&tls_active), + &config, ); let webapp = apply_security_layers(webapp, tls_active); diff --git a/crates/defguard_core/tests/integration/api/proxy_certs.rs b/crates/defguard_core/tests/integration/api/proxy_certs.rs index 5f0fb378b..7a05f98c6 100644 --- a/crates/defguard_core/tests/integration/api/proxy_certs.rs +++ b/crates/defguard_core/tests/integration/api/proxy_certs.rs @@ -99,7 +99,7 @@ async fn make_test_client_with_proxy_rx( .await .expect("Could not bind ephemeral socket"); let port = listener.local_addr().unwrap().port(); - let _config = init_config(Some(&format!("http://localhost:{port}")), &pool).await; + let config = init_config(Some(&format!("http://localhost:{port}")), &pool).await; initialize_users(&pool).await; initialize_current_settings(&pool) .await @@ -150,6 +150,7 @@ async fn make_test_client_with_proxy_rx( Arc::default(), proxy_control_tx, Arc::new(AtomicBool::new(false)), + &config, ); let client = TestClient::new(webapp, listener, api_event_rx);