diff --git a/Cargo.lock b/Cargo.lock index bcb2038..d784e79 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1078,6 +1078,7 @@ dependencies = [ "futures", "hex", "http-body-util", + "httpdate", "hyper", "hyper-util", "itertools 0.12.1", @@ -1113,6 +1114,7 @@ dependencies = [ "tracing-subscriber", "tracing-throttle", "urlencoding", + "webpki-roots 0.26.11", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index b692177..df9d029 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -37,6 +37,8 @@ ws = [ "dep:hyper", "dep:hyper-util", "dep:http-body-util", + "dep:httpdate", + "dep:webpki-roots", # "dep:", ] database = [ @@ -77,7 +79,7 @@ log_throttling = ["dep:tracing-throttle"] [dependencies] serde = { version = "1.0", features = ["derive"] } -serde_json = "1.0" +serde_json = { version = "1.0", features = ["raw_value"] } eyre = "0.6" convert_case = "0.6" itertools = "0.12" @@ -127,10 +129,12 @@ regex = { version = "1.5", optional = true } rev_lines = { version = "0.3", optional = true } lazy_static = { version = "1.5", optional = true } async-trait = { version = "0.1", optional = true } -hyper = { version = "1", features = ["http1", "server"], optional = true } -hyper-util = { version = "0.1", features = ["tokio"], optional = true } +hyper = { version = "1", features = ["http1", "http2", "server", "client"], optional = true } +hyper-util = { version = "0.1", features = ["tokio", "server-auto"], optional = true } http-body-util = { version = "0.1", optional = true } +httpdate = { version = "1.0", optional = true } tracing-throttle = { version = "0.4", features = ["async"], optional = true } +webpki-roots = { version = "0.26", optional = true } # OpenTelemetry dependencies (default feature) opentelemetry = { version = "0.31", features = ["logs"] } diff --git a/src/libs/ws/basics.rs b/src/libs/ws/basics.rs index 9f5f558..a5d3a03 100644 --- a/src/libs/ws/basics.rs +++ b/src/libs/ws/basics.rs @@ -1,6 +1,8 @@ use parking_lot::RwLock; use serde::*; use serde_json::Value; +use serde_json::value::RawValue; +use std::collections::HashSet; use std::fmt::Debug; use std::net::SocketAddr; use std::sync::Arc; @@ -43,9 +45,8 @@ impl WsConnection { self.user_id.load(std::sync::atomic::Ordering::Acquire) } - pub fn get_roles(&self) -> Vec { - let roles = self.roles.read(); - roles.as_ref().clone() + pub fn get_roles(&self) -> Arc> { + self.roles.read().clone() } pub fn set_user_id(&self, user_id: u64) { @@ -59,8 +60,8 @@ impl WsConnection { } } -pub type WsSuccessResponse = WsSuccessResponseGeneric; -pub type WsStreamResponse = WsStreamResponseGeneric; +pub type WsSuccessResponse = WsSuccessResponseGeneric>; +pub type WsStreamResponse = WsStreamResponseGeneric>; #[derive(Debug, Serialize, Deserialize, Clone)] pub struct WsForwardedResponse { @@ -101,11 +102,12 @@ pub enum WsResponseGeneric { Close, } -pub type WsResponseValue = WsResponseGeneric; +pub type WsResponseValue = WsResponseGeneric>; pub struct WsEndpoint { pub schema: EndpointSchema, pub handler: Arc, + pub allowed_roles: HashSet, } pub fn internal_error_to_resp( diff --git a/src/libs/ws/client.rs b/src/libs/ws/client.rs index e053333..d554b99 100644 --- a/src/libs/ws/client.rs +++ b/src/libs/ws/client.rs @@ -1,15 +1,26 @@ -use eyre::{Context, Result, bail, eyre}; +use std::net::SocketAddr; +use std::sync::Arc; + +use bytes::Bytes; +use eyre::{Context, Result, bail, ensure, eyre}; use futures::SinkExt; use futures::StreamExt; +use http_body_util::Empty; +use hyper::StatusCode; +use hyper::client::conn::http2; +use hyper_util::rt::{TokioExecutor, TokioIo}; use reqwest::header::HeaderValue; +use rustls::pki_types::ServerName; use serde::Serialize; use serde::de::DeserializeOwned; use tokio::net::TcpStream; +use tokio_rustls::TlsConnector; use tokio_tungstenite::MaybeTlsStream; use tokio_tungstenite::WebSocketStream; use tokio_tungstenite::connect_async; use tokio_tungstenite::tungstenite::Message; use tokio_tungstenite::tungstenite::client::IntoClientRequest; +use tokio_tungstenite::tungstenite::protocol::Role; use tracing::*; use crate::libs::log::LogLevel; @@ -28,11 +39,49 @@ pub trait WsRequest: Serialize + DeserializeOwned + Send + Sync + Clone { pub trait WsResponse: Serialize + DeserializeOwned + Send + Sync + Clone { type Request: WsRequest; } + +// --------------------------------------------------------------------------- +// Public types +// --------------------------------------------------------------------------- + +/// Which HTTP version to use when connecting. +#[derive(Debug, Clone, Copy, Default)] +pub enum WsVersionMode { + /// HTTP/1.1 upgrade handshake (existing behaviour). + #[default] + Http1Only, + /// HTTP/2 Extended CONNECT (RFC 8441): h2 for `wss://`, h2c for `ws://`. + Http2Only, + /// Try HTTP/2 first; fall back to HTTP/1.1 on any error. + Auto, +} + +/// Response metadata returned by [`WsClientBuilder::build`]. +pub struct WsConnectResponse { + pub status: u16, + pub headers: Vec<(String, String)>, +} + +// --------------------------------------------------------------------------- +// Internal stream abstraction +// --------------------------------------------------------------------------- + +enum WsStream { + H1(WebSocketStream>), + H2(WebSocketStream>), +} + +// --------------------------------------------------------------------------- +// WsClient +// --------------------------------------------------------------------------- + pub struct WsClient { - stream: WebSocketStream>, + stream: WsStream, seq: u32, } + impl WsClient { + // Existing HTTP/1.1 constructor — unchanged externally. pub async fn new( connect_addr: &str, protocol_header: &str, @@ -57,10 +106,38 @@ impl WsClient { .await .context("Failed to connect to endpoint")?; Ok((Self { - stream: ws_stream, + stream: WsStream::H1(ws_stream), seq: 0, }, response)) } + + // --- Private stream helpers ------------------------------------------- + + async fn stream_send(&mut self, msg: Message) -> Result<()> { + match &mut self.stream { + WsStream::H1(s) => s.send(msg).await?, + WsStream::H2(s) => s.send(msg).await?, + } + Ok(()) + } + + async fn stream_next(&mut self) -> Option> { + match &mut self.stream { + WsStream::H1(s) => s.next().await, + WsStream::H2(s) => s.next().await, + } + } + + async fn stream_close(&mut self) -> Result<()> { + match &mut self.stream { + WsStream::H1(s) => s.close(None).await?, + WsStream::H2(s) => s.close(None).await?, + } + Ok(()) + } + + // --- Public API (signatures unchanged) -------------------------------- + pub async fn send_req(&mut self, method: u32, params: impl Serialize) -> Result<()> { self.seq += 1; let req = serde_json::to_string(&WsRequestGeneric { @@ -69,32 +146,28 @@ impl WsClient { params, })?; debug!("send req: {}", req); - self.stream.send(Message::Text(req.into())).await?; - Ok(()) + self.stream_send(Message::Text(req.into())).await } + /// Send a fully pre-serialized request message. - /// The caller is responsible for wrapping params in the request envelope - /// (method, seq, params) and serializing to bytes ahead of time. - /// This avoids any allocation or serialization in the hot path. pub async fn send_raw(&mut self, request_bytes: &[u8]) -> Result<()> { let text = std::str::from_utf8(request_bytes).context("Invalid UTF-8 in request bytes")?; - self.stream.send(Message::Text(text.into())).await?; - Ok(()) + self.stream_send(Message::Text(text.into())).await } + pub async fn recv_raw(&mut self) -> Result { let msg = self - .stream - .next() + .stream_next() .await .ok_or(eyre!("Connection closed"))??; let resp: WsResponseValue = serde_json::from_str(&msg.to_string())?; Ok(resp) } + pub async fn recv_resp(&mut self) -> Result { loop { let msg = self - .stream - .next() + .stream_next() .await .ok_or(eyre!("Connection closed"))??; match msg { @@ -137,19 +210,273 @@ impl WsClient { } } Message::Close(_) => { - self.stream.close(None).await?; + self.stream_close().await?; bail!("Connection closed") } _ => {} } } } + pub async fn request(&mut self, params: T) -> Result { self.send_req(T::METHOD_ID, params).await?; self.recv_resp().await } + pub async fn close(mut self) -> Result<()> { - self.stream.close(None).await?; - Ok(()) + self.stream_close().await } } + +// --------------------------------------------------------------------------- +// WsClientBuilder +// --------------------------------------------------------------------------- + +pub struct WsClientBuilder { + mode: WsVersionMode, + protocol_header: String, + headers: Vec<(&'static str, &'static str)>, +} + +impl WsClientBuilder { + pub fn new() -> Self { + Self { + mode: WsVersionMode::Http1Only, + protocol_header: String::new(), + headers: Vec::new(), + } + } + + pub fn mode(mut self, mode: WsVersionMode) -> Self { + self.mode = mode; + self + } + + pub fn protocol_header(mut self, protocol: impl Into) -> Self { + self.protocol_header = protocol.into(); + self + } + + pub fn header(mut self, key: &'static str, value: &'static str) -> Self { + self.headers.push((key, value)); + self + } + + pub fn headers(mut self, headers: Vec<(&'static str, &'static str)>) -> Self { + self.headers.extend(headers); + self + } + + pub async fn build(self, connect_addr: &str) -> Result<(WsClient, WsConnectResponse)> { + match self.mode { + WsVersionMode::Http1Only => { + connect_h1(connect_addr, &self.protocol_header, &self.headers).await + } + WsVersionMode::Http2Only => { + connect_h2(connect_addr, &self.protocol_header, &self.headers).await + } + WsVersionMode::Auto => { + match connect_h2(connect_addr, &self.protocol_header, &self.headers).await { + Ok(result) => Ok(result), + Err(h2_err) => { + debug!("H2 connection failed ({}), falling back to HTTP/1.1", h2_err); + connect_h1(connect_addr, &self.protocol_header, &self.headers).await + } + } + } + } + } +} + +impl Default for WsClientBuilder { + fn default() -> Self { + Self::new() + } +} + +// --------------------------------------------------------------------------- +// Private helpers +// --------------------------------------------------------------------------- + +struct ParsedUrl { + tls: bool, + host: String, + port: u16, + path: String, +} + +fn parse_ws_url(url: &str) -> Result { + let (tls, rest) = if let Some(r) = url.strip_prefix("wss://") { + (true, r) + } else if let Some(r) = url.strip_prefix("ws://") { + (false, r) + } else { + bail!("URL must start with ws:// or wss://: {}", url) + }; + + let (authority, path) = match rest.find('/') { + Some(i) => (&rest[..i], rest[i..].to_owned()), + None => (rest, "/".to_owned()), + }; + + let (host, port) = match authority.rfind(':') { + Some(i) => { + let h = authority[..i].to_owned(); + let p: u16 = authority[i + 1..].parse().context("Invalid port in URL")?; + (h, p) + } + None => (authority.to_owned(), if tls { 443 } else { 80 }), + }; + + Ok(ParsedUrl { tls, host, port, path }) +} + +async fn connect_h1( + addr: &str, + protocol_header: &str, + headers: &[(&'static str, &'static str)], +) -> Result<(WsClient, WsConnectResponse)> { + let mut req = <&str as IntoClientRequest>::into_client_request(addr)?; + if !protocol_header.is_empty() { + req.headers_mut().insert( + "Sec-WebSocket-Protocol", + HeaderValue::from_str(protocol_header)?, + ); + } + for (k, v) in headers { + req.headers_mut().insert(*k, HeaderValue::from_str(v)?); + } + + let (ws_stream, response) = connect_async(req) + .await + .context("Failed to connect to endpoint")?; + + let conn_resp = WsConnectResponse { + status: response.status().as_u16(), + headers: response + .headers() + .iter() + .map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string())) + .collect(), + }; + + Ok((WsClient { stream: WsStream::H1(ws_stream), seq: 0 }, conn_resp)) +} + +async fn connect_h2( + addr: &str, + protocol_header: &str, + headers: &[(&'static str, &'static str)], +) -> Result<(WsClient, WsConnectResponse)> { + let ParsedUrl { tls, host, port, path } = parse_ws_url(addr)?; + + let sock_addr: SocketAddr = tokio::net::lookup_host(format!("{}:{}", host, port)) + .await + .context("DNS resolution failed")? + .next() + .ok_or_else(|| eyre!("No addresses returned for {}:{}", host, port))?; + + let tcp = TcpStream::connect(sock_addr) + .await + .context("TCP connect failed")?; + tcp.set_nodelay(true)?; + + if tls { + let tls_stream = make_tls_stream(tcp, &host).await?; + h2_upgrade(TokioIo::new(tls_stream), &host, &path, tls, protocol_header, headers).await + } else { + h2_upgrade(TokioIo::new(tcp), &host, &path, tls, protocol_header, headers).await + } +} + +async fn h2_upgrade( + io: T, + host: &str, + path: &str, + tls: bool, + protocol_header: &str, + headers: &[(&'static str, &'static str)], +) -> Result<(WsClient, WsConnectResponse)> +where + T: hyper::rt::Read + hyper::rt::Write + Unpin + Send + 'static, +{ + let (mut sender, conn) = http2::Builder::new(TokioExecutor::new()) + .handshake(io) + .await + .context("HTTP/2 handshake failed")?; + tokio::spawn(async move { + if let Err(e) = conn.await { + debug!("H2 connection driver exited: {}", e); + } + }); + + let scheme = if tls { "https" } else { "http" }; + let mut builder = hyper::Request::builder() + .method(hyper::Method::CONNECT) + .uri(format!("{}://{}{}", scheme, host, path)) + .header("sec-websocket-version", "13"); + if !protocol_header.is_empty() { + builder = builder.header("sec-websocket-protocol", protocol_header); + } + for (k, v) in headers { + builder = builder.header(*k, *v); + } + let mut request = builder + .body(Empty::::new()) + .context("Failed to build H2 upgrade request")?; + + // :protocol pseudo-header must be set as an extension, not a raw header + request + .extensions_mut() + .insert(hyper::ext::Protocol::from_static("websocket")); + + // Capture the upgrade future BEFORE sending (stores the oneshot sender in request extensions) + let on_upgrade = hyper::upgrade::on(&mut request); + + let response = sender + .send_request(request) + .await + .context("Failed to send H2 upgrade request")?; + + ensure!( + response.status() == StatusCode::OK, + "H2 WebSocket upgrade rejected: {}", + response.status() + ); + + let conn_resp = WsConnectResponse { + status: response.status().as_u16(), + headers: response + .headers() + .iter() + .map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string())) + .collect(), + }; + + let upgraded = on_upgrade.await.context("H2 upgrade failed")?; + let ws = + WebSocketStream::from_raw_socket(TokioIo::new(upgraded), Role::Client, None).await; + + Ok((WsClient { stream: WsStream::H2(ws), seq: 0 }, conn_resp)) +} + +async fn make_tls_stream( + tcp: TcpStream, + host: &str, +) -> Result> { + let mut root_store = rustls::RootCertStore::empty(); + root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned()); + + let mut tls_config = rustls::ClientConfig::builder() + .with_root_certificates(root_store) + .with_no_client_auth(); + tls_config.alpn_protocols = vec![b"h2".to_vec()]; + + let connector = TlsConnector::from(Arc::new(tls_config)); + let server_name = + ServerName::try_from(host.to_owned()).context("Invalid TLS server name")?; + connector + .connect(server_name, tcp) + .await + .context("TLS handshake failed") +} diff --git a/src/libs/ws/listener.rs b/src/libs/ws/listener.rs index dc7d52b..9f59e67 100644 --- a/src/libs/ws/listener.rs +++ b/src/libs/ws/listener.rs @@ -66,10 +66,11 @@ impl TlsListener { let key = load_private_key(&priv_cert)?; let tls_cfg = { - let cfg = + let mut cfg = rustls::ServerConfig::builder_with_protocol_versions(&[&rustls::version::TLS13]) .with_no_client_auth() .with_single_cert(certs, key)?; + cfg.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()]; Arc::new(cfg) }; let acceptor = TlsAcceptor::from(tls_cfg); diff --git a/src/libs/ws/push.rs b/src/libs/ws/push.rs index 182f536..05ebabc 100644 --- a/src/libs/ws/push.rs +++ b/src/libs/ws/push.rs @@ -73,7 +73,7 @@ impl> SubscribeManager { filter: impl Fn(&RequestContext) -> bool, ) { if let Some(mut topic_2) = self.topics.get_mut(&topic) { - let data = serde_json::to_value(msg).unwrap(); + let data = serde_json::value::to_raw_value(msg).expect("Failed to serialize stream data"); let mut dead_connections = vec![]; let stream_code = topic.into(); for sub in topic_2.subscribers.values_mut() { diff --git a/src/libs/ws/server.rs b/src/libs/ws/server.rs index e4ba508..5cabd2e 100644 --- a/src/libs/ws/server.rs +++ b/src/libs/ws/server.rs @@ -2,10 +2,10 @@ use eyre::{ContextCompat, Result, bail, eyre}; use http_body_util::Empty; use hyper::body::{Bytes, Incoming}; use hyper::header::{CONNECTION, SEC_WEBSOCKET_ACCEPT, SEC_WEBSOCKET_KEY, UPGRADE}; -use hyper::server::conn::http1; +use hyper::header::HeaderValue; use hyper::service::service_fn; -use hyper::{Request, Response, StatusCode}; -use hyper_util::rt::TokioIo; +use hyper::{Method, Request, Response, StatusCode, Version}; +use hyper_util::rt::{TokioExecutor, TokioIo}; use itertools::Itertools; use parking_lot::RwLock; use serde::{Deserialize, Serialize}; @@ -36,13 +36,18 @@ use crate::model::EndpointSchema; use super::{AuthController, ConnectionId, SimpleAuthController, WebsocketStates, WsEndpoint}; +static HDR_UPGRADE: HeaderValue = HeaderValue::from_static("upgrade"); +static HDR_WEBSOCKET: HeaderValue = HeaderValue::from_static("websocket"); +static HDR_SERVER: HeaderValue = HeaderValue::from_static("RustWebsocketServer/1.0"); +static HDR_CREDENTIALS_TRUE: HeaderValue = HeaderValue::from_static("true"); + pub struct WebsocketServer { pub auth_controller: Arc, pub handlers: HashMap, - pub allowed_roles: HashMap>, - pub message_receiver: Option>, + pub message_receiver: parking_lot::Mutex>>, pub toolbox: ArcToolbox, pub config: WsServerConfig, + pub cached_date: RwLock, } @@ -53,10 +58,14 @@ impl WebsocketServer { } Self { auth_controller: Arc::new(SimpleAuthController), - allowed_roles: HashMap::new(), handlers: Default::default(), - message_receiver: None, + message_receiver: parking_lot::Mutex::new(None), toolbox: Toolbox::new(), + cached_date: RwLock::new( + httpdate::fmt_http_date(std::time::SystemTime::now()) + .parse() + .unwrap(), + ), config, } } @@ -77,11 +86,9 @@ impl WebsocketServer { ) { let roles_set = roles.iter().cloned().collect::>(); - let _old_roles = self.allowed_roles.insert(schema.code, roles_set); - let old = self .handlers - .insert(schema.code, WsEndpoint { schema, handler }); + .insert(schema.code, WsEndpoint { schema, handler, allowed_roles: roles_set }); if let Some(old) = old { panic!( "Overwriting handler for endpoint {} {}", @@ -103,28 +110,50 @@ impl WebsocketServer { let this = Arc::clone(&self); let states = Arc::clone(&states); async move { - let is_upgrade = req - .headers() - .get(UPGRADE) - .and_then(|v| v.to_str().ok()) - .map(|v| v.eq_ignore_ascii_case("websocket")) - .unwrap_or(false); - let key = req - .headers() - .get(SEC_WEBSOCKET_KEY) - .map(|k| k.as_bytes().to_vec()); - - if !is_upgrade || key.is_none() { - let mut resp = Response::new(Empty::::new()); - *resp.status_mut() = StatusCode::BAD_REQUEST; - return Ok::<_, Infallible>(resp); - } - let derived = derive_accept_key(&key.unwrap()); + let is_http2 = req.version() == Version::HTTP_2; + + // Validate the upgrade request based on HTTP version. + // HTTP/1.1: GET + Upgrade: websocket + Sec-WebSocket-Key + // HTTP/2: CONNECT + :protocol = websocket (RFC 8441 / RFC 9113 §8.5) + let derived = if is_http2 { + if req.method() != Method::CONNECT { + let mut resp = Response::new(Empty::::new()); + *resp.status_mut() = StatusCode::METHOD_NOT_ALLOWED; + return Ok::<_, Infallible>(resp); + } + let proto_ok = req + .extensions() + .get::() + .map_or(false, |p| p.as_str() == "websocket"); + if !proto_ok { + let mut resp = Response::new(Empty::::new()); + *resp.status_mut() = StatusCode::BAD_REQUEST; + return Ok::<_, Infallible>(resp); + } + None + } else { + let is_upgrade = req + .headers() + .get(UPGRADE) + .and_then(|v| v.to_str().ok()) + .map(|v| v.eq_ignore_ascii_case("websocket")) + .unwrap_or(false); + if !is_upgrade { + let mut resp = Response::new(Empty::::new()); + *resp.status_mut() = StatusCode::BAD_REQUEST; + return Ok::<_, Infallible>(resp); + } + let Some(key) = req.headers().get(SEC_WEBSOCKET_KEY) else { + let mut resp = Response::new(Empty::::new()); + *resp.status_mut() = StatusCode::BAD_REQUEST; + return Ok::<_, Infallible>(resp); + }; + Some(derive_accept_key(key.as_bytes())) + }; let protocol = req .headers() .get("Sec-WebSocket-Protocol") - .or_else(|| req.headers().get("sec-websocket-protocol")) .and_then(|v| v.to_str().ok()) .unwrap_or("") .to_string(); @@ -165,14 +194,16 @@ impl WebsocketServer { } }); + // HTTP/2: respond 200 OK (RFC 9113 §8.5); no 101 or Sec-WebSocket-Accept. + // HTTP/1.1: respond 101 Switching Protocols with the derived accept key. let mut resp = Response::new(Empty::::new()); - *resp.status_mut() = StatusCode::SWITCHING_PROTOCOLS; - resp.headers_mut() - .append(CONNECTION, "upgrade".parse().unwrap()); - resp.headers_mut() - .append(UPGRADE, "websocket".parse().unwrap()); - resp.headers_mut() - .append(SEC_WEBSOCKET_ACCEPT, derived.parse().unwrap()); + if let Some(derived) = derived { + *resp.status_mut() = StatusCode::SWITCHING_PROTOCOLS; + resp.headers_mut().append(CONNECTION, HDR_UPGRADE.clone()); + resp.headers_mut().append(UPGRADE, HDR_WEBSOCKET.clone()); + resp.headers_mut() + .append(SEC_WEBSOCKET_ACCEPT, derived.parse().unwrap()); + } if !protocol.is_empty() { let first = protocol.split(',').next().unwrap_or("").trim(); @@ -190,26 +221,23 @@ impl WebsocketServer { if let Ok(v) = origin.parse::() { resp.headers_mut() .append("Access-Control-Allow-Origin", v); - resp.headers_mut().append( - "Access-Control-Allow-Credentials", - "true".parse().unwrap(), - ); + resp.headers_mut() + .append("Access-Control-Allow-Credentials", HDR_CREDENTIALS_TRUE.clone()); } } } - resp.headers_mut() - .append("Date", chrono::Utc::now().to_rfc2822().parse().unwrap()); - resp.headers_mut() - .append("Server", "RustWebsocketServer/1.0".parse().unwrap()); + resp.headers_mut().append("Date", this.cached_date.read().clone()); + resp.headers_mut().append("Server", HDR_SERVER.clone()); Ok::<_, Infallible>(resp) } }); - http1::Builder::new() - .serve_connection(io, service) - .with_upgrades() + let mut builder = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new()); + builder.http2().enable_connect_protocol(); + builder + .serve_connection_with_upgrades(io, service) .await .map_err(|e| eyre!(e)) } @@ -308,47 +336,100 @@ impl WebsocketServer { async fn listen_impl(self, listener: Arc) -> Result<()> { let states = Arc::new(WebsocketStates::new()); - self.toolbox - .set_ws_states(states.clone_states(), self.config.header_only); let this = Arc::new(self); - let local_set = LocalSet::new(); - let (mut sigterm, mut sigint) = crate::libs::signal::init_signals()?; - local_set - .run_until(async { + this.toolbox + .set_ws_states(states.clone_states(), this.config.header_only); + + let num_shards = shard_count(); + info!("Starting {} WebSocket shards", num_shards); + + let mut shard_senders = Vec::with_capacity(num_shards); + for _ in 0..num_shards { + let (tx, rx) = mpsc::channel::<(T::Channel1, SocketAddr)>(256); + let this = Arc::clone(&this); + let states = Arc::clone(&states); + let listener = Arc::clone(&listener); + std::thread::spawn(move || { + WebsocketServer::run_shard(this, states, listener, rx); + }); + shard_senders.push(tx); + } + + // Date cache updater — runs on the multi-thread scheduler, no LocalSet needed. + tokio::spawn({ + let this = Arc::clone(&this); + async move { loop { - tokio::select! { - _ = crate::libs::signal::wait_for_signals(&mut sigterm, &mut sigint) => break, - accepted = listener.accept() => { - let (stream, addr) = match accepted { - Ok(x) => x, - Err(err) => { - error!("Error while accepting stream: {:?}", err); - continue; - } - }; - let listener = Arc::clone(&listener); - let this = Arc::clone(&this); - let states = Arc::clone(&states); - local_set.spawn_local(async move { - let stream = match listener.handshake(stream).await { - Ok(channel) => { - debug!("Accepted stream from {}", addr); - channel - } - Err(err) => { - error!("Error while handshaking stream: {:?}", err); - return; - } - }; - - let _ = TOOLBOX.scope(this.toolbox.clone(), this.handle_ws_handshake_and_connection(addr, states, stream)).await; - }); + tokio::time::sleep(tokio::time::Duration::from_secs(1)).await; + *this.cached_date.write() = + httpdate::fmt_http_date(std::time::SystemTime::now()) + .parse() + .unwrap(); + } + } + }); + + let (mut sigterm, mut sigint) = crate::libs::signal::init_signals()?; + let mut shard_idx: usize = 0; + loop { + tokio::select! { + _ = crate::libs::signal::wait_for_signals(&mut sigterm, &mut sigint) => break, + accepted = listener.accept() => { + let (stream, addr) = match accepted { + Ok(x) => x, + Err(err) => { + error!("Error while accepting stream: {:?}", err); + continue; } + }; + let shard = &shard_senders[shard_idx % num_shards]; + shard_idx = shard_idx.wrapping_add(1); + if shard.send((stream, addr)).await.is_err() { + error!("Shard channel closed unexpectedly for addr {}", addr); } } - Ok(()) - }) - .await + } + } + + Ok(()) + } + + fn run_shard( + this: Arc, + states: Arc, + listener: Arc, + mut rx: mpsc::Receiver<(T::Channel1, SocketAddr)>, + ) { + let rt = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .expect("Failed to build shard runtime"); + let local_set = LocalSet::new(); + rt.block_on(local_set.run_until(async move { + while let Some((stream, addr)) = rx.recv().await { + let this = Arc::clone(&this); + let states = Arc::clone(&states); + let listener = Arc::clone(&listener); + tokio::task::spawn_local(async move { + let stream = match listener.handshake(stream).await { + Ok(channel) => { + debug!("Accepted stream from {}", addr); + channel + } + Err(err) => { + error!("Error while handshaking stream: {:?}", err); + return; + } + }; + let _ = TOOLBOX + .scope( + this.toolbox.clone(), + this.handle_ws_handshake_and_connection(addr, states, stream), + ) + .await; + }); + } + })); } pub fn dump_schemas(&self) -> Result<()> { @@ -370,6 +451,61 @@ impl WebsocketServer { } } +/// Determine the number of WebSocket shards to spawn. +/// +/// Resolution order (first match wins): +/// 1. `WS_SHARDS` environment variable — explicit operator override. +/// 2. cgroup v1 CPU quota — handles older Docker/k8s deployments that set +/// `cpu.cfs_quota_us` / `cpu.cfs_period_us` but do not use cgroup v2. +/// 3. `std::thread::available_parallelism` — reads cgroup v2 `cpu.max` and +/// `sched_getaffinity` on Linux (Rust 1.74+), logical CPU count elsewhere. +/// 4. Hard fallback of 1 if all detection fails. +fn shard_count() -> usize { + // 1. Explicit override. + if let Ok(val) = std::env::var("WS_SHARDS") { + if let Ok(n) = val.trim().parse::() { + if n > 0 { + return n; + } + } + warn!("WS_SHARDS env var set but invalid, ignoring: {:?}", val); + } + + // 2. cgroup v1 quota (common in older Docker / k8s). + if let Some(n) = read_cgroup_v1_quota() { + return n.max(1); + } + + // 3. stdlib — cgroup v2 + affinity-aware on Linux (Rust 1.74+). + std::thread::available_parallelism() + .map(|n| n.get()) + .unwrap_or(1) +} + +/// Read the cgroup v1 CPU quota and convert it to a thread count. +/// Returns `None` if the files are absent, unparseable, or the quota is +/// unlimited (quota == -1). +fn read_cgroup_v1_quota() -> Option { + let quota: i64 = std::fs::read_to_string("/sys/fs/cgroup/cpu/cpu.cfs_quota_us") + .ok()? + .trim() + .parse() + .ok()?; + if quota <= 0 { + return None; // -1 means no limit + } + let period: i64 = std::fs::read_to_string("/sys/fs/cgroup/cpu/cpu.cfs_period_us") + .ok()? + .trim() + .parse() + .ok()?; + if period <= 0 { + return None; + } + // Ceiling division: round up so a 1.5-CPU quota gives 2 shards. + Some(((quota + period - 1) / period) as usize) +} + pub fn wrap_ws_error(err: Result) -> Result { err.map_err(|x| eyre!(x)) } diff --git a/src/libs/ws/session.rs b/src/libs/ws/session.rs index 459643e..80d4ab3 100644 --- a/src/libs/ws/session.rs +++ b/src/libs/ws/session.rs @@ -100,10 +100,10 @@ impl< context.seq = req.seq; context.method = req.method; context.user_id = self.conn_info.get_user_id(); - context.roles = Arc::new(self.conn_info.get_roles()); + context.roles = self.conn_info.get_roles(); // Check roles - let Some(allowed_roles) = self.server.allowed_roles.get(&req.method) else { + let Some(endpoint) = self.server.handlers.get(&req.method) else { self.server.toolbox.send( context.connection_id, request_error_to_resp(&context, ErrorCode::NOT_IMPLEMENTED, Value::Null), @@ -111,8 +111,7 @@ impl< return Ok(true); }; - let allowed = check_roles(&context.roles, allowed_roles); - if !allowed { + if !check_roles(&context.roles, &endpoint.allowed_roles) { self.server.toolbox.send( context.connection_id, request_error_to_resp(&context, ErrorCode::FORBIDDEN, "Forbidden"), @@ -120,18 +119,7 @@ impl< return Ok(true); } - let handler = self.server.handlers.get(&req.method); - let handler = match handler { - Some(handler) => handler, - None => { - self.server.toolbox.send( - context.connection_id, - request_error_to_resp(&context, ErrorCode::NOT_IMPLEMENTED, Value::Null), - ); - return Ok(true); - } - }; - let handler = handler.handler.clone(); + let handler = endpoint.handler.clone(); let toolbox = self.server.toolbox.clone(); tokio::task::spawn_local(async move { TOOLBOX @@ -147,6 +135,14 @@ impl< async fn run_loop(&mut self) -> Result<()> { let conn_id = self.conn_info.connection_id; loop { + // Drain all pending outbound messages before blocking on new events. + while let Ok(msg) = self.rx.try_recv() { + self.send_message(msg).await?; + if self.server.config.header_only { + return Ok(()); + } + } + tokio::select! { msg = self.rx.recv() => { // info!(?conn_id, ?msg, "Received message to send"); @@ -185,15 +181,10 @@ impl< } fn check_roles(actual_roles: &[u32], allowed_roles: &HashSet) -> bool { - if allowed_roles.is_empty() { - return false; // No roles are allowed - } - for role in actual_roles.iter() { - if allowed_roles.contains(role) { - return true; // At least one role is allowed - } + if allowed_roles.is_empty() || actual_roles.is_empty() { + return false; } - false // No roles matched + actual_roles.iter().any(|role| allowed_roles.contains(role)) } #[cfg(test)] diff --git a/src/libs/ws/subs.rs b/src/libs/ws/subs.rs index 2f6dbef..f64983b 100644 --- a/src/libs/ws/subs.rs +++ b/src/libs/ws/subs.rs @@ -114,7 +114,7 @@ impl SubscriptionManager { return; }; - let data = serde_json::to_value(msg).unwrap(); + let data = serde_json::value::to_raw_value(msg).expect("Failed to serialize stream data"); let msg = WsResponseGeneric::Stream(WsStreamResponseGeneric { original_seq: sub.ctx.seq, @@ -174,7 +174,7 @@ impl SubscriptionManager { let Some(data) = filter(sub) else { continue; }; - let data = serde_json::to_value(&data).unwrap(); + let data = serde_json::value::to_raw_value(&data).expect("Failed to serialize stream data"); let msg = WsResponseGeneric::Stream(WsStreamResponseGeneric { original_seq: sub.ctx.seq, method: sub.ctx.method, diff --git a/src/libs/ws/toolbox.rs b/src/libs/ws/toolbox.rs index 839836e..548f5b8 100644 --- a/src/libs/ws/toolbox.rs +++ b/src/libs/ws/toolbox.rs @@ -1,11 +1,10 @@ use dashmap::DashMap; use eyre::{Context, Result}; -use parking_lot::RwLock; use serde::*; use serde_json::Value; use std::fmt::{Debug, Display, Formatter}; use std::net::{IpAddr, Ipv4Addr}; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; use tokio_tungstenite::tungstenite::Message; use tracing::*; @@ -99,14 +98,17 @@ impl RequestContext { } } +type SendFn = dyn Fn(ConnectionId, WsResponseValue) -> bool + Send + Sync; +type SendFnArc = Arc; + pub struct Toolbox { - pub send_msg: RwLock bool + Send + Sync>>, + pub send_msg: OnceLock, } pub type ArcToolbox = Arc; impl Toolbox { pub fn new() -> Arc { Arc::new(Self { - send_msg: RwLock::new(Arc::new(|_conn_id, _msg| false)), + send_msg: OnceLock::new(), }) } @@ -115,7 +117,7 @@ impl Toolbox { states: Arc>>, oneshot: bool, ) { - *self.send_msg.write() = Arc::new(move |conn_id, msg| { + let send_fn: SendFnArc = Arc::new(move |conn_id, msg| { let state = if let Some(state) = states.get(&conn_id) { state } else { @@ -124,6 +126,9 @@ impl Toolbox { Self::send_ws_msg(&state.message_queue, msg, oneshot); true }); + if self.send_msg.set(send_fn).is_err() { + panic!("set_ws_states called twice"); + } } pub fn send_ws_msg( @@ -132,15 +137,24 @@ impl Toolbox { oneshot: bool, ) { let resp = serde_json::to_string(&resp).unwrap(); - if let Err(err) = sender.try_send(resp.into()) { - warn!("Failed to send websocket message: {:?}", err) + match sender.try_send(resp.into()) { + Ok(()) => {} + Err(tokio::sync::mpsc::error::TrySendError::Full(_)) => { + error!("WebSocket send buffer full — client is too slow or disconnected"); + } + Err(tokio::sync::mpsc::error::TrySendError::Closed(_)) => { + debug!("WebSocket send channel closed (client already disconnected)"); + } } if oneshot { let _ = sender.try_send(Message::Close(None)); } } pub fn send(&self, conn_id: ConnectionId, resp: WsResponseValue) -> bool { - self.send_msg.read()(conn_id, resp) + match self.send_msg.get() { + Some(f) => f(conn_id, resp), + None => false, + } } pub fn send_response(&self, ctx: &RequestContext, resp: impl Serialize) { self.send( @@ -148,7 +162,7 @@ impl Toolbox { WsResponseValue::Immediate(WsSuccessResponse { method: ctx.method, seq: ctx.seq, - params: serde_json::to_value(&resp).unwrap(), + params: serde_json::value::to_raw_value(&resp).expect("Failed to serialize response"), }), ); } @@ -186,7 +200,7 @@ impl Toolbox { Ok(ok) => WsResponseValue::Immediate(WsSuccessResponse { method, seq, - params: serde_json::to_value(ok).expect("Failed to serialize response"), + params: serde_json::value::to_raw_value(&ok).expect("Failed to serialize response"), }), Err(err) if err.is::() => { return None;