Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions crates/defguard_common/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -195,12 +195,12 @@ pub struct DefGuardConfig {

/// Maximum number of requests per second per client IP before rate limiting kicks in.
/// Set to 0 to disable rate limiting.
#[arg(long, env = "DEFGUARD_RATELIMIT_PERSECOND", default_value_t = 10)]
#[arg(long, env = "DEFGUARD_RATELIMIT_PERSECOND", default_value_t = 100)]
pub rate_limit_per_second: u64,

/// Maximum burst size for the rate limiter (token bucket capacity per client IP).
/// Set to 0 to disable rate limiting.
#[arg(long, env = "DEFGUARD_RATELIMIT_BURST", default_value_t = 100)]
#[arg(long, env = "DEFGUARD_RATELIMIT_BURST", default_value_t = 1000)]
pub rate_limit_burst: u32,
}

Expand Down Expand Up @@ -302,8 +302,8 @@ impl DefGuardConfig {
grpc_bind_address: None,
adopt_gateway: None,
adopt_edge: None,
rate_limit_per_second: 10,
rate_limit_burst: 100,
rate_limit_per_second: 0,
rate_limit_burst: 0,
};

config
Expand Down
106 changes: 63 additions & 43 deletions crates/defguard_core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ pub fn build_webapp(
incompatible_components: Arc<RwLock<IncompatibleComponents>>,
proxy_control_tx: tokio::sync::mpsc::Sender<ProxyControlMessage>,
tls_active: Arc<AtomicBool>,
server_config: &DefGuardConfig,
) -> Router {
let webapp: Router<AppState> = Router::new()
.route("/", get(index))
Expand All @@ -270,7 +271,8 @@ pub fn build_webapp(
.route("/svg/{*path}", get(svg))
.fallback_service(get(handle_404));

let webapp = webapp.nest(
// Collect all API routes into a single router so the rate-limiter can be scoped to API routes only
let api_router: Router<AppState> = Router::new().nest(
"/api/v1",
Router::new()
.route("/health", get(health_check))
Expand Down Expand Up @@ -430,7 +432,7 @@ pub fn build_webapp(
);

// Enterprise features
let webapp = webapp.nest(
let api_router = api_router.nest(
"/api/v1/openid",
Router::new()
.route(
Expand All @@ -448,15 +450,15 @@ pub fn build_webapp(
.route("/auth_info", get(get_auth_info)),
);

let webapp = webapp.nest(
let api_router = api_router.nest(
"/api/v1",
Router::new()
.route("/enterprise_info", get(check_enterprise_info))
.route("/test_directory_sync", get(test_dirsync_connection)),
);

// activity log stream
let webapp = webapp.nest(
let api_router = api_router.nest(
"/api/v1/activity_log_stream",
Router::new()
.route(
Expand All @@ -469,7 +471,7 @@ pub fn build_webapp(
),
);

let webapp = webapp
let api_router = api_router
.nest(
"/api/v1/oauth",
Router::new()
Expand All @@ -491,7 +493,7 @@ pub fn build_webapp(
get(openid_configuration),
);

let webapp = webapp.nest(
let api_router = api_router.nest(
"/api/v1/acl",
Router::new()
.route("/rule", get(list_acl_rules).post(create_acl_rule))
Expand Down Expand Up @@ -526,7 +528,7 @@ pub fn build_webapp(
.route("/destination/apply", put(apply_acl_destinations)),
);

let webapp = webapp.nest(
let api_router = api_router.nest(
"/api/v1",
Router::new()
// FIXME: Conflict; change /device/{device_id} to /device/{username}.
Expand Down Expand Up @@ -623,7 +625,7 @@ pub fn build_webapp(
.route("/license/check", post(license_check)),
);

let webapp = webapp.nest(
let api_router = api_router.nest(
"/api/v1/worker",
Router::new()
.route("/job", post(create_job))
Expand All @@ -633,6 +635,49 @@ pub fn build_webapp(
.layer(Extension(worker_state)),
);

// Setup rate limiter
debug!(
"Configuring rate limiter, per_second: {}, burst: {}",
server_config.rate_limit_per_second, server_config.rate_limit_burst
);
let governor_config = GovernorConfigBuilder::default()
.key_extractor(SmartIpKeyExtractor)
.per_second(server_config.rate_limit_per_second)
.burst_size(server_config.rate_limit_burst)
.finish();
let governor_config = if let Some(conf) = governor_config {
let governor_limiter = conf.limiter().clone();
spawn(async move {
loop {
sleep(RATE_LIMITER_CLEANUP_PERIOD).await;
debug!(
"Cleaning-up rate limiter storage, current size: {}",
governor_limiter.len()
);
governor_limiter.retain_recent();
}
});
Comment thread
wojcik91 marked this conversation as resolved.
info!(
"Rate limiter configured: {} req/s per IP, burst {}",
server_config.rate_limit_per_second, server_config.rate_limit_burst
);
Some(Arc::new(conf))
} else {
info!("Rate limiting disabled (per_second or burst is 0)");
None
};

// Apply rate-limiter to API routes only, leaving static asset routes unaffected.
// Use Arc::clone so the same underlying limiter is shared with the SSE routes below.
let api_router = if let Some(ref conf) = governor_config {
api_router.layer(GovernorLayer::new(Arc::clone(conf)))
} else {
api_router
};

// Merge rate-limited API routes into the static-assets webapp.
let webapp = webapp.merge(api_router);
Comment thread
wojcik91 marked this conversation as resolved.
Comment thread
wojcik91 marked this conversation as resolved.

// SSE routes are long-lived connections; they must not be wrapped by the
// request timeout. They are merged in after TimeoutLayer is applied to the
// main router so that they bypass the timeout while still receiving all
Expand All @@ -647,6 +692,11 @@ pub fn build_webapp(
get(setup_gateway_tls_stream),
),
);
let sse_routes = if let Some(conf) = governor_config {
sse_routes.layer(GovernorLayer::new(conf))
} else {
sse_routes
};

let app_state = AppState::new(
pool.clone(),
Expand Down Expand Up @@ -735,7 +785,9 @@ pub async fn run_web_server(

let tls_active = Arc::new(AtomicBool::new(false));

let mut webapp = build_webapp(
let server_config = server_config();

let webapp = build_webapp(
webhook_tx,
webhook_rx,
wireguard_tx,
Expand All @@ -748,43 +800,11 @@ pub async fn run_web_server(
incompatible_components,
proxy_control_tx,
Arc::clone(&tls_active),
server_config,
);
info!("Started web services");
let server_config = server_config();

// Setup rate limiter. Both fields default to non-zero so limiting is on by default;
// operators can set either env var to 0 to disable.
debug!(
"Configuring rate limiter, per_second: {}, burst: {}",
server_config.rate_limit_per_second, server_config.rate_limit_burst
);
let governor_conf = GovernorConfigBuilder::default()
.key_extractor(SmartIpKeyExtractor)
.per_second(server_config.rate_limit_per_second)
.burst_size(server_config.rate_limit_burst)
.finish();
if let Some(conf) = governor_conf {
let governor_limiter = conf.limiter().clone();
spawn(async move {
loop {
sleep(RATE_LIMITER_CLEANUP_PERIOD).await;
debug!(
"Cleaning-up rate limiter storage, current size: {}",
governor_limiter.len()
);
governor_limiter.retain_recent();
}
});
info!(
"Rate limiter configured: {} req/s per IP, burst {}",
server_config.rate_limit_per_second, server_config.rate_limit_burst
);
webapp = webapp.layer(GovernorLayer::new(conf));
} else {
info!("Rate limiting disabled (per_second or burst is 0)");
}

webapp = apply_security_layers(webapp, Arc::clone(&tls_active));
let webapp = apply_security_layers(webapp, Arc::clone(&tls_active));

let addr = SocketAddr::new(
server_config
Expand Down
1 change: 1 addition & 0 deletions crates/defguard_core/tests/integration/api/common/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ pub(crate) async fn make_base_client(
Arc::default(),
proxy_control_tx,
Arc::clone(&tls_active),
&config,
);
let webapp = apply_security_layers(webapp, tls_active);

Expand Down
3 changes: 2 additions & 1 deletion crates/defguard_core/tests/integration/api/proxy_certs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ async fn make_test_client_with_proxy_rx(
.await
.expect("Could not bind ephemeral socket");
let port = listener.local_addr().unwrap().port();
let _config = init_config(Some(&format!("http://localhost:{port}")), &pool).await;
let config = init_config(Some(&format!("http://localhost:{port}")), &pool).await;
initialize_users(&pool).await;
initialize_current_settings(&pool)
.await
Expand Down Expand Up @@ -150,6 +150,7 @@ async fn make_test_client_with_proxy_rx(
Arc::default(),
proxy_control_tx,
Arc::new(AtomicBool::new(false)),
&config,
);

let client = TestClient::new(webapp, listener, api_event_rx);
Expand Down
Loading