@@ -12,7 +12,10 @@ use std::{
1212 net:: SocketAddr ,
1313 rc:: Rc ,
1414 str:: { FromStr , Utf8Error } ,
15- sync:: Arc ,
15+ sync:: {
16+ atomic:: { AtomicI32 , Ordering } ,
17+ Arc ,
18+ } ,
1619 time:: Duration ,
1720} ;
1821
@@ -46,6 +49,7 @@ use pgwire::{
4649 types:: FromSqlText ,
4750} ;
4851use postgres_types:: { FromSql , Type } ;
52+ use rand:: { rngs:: StdRng , Rng , SeedableRng } ;
4953use rusqlite:: {
5054 ffi:: SQLITE_CONSTRAINT_UNIQUE , functions:: FunctionFlags , types:: ValueRef ,
5155 vtab:: eponymous_only_module, Connection , Statement ,
@@ -62,7 +66,7 @@ use tokio::{
6266 net:: TcpListener ,
6367 sync:: {
6468 mpsc:: { channel, Sender } ,
65- AcquireError , OwnedSemaphorePermit ,
69+ AcquireError , OwnedSemaphorePermit , RwLock as TokioRwLock ,
6670 } ,
6771 time:: timeout,
6872} ;
@@ -436,6 +440,37 @@ enum OpenTxKind {
436440 Explicit ,
437441}
438442
443+ #[ derive( Debug , Clone ) ]
444+ struct CancelInfo {
445+ cancel : CancellationToken ,
446+ secret_key : i32 ,
447+ }
448+
449+ #[ derive( Debug , Clone , Default ) ]
450+ pub struct PgTaskCancellation ( Arc < TokioRwLock < HashMap < i32 , CancelInfo > > > ) ;
451+
452+ impl PgTaskCancellation {
453+ pub async fn insert ( & self , conn_id : i32 , cancel : CancellationToken , secret_key : i32 ) {
454+ self . 0
455+ . write ( )
456+ . await
457+ . insert ( conn_id, CancelInfo { cancel, secret_key } ) ;
458+ }
459+
460+ pub async fn remove ( & self , conn_id : i32 ) {
461+ self . 0 . write ( ) . await . remove ( & conn_id) ;
462+ }
463+
464+ pub async fn get_and_verify ( & self , conn_id : i32 , secret_key : i32 ) -> Option < CancellationToken > {
465+ if let Some ( cancel_info) = self . 0 . read ( ) . await . get ( & conn_id) . cloned ( ) {
466+ if cancel_info. secret_key == secret_key {
467+ return Some ( cancel_info. cancel ) ;
468+ }
469+ }
470+ None
471+ }
472+ }
473+
439474#[ derive( Debug , thiserror:: Error ) ]
440475pub enum PgStartError {
441476 #[ error( transparent) ]
@@ -521,6 +556,8 @@ pub async fn start(
521556 "protocol" => "pg" ,
522557 "readonly" => readonly. to_string( ) ,
523558 ) ;
559+ let conn_counter = AtomicI32 :: new ( 0 ) ;
560+ let task_cancellation = PgTaskCancellation :: default ( ) ;
524561
525562 spawn_counted ( async move {
526563 let mut conn_tripwire = tripwire. clone ( ) ;
@@ -540,6 +577,8 @@ pub async fn start(
540577 let tripwire = tripwire. clone ( ) ;
541578 // Don't use spawn_counted here
542579 // Until the connection gets fully established we don't need to gracefully close it
580+ let conn_id = conn_counter. fetch_add ( 1 , Ordering :: SeqCst ) ;
581+ let task_cancellation = task_cancellation. clone ( ) ;
543582 tokio:: spawn ( async move {
544583 conn. stream . set_nodelay ( true ) ?;
545584 {
@@ -601,6 +640,21 @@ pub async fn start(
601640 PgWireFrontendMessage :: Startup ( startup) => {
602641 debug ! ( "received startup message: {startup:?}" ) ;
603642 }
643+ PgWireFrontendMessage :: CancelRequest ( cancel_request) => {
644+ debug ! ( "received cancel request: {cancel_request:?}" ) ;
645+
646+ if let Some ( secret_key) = cancel_request. secret_key . as_i32 ( ) {
647+ if let Some ( cancel) = task_cancellation
648+ . get_and_verify ( cancel_request. pid , secret_key)
649+ . await
650+ {
651+ cancel. cancel ( ) ;
652+ } else {
653+ warn ! ( "invalid secret key for cancel request" ) ;
654+ }
655+ }
656+ return Ok ( ( ) ) ;
657+ }
604658 _ => {
605659 framed
606660 . send ( PgWireBackendMessage :: ErrorResponse (
@@ -631,6 +685,23 @@ pub async fn start(
631685 ) ) )
632686 . await ?;
633687
688+ let mut rng = StdRng :: from_os_rng ( ) ;
689+ let secret_key: i32 = rng. random :: < i32 > ( ) ;
690+
691+ let cancel = CancellationToken :: new ( ) ;
692+ task_cancellation
693+ . insert ( conn_id, cancel. clone ( ) , secret_key)
694+ . await ;
695+
696+ framed
697+ . feed ( PgWireBackendMessage :: BackendKeyData (
698+ pgwire:: messages:: startup:: BackendKeyData :: new (
699+ conn_id,
700+ pgwire:: messages:: startup:: SecretKey :: I32 ( secret_key) ,
701+ ) ,
702+ ) )
703+ . await ?;
704+
634705 framed
635706 . feed ( PgWireBackendMessage :: ReadyForQuery ( ReadyForQuery :: new (
636707 TransactionStatus :: Idle ,
@@ -646,8 +717,6 @@ pub async fn start(
646717
647718 let ( mut sink, mut stream) = framed. split ( ) ;
648719
649- let cancel = CancellationToken :: new ( ) ;
650-
651720 // If we're shutting down corrosion, both frontend and backend tasks will finish
652721 let mut frontend_task = spawn_counted ( {
653722 // Use a weak sender here; it should NOT hold the backend channel (and half-connection) open
@@ -1941,14 +2010,14 @@ pub async fn start(
19412010 continue ;
19422011 }
19432012 PgWireFrontendMessage :: CancelRequest ( _) => {
1944- // cancel.cancel(); ?
2013+ // cancel should be sent as first message on a new connection.
19452014 back_tx. blocking_send (
19462015 (
19472016 PgWireBackendMessage :: ErrorResponse (
19482017 ErrorInfo :: new (
19492018 "ERROR" . into ( ) ,
19502019 "XX000" . to_owned ( ) ,
1951- "Cancel is not implemented " . into ( ) ,
2020+ "Unexpected Cancel message " . into ( ) ,
19522021 )
19532022 . into ( ) ,
19542023 ) ,
@@ -2034,11 +2103,13 @@ pub async fn start(
20342103 // The message-handling loop has completed, make sure we also abort the tasks
20352104 // handling the TCP connection
20362105 // Firstly we attempt a graceful shutdown -- dropping back_tx will cause
2106+
20372107 // backend_task to complete once it writes all content to the TCP socket
20382108 // Then, frontend_task will eventually receive an EOF if clients behave properly
20392109 // Note that this should be the only reference of back_tx at this point:
20402110 // the one in frontend_task is weak, and the one cloned into the message-handling
20412111 // thread should have been dropped.
2112+ task_cancellation. remove ( conn_id) . await ;
20422113 assert_eq ! ( back_tx. strong_count( ) , 1 ) ;
20432114 drop ( back_tx) ;
20442115
0 commit comments