Skip to content

Commit 116e8c8

Browse files
authored
Merge pull request #405 from superfly/somtochi/cancel-pg
Implement CancelRequest for PG
2 parents 3a87a2f + dbcec42 commit 116e8c8

File tree

4 files changed

+119
-6
lines changed

4 files changed

+119
-6
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

crates/corro-pg/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ parking_lot.workspace = true
2020
pgwire = { version = "0.32", default-features = false, features = ["server-api-ring"] }
2121
pin-project-lite.workspace = true
2222
postgres-types = { version = "0.2", features = ["with-time-0_3"] }
23+
rand = { workspace = true }
2324
rusqlite = { workspace = true }
2425
rustls = { workspace = true }
2526
spawn = { path = "../spawn" }

crates/corro-pg/src/lib.rs

Lines changed: 77 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,10 @@ use std::{
1212
net::SocketAddr,
1313
rc::Rc,
1414
str::{FromStr, Utf8Error},
15-
sync::Arc,
15+
sync::{
16+
atomic::{AtomicI32, Ordering},
17+
Arc,
18+
},
1619
time::Duration,
1720
};
1821

@@ -46,6 +49,7 @@ use pgwire::{
4649
types::FromSqlText,
4750
};
4851
use postgres_types::{FromSql, Type};
52+
use rand::{rngs::StdRng, Rng, SeedableRng};
4953
use rusqlite::{
5054
ffi::SQLITE_CONSTRAINT_UNIQUE, functions::FunctionFlags, types::ValueRef,
5155
vtab::eponymous_only_module, Connection, Statement,
@@ -62,7 +66,7 @@ use tokio::{
6266
net::TcpListener,
6367
sync::{
6468
mpsc::{channel, Sender},
65-
AcquireError, OwnedSemaphorePermit,
69+
AcquireError, OwnedSemaphorePermit, RwLock as TokioRwLock,
6670
},
6771
time::timeout,
6872
};
@@ -436,6 +440,37 @@ enum OpenTxKind {
436440
Explicit,
437441
}
438442

443+
#[derive(Debug, Clone)]
444+
struct CancelInfo {
445+
cancel: CancellationToken,
446+
secret_key: i32,
447+
}
448+
449+
#[derive(Debug, Clone, Default)]
450+
pub struct PgTaskCancellation(Arc<TokioRwLock<HashMap<i32, CancelInfo>>>);
451+
452+
impl PgTaskCancellation {
453+
pub async fn insert(&self, conn_id: i32, cancel: CancellationToken, secret_key: i32) {
454+
self.0
455+
.write()
456+
.await
457+
.insert(conn_id, CancelInfo { cancel, secret_key });
458+
}
459+
460+
pub async fn remove(&self, conn_id: i32) {
461+
self.0.write().await.remove(&conn_id);
462+
}
463+
464+
pub async fn get_and_verify(&self, conn_id: i32, secret_key: i32) -> Option<CancellationToken> {
465+
if let Some(cancel_info) = self.0.read().await.get(&conn_id).cloned() {
466+
if cancel_info.secret_key == secret_key {
467+
return Some(cancel_info.cancel);
468+
}
469+
}
470+
None
471+
}
472+
}
473+
439474
#[derive(Debug, thiserror::Error)]
440475
pub enum PgStartError {
441476
#[error(transparent)]
@@ -521,6 +556,8 @@ pub async fn start(
521556
"protocol" => "pg",
522557
"readonly" => readonly.to_string(),
523558
);
559+
let conn_counter = AtomicI32::new(0);
560+
let task_cancellation = PgTaskCancellation::default();
524561

525562
spawn_counted(async move {
526563
let mut conn_tripwire = tripwire.clone();
@@ -540,6 +577,8 @@ pub async fn start(
540577
let tripwire = tripwire.clone();
541578
// Don't use spawn_counted here
542579
// Until the connection gets fully established we don't need to gracefully close it
580+
let conn_id = conn_counter.fetch_add(1, Ordering::SeqCst);
581+
let task_cancellation = task_cancellation.clone();
543582
tokio::spawn(async move {
544583
conn.stream.set_nodelay(true)?;
545584
{
@@ -601,6 +640,21 @@ pub async fn start(
601640
PgWireFrontendMessage::Startup(startup) => {
602641
debug!("received startup message: {startup:?}");
603642
}
643+
PgWireFrontendMessage::CancelRequest(cancel_request) => {
644+
debug!("received cancel request: {cancel_request:?}");
645+
646+
if let Some(secret_key) = cancel_request.secret_key.as_i32() {
647+
if let Some(cancel) = task_cancellation
648+
.get_and_verify(cancel_request.pid, secret_key)
649+
.await
650+
{
651+
cancel.cancel();
652+
} else {
653+
warn!("invalid secret key for cancel request");
654+
}
655+
}
656+
return Ok(());
657+
}
604658
_ => {
605659
framed
606660
.send(PgWireBackendMessage::ErrorResponse(
@@ -631,6 +685,23 @@ pub async fn start(
631685
)))
632686
.await?;
633687

688+
let mut rng = StdRng::from_os_rng();
689+
let secret_key: i32 = rng.random::<i32>();
690+
691+
let cancel = CancellationToken::new();
692+
task_cancellation
693+
.insert(conn_id, cancel.clone(), secret_key)
694+
.await;
695+
696+
framed
697+
.feed(PgWireBackendMessage::BackendKeyData(
698+
pgwire::messages::startup::BackendKeyData::new(
699+
conn_id,
700+
pgwire::messages::startup::SecretKey::I32(secret_key),
701+
),
702+
))
703+
.await?;
704+
634705
framed
635706
.feed(PgWireBackendMessage::ReadyForQuery(ReadyForQuery::new(
636707
TransactionStatus::Idle,
@@ -646,8 +717,6 @@ pub async fn start(
646717

647718
let (mut sink, mut stream) = framed.split();
648719

649-
let cancel = CancellationToken::new();
650-
651720
// If we're shutting down corrosion, both frontend and backend tasks will finish
652721
let mut frontend_task = spawn_counted({
653722
// Use a weak sender here; it should NOT hold the backend channel (and half-connection) open
@@ -1941,14 +2010,14 @@ pub async fn start(
19412010
continue;
19422011
}
19432012
PgWireFrontendMessage::CancelRequest(_) => {
1944-
// cancel.cancel(); ?
2013+
// cancel should be sent as first message on a new connection.
19452014
back_tx.blocking_send(
19462015
(
19472016
PgWireBackendMessage::ErrorResponse(
19482017
ErrorInfo::new(
19492018
"ERROR".into(),
19502019
"XX000".to_owned(),
1951-
"Cancel is not implemented".into(),
2020+
"Unexpected Cancel message".into(),
19522021
)
19532022
.into(),
19542023
),
@@ -2034,11 +2103,13 @@ pub async fn start(
20342103
// The message-handling loop has completed, make sure we also abort the tasks
20352104
// handling the TCP connection
20362105
// Firstly we attempt a graceful shutdown -- dropping back_tx will cause
2106+
20372107
// backend_task to complete once it writes all content to the TCP socket
20382108
// Then, frontend_task will eventually receive an EOF if clients behave properly
20392109
// Note that this should be the only reference of back_tx at this point:
20402110
// the one in frontend_task is weak, and the one cloned into the message-handling
20412111
// thread should have been dropped.
2112+
task_cancellation.remove(conn_id).await;
20422113
assert_eq!(back_tx.strong_count(), 1);
20432114
drop(back_tx);
20442115

crates/corro-pg/tests/tests.rs

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -925,6 +925,46 @@ async fn test_unnest_max_parameters() {
925925
wait_for_all_pending_handles().await;
926926
}
927927

928+
#[tokio::test(flavor = "multi_thread")]
929+
async fn test_request_cancellation() {
930+
let (tripwire, tripwire_worker, tripwire_tx) = Tripwire::new_simple();
931+
932+
let (_ta, server) = setup_pg_test_server(tripwire, None).await;
933+
934+
let conn_str = format!(
935+
"host={} port={} user=testuser",
936+
server.local_addr.ip(),
937+
server.local_addr.port()
938+
);
939+
940+
{
941+
let (client, client_conn) = tokio_postgres::connect(&conn_str, NoTls).await.unwrap();
942+
println!("client is ready!");
943+
tokio::spawn(client_conn);
944+
945+
// cancel the query after 2 seconds
946+
let cancel_token = client.cancel_token();
947+
tokio::spawn(async move {
948+
tokio::time::sleep(Duration::from_secs(2)).await;
949+
cancel_token.cancel_query(NoTls).await.unwrap();
950+
});
951+
952+
let res = client.query("WITH RECURSIVE cnt(x) AS (SELECT 1 UNION ALL SELECT x + 1 FROM cnt WHERE x < 1000000000) SELECT MAX(x) FROM cnt", &[]).await;
953+
assert!(res.is_err());
954+
assert!(res
955+
.err()
956+
.unwrap()
957+
.as_db_error()
958+
.unwrap()
959+
.message()
960+
.contains("interrupted"));
961+
}
962+
963+
tripwire_tx.send(()).await.ok();
964+
tripwire_worker.await;
965+
wait_for_all_pending_handles().await;
966+
}
967+
928968
#[tokio::test(flavor = "multi_thread")]
929969
async fn test_unnest_vtab() {
930970
let (tripwire, tripwire_worker, tripwire_tx) = Tripwire::new_simple();

0 commit comments

Comments
 (0)