From f79374e7b952d8638179390b98f22628d6c54fc5 Mon Sep 17 00:00:00 2001 From: MasterPtato Date: Thu, 25 Jun 2026 12:17:20 -0700 Subject: [PATCH] [slopfix] test(universaldb): postgres leader failover + abort resolver task on driver drop --- engine/artifacts/config-schema.json | 32 -- .../packages/config/src/config/api_public.rs | 25 -- engine/packages/config/src/config/mod.rs | 11 - engine/packages/engine/src/main.rs | 2 + engine/packages/metrics-server/src/server.rs | 2 +- .../pegboard-envoy/src/ws_to_tunnel_task.rs | 9 +- .../pegboard-gateway/src/shared_state.rs | 2 +- .../packages/pegboard-gateway2/src/metrics.rs | 6 - .../pegboard-gateway2/src/shared_state.rs | 20 +- engine/packages/perf/src/lib.rs | 2 +- engine/packages/service-manager/src/lib.rs | 2 - engine/packages/test-deps/src/datacenter.rs | 1 - engine/packages/universaldb/Cargo.toml | 1 + .../src/driver/postgres/database.rs | 7 +- .../src/driver/postgres/resolver/mod.rs | 7 +- engine/packages/universaldb/tests/failover.rs | 219 ++++++++++++ .../src/driver/postgres/mod.rs | 319 ++++++++++++++---- .../universalpubsub/tests/reconnect.rs | 35 +- scripts/run/engine-postgres.sh | 21 +- 19 files changed, 535 insertions(+), 188 deletions(-) delete mode 100644 engine/packages/config/src/config/api_public.rs create mode 100644 engine/packages/universaldb/tests/failover.rs diff --git a/engine/artifacts/config-schema.json b/engine/artifacts/config-schema.json index 6a08895e60..a026e50760 100644 --- a/engine/artifacts/config-schema.json +++ b/engine/artifacts/config-schema.json @@ -40,17 +40,6 @@ } ] }, - "api_public": { - "default": null, - "anyOf": [ - { - "$ref": "#/definitions/ApiPublic" - }, - { - "type": "null" - } - ] - }, "auth": { "default": null, "anyOf": [ @@ -214,27 +203,6 @@ }, "additionalProperties": false }, - "ApiPublic": { - "description": "Configuration for the public API service.", - "type": "object", - "properties": { - "respect_forwarded_for": { - "description": "Flag to respect the X-Forwarded-For header for client IP addresses.\n\nWill be ignored in favor of CF-Connecting-IP if DNS provider is configured as Cloudflare.", - "type": [ - "boolean", - "null" - ] - }, - "verbose_errors": { - "description": "Flag to enable verbose error reporting.", - "type": [ - "boolean", - "null" - ] - } - }, - "additionalProperties": false - }, "Auth": { "type": "object", "required": [ diff --git a/engine/packages/config/src/config/api_public.rs b/engine/packages/config/src/config/api_public.rs deleted file mode 100644 index 53cb280a37..0000000000 --- a/engine/packages/config/src/config/api_public.rs +++ /dev/null @@ -1,25 +0,0 @@ -use schemars::JsonSchema; -use serde::{Deserialize, Serialize}; - -/// Configuration for the public API service. -#[derive(Debug, Serialize, Deserialize, Clone, Default, JsonSchema)] -#[serde(deny_unknown_fields)] -pub struct ApiPublic { - /// Flag to enable verbose error reporting. - pub verbose_errors: Option, - /// Flag to respect the X-Forwarded-For header for client IP addresses. - /// - /// Will be ignored in favor of CF-Connecting-IP if DNS provider is - /// configured as Cloudflare. - pub respect_forwarded_for: Option, -} - -impl ApiPublic { - pub fn verbose_errors(&self) -> bool { - self.verbose_errors.unwrap_or(true) - } - - pub fn respect_forwarded_for(&self) -> bool { - self.respect_forwarded_for.unwrap_or(false) - } -} diff --git a/engine/packages/config/src/config/mod.rs b/engine/packages/config/src/config/mod.rs index 3ee2aace1d..caa8bf196e 100644 --- a/engine/packages/config/src/config/mod.rs +++ b/engine/packages/config/src/config/mod.rs @@ -4,7 +4,6 @@ use serde::{Deserialize, Serialize}; use std::sync::LazyLock; pub mod api_peer; -pub mod api_public; pub mod auth; pub mod cache; pub mod clickhouse; @@ -21,7 +20,6 @@ pub mod telemetry; pub mod topology; pub use api_peer::*; -pub use api_public::*; pub use auth::*; pub use cache::*; pub use clickhouse::*; @@ -74,9 +72,6 @@ pub struct Root { #[serde(default)] pub guard: Option, - #[serde(default)] - pub api_public: Option, - #[serde(default)] pub api_peer: Option, @@ -122,7 +117,6 @@ impl Default for Root { Root { auth: None, guard: None, - api_public: None, api_peer: None, pegboard: None, logs: None, @@ -146,11 +140,6 @@ impl Root { self.guard.as_ref().unwrap_or(&DEFAULT) } - pub fn api_public(&self) -> &ApiPublic { - static DEFAULT: LazyLock = LazyLock::new(ApiPublic::default); - self.api_public.as_ref().unwrap_or(&DEFAULT) - } - pub fn api_peer(&self) -> &ApiPeer { static DEFAULT: LazyLock = LazyLock::new(ApiPeer::default); self.api_peer.as_ref().unwrap_or(&DEFAULT) diff --git a/engine/packages/engine/src/main.rs b/engine/packages/engine/src/main.rs index d0f38a6efb..88e74b0f1b 100644 --- a/engine/packages/engine/src/main.rs +++ b/engine/packages/engine/src/main.rs @@ -36,6 +36,8 @@ fn main() -> Result<()> { } async fn main_inner() -> Result<()> { + tracing::info!(version=%build_meta::VERSION, git_sha=%build_meta::GIT_SHA, built_at=%build_meta::BUILD_TIMESTAMP, "starting rivet"); + let cli = Cli::parse(); // Load config diff --git a/engine/packages/metrics-server/src/server.rs b/engine/packages/metrics-server/src/server.rs index 48b974c537..b2512b3e31 100644 --- a/engine/packages/metrics-server/src/server.rs +++ b/engine/packages/metrics-server/src/server.rs @@ -28,7 +28,7 @@ pub async fn run_standalone(config: rivet_config::Config) -> Result<()> { Ok::<_, hyper::Error>(service_fn(serve_req)) })); - tracing::info!(?host, ?port, "started metrics server"); + tracing::debug!(?host, ?port, "started metrics server"); server.await?; Ok(()) diff --git a/engine/packages/pegboard-envoy/src/ws_to_tunnel_task.rs b/engine/packages/pegboard-envoy/src/ws_to_tunnel_task.rs index 746bea158d..2e023fe34a 100644 --- a/engine/packages/pegboard-envoy/src/ws_to_tunnel_task.rs +++ b/engine/packages/pegboard-envoy/src/ws_to_tunnel_task.rs @@ -474,9 +474,14 @@ pub async fn task_inner( // backpressure to the runner rather than dropping protocol messages. let mut rate_limit = rivet_util::throttle::RateLimiter::new( rivet_util::throttle::RateLimitMethod::LeakyBucket { - requests: ctx.config().pegboard().envoy_websocket_rate_limit_requests(), + requests: ctx + .config() + .pegboard() + .envoy_websocket_rate_limit_requests(), drip_rate: Duration::from_micros( - ctx.config().pegboard().envoy_websocket_rate_limit_drip_rate_us(), + ctx.config() + .pegboard() + .envoy_websocket_rate_limit_drip_rate_us(), ), }, ); diff --git a/engine/packages/pegboard-gateway/src/shared_state.rs b/engine/packages/pegboard-gateway/src/shared_state.rs index 580da8eabc..742b146921 100644 --- a/engine/packages/pegboard-gateway/src/shared_state.rs +++ b/engine/packages/pegboard-gateway/src/shared_state.rs @@ -142,7 +142,7 @@ pub struct SharedState(Arc); impl SharedState { pub fn new(config: &rivet_config::Config, ups: PubSub) -> Self { let gateway_id = protocol::util::generate_gateway_id(); - tracing::info!(gateway_id = %protocol::util::id_to_string(&gateway_id), "setting up shared state for gateway"); + tracing::debug!(gateway_id = %protocol::util::id_to_string(&gateway_id), "setting up shared state for gateway"); let receiver_subject = GatewayReceiverSubject::new(gateway_id); let pegboard_config = config.pegboard(); diff --git a/engine/packages/pegboard-gateway2/src/metrics.rs b/engine/packages/pegboard-gateway2/src/metrics.rs index 244df7f882..f087d2ac3b 100644 --- a/engine/packages/pegboard-gateway2/src/metrics.rs +++ b/engine/packages/pegboard-gateway2/src/metrics.rs @@ -53,12 +53,6 @@ lazy_static::lazy_static! { &["namespace_id", "pool_name", "protocol", "reason"], *REGISTRY ).unwrap(); - pub static ref SHUTDOWN_IN_FLIGHT_ABORTED_TOTAL: IntCounter = - register_int_counter_with_registry!( - "gateway2_shutdown_in_flight_aborted_total", - "In-flight gateway requests abandoned on pod shutdown without sending close.", - *REGISTRY - ).unwrap(); pub static ref MSG_SENT_TOTAL: IntCounterVec = register_int_counter_vec_with_registry!( "gateway2_msg_sent_total", "Count of total of tunnel messages sent.", diff --git a/engine/packages/pegboard-gateway2/src/shared_state.rs b/engine/packages/pegboard-gateway2/src/shared_state.rs index b3de6e5887..1d1b7a95cd 100644 --- a/engine/packages/pegboard-gateway2/src/shared_state.rs +++ b/engine/packages/pegboard-gateway2/src/shared_state.rs @@ -163,7 +163,7 @@ impl SharedState { init_slow_ping_threshold_from_env(); let gateway_id = protocol::util::generate_gateway_id(); - tracing::info!(gateway_id = %display_id(&gateway_id), "setting up shared state for gateway"); + tracing::debug!(gateway_id = %display_id(&gateway_id), "setting up shared state for gateway"); let receiver_subject = GatewayReceiverSubject::new(gateway_id); let pegboard_config = config.pegboard(); @@ -194,27 +194,9 @@ impl SharedState { let self_clone = self.clone(); tokio::spawn(async move { self_clone.gc().await }); - let self_clone = self.clone(); - tokio::spawn(async move { self_clone.shutdown_watcher().await }); - Ok(()) } - #[tracing::instrument(skip_all)] - async fn shutdown_watcher(&self) { - let mut term_signal = __rivet_runtime::TermSignal::get(); - term_signal.recv().await; - - let in_flight_aborted = self.in_flight_requests.len(); - if in_flight_aborted > 0 { - metrics::SHUTDOWN_IN_FLIGHT_ABORTED_TOTAL.inc_by(in_flight_aborted as u64); - } - tracing::info!( - in_flight_aborted, - "gateway shutdown in-flight requests abandoned without close" - ); - } - #[tracing::instrument(skip_all)] async fn receiver(&self) { // Automatically resubscribe if unsubscribed diff --git a/engine/packages/perf/src/lib.rs b/engine/packages/perf/src/lib.rs index 7642c406ad..f7ebb0dc71 100644 --- a/engine/packages/perf/src/lib.rs +++ b/engine/packages/perf/src/lib.rs @@ -95,7 +95,7 @@ impl Drop for PerfMeasure { let elapsed = self.start.elapsed(); let _guard = self.span.enter(); - tracing::warn!( + tracing::debug!( name = self.name, elapsed_ms = PerfMeasure::__elapsed_ms(elapsed), "PerfMeasure dropped without finish() - measurement discarded", diff --git a/engine/packages/service-manager/src/lib.rs b/engine/packages/service-manager/src/lib.rs index 6c9f198d78..a60854f3e0 100644 --- a/engine/packages/service-manager/src/lib.rs +++ b/engine/packages/service-manager/src/lib.rs @@ -160,8 +160,6 @@ pub async fn start( let shutting_down = Arc::new(AtomicBool::new(false)); for service in services { - tracing::debug!(name=%service.name, kind=?service.kind, "server starting service"); - match service.kind.behavior() { ServiceBehavior::Service => { let config = config.clone(); diff --git a/engine/packages/test-deps/src/datacenter.rs b/engine/packages/test-deps/src/datacenter.rs index 33843062df..d7fdfcf040 100644 --- a/engine/packages/test-deps/src/datacenter.rs +++ b/engine/packages/test-deps/src/datacenter.rs @@ -73,7 +73,6 @@ pub async fn setup_single_datacenter( let mut root = rivet_config::config::Root::default(); root.database = Some(db_config); root.pubsub = Some(pubsub_config); - root.api_public = Some(Default::default()); root.api_peer = Some(rivet_config::config::ApiPeer { port: Some(api_peer_port), ..Default::default() diff --git a/engine/packages/universaldb/Cargo.toml b/engine/packages/universaldb/Cargo.toml index b2f6b5f045..57dba4ab33 100644 --- a/engine/packages/universaldb/Cargo.toml +++ b/engine/packages/universaldb/Cargo.toml @@ -39,4 +39,5 @@ rivet-config.workspace = true rivet-env.workspace = true rivet-pools.workspace = true rivet-test-deps-docker.workspace = true +tokio-postgres.workspace = true tracing-subscriber.workspace = true diff --git a/engine/packages/universaldb/src/driver/postgres/database.rs b/engine/packages/universaldb/src/driver/postgres/database.rs index 6e24cf647d..6cded0f320 100644 --- a/engine/packages/universaldb/src/driver/postgres/database.rs +++ b/engine/packages/universaldb/src/driver/postgres/database.rs @@ -59,6 +59,7 @@ impl PostgresConfig { pub struct PostgresDatabaseDriver { shared: Arc, max_retries: AtomicI32, + resolver_handle: JoinHandle<()>, gc_handle: JoinHandle<()>, } @@ -112,13 +113,14 @@ impl PostgresDatabaseDriver { let shared = PostgresShared::new(pool, node_id, listener); // Every node runs the resolver; only the elected leader drains the commit queue. - resolver::spawn(shared.clone()); + let resolver_handle = resolver::spawn(shared.clone()); let gc_handle = Self::spawn_gc(shared.clone()); Ok(PostgresDatabaseDriver { shared, max_retries: AtomicI32::new(100), + resolver_handle, gc_handle, }) } @@ -292,6 +294,9 @@ impl DatabaseDriver for PostgresDatabaseDriver { impl Drop for PostgresDatabaseDriver { fn drop(&mut self) { + // Abort the resolver so a dropped node stops renewing its lease; the lease then expires and + // another node can take over. Without this a dropped leader would renew its lease forever. + self.resolver_handle.abort(); self.gc_handle.abort(); } } diff --git a/engine/packages/universaldb/src/driver/postgres/resolver/mod.rs b/engine/packages/universaldb/src/driver/postgres/resolver/mod.rs index 79f1f1f48b..3dfac50fce 100644 --- a/engine/packages/universaldb/src/driver/postgres/resolver/mod.rs +++ b/engine/packages/universaldb/src/driver/postgres/resolver/mod.rs @@ -33,9 +33,10 @@ enum DrainOutcome { } /// Spawn the per-process resolver task. Every node runs this; only the elected leader drains the -/// commit queue. -pub fn spawn(shared: Arc) { - tokio::spawn(run(shared)); +/// commit queue. The returned handle is aborted when the owning driver drops, which stops lease +/// renewal so the lease expires and another node can take over (node-death / failover path). +pub fn spawn(shared: Arc) -> tokio::task::JoinHandle<()> { + tokio::spawn(run(shared)) } async fn run(shared: Arc) { diff --git a/engine/packages/universaldb/tests/failover.rs b/engine/packages/universaldb/tests/failover.rs new file mode 100644 index 0000000000..7d122ca781 --- /dev/null +++ b/engine/packages/universaldb/tests/failover.rs @@ -0,0 +1,219 @@ +use std::{sync::Arc, time::Duration}; + +use rivet_test_deps_docker::TestDatabase; +use tokio_postgres::NoTls; +use universaldb::{Database, utils::IsolationLevel::*}; +use uuid::Uuid; + +const ALPHA_KEY: &[u8] = b"failover/alpha"; +const BETA_KEY: &[u8] = b"failover/beta"; + +/// Build a fresh Postgres-backed `Database`. Each call spins up an independent driver (its own pool, +/// node id, listener, and resolver), so two of them against one Postgres model two engine nodes. +async fn make_db(connection_string: &str) -> Database { + let driver = universaldb::driver::PostgresDatabaseDriver::new_with_config( + universaldb::driver::postgres::PostgresConfig::new(connection_string.to_string()), + ) + .await + .unwrap(); + Database::new(Arc::new(driver)) +} + +/// Raw verification connection used to inspect leader/lease/version state out of band. +async fn connect_raw(connection_string: &str) -> tokio_postgres::Client { + let (client, connection) = tokio_postgres::connect(connection_string, NoTls) + .await + .unwrap(); + tokio::spawn(async move { + let _ = connection.await; + }); + client +} + +struct LeaseRow { + epoch: i64, + leader_addr: String, + durable_version: i64, +} + +async fn read_lease(client: &tokio_postgres::Client) -> Option { + let row = client + .query_opt( + "SELECT epoch, leader_addr, durable_version FROM udb_lease WHERE id = 1", + &[], + ) + .await + .unwrap()?; + Some(LeaseRow { + epoch: row.get(0), + leader_addr: row.get(1), + durable_version: row.get(2), + }) +} + +/// High-water of the LOGGED version sequence. A freshly elected leader must continue from at least +/// this value, never regress below it. +async fn read_seq_high(client: &tokio_postgres::Client) -> i64 { + client + .query_one("SELECT last_value FROM udb_version_seq", &[]) + .await + .unwrap() + .get(0) +} + +/// Poll `udb_lease` until `pred` holds or the deadline passes. +async fn wait_for_lease bool>( + client: &tokio_postgres::Client, + timeout: Duration, + pred: F, +) -> LeaseRow { + let deadline = tokio::time::Instant::now() + timeout; + loop { + if let Some(lease) = read_lease(client).await { + if pred(&lease) { + return lease; + } + } + if tokio::time::Instant::now() >= deadline { + panic!("timed out waiting for lease condition"); + } + tokio::time::sleep(Duration::from_millis(200)).await; + } +} + +async fn write_key(db: &Database, key: &'static [u8], value: &'static [u8]) { + db.txn("test_failover", move |tx| async move { + tx.set(key, value); + Ok(()) + }) + .await + .unwrap(); +} + +async fn read_key(db: &Database, key: &'static [u8]) -> Option> { + db.txn("test_failover", move |tx| async move { + let val = tx.get(key, Serializable).await?; + Ok(val) + }) + .await + .unwrap() + .map(|slice| slice.to_vec()) +} + +/// Exercises leader failover: two nodes share one Postgres, the elected leader is killed, the +/// survivor must take over the lease (new epoch), continue the crash-safe version sequence without +/// regression, preserve the dead leader's committed data, and resume accepting commits. +#[tokio::test] +async fn test_postgres_leader_failover() { + let _ = tracing_subscriber::fmt() + .with_env_filter("info") + .with_test_writer() + .try_init(); + + let (db_config, docker_config) = TestDatabase::Postgres + .config(Uuid::new_v4(), 1) + .await + .unwrap(); + let mut docker_config = docker_config.unwrap(); + docker_config.start().await.unwrap(); + + tokio::time::sleep(Duration::from_secs(4)).await; + + let rivet_config::config::Database::Postgres(postgres_config) = db_config else { + unreachable!(); + }; + let connection_string = postgres_config.url.read().clone(); + + let raw = connect_raw(&connection_string).await; + + // Node 1 comes up first and deterministically wins the first election (epoch 1). + let db1 = make_db(&connection_string).await; + let lease1 = wait_for_lease(&raw, Duration::from_secs(15), |l| l.epoch == 1).await; + let leader1_addr = lease1.leader_addr.clone(); + + // Node 2 joins while node 1 holds a valid lease, so it loses the election and runs as a + // follower. + let db2 = make_db(&connection_string).await; + + // Leader (node 1) commits data. The version sequence and watermark advance. + write_key(&db1, ALPHA_KEY, b"1").await; + + // The follower (node 2) reads through its own snapshot and sees the leader's committed write, + // proving cross-node reads work before any failover. + assert_eq!( + read_key(&db2, ALPHA_KEY).await, + Some(b"1".to_vec()), + "follower must see the leader's committed write" + ); + + let lease_before = read_lease(&raw).await.unwrap(); + let seq_before = read_seq_high(&raw).await; + assert!( + lease_before.durable_version >= 1, + "durable_version must have advanced after the first commit" + ); + + // Kill node 1. Dropping the driver aborts its resolver, so it stops renewing the lease. + drop(db1); + + // Node 2 must take over once node 1's lease expires (TTL is 10s). The epoch is bumped and the + // leader address changes to node 2. + let lease_after = wait_for_lease(&raw, Duration::from_secs(40), |l| { + l.epoch > lease_before.epoch + }) + .await; + assert!( + lease_after.epoch > lease_before.epoch, + "new leader must bump the epoch (was {}, now {})", + lease_before.epoch, + lease_after.epoch + ); + assert_ne!( + lease_after.leader_addr, leader1_addr, + "the surviving node must become the new leader" + ); + + // The crash-safe LOGGED sequence continues from the prior high-water; it never regresses. + let seq_after_takeover = read_seq_high(&raw).await; + assert!( + seq_after_takeover >= seq_before, + "version sequence regressed across failover ({} -> {})", + seq_before, + seq_after_takeover + ); + assert!( + lease_after.durable_version >= lease_before.durable_version, + "durable_version regressed across failover ({} -> {})", + lease_before.durable_version, + lease_after.durable_version + ); + + // The data the dead leader committed survives the failover. + assert_eq!( + read_key(&db2, ALPHA_KEY).await, + Some(b"1".to_vec()), + "committed data must survive leader failover" + ); + + // The new leader resumes accepting commits. + write_key(&db2, BETA_KEY, b"2").await; + assert_eq!( + read_key(&db2, BETA_KEY).await, + Some(b"2".to_vec()), + "new leader must accept and durably apply commits" + ); + + // The new commit advanced the version sequence and watermark past the pre-failover floor, + // confirming the new leader sequences from a strictly higher version. + let lease_final = read_lease(&raw).await.unwrap(); + assert!( + read_seq_high(&raw).await > seq_before, + "a post-failover commit must advance the version sequence" + ); + assert!( + lease_final.durable_version > lease_before.durable_version, + "a post-failover commit must advance the durable watermark" + ); + + drop(db2); +} diff --git a/engine/packages/universalpubsub/src/driver/postgres/mod.rs b/engine/packages/universalpubsub/src/driver/postgres/mod.rs index d361eb1149..2a492fbccf 100644 --- a/engine/packages/universalpubsub/src/driver/postgres/mod.rs +++ b/engine/packages/universalpubsub/src/driver/postgres/mod.rs @@ -40,9 +40,12 @@ const POLL_INTERVAL: Duration = Duration::from_secs(1); /// rather than a cluster outage. const LISTEN_IDLE_IN_TRANSACTION_TIMEOUT_MS: i64 = 30_000; -const QUEUE_SUB_HEARTBEAT_INTERVAL: Duration = Duration::from_secs(10); -/// How long a queue subscriber's heartbeat must be within to be considered active. -const QUEUE_SUB_TTL_SECS: i64 = 30; +/// How often this process refreshes its node liveness heartbeat. One heartbeat per +/// process keeps all of its subscriber registrations alive at once. +const NODE_HEARTBEAT_INTERVAL: Duration = Duration::from_secs(10); +/// How recent a node's heartbeat must be for its subscribers to count as live +/// responders. +const NODE_TTL_SECS: i64 = 30; /// How often to GC expired broadcast messages. const MESSAGE_GC_INTERVAL: Duration = Duration::from_secs(5); @@ -51,19 +54,34 @@ const MESSAGE_GC_INTERVAL: Duration = Duration::from_secs(5); /// messages, matching NATS-core at-most-once semantics for slow consumers. const MESSAGE_MAX_AGE_SECS: i64 = 10; +/// How often to GC dead nodes and the subscriber rows orphaned by them. +const REGISTRY_GC_INTERVAL: Duration = Duration::from_secs(30); + /// How often to GC orphaned queue messages. const QUEUE_MESSAGE_GC_INTERVAL: Duration = Duration::from_secs(300); /// Max age before an unconsumed queue message is garbage collected. const QUEUE_MESSAGE_MAX_AGE_SECS: i64 = 3600; +/// Per-shard signal carried over a subscriber's in-process wakeup channel. +#[derive(Clone)] +enum ShardSignal { + /// A doorbell NOTIFY landed for this shard. Poll the table. + Wakeup, + /// A local request found no responders for the given reply subject. The matching + /// reply subscriber surfaces a no-responders result. + NoResponders { subject: String }, +} + #[derive(Clone)] pub struct PostgresDriver { pool: Arc, client: Arc>>, + /// Identifies this process in the subscriber registry. A single heartbeat keeps + /// all of this node's registrations live. + node_id: String, /// Wakeup channels keyed by doorbell shard channel name. Shared by broadcast and - /// queue subscribers whose subjects map to the same shard. Carries empty wakeups - /// only; payload lives in the table. - shard_subscriptions: Arc>>, + /// queue subscribers whose subjects map to the same shard. + shard_subscriptions: Arc>>, doorbell: Arc, client_ready: tokio::sync::watch::Receiver, } @@ -105,9 +123,10 @@ impl PostgresDriver { tracing::debug!("postgres pool created successfully"); let pool = Arc::new(pool); - let shard_subscriptions: Arc>> = + let shard_subscriptions: Arc>> = Arc::new(HashMap::new()); let client: Arc>> = Arc::new(Mutex::new(None)); + let node_id = Uuid::new_v4().to_string(); // Create channel for client ready notifications let (ready_tx, client_ready) = tokio::sync::watch::channel(false); @@ -128,6 +147,7 @@ impl PostgresDriver { let driver = Self { pool, client, + node_id, shard_subscriptions, doorbell, client_ready, @@ -138,6 +158,7 @@ impl PostgresDriver { // Create tables eagerly so they exist before any publish or subscribe. { + tracing::debug!("configuring postgres udb tables"); let conn = driver .pool .get() @@ -157,11 +178,23 @@ impl PostgresDriver { ); \ CREATE INDEX IF NOT EXISTS ups_messages_subject_id \ ON ups_messages (subject_hash, id); \ + CREATE TABLE IF NOT EXISTS ups_nodes ( \ + node_id TEXT PRIMARY KEY, \ + heartbeat_at TIMESTAMPTZ NOT NULL DEFAULT NOW() \ + ); \ + CREATE TABLE IF NOT EXISTS ups_subs ( \ + id TEXT PRIMARY KEY, \ + node_id TEXT NOT NULL, \ + subject_hash TEXT NOT NULL, \ + subject TEXT NOT NULL \ + ); \ + CREATE INDEX IF NOT EXISTS ups_subs_subject \ + ON ups_subs (subject_hash); \ CREATE TABLE IF NOT EXISTS ups_queue_subs ( \ id TEXT PRIMARY KEY, \ + node_id TEXT NOT NULL, \ subject_hash TEXT NOT NULL, \ - queue_hash TEXT NOT NULL, \ - heartbeat_at TIMESTAMPTZ NOT NULL DEFAULT NOW() \ + queue_hash TEXT NOT NULL \ ); \ CREATE INDEX IF NOT EXISTS ups_queue_subs_subject_queue \ ON ups_queue_subs (subject_hash, queue_hash); \ @@ -177,9 +210,24 @@ impl PostgresDriver { ) .await .context("failed to create tables")?; - tracing::debug!("tables ready"); + tracing::debug!("postgres udb tables ready"); } + // Register this node and start its liveness heartbeat. + driver.heartbeat_node().await?; + let heartbeat_driver = driver.clone(); + tokio::spawn(async move { + let mut interval = tokio::time::interval(NODE_HEARTBEAT_INTERVAL); + interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); + + loop { + interval.tick().await; + if let Err(e) = heartbeat_driver.heartbeat_node().await { + tracing::warn!(?e, "failed to heartbeat node"); + } + } + }); + // Spawn GC task for expired broadcast messages let message_gc_driver = driver.clone(); tokio::spawn(async move { @@ -203,6 +251,49 @@ impl PostgresDriver { } }); + // Spawn GC task for dead nodes and orphaned subscriber rows. + let registry_gc_driver = driver.clone(); + tokio::spawn(async move { + let mut interval = tokio::time::interval(REGISTRY_GC_INTERVAL); + interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); + + loop { + interval.tick().await; + if let Ok(conn) = registry_gc_driver.pool.get().await { + if let Err(e) = conn + .execute( + "DELETE FROM ups_nodes \ + WHERE heartbeat_at < NOW() - ($1::bigint * INTERVAL '1 second')", + &[&NODE_TTL_SECS], + ) + .await + { + tracing::warn!(?e, "failed to gc dead nodes"); + } + if let Err(e) = conn + .execute( + "DELETE FROM ups_subs \ + WHERE node_id NOT IN (SELECT node_id FROM ups_nodes)", + &[], + ) + .await + { + tracing::warn!(?e, "failed to gc orphaned subs"); + } + if let Err(e) = conn + .execute( + "DELETE FROM ups_queue_subs \ + WHERE node_id NOT IN (SELECT node_id FROM ups_nodes)", + &[], + ) + .await + { + tracing::warn!(?e, "failed to gc orphaned queue subs"); + } + } + } + }); + // Spawn GC task for orphaned queue messages let gc_driver = driver.clone(); tokio::spawn(async move { @@ -232,7 +323,7 @@ impl PostgresDriver { /// Manages the connection lifecycle with automatic reconnection async fn spawn_connection_lifecycle( conn_str: String, - shard_subscriptions: Arc>>, + shard_subscriptions: Arc>>, client: Arc>>, ready_tx: tokio::sync::watch::Sender, ssl_root_cert_path: Option, @@ -337,7 +428,7 @@ impl PostgresDriver { /// Polls the connection for notifications until it closes or errors async fn poll_connection( mut conn: tokio_postgres::Connection, - shard_subscriptions: Arc>>, + shard_subscriptions: Arc>>, ) where T: tokio_postgres::tls::TlsStream + Unpin, { @@ -348,7 +439,7 @@ impl PostgresDriver { // Doorbell notifications are payload-free wakeup signals only. // Subscribers read their payload from the table. if let Some(sub) = shard_subscriptions.get_async(note.channel()).await { - let _ = sub.send(()); + let _ = sub.send(ShardSignal::Wakeup); } else { tracing::trace!(channel = %note.channel(), "wakeup for unknown shard"); } @@ -404,6 +495,24 @@ impl PostgresDriver { format!("{:x}", hasher.finish()) } + /// Upserts this node's liveness heartbeat. Re-inserts the row if a GC pass removed + /// it after a transient stall. + async fn heartbeat_node(&self) -> Result<()> { + let conn = self + .pool + .get() + .await + .context("failed to get connection for node heartbeat")?; + conn.execute( + "INSERT INTO ups_nodes (node_id, heartbeat_at) VALUES ($1, NOW()) \ + ON CONFLICT (node_id) DO UPDATE SET heartbeat_at = NOW()", + &[&self.node_id], + ) + .await + .context("failed to upsert node heartbeat")?; + Ok(()) + } + /// Returns the current max broadcast message id, used as a subscriber's starting /// cursor so it only sees future messages (NATS at-most-once, no replay). async fn current_max_id(&self) -> Result { @@ -424,7 +533,10 @@ impl PostgresDriver { async fn ensure_shard_listen( &self, shard: usize, - ) -> (broadcast::Receiver<()>, tokio_util::sync::DropGuard) { + ) -> ( + broadcast::Receiver, + tokio_util::sync::DropGuard, + ) { let channel = shard_channel(shard); match self.shard_subscriptions.entry_async(channel.clone()).await { @@ -465,7 +577,7 @@ impl PostgresDriver { fn spawn_shard_cleanup_task( &self, channel: String, - tx: broadcast::Sender<()>, + tx: broadcast::Sender, ) -> tokio_util::sync::DropGuard { let driver = self.clone(); let token = tokio_util::sync::CancellationToken::new(); @@ -515,14 +627,15 @@ impl PostgresDriver { .await .context("failed to insert broadcast message")?; - // Queue rows for every active queue group on this subject. Batched into the - // same transaction so a crash never strands a row mid-publish. + // Queue rows for every live queue group on this subject. Batched into the same + // transaction so a crash never strands a row mid-publish. let rows = tx .query( - "SELECT DISTINCT queue_hash FROM ups_queue_subs \ - WHERE subject_hash = $1 \ - AND heartbeat_at > NOW() - ($2::bigint * INTERVAL '1 second')", - &[&subject_hash, &QUEUE_SUB_TTL_SECS], + "SELECT DISTINCT s.queue_hash FROM ups_queue_subs s \ + JOIN ups_nodes n ON s.node_id = n.node_id \ + WHERE s.subject_hash = $1 \ + AND n.heartbeat_at > NOW() - ($2::bigint * INTERVAL '1 second')", + &[&subject_hash, &NODE_TTL_SECS], ) .await .context("failed to query active queue subs")?; @@ -542,6 +655,50 @@ impl PostgresDriver { Ok(()) } + + /// Returns whether any live subscriber (broadcast or queue) exists for the subject + /// anywhere in the fleet. Used to decide whether a request surfaces a no-responders + /// result instead of waiting out its timeout. + async fn has_responders(&self, subject_hash: &str, subject: &str) -> Result { + let conn = self + .pool + .get() + .await + .context("failed to get connection for responder check")?; + let row = conn + .query_one( + "SELECT \ + EXISTS( \ + SELECT 1 FROM ups_subs s \ + JOIN ups_nodes n ON s.node_id = n.node_id \ + WHERE s.subject_hash = $1 AND s.subject = $2 \ + AND n.heartbeat_at > NOW() - ($3::bigint * INTERVAL '1 second') \ + ) \ + OR EXISTS( \ + SELECT 1 FROM ups_queue_subs s \ + JOIN ups_nodes n ON s.node_id = n.node_id \ + WHERE s.subject_hash = $1 \ + AND n.heartbeat_at > NOW() - ($3::bigint * INTERVAL '1 second') \ + )", + &[&subject_hash, &subject, &NODE_TTL_SECS], + ) + .await + .context("failed to check responders")?; + Ok(row.get(0)) + } + + /// Delivers a no-responders result to the local reply subscriber. The requester is + /// always in this process, so the signal is routed in-memory over the reply + /// subject's shard channel rather than the table. + async fn signal_no_responders(&self, reply_subject: &str) { + let reply_hash = self.hash_subject(reply_subject); + let channel = shard_channel(shard_for(&reply_hash)); + if let Some(tx) = self.shard_subscriptions.get_async(&channel).await { + let _ = tx.send(ShardSignal::NoResponders { + subject: reply_subject.to_string(), + }); + } + } } #[async_trait] @@ -549,7 +706,7 @@ impl PubSubDriver for PostgresDriver { async fn subscribe( &self, subject: &str, - _reply_id: Option, + reply_id: Option, ) -> Result { let subject_hash = self.hash_subject(subject); let shard = shard_for(&subject_hash); @@ -561,6 +718,28 @@ impl PubSubDriver for PostgresDriver { let (rx, drop_guard) = self.ensure_shard_listen(shard).await; + // Register in the responder registry so requests to this subject can detect + // responders. Reply inboxes are never request targets, so they skip the + // registry to keep request latency off this path. + let sub_id = if reply_id.is_none() { + let sub_id = Uuid::new_v4().to_string(); + let conn = self + .pool + .get() + .await + .context("failed to get connection for subscribe")?; + conn.execute( + "INSERT INTO ups_subs (id, node_id, subject_hash, subject) \ + VALUES ($1, $2, $3, $4)", + &[&sub_id, &self.node_id, &subject_hash, &subject], + ) + .await + .context("failed to register subscriber")?; + Some(sub_id) + } else { + None + }; + Ok(Box::new(PostgresSubscriber { subject: subject.to_string(), subject_hash, @@ -568,6 +747,7 @@ impl PubSubDriver for PostgresDriver { cursor, buffer: VecDeque::new(), rx, + sub_id, _drop_guard: drop_guard, })) } @@ -586,8 +766,9 @@ impl PubSubDriver for PostgresDriver { .await .context("failed to get connection for queue subscribe")?; conn.execute( - "INSERT INTO ups_queue_subs (id, subject_hash, queue_hash) VALUES ($1, $2, $3)", - &[&sub_id, &subject_hash, &queue_hash], + "INSERT INTO ups_queue_subs (id, node_id, subject_hash, queue_hash) \ + VALUES ($1, $2, $3, $4)", + &[&sub_id, &self.node_id, &subject_hash, &queue_hash], ) .await .context("failed to register queue subscriber")?; @@ -595,35 +776,6 @@ impl PubSubDriver for PostgresDriver { let (rx, drop_guard) = self.ensure_shard_listen(shard).await; - // Spawn heartbeat task to keep the registration alive - let pool = self.pool.clone(); - let sub_id_for_heartbeat = sub_id.clone(); - let heartbeat_token = tokio_util::sync::CancellationToken::new(); - let heartbeat_token_child = heartbeat_token.clone(); - tokio::spawn(async move { - let mut interval = tokio::time::interval(QUEUE_SUB_HEARTBEAT_INTERVAL); - interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); - - loop { - tokio::select! { - _ = heartbeat_token_child.cancelled() => break, - _ = interval.tick() => { - if let Ok(conn) = pool.get().await { - if let Err(e) = conn - .execute( - "UPDATE ups_queue_subs SET heartbeat_at = NOW() WHERE id = $1", - &[&sub_id_for_heartbeat], - ) - .await - { - tracing::warn!(?e, id = %sub_id_for_heartbeat, "failed to heartbeat queue sub"); - } - } - } - } - } - }); - Ok(Box::new(PostgresQueueSubscriber { subject: subject.to_string(), subject_hash, @@ -632,7 +784,6 @@ impl PubSubDriver for PostgresDriver { pool: self.pool.clone(), rx, _drop_guard: drop_guard, - _heartbeat_token: heartbeat_token, })) } @@ -640,11 +791,29 @@ impl PubSubDriver for PostgresDriver { &self, subject: &str, payload: &[u8], - _reply_subject: Option<&str>, + reply_subject: Option<&str>, ) -> Result<()> { let subject_hash = self.hash_subject(subject); let shard = shard_for(&subject_hash); + // Request semantics: if a reply is expected and no responder exists anywhere, + // surface a no-responders result immediately instead of persisting a message + // nobody will read. + if let Some(reply_subject) = reply_subject { + match self.has_responders(&subject_hash, subject).await { + Result::Ok(false) => { + self.signal_no_responders(reply_subject).await; + return Ok(()); + } + Result::Ok(true) => {} + Result::Err(e) => { + // On a failed check, fall through to a normal publish rather than + // risk a false no-responders result. + tracing::warn!(?e, %subject, "responder check failed, publishing anyway"); + } + } + } + // Persist the message, retrying on transient connection errors. The row is // committed before the doorbell rings so any wakeup observes it. let mut backoff = Backoff::default(); @@ -686,7 +855,9 @@ pub struct PostgresSubscriber { pool: Arc, cursor: i64, buffer: VecDeque>, - rx: broadcast::Receiver<()>, + rx: broadcast::Receiver, + /// Responder-registry row id, present for non-inbox subscriptions. Deleted on drop. + sub_id: Option, _drop_guard: tokio_util::sync::DropGuard, } @@ -744,11 +915,17 @@ impl SubscriberDriver for PostgresSubscriber { continue; } - // Wait for a doorbell wakeup or the poll backstop, whichever is first. + // Wait for a doorbell wakeup, a no-responders signal, or the poll backstop. tokio::select! { res = self.rx.recv() => { match res { - std::result::Result::Ok(()) => {} + std::result::Result::Ok(ShardSignal::Wakeup) => {} + std::result::Result::Ok(ShardSignal::NoResponders { subject }) + if subject == self.subject => + { + return Ok(DriverOutput::NoResponders); + } + std::result::Result::Ok(ShardSignal::NoResponders { .. }) => {} Err(broadcast::error::RecvError::Lagged(_)) => {} Err(broadcast::error::RecvError::Closed) => { return Ok(DriverOutput::Unsubscribed); @@ -761,15 +938,33 @@ impl SubscriberDriver for PostgresSubscriber { } } +impl Drop for PostgresSubscriber { + fn drop(&mut self) { + let Some(sub_id) = self.sub_id.take() else { + return; + }; + let pool = self.pool.clone(); + tokio::spawn(async move { + if let Ok(conn) = pool.get().await { + if let Err(e) = conn + .execute("DELETE FROM ups_subs WHERE id = $1", &[&sub_id]) + .await + { + tracing::warn!(?e, %sub_id, "failed to deregister subscriber"); + } + } + }); + } +} + pub struct PostgresQueueSubscriber { subject: String, subject_hash: String, queue_hash: String, sub_id: String, pool: Arc, - rx: broadcast::Receiver<()>, + rx: broadcast::Receiver, _drop_guard: tokio_util::sync::DropGuard, - _heartbeat_token: tokio_util::sync::CancellationToken, } impl PostgresQueueSubscriber { @@ -820,11 +1015,11 @@ impl SubscriberDriver for PostgresQueueSubscriber { } } - // Wait for a doorbell wakeup or the poll backstop, then loop back to claim. + // Wait for any shard signal or the poll backstop, then loop back to claim. tokio::select! { res = self.rx.recv() => { match res { - std::result::Result::Ok(()) => {} + std::result::Result::Ok(_) => {} Err(broadcast::error::RecvError::Lagged(_)) => {} Err(broadcast::error::RecvError::Closed) => { return Ok(DriverOutput::Unsubscribed); diff --git a/engine/packages/universalpubsub/tests/reconnect.rs b/engine/packages/universalpubsub/tests/reconnect.rs index 35ace375d3..8c15f6ab54 100644 --- a/engine/packages/universalpubsub/tests/reconnect.rs +++ b/engine/packages/universalpubsub/tests/reconnect.rs @@ -43,7 +43,7 @@ async fn test_nats_driver_with_memory_reconnect() { .unwrap(); let pubsub = PubSub::new_with_memory_optimization(Arc::new(driver), true); - test_all_inner(&pubsub, &docker).await; + test_all_inner(&pubsub, &docker, true).await; } #[tokio::test] @@ -77,7 +77,7 @@ async fn test_nats_driver_without_memory_reconnect() { .unwrap(); let pubsub = PubSub::new_with_memory_optimization(Arc::new(driver), false); - test_all_inner(&pubsub, &docker).await; + test_all_inner(&pubsub, &docker, true).await; } #[tokio::test] @@ -95,13 +95,12 @@ async fn test_postgres_driver_with_memory_reconnect() { }; let url = pg.url.read().clone(); - let driver = - universalpubsub::driver::postgres::PostgresDriver::connect(url, true, None, None, None) - .await - .unwrap(); + let driver = universalpubsub::driver::postgres::PostgresDriver::connect(url, None, None, None) + .await + .unwrap(); let pubsub = PubSub::new_with_memory_optimization(Arc::new(driver), true); - test_all_inner(&pubsub, &docker).await; + test_all_inner(&pubsub, &docker, false).await; } #[tokio::test] @@ -119,19 +118,27 @@ async fn test_postgres_driver_without_memory_reconnect() { }; let url = pg.url.read().clone(); - let driver = - universalpubsub::driver::postgres::PostgresDriver::connect(url, false, None, None, None) - .await - .unwrap(); + let driver = universalpubsub::driver::postgres::PostgresDriver::connect(url, None, None, None) + .await + .unwrap(); let pubsub = PubSub::new_with_memory_optimization(Arc::new(driver), false); - test_all_inner(&pubsub, &docker).await; + test_all_inner(&pubsub, &docker, false).await; } -async fn test_all_inner(pubsub: &PubSub, docker: &rivet_test_deps_docker::DockerRunConfig) { +async fn test_all_inner( + pubsub: &PubSub, + docker: &rivet_test_deps_docker::DockerRunConfig, + supports_subscribe_while_stopped: bool, +) { test_reconnect_inner(&pubsub, &docker).await; test_publish_while_stopped(&pubsub, &docker).await; - test_subscribe_while_stopped(&pubsub, &docker).await; + // The table-backed Postgres driver must read its cursor and register in the + // responder table when subscribing, so it cannot subscribe while the backend is + // fully down. NATS buffers the subscribe and reconnects, so it can. + if supports_subscribe_while_stopped { + test_subscribe_while_stopped(&pubsub, &docker).await; + } } async fn test_reconnect_inner(pubsub: &PubSub, docker: &rivet_test_deps_docker::DockerRunConfig) { diff --git a/scripts/run/engine-postgres.sh b/scripts/run/engine-postgres.sh index 966c421bec..70689517bd 100755 --- a/scripts/run/engine-postgres.sh +++ b/scripts/run/engine-postgres.sh @@ -4,19 +4,26 @@ set -euo pipefail SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" REPO_ROOT="$(cd "${SCRIPT_DIR}/../.." && pwd)" -if ! command -v nc >/dev/null 2>&1; then - echo "error: required command 'nc' not found." - exit 1 -fi +POSTGRES_IMAGE="postgres:18" + +# pg_isready reports ready only once the server is actually accepting connections. +# The Postgres entrypoint binds the port during its bootstrap phase and then +# restarts, so a plain port check (nc -z) passes too early and the engine hits +# "connection reset" / "early eof" on first connect. Run pg_isready from a throwaway +# container on the host network so no client binary needs to be installed locally. +postgres_ready() { + docker run --rm --network host "${POSTGRES_IMAGE}" \ + pg_isready -h localhost -p 5432 -U postgres -d postgres >/dev/null 2>&1 +} -if ! nc -z localhost 5432 >/dev/null 2>&1; then - echo "Postgres is not reachable at localhost:5432." +if ! postgres_ready; then + echo "Postgres is not accepting connections." echo "Starting postgres container..." "${SCRIPT_DIR}/postgres.sh" echo "Waiting for postgres to be ready..." for i in {1..30}; do - if nc -z localhost 5432 >/dev/null 2>&1; then + if postgres_ready; then echo "Postgres is ready!" break fi