diff --git a/Cargo.lock b/Cargo.lock index 8f036f1dc7..f1a4280d57 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5646,6 +5646,7 @@ dependencies = [ "bytes", "compio-buf", "enumset", + "secrecy", "smallvec", "thiserror 2.0.18", ] diff --git a/core/binary_protocol/Cargo.toml b/core/binary_protocol/Cargo.toml index cab10cc8c6..d550c5354a 100644 --- a/core/binary_protocol/Cargo.toml +++ b/core/binary_protocol/Cargo.toml @@ -34,6 +34,7 @@ bytemuck = { workspace = true } bytes = { workspace = true } compio-buf = { workspace = true } enumset = { workspace = true } +secrecy = { workspace = true } smallvec = { workspace = true } thiserror = { workspace = true } diff --git a/core/binary_protocol/src/codes.rs b/core/binary_protocol/src/codes.rs index a088b71fb2..f663c85d0d 100644 --- a/core/binary_protocol/src/codes.rs +++ b/core/binary_protocol/src/codes.rs @@ -39,6 +39,8 @@ pub const UPDATE_PERMISSIONS_CODE: u32 = 36; pub const CHANGE_PASSWORD_CODE: u32 = 37; pub const LOGIN_USER_CODE: u32 = 38; pub const LOGOUT_USER_CODE: u32 = 39; +pub const LOGIN_REGISTER_CODE: u32 = 40; +pub const LOGIN_REGISTER_WITH_PAT_CODE: u32 = 45; // -- Personal Access Tokens -- pub const GET_PERSONAL_ACCESS_TOKENS_CODE: u32 = 41; @@ -119,6 +121,8 @@ mod tests { CHANGE_PASSWORD_CODE, LOGIN_USER_CODE, LOGOUT_USER_CODE, + LOGIN_REGISTER_CODE, + LOGIN_REGISTER_WITH_PAT_CODE, GET_PERSONAL_ACCESS_TOKENS_CODE, CREATE_PERSONAL_ACCESS_TOKEN_CODE, DELETE_PERSONAL_ACCESS_TOKEN_CODE, diff --git a/core/binary_protocol/src/consensus/header.rs b/core/binary_protocol/src/consensus/header.rs index ee09693786..728974b365 100644 --- a/core/binary_protocol/src/consensus/header.rs +++ b/core/binary_protocol/src/consensus/header.rs @@ -105,7 +105,8 @@ pub struct RequestHeader { pub operation: Operation, pub operation_padding: [u8; 7], pub namespace: u64, - pub reserved: [u8; 64], + pub session: u64, + pub reserved: [u8; 56], } const _: () = { assert!(size_of::() == HEADER_SIZE); @@ -113,7 +114,7 @@ const _: () = { offset_of!(RequestHeader, client) == offset_of!(RequestHeader, reserved_frame) + size_of::<[u8; 66]>() ); - assert!(offset_of!(RequestHeader, reserved) + size_of::<[u8; 64]>() == HEADER_SIZE); + assert!(offset_of!(RequestHeader, reserved) + size_of::<[u8; 56]>() == HEADER_SIZE); }; impl Default for RequestHeader { @@ -135,7 +136,8 @@ impl Default for RequestHeader { operation: Operation::Reserved, operation_padding: [0; 7], namespace: 0, - reserved: [0; 64], + session: 0, + reserved: [0; 56], } } } @@ -159,6 +161,31 @@ impl ConsensusHeader for RequestHeader { found: self.command, }); } + // Register: session must be 0, request must be 0. + // Non-register: session must be > 0, request must be > 0. + if self.operation == Operation::Register { + if self.session != 0 { + return Err(ConsensusError::InvalidField( + "register: session must be 0".to_string(), + )); + } + if self.request != 0 { + return Err(ConsensusError::InvalidField( + "register: request must be 0".to_string(), + )); + } + } else if self.operation != Operation::Reserved { + if self.session == 0 { + return Err(ConsensusError::InvalidField( + "non-register: session must be > 0".to_string(), + )); + } + if self.request == 0 { + return Err(ConsensusError::InvalidField( + "non-register: request must be > 0".to_string(), + )); + } + } Ok(()) } } @@ -670,8 +697,9 @@ impl ConsensusHeader for StartViewHeader { #[cfg(test)] mod tests { use super::{ - Command2, CommitHeader, ConsensusHeader, DoViewChangeHeader, GenericHeader, PrepareHeader, - PrepareOkHeader, ReplyHeader, RequestHeader, StartViewChangeHeader, StartViewHeader, + Command2, CommitHeader, ConsensusHeader, DoViewChangeHeader, GenericHeader, Operation, + PrepareHeader, PrepareOkHeader, ReplyHeader, RequestHeader, StartViewChangeHeader, + StartViewHeader, }; fn aligned_zeroed(size: usize) -> bytes::BytesMut { @@ -715,6 +743,66 @@ mod tests { assert!(header.validate().is_err()); } + #[test] + fn request_register_nonzero_session_rejected() { + let header = RequestHeader { + command: Command2::Request, + operation: Operation::Register, + session: 5, + request: 0, + ..RequestHeader::default() + }; + assert!(header.validate().is_err()); + } + + #[test] + fn request_register_nonzero_request_rejected() { + let header = RequestHeader { + command: Command2::Request, + operation: Operation::Register, + session: 0, + request: 1, + ..RequestHeader::default() + }; + assert!(header.validate().is_err()); + } + + #[test] + fn request_non_register_valid() { + let header = RequestHeader { + command: Command2::Request, + operation: Operation::SendMessages, + session: 10, + request: 1, + ..RequestHeader::default() + }; + assert!(header.validate().is_ok()); + } + + #[test] + fn request_non_register_zero_session_rejected() { + let header = RequestHeader { + command: Command2::Request, + operation: Operation::SendMessages, + session: 0, + request: 1, + ..RequestHeader::default() + }; + assert!(header.validate().is_err()); + } + + #[test] + fn request_non_register_zero_request_rejected() { + let header = RequestHeader { + command: Command2::Request, + operation: Operation::SendMessages, + session: 10, + request: 0, + ..RequestHeader::default() + }; + assert!(header.validate().is_err()); + } + #[test] fn reply_header_zero_copy() { let mut buf = aligned_zeroed(256); diff --git a/core/binary_protocol/src/consensus/operation.rs b/core/binary_protocol/src/consensus/operation.rs index d275269aeb..7f0a7f5a9b 100644 --- a/core/binary_protocol/src/consensus/operation.rs +++ b/core/binary_protocol/src/consensus/operation.rs @@ -22,9 +22,17 @@ use bytemuck::{CheckedBitPattern, NoUninit}; #[derive(Default, Debug, Clone, Copy, PartialEq, Eq, NoUninit, CheckedBitPattern)] #[repr(u8)] pub enum Operation { + /// The value 0 is reserved to prevent a spurious zero from being + /// interpreted as a valid operation. #[default] Reserved = 0, + /// Register a client session with the cluster. Goes through the same + /// consensus pipeline (prepare/replicate/commit) as normal operations + /// but skips state machine dispatch at commit time, the metadata + /// plane calls `commit_register` directly. Session number = commit op. + Register = 1, + // Metadata operations (shard 0) CreateStream = 128, UpdateStream = 129, @@ -85,6 +93,14 @@ impl Operation { ) } + /// VSR protocol-level operations that go through consensus but skip + /// state machine dispatch at commit time. + #[must_use] + #[inline] + pub const fn is_vsr_reserved(&self) -> bool { + matches!(self, Self::Reserved | Self::Register) + } + /// Data-plane operations routed to the shard owning the partition. #[must_use] #[inline] @@ -104,7 +120,7 @@ impl Operation { #[must_use] pub const fn to_command_code(&self) -> Option { match self { - Self::Reserved => None, + Self::Reserved | Self::Register => None, Self::CreateStream | Self::UpdateStream | Self::DeleteStream @@ -188,8 +204,19 @@ mod tests { } #[test] - fn reserved_has_no_code() { + fn vsr_reserved_have_no_code() { assert_eq!(Operation::Reserved.to_command_code(), None); + assert_eq!(Operation::Register.to_command_code(), None); + } + + #[test] + fn vsr_reserved_classification() { + assert!(Operation::Reserved.is_vsr_reserved()); + assert!(Operation::Register.is_vsr_reserved()); + assert!(!Operation::CreateStream.is_vsr_reserved()); + assert!(!Operation::SendMessages.is_vsr_reserved()); + assert!(!Operation::Register.is_metadata()); + assert!(!Operation::Register.is_partition()); } #[test] diff --git a/core/binary_protocol/src/dispatch.rs b/core/binary_protocol/src/dispatch.rs index 1aa9eb8289..e79432f386 100644 --- a/core/binary_protocol/src/dispatch.rs +++ b/core/binary_protocol/src/dispatch.rs @@ -90,6 +90,7 @@ pub const COMMAND_TABLE: &[CommandMeta] = &[ ), CommandMeta::non_replicated(LOGIN_USER_CODE, "user.login"), CommandMeta::non_replicated(LOGOUT_USER_CODE, "user.logout"), + CommandMeta::non_replicated(LOGIN_REGISTER_CODE, "user.login_register"), // Personal Access Tokens CommandMeta::non_replicated( GET_PERSONAL_ACCESS_TOKENS_CODE, @@ -171,6 +172,8 @@ pub const COMMAND_TABLE: &[CommandMeta] = &[ ), CommandMeta::non_replicated(JOIN_CONSUMER_GROUP_CODE, "consumer_group.join"), CommandMeta::non_replicated(LEAVE_CONSUMER_GROUP_CODE, "consumer_group.leave"), + // Login + Register (PAT - Personal Access Token variant) + CommandMeta::non_replicated(LOGIN_REGISTER_WITH_PAT_CODE, "user.login_register_with_pat"), ]; /// Lookup command metadata by command code. @@ -198,37 +201,39 @@ pub const fn lookup_command(code: u32) -> Option<&'static CommandMeta> { CHANGE_PASSWORD_CODE => 13, LOGIN_USER_CODE => 14, LOGOUT_USER_CODE => 15, - GET_PERSONAL_ACCESS_TOKENS_CODE => 16, - CREATE_PERSONAL_ACCESS_TOKEN_CODE => 17, - DELETE_PERSONAL_ACCESS_TOKEN_CODE => 18, - LOGIN_WITH_PERSONAL_ACCESS_TOKEN_CODE => 19, - POLL_MESSAGES_CODE => 20, - SEND_MESSAGES_CODE => 21, - FLUSH_UNSAVED_BUFFER_CODE => 22, - GET_CONSUMER_OFFSET_CODE => 23, - STORE_CONSUMER_OFFSET_CODE => 24, - DELETE_CONSUMER_OFFSET_CODE => 25, - GET_STREAM_CODE => 26, - GET_STREAMS_CODE => 27, - CREATE_STREAM_CODE => 28, - DELETE_STREAM_CODE => 29, - UPDATE_STREAM_CODE => 30, - PURGE_STREAM_CODE => 31, - GET_TOPIC_CODE => 32, - GET_TOPICS_CODE => 33, - CREATE_TOPIC_CODE => 34, - DELETE_TOPIC_CODE => 35, - UPDATE_TOPIC_CODE => 36, - PURGE_TOPIC_CODE => 37, - CREATE_PARTITIONS_CODE => 38, - DELETE_PARTITIONS_CODE => 39, - DELETE_SEGMENTS_CODE => 40, - GET_CONSUMER_GROUP_CODE => 41, - GET_CONSUMER_GROUPS_CODE => 42, - CREATE_CONSUMER_GROUP_CODE => 43, - DELETE_CONSUMER_GROUP_CODE => 44, - JOIN_CONSUMER_GROUP_CODE => 45, - LEAVE_CONSUMER_GROUP_CODE => 46, + LOGIN_REGISTER_CODE => 16, + GET_PERSONAL_ACCESS_TOKENS_CODE => 17, + CREATE_PERSONAL_ACCESS_TOKEN_CODE => 18, + DELETE_PERSONAL_ACCESS_TOKEN_CODE => 19, + LOGIN_WITH_PERSONAL_ACCESS_TOKEN_CODE => 20, + POLL_MESSAGES_CODE => 21, + SEND_MESSAGES_CODE => 22, + FLUSH_UNSAVED_BUFFER_CODE => 23, + GET_CONSUMER_OFFSET_CODE => 24, + STORE_CONSUMER_OFFSET_CODE => 25, + DELETE_CONSUMER_OFFSET_CODE => 26, + GET_STREAM_CODE => 27, + GET_STREAMS_CODE => 28, + CREATE_STREAM_CODE => 29, + DELETE_STREAM_CODE => 30, + UPDATE_STREAM_CODE => 31, + PURGE_STREAM_CODE => 32, + GET_TOPIC_CODE => 33, + GET_TOPICS_CODE => 34, + CREATE_TOPIC_CODE => 35, + DELETE_TOPIC_CODE => 36, + UPDATE_TOPIC_CODE => 37, + PURGE_TOPIC_CODE => 38, + CREATE_PARTITIONS_CODE => 39, + DELETE_PARTITIONS_CODE => 40, + DELETE_SEGMENTS_CODE => 41, + GET_CONSUMER_GROUP_CODE => 42, + GET_CONSUMER_GROUPS_CODE => 43, + CREATE_CONSUMER_GROUP_CODE => 44, + DELETE_CONSUMER_GROUP_CODE => 45, + JOIN_CONSUMER_GROUP_CODE => 46, + LEAVE_CONSUMER_GROUP_CODE => 47, + LOGIN_REGISTER_WITH_PAT_CODE => 48, _ => return None, }; Some(&COMMAND_TABLE[idx]) @@ -242,30 +247,30 @@ pub const fn lookup_command(code: u32) -> Option<&'static CommandMeta> { pub const fn lookup_by_operation(op: Operation) -> Option<&'static CommandMeta> { // Indices must match the order of entries in COMMAND_TABLE above. let idx = match op { - Operation::CreateStream => 28, - Operation::UpdateStream => 30, - Operation::DeleteStream => 29, - Operation::PurgeStream => 31, - Operation::CreateTopic => 34, - Operation::UpdateTopic => 36, - Operation::DeleteTopic => 35, - Operation::PurgeTopic => 37, - Operation::CreatePartitions => 38, - Operation::DeletePartitions => 39, - Operation::DeleteSegments => 40, - Operation::CreateConsumerGroup => 43, - Operation::DeleteConsumerGroup => 44, + Operation::CreateStream => 29, + Operation::UpdateStream => 31, + Operation::DeleteStream => 30, + Operation::PurgeStream => 32, + Operation::CreateTopic => 35, + Operation::UpdateTopic => 37, + Operation::DeleteTopic => 36, + Operation::PurgeTopic => 38, + Operation::CreatePartitions => 39, + Operation::DeletePartitions => 40, + Operation::DeleteSegments => 41, + Operation::CreateConsumerGroup => 44, + Operation::DeleteConsumerGroup => 45, Operation::CreateUser => 9, Operation::UpdateUser => 11, Operation::DeleteUser => 10, Operation::ChangePassword => 13, Operation::UpdatePermissions => 12, - Operation::CreatePersonalAccessToken => 17, - Operation::DeletePersonalAccessToken => 18, - Operation::SendMessages => 21, - Operation::StoreConsumerOffset => 24, - Operation::DeleteConsumerOffset => 25, - Operation::Reserved => return None, + Operation::CreatePersonalAccessToken => 18, + Operation::DeletePersonalAccessToken => 19, + Operation::SendMessages => 22, + Operation::StoreConsumerOffset => 25, + Operation::DeleteConsumerOffset => 26, + Operation::Reserved | Operation::Register => return None, }; Some(&COMMAND_TABLE[idx]) } @@ -293,6 +298,8 @@ mod tests { CHANGE_PASSWORD_CODE, LOGIN_USER_CODE, LOGOUT_USER_CODE, + LOGIN_REGISTER_CODE, + LOGIN_REGISTER_WITH_PAT_CODE, GET_PERSONAL_ACCESS_TOKENS_CODE, CREATE_PERSONAL_ACCESS_TOKEN_CODE, DELETE_PERSONAL_ACCESS_TOKEN_CODE, diff --git a/core/binary_protocol/src/framing.rs b/core/binary_protocol/src/framing.rs index f303a61aed..6ead7d6e09 100644 --- a/core/binary_protocol/src/framing.rs +++ b/core/binary_protocol/src/framing.rs @@ -25,7 +25,7 @@ //! switch to `consensus::header::GenericHeader` (256-byte fixed header) //! while the command payload codec stays the same. -use crate::codec::{read_bytes, read_u32_le}; +use crate::codec::{read_bytes, read_u32_le, read_u64_le}; use crate::error::WireError; use bytes::{BufMut, BytesMut}; use std::borrow::Cow; @@ -183,6 +183,194 @@ impl<'a> ResponseFrame<'a> { } } +/// Decoded request frame with request ID for request-response correlation +/// and consensus-level duplicate detection (server-ng framing). +/// +/// Wire format: `[length:4 LE][code:4 LE][request_id:8 LE][payload:N]` +/// where `length` = 4 (code) + 8 (`request_id`) + N (payload). +#[derive(Debug)] +pub struct RequestFrame2<'a> { + pub code: u32, + pub request_id: u64, + pub payload: &'a [u8], +} + +impl<'a> RequestFrame2<'a> { + /// Size of the frame header: `[length:4][code:4][request_id:8]`. + pub const HEADER_SIZE: usize = 16; + + /// Validate a frame length field and return the payload size. + /// + /// # Errors + /// Returns `WireError::Validation` if `frame_length < 12` (must contain + /// code + `request_id`). + pub fn payload_length(frame_length: u32) -> Result { + frame_length + .checked_sub(12) + .ok_or(WireError::Validation(Cow::Borrowed( + "request frame length must be at least 12 (code + request_id)", + ))) + } + + /// Construct a frame from pre-parsed header fields and a payload slice. + #[must_use] + pub const fn from_parts(code: u32, request_id: u64, payload: &'a [u8]) -> Self { + Self { + code, + request_id, + payload, + } + } + + /// Decode a request frame from a complete buffer. + /// + /// # Errors + /// Returns `WireError::UnexpectedEof` if the buffer is too short. + pub fn decode(buf: &'a [u8]) -> Result<(Self, usize), WireError> { + let frame_length = read_u32_le(buf, 0)?; + let payload_len = Self::payload_length(frame_length)? as usize; + let code = read_u32_le(buf, 4)?; + let request_id = read_u64_le(buf, 8)?; + let payload = read_bytes(buf, Self::HEADER_SIZE, payload_len)?; + let total = Self::HEADER_SIZE + payload_len; + Ok(( + Self { + code, + request_id, + payload, + }, + total, + )) + } + + /// Encode a request frame into `out`. + /// + /// Writes `[length:4 LE][code:4 LE][request_id:8 LE][payload]` where + /// length = 4 (code) + 8 (`request_id`) + `payload.len()`. + /// + /// # Errors + /// Returns `WireError::PayloadTooLarge` if payload exceeds u32 capacity. + pub fn encode( + code: u32, + request_id: u64, + payload: &[u8], + out: &mut BytesMut, + ) -> Result<(), WireError> { + let length = payload + .len() + .checked_add(12) + .and_then(|n| u32::try_from(n).ok()) + .ok_or(WireError::PayloadTooLarge { + size: payload.len(), + max: u32::MAX as usize - 12, + })?; + out.reserve(Self::HEADER_SIZE + payload.len()); + out.put_u32_le(length); + out.put_u32_le(code); + out.put_u64_le(request_id); + out.put_slice(payload); + Ok(()) + } + + /// Total encoded size for a given payload length. + /// + /// Returns `None` if `HEADER_SIZE + payload_len` overflows `usize`. + #[must_use] + pub const fn encoded_size(payload_len: usize) -> Option { + Self::HEADER_SIZE.checked_add(payload_len) + } +} + +/// Decoded response frame with request ID for request-response correlation +/// (server-ng framing). +/// +/// Wire format: `[status:4 LE][length:4 LE][request_id:8 LE][payload:N]` +/// where `status` = 0 for success, non-zero for error code. +#[derive(Debug)] +pub struct ResponseFrame2<'a> { + pub status: u32, + pub request_id: u64, + pub payload: &'a [u8], +} + +impl<'a> ResponseFrame2<'a> { + /// Size of the frame header: `[status:4][length:4][request_id:8]`. + pub const HEADER_SIZE: usize = 16; + + /// Decode a response frame from a complete buffer. + /// + /// The `length` field covers `request_id(8) + payload(N)`. + /// + /// # Errors + /// Returns `WireError::UnexpectedEof` if the buffer is too short. + pub fn decode(buf: &'a [u8]) -> Result<(Self, usize), WireError> { + let status = read_u32_le(buf, 0)?; + let length = read_u32_le(buf, 4)? as usize; + if length < 8 { + return Err(WireError::Validation(Cow::Borrowed( + "response frame length must be at least 8 (request_id)", + ))); + } + let request_id = read_u64_le(buf, 8)?; + let payload_len = length - 8; + let payload = read_bytes(buf, Self::HEADER_SIZE, payload_len)?; + let total = Self::HEADER_SIZE + payload_len; + Ok(( + Self { + status, + request_id, + payload, + }, + total, + )) + } + + /// Encode a successful response with payload. + /// + /// The `length` field = 8 (`request_id`) + `payload.len()`. + /// + /// # Errors + /// Returns `WireError::PayloadTooLarge` if payload exceeds u32 capacity. + pub fn encode_ok(request_id: u64, payload: &[u8], out: &mut BytesMut) -> Result<(), WireError> { + let length = payload + .len() + .checked_add(8) + .and_then(|n| u32::try_from(n).ok()) + .ok_or(WireError::PayloadTooLarge { + size: payload.len(), + max: u32::MAX as usize - 8, + })?; + out.reserve(Self::HEADER_SIZE + payload.len()); + out.put_u32_le(STATUS_OK); + out.put_u32_le(length); + out.put_u64_le(request_id); + out.put_slice(payload); + Ok(()) + } + + /// Encode an error response (status code, no payload, preserves `request_id`). + pub fn encode_error(status: NonZeroU32, request_id: u64, out: &mut BytesMut) { + out.reserve(Self::HEADER_SIZE); + out.put_u32_le(status.get()); + out.put_u32_le(8); // length = request_id only + out.put_u64_le(request_id); + } + + /// Returns `true` if this is a success response. + #[must_use] + pub const fn is_ok(&self) -> bool { + self.status == STATUS_OK + } + + /// Total encoded size for a given payload length. + /// + /// Returns `None` if `HEADER_SIZE + payload_len` overflows `usize`. + #[must_use] + pub const fn encoded_size(payload_len: usize) -> Option { + Self::HEADER_SIZE.checked_add(payload_len) + } +} + #[cfg(test)] mod tests { use super::*; @@ -324,4 +512,156 @@ mod tests { assert_eq!(ResponseFrame::encoded_size(256), Some(264)); assert_eq!(ResponseFrame::encoded_size(usize::MAX), None); } + + // RequestFrame2 tests + + #[test] + fn request2_roundtrip() { + let payload = b"hello world"; + let mut buf = BytesMut::with_capacity(RequestFrame2::encoded_size(payload.len()).unwrap()); + RequestFrame2::encode(42, 7, payload, &mut buf).unwrap(); + + let (frame, consumed) = RequestFrame2::decode(&buf).unwrap(); + assert_eq!(consumed, buf.len()); + assert_eq!(frame.code, 42); + assert_eq!(frame.request_id, 7); + assert_eq!(frame.payload, payload); + } + + #[test] + fn request2_empty_payload() { + let mut buf = BytesMut::with_capacity(RequestFrame2::HEADER_SIZE); + RequestFrame2::encode(1, 99, &[], &mut buf).unwrap(); + + let (frame, consumed) = RequestFrame2::decode(&buf).unwrap(); + assert_eq!(consumed, 16); + assert_eq!(frame.code, 1); + assert_eq!(frame.request_id, 99); + assert!(frame.payload.is_empty()); + } + + #[test] + fn request2_length_field_includes_code_and_request_id() { + let payload = b"test"; + let mut buf = BytesMut::new(); + RequestFrame2::encode(99, 1, payload, &mut buf).unwrap(); + + let length = u32::from_le_bytes(buf[0..4].try_into().unwrap()); + assert_eq!(length, 4 + 8 + 4); // code(4) + request_id(8) + payload(4) + } + + #[test] + fn request2_truncated_header() { + let buf = [0u8; 15]; // less than HEADER_SIZE (16) + assert!(RequestFrame2::decode(&buf).is_err()); + } + + #[test] + fn request2_truncated_payload() { + let mut buf = BytesMut::new(); + buf.put_u32_le(112); // length = 112 (code + request_id + 100 bytes payload) + buf.put_u32_le(1); // code + buf.put_u64_le(1); // request_id + buf.put_slice(&[0u8; 50]); // only 50 of 100 bytes + assert!(RequestFrame2::decode(&buf).is_err()); + } + + #[test] + fn request2_length_too_small() { + let mut buf = BytesMut::new(); + buf.put_u32_le(11); // length < 12 (must include code + request_id) + buf.put_u32_le(1); + buf.put_u64_le(1); + assert!(RequestFrame2::decode(&buf).is_err()); + } + + #[test] + fn request2_payload_length_valid() { + assert_eq!(RequestFrame2::payload_length(12).unwrap(), 0); + assert_eq!(RequestFrame2::payload_length(112).unwrap(), 100); + } + + #[test] + fn request2_payload_length_too_small() { + assert!(RequestFrame2::payload_length(0).is_err()); + assert!(RequestFrame2::payload_length(11).is_err()); + } + + #[test] + fn request2_encoded_size() { + assert_eq!(RequestFrame2::encoded_size(0), Some(16)); + assert_eq!(RequestFrame2::encoded_size(100), Some(116)); + assert_eq!(RequestFrame2::encoded_size(usize::MAX), None); + } + + #[test] + fn request2_from_parts() { + let payload = b"data"; + let frame = RequestFrame2::from_parts(5, 42, payload); + assert_eq!(frame.code, 5); + assert_eq!(frame.request_id, 42); + assert_eq!(frame.payload, payload); + } + + // ResponseFrame2 tests + + #[test] + fn response2_ok_roundtrip() { + let payload = b"response data"; + let mut buf = BytesMut::with_capacity(ResponseFrame2::encoded_size(payload.len()).unwrap()); + ResponseFrame2::encode_ok(7, payload, &mut buf).unwrap(); + + let (frame, consumed) = ResponseFrame2::decode(&buf).unwrap(); + assert_eq!(consumed, buf.len()); + assert!(frame.is_ok()); + assert_eq!(frame.request_id, 7); + assert_eq!(frame.payload, payload); + } + + #[test] + fn response2_ok_empty_payload() { + let mut buf = BytesMut::new(); + ResponseFrame2::encode_ok(42, &[], &mut buf).unwrap(); + + let (frame, consumed) = ResponseFrame2::decode(&buf).unwrap(); + assert_eq!(consumed, 16); + assert!(frame.is_ok()); + assert_eq!(frame.request_id, 42); + assert!(frame.payload.is_empty()); + } + + #[test] + fn response2_error_roundtrip() { + let mut buf = BytesMut::new(); + ResponseFrame2::encode_error(NonZeroU32::new(1001).unwrap(), 55, &mut buf); + + let (frame, consumed) = ResponseFrame2::decode(&buf).unwrap(); + assert_eq!(consumed, 16); + assert!(!frame.is_ok()); + assert_eq!(frame.status, 1001); + assert_eq!(frame.request_id, 55); + assert!(frame.payload.is_empty()); + } + + #[test] + fn response2_truncated_header() { + let buf = [0u8; 15]; + assert!(ResponseFrame2::decode(&buf).is_err()); + } + + #[test] + fn response2_length_too_small() { + let mut buf = BytesMut::new(); + buf.put_u32_le(0); // status + buf.put_u32_le(7); // length < 8 (must include request_id) + buf.put_u64_le(1); // request_id + assert!(ResponseFrame2::decode(&buf).is_err()); + } + + #[test] + fn response2_encoded_size() { + assert_eq!(ResponseFrame2::encoded_size(0), Some(16)); + assert_eq!(ResponseFrame2::encoded_size(256), Some(272)); + assert_eq!(ResponseFrame2::encoded_size(usize::MAX), None); + } } diff --git a/core/binary_protocol/src/lib.rs b/core/binary_protocol/src/lib.rs index a63aad99af..4571bdc6ee 100644 --- a/core/binary_protocol/src/lib.rs +++ b/core/binary_protocol/src/lib.rs @@ -76,7 +76,7 @@ pub use consensus::{ }; pub use dispatch::{COMMAND_TABLE, CommandMeta, lookup_by_operation, lookup_command}; pub use error::WireError; -pub use framing::{RequestFrame, ResponseFrame, STATUS_OK}; +pub use framing::{RequestFrame, RequestFrame2, ResponseFrame, ResponseFrame2, STATUS_OK}; pub use message_view::{ WireMessageIterator, WireMessageIteratorMut, WireMessageView, WireMessageViewMut, }; diff --git a/core/binary_protocol/src/requests/users/login_register.rs b/core/binary_protocol/src/requests/users/login_register.rs new file mode 100644 index 0000000000..72ceb147cd --- /dev/null +++ b/core/binary_protocol/src/requests/users/login_register.rs @@ -0,0 +1,216 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::WireError; +use crate::codec::{WireDecode, WireEncode, read_str, read_u8, read_u32_le, read_u128_le}; +use crate::primitives::identifier::WireName; +use bytes::{BufMut, BytesMut}; +use secrecy::{ExposeSecret, SecretString}; + +/// Combined login + register request for server-ng. +/// +/// The client sends credentials and its ephemeral `client_id`. The server +/// verifies credentials locally, then submits `Operation::Register` through +/// consensus. The response carries `user_id` + `session` (commit op number). +/// +/// Wire format: +/// ```text +/// [client_id:16 LE][username_len:u8][username:N][password_len:u8][password:N] +/// [version_len:u32_le][version:N?][context_len:u32_le][context:N?] +/// ``` +#[derive(Debug, Clone)] +pub struct LoginRegisterRequest { + pub client_id: u128, + pub username: WireName, + pub password: SecretString, + pub version: Option, + pub client_context: Option, +} + +impl WireEncode for LoginRegisterRequest { + fn encoded_size(&self) -> usize { + 16 + self.username.encoded_size() + + 1 + + self.password.expose_secret().len() + + 4 + + self.version.as_ref().map_or(0, String::len) + + 4 + + self.client_context.as_ref().map_or(0, String::len) + } + + fn encode(&self, buf: &mut BytesMut) { + buf.put_u128_le(self.client_id); + self.username.encode(buf); + let password = self.password.expose_secret(); + #[allow(clippy::cast_possible_truncation)] + buf.put_u8(password.len() as u8); + buf.put_slice(password.as_bytes()); + match &self.version { + Some(v) => { + #[allow(clippy::cast_possible_truncation)] + buf.put_u32_le(v.len() as u32); + buf.put_slice(v.as_bytes()); + } + None => buf.put_u32_le(0), + } + match &self.client_context { + Some(c) => { + #[allow(clippy::cast_possible_truncation)] + buf.put_u32_le(c.len() as u32); + buf.put_slice(c.as_bytes()); + } + None => buf.put_u32_le(0), + } + } +} + +impl WireDecode for LoginRegisterRequest { + fn decode(buf: &[u8]) -> Result<(Self, usize), WireError> { + let client_id = read_u128_le(buf, 0)?; + let mut pos = 16; + + let (username, name_len) = WireName::decode(&buf[pos..])?; + pos += name_len; + + let password_len = read_u8(buf, pos)? as usize; + pos += 1; + let password = SecretString::from(read_str(buf, pos, password_len)?); + pos += password_len; + + let version_len = read_u32_le(buf, pos)? as usize; + pos += 4; + let version = if version_len > 0 { + let v = read_str(buf, pos, version_len)?; + pos += version_len; + Some(v) + } else { + None + }; + + let client_context_len = read_u32_le(buf, pos)? as usize; + pos += 4; + let client_context = if client_context_len > 0 { + let c = read_str(buf, pos, client_context_len)?; + pos += client_context_len; + Some(c) + } else { + None + }; + + Ok(( + Self { + client_id, + username, + password, + version, + client_context, + }, + pos, + )) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn assert_req_eq(a: &LoginRegisterRequest, b: &LoginRegisterRequest) { + assert_eq!(a.client_id, b.client_id); + assert_eq!(a.username, b.username); + assert_eq!(a.password.expose_secret(), b.password.expose_secret()); + assert_eq!(a.version, b.version); + assert_eq!(a.client_context, b.client_context); + } + + #[test] + fn roundtrip_full() { + let req = LoginRegisterRequest { + client_id: 0xDEAD_BEEF_CAFE_BABE_1234_5678_9ABC_DEF0, + username: WireName::new("admin").unwrap(), + password: SecretString::from("secret"), + version: Some("1.0.0".to_string()), + client_context: Some("rust-sdk".to_string()), + }; + let bytes = req.to_bytes(); + let (decoded, consumed) = LoginRegisterRequest::decode(&bytes).unwrap(); + assert_eq!(consumed, bytes.len()); + assert_req_eq(&decoded, &req); + } + + #[test] + fn roundtrip_no_optionals() { + let req = LoginRegisterRequest { + client_id: 42, + username: WireName::new("user").unwrap(), + password: SecretString::from("pass"), + version: None, + client_context: None, + }; + let bytes = req.to_bytes(); + let (decoded, consumed) = LoginRegisterRequest::decode(&bytes).unwrap(); + assert_eq!(consumed, bytes.len()); + assert_req_eq(&decoded, &req); + } + + #[test] + fn encoded_size_matches_output() { + let req = LoginRegisterRequest { + client_id: 1, + username: WireName::new("admin").unwrap(), + password: SecretString::from("p"), + version: Some("v1".to_string()), + client_context: Some("ctx".to_string()), + }; + assert_eq!(req.encoded_size(), req.to_bytes().len()); + } + + #[test] + fn truncated_returns_error() { + let req = LoginRegisterRequest { + client_id: 1, + username: WireName::new("u").unwrap(), + password: SecretString::from("p"), + version: Some("v".to_string()), + client_context: Some("c".to_string()), + }; + let bytes = req.to_bytes(); + for i in 0..bytes.len() { + assert!( + LoginRegisterRequest::decode(&bytes[..i]).is_err(), + "expected error for truncation at byte {i}" + ); + } + } + + #[test] + fn wire_layout_client_id_first() { + let req = LoginRegisterRequest { + client_id: 0x0102_0304_0506_0708_090A_0B0C_0D0E_0F10, + username: WireName::new("u").unwrap(), + password: SecretString::from("p"), + version: None, + client_context: None, + }; + let bytes = req.to_bytes(); + // First 16 bytes are client_id in LE. + let client_id = u128::from_le_bytes(bytes[..16].try_into().unwrap()); + assert_eq!(client_id, req.client_id); + // Then username: [1, b'u'], password: [1, b'p'], version: [0,0,0,0], client_context: [0,0,0,0] + assert_eq!(bytes[16], 1); // username len + assert_eq!(bytes[17], b'u'); + } +} diff --git a/core/binary_protocol/src/requests/users/login_register_with_pat.rs b/core/binary_protocol/src/requests/users/login_register_with_pat.rs new file mode 100644 index 0000000000..359bc4bf28 --- /dev/null +++ b/core/binary_protocol/src/requests/users/login_register_with_pat.rs @@ -0,0 +1,203 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::WireError; +use crate::codec::{WireDecode, WireEncode, read_str, read_u8, read_u32_le, read_u128_le}; +use bytes::{BufMut, BytesMut}; +use secrecy::{ExposeSecret, SecretString}; + +/// Combined login-with-PAT + register request for server-ng. +/// +/// The client sends a personal access token and its ephemeral `client_id`. +/// The server verifies the token locally, then submits `Operation::Register` +/// through consensus. The response carries `user_id` + `session` (commit op +/// number). +/// +/// Wire format: +/// ```text +/// [client_id:16 LE][token_len:u8][token:N] +/// [version_len:u32_le][version:N?][context_len:u32_le][context:N?] +/// ``` +#[derive(Debug, Clone)] +pub struct LoginRegisterWithPatRequest { + pub client_id: u128, + pub token: SecretString, + pub version: Option, + pub client_context: Option, +} + +impl WireEncode for LoginRegisterWithPatRequest { + fn encoded_size(&self) -> usize { + 16 + 1 + + self.token.expose_secret().len() + + 4 + + self.version.as_ref().map_or(0, String::len) + + 4 + + self.client_context.as_ref().map_or(0, String::len) + } + + fn encode(&self, buf: &mut BytesMut) { + buf.put_u128_le(self.client_id); + let token = self.token.expose_secret(); + #[allow(clippy::cast_possible_truncation)] + buf.put_u8(token.len() as u8); + buf.put_slice(token.as_bytes()); + match &self.version { + Some(v) => { + #[allow(clippy::cast_possible_truncation)] + buf.put_u32_le(v.len() as u32); + buf.put_slice(v.as_bytes()); + } + None => buf.put_u32_le(0), + } + match &self.client_context { + Some(c) => { + #[allow(clippy::cast_possible_truncation)] + buf.put_u32_le(c.len() as u32); + buf.put_slice(c.as_bytes()); + } + None => buf.put_u32_le(0), + } + } +} + +impl WireDecode for LoginRegisterWithPatRequest { + fn decode(buf: &[u8]) -> Result<(Self, usize), WireError> { + let client_id = read_u128_le(buf, 0)?; + let mut pos = 16; + + let token_len = read_u8(buf, pos)? as usize; + pos += 1; + let token = SecretString::from(read_str(buf, pos, token_len)?); + pos += token_len; + + let version_len = read_u32_le(buf, pos)? as usize; + pos += 4; + let version = if version_len > 0 { + let v = read_str(buf, pos, version_len)?; + pos += version_len; + Some(v) + } else { + None + }; + + let client_context_len = read_u32_le(buf, pos)? as usize; + pos += 4; + let client_context = if client_context_len > 0 { + let c = read_str(buf, pos, client_context_len)?; + pos += client_context_len; + Some(c) + } else { + None + }; + + Ok(( + Self { + client_id, + token, + version, + client_context, + }, + pos, + )) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn assert_req_eq(a: &LoginRegisterWithPatRequest, b: &LoginRegisterWithPatRequest) { + assert_eq!(a.client_id, b.client_id); + assert_eq!(a.token.expose_secret(), b.token.expose_secret()); + assert_eq!(a.version, b.version); + assert_eq!(a.client_context, b.client_context); + } + + #[test] + fn roundtrip_full() { + let req = LoginRegisterWithPatRequest { + client_id: 0xDEAD_BEEF_CAFE_BABE_1234_5678_9ABC_DEF0, + token: SecretString::from("pat-abc123def456"), + version: Some("1.0.0".to_string()), + client_context: Some("rust-sdk".to_string()), + }; + let bytes = req.to_bytes(); + let (decoded, consumed) = LoginRegisterWithPatRequest::decode(&bytes).unwrap(); + assert_eq!(consumed, bytes.len()); + assert_req_eq(&decoded, &req); + } + + #[test] + fn roundtrip_no_optionals() { + let req = LoginRegisterWithPatRequest { + client_id: 42, + token: SecretString::from("tok"), + version: None, + client_context: None, + }; + let bytes = req.to_bytes(); + let (decoded, consumed) = LoginRegisterWithPatRequest::decode(&bytes).unwrap(); + assert_eq!(consumed, bytes.len()); + assert_req_eq(&decoded, &req); + } + + #[test] + fn encoded_size_matches_output() { + let req = LoginRegisterWithPatRequest { + client_id: 1, + token: SecretString::from("t"), + version: Some("v1".to_string()), + client_context: Some("ctx".to_string()), + }; + assert_eq!(req.encoded_size(), req.to_bytes().len()); + } + + #[test] + fn truncated_returns_error() { + let req = LoginRegisterWithPatRequest { + client_id: 1, + token: SecretString::from("t"), + version: Some("v".to_string()), + client_context: Some("c".to_string()), + }; + let bytes = req.to_bytes(); + for i in 0..bytes.len() { + assert!( + LoginRegisterWithPatRequest::decode(&bytes[..i]).is_err(), + "expected error for truncation at byte {i}" + ); + } + } + + #[test] + fn wire_layout_client_id_first() { + let req = LoginRegisterWithPatRequest { + client_id: 0x0102_0304_0506_0708_090A_0B0C_0D0E_0F10, + token: SecretString::from("t"), + version: None, + client_context: None, + }; + let bytes = req.to_bytes(); + // First 16 bytes are client_id in LE. + let client_id = u128::from_le_bytes(bytes[..16].try_into().unwrap()); + assert_eq!(client_id, req.client_id); + // Then token: [1, b't'], version: [0,0,0,0], client_context: [0,0,0,0] + assert_eq!(bytes[16], 1); // token len + assert_eq!(bytes[17], b't'); + } +} diff --git a/core/binary_protocol/src/requests/users/mod.rs b/core/binary_protocol/src/requests/users/mod.rs index d836344150..59860b5ebf 100644 --- a/core/binary_protocol/src/requests/users/mod.rs +++ b/core/binary_protocol/src/requests/users/mod.rs @@ -20,6 +20,8 @@ pub mod create_user; pub mod delete_user; pub mod get_user; pub mod get_users; +pub mod login_register; +pub mod login_register_with_pat; pub mod login_user; pub mod logout_user; pub mod update_permissions; @@ -30,6 +32,8 @@ pub use create_user::CreateUserRequest; pub use delete_user::DeleteUserRequest; pub use get_user::GetUserRequest; pub use get_users::GetUsersRequest; +pub use login_register::LoginRegisterRequest; +pub use login_register_with_pat::LoginRegisterWithPatRequest; pub use login_user::LoginUserRequest; pub use logout_user::LogoutUserRequest; pub use update_permissions::UpdatePermissionsRequest; diff --git a/core/binary_protocol/src/responses/users/login_register.rs b/core/binary_protocol/src/responses/users/login_register.rs new file mode 100644 index 0000000000..59bdd578d2 --- /dev/null +++ b/core/binary_protocol/src/responses/users/login_register.rs @@ -0,0 +1,104 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::WireError; +use crate::codec::{WireDecode, WireEncode, read_u32_le, read_u64_le}; +use bytes::{BufMut, BytesMut}; + +/// Combined login + register response for server-ng. +/// +/// Returns the authenticated user's ID and the consensus session number +/// (commit op number from the Register operation). +/// +/// Wire format (12 bytes): +/// ```text +/// [user_id:4 LE][session:8 LE] +/// ``` +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct LoginRegisterResponse { + pub user_id: u32, + pub session: u64, +} + +impl WireEncode for LoginRegisterResponse { + fn encoded_size(&self) -> usize { + 12 + } + + fn encode(&self, buf: &mut BytesMut) { + buf.put_u32_le(self.user_id); + buf.put_u64_le(self.session); + } +} + +impl WireDecode for LoginRegisterResponse { + fn decode(buf: &[u8]) -> Result<(Self, usize), WireError> { + let user_id = read_u32_le(buf, 0)?; + let session = read_u64_le(buf, 4)?; + Ok((Self { user_id, session }, 12)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn roundtrip() { + let resp = LoginRegisterResponse { + user_id: 42, + session: 100, + }; + let bytes = resp.to_bytes(); + assert_eq!(bytes.len(), 12); + let (decoded, consumed) = LoginRegisterResponse::decode(&bytes).unwrap(); + assert_eq!(consumed, 12); + assert_eq!(decoded, resp); + } + + #[test] + fn truncated_returns_error() { + let resp = LoginRegisterResponse { + user_id: 1, + session: 1, + }; + let bytes = resp.to_bytes(); + for i in 0..bytes.len() { + assert!( + LoginRegisterResponse::decode(&bytes[..i]).is_err(), + "expected error for truncation at byte {i}" + ); + } + } + + #[test] + fn wire_layout() { + let resp = LoginRegisterResponse { + user_id: 0x0102_0304, + session: 0x0506_0708_090A_0B0C, + }; + let bytes = resp.to_bytes(); + assert_eq!( + u32::from_le_bytes(bytes[..4].try_into().unwrap()), + 0x0102_0304 + ); + assert_eq!( + u64::from_le_bytes(bytes[4..12].try_into().unwrap()), + 0x0506_0708_090A_0B0C + ); + } +} diff --git a/core/binary_protocol/src/responses/users/mod.rs b/core/binary_protocol/src/responses/users/mod.rs index 08cfd7d0e2..cb95cdd71f 100644 --- a/core/binary_protocol/src/responses/users/mod.rs +++ b/core/binary_protocol/src/responses/users/mod.rs @@ -20,6 +20,7 @@ mod create_user; mod delete_user; pub mod get_user; pub mod get_users; +pub mod login_register; pub mod login_user; mod logout_user; mod update_permissions; @@ -32,6 +33,7 @@ pub use create_user::CreateUserResponse; pub use delete_user::DeleteUserResponse; pub use get_user::UserDetailsResponse; pub use get_users::GetUsersResponse; +pub use login_register::LoginRegisterResponse; pub use login_user::IdentityResponse; pub use logout_user::LogoutUserResponse; pub use update_permissions::UpdatePermissionsResponse; diff --git a/core/consensus/src/client_table.rs b/core/consensus/src/client_table.rs index 620951fa0f..54a888034c 100644 --- a/core/consensus/src/client_table.rs +++ b/core/consensus/src/client_table.rs @@ -103,10 +103,16 @@ impl Default for Notify { /// Per-client entry in the clients table (VR paper Section 4, Figure 2). /// -/// Stores the reply for the client's latest committed request. The client ID, -/// request number, and commit number are all read from `reply.header()`. +/// Stores the session number and the reply for the client's latest committed +/// request. The session number is assigned once at registration (from the +/// commit op number) and never changes for the lifetime of the entry. #[derive(Debug)] pub struct ClientEntry { + /// Session number assigned at registration time (= commit op number of the + /// register operation). Monotonically increasing across registrations. + /// Used to distinguish between successive registrations from different + /// client processes, a new register always gets a higher session. + pub session: u64, /// The cached reply for the client's latest committed request (header + body). pub reply: Message, } @@ -123,6 +129,18 @@ pub enum RequestStatus { /// Request number is older than the client's latest committed request. /// Already handled in a prior commit cycle, drop silently. Stale, + /// No session exists for this client. Client must register first. + NoSession, + /// Client's session number doesn't match the entry. + SessionMismatch { expected: u64, received: u64 }, + /// Request number is not exactly `committed + 1`. Client skipped + /// request numbers, which would permanently lose the skipped range. + RequestGap { expected: u64, received: u64 }, + /// Client already has a session. Returned by `check_register` when + /// the client is already registered. Carries the existing session + /// number so the caller can synthesize the correct register reply + /// without type-confusing the latest app reply as a register reply. + AlreadyRegistered { session: u64 }, } /// VSR client-table: tracks per-client request state for duplicate detection, @@ -182,32 +200,39 @@ impl ClientTable { /// Check a request against the table. /// - /// Returns: - /// - [`RequestStatus::New`]: not seen before, proceed with consensus - /// - [`RequestStatus::Duplicate`]: already committed, re-send cached reply - /// - [`RequestStatus::InProgress`]: in the pipeline awaiting commit - /// - [`RequestStatus::Stale`]: older than the client's latest committed request + /// Validates session number first, then request number progression. + /// For `Register` operations, use [`check_register`] instead. /// /// # Panics /// Panics if the internal index points to an empty slot (invariant violation). #[must_use] - pub fn check_request(&self, client_id: u128, request: u64) -> RequestStatus { - // TODO: Once client sessions are added (register/evict protocol like - // validate client_id at the session layer instead of - // panicking here. Unregistered or invalid clients should be rejected - // gracefully at ingress, not inside the client table. + pub fn check_request(&self, client_id: u128, session: u64, request: u64) -> RequestStatus { assert!(client_id != 0, "client_id 0 is reserved for internal use"); + // Non-register: session and request must both be > 0. + // Header validation enforces this at the wire layer. + debug_assert!(session > 0, "check_request: session must be > 0"); + debug_assert!(request > 0, "check_request: request must be > 0"); + + // Session validation first, then pipeline check (like TigerBeetle). + // A wrong-session request must be rejected even if the same + // (client_id, request) happens to be pending from the correct session. + let Some(&slot_idx) = self.index.get(&client_id) else { + return RequestStatus::NoSession; + }; + let entry = self.slots[slot_idx].as_ref().expect("index/slot mismatch"); + + if session != entry.session { + return RequestStatus::SessionMismatch { + expected: entry.session, + received: session, + }; + } - // Check if already pending in the pipeline. let key = ClientRequest { client_id, request }; if self.pending.contains_key(&key) { return RequestStatus::InProgress; } - let Some(&slot_idx) = self.index.get(&client_id) else { - return RequestStatus::New; - }; - let entry = self.slots[slot_idx].as_ref().expect("index/slot mismatch"); let committed_request = entry.reply.header().request; if request < committed_request { @@ -216,10 +241,45 @@ impl ClientTable { if request == committed_request { return RequestStatus::Duplicate(entry.reply.clone()); } + if request != committed_request + 1 { + return RequestStatus::RequestGap { + expected: committed_request + 1, + received: request, + }; + } RequestStatus::New } + /// Check whether a register request should be processed. + /// + /// Register is valid even without an existing session. If the client + /// already has a session, returns `AlreadyRegistered` with the session + /// number so the caller can synthesize the correct register reply. + /// + /// # Panics + /// Panics if `client_id` is 0 or if the internal index points to an empty slot. + #[must_use] + pub fn check_register(&self, client_id: u128) -> RequestStatus { + assert!(client_id != 0, "client_id 0 is reserved for internal use"); + + let key = ClientRequest { + client_id, + request: 0, + }; + if self.pending.contains_key(&key) { + return RequestStatus::InProgress; + } + + let Some(&slot_idx) = self.index.get(&client_id) else { + return RequestStatus::New; + }; + let entry = self.slots[slot_idx].as_ref().expect("index/slot mismatch"); + RequestStatus::AlreadyRegistered { + session: entry.session, + } + } + /// Register interest in a pending request's commit. /// /// Returns a [`Notify`] the caller can `.notified().await` on. The `Notify` @@ -242,56 +302,112 @@ impl ClientTable { notify } - /// Record a committed reply and cache it. + /// Record a committed register reply, creates or updates a session. /// - /// If the client already has a slot, updates it in place. Otherwise allocates - /// a free slot, evicting the client with the oldest commit number if the table - /// is full. + /// Session number = `reply.header().commit` (the commit op number), + /// for deterministic, monotonically increasing session numbers. /// - /// Wakes the pending [`Notify`] for this `(client_id, request)` if one exists. + /// Idempotent: if the client already has a slot (e.g. WAL replay after + /// view change), the entry is updated in place. The session number must + /// match, a different session for the same `client_id` is a bug. /// - /// Called in `on_ack` after `build_reply_message`. + /// If the table is full and the client is new, the client with the oldest + /// commit number is evicted. /// /// # Panics - /// Panics if the internal index points to an empty slot (invariant violation). - pub fn commit_reply(&mut self, client_id: u128, reply: Message) { + /// Panics if `client_id` is 0, or if an existing entry has a different + /// session number (indicates a protocol violation). + pub fn commit_register(&mut self, client_id: u128, reply: Message) { assert!(client_id != 0, "client_id 0 is reserved for internal use"); assert_eq!( client_id, reply.header().client, - "commit_reply: client_id mismatch (arg={client_id}, header={})", + "commit_register: client_id mismatch (arg={client_id}, header={})", reply.header().client ); - let request = reply.header().request; + + let session = reply.header().commit; + assert!(session > 0, "commit_register: session must be > 0"); if let Some(&slot_idx) = self.index.get(&client_id) { + // Re-registration during WAL replay, update in place. let slot = self.slots[slot_idx].as_mut().expect("index/slot mismatch"); - // Monotonicity: both commit (op) and request must not regress. - assert!( - reply.header().commit >= slot.reply.header().commit, - "commit_reply: commit regression for client {client_id}: {} -> {}", - slot.reply.header().commit, - reply.header().commit - ); - assert!( - reply.header().request >= slot.reply.header().request, - "commit_reply: request regression for client {client_id}: {} -> {}", - slot.reply.header().request, - reply.header().request + assert_eq!( + slot.session, session, + "commit_register: session mismatch for client {client_id}: \ + existing={}, new={session}", + slot.session ); slot.reply = reply; } else { - // Need a free slot. Evict if full. + // New client, allocate a slot. if self.index.len() >= self.slots.len() { self.evict_oldest(); } let slot_idx = self.first_free_slot().expect("eviction must free a slot"); - self.slots[slot_idx] = Some(ClientEntry { reply }); + self.slots[slot_idx] = Some(ClientEntry { session, reply }); self.index.insert(client_id, slot_idx); } - // Wake the waiter, if any. + let key = ClientRequest { + client_id, + request: 0, + }; + if let Some(notify) = self.pending.remove(&key) { + notify.notify(); + } + } + + /// Record a committed reply and cache it. + /// + /// Updates the existing entry in place. The client must already be + /// registered via [`commit_register`]. + /// + /// The `session` parameter is the session from the prepare/request header. + /// It is asserted against the stored session to guard against WAL replay + /// committing a stale reply from a previous session to the wrong entry. + /// + /// Wakes the pending [`Notify`] for this `(client_id, request)` if one exists. + /// + /// # Panics + /// Panics if the client has no slot, if session doesn't match, or if + /// commit/request regresses. + pub fn commit_reply(&mut self, client_id: u128, session: u64, reply: Message) { + assert!(client_id != 0, "client_id 0 is reserved for internal use"); + assert_eq!( + client_id, + reply.header().client, + "commit_reply: client_id mismatch (arg={client_id}, header={})", + reply.header().client + ); + let request = reply.header().request; + + let &slot_idx = self + .index + .get(&client_id) + .unwrap_or_else(|| panic!("commit_reply: client {client_id} not registered")); + let slot = self.slots[slot_idx].as_mut().expect("index/slot mismatch"); + assert_eq!( + slot.session, session, + "commit_reply: session mismatch for client {client_id}: \ + entry={}, prepare={session}", + slot.session + ); + assert!( + reply.header().commit >= slot.reply.header().commit, + "commit_reply: commit regression for client {client_id}: {} -> {}", + slot.reply.header().commit, + reply.header().commit + ); + assert!( + reply.header().request >= slot.reply.header().request, + "commit_reply: request regression for client {client_id}: {} -> {}", + slot.reply.header().request, + reply.header().request + ); + slot.reply = reply; + let key = ClientRequest { client_id, request }; if let Some(notify) = self.pending.remove(&key) { notify.notify(); @@ -341,6 +457,13 @@ impl ClientTable { self.slots[slot_idx].as_ref().map(|entry| &entry.reply) } + /// Get the session number for a registered client. + #[must_use] + pub fn get_session(&self, client_id: u128) -> Option { + let &slot_idx = self.index.get(&client_id)?; + self.slots[slot_idx].as_ref().map(|entry| entry.session) + } + /// Number of active committed client entries. #[must_use] pub fn count(&self) -> usize { @@ -368,6 +491,24 @@ mod tests { use super::*; use iggy_binary_protocol::{Command2, Operation}; + fn make_register_reply(client: u128, commit: u64) -> Message { + let header_size = std::mem::size_of::(); + let mut msg = Message::::new(header_size); + let header = bytemuck::checked::try_from_bytes_mut::( + &mut msg.as_mut_slice()[..header_size], + ) + .expect("zeroed bytes are valid"); + *header = ReplyHeader { + client, + request: 0, + commit, + command: Command2::Reply, + operation: Operation::Register, + ..ReplyHeader::default() + }; + msg + } + fn make_reply_for(client: u128, request: u64, commit: u64) -> Message { let header_size = std::mem::size_of::(); let mut msg = Message::::new(header_size); @@ -386,8 +527,12 @@ mod tests { msg } - fn make_reply(request: u64, commit: u64) -> Message { - make_reply_for(1, request, commit) + /// Register client 1 at commit 10, return (table, session=10). + fn table_with_client() -> (ClientTable, u64) { + let mut table = ClientTable::new(10); + let session = 10; + table.commit_register(1, make_register_reply(1, session)); + (table, session) } // Notify tests @@ -396,8 +541,6 @@ mod tests { fn notify_after_await_registration() { let notify = Notify::new(); let waiter = notify.clone(); - - // Notify before anyone polls, should resolve immediately on first poll. notify.notify(); let waker = futures::task::noop_waker(); @@ -414,14 +557,9 @@ mod tests { let waker = futures::task::noop_waker(); let mut cx = std::task::Context::from_waker(&waker); let mut fut = std::pin::pin!(waiter.notified()); - - // First poll returns Pending (not yet notified). assert!(fut.as_mut().poll(&mut cx).is_pending()); - // Notify fires. notify.notify(); - - // Next poll resolves. assert!(fut.as_mut().poll(&mut cx).is_ready()); } @@ -432,86 +570,182 @@ mod tests { let waker = futures::task::noop_waker(); let mut cx = std::task::Context::from_waker(&waker); - - // First notified() consumes the permit. let mut fut1 = std::pin::pin!(notify.notified()); assert!(fut1.as_mut().poll(&mut cx).is_ready()); - // Second notified() should be pending (permit consumed). let mut fut2 = std::pin::pin!(notify.notified()); assert!(fut2.as_mut().poll(&mut cx).is_pending()); } - // ClientTable tests + // Registration tests + + #[test] + fn register_creates_session() { + let mut table = ClientTable::new(10); + table.commit_register(1, make_register_reply(1, 42)); + assert_eq!(table.get_session(1), Some(42)); + assert_eq!(table.count(), 1); + } #[test] - fn check_request_new() { + fn check_register_new_client() { let table = ClientTable::new(10); - assert!(matches!(table.check_request(1, 1), RequestStatus::New)); + assert!(matches!(table.check_register(1), RequestStatus::New)); } #[test] - fn check_request_duplicate_after_commit() { + fn check_register_already_registered() { + let (table, session) = table_with_client(); + match table.check_register(1) { + RequestStatus::AlreadyRegistered { session: s } => assert_eq!(s, session), + other => panic!("expected AlreadyRegistered, got {other:?}"), + } + } + + #[test] + fn check_register_already_registered_after_progress() { + let (mut table, session) = table_with_client(); + // Client progresses past registration. + table.commit_reply(1, 10, make_reply_for(1, 1, 11)); + table.commit_reply(1, 10, make_reply_for(1, 2, 12)); + // Re-register still returns the session, not the latest app reply. + match table.check_register(1) { + RequestStatus::AlreadyRegistered { session: s } => assert_eq!(s, session), + other => panic!("expected AlreadyRegistered, got {other:?}"), + } + } + + #[test] + fn commit_register_notifies() { let mut table = ClientTable::new(10); - table.commit_reply(1, make_reply(1, 10)); + let notify = table.register_pending(1, 0); + table.commit_register(1, make_register_reply(1, 5)); + + let waker = futures::task::noop_waker(); + let mut cx = std::task::Context::from_waker(&waker); + let mut fut = std::pin::pin!(notify.notified()); + assert!(fut.as_mut().poll(&mut cx).is_ready()); + } - match table.check_request(1, 1) { - RequestStatus::Duplicate(cached) => { - assert_eq!(cached.header().request, 1); + // Session validation tests + + #[test] + fn check_request_no_session() { + let table = ClientTable::new(10); + // Client 1 not registered — valid session/request but no entry. + assert!(matches!( + table.check_request(1, 99, 1), + RequestStatus::NoSession + )); + } + + #[test] + fn check_request_session_mismatch() { + let (table, session) = table_with_client(); + match table.check_request(1, session + 1, 1) { + RequestStatus::SessionMismatch { expected, received } => { + assert_eq!(expected, session); + assert_eq!(received, session + 1); } - _ => panic!("expected Duplicate"), + other => panic!("expected SessionMismatch, got {other:?}"), + } + } + + #[test] + fn check_request_correct_session_new() { + let (mut table, session) = table_with_client(); + table.commit_reply(1, 10, make_reply_for(1, 1, 11)); + assert!(matches!( + table.check_request(1, session, 2), + RequestStatus::New + )); + } + + #[test] + fn check_request_duplicate_after_commit() { + let (mut table, session) = table_with_client(); + table.commit_reply(1, 10, make_reply_for(1, 1, 11)); + match table.check_request(1, session, 1) { + RequestStatus::Duplicate(cached) => assert_eq!(cached.header().request, 1), + other => panic!("expected Duplicate, got {other:?}"), } } #[test] fn check_request_stale() { - let mut table = ClientTable::new(10); - table.commit_reply(1, make_reply(5, 10)); + let (mut table, session) = table_with_client(); + table.commit_reply(1, 10, make_reply_for(1, 5, 15)); + assert!(matches!( + table.check_request(1, session, 3), + RequestStatus::Stale + )); + } - assert!(matches!(table.check_request(1, 3), RequestStatus::Stale)); + #[test] + fn check_request_gap_rejected() { + let (mut table, session) = table_with_client(); + table.commit_reply(1, 10, make_reply_for(1, 1, 11)); + // Request 3 skips request 2 — must be rejected. + match table.check_request(1, session, 3) { + RequestStatus::RequestGap { expected, received } => { + assert_eq!(expected, 2); + assert_eq!(received, 3); + } + other => panic!("expected RequestGap, got {other:?}"), + } } #[test] fn check_request_in_progress_while_pending() { - let mut table = ClientTable::new(10); + let (mut table, session) = table_with_client(); let _notify = table.register_pending(1, 1); - assert!(matches!( - table.check_request(1, 1), + table.check_request(1, session, 1), RequestStatus::InProgress )); } #[test] - fn commit_caches_reply() { - let mut table = ClientTable::new(10); - table.commit_reply(1, make_reply(1, 10)); + fn check_request_wrong_session_even_if_pending() { + let (mut table, session) = table_with_client(); + let _notify = table.register_pending(1, 1); + // Same (client_id, request) is pending, but session is wrong. + // Must return SessionMismatch, not InProgress. + match table.check_request(1, session + 1, 1) { + RequestStatus::SessionMismatch { expected, received } => { + assert_eq!(expected, session); + assert_eq!(received, session + 1); + } + other => panic!("expected SessionMismatch, got {other:?}"), + } + } + // Commit tests + + #[test] + fn commit_caches_reply() { + let (mut table, _) = table_with_client(); + table.commit_reply(1, 10, make_reply_for(1, 1, 11)); let cached = table.get_reply(1).expect("should have cached reply"); assert_eq!(cached.header().request, 1); } #[test] - fn commit_updates_existing_entry() { - let mut table = ClientTable::new(10); - table.commit_reply(1, make_reply(1, 10)); - table.commit_reply(1, make_reply(2, 20)); - - let cached = table.get_reply(1).expect("should have cached reply"); - assert_eq!(cached.header().request, 2); + fn commit_updates_preserves_session() { + let (mut table, session) = table_with_client(); + table.commit_reply(1, 10, make_reply_for(1, 1, 11)); + table.commit_reply(1, 10, make_reply_for(1, 2, 12)); + assert_eq!(table.get_reply(1).unwrap().header().request, 2); + assert_eq!(table.get_session(1), Some(session)); assert_eq!(table.count(), 1); } #[test] fn register_and_commit_notifies() { - let mut table = ClientTable::new(10); + let (mut table, _) = table_with_client(); let notify = table.register_pending(1, 1); - assert_eq!(table.pending_count(), 1); - - // Commit fires the notify. - table.commit_reply(1, make_reply(1, 10)); - + table.commit_reply(1, 10, make_reply_for(1, 1, 11)); assert_eq!(table.pending_count(), 0); let waker = futures::task::noop_waker(); @@ -520,14 +754,14 @@ mod tests { assert!(fut.as_mut().poll(&mut cx).is_ready()); } + // Eviction tests + #[test] fn eviction_removes_oldest_commit() { let mut table = ClientTable::new(2); - - table.commit_reply(100, make_reply_for(100, 1, 10)); - table.commit_reply(200, make_reply_for(200, 1, 20)); - table.commit_reply(300, make_reply_for(300, 1, 30)); - + table.commit_register(100, make_register_reply(100, 10)); + table.commit_register(200, make_register_reply(200, 20)); + table.commit_register(300, make_register_reply(300, 30)); assert!(table.get_reply(100).is_none()); assert!(table.get_reply(200).is_some()); assert!(table.get_reply(300).is_some()); @@ -537,36 +771,26 @@ mod tests { #[test] fn eviction_is_deterministic_by_slot_index() { let mut table = ClientTable::new(2); - - table.commit_reply(100, make_reply_for(100, 1, 10)); - table.commit_reply(200, make_reply_for(200, 1, 10)); - table.commit_reply(300, make_reply_for(300, 1, 30)); - + table.commit_register(100, make_register_reply(100, 10)); + table.commit_register(200, make_register_reply(200, 10)); + table.commit_register(300, make_register_reply(300, 30)); assert!(table.get_reply(100).is_none()); assert!(table.get_reply(200).is_some()); assert!(table.get_reply(300).is_some()); } - #[test] - fn new_request_after_commit_is_new() { - let mut table = ClientTable::new(10); - table.commit_reply(1, make_reply(1, 10)); - - assert!(matches!(table.check_request(1, 2), RequestStatus::New)); - } - #[test] fn slot_reuse_after_eviction() { let mut table = ClientTable::new(1); - - table.commit_reply(100, make_reply_for(100, 1, 10)); - table.commit_reply(200, make_reply_for(200, 1, 20)); - + table.commit_register(100, make_register_reply(100, 10)); + table.commit_register(200, make_register_reply(200, 20)); assert!(table.get_reply(100).is_none()); assert!(table.get_reply(200).is_some()); assert_eq!(table.count(), 1); } + // Edge cases + #[test] #[should_panic(expected = "already has a pending waiter")] fn register_pending_twice_panics() { @@ -574,4 +798,62 @@ mod tests { let _n1 = table.register_pending(1, 1); let _n2 = table.register_pending(1, 1); } + + #[test] + fn commit_register_idempotent_on_replay() { + let mut table = ClientTable::new(10); + table.commit_register(1, make_register_reply(1, 10)); + // Same client_id, same session — idempotent (WAL replay). + table.commit_register(1, make_register_reply(1, 10)); + assert_eq!(table.get_session(1), Some(10)); + assert_eq!(table.count(), 1); + } + + #[test] + #[should_panic(expected = "session mismatch")] + fn commit_register_different_session_panics() { + let mut table = ClientTable::new(10); + table.commit_register(1, make_register_reply(1, 10)); + // Same client_id, different session — protocol violation. + table.commit_register(1, make_register_reply(1, 20)); + } + + #[test] + #[should_panic(expected = "not registered")] + fn commit_reply_without_register_panics() { + let mut table = ClientTable::new(10); + table.commit_reply(1, 10, make_reply_for(1, 1, 10)); + } + + #[test] + #[should_panic(expected = "session mismatch")] + fn commit_reply_wrong_session_panics() { + let (mut table, _session) = table_with_client(); + // Session 10 is registered, but commit with session 99. + table.commit_reply(1, 99, make_reply_for(1, 1, 11)); + } + + #[test] + fn different_clients_independent_sessions() { + let mut table = ClientTable::new(10); + table.commit_register(1, make_register_reply(1, 10)); + table.commit_register(2, make_register_reply(2, 20)); + assert_eq!(table.get_session(1), Some(10)); + assert_eq!(table.get_session(2), Some(20)); + assert!(matches!(table.check_request(1, 10, 1), RequestStatus::New)); + assert!(matches!(table.check_request(2, 20, 1), RequestStatus::New)); + assert!(matches!( + table.check_request(1, 20, 1), + RequestStatus::SessionMismatch { .. } + )); + } + + #[test] + fn clear_pending_removes_all() { + let mut table = ClientTable::new(10); + let _n1 = table.register_pending(1, 1); + let _n2 = table.register_pending(2, 1); + table.clear_pending(); + assert_eq!(table.pending_count(), 0); + } } diff --git a/core/consensus/src/observability.rs b/core/consensus/src/observability.rs index 697c8e81aa..3306985da9 100644 --- a/core/consensus/src/observability.rs +++ b/core/consensus/src/observability.rs @@ -638,6 +638,7 @@ pub const fn operation_as_str(operation: Operation) -> &'static str { Operation::UpdatePermissions => "update_permissions", Operation::CreatePersonalAccessToken => "create_personal_access_token", Operation::DeletePersonalAccessToken => "delete_personal_access_token", + Operation::Register => "register", Operation::SendMessages => "send_messages", Operation::StoreConsumerOffset => "store_consumer_offset", Operation::DeleteConsumerOffset => "delete_consumer_offset", diff --git a/core/consensus/src/plane_helpers.rs b/core/consensus/src/plane_helpers.rs index e3fa54b3a1..5a5fecfa96 100644 --- a/core/consensus/src/plane_helpers.rs +++ b/core/consensus/src/plane_helpers.rs @@ -33,11 +33,13 @@ use std::ops::AsyncFnOnce; /// /// Returns `Some(Notify)` if the request is new and should proceed through /// consensus. Returns `None` if the request was already handled (duplicate -/// reply sent, in-progress, or stale), the caller should return early. +/// reply sent, in-progress, stale, or session error), the caller should +/// return early. #[allow(clippy::future_not_send)] pub async fn request_preflight( consensus: &VsrConsensus, client_id: u128, + session: u64, request: u64, ) -> Option where @@ -47,7 +49,7 @@ where let status = consensus .client_table() .borrow() - .check_request(client_id, request); + .check_request(client_id, session, request); match status { RequestStatus::Duplicate(cached_reply) => { // Best-effort resend, client may have disconnected. @@ -57,7 +59,12 @@ where .await; None } - RequestStatus::InProgress | RequestStatus::Stale => None, + RequestStatus::InProgress + | RequestStatus::Stale + | RequestStatus::NoSession + | RequestStatus::SessionMismatch { .. } + | RequestStatus::RequestGap { .. } + | RequestStatus::AlreadyRegistered { .. } => None, RequestStatus::New => { let notify = consensus .client_table() @@ -68,6 +75,48 @@ where } } +/// Shared register preflight: duplicate detection for `Operation::Register`. +/// +/// Returns `Some(Notify)` if the register is new and should proceed through +/// consensus. Returns `None` if the client is already registered (session +/// number sent back) or the register is already in progress. +#[allow(clippy::future_not_send, clippy::unused_async)] +pub async fn register_preflight( + consensus: &VsrConsensus, + client_id: u128, +) -> Option +where + B: MessageBus, Client = u128>, + P: Pipeline, +{ + let status = consensus.client_table().borrow().check_register(client_id); + match status { + RequestStatus::AlreadyRegistered { session } => { + // Synthesize a register reply with the existing session. + // The caller can extract session from reply.header().commit. + tracing::debug!( + client_id, + session, + "register_preflight: client already registered, ignoring" + ); + None + } + RequestStatus::InProgress => None, + RequestStatus::New => { + let notify = consensus + .client_table() + .borrow_mut() + .register_pending(client_id, 0); + Some(notify) + } + // check_register only returns AlreadyRegistered, InProgress, or New. + other => { + tracing::warn!(client_id, ?other, "register_preflight: unexpected status"); + None + } + } +} + /// Shared pipeline-first request flow used by metadata and partitions. /// /// # Panics diff --git a/core/integration/tests/server/scenarios/authentication_scenario.rs b/core/integration/tests/server/scenarios/authentication_scenario.rs index 6e476cf4c1..ab5245e3c1 100644 --- a/core/integration/tests/server/scenarios/authentication_scenario.rs +++ b/core/integration/tests/server/scenarios/authentication_scenario.rs @@ -115,12 +115,16 @@ async fn test_all_commands_require_auth(client: &IggyClient) { let name = entry.name; // ================================================================ - // SKIPPED COMMANDS (8 total) + // SKIPPED COMMANDS (10 total) // ================================================================ // No auth required if matches!( code, - PING_CODE | LOGIN_USER_CODE | LOGIN_WITH_PERSONAL_ACCESS_TOKEN_CODE + PING_CODE + | LOGIN_USER_CODE + | LOGIN_WITH_PERSONAL_ACCESS_TOKEN_CODE + | LOGIN_REGISTER_CODE + | LOGIN_REGISTER_WITH_PAT_CODE ) { continue; } diff --git a/core/metadata/src/impls/metadata.rs b/core/metadata/src/impls/metadata.rs index 79b3d98e18..c362821ec4 100644 --- a/core/metadata/src/impls/metadata.rs +++ b/core/metadata/src/impls/metadata.rs @@ -21,11 +21,11 @@ use consensus::{ ReplicaLogContext, RequestLogEvent, Sequencer, SimEventKind, VsrConsensus, ack_preflight, ack_quorum_reached, build_reply_message, drain_committable_prefix, emit_sim_event, fence_old_prepare_by_commit, panic_if_hash_chain_would_break_in_same_view, - pipeline_prepare_common, replicate_preflight, replicate_to_next_in_chain, request_preflight, - send_prepare_ok as send_prepare_ok_common, + pipeline_prepare_common, register_preflight, replicate_preflight, replicate_to_next_in_chain, + request_preflight, send_prepare_ok as send_prepare_ok_common, }; use iggy_binary_protocol::{ - Command2, ConsensusHeader, GenericHeader, Message, PrepareHeader, PrepareOkHeader, + Command2, ConsensusHeader, GenericHeader, Message, Operation, PrepareHeader, PrepareOkHeader, RequestHeader, }; use journal::{Journal, JournalHandle}; @@ -292,7 +292,9 @@ where async fn on_request(&self, message: as Consensus>::Message) { let consensus = self.consensus.as_ref().unwrap(); let client_id = message.header().client; + let session = message.header().session; let request = message.header().request; + let operation = message.header().operation; // TODO: Add a bounded request queue instead of dropping here. // When the prepare queue (8 max) is full, buffer @@ -311,9 +313,18 @@ where return; } - let Some(_notify) = request_preflight(consensus, client_id, request).await else { - return; + // Register uses a dedicated preflight (check_register, request=0). + // Normal metadata ops use request_preflight (check_request with session). + let preflight_ok = if operation == Operation::Register { + register_preflight(consensus, client_id).await.is_some() + } else { + request_preflight(consensus, client_id, session, request) + .await + .is_some() }; + if !preflight_ok { + return; + } emit_sim_event( SimEventKind::ClientRequestReceived, @@ -537,14 +548,6 @@ where ) }); - // Committed ops must be infallible — if the state machine cannot - // apply a committed op, replicas will diverge. - let response = self.mux_stm.update(prepare).unwrap_or_else(|err| { - panic!( - "on_ack: committed metadata op={} failed to apply: {err}", - prepare_header.op - ); - }); consensus.advance_commit_min(prepare_header.op); let pipeline_depth = consensus.pipeline().borrow().len(); let event = CommitLogEvent { @@ -555,14 +558,40 @@ where operation: prepare_header.operation, pipeline_depth, }; - emit_sim_event(SimEventKind::OperationCommitted, &event); - let reply = build_reply_message(consensus, &prepare_header, response); - // Cache reply for duplicate detection: - consensus - .client_table() - .borrow_mut() - .commit_reply(prepare_header.client, reply.clone()); + let reply = if prepare_header.operation == Operation::Register { + // Register: no state machine, commit_register creates session. + let reply = + build_reply_message(consensus, &prepare_header, bytes::Bytes::new()); + consensus + .client_table() + .borrow_mut() + .commit_register(prepare_header.client, reply.clone()); + reply + } else { + // Normal metadata op: apply state machine, commit_reply. + let response = self.mux_stm.update(prepare).unwrap_or_else(|err| { + panic!( + "on_ack: committed metadata op={} failed to apply: {err}", + prepare_header.op + ); + }); + let reply = build_reply_message(consensus, &prepare_header, response); + let session = consensus + .client_table() + .borrow() + .get_session(prepare_header.client) + .unwrap_or_else(|| { + panic!("on_ack: client {} not registered", prepare_header.client) + }); + consensus.client_table().borrow_mut().commit_reply( + prepare_header.client, + session, + reply.clone(), + ); + reply + }; + emit_sim_event(SimEventKind::OperationCommitted, &event); let generic_reply = reply.into_generic(); let reply_buffers = freeze_client_reply(generic_reply); @@ -599,7 +628,8 @@ where message.header().command(), Command2::Request | Command2::Prepare | Command2::PrepareOk )); - message.header().operation().is_metadata() + let op = message.header().operation(); + op.is_metadata() || op == Operation::Register } } @@ -667,18 +697,33 @@ where break; }; - // Committed ops must be infallible (see on_ack comment). - let response = self.mux_stm.update(prepare).unwrap_or_else(|err| { - panic!("commit_journal: committed metadata op={op} failed to apply: {err}"); - }); - consensus.advance_commit_min(op); - let reply = build_reply_message(consensus, &header, response); - consensus - .client_table() - .borrow_mut() - .commit_reply(header.client, reply); + if header.operation == Operation::Register { + // Register: no state machine, commit_register creates session. + let reply = build_reply_message(consensus, &header, bytes::Bytes::new()); + consensus + .client_table() + .borrow_mut() + .commit_register(header.client, reply); + } else { + // Normal metadata op: apply state machine, commit_reply. + let response = self.mux_stm.update(prepare).unwrap_or_else(|err| { + panic!("commit_journal: committed metadata op={op} failed to apply: {err}"); + }); + let reply = build_reply_message(consensus, &header, response); + let session = consensus + .client_table() + .borrow() + .get_session(header.client) + .unwrap_or_else(|| { + panic!("commit_journal: client {} not registered", header.client) + }); + consensus + .client_table() + .borrow_mut() + .commit_reply(header.client, session, reply); + } debug!("commit_journal: committed op={op}"); } diff --git a/core/partitions/src/iggy_partition.rs b/core/partitions/src/iggy_partition.rs index 5f8a97c183..729cc21328 100644 --- a/core/partitions/src/iggy_partition.rs +++ b/core/partitions/src/iggy_partition.rs @@ -34,7 +34,7 @@ use consensus::{ ack_quorum_reached, build_reply_message, drain_committable_prefix, emit_namespace_progress_event, emit_partition_diag, emit_sim_event, fence_old_prepare_by_commit, replicate_preflight, replicate_to_next_in_chain, - send_prepare_ok as send_prepare_ok_common, + request_preflight, send_prepare_ok as send_prepare_ok_common, }; use iggy_binary_protocol::consensus::iobuf::Frozen; use iggy_binary_protocol::{GenericHeader, PrepareOkHeader, RequestHeader}; @@ -587,14 +587,38 @@ where pub async fn on_request(&mut self, message: Message) { self.clear_pending_consumer_offset_commits_if_view_changed(); let namespace = IggyNamespace::from_raw(message.header().namespace); + let client_id = message.header().client; + let session = message.header().session; + let request = message.header().request; + + // TODO: Add a bounded request queue instead of dropping here. + // When the prepare queue (8 max) is full, buffer incoming requests + // in a request queue. On commit, pop the next request from the + // request queue and begin preparing it. Only drop when both queues + // are full. + { + let consensus = self.consensus(); + if consensus.pipeline().borrow().is_full() { + emit_partition_diag( + tracing::Level::WARN, + &PartitionDiagEvent::new( + ReplicaLogContext::from_consensus(consensus, PlaneKind::Partitions), + "on_request: pipeline full, dropping request", + ) + .with_operation(message.header().operation), + ); + return; + } + } + let prepare = { let consensus = self.consensus(); emit_sim_event( SimEventKind::ClientRequestReceived, &RequestLogEvent { replica: ReplicaLogContext::from_consensus(consensus, PlaneKind::Partitions), - client_id: message.header().client, - request_id: message.header().request, + client_id, + request_id: request, operation: message.header().operation, }, ); @@ -640,6 +664,13 @@ where } } + if request_preflight(consensus, client_id, session, request) + .await + .is_none() + { + return; + } + assert!(!consensus.is_follower(), "on_request: primary only"); assert!(consensus.is_normal(), "on_request: status must be normal"); assert!(!consensus.is_syncing(), "on_request: must not be syncing"); @@ -1080,10 +1111,29 @@ where pipeline_depth, ); + // Cache reply in client_table (both primary and backups) to + // preserve idempotency/dedup across view changes. Only the + // primary actually sends the reply to the client. + let reply = build_reply_message(&self.consensus, &prepare_header, bytes::Bytes::new()); + let session = self + .consensus + .client_table() + .borrow() + .get_session(prepare_header.client) + .unwrap_or_else(|| { + panic!( + "handle_committed_entries: client {} not registered", + prepare_header.client + ) + }); + self.consensus.client_table().borrow_mut().commit_reply( + prepare_header.client, + session, + reply.clone(), + ); + if send_client_replies { - let reply_buffers = - build_reply_message(&self.consensus, &prepare_header, bytes::Bytes::new()) - .into_generic(); + let reply_buffers = reply.into_generic(); emit_sim_event(SimEventKind::ClientReplyEmitted, &event); if let Err(error) = self diff --git a/core/sdk/src/lib.rs b/core/sdk/src/lib.rs index 9fa07df91a..35f0e1b732 100644 --- a/core/sdk/src/lib.rs +++ b/core/sdk/src/lib.rs @@ -25,6 +25,7 @@ pub mod http; mod leader_aware; pub mod prelude; pub mod quic; +pub mod session; pub mod stream_builder; pub mod tcp; pub mod websocket; diff --git a/core/sdk/src/session.rs b/core/sdk/src/session.rs new file mode 100644 index 0000000000..3479d5eb61 --- /dev/null +++ b/core/sdk/src/session.rs @@ -0,0 +1,252 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +//! Consensus-level session state for the SDK. +//! +//! Each SDK client instance generates an ephemeral random `client_id: u128` +//! on construction. After login, the server registers this client through +//! consensus and returns a `session` number (the commit op number). +//! +//! The SDK tracks the `(client_id, session)` pair and a monotonically +//! increasing `request_id` counter. These values populate the consensus +//! headers (`RequestHeader.client`, `.session`, `.request`) when the +//! transport sends requests through server-ng. +//! +//! ## Lifecycle +//! +//! ```text +//! new() - fresh client_id generated, session = None +//! register_request_id() - returns 0 (for the register operation, once only) +//! bind(session) - session assigned by server after register commits +//! next_request_id() - returns 1, 2, 3, ... (application requests) +//! drop + new() - on disconnect/crash, create a fresh session +//! ``` +//! +//! A `ConsensusSession` is **not reusable** across connections. On disconnect +//! or crash, drop it and create a new one. This generates a fresh `client_id` +//! and avoids ambiguous re-register semantics (TigerBeetle requires a fresh +//! client_id per process). The old session stays in the server's `ClientTable` +//! until evicted. + +/// Consensus-level session state. +/// +/// Single-threaded: owned by whatever drives the request loop (connection +/// handler or SDK transport). All methods take `&mut self`. If shared +/// access is needed, wrap in `Arc>` at the call site. +#[derive(Debug)] +pub struct ConsensusSession { + /// Ephemeral random client identifier. Generated once per process, + /// never persisted. Each SDK instance gets a unique value. + client_id: u128, + /// Session number assigned by the server after register commits + /// through consensus. `None` until bound. + session: Option, + /// Monotonically increasing request counter for application requests. + /// Starts at 1 after registration. Register itself always uses request=0. + request_counter: u64, + /// Whether `register_request_id()` has been called. + register_consumed: bool, +} + +impl ConsensusSession { + /// Create a new session with a random `client_id`. + #[must_use] + pub fn new() -> Self { + Self::with_client_id(generate_client_id()) + } + + /// Create a session with a specific `client_id` (for testing). + #[must_use] + pub fn with_client_id(client_id: u128) -> Self { + Self { + client_id, + session: None, + request_counter: 1, + register_consumed: false, + } + } + + /// The ephemeral client identifier for this session. + #[must_use] + pub fn client_id(&self) -> u128 { + self.client_id + } + + /// The session number, if registered. + #[must_use] + pub fn session(&self) -> Option { + self.session + } + + /// Whether the session is bound (register committed). + #[must_use] + pub fn is_bound(&self) -> bool { + self.session.is_some() + } + + /// Bind the session after register commits through consensus. + /// The `session` value is the commit op number from the server's reply. + /// + /// # Panics + /// Panics if already bound (drop and create a new session instead). + pub fn bind(&mut self, session: u64) { + assert!( + self.session.is_none(), + "session already bound (session={})", + self.session.unwrap() + ); + assert!(session > 0, "session must be > 0"); + self.session = Some(session); + } + + /// Returns the request ID for the register operation (always 0). + /// + /// Must be called exactly once, before [`bind`](Self::bind). + /// + /// # Panics + /// Panics if called more than once or after the session is bound. + pub fn register_request_id(&mut self) -> u64 { + assert!( + !self.register_consumed, + "register_request_id already called" + ); + assert!(!self.is_bound(), "register_request_id called after bind"); + self.register_consumed = true; + 0 + } + + /// Get the next application request ID and advance the counter. + /// + /// Returns 1, 2, 3, ... (request 0 is reserved for register). + /// + /// # Panics + /// Panics if the session is not bound. + pub fn next_request_id(&mut self) -> u64 { + assert!(self.is_bound(), "next_request_id called before bind"); + let id = self.request_counter; + self.request_counter = self + .request_counter + .checked_add(1) + .expect("request counter overflow (u64::MAX requests on a single session)"); + id + } + + /// Current request counter value (the next ID that will be returned). + #[must_use] + pub fn current_request_id(&self) -> u64 { + self.request_counter + } +} + +impl Default for ConsensusSession { + fn default() -> Self { + Self::new() + } +} + +/// Generate an ephemeral random u128 client ID using UUID v4. +/// +/// Non-zero by construction (UUID v4 has fixed bits that prevent all-zeros). +fn generate_client_id() -> u128 { + iggy_common::random_id::get_uuid() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn new_session_is_unbound() { + let session = ConsensusSession::new(); + assert!(!session.is_bound()); + assert!(session.session().is_none()); + assert_ne!(session.client_id(), 0); + } + + #[test] + fn client_id_is_unique() { + let s1 = ConsensusSession::new(); + let s2 = ConsensusSession::new(); + assert_ne!(s1.client_id(), s2.client_id()); + } + + #[test] + fn bind_sets_session() { + let mut session = ConsensusSession::with_client_id(42); + session.bind(100); + assert!(session.is_bound()); + assert_eq!(session.session(), Some(100)); + } + + #[test] + fn register_request_id_returns_zero() { + let mut session = ConsensusSession::with_client_id(1); + assert_eq!(session.register_request_id(), 0); + } + + #[test] + fn request_ids_are_monotonic_after_bind() { + let mut session = ConsensusSession::with_client_id(1); + let _ = session.register_request_id(); + session.bind(10); + assert_eq!(session.next_request_id(), 1); + assert_eq!(session.next_request_id(), 2); + assert_eq!(session.next_request_id(), 3); + assert_eq!(session.current_request_id(), 4); + } + + #[test] + #[should_panic(expected = "register_request_id already called")] + fn double_register_request_id_panics() { + let mut session = ConsensusSession::with_client_id(1); + let _ = session.register_request_id(); + let _ = session.register_request_id(); + } + + #[test] + #[should_panic(expected = "next_request_id called before bind")] + fn next_request_id_before_bind_panics() { + let mut session = ConsensusSession::with_client_id(1); + let _ = session.next_request_id(); + } + + #[test] + #[should_panic(expected = "already bound")] + fn double_bind_panics() { + let mut session = ConsensusSession::with_client_id(1); + session.bind(10); + session.bind(20); + } + + #[test] + fn reconnect_uses_fresh_session() { + let s1 = ConsensusSession::new(); + let s2 = ConsensusSession::new(); + // Each new session gets a different client_id. No reuse. + assert_ne!(s1.client_id(), s2.client_id()); + assert!(!s1.is_bound()); + assert!(!s2.is_bound()); + } + + #[test] + fn with_client_id_deterministic() { + let session = ConsensusSession::with_client_id(0xDEAD_BEEF); + assert_eq!(session.client_id(), 0xDEAD_BEEF); + } +} diff --git a/core/server-ng/src/lib.rs b/core/server-ng/src/lib.rs index 042f3ce1f3..caae94a3ab 100644 --- a/core/server-ng/src/lib.rs +++ b/core/server-ng/src/lib.rs @@ -16,3 +16,6 @@ * specific language governing permissions and limitations * under the License. */ + +pub mod login_register; +pub mod session_manager; diff --git a/core/server-ng/src/login_register.rs b/core/server-ng/src/login_register.rs new file mode 100644 index 0000000000..b4baebb64f --- /dev/null +++ b/core/server-ng/src/login_register.rs @@ -0,0 +1,467 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +//! Combined login + register handler for server-ng. +//! +//! One client-facing command, two internal phases: +//! 1. Verify credentials locally (Argon2). NOT through consensus +//! 2. Submit `Operation::Register` through consensus. All replicas create `ClientTable` entry +//! +//! The handler is trait-based so it can be tested via mocking. + +use crate::session_manager::{SessionError, SessionManager}; +use iggy_binary_protocol::requests::users::{LoginRegisterRequest, LoginRegisterWithPatRequest}; +use iggy_binary_protocol::responses::users::LoginRegisterResponse; +use secrecy::ExposeSecret; + +/// Credential verification abstraction. +/// +/// The real implementation delegates to the shard's metadata user store +/// and Argon2 password hashing. Test implementations return fixed results. +pub trait CredentialVerifier { + /// Verify username/password. Returns `user_id` on success. + /// + /// # Errors + /// Returns `LoginRegisterError` if credentials are invalid. + fn verify(&self, username: &str, password: &str) -> Result; +} + +/// Personal access token verification abstraction. +/// +/// The real implementation looks up the PAT by hash in the user store, +/// checks expiry, and returns the owning user's ID. +pub trait TokenVerifier { + /// Verify a personal access token. Returns `user_id` on success. + /// + /// # Errors + /// Returns `LoginRegisterError` if the token is invalid or expired. + fn verify_token(&self, token: &str) -> Result; +} + +/// Consensus register submission abstraction. +/// +/// The real implementation builds a `RequestHeader { operation: Register }`, +/// calls `check_register` on the `ClientTable`, submits through the consensus +/// pipeline, and awaits the `Notify` for commit. Returns the session number +/// (commit op number). +pub trait RegisterSubmitter { + /// Submit a register for `client_id` through consensus and await commit. + /// Returns the session number. + /// + /// # Errors + /// Returns `LoginRegisterError` if consensus fails or pipeline is full. + fn submit_register( + &self, + client_id: u128, + ) -> impl std::future::Future> + '_; +} + +/// Handle a combined login + register request (username/password). +/// +/// 1. Validates input +/// 2. Verifies credentials locally (no consensus) +/// 3. Submits `Register` through consensus. +/// 4. Returns `user_id` + `session` +/// +/// # Errors +/// Returns `LoginRegisterError` on auth failure, consensus failure, or +/// session state errors. +#[allow(clippy::future_not_send)] +pub async fn handle_login_register( + request: &LoginRegisterRequest, + verifier: &V, + submitter: &R, + session_manager: &mut SessionManager, + connection_id: u64, +) -> Result { + if request.client_id == 0 { + return Err(LoginRegisterError::InvalidClientId); + } + + // Phase 1: Local credential verification (NOT replicated). + let user_id = verifier.verify(request.username.as_str(), request.password.expose_secret())?; + + // Phase 2: Register through consensus. + complete_register( + request.client_id, + user_id, + submitter, + session_manager, + connection_id, + ) + .await +} + +/// Handle a combined login + register request (personal access token). +/// +/// Same two-phase flow as [`handle_login_register`], but Phase 1 verifies +/// a personal access token instead of username/password. +/// +/// # Errors +/// Returns `LoginRegisterError` on token failure, consensus failure, or +/// session state errors. +#[allow(clippy::future_not_send)] +pub async fn handle_login_register_with_pat( + request: &LoginRegisterWithPatRequest, + token_verifier: &T, + submitter: &R, + session_manager: &mut SessionManager, + connection_id: u64, +) -> Result { + if request.client_id == 0 { + return Err(LoginRegisterError::InvalidClientId); + } + + // Phase 1: Token verification (local, not replicated). + let user_id = token_verifier.verify_token(request.token.expose_secret())?; + + // Phase 2: Register through consensus (shared). + complete_register( + request.client_id, + user_id, + submitter, + session_manager, + connection_id, + ) + .await +} + +/// Phase 2 - transition session state and register through consensus. +/// +/// Called by both password and PAT handlers after their respective Phase 1 +/// credential verification succeeds. +#[allow(clippy::future_not_send)] +async fn complete_register( + client_id: u128, + user_id: u32, + submitter: &R, + session_manager: &mut SessionManager, + connection_id: u64, +) -> Result { + // Transition: Connected -> Authenticated. + session_manager + .login(connection_id, user_id) + .map_err(LoginRegisterError::Session)?; + + // Submit Register through consensus. + let session = match submitter.submit_register(client_id).await { + Ok(session) => session, + Err(e) => { + // Rollback: Authenticated -> Connected so the client can retry + // the full login+register on the same connection. + let _ = session_manager.reset_to_connected(connection_id); + return Err(e); + } + }; + + // Transition: Authenticated -> Bound. + session_manager + .bind_session(connection_id, client_id, session) + .map_err(LoginRegisterError::Session)?; + + Ok(LoginRegisterResponse { user_id, session }) +} + +#[derive(Debug)] +pub enum LoginRegisterError { + InvalidClientId, + InvalidCredentials, + InvalidToken, + UserInactive, + Session(SessionError), + PipelineFull, + ConsensusFailed(String), +} + +impl std::fmt::Display for LoginRegisterError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::InvalidClientId => write!(f, "client_id must be non-zero"), + Self::InvalidCredentials => write!(f, "invalid username or password"), + Self::InvalidToken => write!(f, "invalid or expired personal access token"), + Self::UserInactive => write!(f, "user account is inactive"), + Self::Session(e) => write!(f, "session error: {e}"), + Self::PipelineFull => write!(f, "consensus pipeline full, try again later"), + Self::ConsensusFailed(msg) => write!(f, "consensus failed: {msg}"), + } + } +} + +impl std::error::Error for LoginRegisterError {} + +#[cfg(test)] +mod tests { + use super::*; + use crate::session_manager::SessionManager; + use std::net::{IpAddr, Ipv4Addr, SocketAddr}; + + fn addr(port: u16) -> SocketAddr { + SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), port) + } + + struct MockVerifier { + result: Result, + } + + impl CredentialVerifier for MockVerifier { + fn verify(&self, _username: &str, _password: &str) -> Result { + match &self.result { + Ok(uid) => Ok(*uid), + Err(LoginRegisterError::InvalidCredentials) => { + Err(LoginRegisterError::InvalidCredentials) + } + Err(LoginRegisterError::UserInactive) => Err(LoginRegisterError::UserInactive), + _ => Err(LoginRegisterError::InvalidCredentials), + } + } + } + + struct MockSubmitter { + session: Result, + } + + impl RegisterSubmitter for MockSubmitter { + async fn submit_register(&self, _client_id: u128) -> Result { + match &self.session { + Ok(s) => Ok(*s), + Err(LoginRegisterError::PipelineFull) => Err(LoginRegisterError::PipelineFull), + Err(LoginRegisterError::ConsensusFailed(msg)) => { + Err(LoginRegisterError::ConsensusFailed(msg.clone())) + } + _ => Err(LoginRegisterError::ConsensusFailed( + "mock error".to_string(), + )), + } + } + } + + fn make_request(client_id: u128) -> LoginRegisterRequest { + LoginRegisterRequest { + client_id, + username: iggy_binary_protocol::WireName::new("admin").unwrap(), + password: secrecy::SecretString::from("secret"), + version: None, + client_context: None, + } + } + + macro_rules! block_on { + ($e:expr) => { + futures::executor::block_on($e) + }; + } + + #[test] + fn happy_path() { + block_on!(async { + let mut mgr = SessionManager::new(); + let conn = mgr.add_connection(addr(5000)); + let verifier = MockVerifier { result: Ok(42) }; + let submitter = MockSubmitter { session: Ok(100) }; + let req = make_request(0xDEAD); + + let resp = handle_login_register(&req, &verifier, &submitter, &mut mgr, conn) + .await + .unwrap(); + + assert_eq!(resp.user_id, 42); + assert_eq!(resp.session, 100); + assert_eq!(mgr.get_session(conn), Some((0xDEAD, 100))); + assert_eq!(mgr.bound_count(), 1); + }); + } + + #[test] + fn auth_failure_stays_connected() { + block_on!(async { + let mut mgr = SessionManager::new(); + let conn = mgr.add_connection(addr(5000)); + let verifier = MockVerifier { + result: Err(LoginRegisterError::InvalidCredentials), + }; + let submitter = MockSubmitter { session: Ok(100) }; + let req = make_request(0xDEAD); + + let err = handle_login_register(&req, &verifier, &submitter, &mut mgr, conn) + .await + .unwrap_err(); + + assert!(matches!(err, LoginRegisterError::InvalidCredentials)); + assert!(mgr.get_session(conn).is_none()); + }); + } + + #[test] + fn consensus_failure_rolls_back_to_connected() { + block_on!(async { + let mut mgr = SessionManager::new(); + let conn = mgr.add_connection(addr(5000)); + let verifier = MockVerifier { result: Ok(42) }; + let submitter = MockSubmitter { + session: Err(LoginRegisterError::PipelineFull), + }; + let req = make_request(0xDEAD); + + let err = handle_login_register(&req, &verifier, &submitter, &mut mgr, conn) + .await + .unwrap_err(); + + assert!(matches!(err, LoginRegisterError::PipelineFull)); + assert!(mgr.get_session(conn).is_none()); + + // Connection rolled back to Connected. Retry. + let submitter_ok = MockSubmitter { session: Ok(100) }; + let resp = handle_login_register(&req, &verifier, &submitter_ok, &mut mgr, conn) + .await + .unwrap(); + assert_eq!(resp.user_id, 42); + assert_eq!(resp.session, 100); + assert_eq!(mgr.get_session(conn), Some((0xDEAD, 100))); + }); + } + + #[test] + fn zero_client_id_rejected() { + block_on!(async { + let mut mgr = SessionManager::new(); + let conn = mgr.add_connection(addr(5000)); + let verifier = MockVerifier { result: Ok(42) }; + let submitter = MockSubmitter { session: Ok(100) }; + let req = make_request(0); + + let err = handle_login_register(&req, &verifier, &submitter, &mut mgr, conn) + .await + .unwrap_err(); + + assert!(matches!(err, LoginRegisterError::InvalidClientId)); + }); + } + + // PAT tests + + struct MockTokenVerifier { + result: Result, + } + + impl TokenVerifier for MockTokenVerifier { + fn verify_token(&self, _token: &str) -> Result { + match &self.result { + Ok(uid) => Ok(*uid), + Err(LoginRegisterError::UserInactive) => Err(LoginRegisterError::UserInactive), + _ => Err(LoginRegisterError::InvalidToken), + } + } + } + + fn make_pat_request(client_id: u128) -> LoginRegisterWithPatRequest { + LoginRegisterWithPatRequest { + client_id, + token: secrecy::SecretString::from("test-pat-token"), + version: None, + client_context: None, + } + } + + #[test] + fn pat_happy_path() { + block_on!(async { + let mut mgr = SessionManager::new(); + let conn = mgr.add_connection(addr(5000)); + let verifier = MockTokenVerifier { result: Ok(42) }; + let submitter = MockSubmitter { session: Ok(100) }; + let req = make_pat_request(0xDEAD); + + let resp = handle_login_register_with_pat(&req, &verifier, &submitter, &mut mgr, conn) + .await + .unwrap(); + + assert_eq!(resp.user_id, 42); + assert_eq!(resp.session, 100); + assert_eq!(mgr.get_session(conn), Some((0xDEAD, 100))); + assert_eq!(mgr.bound_count(), 1); + }); + } + + #[test] + fn pat_auth_failure_stays_connected() { + block_on!(async { + let mut mgr = SessionManager::new(); + let conn = mgr.add_connection(addr(5000)); + let verifier = MockTokenVerifier { + result: Err(LoginRegisterError::InvalidToken), + }; + let submitter = MockSubmitter { session: Ok(100) }; + let req = make_pat_request(0xDEAD); + + let err = handle_login_register_with_pat(&req, &verifier, &submitter, &mut mgr, conn) + .await + .unwrap_err(); + + assert!(matches!(err, LoginRegisterError::InvalidToken)); + assert!(mgr.get_session(conn).is_none()); + }); + } + + #[test] + fn pat_consensus_failure_rolls_back_to_connected() { + block_on!(async { + let mut mgr = SessionManager::new(); + let conn = mgr.add_connection(addr(5000)); + let verifier = MockTokenVerifier { result: Ok(42) }; + let submitter = MockSubmitter { + session: Err(LoginRegisterError::PipelineFull), + }; + let req = make_pat_request(0xDEAD); + + let err = handle_login_register_with_pat(&req, &verifier, &submitter, &mut mgr, conn) + .await + .unwrap_err(); + + assert!(matches!(err, LoginRegisterError::PipelineFull)); + assert!(mgr.get_session(conn).is_none()); + + // Connection rolled back to Connected. Retry. + let submitter_ok = MockSubmitter { session: Ok(100) }; + let resp = + handle_login_register_with_pat(&req, &verifier, &submitter_ok, &mut mgr, conn) + .await + .unwrap(); + assert_eq!(resp.user_id, 42); + assert_eq!(resp.session, 100); + assert_eq!(mgr.get_session(conn), Some((0xDEAD, 100))); + }); + } + + #[test] + fn pat_zero_client_id_rejected() { + block_on!(async { + let mut mgr = SessionManager::new(); + let conn = mgr.add_connection(addr(5000)); + let verifier = MockTokenVerifier { result: Ok(42) }; + let submitter = MockSubmitter { session: Ok(100) }; + let req = make_pat_request(0); + + let err = handle_login_register_with_pat(&req, &verifier, &submitter, &mut mgr, conn) + .await + .unwrap_err(); + + assert!(matches!(err, LoginRegisterError::InvalidClientId)); + }); + } +} diff --git a/core/server-ng/src/session_manager.rs b/core/server-ng/src/session_manager.rs new file mode 100644 index 0000000000..b89ae4a000 --- /dev/null +++ b/core/server-ng/src/session_manager.rs @@ -0,0 +1,432 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +//! Transport-to-consensus session bridge for server-ng. +//! +//! Maps ephemeral transport connections to durable consensus sessions. +//! Each connection goes through: `connect → login → register → bound`. +//! +//! The [`SessionManager`] is the server-side counterpart of the SDK's +//! session lifecycle. It does **not** own the [`ClientTable`]. That lives +//! in the consensus layer. This module tracks the binding between a +//! transport connection and the consensus-level `(client_id, session)` pair. + +use std::collections::HashMap; +use std::net::SocketAddr; + +/// Connection lifecycle states. +/// +/// ```text +/// Connected ──login──> Authenticated ──register──> Bound +/// +/// Bound ──evict──> Connected (another conn binds same client_id) +/// {any} ──disconnect──> ∅ +/// ``` +#[derive(Debug, Clone)] +pub enum ConnectionState { + /// Connection established, not yet authenticated. + Connected, + /// Login succeeded (credentials verified). `user_id` is known. + /// Waiting for register to establish a consensus session. + Authenticated { user_id: u32 }, + /// Register committed through consensus. Connection is bound to a + /// `(client_id, session)` pair. Requests on this connection use + /// these values to populate `RequestHeader.client` and + /// `RequestHeader.session`. + Bound { + user_id: u32, + client_id: u128, + session: u64, + }, +} + +/// Per-connection metadata tracked by the session manager. +#[derive(Debug, Clone)] +pub struct Connection { + pub address: SocketAddr, + pub state: ConnectionState, +} + +/// Bridges transport connections to consensus sessions. +/// +/// Thread-safe: intended to be shared across connection handler tasks. +/// Uses interior mutability via the caller's synchronization (single-threaded +/// shard model like the rest of iggy's server-ng). +/// +/// ## Invariants +/// +/// - A `connection_id` appears in at most one of `connections`. +/// - A `client_id` appears in at most one `Bound` connection (one connection +/// per consensus session). If a client reconnects with the same `client_id`, +/// the old connection must be evicted first. +pub struct SessionManager { + connections: HashMap, + /// Reverse index: `client_id` → `connection_id` for fast lookup when + /// a consensus reply arrives and needs routing to the right connection. + client_to_connection: HashMap, + next_connection_id: u64, +} + +impl SessionManager { + #[must_use] + pub fn new() -> Self { + Self { + connections: HashMap::new(), + client_to_connection: HashMap::new(), + next_connection_id: 1, + } + } + + /// Register a new transport connection. Returns the assigned connection ID. + /// + /// # Panics + /// Panics if the connection ID counter overflows `u64::MAX`. + pub fn add_connection(&mut self, address: SocketAddr) -> u64 { + let id = self.next_connection_id; + self.next_connection_id = self + .next_connection_id + .checked_add(1) + .expect("connection ID overflow (u64::MAX connections without restart)"); + self.connections.insert( + id, + Connection { + address, + state: ConnectionState::Connected, + }, + ); + id + } + + /// Remove a connection (disconnect). Cleans up the reverse index if bound. + pub fn remove_connection(&mut self, connection_id: u64) { + if let Some(conn) = self.connections.remove(&connection_id) + && let ConnectionState::Bound { client_id, .. } = conn.state + { + self.client_to_connection.remove(&client_id); + } + } + + /// Transition to `Authenticated` after successful login. + /// + /// # Errors + /// Returns `Err` if the connection doesn't exist or isn't in `Connected` state. + pub fn login(&mut self, connection_id: u64, user_id: u32) -> Result<(), SessionError> { + let conn = self + .connections + .get_mut(&connection_id) + .ok_or(SessionError::ConnectionNotFound(connection_id))?; + match conn.state { + ConnectionState::Connected => { + conn.state = ConnectionState::Authenticated { user_id }; + Ok(()) + } + _ => Err(SessionError::InvalidTransition { + connection_id, + from: state_name(&conn.state), + to: "Authenticated", + }), + } + } + + /// Reset a connection back to `Connected` state. + /// + /// Used to roll back a failed register attempt so the client can retry + /// the full login+register flow on the same connection without + /// reconnecting. + /// + /// # Errors + /// Returns `Err` if the connection doesn't exist or isn't `Authenticated`. + pub fn reset_to_connected(&mut self, connection_id: u64) -> Result<(), SessionError> { + let conn = self + .connections + .get_mut(&connection_id) + .ok_or(SessionError::ConnectionNotFound(connection_id))?; + match conn.state { + ConnectionState::Authenticated { .. } => { + conn.state = ConnectionState::Connected; + Ok(()) + } + _ => Err(SessionError::InvalidTransition { + connection_id, + from: state_name(&conn.state), + to: "Connected", + }), + } + } + + /// Transition to `Bound` after register commits through consensus. + /// + /// The `client_id` is the ephemeral u128 the client generated. + /// The `session` is the commit op number assigned by the consensus layer. + /// + /// If another connection was previously bound to this `client_id`, it is + /// forcibly unbound (set back to `Connected`). Only one connection per + /// session at a time. + /// + /// # Errors + /// Returns `Err` if the connection doesn't exist or isn't `Authenticated`. + /// + /// # Panics + /// Panics if the connection disappears between validation and mutation + /// (impossible in single-threaded use). + pub fn bind_session( + &mut self, + connection_id: u64, + client_id: u128, + session: u64, + ) -> Result<(), SessionError> { + // Validate state first (immutable borrow). + let conn = self + .connections + .get(&connection_id) + .ok_or(SessionError::ConnectionNotFound(connection_id))?; + let ConnectionState::Authenticated { user_id } = conn.state else { + return Err(SessionError::InvalidTransition { + connection_id, + from: state_name(&conn.state), + to: "Bound", + }); + }; + + // Evict any previous connection bound to this client_id. + if let Some(&old_conn_id) = self.client_to_connection.get(&client_id) + && old_conn_id != connection_id + && let Some(old_conn) = self.connections.get_mut(&old_conn_id) + { + old_conn.state = ConnectionState::Connected; + } + + // Now mutate the target connection. + self.connections.get_mut(&connection_id).unwrap().state = ConnectionState::Bound { + user_id, + client_id, + session, + }; + self.client_to_connection.insert(client_id, connection_id); + Ok(()) + } + + /// Look up the consensus session for a connection. + /// + /// Returns `(client_id, session)` if the connection is `Bound`, `None` otherwise. + #[must_use] + pub fn get_session(&self, connection_id: u64) -> Option<(u128, u64)> { + let conn = self.connections.get(&connection_id)?; + match conn.state { + ConnectionState::Bound { + client_id, session, .. + } => Some((client_id, session)), + _ => None, + } + } + + /// Look up the connection ID for a client (for routing consensus replies). + #[must_use] + pub fn connection_for_client(&self, client_id: u128) -> Option { + self.client_to_connection.get(&client_id).copied() + } + + /// Get connection metadata. + #[must_use] + pub fn get_connection(&self, connection_id: u64) -> Option<&Connection> { + self.connections.get(&connection_id) + } + + /// Number of active connections. + #[must_use] + pub fn connection_count(&self) -> usize { + self.connections.len() + } + + /// Number of bound (registered) sessions. + #[must_use] + pub fn bound_count(&self) -> usize { + self.client_to_connection.len() + } +} + +impl Default for SessionManager { + fn default() -> Self { + Self::new() + } +} + +#[derive(Debug)] +pub enum SessionError { + ConnectionNotFound(u64), + InvalidTransition { + connection_id: u64, + from: &'static str, + to: &'static str, + }, +} + +impl std::fmt::Display for SessionError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::ConnectionNotFound(id) => write!(f, "connection {id} not found"), + Self::InvalidTransition { + connection_id, + from, + to, + } => write!( + f, + "connection {connection_id}: invalid transition {from} -> {to}" + ), + } + } +} + +impl std::error::Error for SessionError {} + +const fn state_name(state: &ConnectionState) -> &'static str { + match state { + ConnectionState::Connected => "Connected", + ConnectionState::Authenticated { .. } => "Authenticated", + ConnectionState::Bound { .. } => "Bound", + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::net::{IpAddr, Ipv4Addr}; + + fn addr(port: u16) -> SocketAddr { + SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), port) + } + + #[test] + fn full_lifecycle() { + let mut mgr = SessionManager::new(); + + let conn = mgr.add_connection(addr(5000)); + assert_eq!(mgr.connection_count(), 1); + assert!(mgr.get_session(conn).is_none()); + + // Login + mgr.login(conn, 42).unwrap(); + assert!(mgr.get_session(conn).is_none()); // not bound yet + + // Register committed. Bind session + let client_id: u128 = 0xDEAD_BEEF; + let session: u64 = 100; + mgr.bind_session(conn, client_id, session).unwrap(); + + assert_eq!(mgr.get_session(conn), Some((client_id, session))); + assert_eq!(mgr.connection_for_client(client_id), Some(conn)); + assert_eq!(mgr.bound_count(), 1); + + // Disconnect + mgr.remove_connection(conn); + assert_eq!(mgr.connection_count(), 0); + assert_eq!(mgr.bound_count(), 0); + assert!(mgr.connection_for_client(client_id).is_none()); + } + + #[test] + fn login_requires_connected_state() { + let mut mgr = SessionManager::new(); + let conn = mgr.add_connection(addr(5000)); + mgr.login(conn, 1).unwrap(); + + // Double login should fail. Already Authenticated. + assert!(mgr.login(conn, 2).is_err()); + } + + #[test] + fn bind_requires_authenticated_state() { + let mut mgr = SessionManager::new(); + let conn = mgr.add_connection(addr(5000)); + + // Bind without login should fail. + assert!(mgr.bind_session(conn, 1, 1).is_err()); + } + + #[test] + fn bind_evicts_old_connection_for_same_client() { + let mut mgr = SessionManager::new(); + + // First connection binds to client_id 99. + let conn1 = mgr.add_connection(addr(5000)); + mgr.login(conn1, 1).unwrap(); + mgr.bind_session(conn1, 99, 10).unwrap(); + assert_eq!(mgr.connection_for_client(99), Some(conn1)); + + // Second connection binds to same client_id. Evicts conn1. + let conn2 = mgr.add_connection(addr(5001)); + mgr.login(conn2, 1).unwrap(); + mgr.bind_session(conn2, 99, 20).unwrap(); + + assert_eq!(mgr.connection_for_client(99), Some(conn2)); + // conn1 reverted to Connected. + assert!(mgr.get_session(conn1).is_none()); + } + + #[test] + fn remove_nonexistent_connection_is_noop() { + let mut mgr = SessionManager::new(); + mgr.remove_connection(999); // should not panic + } + + #[test] + fn login_nonexistent_connection_errors() { + let mut mgr = SessionManager::new(); + assert!(mgr.login(999, 1).is_err()); + } + + #[test] + fn reset_to_connected_from_authenticated() { + let mut mgr = SessionManager::new(); + let conn = mgr.add_connection(addr(5000)); + mgr.login(conn, 1).unwrap(); + mgr.reset_to_connected(conn).unwrap(); + // Back to Connected. Can login again. + mgr.login(conn, 2).unwrap(); + } + + #[test] + fn reset_to_connected_rejects_wrong_state() { + let mut mgr = SessionManager::new(); + let conn = mgr.add_connection(addr(5000)); + // Connected - reset should fail. + assert!(mgr.reset_to_connected(conn).is_err()); + } + + #[test] + fn multiple_independent_sessions() { + let mut mgr = SessionManager::new(); + + let c1 = mgr.add_connection(addr(5000)); + let c2 = mgr.add_connection(addr(5001)); + mgr.login(c1, 1).unwrap(); + mgr.login(c2, 2).unwrap(); + mgr.bind_session(c1, 100, 10).unwrap(); + mgr.bind_session(c2, 200, 20).unwrap(); + + assert_eq!(mgr.get_session(c1), Some((100, 10))); + assert_eq!(mgr.get_session(c2), Some((200, 20))); + assert_eq!(mgr.bound_count(), 2); + + mgr.remove_connection(c1); + assert_eq!(mgr.bound_count(), 1); + assert!(mgr.connection_for_client(100).is_none()); + assert_eq!(mgr.connection_for_client(200), Some(c2)); + } +} diff --git a/core/simulator/src/client.rs b/core/simulator/src/client.rs index 7928e053e5..96a20fa928 100644 --- a/core/simulator/src/client.rs +++ b/core/simulator/src/client.rs @@ -31,6 +31,7 @@ use std::cell::Cell; pub struct SimClient { client_id: u128, request_counter: Cell, + session: Cell, } impl SimClient { @@ -39,13 +40,64 @@ impl SimClient { Self { client_id, request_counter: Cell::new(0), + session: Cell::new(0), } } + #[must_use] + pub const fn client_id(&self) -> u128 { + self.client_id + } + + /// Bind the session assigned by the consensus layer after registration. + /// + /// # Panics + /// Panics if `session` is 0. + pub fn bind_session(&self, session: u64) { + assert!(session > 0, "bind_session: session must be > 0"); + self.session.set(session); + } + fn next_request_number(&self) -> u64 { - let current = self.request_counter.get(); - self.request_counter.set(current + 1); - current + let next = self.request_counter.get() + 1; + self.request_counter.set(next); + next + } + + fn session_id(&self) -> u64 { + let s = self.session.get(); + assert!( + s > 0, + "session not bound — call register() + bind_session() first" + ); + s + } + + /// Build a `Register` request for this client. + /// + /// Register uses `session=0, request=0` per the protocol spec. + /// The consensus layer assigns a session on commit. + /// + /// # Panics + /// Panics if the register request buffer is invalid. + #[allow(clippy::cast_possible_truncation)] + pub fn register(&self) -> Message { + let header_size = std::mem::size_of::(); + let header = RequestHeader { + command: iggy_binary_protocol::Command2::Request, + operation: Operation::Register, + size: header_size as u32, + client: self.client_id, + session: 0, + request: 0, + ..Default::default() + }; + + let header_bytes = bytemuck::bytes_of(&header); + let buffer = header_bytes.to_vec(); + + Message::try_from(Owned::<4096>::copy_from_slice(&buffer)) + .expect("register request must be valid") } /// # Panics @@ -169,6 +221,7 @@ impl SimClient { client: self.client_id, request_checksum: 0, timestamp: 0, // TODO: Use actual timestamp + session: self.session_id(), request: self.next_request_number(), ..Default::default() }; @@ -203,6 +256,7 @@ impl SimClient { client: self.client_id, request_checksum: 0, timestamp: 0, + session: self.session_id(), request: self.next_request_number(), namespace: namespace.inner(), ..Default::default() diff --git a/core/simulator/src/lib.rs b/core/simulator/src/lib.rs index 6d79485ab7..9cfc09b46a 100644 --- a/core/simulator/src/lib.rs +++ b/core/simulator/src/lib.rs @@ -24,8 +24,10 @@ pub mod ready_queue; pub mod replica; use bus::SimOutbox; +use client::SimClient; use consensus::PartitionsHandle; -use iggy_binary_protocol::{GenericHeader, Message, ReplyHeader}; +use iggy_binary_protocol::consensus::iobuf::Owned; +use iggy_binary_protocol::{Command2, GenericHeader, Message, Operation, ReplyHeader}; use iggy_common::IggyError; use iggy_common::sharding::IggyNamespace; use message_bus::MessageBus; @@ -36,6 +38,27 @@ use replica::{Replica, new_replica}; use std::collections::HashSet; use std::sync::Arc; +/// Build a minimal `ReplyHeader` message for a Register operation. +/// +/// Used to seed partition-level client tables, which don't yet receive +/// Register through the network protocol. +#[allow(clippy::cast_possible_truncation)] +fn build_register_reply(client_id: u128, session: u64) -> Message { + let header_size = std::mem::size_of::(); + let header = ReplyHeader { + command: Command2::Reply, + operation: Operation::Register, + size: header_size as u32, + client: client_id, + commit: session, + request: 0, + ..Default::default() + }; + let header_bytes = bytemuck::bytes_of(&header); + Message::try_from(Owned::<4096>::copy_from_slice(header_bytes)) + .expect("register reply must be valid") +} + pub struct Simulator { /// All replicas, indexed by replica id. Always fully populated — crashed /// replicas are kept alive but skipped during dispatch. @@ -206,6 +229,61 @@ impl Simulator { ); } + /// Register a client with the consensus cluster via the primary (replica 0). + /// + /// This sends a `Register` through consensus (metadata plane) and also + /// directly registers the client in each live replica's partition consensus + /// client table. The partition plane doesn't yet handle `Register` via + /// the network protocol, so we seed it manually. + /// + /// Binds the assigned session on the `SimClient`. + /// + /// # Panics + /// Panics if no reply arrives within 100 steps. + #[allow(clippy::cast_possible_truncation)] + pub fn register_client_with_primary(&mut self, client: &SimClient) { + let msg = client.register(); + self.submit_request(client.client_id(), 0, msg.into_generic()); + let mut session = 0u64; + let mut got_reply = false; + for _ in 0..100 { + let replies = self.step(); + if !replies.is_empty() { + session = replies[0].header().commit; + got_reply = true; + break; + } + } + assert!( + got_reply, + "register_client_with_primary: no reply within 100 steps" + ); + client.bind_session(session); + + // Seed the partition consensus client tables on all live replicas. + // The partition plane doesn't route Register operations yet, so we + // do it directly to match what the full server will do once the + // partition-level registration is wired up. With per-partition + // consensus, every partition's client_table needs the entry. + for (i, replica) in self.replicas.iter().enumerate() { + if self.crashed.contains(&(i as u8)) { + continue; + } + let partitions = replica.plane.partitions(); + let namespaces: Vec<_> = partitions.namespaces().copied().collect(); + for ns in namespaces { + if let Some(partition) = partitions.get_by_ns(&ns) { + let reply = build_register_reply(client.client_id(), session); + partition + .consensus() + .client_table() + .borrow_mut() + .commit_register(client.client_id(), reply); + } + } + } + } + /// Crash a replica: disable its network links and discard its outbox. /// /// The replica object is kept alive but will not receive any messages or @@ -340,6 +418,9 @@ mod tests { let ns = IggyNamespace::new(1, 1, 0); sim.init_partition(ns); + // Register the client with the consensus cluster. + sim.register_client_with_primary(&client); + // Send a message through the primary (replica 0) to verify normal operation. let msg = client.send_messages(ns, &[b"before crash"]); sim.submit_request(client_id, 0, msg.into_generic()); @@ -435,6 +516,9 @@ mod tests { let ns = IggyNamespace::new(1, 1, 0); sim.init_partition(ns); + // Register the client with the consensus cluster. + sim.register_client_with_primary(&client); + // Send several messages so the primary commits ahead of backups. // Backups receive prepares but may not have committed all of them // (commit_max lags behind the primary's commit_min because the