Skip to content
Open
11 changes: 10 additions & 1 deletion engine/artifacts/config-schema.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions engine/packages/config/src/config/pegboard.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@
pub gateway_websocket_open_timeout_ms: Option<u64>,
/// Timeout for response to start in milliseconds.
pub gateway_response_start_timeout_ms: Option<u64>,
/// Timeout between streaming HTTP response chunks in milliseconds.
pub gateway_response_chunk_idle_timeout_ms: Option<u64>,
/// Ping interval for gateway updates in milliseconds.
pub gateway_update_ping_interval_ms: Option<u64>,
/// GC interval for in-flight requests in milliseconds.
Expand Down Expand Up @@ -278,8 +280,12 @@
pub fn gateway_response_start_timeout_ms(&self) -> u64 {
self.gateway_response_start_timeout_ms
.unwrap_or(5 * 60 * 1000)
}

Check warning on line 283 in engine/packages/config/src/config/pegboard.rs

View workflow job for this annotation

GitHub Actions / Rustfmt

Diff in /home/runner/work/rivet/rivet/engine/packages/config/src/config/pegboard.rs

pub fn gateway_response_chunk_idle_timeout_ms(&self) -> u64 {
self.gateway_response_chunk_idle_timeout_ms.unwrap_or(30_000)
}

pub fn gateway_update_ping_interval_ms(&self) -> u64 {
self.gateway_update_ping_interval_ms.unwrap_or(3_000)
}
Expand Down
17 changes: 16 additions & 1 deletion engine/packages/guard-core/src/custom_serve.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use anyhow::{Result, bail};
use async_trait::async_trait;
use bytes::Bytes;
use http_body_util::Full;
use hyper::{Request, Response};
use hyper::{Request, Response, body::Incoming as BodyIncoming};
use tokio_tungstenite::tungstenite::protocol::frame::CloseFrame;

use crate::WebSocketHandle;
Expand All @@ -17,13 +17,28 @@ pub enum HibernationResult {
/// Trait for custom request serving logic that can handle both HTTP and WebSocket requests
#[async_trait]
pub trait CustomServeTrait: Send + Sync {
/// Returns true when this service wants the original request body stream.
/// The default buffered path keeps retry semantics for existing custom routes.
fn streams_request_body(&self) -> bool {
false
}

/// Handle a regular HTTP request
async fn handle_request(
&self,
req: Request<Full<Bytes>>,
req_ctx: &mut RequestContext,
) -> Result<Response<ResponseBody>>;

/// Handle a regular HTTP request with the original inbound body stream.
async fn handle_streaming_request(
&self,
_req: Request<BodyIncoming>,
_req_ctx: &mut RequestContext,
) -> Result<Response<ResponseBody>> {
bail!("service does not support streaming request bodies");
}

/// Handle a WebSocket connection after upgrade. Supports connection retries.
async fn handle_websocket(
&self,
Expand Down
4 changes: 4 additions & 0 deletions engine/packages/guard-core/src/proxy_service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -926,6 +926,10 @@ impl ProxyService {
.build());
}
ResolveRouteOutput::CustomServe(mut handler) => {
if handler.streams_request_body() {
return handler.handle_streaming_request(req, req_ctx).await;
}

// Collect request body
let (req_parts, body) = req.into_parts();
let req_body =
Expand Down
17 changes: 16 additions & 1 deletion engine/packages/guard-core/src/response_body.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use bytes::Bytes;
use http_body_util::Full;
use hyper::body::Incoming as BodyIncoming;
use tokio::sync::mpsc;

pub type ResponseBodyError = Box<dyn std::error::Error + Send + Sync>;

/// Response body type that can handle both streaming and buffered responses
#[derive(Debug)]
Expand All @@ -9,11 +12,13 @@ pub enum ResponseBody {
Full(Full<Bytes>),
/// Streaming response body
Incoming(BodyIncoming),
/// Channel-backed streaming response body
Channel(mpsc::Receiver<Result<Bytes, ResponseBodyError>>),
}

impl http_body::Body for ResponseBody {
type Data = Bytes;
type Error = Box<dyn std::error::Error + Send + Sync>;
type Error = ResponseBodyError;

fn poll_frame(
self: std::pin::Pin<&mut Self>,
Expand Down Expand Up @@ -46,20 +51,30 @@ impl http_body::Body for ResponseBody {
std::task::Poll::Pending => std::task::Poll::Pending,
}
}
ResponseBody::Channel(rx) => match rx.poll_recv(cx) {
std::task::Poll::Ready(Some(Ok(bytes))) => {
std::task::Poll::Ready(Some(Ok(http_body::Frame::data(bytes))))
}
std::task::Poll::Ready(Some(Err(err))) => std::task::Poll::Ready(Some(Err(err))),
std::task::Poll::Ready(None) => std::task::Poll::Ready(None),
std::task::Poll::Pending => std::task::Poll::Pending,
},
}
}

fn is_end_stream(&self) -> bool {
match self {
ResponseBody::Full(body) => body.is_end_stream(),
ResponseBody::Incoming(body) => body.is_end_stream(),
ResponseBody::Channel(rx) => rx.is_closed() && rx.is_empty(),
}
}

fn size_hint(&self) -> http_body::SizeHint {
match self {
ResponseBody::Full(body) => body.size_hint(),
ResponseBody::Incoming(body) => body.size_hint(),
ResponseBody::Channel(_) => http_body::SizeHint::default(),
}
}
}
30 changes: 30 additions & 0 deletions engine/packages/guard-core/tests/response_body.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
use bytes::Bytes;
use http_body_util::BodyExt;
use rivet_guard_core::ResponseBody;
use tokio::sync::mpsc;

#[tokio::test]
async fn channel_body_yields_sent_chunks() {
let (tx, rx) = mpsc::channel(2);
tx.send(Ok(Bytes::from_static(b"hello "))).await.unwrap();
tx.send(Ok(Bytes::from_static(b"world"))).await.unwrap();
drop(tx);

let collected = ResponseBody::Channel(rx).collect().await.unwrap();

assert_eq!(collected.to_bytes(), Bytes::from_static(b"hello world"));
}

#[tokio::test]
async fn channel_body_surfaces_errors() {
let (tx, rx) = mpsc::channel(1);
tx.send(Err(std::io::Error::other("stream failed").into()))
.await
.unwrap();
drop(tx);

let mut body = ResponseBody::Channel(rx);
let frame = body.frame().await.expect("expected frame");

assert!(frame.is_err());
}
4 changes: 2 additions & 2 deletions engine/packages/pegboard-envoy/src/tunnel_to_ws_task.rs
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ fn to_envoy_tunnel_message_kind_name(kind: &protocol::ToEnvoyTunnelMessageKind)
match kind {
protocol::ToEnvoyTunnelMessageKind::ToEnvoyRequestStart(_) => "ToEnvoyRequestStart",
protocol::ToEnvoyTunnelMessageKind::ToEnvoyRequestChunk(_) => "ToEnvoyRequestChunk",
protocol::ToEnvoyTunnelMessageKind::ToEnvoyRequestAbort => "ToEnvoyRequestAbort",
protocol::ToEnvoyTunnelMessageKind::ToEnvoyRequestAbort(_) => "ToEnvoyRequestAbort",
protocol::ToEnvoyTunnelMessageKind::ToEnvoyWebSocketOpen(_) => "ToEnvoyWebSocketOpen",
protocol::ToEnvoyTunnelMessageKind::ToEnvoyWebSocketMessage(_) => "ToEnvoyWebSocketMessage",
protocol::ToEnvoyTunnelMessageKind::ToEnvoyWebSocketClose(_) => "ToEnvoyWebSocketClose",
Expand All @@ -272,7 +272,7 @@ fn to_envoy_tunnel_message_inner_data_len(kind: &protocol::ToEnvoyTunnelMessageK
}
protocol::ToEnvoyTunnelMessageKind::ToEnvoyRequestChunk(msg) => msg.body.len(),
protocol::ToEnvoyTunnelMessageKind::ToEnvoyWebSocketMessage(msg) => msg.data.len(),
protocol::ToEnvoyTunnelMessageKind::ToEnvoyRequestAbort
protocol::ToEnvoyTunnelMessageKind::ToEnvoyRequestAbort(_)
| protocol::ToEnvoyTunnelMessageKind::ToEnvoyWebSocketOpen(_)
| protocol::ToEnvoyTunnelMessageKind::ToEnvoyWebSocketClose(_) => 0,
}
Expand Down
4 changes: 2 additions & 2 deletions engine/packages/pegboard-envoy/src/ws_to_tunnel_task.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1400,7 +1400,7 @@ fn tunnel_message_kind_name(kind: &protocol::ToRivetTunnelMessageKind) -> &'stat
match kind {
ToRivetTunnelMessageKind::ToRivetResponseStart(_) => "ToRivetResponseStart",
ToRivetTunnelMessageKind::ToRivetResponseChunk(_) => "ToRivetResponseChunk",
ToRivetTunnelMessageKind::ToRivetResponseAbort => "ToRivetResponseAbort",
ToRivetTunnelMessageKind::ToRivetResponseAbort(_) => "ToRivetResponseAbort",
ToRivetTunnelMessageKind::ToRivetWebSocketOpen(_) => "ToRivetWebSocketOpen",
ToRivetTunnelMessageKind::ToRivetWebSocketMessage(_) => "ToRivetWebSocketMessage",
ToRivetTunnelMessageKind::ToRivetWebSocketMessageAck(_) => "ToRivetWebSocketMessageAck",
Expand Down Expand Up @@ -2067,7 +2067,7 @@ fn tunnel_message_inner_data_len(kind: &protocol::ToRivetTunnelMessageKind) -> u
}
ToRivetTunnelMessageKind::ToRivetResponseChunk(chunk) => chunk.body.len(),
ToRivetTunnelMessageKind::ToRivetWebSocketMessage(msg) => msg.data.len(),
ToRivetTunnelMessageKind::ToRivetResponseAbort
ToRivetTunnelMessageKind::ToRivetResponseAbort(_)
| ToRivetTunnelMessageKind::ToRivetWebSocketOpen(_)
| ToRivetTunnelMessageKind::ToRivetWebSocketMessageAck(_)
| ToRivetTunnelMessageKind::ToRivetWebSocketClose(_) => 0,
Expand Down
6 changes: 3 additions & 3 deletions engine/packages/pegboard-gateway2/src/hibernation_task.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use std::sync::{
use tokio::sync::{mpsc, watch};
use tokio_tungstenite::tungstenite::Message;

use crate::shared_state::{InFlightRequestHandle, MsgGcReason};
use crate::shared_state::{InFlightRequestHandle, InFlightTunnelMessage, MsgGcReason};

use super::HibernationLifecycleResult;

Expand All @@ -26,7 +26,7 @@ pub async fn task(
in_flight_req: InFlightRequestHandle,
ctx: StandaloneCtx,
actor_id: Id,
mut msg_rx: mpsc::UnboundedReceiver<protocol::ToRivetTunnelMessageKind>,
mut msg_rx: mpsc::UnboundedReceiver<InFlightTunnelMessage>,
mut drop_rx: watch::Receiver<Option<MsgGcReason>>,
egress_bytes: Arc<AtomicU64>,
mut hibernation_abort_rx: watch::Receiver<()>,
Expand Down Expand Up @@ -55,7 +55,7 @@ pub async fn task(
tokio::select! {
res = msg_rx.recv() => {
if let Some(msg) = res {
match msg {
match msg.message_kind {
protocol::ToRivetTunnelMessageKind::ToRivetWebSocketMessage(ws_msg) => {
tracing::trace!(
request_id=%protocol::util::id_to_string(&in_flight_req.request_id),
Expand Down
Loading
Loading