From 1915ab8170e995b40f6c3e53af30ee24c2d1e011 Mon Sep 17 00:00:00 2001 From: MasterPtato Date: Thu, 25 Jun 2026 12:00:32 -0700 Subject: [PATCH] [SLOP(claude-opus-4-8-high)] feat(ups): table-backed postgres transport with coalesced doorbell --- .../src/driver/postgres/doorbell.rs | 133 ++++ .../src/driver/postgres/mod.rs | 692 +++++++++--------- 2 files changed, 477 insertions(+), 348 deletions(-) create mode 100644 engine/packages/universalpubsub/src/driver/postgres/doorbell.rs diff --git a/engine/packages/universalpubsub/src/driver/postgres/doorbell.rs b/engine/packages/universalpubsub/src/driver/postgres/doorbell.rs new file mode 100644 index 0000000000..f08ff7bbf3 --- /dev/null +++ b/engine/packages/universalpubsub/src/driver/postgres/doorbell.rs @@ -0,0 +1,133 @@ +use std::sync::Arc; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::time::Duration; + +use deadpool_postgres::Pool; +use tokio::sync::Notify; +use tokio::time::Instant; + +/// Number of doorbell shards. A subject maps to a shard via `hash(subject_hash) % K`. +/// Subscribers LISTEN their subject's shard channel; publishers wake the local +/// doorbell task which NOTIFYs the shard. +pub const DOORBELL_SHARD_COUNT: usize = 32; + +/// Debounce window. Caps each (process, shard) NOTIFY rate at one per window, which +/// bounds how many backends are woken per shard over time. +const DOORBELL_WINDOW: Duration = Duration::from_millis(5); + +/// Returns the NOTIFY channel name for a doorbell shard. +pub fn shard_channel(shard: usize) -> String { + format!("ups_db_{shard}") +} + +/// Returns the doorbell shard for a subject hash. +pub fn shard_for(subject_hash: &str) -> usize { + use std::hash::{DefaultHasher, Hash, Hasher}; + let mut hasher = DefaultHasher::new(); + subject_hash.hash(&mut hasher); + (hasher.finish() as usize) % DOORBELL_SHARD_COUNT +} + +/// Coalesced, payload-free NOTIFY doorbell. +/// +/// Publishers call [`Doorbell::mark_dirty`] after committing a row. A single +/// per-process task drains dirty shards and emits at most one NOTIFY per shard per +/// debounce window using leading-edge fire plus a trailing-edge flush. The doorbell +/// is a latency optimization only. Correctness comes from the table plus the +/// subscriber poll backstop, so a dropped or failed NOTIFY only adds latency. +pub struct Doorbell { + dirty: [AtomicBool; DOORBELL_SHARD_COUNT], + notify: Notify, + pool: Arc, +} + +impl Doorbell { + pub fn new(pool: Arc) -> Arc { + let doorbell = Arc::new(Self { + dirty: std::array::from_fn(|_| AtomicBool::new(false)), + notify: Notify::new(), + pool, + }); + + let task_doorbell = doorbell.clone(); + tokio::spawn(async move { task_doorbell.run().await }); + + doorbell + } + + /// Marks a shard dirty and wakes the doorbell task. Never blocks. + pub fn mark_dirty(&self, shard: usize) { + self.dirty[shard].store(true, Ordering::Release); + self.notify.notify_one(); + } + + async fn run(self: Arc) { + // Per-shard timestamp of the last NOTIFY emitted by this process. + let mut last_notify: [Option; DOORBELL_SHARD_COUNT] = [None; DOORBELL_SHARD_COUNT]; + // Per-shard deadline for a pending trailing-edge NOTIFY, if any. + let mut trailing: [Option; DOORBELL_SHARD_COUNT] = [None; DOORBELL_SHARD_COUNT]; + + loop { + // Arm on the next pending trailing deadline so the trailing edge fires + // even with no further publishes. Wait on the notify permit otherwise. + let next_deadline = trailing.iter().filter_map(|x| *x).min(); + match next_deadline { + Some(deadline) => { + tokio::select! { + _ = self.notify.notified() => {} + _ = tokio::time::sleep_until(deadline) => {} + } + } + None => { + self.notify.notified().await; + } + } + + let now = Instant::now(); + for shard in 0..DOORBELL_SHARD_COUNT { + let is_dirty = self.dirty[shard].swap(false, Ordering::AcqRel); + if is_dirty { + match last_notify[shard] { + Some(last) if now.duration_since(last) < DOORBELL_WINDOW => { + // Within the window. Defer to a trailing-edge NOTIFY at + // window end so at most one NOTIFY fires per shard per W. + if trailing[shard].is_none() { + trailing[shard] = Some(last + DOORBELL_WINDOW); + } + } + _ => { + // Leading edge. Fire immediately for low idle latency. + self.notify_shard(shard).await; + last_notify[shard] = Some(now); + trailing[shard] = None; + } + } + } + + // Flush a trailing-edge NOTIFY whose window has elapsed. + if let Some(deadline) = trailing[shard] { + if now >= deadline { + self.notify_shard(shard).await; + last_notify[shard] = Some(now); + trailing[shard] = None; + } + } + } + } + } + + async fn notify_shard(&self, shard: usize) { + let channel = shard_channel(shard); + match self.pool.get().await { + Ok(conn) => { + // Payload-free doorbell. The payload lives in the table. + if let Err(err) = conn.execute("SELECT pg_notify($1, '')", &[&channel]).await { + tracing::warn!(?err, %channel, "failed to emit doorbell notify"); + } + } + Err(err) => { + tracing::warn!(?err, %channel, "failed to get connection for doorbell notify"); + } + } + } +} diff --git a/engine/packages/universalpubsub/src/driver/postgres/mod.rs b/engine/packages/universalpubsub/src/driver/postgres/mod.rs index 52bc219b2a..d361eb1149 100644 --- a/engine/packages/universalpubsub/src/driver/postgres/mod.rs +++ b/engine/packages/universalpubsub/src/driver/postgres/mod.rs @@ -1,12 +1,11 @@ -use anyhow::{Context, Result, anyhow}; +use anyhow::{Context, Result}; use async_trait::async_trait; -use base64::Engine; -use base64::engine::general_purpose::STANDARD_NO_PAD as BASE64; use deadpool_postgres::{Config, ManagerConfig, Pool, PoolConfig, RecyclingMethod, Runtime}; use futures_util::future::poll_fn; use rivet_postgres_util::build_tls_config; use rivet_util::throttle::Backoff; use scc::HashMap; +use std::collections::VecDeque; use std::hash::{DefaultHasher, Hash, Hasher}; use std::path::PathBuf; use std::sync::Arc; @@ -21,37 +20,37 @@ use crate::driver::{PubSubDriver, SubscriberDriver, SubscriberDriverHandle}; use crate::metrics; use crate::pubsub::DriverOutput; -#[derive(Clone)] -struct Subscription { - // Channel to send messages to this subscription - tx: broadcast::Sender>, -} +mod doorbell; -impl Subscription { - fn new(tx: broadcast::Sender>) -> Self { - Self { tx } - } -} +use doorbell::{Doorbell, shard_channel, shard_for}; -/// > In the default configuration it must be shorter than 8000 bytes -/// -/// https://www.postgresql.org/docs/17/sql-notify.html -const MAX_NOTIFY_LENGTH: usize = 8000; +/// The transport is the table, not the NOTIFY payload, so there is no per-message +/// size cap from the 8000-byte NOTIFY limit. Match the NATS ceiling so chunking +/// behaves identically across drivers. +pub const POSTGRES_MAX_MESSAGE_SIZE: usize = 1024 * 1024; -/// Base64 encoding ratio -const BYTES_PER_BLOCK: usize = 3; -const CHARS_PER_BLOCK: usize = 4; +/// Poll backstop interval. Every subscriber reads its table on this interval +/// regardless of doorbell wakeups. This is the correctness floor that makes delivery +/// independent of any NOTIFY arriving. +const POLL_INTERVAL: Duration = Duration::from_secs(1); -/// Calculate max message size if encoded as base64 -/// -/// We need to remove BYTES_PER_BLOCK since there might be a tail on the base64-encoded data that -/// would bump it over the limit. -pub const POSTGRES_MAX_MESSAGE_SIZE: usize = - (MAX_NOTIFY_LENGTH * BYTES_PER_BLOCK) / CHARS_PER_BLOCK - BYTES_PER_BLOCK; +/// Idle-in-transaction timeout applied to the LISTEN connection. A wedged listener +/// holding a transaction open would otherwise fill the shared notify queue and fail +/// NOTIFY cluster-wide. Bounding it keeps a stuck listener degrading to added latency +/// rather than a cluster outage. +const LISTEN_IDLE_IN_TRANSACTION_TIMEOUT_MS: i64 = 30_000; const QUEUE_SUB_HEARTBEAT_INTERVAL: Duration = Duration::from_secs(10); /// How long a queue subscriber's heartbeat must be within to be considered active. const QUEUE_SUB_TTL_SECS: i64 = 30; + +/// How often to GC expired broadcast messages. +const MESSAGE_GC_INTERVAL: Duration = Duration::from_secs(5); +/// Max age before a broadcast message row is garbage collected. Must exceed the poll +/// interval plus the reconnect gap. A subscriber that falls behind this misses +/// messages, matching NATS-core at-most-once semantics for slow consumers. +const MESSAGE_MAX_AGE_SECS: i64 = 10; + /// How often to GC orphaned queue messages. const QUEUE_MESSAGE_GC_INTERVAL: Duration = Duration::from_secs(300); /// Max age before an unconsumed queue message is garbage collected. @@ -61,9 +60,11 @@ const QUEUE_MESSAGE_MAX_AGE_SECS: i64 = 3600; pub struct PostgresDriver { pool: Arc, client: Arc>>, - subscriptions: Arc>, - /// Wakeup channels for queue subscriptions, keyed by queue channel name. - queue_subscriptions: Arc>, + /// Wakeup channels keyed by doorbell shard channel name. Shared by broadcast and + /// queue subscribers whose subjects map to the same shard. Carries empty wakeups + /// only; payload lives in the table. + shard_subscriptions: Arc>>, + doorbell: Arc, client_ready: tokio::sync::watch::Receiver, } @@ -103,8 +104,9 @@ impl PostgresDriver { .context("failed to create postgres pool")?; tracing::debug!("postgres pool created successfully"); - let subscriptions: Arc> = Arc::new(HashMap::new()); - let queue_subscriptions: Arc> = Arc::new(HashMap::new()); + let pool = Arc::new(pool); + let shard_subscriptions: Arc>> = + Arc::new(HashMap::new()); let client: Arc>> = Arc::new(Mutex::new(None)); // Create channel for client ready notifications @@ -113,8 +115,7 @@ impl PostgresDriver { // Spawn connection lifecycle task tokio::spawn(Self::spawn_connection_lifecycle( conn_str.clone(), - subscriptions.clone(), - queue_subscriptions.clone(), + shard_subscriptions.clone(), client.clone(), ready_tx, ssl_root_cert_path.clone(), @@ -122,26 +123,41 @@ impl PostgresDriver { ssl_client_key_path.clone(), )); + let doorbell = Doorbell::new(pool.clone()); + let driver = Self { - pool: Arc::new(pool), + pool, client, - subscriptions, - queue_subscriptions, + shard_subscriptions, + doorbell, client_ready, }; // Wait for initial connection to be established driver.wait_for_client().await?; - // Create queue tables eagerly so they exist before any publish or subscribe + // Create tables eagerly so they exist before any publish or subscribe. { let conn = driver .pool .get() .await - .context("failed to get connection for queue table creation")?; + .context("failed to get connection for table creation")?; conn.batch_execute( - "CREATE TABLE IF NOT EXISTS ups_queue_subs ( \ + // Broadcast transport table. UNLOGGED gives at-most-once across a + // crash, matching NATS-core semantics, and avoids WAL fsync on every + // publish. The real subject is stored so receivers can verify it and + // reject DefaultHasher subject-hash collisions. + "CREATE UNLOGGED TABLE IF NOT EXISTS ups_messages ( \ + id BIGSERIAL PRIMARY KEY, \ + subject_hash TEXT NOT NULL, \ + subject TEXT NOT NULL, \ + payload BYTEA NOT NULL, \ + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() \ + ); \ + CREATE INDEX IF NOT EXISTS ups_messages_subject_id \ + ON ups_messages (subject_hash, id); \ + CREATE TABLE IF NOT EXISTS ups_queue_subs ( \ id TEXT PRIMARY KEY, \ subject_hash TEXT NOT NULL, \ queue_hash TEXT NOT NULL, \ @@ -149,7 +165,7 @@ impl PostgresDriver { ); \ CREATE INDEX IF NOT EXISTS ups_queue_subs_subject_queue \ ON ups_queue_subs (subject_hash, queue_hash); \ - CREATE TABLE IF NOT EXISTS ups_queue_messages ( \ + CREATE UNLOGGED TABLE IF NOT EXISTS ups_queue_messages ( \ id BIGSERIAL PRIMARY KEY, \ subject_hash TEXT NOT NULL, \ queue_hash TEXT NOT NULL, \ @@ -160,10 +176,33 @@ impl PostgresDriver { ON ups_queue_messages (subject_hash, queue_hash, id);", ) .await - .context("failed to create queue tables")?; - tracing::debug!("queue tables ready"); + .context("failed to create tables")?; + tracing::debug!("tables ready"); } + // Spawn GC task for expired broadcast messages + let message_gc_driver = driver.clone(); + tokio::spawn(async move { + let mut interval = tokio::time::interval(MESSAGE_GC_INTERVAL); + interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); + + loop { + interval.tick().await; + if let Ok(conn) = message_gc_driver.pool.get().await { + let result = conn + .execute( + "DELETE FROM ups_messages \ + WHERE created_at < NOW() - ($1::bigint * INTERVAL '1 second')", + &[&MESSAGE_MAX_AGE_SECS], + ) + .await; + if let Err(e) = result { + tracing::warn!(?e, "failed to gc broadcast messages"); + } + } + } + }); + // Spawn GC task for orphaned queue messages let gc_driver = driver.clone(); tokio::spawn(async move { @@ -193,8 +232,7 @@ impl PostgresDriver { /// Manages the connection lifecycle with automatic reconnection async fn spawn_connection_lifecycle( conn_str: String, - subscriptions: Arc>, - queue_subscriptions: Arc>, + shard_subscriptions: Arc>>, client: Arc>>, ready_tx: tokio::sync::watch::Sender, ssl_root_cert_path: Option, @@ -227,41 +265,42 @@ impl PostgresDriver { // Spawn the polling task immediately // This must be done before any operations on the client - let subscriptions_clone = subscriptions.clone(); - let queue_subscriptions_clone = queue_subscriptions.clone(); + let shard_subscriptions_clone = shard_subscriptions.clone(); let poll_handle = tokio::spawn(async move { - Self::poll_connection(conn, subscriptions_clone, queue_subscriptions_clone) - .await; + Self::poll_connection(conn, shard_subscriptions_clone).await; }); - // Get regular channels to re-subscribe to + // Bound a stuck listener so it cannot wedge the shared notify queue. + if let Result::Err(e) = new_client + .execute( + &format!( + "SET idle_in_transaction_session_timeout = '{}'", + LISTEN_IDLE_IN_TRANSACTION_TIMEOUT_MS + ), + &[], + ) + .await + { + tracing::warn!(?e, "failed to set idle_in_transaction_session_timeout"); + } + + // Get shard channels to re-subscribe to let mut channels = Vec::new(); - subscriptions + shard_subscriptions .iter_async(|k, _| { channels.push(k.clone()); true }) .await; - // Get queue wakeup channels to re-subscribe to - let mut queue_channels = Vec::new(); - queue_subscriptions - .iter_async(|k, _| { - queue_channels.push(k.clone()); - true - }) - .await; - - let needs_resubscribe = !channels.is_empty() || !queue_channels.is_empty(); - if needs_resubscribe { + if !channels.is_empty() { tracing::debug!( - regular_channels = channels.len(), - queue_channels = queue_channels.len(), - "re-subscribing to channels after reconnection" + channels = channels.len(), + "re-subscribing to doorbell shards after reconnection" ); } - for channel in channels.iter().chain(queue_channels.iter()) { + for channel in channels.iter() { tracing::debug!(?channel, "re-subscribing to channel"); if let Result::Err(e) = new_client .execute(&format!("LISTEN \"{}\"", channel), &[]) @@ -298,30 +337,20 @@ impl PostgresDriver { /// Polls the connection for notifications until it closes or errors async fn poll_connection( mut conn: tokio_postgres::Connection, - subscriptions: Arc>, - queue_subscriptions: Arc>, + shard_subscriptions: Arc>>, ) where T: tokio_postgres::tls::TlsStream + Unpin, { loop { match poll_fn(|cx| conn.poll_message(cx)).await { Some(std::result::Result::Ok(AsyncMessage::Notification(note))) => { - tracing::trace!(channel = %note.channel(), "received notification"); - if let Some(sub) = subscriptions.get_async(note.channel()).await { - let bytes = match BASE64.decode(note.payload()) { - std::result::Result::Ok(b) => b, - std::result::Result::Err(err) => { - tracing::error!(?err, "failed decoding base64"); - continue; - } - }; - tracing::trace!(channel = %note.channel(), bytes_len = bytes.len(), "sending to broadcast channel"); - let _ = sub.tx.send(bytes); - } else if let Some(sub) = queue_subscriptions.get_async(note.channel()).await { - // Queue notifications are wakeup signals only; payload lives in the table - let _ = sub.tx.send(Vec::new()); + tracing::trace!(channel = %note.channel(), "received doorbell wakeup"); + // Doorbell notifications are payload-free wakeup signals only. + // Subscribers read their payload from the table. + if let Some(sub) = shard_subscriptions.get_async(note.channel()).await { + let _ = sub.send(()); } else { - tracing::warn!(channel = %note.channel(), "received notification for unknown channel"); + tracing::trace!(channel = %note.channel(), "wakeup for unknown shard"); } } Some(std::result::Result::Ok(_)) => { @@ -361,8 +390,9 @@ impl PostgresDriver { } fn hash_subject(&self, subject: &str) -> String { - // Postgres channel names have a 64 character limit - // Hash the subject to ensure it fits + // Postgres channel names have a 64 character limit, but this hash is also the + // table index key. Collisions are possible and resolved by verifying the real + // subject stored alongside each row. let mut hasher = DefaultHasher::new(); subject.hash(&mut hasher); format!("ups_{:x}", hasher.finish()) @@ -374,57 +404,68 @@ impl PostgresDriver { format!("{:x}", hasher.finish()) } - /// Returns the NOTIFY channel name for a (subject, queue) pair. - fn queue_channel(&self, subject_hash: &str, queue_hash: &str) -> String { - // Max length: "ups_q_" (6) + 16 + "_" (1) + 16 = 39 chars, well within 64 - format!("ups_q_{}_{}", subject_hash, queue_hash) - } - - /// Inserts messages into the queue table and notifies active queue subscribers. - async fn publish_to_queues(&self, subject: &str, payload: &[u8]) -> Result<()> { - let subject_hash = self.hash_subject(subject); - + /// Returns the current max broadcast message id, used as a subscriber's starting + /// cursor so it only sees future messages (NATS at-most-once, no replay). + async fn current_max_id(&self) -> Result { let conn = self .pool .get() .await - .context("failed to get connection for queue publish")?; - - // Find active queue groups for this subject - let rows = conn - .query( - "SELECT DISTINCT queue_hash FROM ups_queue_subs \ - WHERE subject_hash = $1 \ - AND heartbeat_at > NOW() - ($2::bigint * INTERVAL '1 second')", - &[&subject_hash, &QUEUE_SUB_TTL_SECS], - ) + .context("failed to get connection for cursor init")?; + let row = conn + .query_one("SELECT COALESCE(MAX(id), 0) FROM ups_messages", &[]) .await - .context("failed to query active queue subs")?; + .context("failed to read current max id")?; + Ok(row.get(0)) + } - for row in rows { - let queue_hash: String = row.get(0); - let channel = self.queue_channel(&subject_hash, &queue_hash); + /// Ensures this process is LISTENing on the given doorbell shard and returns a + /// wakeup receiver plus a drop guard that UNLISTENs once no receivers remain. + async fn ensure_shard_listen( + &self, + shard: usize, + ) -> (broadcast::Receiver<()>, tokio_util::sync::DropGuard) { + let channel = shard_channel(shard); - conn.execute( - "INSERT INTO ups_queue_messages (subject_hash, queue_hash, payload) \ - VALUES ($1, $2, $3)", - &[&subject_hash, &queue_hash, &payload], - ) - .await - .context("failed to insert queue message")?; + match self.shard_subscriptions.entry_async(channel.clone()).await { + scc::hash_map::Entry::Occupied(existing) => { + let rx = existing.subscribe(); + let drop_guard = + self.spawn_shard_cleanup_task(channel.clone(), existing.get().clone()); + (rx, drop_guard) + } + scc::hash_map::Entry::Vacant(e) => { + let (tx, rx) = broadcast::channel(1024); + e.insert_entry(tx.clone()); + metrics::POSTGRES_SUBSCRIPTION_COUNT.set(self.shard_subscriptions.len() as i64); - conn.execute(&format!("NOTIFY \"{}\"", channel), &[]) - .await - .context("failed to notify queue channel")?; - } + if let Some(client) = &*self.client.lock().await { + match client + .execute(&format!("LISTEN \"{channel}\""), &[]) + .instrument(tracing::trace_span!("pg_listen")) + .await + { + Result::Ok(_) => { + tracing::debug!(%channel, "successfully subscribed to shard"); + } + Result::Err(e) => { + tracing::warn!(?e, %channel, "failed to LISTEN, will retry on reconnection"); + } + } + } else { + tracing::debug!(%channel, "client not connected, will LISTEN on reconnection"); + } - Ok(()) + let drop_guard = self.spawn_shard_cleanup_task(channel.clone(), tx.clone()); + (rx, drop_guard) + } + } } - fn spawn_subscription_cleanup_task( + fn spawn_shard_cleanup_task( &self, - subject_hash: String, - tx: broadcast::Sender>, + channel: String, + tx: broadcast::Sender<()>, ) -> tokio_util::sync::DropGuard { let driver = self.clone(); let token = tokio_util::sync::CancellationToken::new(); @@ -434,47 +475,72 @@ impl PostgresDriver { token.cancelled().await; if tx.receiver_count() == 0 { if let Some(client) = &*driver.client.lock().await { - let sql = format!("UNLISTEN \"{}\"", subject_hash); + let sql = format!("UNLISTEN \"{}\"", channel); if let Err(err) = client.execute(sql.as_str(), &[]).await { - tracing::warn!(?err, %subject_hash, "failed to UNLISTEN channel"); + tracing::warn!(?err, %channel, "failed to UNLISTEN channel"); } else { - tracing::trace!(%subject_hash, "unlistened channel"); + tracing::trace!(%channel, "unlistened channel"); } } - driver.subscriptions.remove_async(&subject_hash).await; - metrics::POSTGRES_SUBSCRIPTION_COUNT.set(driver.subscriptions.len() as i64); + driver.shard_subscriptions.remove_async(&channel).await; + metrics::POSTGRES_SUBSCRIPTION_COUNT.set(driver.shard_subscriptions.len() as i64); } }); drop_guard } - fn spawn_queue_subscription_cleanup_task( + /// Inserts the broadcast row and any active queue-group rows in one transaction. + async fn try_publish_to_db( &self, - channel: String, - tx: broadcast::Sender>, - ) -> tokio_util::sync::DropGuard { - let driver = self.clone(); - let token = tokio_util::sync::CancellationToken::new(); - let drop_guard = token.clone().drop_guard(); + subject: &str, + subject_hash: &str, + payload: &[u8], + ) -> Result<()> { + let mut conn = self + .pool + .get() + .await + .context("failed to get connection for publish")?; + let tx = conn + .transaction() + .await + .context("failed to begin publish transaction")?; - tokio::spawn(async move { - token.cancelled().await; - if tx.receiver_count() == 0 { - if let Some(client) = &*driver.client.lock().await { - let sql = format!("UNLISTEN \"{}\"", channel); + // Broadcast row. + tx.execute( + "INSERT INTO ups_messages (subject_hash, subject, payload) VALUES ($1, $2, $3)", + &[&subject_hash, &subject, &payload], + ) + .await + .context("failed to insert broadcast message")?; - if let Err(err) = client.execute(sql.as_str(), &[]).await { - tracing::warn!(?err, %channel, "failed to UNLISTEN queue channel"); - } else { - tracing::trace!(%channel, "unlistened queue channel"); - } - } - driver.queue_subscriptions.remove_async(&channel).await; - } - }); + // Queue rows for every active queue group on this subject. Batched into the + // same transaction so a crash never strands a row mid-publish. + let rows = tx + .query( + "SELECT DISTINCT queue_hash FROM ups_queue_subs \ + WHERE subject_hash = $1 \ + AND heartbeat_at > NOW() - ($2::bigint * INTERVAL '1 second')", + &[&subject_hash, &QUEUE_SUB_TTL_SECS], + ) + .await + .context("failed to query active queue subs")?; - drop_guard + for row in rows { + let queue_hash: String = row.get(0); + tx.execute( + "INSERT INTO ups_queue_messages (subject_hash, queue_hash, payload) \ + VALUES ($1, $2, $3)", + &[&subject_hash, &queue_hash, &payload], + ) + .await + .context("failed to insert queue message")?; + } + + tx.commit().await.context("failed to commit publish")?; + + Ok(()) } } @@ -485,66 +551,23 @@ impl PubSubDriver for PostgresDriver { subject: &str, _reply_id: Option, ) -> Result { - // TODO: To match NATS implementation, LISTEN must be pipelined (i.e. wait for the command - // to reach the server, but not wait for it to respond). However, this has to ensure that - // NOTIFY & LISTEN are called on the same connection (not diff connections in a pool) or - // else there will be race conditions where messages might be published before - // subscriptions are registered. - // - // tokio-postgres currently does not expose the API for pipelining, so we are SOL. - // - // We might be able to use a background tokio task in combination with flush if we use the - // same Postgres connection, but unsure if that will create a bottleneck. - - let hashed = self.hash_subject(subject); - - // Check if we already have a subscription for this channel - let (rx, drop_guard) = match self.subscriptions.entry_async(hashed.clone()).await { - scc::hash_map::Entry::Occupied(existing_sub) => { - // Reuse the existing broadcast channel - let rx = existing_sub.tx.subscribe(); - let drop_guard = - self.spawn_subscription_cleanup_task(hashed.clone(), existing_sub.tx.clone()); - (rx, drop_guard) - } - scc::hash_map::Entry::Vacant(e) => { - // Create a new broadcast channel for this subject - let (tx, rx) = tokio::sync::broadcast::channel(1024); - let subscription = Subscription::new(tx.clone()); - - // Register subscription - e.insert_entry(subscription.clone()); - metrics::POSTGRES_SUBSCRIPTION_COUNT.set(self.subscriptions.len() as i64); - - // Execute LISTEN command on the async client (for receiving notifications) - // This only needs to be done once per channel - // Try to LISTEN if client is available, but don't fail if disconnected - // The reconnection logic will handle re-subscribing - if let Some(client) = &*self.client.lock().await { - match client - .execute(&format!("LISTEN \"{hashed}\""), &[]) - .instrument(tracing::trace_span!("pg_listen")) - .await - { - Result::Ok(_) => { - tracing::debug!(%hashed, "successfully subscribed to channel"); - } - Result::Err(e) => { - tracing::warn!(?e, %hashed, "failed to LISTEN, will retry on reconnection"); - } - } - } else { - tracing::debug!(%hashed, "client not connected, will LISTEN on reconnection"); - } + let subject_hash = self.hash_subject(subject); + let shard = shard_for(&subject_hash); - let drop_guard = self.spawn_subscription_cleanup_task(hashed.clone(), tx.clone()); - (rx, drop_guard) - } - }; + // Capture the cursor before LISTENing. Any message inserted after this point + // has a higher id and is delivered either by the doorbell wakeup or the poll + // backstop, so there is no subscribe/publish race. + let cursor = self.current_max_id().await?; + + let (rx, drop_guard) = self.ensure_shard_listen(shard).await; Ok(Box::new(PostgresSubscriber { subject: subject.to_string(), - rx: Some(rx), + subject_hash, + pool: self.pool.clone(), + cursor, + buffer: VecDeque::new(), + rx, _drop_guard: drop_guard, })) } @@ -552,7 +575,7 @@ impl PubSubDriver for PostgresDriver { async fn queue_subscribe(&self, subject: &str, queue: &str) -> Result { let subject_hash = self.hash_subject(subject); let queue_hash = self.hash_queue(queue); - let channel = self.queue_channel(&subject_hash, &queue_hash); + let shard = shard_for(&subject_hash); // Register this subscriber in the database so publishers know the queue exists let sub_id = Uuid::new_v4().to_string(); @@ -570,44 +593,7 @@ impl PubSubDriver for PostgresDriver { .context("failed to register queue subscriber")?; } - // Set up a shared LISTEN/broadcast channel for the wakeup signal - let (rx, drop_guard) = match self.queue_subscriptions.entry_async(channel.clone()).await { - scc::hash_map::Entry::Occupied(existing_sub) => { - let rx = existing_sub.tx.subscribe(); - let drop_guard = self.spawn_queue_subscription_cleanup_task( - channel.clone(), - existing_sub.tx.clone(), - ); - (rx, drop_guard) - } - scc::hash_map::Entry::Vacant(e) => { - let (tx, rx) = tokio::sync::broadcast::channel(1024); - let subscription = Subscription::new(tx.clone()); - - e.insert_entry(subscription.clone()); - - if let Some(client) = &*self.client.lock().await { - match client - .execute(&format!("LISTEN \"{}\"", channel), &[]) - .instrument(tracing::trace_span!("pg_listen_queue")) - .await - { - Result::Ok(_) => { - tracing::debug!(%channel, "successfully subscribed to queue channel"); - } - Result::Err(e) => { - tracing::warn!(?e, %channel, "failed to LISTEN queue channel, will retry on reconnection"); - } - } - } else { - tracing::debug!(%channel, "client not connected, will LISTEN queue channel on reconnection"); - } - - let drop_guard = - self.spawn_queue_subscription_cleanup_task(channel.clone(), tx.clone()); - (rx, drop_guard) - } - }; + let (rx, drop_guard) = self.ensure_shard_listen(shard).await; // Spawn heartbeat task to keep the registration alive let pool = self.pool.clone(); @@ -644,7 +630,7 @@ impl PubSubDriver for PostgresDriver { queue_hash, sub_id, pool: self.pool.clone(), - rx: Some(rx), + rx, _drop_guard: drop_guard, _heartbeat_token: heartbeat_token, })) @@ -656,78 +642,33 @@ impl PubSubDriver for PostgresDriver { payload: &[u8], _reply_subject: Option<&str>, ) -> Result<()> { - // TODO: See `subscribe` about pipelining - - // Encode payload to base64 and send NOTIFY - let encoded = BASE64.encode(payload); - let hashed = self.hash_subject(subject); - - tracing::trace!("attempting to get connection for publish"); - - // Wait for listen connection to be ready first if this channel has subscribers - // This ensures that if we're reconnecting, the LISTEN is re-registered before NOTIFY - if self.subscriptions.contains_async(&hashed).await { - self.wait_for_client().await?; - } + let subject_hash = self.hash_subject(subject); + let shard = shard_for(&subject_hash); - // Retry getting a connection from the pool with backoff in case the connection is - // currently disconnected + // Persist the message, retrying on transient connection errors. The row is + // committed before the doorbell rings so any wakeup observes it. let mut backoff = Backoff::default(); - let mut last_error; - loop { - match self.pool.get().await { - Result::Ok(conn) => { - // Test the connection with a simple query before using it - match conn.execute("SELECT 1", &[]).await { - Result::Ok(_) => { - // Connection is good; run NOTIFY and queue publish in parallel. - // publish_to_queues acquires its own pool connection so both - // can proceed concurrently. - let notify_sql = format!("NOTIFY \"{hashed}\", '{encoded}'"); - let (notify_result, queue_result) = tokio::join!( - conn.execute(notify_sql.as_str(), &[]) - .instrument(tracing::trace_span!("pg_notify")), - self.publish_to_queues(subject, payload), - ); - match notify_result { - Result::Ok(_) => { - if let Err(e) = queue_result { - tracing::warn!(?e, %subject, "failed to publish to queue subscribers"); - } - return Ok(()); - } - Result::Err(e) => { - tracing::debug!( - ?e, - "NOTIFY failed, retrying with new connection" - ); - last_error = Some(e.into()); - } - } - } - Result::Err(e) => { - tracing::debug!( - ?e, - "connection test failed, retrying with new connection" - ); - last_error = Some(e.into()); - } - } - } + match self + .try_publish_to_db(subject, &subject_hash, payload) + .await + { + Result::Ok(()) => break, Result::Err(e) => { - tracing::debug!(?e, "failed to get connection from pool, retrying"); - last_error = Some(e.into()); + if !backoff.tick().await { + tracing::warn!(?e, %subject, "failed to publish, cannot retry again"); + return Err(e); + } + tracing::debug!(?e, "publish failed, retrying"); } } - - // Check if we should continue retrying - if !backoff.tick().await { - return Err( - last_error.unwrap_or_else(|| anyhow!("failed to publish after retries")) - ); - } } + + // Ring the doorbell. Best-effort: the subscriber poll backstop covers a + // dropped or coalesced wakeup, so publish never blocks on NOTIFY. + self.doorbell.mark_dirty(shard); + + Ok(()) } async fn flush(&self) -> Result<()> { @@ -741,26 +682,80 @@ impl PubSubDriver for PostgresDriver { pub struct PostgresSubscriber { subject: String, - rx: Option>>, + subject_hash: String, + pool: Arc, + cursor: i64, + buffer: VecDeque>, + rx: broadcast::Receiver<()>, _drop_guard: tokio_util::sync::DropGuard, } +impl PostgresSubscriber { + /// Reads new rows past the cursor into the buffer, advancing the cursor. Rows + /// whose stored subject does not match are skipped (DefaultHasher collisions) but + /// still advance the cursor. + async fn fetch(&mut self) -> Result<()> { + let conn = self + .pool + .get() + .await + .context("failed to get connection for poll")?; + let rows = conn + .query( + "SELECT id, subject, payload FROM ups_messages \ + WHERE subject_hash = $1 AND id > $2 ORDER BY id", + &[&self.subject_hash, &self.cursor], + ) + .await + .context("failed to poll broadcast messages")?; + + for row in rows { + let id: i64 = row.get(0); + let subject: String = row.get(1); + let payload: Vec = row.get(2); + self.cursor = id; + if subject == self.subject { + self.buffer.push_back(payload); + } + } + + Ok(()) + } +} + #[async_trait] impl SubscriberDriver for PostgresSubscriber { async fn next(&mut self) -> Result { - let rx = match self.rx.as_mut() { - Some(rx) => rx, - None => return Ok(DriverOutput::Unsubscribed), - }; - match rx.recv().await { - std::result::Result::Ok(payload) => Ok(DriverOutput::Message { - subject: self.subject.clone(), - payload, - }), - Err(tokio::sync::broadcast::error::RecvError::Closed) => Ok(DriverOutput::Unsubscribed), - Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => { - // Try again - self.next().await + loop { + if let Some(payload) = self.buffer.pop_front() { + return Ok(DriverOutput::Message { + subject: self.subject.clone(), + payload, + }); + } + + if let Err(e) = self.fetch().await { + // Transient DB errors must not kill the subscriber; the next poll + // tick retries. + tracing::warn!(?e, subject = %self.subject, "failed to poll, will retry"); + } + + if !self.buffer.is_empty() { + continue; + } + + // Wait for a doorbell wakeup or the poll backstop, whichever is first. + tokio::select! { + res = self.rx.recv() => { + match res { + std::result::Result::Ok(()) => {} + Err(broadcast::error::RecvError::Lagged(_)) => {} + Err(broadcast::error::RecvError::Closed) => { + return Ok(DriverOutput::Unsubscribed); + } + } + } + _ = tokio::time::sleep(POLL_INTERVAL) => {} } } } @@ -772,7 +767,7 @@ pub struct PostgresQueueSubscriber { queue_hash: String, sub_id: String, pool: Arc, - rx: Option>>, + rx: broadcast::Receiver<()>, _drop_guard: tokio_util::sync::DropGuard, _heartbeat_token: tokio_util::sync::CancellationToken, } @@ -811,31 +806,32 @@ impl PostgresQueueSubscriber { impl SubscriberDriver for PostgresQueueSubscriber { async fn next(&mut self) -> Result { loop { - // Drain any messages that arrived before or between notifications. - // Do this before borrowing rx so claim_message can borrow self freely. - if let Some(payload) = self.claim_message().await? { - return Ok(DriverOutput::Message { - subject: self.subject.clone(), - payload, - }); - } - - // Wait for a wakeup notification, then loop back to claim. - let rx = match self.rx.as_mut() { - Some(rx) => rx, - None => return Ok(DriverOutput::Unsubscribed), - }; - match rx.recv().await { - std::result::Result::Ok(_) => { - // Wakeup received; loop back to claim + // Drain any messages that arrived before or between wakeups. + match self.claim_message().await { + Result::Ok(Some(payload)) => { + return Ok(DriverOutput::Message { + subject: self.subject.clone(), + payload, + }); } - Err(tokio::sync::broadcast::error::RecvError::Closed) => { - return Ok(DriverOutput::Unsubscribed); + Result::Ok(None) => {} + Result::Err(e) => { + tracing::warn!(?e, subject = %self.subject, "failed to claim, will retry"); } - Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => { - // Notifications were dropped while lagged; loop back to claim in case - // messages are waiting + } + + // Wait for a doorbell wakeup or the poll backstop, then loop back to claim. + tokio::select! { + res = self.rx.recv() => { + match res { + std::result::Result::Ok(()) => {} + Err(broadcast::error::RecvError::Lagged(_)) => {} + Err(broadcast::error::RecvError::Closed) => { + return Ok(DriverOutput::Unsubscribed); + } + } } + _ = tokio::time::sleep(POLL_INTERVAL) => {} } } }