diff --git a/Cargo.lock b/Cargo.lock index 4d962a6..a821d50 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -120,6 +120,26 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724" +[[package]] +name = "chat_room" +version = "0.4.5" +dependencies = [ + "spawned-concurrency", + "spawned-macros", + "spawned-rt", + "tracing", +] + +[[package]] +name = "chat_room_threads" +version = "0.4.5" +dependencies = [ + "spawned-concurrency", + "spawned-macros", + "spawned-rt", + "tracing", +] + [[package]] name = "core-foundation" version = "0.9.4" @@ -1129,6 +1149,16 @@ dependencies = [ "serde", ] +[[package]] +name = "service_discovery" +version = "0.4.5" +dependencies = [ + "spawned-concurrency", + "spawned-macros", + "spawned-rt", + "tracing", +] + [[package]] name = "sharded-slab" version = "0.1.7" @@ -1210,6 +1240,7 @@ version = "0.4.5" dependencies = [ "futures", "pin-project-lite", + "spawned-macros", "spawned-rt", "thiserror", "tokio", @@ -1217,6 +1248,15 @@ dependencies = [ "tracing", ] +[[package]] +name = "spawned-macros" +version = "0.4.5" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "spawned-rt" version = "0.4.5" diff --git a/Cargo.toml b/Cargo.toml index f234fe4..abd1256 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,6 +4,7 @@ resolver = "3" members = [ "rt", "concurrency", + "macros", "examples/bank", "examples/bank_threads", "examples/name_server", @@ -15,11 +16,15 @@ members = [ "examples/busy_genserver_warning", "examples/signal_test", "examples/signal_test_threads", + "examples/chat_room", + "examples/chat_room_threads", + "examples/service_discovery", ] [workspace.dependencies] spawned-rt = { path = "rt", version = "0.4.5" } spawned-concurrency = { path = "concurrency", version = "0.4.5" } +spawned-macros = { path = "macros", version = "0.4.5" } tracing = { version = "0.1.41", features = ["log"] } tracing-subscriber = { version = "0.3.19", features = ["env-filter"] } diff --git a/concurrency/Cargo.toml b/concurrency/Cargo.toml index c2845a9..04ba8b8 100644 --- a/concurrency/Cargo.toml +++ b/concurrency/Cargo.toml @@ -7,6 +7,7 @@ license.workspace = true [dependencies] spawned-rt = { workspace = true } +spawned-macros = { workspace = true } tracing = { workspace = true } futures = "0.3.1" thiserror = "2.0.12" diff --git a/concurrency/src/error.rs b/concurrency/src/error.rs index 3b23e4b..35123ef 100644 --- a/concurrency/src/error.rs +++ b/concurrency/src/error.rs @@ -1,28 +1,20 @@ #[derive(Debug, thiserror::Error)] pub enum ActorError { - #[error("Callback Error")] - Callback, - #[error("Initialization error")] - Initialization, - #[error("Server error")] - Server, - #[error("Unsupported Request on this Actor")] - RequestUnused, - #[error("Unsupported Message on this Actor")] - MessageUnused, + #[error("Actor stopped")] + ActorStopped, #[error("Request to Actor timed out")] RequestTimeout, } impl From> for ActorError { fn from(_value: spawned_rt::threads::mpsc::SendError) -> Self { - Self::Server + Self::ActorStopped } } impl From> for ActorError { fn from(_value: spawned_rt::tasks::mpsc::SendError) -> Self { - Self::Server + Self::ActorStopped } } @@ -32,7 +24,7 @@ mod tests { #[test] fn test_error_into_std_error() { - let error: &dyn std::error::Error = &ActorError::Callback; - assert_eq!(error.to_string(), "Callback Error"); + let error: &dyn std::error::Error = &ActorError::ActorStopped; + assert_eq!(error.to_string(), "Actor stopped"); } } diff --git a/concurrency/src/lib.rs b/concurrency/src/lib.rs index 0edcab8..c470d0b 100644 --- a/concurrency/src/lib.rs +++ b/concurrency/src/lib.rs @@ -1,6 +1,5 @@ -//! spawned concurrency -//! Some basic traits and structs to implement concurrent code à-la-Erlang. pub mod error; -pub mod messages; +pub mod message; +pub mod registry; pub mod tasks; pub mod threads; diff --git a/concurrency/src/message.rs b/concurrency/src/message.rs new file mode 100644 index 0000000..6d010b5 --- /dev/null +++ b/concurrency/src/message.rs @@ -0,0 +1,297 @@ +pub trait Message: Send + 'static { + type Result: Send + 'static; +} + +/// Declarative macro for defining message types. +/// +/// Supports both unit structs and structs with fields, and they can be mixed +/// in a single invocation: +/// +/// ```ignore +/// messages! { +/// GetCount -> u64; +/// Deposit { who: String, amount: i32 } -> Result; +/// Stop -> () +/// } +/// ``` +#[macro_export] +macro_rules! messages { + () => {}; + + // Base: unit message + ($(#[$meta:meta])* $name:ident -> $result:ty) => { + $(#[$meta])* + pub struct $name; + impl $crate::message::Message for $name { + type Result = $result; + } + }; + + // Base: struct message + ($(#[$meta:meta])* $name:ident { $($field:ident : $ftype:ty),* $(,)? } -> $result:ty) => { + $(#[$meta])* + pub struct $name { $(pub $field: $ftype,)* } + impl $crate::message::Message for $name { + type Result = $result; + } + }; + + // Recursive: unit message followed by more + ($(#[$meta:meta])* $name:ident -> $result:ty; $($rest:tt)*) => { + $crate::messages!($(#[$meta])* $name -> $result); + $crate::messages!($($rest)*); + }; + + // Recursive: struct message followed by more + ($(#[$meta:meta])* $name:ident { $($field:ident : $ftype:ty),* $(,)? } -> $result:ty; $($rest:tt)*) => { + $crate::messages!($(#[$meta])* $name { $($field : $ftype),* } -> $result); + $crate::messages!($($rest)*); + }; +} + +/// Fire-and-forget messages (Result type is always `()`). +/// +/// ```ignore +/// send_messages! { +/// Increment; +/// Deposit { who: String, amount: i32 } +/// } +/// ``` +#[macro_export] +macro_rules! send_messages { + () => {}; + + // Base: unit message + ($(#[$meta:meta])* $name:ident) => { + $(#[$meta])* + pub struct $name; + impl $crate::message::Message for $name { + type Result = (); + } + }; + + // Base: struct message + ($(#[$meta:meta])* $name:ident { $($field:ident : $ftype:ty),* $(,)? }) => { + $(#[$meta])* + pub struct $name { $(pub $field: $ftype,)* } + impl $crate::message::Message for $name { + type Result = (); + } + }; + + // Recursive: unit message followed by more + ($(#[$meta:meta])* $name:ident; $($rest:tt)*) => { + $crate::send_messages!($(#[$meta])* $name); + $crate::send_messages!($($rest)*); + }; + + // Recursive: struct message followed by more + ($(#[$meta:meta])* $name:ident { $($field:ident : $ftype:ty),* $(,)? }; $($rest:tt)*) => { + $crate::send_messages!($(#[$meta])* $name { $($field : $ftype),* }); + $crate::send_messages!($($rest)*); + }; +} + +/// Request-response messages (Result type is explicitly specified). +/// +/// ```ignore +/// request_messages! { +/// GetCount -> u64; +/// Lookup { key: String } -> Option +/// } +/// ``` +#[macro_export] +macro_rules! request_messages { + () => {}; + + // Base: unit message + ($(#[$meta:meta])* $name:ident -> $result:ty) => { + $(#[$meta])* + pub struct $name; + impl $crate::message::Message for $name { + type Result = $result; + } + }; + + // Base: struct message + ($(#[$meta:meta])* $name:ident { $($field:ident : $ftype:ty),* $(,)? } -> $result:ty) => { + $(#[$meta])* + pub struct $name { $(pub $field: $ftype,)* } + impl $crate::message::Message for $name { + type Result = $result; + } + }; + + // Recursive: unit message followed by more + ($(#[$meta:meta])* $name:ident -> $result:ty; $($rest:tt)*) => { + $crate::request_messages!($(#[$meta])* $name -> $result); + $crate::request_messages!($($rest)*); + }; + + // Recursive: struct message followed by more + ($(#[$meta:meta])* $name:ident { $($field:ident : $ftype:ty),* $(,)? } -> $result:ty; $($rest:tt)*) => { + $crate::request_messages!($(#[$meta])* $name { $($field : $ftype),* } -> $result); + $crate::request_messages!($($rest)*); + }; +} + +/// Generates an extension trait + impl on `ActorRef` for ergonomic method-call syntax. +/// +/// Parameter names must match message struct field names exactly (ownership transfer). +/// +/// ```ignore +/// actor_api! { +/// pub ChatRoomApi for ActorRef { +/// send fn say(from: String, text: String) => Say; +/// send fn add_member(name: String, inbox: Recipient) => Join; +/// request async fn members() -> Vec => Members; +/// } +/// } +/// ``` +/// +/// For threads (sync), use `request fn` instead of `request async fn`. +#[macro_export] +macro_rules! actor_api { + // Entry: pub trait + (pub $trait_name:ident for $actor_ref:ty { $($body:tt)* }) => { + $crate::actor_api!(@parse [pub] $trait_name $actor_ref [] [] $($body)*); + }; + + // Entry: private trait + ($trait_name:ident for $actor_ref:ty { $($body:tt)* }) => { + $crate::actor_api!(@parse [] $trait_name $actor_ref [] [] $($body)*); + }; + + // Terminal: generate trait + impl + (@parse [$($vis:tt)*] $trait_name:ident $actor_ref:ty + [$($trait_items:tt)*] + [$($impl_items:tt)*] + ) => { + $($vis)* trait $trait_name { + $($trait_items)* + } + impl $trait_name for $actor_ref { + $($impl_items)* + } + }; + + // send fn with params + (@parse [$($vis:tt)*] $trait_name:ident $actor_ref:ty + [$($trait_items:tt)*] + [$($impl_items:tt)*] + send fn $method:ident($($param:ident : $ptype:ty),+ $(,)?) => $msg:ident; + $($rest:tt)* + ) => { + $crate::actor_api!(@parse [$($vis)*] $trait_name $actor_ref + [$($trait_items)* + fn $method(&self, $($param : $ptype),+) -> Result<(), $crate::error::ActorError>; + ] + [$($impl_items)* + fn $method(&self, $($param : $ptype),+) -> Result<(), $crate::error::ActorError> { + self.send($msg { $($param),+ }) + } + ] + $($rest)* + ); + }; + + // send fn without params (unit message) + (@parse [$($vis:tt)*] $trait_name:ident $actor_ref:ty + [$($trait_items:tt)*] + [$($impl_items:tt)*] + send fn $method:ident() => $msg:ident; + $($rest:tt)* + ) => { + $crate::actor_api!(@parse [$($vis)*] $trait_name $actor_ref + [$($trait_items)* + fn $method(&self) -> Result<(), $crate::error::ActorError>; + ] + [$($impl_items)* + fn $method(&self) -> Result<(), $crate::error::ActorError> { + self.send($msg) + } + ] + $($rest)* + ); + }; + + // request async fn with params + (@parse [$($vis:tt)*] $trait_name:ident $actor_ref:ty + [$($trait_items:tt)*] + [$($impl_items:tt)*] + request async fn $method:ident($($param:ident : $ptype:ty),+ $(,)?) -> $ret:ty => $msg:ident; + $($rest:tt)* + ) => { + $crate::actor_api!(@parse [$($vis)*] $trait_name $actor_ref + [$($trait_items)* + async fn $method(&self, $($param : $ptype),+) -> Result<$ret, $crate::error::ActorError>; + ] + [$($impl_items)* + async fn $method(&self, $($param : $ptype),+) -> Result<$ret, $crate::error::ActorError> { + self.request($msg { $($param),+ }).await + } + ] + $($rest)* + ); + }; + + // request async fn without params (unit message) + (@parse [$($vis:tt)*] $trait_name:ident $actor_ref:ty + [$($trait_items:tt)*] + [$($impl_items:tt)*] + request async fn $method:ident() -> $ret:ty => $msg:ident; + $($rest:tt)* + ) => { + $crate::actor_api!(@parse [$($vis)*] $trait_name $actor_ref + [$($trait_items)* + async fn $method(&self) -> Result<$ret, $crate::error::ActorError>; + ] + [$($impl_items)* + async fn $method(&self) -> Result<$ret, $crate::error::ActorError> { + self.request($msg).await + } + ] + $($rest)* + ); + }; + + // request fn with params (sync/threads) + (@parse [$($vis:tt)*] $trait_name:ident $actor_ref:ty + [$($trait_items:tt)*] + [$($impl_items:tt)*] + request fn $method:ident($($param:ident : $ptype:ty),+ $(,)?) -> $ret:ty => $msg:ident; + $($rest:tt)* + ) => { + $crate::actor_api!(@parse [$($vis)*] $trait_name $actor_ref + [$($trait_items)* + fn $method(&self, $($param : $ptype),+) -> Result<$ret, $crate::error::ActorError>; + ] + [$($impl_items)* + fn $method(&self, $($param : $ptype),+) -> Result<$ret, $crate::error::ActorError> { + self.request($msg { $($param),+ }) + } + ] + $($rest)* + ); + }; + + // request fn without params (sync/threads, unit message) + (@parse [$($vis:tt)*] $trait_name:ident $actor_ref:ty + [$($trait_items:tt)*] + [$($impl_items:tt)*] + request fn $method:ident() -> $ret:ty => $msg:ident; + $($rest:tt)* + ) => { + $crate::actor_api!(@parse [$($vis)*] $trait_name $actor_ref + [$($trait_items)* + fn $method(&self) -> Result<$ret, $crate::error::ActorError>; + ] + [$($impl_items)* + fn $method(&self) -> Result<$ret, $crate::error::ActorError> { + self.request($msg) + } + ] + $($rest)* + ); + }; +} diff --git a/concurrency/src/messages.rs b/concurrency/src/messages.rs deleted file mode 100644 index e0aceb8..0000000 --- a/concurrency/src/messages.rs +++ /dev/null @@ -1,2 +0,0 @@ -#[derive(Clone, Debug)] -pub struct Unused; diff --git a/concurrency/src/registry.rs b/concurrency/src/registry.rs new file mode 100644 index 0000000..f37a5ba --- /dev/null +++ b/concurrency/src/registry.rs @@ -0,0 +1,91 @@ +use std::any::Any; +use std::collections::HashMap; +use std::sync::{OnceLock, RwLock}; + +type Store = RwLock>>; + +fn global_store() -> &'static Store { + static STORE: OnceLock = OnceLock::new(); + STORE.get_or_init(|| RwLock::new(HashMap::new())) +} + +#[derive(Debug, thiserror::Error)] +pub enum RegistryError { + #[error("name '{0}' is already registered")] + AlreadyRegistered(String), +} + +pub fn register(name: &str, value: T) -> Result<(), RegistryError> { + let mut store = global_store().write().unwrap_or_else(|p| p.into_inner()); + if store.contains_key(name) { + return Err(RegistryError::AlreadyRegistered(name.to_string())); + } + store.insert(name.to_string(), Box::new(value)); + Ok(()) +} + +pub fn whereis(name: &str) -> Option { + let store = global_store().read().unwrap_or_else(|p| p.into_inner()); + store.get(name)?.downcast_ref::().cloned() +} + +pub fn unregister(name: &str) { + let mut store = global_store().write().unwrap_or_else(|p| p.into_inner()); + store.remove(name); +} + +pub fn registered() -> Vec { + let store = global_store().read().unwrap_or_else(|p| p.into_inner()); + store.keys().cloned().collect() +} + +#[cfg(test)] +mod tests { + use super::*; + + // Use unique names per test to avoid cross-test interference with global state. + + #[test] + fn register_and_whereis() { + register("test_rw_1", 42u64).unwrap(); + let val: Option = whereis("test_rw_1"); + assert_eq!(val, Some(42)); + } + + #[test] + fn whereis_wrong_type_returns_none() { + register("test_wt_1", 42u64).unwrap(); + let val: Option = whereis("test_wt_1"); + assert_eq!(val, None); + } + + #[test] + fn whereis_missing_returns_none() { + let val: Option = whereis("nonexistent_key"); + assert_eq!(val, None); + } + + #[test] + fn duplicate_register_fails() { + register("test_dup_1", 1u32).unwrap(); + let result = register("test_dup_1", 2u32); + assert!(result.is_err()); + } + + #[test] + fn unregister_removes_entry() { + register("test_unreg_1", "hello".to_string()).unwrap(); + unregister("test_unreg_1"); + let val: Option = whereis("test_unreg_1"); + assert_eq!(val, None); + } + + #[test] + fn registered_lists_names() { + register("test_list_a", 1u32).unwrap(); + register("test_list_b", 2u32).unwrap(); + let names = registered(); + assert!(names.contains(&"test_list_a".to_string())); + assert!(names.contains(&"test_list_b".to_string())); + } +} diff --git a/concurrency/src/tasks/actor.rs b/concurrency/src/tasks/actor.rs index d41e3a3..c0a90a9 100644 --- a/concurrency/src/tasks/actor.rs +++ b/concurrency/src/tasks/actor.rs @@ -1,470 +1,452 @@ -//! Actor trait and structs to create an abstraction similar to Erlang gen_server. -//! See examples/name_server for a usage example. -use crate::{ - error::ActorError, - tasks::InitResult::{NoSuccess, Success}, -}; +use crate::error::ActorError; +use crate::message::Message; use core::pin::pin; use futures::future::{self, FutureExt as _}; use spawned_rt::{ tasks::{self as rt, mpsc, oneshot, timeout, watch, CancellationToken, JoinHandle}, threads, }; -use std::{fmt::Debug, future::Future, panic::AssertUnwindSafe, time::Duration}; +use std::{fmt::Debug, future::Future, panic::AssertUnwindSafe, pin::Pin, sync::Arc, time::Duration}; const DEFAULT_REQUEST_TIMEOUT: Duration = Duration::from_secs(5); -/// Execution backend for Actor. -/// -/// Determines how the Actor's async loop is executed. Choose based on -/// the nature of your workload: -/// -/// # Backend Comparison -/// -/// | Backend | Execution Model | Best For | Limitations | -/// |---------|-----------------|----------|-------------| -/// | `Async` | Tokio task | Non-blocking I/O, async operations | Blocks runtime if sync code runs too long | -/// | `Blocking` | Tokio blocking pool | Short blocking operations (file I/O, DNS) | Shared pool with limited threads | -/// | `Thread` | Dedicated OS thread with own runtime | Long-running services, isolation from main runtime | Higher memory overhead per Actor | -/// -/// **Note**: All backends use async internally. For fully synchronous code without any async -/// runtime, use [`threads::Actor`](crate::threads::Actor) instead. -/// -/// # Examples -/// -/// ```ignore -/// // For typical async workloads (HTTP handlers, database queries) -/// let handle = MyServer::new().start(); -/// -/// // For occasional blocking operations (file reads, external commands) -/// let handle = MyServer::new().start_with_backend(Backend::Blocking); -/// -/// // For CPU-intensive or permanently blocking services -/// let handle = MyServer::new().start_with_backend(Backend::Thread); -/// ``` -/// -/// # When to Use Each Backend -/// -/// ## `Backend::Async` (Default) -/// - **Advantages**: Lightweight, efficient, good for high concurrency -/// - **Use when**: Your Actor does mostly async I/O (network, database) -/// - **Avoid when**: Your code blocks (e.g., `std::thread::sleep`, heavy computation) -/// -/// ## `Backend::Blocking` -/// - **Advantages**: Prevents blocking the async runtime, uses tokio's managed pool -/// - **Use when**: You have occasional blocking operations that complete quickly -/// - **Avoid when**: You need guaranteed thread availability or long-running blocks -/// -/// ## `Backend::Thread` -/// - **Advantages**: Isolated from main runtime, dedicated thread won't affect other tasks -/// - **Use when**: Long-running singleton services that shouldn't share the main runtime -/// - **Avoid when**: You need many Actors (each gets its own OS thread + runtime) -/// - **Note**: Still uses async internally (own runtime). For sync code, use `threads::Actor` +// --------------------------------------------------------------------------- +// Backend +// --------------------------------------------------------------------------- + #[derive(Debug, Clone, Copy, Default, PartialEq, Eq)] pub enum Backend { - /// Run on tokio async runtime (default). - /// - /// Best for non-blocking, async workloads. The Actor runs as a - /// lightweight tokio task, enabling high concurrency with minimal overhead. - /// - /// **Warning**: If your `handle_request` or `handle_message` blocks synchronously - /// (e.g., `std::thread::sleep`, CPU-heavy loops), it will block the entire - /// tokio runtime thread, affecting other tasks. #[default] Async, - - /// Run on tokio's blocking thread pool. - /// - /// Use for Actors that perform blocking operations like: - /// - Synchronous file I/O - /// - DNS lookups - /// - External process calls - /// - Short CPU-bound computations - /// - /// The pool is shared across all `spawn_blocking` calls and has a default - /// limit of 512 threads. If the pool is exhausted, new blocking tasks wait. Blocking, - - /// Run on a dedicated OS thread with its own async runtime. - /// - /// Use for Actors that: - /// - Need isolation from the main tokio runtime - /// - Are long-running singleton services - /// - Should not compete with other tasks for runtime resources - /// - /// Each Actor gets its own thread with a separate tokio runtime, - /// providing isolation from other async tasks. Higher memory overhead - /// (~2MB stack per thread plus runtime overhead). - /// - /// **Note**: This still uses async internally. For fully synchronous code - /// without any async runtime, use [`threads::Actor`](crate::threads::Actor). Thread, } -#[derive(Debug)] -pub struct ActorRef { - pub tx: mpsc::Sender>, - /// Cancellation token to stop the Actor +// --------------------------------------------------------------------------- +// Actor trait +// --------------------------------------------------------------------------- + +pub trait Actor: Send + Sized + 'static { + fn started(&mut self, _ctx: &Context) -> impl Future + Send { + async {} + } + + fn stopped(&mut self, _ctx: &Context) -> impl Future + Send { + async {} + } +} + +// --------------------------------------------------------------------------- +// Handler trait (per-message, uses RPITIT — NOT object-safe, that's fine) +// --------------------------------------------------------------------------- + +pub trait Handler: Actor { + fn handle( + &mut self, + msg: M, + ctx: &Context, + ) -> impl Future + Send; +} + +// --------------------------------------------------------------------------- +// Envelope (type-erasure on the actor side) +// --------------------------------------------------------------------------- + +trait Envelope: Send { + fn handle<'a>( + self: Box, + actor: &'a mut A, + ctx: &'a Context, + ) -> Pin + Send + 'a>>; +} + +struct MessageEnvelope { + msg: M, + tx: Option>, +} + +impl Envelope for MessageEnvelope +where + A: Actor + Handler, + M: Message, +{ + fn handle<'a>( + self: Box, + actor: &'a mut A, + ctx: &'a Context, + ) -> Pin + Send + 'a>> { + Box::pin(async move { + let result = actor.handle(self.msg, ctx).await; + if let Some(tx) = self.tx { + let _ = tx.send(result); + } + }) + } +} + +// --------------------------------------------------------------------------- +// Context +// --------------------------------------------------------------------------- + +pub struct Context { + sender: mpsc::Sender + Send>>, cancellation_token: CancellationToken, - /// Completion signal for waiting on actor stop (true = stopped) - completion_rx: watch::Receiver, } -impl Clone for ActorRef { +impl Clone for Context { fn clone(&self) -> Self { Self { - tx: self.tx.clone(), + sender: self.sender.clone(), cancellation_token: self.cancellation_token.clone(), - completion_rx: self.completion_rx.clone(), } } } -impl ActorRef { - fn new(actor: A) -> Self { - let (tx, mut rx) = mpsc::channel::>(); - let cancellation_token = CancellationToken::new(); - let (completion_tx, completion_rx) = watch::channel(false); - let handle = ActorRef { - tx, - cancellation_token, - completion_rx, - }; - let handle_clone = handle.clone(); - let inner_future = async move { - if let Err(error) = actor.run(&handle, &mut rx).await { - tracing::trace!(%error, "Actor crashed") - } - // Signal completion to all waiters - let _ = completion_tx.send(true); - }; +impl Debug for Context { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Context").finish_non_exhaustive() + } +} - #[cfg(debug_assertions)] - // Optionally warn if the Actor future blocks for too much time - let inner_future = warn_on_block::WarnOnBlocking::new(inner_future); +impl Context { + pub fn from_ref(actor_ref: &ActorRef) -> Self { + Self { + sender: actor_ref.sender.clone(), + cancellation_token: actor_ref.cancellation_token.clone(), + } + } - let _task_handle = rt::spawn(inner_future); + pub fn stop(&self) { + self.cancellation_token.cancel(); + } - handle_clone + pub fn send(&self, msg: M) -> Result<(), ActorError> + where + A: Handler, + M: Message, + { + let envelope = MessageEnvelope { msg, tx: None }; + self.sender + .send(Box::new(envelope)) + .map_err(|_| ActorError::ActorStopped) } - fn new_blocking(actor: A) -> Self { - let (tx, mut rx) = mpsc::channel::>(); - let cancellation_token = CancellationToken::new(); - let (completion_tx, completion_rx) = watch::channel(false); - let handle = ActorRef { - tx, - cancellation_token, - completion_rx, + pub fn request_raw(&self, msg: M) -> Result, ActorError> + where + A: Handler, + M: Message, + { + let (tx, rx) = oneshot::channel(); + let envelope = MessageEnvelope { + msg, + tx: Some(tx), }; - let handle_clone = handle.clone(); - let _task_handle = rt::spawn_blocking(move || { - rt::block_on(async move { - if let Err(error) = actor.run(&handle, &mut rx).await { - tracing::trace!(%error, "Actor crashed") - }; - // Signal completion to all waiters - let _ = completion_tx.send(true); - }) - }); + self.sender + .send(Box::new(envelope)) + .map_err(|_| ActorError::ActorStopped)?; + Ok(rx) + } + + pub async fn request(&self, msg: M) -> Result + where + A: Handler, + M: Message, + { + let rx = self.request_raw(msg)?; + match timeout(DEFAULT_REQUEST_TIMEOUT, rx).await { + Ok(Ok(result)) => Ok(result), + Ok(Err(_)) => Err(ActorError::ActorStopped), + Err(_) => Err(ActorError::RequestTimeout), + } + } - handle_clone + pub fn recipient(&self) -> Recipient + where + A: Handler, + M: Message, + { + Arc::new(self.clone()) } - fn new_on_thread(actor: A) -> Self { - let (tx, mut rx) = mpsc::channel::>(); - let cancellation_token = CancellationToken::new(); - let (completion_tx, completion_rx) = watch::channel(false); - let handle = ActorRef { - tx, - cancellation_token, - completion_rx, - }; - let handle_clone = handle.clone(); - let _thread_handle = threads::spawn(move || { - threads::block_on(async move { - if let Err(error) = actor.run(&handle, &mut rx).await { - tracing::trace!(%error, "Actor crashed") - }; - // Signal completion to all waiters - let _ = completion_tx.send(true); - }) - }); + pub(crate) fn cancellation_token(&self) -> CancellationToken { + self.cancellation_token.clone() + } +} - handle_clone +// Bridge: Context implements Receiver for any M that A handles +impl Receiver for Context +where + A: Actor + Handler, + M: Message, +{ + fn send(&self, msg: M) -> Result<(), ActorError> { + Context::send(self, msg) } - pub fn sender(&self) -> mpsc::Sender> { - self.tx.clone() + fn request_raw(&self, msg: M) -> Result, ActorError> { + Context::request_raw(self, msg) } +} + +// --------------------------------------------------------------------------- +// Receiver trait (object-safe) + Recipient alias +// --------------------------------------------------------------------------- + +pub trait Receiver: Send + Sync { + fn send(&self, msg: M) -> Result<(), ActorError>; + fn request_raw(&self, msg: M) -> Result, ActorError>; +} + +pub type Recipient = Arc>; - pub async fn request(&mut self, message: A::Request) -> Result { - self.request_with_timeout(message, DEFAULT_REQUEST_TIMEOUT) - .await +pub async fn request( + recipient: &dyn Receiver, + msg: M, + timeout_duration: Duration, +) -> Result { + let rx = recipient.request_raw(msg)?; + match timeout(timeout_duration, rx).await { + Ok(Ok(result)) => Ok(result), + Ok(Err(_)) => Err(ActorError::ActorStopped), + Err(_) => Err(ActorError::RequestTimeout), } +} - pub async fn request_with_timeout( - &mut self, - message: A::Request, +// --------------------------------------------------------------------------- +// ActorRef +// --------------------------------------------------------------------------- + +pub struct ActorRef { + sender: mpsc::Sender + Send>>, + cancellation_token: CancellationToken, + completion_rx: watch::Receiver, +} + +impl Debug for ActorRef { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ActorRef").finish_non_exhaustive() + } +} + +impl Clone for ActorRef { + fn clone(&self) -> Self { + Self { + sender: self.sender.clone(), + cancellation_token: self.cancellation_token.clone(), + completion_rx: self.completion_rx.clone(), + } + } +} + +impl ActorRef { + pub fn send(&self, msg: M) -> Result<(), ActorError> + where + A: Handler, + M: Message, + { + let envelope = MessageEnvelope { msg, tx: None }; + self.sender + .send(Box::new(envelope)) + .map_err(|_| ActorError::ActorStopped) + } + + pub fn request_raw(&self, msg: M) -> Result, ActorError> + where + A: Handler, + M: Message, + { + let (tx, rx) = oneshot::channel(); + let envelope = MessageEnvelope { + msg, + tx: Some(tx), + }; + self.sender + .send(Box::new(envelope)) + .map_err(|_| ActorError::ActorStopped)?; + Ok(rx) + } + + pub async fn request(&self, msg: M) -> Result + where + A: Handler, + M: Message, + { + self.request_with_timeout(msg, DEFAULT_REQUEST_TIMEOUT).await + } + + pub async fn request_with_timeout( + &self, + msg: M, duration: Duration, - ) -> Result { - let (oneshot_tx, oneshot_rx) = oneshot::channel::>(); - self.tx.send(ActorInMsg::Request { - sender: oneshot_tx, - message, - })?; - - match timeout(duration, oneshot_rx).await { - Ok(Ok(result)) => result, - Ok(Err(_)) => Err(ActorError::Server), + ) -> Result + where + A: Handler, + M: Message, + { + let rx = self.request_raw(msg)?; + match timeout(duration, rx).await { + Ok(Ok(result)) => Ok(result), + Ok(Err(_)) => Err(ActorError::ActorStopped), Err(_) => Err(ActorError::RequestTimeout), } } - pub async fn send(&mut self, message: A::Message) -> Result<(), ActorError> { - self.tx - .send(ActorInMsg::Message { message }) - .map_err(|_error| ActorError::Server) + pub fn recipient(&self) -> Recipient + where + A: Handler, + M: Message, + { + Arc::new(self.clone()) } - pub(crate) fn cancellation_token(&self) -> CancellationToken { - self.cancellation_token.clone() + pub fn context(&self) -> Context { + Context::from_ref(self) } - /// Waits for the actor to stop. - /// - /// This method returns a future that completes when the actor has finished - /// processing and exited its main loop. Can be called multiple times from - /// different clones of the ActorRef - all callers will be notified when - /// the actor stops. pub async fn join(&self) { let mut rx = self.completion_rx.clone(); - // Wait until completion signal is true while !*rx.borrow_and_update() { if rx.changed().await.is_err() { - // Sender dropped, actor must have completed break; } } } } -pub enum ActorInMsg { - Request { - sender: oneshot::Sender>, - message: A::Request, - }, - Message { - message: A::Message, - }, -} - -pub enum RequestResponse { - Reply(A::Reply), - Unused, - Stop(A::Reply), -} +// Bridge: ActorRef implements Receiver for any M that A handles +impl Receiver for ActorRef +where + A: Actor + Handler, + M: Message, +{ + fn send(&self, msg: M) -> Result<(), ActorError> { + ActorRef::send(self, msg) + } -pub enum MessageResponse { - NoReply, - Unused, - Stop, + fn request_raw(&self, msg: M) -> Result, ActorError> { + ActorRef::request_raw(self, msg) + } } -pub enum InitResult { - Success(A), - NoSuccess(A), -} +// --------------------------------------------------------------------------- +// Actor startup + main loop +// --------------------------------------------------------------------------- -pub trait Actor: Send + Sized { - type Request: Clone + Send + Sized + Sync; - type Message: Clone + Send + Sized + Sync; - type Reply: Send + Sized; - type Error: Debug + Send; +impl ActorRef { + fn spawn(actor: A, backend: Backend) -> Self { + let (tx, rx) = mpsc::channel:: + Send>>(); + let cancellation_token = CancellationToken::new(); + let (completion_tx, completion_rx) = watch::channel(false); - /// Start the Actor with the default backend (Async). - fn start(self) -> ActorRef { - self.start_with_backend(Backend::default()) - } + let actor_ref = ActorRef { + sender: tx.clone(), + cancellation_token: cancellation_token.clone(), + completion_rx, + }; - /// Start the Actor with the specified backend. - /// - /// # Arguments - /// * `backend` - The execution backend to use: - /// - `Backend::Async` - Run on tokio async runtime (default, best for non-blocking workloads) - /// - `Backend::Blocking` - Run on tokio's blocking thread pool (for blocking operations) - /// - `Backend::Thread` - Run on a dedicated OS thread (for long-running blocking services) - fn start_with_backend(self, backend: Backend) -> ActorRef { - match backend { - Backend::Async => ActorRef::new(self), - Backend::Blocking => ActorRef::new_blocking(self), - Backend::Thread => ActorRef::new_on_thread(self), - } - } + let ctx = Context { + sender: tx, + cancellation_token: cancellation_token.clone(), + }; - fn run( - self, - handle: &ActorRef, - rx: &mut mpsc::Receiver>, - ) -> impl Future> + Send { - async { - let res = match self.init(handle).await { - Ok(Success(new_state)) => Ok(new_state.main_loop(handle, rx).await), - Ok(NoSuccess(intermediate_state)) => { - // new_state is NoSuccess, this means the initialization failed, but the error was handled - // in callback. No need to report the error. - // Just skip main_loop and return the state to teardown the Actor - Ok(intermediate_state) - } - Err(err) => { - tracing::error!("Initialization failed with unhandled error: {err:?}"); - Err(ActorError::Initialization) - } - }; + let inner_future = async move { + run_actor(actor, ctx, rx, cancellation_token).await; + let _ = completion_tx.send(true); + }; - handle.cancellation_token().cancel(); - if let Ok(final_state) = res { - if let Err(err) = final_state.teardown(handle).await { - tracing::error!("Error during teardown: {err:?}"); - } + match backend { + Backend::Async => { + #[cfg(debug_assertions)] + let inner_future = warn_on_block::WarnOnBlocking::new(inner_future); + let _handle = rt::spawn(inner_future); + } + Backend::Blocking => { + let _handle = rt::spawn_blocking(move || { + rt::block_on(inner_future) + }); + } + Backend::Thread => { + let _handle = threads::spawn(move || { + threads::block_on(inner_future) + }); } - Ok(()) } - } - /// Initialization function. It's called before main loop. It - /// can be overrided on implementations in case initial steps are - /// required. - fn init( - self, - _handle: &ActorRef, - ) -> impl Future, Self::Error>> + Send { - async { Ok(Success(self)) } + actor_ref } +} - fn main_loop( - mut self, - handle: &ActorRef, - rx: &mut mpsc::Receiver>, - ) -> impl Future + Send { - async { - loop { - if !self.receive(handle, rx).await { +async fn run_actor( + mut actor: A, + ctx: Context, + mut rx: mpsc::Receiver + Send>>, + cancellation_token: CancellationToken, +) { + actor.started(&ctx).await; + + if cancellation_token.is_cancelled() { + actor.stopped(&ctx).await; + return; + } + + loop { + let msg = rx.recv().await; + match msg { + Some(envelope) => { + let result = AssertUnwindSafe(envelope.handle(&mut actor, &ctx)) + .catch_unwind() + .await; + if let Err(panic) = result { + tracing::error!("Panic in message handler: {panic:?}"); + break; + } + if cancellation_token.is_cancelled() { break; } } - tracing::trace!("Stopping Actor"); - self + None => break, } } - fn receive( - &mut self, - handle: &ActorRef, - rx: &mut mpsc::Receiver>, - ) -> impl Future + Send { - async move { - let message = rx.recv().await; - - let keep_running = match message { - Some(ActorInMsg::Request { sender, message }) => { - let (keep_running, response) = - match AssertUnwindSafe(self.handle_request(message, handle)) - .catch_unwind() - .await - { - Ok(response) => match response { - RequestResponse::Reply(response) => (true, Ok(response)), - RequestResponse::Stop(response) => (false, Ok(response)), - RequestResponse::Unused => { - tracing::error!("Actor received unexpected Request"); - (false, Err(ActorError::RequestUnused)) - } - }, - Err(error) => { - tracing::error!("Error in callback: '{error:?}'"); - (false, Err(ActorError::Callback)) - } - }; - // Send response back - if sender.send(response).is_err() { - tracing::error!("Actor failed to send response back, client must have died") - }; - keep_running - } - Some(ActorInMsg::Message { message }) => { - match AssertUnwindSafe(self.handle_message(message, handle)) - .catch_unwind() - .await - { - Ok(response) => match response { - MessageResponse::NoReply => true, - MessageResponse::Stop => false, - MessageResponse::Unused => { - tracing::error!("Actor received unexpected Message"); - false - } - }, - Err(error) => { - tracing::trace!("Error in callback: '{error:?}'"); - false - } - } - } - None => { - // Channel has been closed; won't receive further messages. Stop the server. - false - } - }; - keep_running - } - } + cancellation_token.cancel(); + actor.stopped(&ctx).await; +} - fn handle_request( - &mut self, - _message: Self::Request, - _handle: &ActorRef, - ) -> impl Future> + Send { - async { RequestResponse::Unused } - } +// --------------------------------------------------------------------------- +// Actor::start +// --------------------------------------------------------------------------- - fn handle_message( - &mut self, - _message: Self::Message, - _handle: &ActorRef, - ) -> impl Future + Send { - async { MessageResponse::Unused } +pub trait ActorStart: Actor { + fn start(self) -> ActorRef { + self.start_with_backend(Backend::default()) } - /// Teardown function. It's called after the stop message is received. - /// It can be overrided on implementations in case final steps are required, - /// like closing streams, stopping timers, etc. - fn teardown( - self, - _handle: &ActorRef, - ) -> impl Future> + Send { - async { Ok(()) } + fn start_with_backend(self, backend: Backend) -> ActorRef { + ActorRef::spawn(self, backend) } } -/// Spawns a task that awaits on a future and sends a message to an Actor -/// on completion. -/// This function returns a handle to the spawned task. -pub fn send_message_on(handle: ActorRef, future: U, message: T::Message) -> JoinHandle<()> +impl ActorStart for A {} + +// --------------------------------------------------------------------------- +// send_message_on (utility) +// --------------------------------------------------------------------------- + +pub fn send_message_on(ctx: Context, future: U, msg: M) -> JoinHandle<()> where - T: Actor, + A: Actor + Handler, + M: Message, U: Future + Send + 'static, ::Output: Send, { - let cancellation_token = handle.cancellation_token(); - let mut handle_clone = handle.clone(); + let cancellation_token = ctx.cancellation_token(); let join_handle = rt::spawn(async move { let is_cancelled = pin!(cancellation_token.cancelled()); let signal = pin!(future); match future::select(is_cancelled, signal).await { future::Either::Left(_) => tracing::debug!("Actor stopped"), future::Either::Right(_) => { - if let Err(e) = handle_clone.send(message).await { + if let Err(e) = ctx.send(msg) { tracing::error!("Failed to send message: {e:?}") } } @@ -473,10 +455,13 @@ where join_handle } +// --------------------------------------------------------------------------- +// WarnOnBlocking (debug only) +// --------------------------------------------------------------------------- + #[cfg(debug_assertions)] mod warn_on_block { use super::*; - use std::time::Instant; use tracing::warn; @@ -514,229 +499,54 @@ mod warn_on_block { } } +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + #[cfg(test)] mod tests { - use super::*; - use crate::{messages::Unused, tasks::send_after}; + use crate::messages; use std::{ - sync::{Arc, Mutex}, + sync::{atomic, Arc}, thread, time::Duration, }; - struct BadlyBehavedTask; - - #[derive(Clone)] - pub enum InMessage { - GetCount, - Stop, - } - #[derive(Clone)] - pub enum OutMsg { - Count(u64), - } - - impl Actor for BadlyBehavedTask { - type Request = InMessage; - type Message = Unused; - type Reply = Unused; - type Error = Unused; - - async fn handle_request( - &mut self, - _: Self::Request, - _: &ActorRef, - ) -> RequestResponse { - RequestResponse::Stop(Unused) - } - - async fn handle_message( - &mut self, - _: Self::Message, - _: &ActorRef, - ) -> MessageResponse { - rt::sleep(Duration::from_millis(20)).await; - thread::sleep(Duration::from_secs(2)); - MessageResponse::Stop - } - } - - struct WellBehavedTask { - pub count: u64, - } - - impl Actor for WellBehavedTask { - type Request = InMessage; - type Message = Unused; - type Reply = OutMsg; - type Error = Unused; - - async fn handle_request( - &mut self, - message: Self::Request, - _: &ActorRef, - ) -> RequestResponse { - match message { - InMessage::GetCount => RequestResponse::Reply(OutMsg::Count(self.count)), - InMessage::Stop => RequestResponse::Stop(OutMsg::Count(self.count)), - } - } - - async fn handle_message( - &mut self, - _: Self::Message, - handle: &ActorRef, - ) -> MessageResponse { - self.count += 1; - println!("{:?}: good still alive", thread::current().id()); - send_after(Duration::from_millis(100), handle.to_owned(), Unused); - MessageResponse::NoReply - } - } - - const BLOCKING: Backend = Backend::Blocking; - - #[test] - pub fn badly_behaved_thread_non_blocking() { - let runtime = rt::Runtime::new().unwrap(); - runtime.block_on(async move { - let mut badboy = BadlyBehavedTask.start(); - let _ = badboy.send(Unused).await; - let mut goodboy = WellBehavedTask { count: 0 }.start(); - let _ = goodboy.send(Unused).await; - rt::sleep(Duration::from_secs(1)).await; - let count = goodboy.request(InMessage::GetCount).await.unwrap(); + // --- Counter actor for basic tests --- - match count { - OutMsg::Count(num) => { - assert_ne!(num, 10); - } - } - goodboy.request(InMessage::Stop).await.unwrap(); - }); + struct Counter { + count: u64, } - #[test] - pub fn badly_behaved_thread() { - let runtime = rt::Runtime::new().unwrap(); - runtime.block_on(async move { - let mut badboy = BadlyBehavedTask.start_with_backend(BLOCKING); - let _ = badboy.send(Unused).await; - let mut goodboy = WellBehavedTask { count: 0 }.start(); - let _ = goodboy.send(Unused).await; - rt::sleep(Duration::from_secs(1)).await; - let count = goodboy.request(InMessage::GetCount).await.unwrap(); - - match count { - OutMsg::Count(num) => { - assert_eq!(num, 10); - } - } - goodboy.request(InMessage::Stop).await.unwrap(); - }); + messages! { + GetCount -> u64; + Increment -> u64; + StopCounter -> u64 } - const TIMEOUT_DURATION: Duration = Duration::from_millis(100); - - #[derive(Debug, Default)] - struct SomeTask; - - #[derive(Clone)] - enum SomeTaskRequest { - SlowOperation, - FastOperation, - } + impl Actor for Counter {} - impl Actor for SomeTask { - type Request = SomeTaskRequest; - type Message = Unused; - type Reply = Unused; - type Error = Unused; - - async fn handle_request( - &mut self, - message: Self::Request, - _handle: &ActorRef, - ) -> RequestResponse { - match message { - SomeTaskRequest::SlowOperation => { - // Simulate a slow operation that will not resolve in time - rt::sleep(TIMEOUT_DURATION * 2).await; - RequestResponse::Reply(Unused) - } - SomeTaskRequest::FastOperation => { - // Simulate a fast operation that resolves in time - rt::sleep(TIMEOUT_DURATION / 2).await; - RequestResponse::Reply(Unused) - } - } + impl Handler for Counter { + async fn handle(&mut self, _msg: GetCount, _ctx: &Context) -> u64 { + self.count } } - #[test] - pub fn unresolving_task_times_out() { - let runtime = rt::Runtime::new().unwrap(); - runtime.block_on(async move { - let mut unresolving_task = SomeTask.start(); - - let result = unresolving_task - .request_with_timeout(SomeTaskRequest::FastOperation, TIMEOUT_DURATION) - .await; - assert!(matches!(result, Ok(Unused))); - - let result = unresolving_task - .request_with_timeout(SomeTaskRequest::SlowOperation, TIMEOUT_DURATION) - .await; - assert!(matches!(result, Err(ActorError::RequestTimeout))); - }); - } - - struct SomeTaskThatFailsOnInit { - sender_channel: Arc>>, - } - - impl SomeTaskThatFailsOnInit { - pub fn new(sender_channel: Arc>>) -> Self { - Self { sender_channel } + impl Handler for Counter { + async fn handle(&mut self, _msg: Increment, _ctx: &Context) -> u64 { + self.count += 1; + self.count } } - impl Actor for SomeTaskThatFailsOnInit { - type Request = Unused; - type Message = Unused; - type Reply = Unused; - type Error = Unused; - - async fn init(self, _handle: &ActorRef) -> Result, Self::Error> { - // Simulate an initialization failure by returning NoSuccess - Ok(NoSuccess(self)) - } - - async fn teardown(self, _handle: &ActorRef) -> Result<(), Self::Error> { - self.sender_channel.lock().unwrap().close(); - Ok(()) + impl Handler for Counter { + async fn handle(&mut self, _msg: StopCounter, ctx: &Context) -> u64 { + ctx.stop(); + self.count } } - #[test] - pub fn task_fails_with_intermediate_state() { - let runtime = rt::Runtime::new().unwrap(); - runtime.block_on(async move { - let (rx, tx) = mpsc::channel::(); - let sender_channel = Arc::new(Mutex::new(tx)); - let _task = SomeTaskThatFailsOnInit::new(sender_channel).start(); - - // Wait a while to ensure the task has time to run and fail - rt::sleep(Duration::from_secs(1)).await; - - // We assure that the teardown function has ran by checking that the receiver channel is closed - assert!(rx.is_closed()) - }); - } - - // ==================== Backend enum tests ==================== - #[test] pub fn backend_default_is_async() { assert_eq!(Backend::default(), Backend::Async); @@ -746,8 +556,8 @@ mod tests { #[allow(clippy::clone_on_copy)] pub fn backend_enum_is_copy_and_clone() { let backend = Backend::Async; - let copied = backend; // Copy - let cloned = backend.clone(); // Clone - intentionally testing Clone trait + let copied = backend; + let cloned = backend.clone(); assert_eq!(backend, copied); assert_eq!(backend, cloned); } @@ -769,284 +579,183 @@ mod tests { assert_ne!(Backend::Blocking, Backend::Thread); } - // ==================== Backend functionality tests ==================== - - /// Simple counter Actor for testing all backends - struct Counter { - count: u64, - } - - #[derive(Clone)] - enum CounterRequest { - Get, - Increment, - Stop, - } - - #[derive(Clone)] - enum CounterMessage { - Increment, - } - - impl Actor for Counter { - type Request = CounterRequest; - type Message = CounterMessage; - type Reply = u64; - type Error = (); - - async fn handle_request( - &mut self, - message: Self::Request, - _: &ActorRef, - ) -> RequestResponse { - match message { - CounterRequest::Get => RequestResponse::Reply(self.count), - CounterRequest::Increment => { - self.count += 1; - RequestResponse::Reply(self.count) - } - CounterRequest::Stop => RequestResponse::Stop(self.count), - } - } - - async fn handle_message( - &mut self, - message: Self::Message, - _: &ActorRef, - ) -> MessageResponse { - match message { - CounterMessage::Increment => { - self.count += 1; - MessageResponse::NoReply - } - } - } - } - #[test] - pub fn backend_async_handles_call_and_cast() { + pub fn backend_async_handles_send_and_request() { let runtime = rt::Runtime::new().unwrap(); runtime.block_on(async move { - let mut counter = Counter { count: 0 }.start(); + let counter = Counter { count: 0 }.start(); - // Test call - let result = counter.request(CounterRequest::Get).await.unwrap(); + let result = counter.request(GetCount).await.unwrap(); assert_eq!(result, 0); - let result = counter.request(CounterRequest::Increment).await.unwrap(); + let result = counter.request(Increment).await.unwrap(); assert_eq!(result, 1); - // Test cast - counter.send(CounterMessage::Increment).await.unwrap(); - rt::sleep(Duration::from_millis(10)).await; // Give time for cast to process + // fire-and-forget send + counter.send(Increment).unwrap(); + rt::sleep(Duration::from_millis(10)).await; - let result = counter.request(CounterRequest::Get).await.unwrap(); + let result = counter.request(GetCount).await.unwrap(); assert_eq!(result, 2); - // Stop - let final_count = counter.request(CounterRequest::Stop).await.unwrap(); + let final_count = counter.request(StopCounter).await.unwrap(); assert_eq!(final_count, 2); }); } #[test] - pub fn backend_blocking_handles_call_and_cast() { + pub fn backend_blocking_handles_send_and_request() { let runtime = rt::Runtime::new().unwrap(); runtime.block_on(async move { - let mut counter = Counter { count: 0 }.start_with_backend(Backend::Blocking); + let counter = Counter { count: 0 }.start_with_backend(Backend::Blocking); - // Test call - let result = counter.request(CounterRequest::Get).await.unwrap(); + let result = counter.request(GetCount).await.unwrap(); assert_eq!(result, 0); - let result = counter.request(CounterRequest::Increment).await.unwrap(); + let result = counter.request(Increment).await.unwrap(); assert_eq!(result, 1); - // Test cast - counter.send(CounterMessage::Increment).await.unwrap(); - rt::sleep(Duration::from_millis(50)).await; // Give time for cast to process + counter.send(Increment).unwrap(); + rt::sleep(Duration::from_millis(50)).await; - let result = counter.request(CounterRequest::Get).await.unwrap(); + let result = counter.request(GetCount).await.unwrap(); assert_eq!(result, 2); - // Stop - let final_count = counter.request(CounterRequest::Stop).await.unwrap(); + let final_count = counter.request(StopCounter).await.unwrap(); assert_eq!(final_count, 2); }); } #[test] - pub fn backend_thread_handles_call_and_cast() { + pub fn backend_thread_handles_send_and_request() { let runtime = rt::Runtime::new().unwrap(); runtime.block_on(async move { - let mut counter = Counter { count: 0 }.start_with_backend(Backend::Thread); + let counter = Counter { count: 0 }.start_with_backend(Backend::Thread); - // Test call - let result = counter.request(CounterRequest::Get).await.unwrap(); + let result = counter.request(GetCount).await.unwrap(); assert_eq!(result, 0); - let result = counter.request(CounterRequest::Increment).await.unwrap(); + let result = counter.request(Increment).await.unwrap(); assert_eq!(result, 1); - // Test cast - counter.send(CounterMessage::Increment).await.unwrap(); - rt::sleep(Duration::from_millis(50)).await; // Give time for cast to process + counter.send(Increment).unwrap(); + rt::sleep(Duration::from_millis(50)).await; - let result = counter.request(CounterRequest::Get).await.unwrap(); + let result = counter.request(GetCount).await.unwrap(); assert_eq!(result, 2); - // Stop - let final_count = counter.request(CounterRequest::Stop).await.unwrap(); + let final_count = counter.request(StopCounter).await.unwrap(); assert_eq!(final_count, 2); }); } #[test] - pub fn backend_thread_isolates_blocking_work() { - // Similar to badly_behaved_thread but using Backend::Thread + pub fn multiple_backends_concurrent() { let runtime = rt::Runtime::new().unwrap(); runtime.block_on(async move { - let mut badboy = BadlyBehavedTask.start_with_backend(Backend::Thread); - let _ = badboy.send(Unused).await; - let mut goodboy = WellBehavedTask { count: 0 }.start(); - let _ = goodboy.send(Unused).await; - rt::sleep(Duration::from_secs(1)).await; - let count = goodboy.request(InMessage::GetCount).await.unwrap(); + let async_counter = Counter { count: 0 }.start(); + let blocking_counter = Counter { count: 100 }.start_with_backend(Backend::Blocking); + let thread_counter = Counter { count: 200 }.start_with_backend(Backend::Thread); - // goodboy should have run normally because badboy is on a separate thread - match count { - OutMsg::Count(num) => { - assert_eq!(num, 10); - } - } - goodboy.request(InMessage::Stop).await.unwrap(); + async_counter.request(Increment).await.unwrap(); + blocking_counter.request(Increment).await.unwrap(); + thread_counter.request(Increment).await.unwrap(); + + let async_val = async_counter.request(GetCount).await.unwrap(); + let blocking_val = blocking_counter.request(GetCount).await.unwrap(); + let thread_val = thread_counter.request(GetCount).await.unwrap(); + + assert_eq!(async_val, 1); + assert_eq!(blocking_val, 101); + assert_eq!(thread_val, 201); + + async_counter.request(StopCounter).await.unwrap(); + blocking_counter.request(StopCounter).await.unwrap(); + thread_counter.request(StopCounter).await.unwrap(); }); } #[test] - pub fn multiple_backends_concurrent() { + pub fn request_timeout() { let runtime = rt::Runtime::new().unwrap(); runtime.block_on(async move { - // Start counters on all three backends - let mut async_counter = Counter { count: 0 }.start(); - let mut blocking_counter = Counter { count: 100 }.start_with_backend(Backend::Blocking); - let mut thread_counter = Counter { count: 200 }.start_with_backend(Backend::Thread); - - // Increment each - async_counter - .request(CounterRequest::Increment) - .await - .unwrap(); - blocking_counter - .request(CounterRequest::Increment) - .await - .unwrap(); - thread_counter - .request(CounterRequest::Increment) - .await - .unwrap(); - - // Verify each has independent state - let async_val = async_counter.request(CounterRequest::Get).await.unwrap(); - let blocking_val = blocking_counter.request(CounterRequest::Get).await.unwrap(); - let thread_val = thread_counter.request(CounterRequest::Get).await.unwrap(); - - assert_eq!(async_val, 1); - assert_eq!(blocking_val, 101); - assert_eq!(thread_val, 201); + struct SlowActor; + messages! { SlowOp -> () } + impl Actor for SlowActor {} + impl Handler for SlowActor { + async fn handle(&mut self, _msg: SlowOp, _ctx: &Context) { + rt::sleep(Duration::from_millis(200)).await; + } + } - // Clean up - async_counter.request(CounterRequest::Stop).await.unwrap(); - blocking_counter - .request(CounterRequest::Stop) - .await - .unwrap(); - thread_counter.request(CounterRequest::Stop).await.unwrap(); + let actor = SlowActor.start(); + let result = actor + .request_with_timeout(SlowOp, Duration::from_millis(50)) + .await; + assert!(matches!(result, Err(ActorError::RequestTimeout))); }); } #[test] - pub fn backend_default_works_in_start() { + pub fn recipient_type_erasure() { let runtime = rt::Runtime::new().unwrap(); runtime.block_on(async move { - // Using Backend::default() should work the same as Backend::Async - let mut counter = Counter { count: 42 }.start_with_backend(Backend::Async); + let counter = Counter { count: 42 }.start(); + let recipient: Recipient = counter.recipient(); - let result = counter.request(CounterRequest::Get).await.unwrap(); + let rx = recipient.request_raw(GetCount).unwrap(); + let result = rx.await.unwrap(); assert_eq!(result, 42); - counter.request(CounterRequest::Stop).await.unwrap(); + // Also test request helper + let result = request(&*recipient, GetCount, Duration::from_secs(5)).await.unwrap(); + assert_eq!(result, 42); }); } - /// Actor that sleeps during teardown to simulate slow shutdown + // --- SlowShutdownActor for join tests --- + struct SlowShutdownActor; + messages! { StopSlow -> () } + impl Actor for SlowShutdownActor { - type Request = Unused; - type Message = Unused; - type Reply = Unused; - type Error = Unused; - - async fn handle_message( - &mut self, - _message: Self::Message, - _handle: &ActorRef, - ) -> MessageResponse { - MessageResponse::Stop + async fn stopped(&mut self, _ctx: &Context) { + thread::sleep(Duration::from_millis(500)); } + } - async fn teardown(self, _handle: &ActorRef) -> Result<(), Self::Error> { - // Simulate slow shutdown - this runs on the thread - std::thread::sleep(Duration::from_millis(500)); - Ok(()) + impl Handler for SlowShutdownActor { + async fn handle(&mut self, _msg: StopSlow, ctx: &Context) { + ctx.stop(); } } - /// Test that join() on a Backend::Thread actor doesn't block other async tasks. - /// - /// This test verifies that when we call join().await on an actor running on - /// Backend::Thread, it doesn't block the tokio runtime - other async tasks - /// should continue to make progress. - /// - /// Uses a single-threaded runtime to ensure we detect blocking behavior. #[test] pub fn thread_backend_join_does_not_block_runtime() { - // Use current_thread runtime to ensure blocking would be detected let runtime = tokio::runtime::Builder::new_current_thread() .enable_all() .build() .unwrap(); runtime.block_on(async move { - // Start a thread-backend actor that takes 500ms to teardown - let mut slow_actor = SlowShutdownActor.start_with_backend(Backend::Thread); + let slow_actor = SlowShutdownActor.start_with_backend(Backend::Thread); - // Spawn an async task that increments a counter every 50ms - let tick_count = Arc::new(std::sync::atomic::AtomicU64::new(0)); + let tick_count = Arc::new(atomic::AtomicU64::new(0)); let tick_count_clone = tick_count.clone(); let _ticker = rt::spawn(async move { for _ in 0..20 { rt::sleep(Duration::from_millis(50)).await; - tick_count_clone.fetch_add(1, std::sync::atomic::Ordering::SeqCst); + tick_count_clone.fetch_add(1, atomic::Ordering::SeqCst); } }); - // Tell the actor to stop - it will start its slow teardown - slow_actor.send(Unused).await.unwrap(); - - // Small delay to ensure the actor received the message + slow_actor.send(StopSlow).unwrap(); rt::sleep(Duration::from_millis(10)).await; - // Now join the actor - this waits for the 500ms teardown - // If implemented correctly, the ticker should continue running DURING the join slow_actor.join().await; - // Check tick count IMMEDIATELY after join returns, before awaiting ticker. - // The actor teardown takes 500ms. In that time, the ticker should have - // completed about 10 ticks (500ms / 50ms = 10). - // If join() blocked the runtime, the ticker would have 0-1 ticks. - let count_after_join = tick_count.load(std::sync::atomic::Ordering::SeqCst); + let count_after_join = tick_count.load(atomic::Ordering::SeqCst); assert!( count_after_join >= 8, "Ticker should have completed ~10 ticks during the 500ms join(), but only got {}. \ @@ -1056,19 +765,14 @@ mod tests { }); } - /// Test that multiple callers can wait on join() simultaneously. - /// - /// This verifies that the completion signal approach works correctly - /// when multiple tasks want to wait for the same actor to stop. #[test] pub fn multiple_join_callers_all_notified() { let runtime = rt::Runtime::new().unwrap(); runtime.block_on(async move { - let mut actor = SlowShutdownActor.start(); + let actor = SlowShutdownActor.start(); let actor_clone1 = actor.clone(); let actor_clone2 = actor.clone(); - // Spawn multiple tasks that will all call join() let join1 = rt::spawn(async move { actor_clone1.join().await; 1u32 @@ -1078,19 +782,105 @@ mod tests { 2u32 }); - // Give the join tasks time to start waiting rt::sleep(Duration::from_millis(10)).await; - // Tell the actor to stop - actor.send(Unused).await.unwrap(); + actor.send(StopSlow).unwrap(); - // All join tasks should complete after the actor stops let (r1, r2) = tokio::join!(join1, join2); assert_eq!(r1.unwrap(), 1); assert_eq!(r2.unwrap(), 2); - // Calling join again should return immediately (actor already stopped) actor.join().await; }); } + + // --- Badly behaved actors for blocking tests --- + + struct BadlyBehavedTask; + + messages! { DoBlock -> () } + + impl Actor for BadlyBehavedTask {} + + impl Handler for BadlyBehavedTask { + async fn handle(&mut self, _msg: DoBlock, ctx: &Context) { + rt::sleep(Duration::from_millis(20)).await; + thread::sleep(Duration::from_secs(2)); + ctx.stop(); + } + } + + messages! { IncrementWell -> () } + + struct WellBehavedTask { + pub count: u64, + } + + impl Actor for WellBehavedTask {} + + impl Handler for WellBehavedTask { + async fn handle(&mut self, _msg: GetCount, _ctx: &Context) -> u64 { + self.count + } + } + + impl Handler for WellBehavedTask { + async fn handle(&mut self, _msg: StopCounter, ctx: &Context) -> u64 { + ctx.stop(); + self.count + } + } + + impl Handler for WellBehavedTask { + async fn handle(&mut self, _msg: IncrementWell, ctx: &Context) { + self.count += 1; + use crate::tasks::send_after; + send_after(Duration::from_millis(100), ctx.clone(), IncrementWell); + } + } + + #[test] + pub fn badly_behaved_thread_non_blocking() { + let runtime = rt::Runtime::new().unwrap(); + runtime.block_on(async move { + let badboy = BadlyBehavedTask.start(); + badboy.send(DoBlock).unwrap(); + let goodboy = WellBehavedTask { count: 0 }.start(); + goodboy.send(IncrementWell).unwrap(); + rt::sleep(Duration::from_secs(1)).await; + let count = goodboy.request(GetCount).await.unwrap(); + assert_ne!(count, 10); + goodboy.request(StopCounter).await.unwrap(); + }); + } + + #[test] + pub fn badly_behaved_thread() { + let runtime = rt::Runtime::new().unwrap(); + runtime.block_on(async move { + let badboy = BadlyBehavedTask.start_with_backend(Backend::Blocking); + badboy.send(DoBlock).unwrap(); + let goodboy = WellBehavedTask { count: 0 }.start(); + goodboy.send(IncrementWell).unwrap(); + rt::sleep(Duration::from_secs(1)).await; + let count = goodboy.request(GetCount).await.unwrap(); + assert_eq!(count, 10); + goodboy.request(StopCounter).await.unwrap(); + }); + } + + #[test] + pub fn backend_thread_isolates_blocking_work() { + let runtime = rt::Runtime::new().unwrap(); + runtime.block_on(async move { + let badboy = BadlyBehavedTask.start_with_backend(Backend::Thread); + badboy.send(DoBlock).unwrap(); + let goodboy = WellBehavedTask { count: 0 }.start(); + goodboy.send(IncrementWell).unwrap(); + rt::sleep(Duration::from_secs(1)).await; + let count = goodboy.request(GetCount).await.unwrap(); + assert_eq!(count, 10); + goodboy.request(StopCounter).await.unwrap(); + }); + } } diff --git a/concurrency/src/tasks/mod.rs b/concurrency/src/tasks/mod.rs index dbbc269..c92e4cf 100644 --- a/concurrency/src/tasks/mod.rs +++ b/concurrency/src/tasks/mod.rs @@ -1,7 +1,4 @@ -//! spawned concurrency -//! Runtime tasks-based traits and structs to implement concurrent code à-la-Erlang. - -mod actor; +pub(crate) mod actor; mod process; mod stream; mod time; @@ -12,9 +9,11 @@ mod stream_tests; mod timer_tests; pub use actor::{ - send_message_on, Actor, ActorInMsg, ActorRef, Backend, InitResult, InitResult::NoSuccess, - InitResult::Success, MessageResponse, RequestResponse, + send_message_on, Actor, ActorRef, ActorStart, Backend, Context, Handler, Receiver, Recipient, + request, }; pub use process::{send, Process, ProcessInfo}; pub use stream::spawn_listener; -pub use time::{send_after, send_interval}; +pub use time::{send_after, send_interval, TimerHandle}; + +pub use crate::registry; diff --git a/concurrency/src/tasks/stream.rs b/concurrency/src/tasks/stream.rs index ebf09a3..afc5ce6 100644 --- a/concurrency/src/tasks/stream.rs +++ b/concurrency/src/tasks/stream.rs @@ -1,26 +1,23 @@ -use crate::tasks::{Actor, ActorRef}; +use crate::message::Message; use futures::{future::select, Stream, StreamExt}; use spawned_rt::tasks::JoinHandle; -/// Spawns a listener that listens to a stream and sends messages to an Actor. -/// -/// Items sent through the stream are required to be wrapped in a Result type. -/// -/// This function returns a handle to the spawned task and a cancellation token -/// to stop it. -pub fn spawn_listener(mut handle: ActorRef, stream: S) -> JoinHandle<()> +use super::actor::{Actor, Context, Handler}; + +pub fn spawn_listener(ctx: Context, stream: S) -> JoinHandle<()> where - T: Actor, - S: Send + Stream + 'static, + A: Actor + Handler, + M: Message, + S: Send + Stream + 'static, { - let cancellation_token = handle.cancellation_token(); + let cancellation_token = ctx.cancellation_token(); let join_handle = spawned_rt::tasks::spawn(async move { let mut pinned_stream = core::pin::pin!(stream); let is_cancelled = core::pin::pin!(cancellation_token.cancelled()); let listener_loop = core::pin::pin!(async { loop { match pinned_stream.next().await { - Some(msg) => match handle.send(msg).await { + Some(msg) => match ctx.send(msg) { Ok(_) => tracing::trace!("Message sent successfully"), Err(e) => { tracing::error!("Failed to send message: {e:?}"); @@ -36,7 +33,7 @@ where }); match select(is_cancelled, listener_loop).await { futures::future::Either::Left(_) => tracing::trace!("Actor stopped"), - futures::future::Either::Right(_) => (), // Stream finished or errored out + futures::future::Either::Right(_) => (), } }); join_handle diff --git a/concurrency/src/tasks/stream_tests.rs b/concurrency/src/tasks/stream_tests.rs index d270002..75d1000 100644 --- a/concurrency/src/tasks/stream_tests.rs +++ b/concurrency/src/tasks/stream_tests.rs @@ -1,11 +1,30 @@ use crate::tasks::{ - send_after, stream::spawn_listener, Actor, ActorRef, MessageResponse, RequestResponse, + send_after, Actor, ActorStart, Context, Handler, + stream::spawn_listener, }; +use crate::message::Message; use futures::{stream, StreamExt}; use spawned_rt::tasks::{self as rt, BroadcastStream, ReceiverStream}; use std::time::Duration; -type SummatoryHandle = ActorRef; +// --- Messages --- + +#[derive(Debug)] +enum StreamMsg { + Add(u16), + Error, +} +impl Message for StreamMsg { type Result = (); } + +#[derive(Debug)] +struct StopSum; +impl Message for StopSum { type Result = (); } + +#[derive(Debug)] +struct GetValue; +impl Message for GetValue { type Result = u16; } + +// --- Summatory Actor --- struct Summatory { count: u16, @@ -17,49 +36,26 @@ impl Summatory { } } -type SummatoryOutMessage = u16; - -#[derive(Clone)] -enum SummatoryCastMessage { - Add(u16), - StreamError, - Stop, -} +impl Actor for Summatory {} -impl Summatory { - pub async fn get_value(server: &mut SummatoryHandle) -> Result { - server.request(()).await.map_err(|_| ()) +impl Handler for Summatory { + async fn handle(&mut self, msg: StreamMsg, ctx: &Context) { + match msg { + StreamMsg::Add(val) => self.count += val, + StreamMsg::Error => ctx.stop(), + } } } -impl Actor for Summatory { - type Request = (); // We only handle one type of call, so there is no need for a specific message type. - type Message = SummatoryCastMessage; - type Reply = SummatoryOutMessage; - type Error = (); - - async fn handle_message( - &mut self, - message: Self::Message, - _handle: &ActorRef, - ) -> MessageResponse { - match message { - SummatoryCastMessage::Add(val) => { - self.count += val; - MessageResponse::NoReply - } - SummatoryCastMessage::StreamError => MessageResponse::Stop, - SummatoryCastMessage::Stop => MessageResponse::Stop, - } +impl Handler for Summatory { + async fn handle(&mut self, _msg: StopSum, ctx: &Context) { + ctx.stop(); } +} - async fn handle_request( - &mut self, - _message: Self::Request, - _handle: &SummatoryHandle, - ) -> RequestResponse { - let current_value = self.count; - RequestResponse::Reply(current_value) +impl Handler for Summatory { + async fn handle(&mut self, _msg: GetValue, _ctx: &Context) -> u16 { + self.count } } @@ -67,18 +63,18 @@ impl Actor for Summatory { pub fn test_sum_numbers_from_stream() { let runtime = rt::Runtime::new().unwrap(); runtime.block_on(async move { - let mut summatory_handle = Summatory::new(0).start(); + let summatory = Summatory::new(0).start(); let stream = stream::iter(vec![1u16, 2, 3, 4, 5].into_iter().map(Ok::)); + let ctx = Context::from_ref(&summatory); spawn_listener( - summatory_handle.clone(), - stream.filter_map(|result| async move { result.ok().map(SummatoryCastMessage::Add) }), + ctx, + stream.filter_map(|result| async move { result.ok().map(StreamMsg::Add) }), ); - // Wait for 1 second so the whole stream is processed rt::sleep(Duration::from_secs(1)).await; - let val = Summatory::get_value(&mut summatory_handle).await.unwrap(); + let val = summatory.request(GetValue).await.unwrap(); assert_eq!(val, 15); }) } @@ -87,26 +83,25 @@ pub fn test_sum_numbers_from_stream() { pub fn test_sum_numbers_from_channel() { let runtime = rt::Runtime::new().unwrap(); runtime.block_on(async move { - let mut summatory_handle = Summatory::new(0).start(); + let summatory = Summatory::new(0).start(); let (tx, rx) = spawned_rt::tasks::mpsc::channel::>(); - // Spawn a task to send numbers to the channel spawned_rt::tasks::spawn(async move { for i in 1..=5 { tx.send(Ok(i)).unwrap(); } }); + let ctx = Context::from_ref(&summatory); spawn_listener( - summatory_handle.clone(), + ctx, ReceiverStream::new(rx) - .filter_map(|result| async move { result.ok().map(SummatoryCastMessage::Add) }), + .filter_map(|result| async move { result.ok().map(StreamMsg::Add) }), ); - // Wait for 1 second so the whole stream is processed rt::sleep(Duration::from_secs(1)).await; - let val = Summatory::get_value(&mut summatory_handle).await.unwrap(); + let val = summatory.request(GetValue).await.unwrap(); assert_eq!(val, 15); }) } @@ -115,44 +110,40 @@ pub fn test_sum_numbers_from_channel() { pub fn test_sum_numbers_from_broadcast_channel() { let runtime = rt::Runtime::new().unwrap(); runtime.block_on(async move { - let mut summatory_handle = Summatory::new(0).start(); + let summatory = Summatory::new(0).start(); let (tx, rx) = tokio::sync::broadcast::channel::(5); - // Spawn a task to send numbers to the channel spawned_rt::tasks::spawn(async move { for i in 1u16..=5 { tx.send(i).unwrap(); } }); + let ctx = Context::from_ref(&summatory); spawn_listener( - summatory_handle.clone(), + ctx, BroadcastStream::new(rx) - .filter_map(|result| async move { result.ok().map(SummatoryCastMessage::Add) }), + .filter_map(|result| async move { result.ok().map(StreamMsg::Add) }), ); - // Wait for 1 second so the whole stream is processed rt::sleep(Duration::from_secs(1)).await; - let val = Summatory::get_value(&mut summatory_handle).await.unwrap(); + let val = summatory.request(GetValue).await.unwrap(); assert_eq!(val, 15); }) } #[test] pub fn test_stream_cancellation() { - // Messages sent at: t=0, t=250, t=500, t=750, t=1000ms - // We read at t=850ms (after 4th message at t=750, before 5th at t=1000) const MESSAGE_INTERVAL: u64 = 250; const READ_TIME: u64 = 850; const STOP_TIME: u64 = 1100; let runtime = rt::Runtime::new().unwrap(); runtime.block_on(async move { - let mut summatory_handle = Summatory::new(0).start(); + let summatory = Summatory::new(0).start(); let (tx, rx) = spawned_rt::tasks::mpsc::channel::>(); - // Spawn a task to send numbers to the channel spawned_rt::tasks::spawn(async move { for i in 1..=5 { tx.send(Ok(i)).unwrap(); @@ -160,34 +151,28 @@ pub fn test_stream_cancellation() { } }); + let ctx = Context::from_ref(&summatory); let listener_handle = spawn_listener( - summatory_handle.clone(), + ctx.clone(), ReceiverStream::new(rx) - .filter_map(|result| async move { result.ok().map(SummatoryCastMessage::Add) }), + .filter_map(|result| async move { result.ok().map(StreamMsg::Add) }), ); - // Start a timer to stop the actor after all messages would be sent - let summatory_handle_clone = summatory_handle.clone(); let _ = send_after( Duration::from_millis(STOP_TIME), - summatory_handle_clone, - SummatoryCastMessage::Stop, + ctx, + StopSum, ); - // Read value after 4th message (t=750) but before 5th (t=1000). - // Expected sum: 1+2+3+4 = 10, but allow some slack for timing variations. rt::sleep(Duration::from_millis(READ_TIME)).await; - let val = Summatory::get_value(&mut summatory_handle).await.unwrap(); + let val = summatory.request(GetValue).await.unwrap(); - // At t=850ms, we expect 4 messages processed (sum=10), but timing variations - // could result in 3 messages (sum=6) or occasionally all 5 (sum=15). assert!((1..=15).contains(&val)); assert!(listener_handle.await.is_ok()); - // Finally, we check that the server is stopped, by getting an error when trying to call it. rt::sleep(Duration::from_millis(10)).await; - assert!(Summatory::get_value(&mut summatory_handle).await.is_err()); + assert!(summatory.request(GetValue).await.is_err()); }) } @@ -195,22 +180,21 @@ pub fn test_stream_cancellation() { pub fn test_halting_on_stream_error() { let runtime = rt::Runtime::new().unwrap(); runtime.block_on(async move { - let mut summatory_handle = Summatory::new(0).start(); + let summatory = Summatory::new(0).start(); let stream = tokio_stream::iter(vec![Ok(1u16), Ok(2), Ok(3), Err(()), Ok(4), Ok(5)]); let msg_stream = stream.filter_map(|value| async move { match value { - Ok(number) => Some(SummatoryCastMessage::Add(number)), - Err(_) => Some(SummatoryCastMessage::StreamError), + Ok(number) => Some(StreamMsg::Add(number)), + Err(_) => Some(StreamMsg::Error), } }); - spawn_listener(summatory_handle.clone(), msg_stream); + let ctx = Context::from_ref(&summatory); + spawn_listener(ctx, msg_stream); - // Wait for 1 second so the whole stream is processed rt::sleep(Duration::from_secs(1)).await; - let result = Summatory::get_value(&mut summatory_handle).await; - // Actor should have been terminated, hence the result should be an error + let result = summatory.request(GetValue).await; assert!(result.is_err()); }) } @@ -219,21 +203,21 @@ pub fn test_halting_on_stream_error() { pub fn test_skipping_on_stream_error() { let runtime = rt::Runtime::new().unwrap(); runtime.block_on(async move { - let mut summatory_handle = Summatory::new(0).start(); + let summatory = Summatory::new(0).start(); let stream = tokio_stream::iter(vec![Ok(1u16), Ok(2), Ok(3), Err(()), Ok(4), Ok(5)]); let msg_stream = stream.filter_map(|value| async move { match value { - Ok(number) => Some(SummatoryCastMessage::Add(number)), + Ok(number) => Some(StreamMsg::Add(number)), Err(_) => None, } }); - spawn_listener(summatory_handle.clone(), msg_stream); + let ctx = Context::from_ref(&summatory); + spawn_listener(ctx, msg_stream); - // Wait for 1 second so the whole stream is processed rt::sleep(Duration::from_secs(1)).await; - let val = Summatory::get_value(&mut summatory_handle).await.unwrap(); + let val = summatory.request(GetValue).await.unwrap(); assert_eq!(val, 15); }) } diff --git a/concurrency/src/tasks/time.rs b/concurrency/src/tasks/time.rs index e334c81..69871e8 100644 --- a/concurrency/src/tasks/time.rs +++ b/concurrency/src/tasks/time.rs @@ -3,7 +3,8 @@ use std::time::Duration; use spawned_rt::tasks::{self as rt, CancellationToken, JoinHandle}; -use super::{Actor, ActorRef}; +use super::actor::{Actor, Context, Handler}; +use crate::message::Message; use core::pin::pin; pub struct TimerHandle { @@ -11,24 +12,22 @@ pub struct TimerHandle { pub cancellation_token: CancellationToken, } -// Sends a message after a given period to the specified Actor. The task terminates -// once the send has completed -pub fn send_after(period: Duration, mut handle: ActorRef, message: T::Message) -> TimerHandle +pub fn send_after(period: Duration, ctx: Context, msg: M) -> TimerHandle where - T: Actor + 'static, + A: Actor + Handler, + M: Message, { let cancellation_token = CancellationToken::new(); let cloned_token = cancellation_token.clone(); - let actor_cancellation_token = handle.cancellation_token(); + let actor_cancellation_token = ctx.cancellation_token(); let join_handle = rt::spawn(async move { - // Timer action is ignored if it was either cancelled or the associated Actor is no longer running. let cancel_token_fut = pin!(cloned_token.cancelled()); let actor_cancel_fut = pin!(actor_cancellation_token.cancelled()); let cancel_conditions = select(cancel_token_fut, actor_cancel_fut); let async_block = pin!(async { rt::sleep(period).await; - let _ = handle.send(message.clone()).await; + let _ = ctx.send(msg); }); let _ = select(cancel_conditions, async_block).await; }); @@ -38,28 +37,24 @@ where } } -// Sends a message to the specified Actor repeatedly after `Time` milliseconds. -pub fn send_interval( - period: Duration, - mut handle: ActorRef, - message: T::Message, -) -> TimerHandle +pub fn send_interval(period: Duration, ctx: Context, msg: M) -> TimerHandle where - T: Actor + 'static, + A: Actor + Handler, + M: Message + Clone, { let cancellation_token = CancellationToken::new(); let cloned_token = cancellation_token.clone(); - let actor_cancellation_token = handle.cancellation_token(); + let actor_cancellation_token = ctx.cancellation_token(); let join_handle = rt::spawn(async move { loop { - // Timer action is ignored if it was either cancelled or the associated Actor is no longer running. let cancel_token_fut = pin!(cloned_token.cancelled()); let actor_cancel_fut = pin!(actor_cancellation_token.cancelled()); let cancel_conditions = select(cancel_token_fut, actor_cancel_fut); + let msg_clone = msg.clone(); let async_block = pin!(async { rt::sleep(period).await; - let _ = handle.send(message.clone()).await; + let _ = ctx.send(msg_clone); }); let result = select(cancel_conditions, async_block).await; match result { diff --git a/concurrency/src/tasks/timer_tests.rs b/concurrency/src/tasks/timer_tests.rs index 46eb664..a1ddff7 100644 --- a/concurrency/src/tasks/timer_tests.rs +++ b/concurrency/src/tasks/timer_tests.rs @@ -1,31 +1,27 @@ use super::{ - send_after, send_interval, Actor, ActorRef, InitResult, InitResult::Success, MessageResponse, - RequestResponse, + send_after, send_interval, Actor, ActorStart, Context, Handler, }; +use crate::message::Message; use spawned_rt::tasks::{self as rt, CancellationToken}; use std::time::Duration; -type RepeaterHandle = ActorRef; +// --- Repeater (interval timer test) --- -#[derive(Clone)] -enum RepeaterCastMessage { - Inc, - StopTimer, -} +#[derive(Clone, Debug)] +struct Inc; +impl Message for Inc { type Result = (); } -#[derive(Clone)] -enum RepeaterCallMessage { - GetCount, -} +#[derive(Clone, Debug)] +struct StopTimer; +impl Message for StopTimer { type Result = (); } -#[derive(PartialEq, Debug)] -enum RepeaterOutMessage { - Count(i32), -} +#[derive(Debug)] +struct GetRepCount; +impl Message for GetRepCount { type Result = i32; } struct Repeater { - pub(crate) count: i32, - pub(crate) cancellation_token: Option, + count: i32, + cancellation_token: Option, } impl Repeater { @@ -37,63 +33,34 @@ impl Repeater { } } -impl Repeater { - pub async fn stop_timer(server: &mut RepeaterHandle) -> Result<(), ()> { - server - .send(RepeaterCastMessage::StopTimer) - .await - .map_err(|_| ()) - } - - pub async fn get_count(server: &mut RepeaterHandle) -> Result { - server - .request(RepeaterCallMessage::GetCount) - .await - .map_err(|_| ()) - } -} - impl Actor for Repeater { - type Request = RepeaterCallMessage; - type Message = RepeaterCastMessage; - type Reply = RepeaterOutMessage; - type Error = (); - - async fn init(mut self, handle: &RepeaterHandle) -> Result, Self::Error> { + async fn started(&mut self, ctx: &Context) { let timer = send_interval( Duration::from_millis(100), - handle.clone(), - RepeaterCastMessage::Inc, + ctx.clone(), + Inc, ); self.cancellation_token = Some(timer.cancellation_token); - Ok(Success(self)) } +} + +impl Handler for Repeater { + async fn handle(&mut self, _msg: Inc, _ctx: &Context) { + self.count += 1; + } +} - async fn handle_request( - &mut self, - _message: Self::Request, - _handle: &RepeaterHandle, - ) -> RequestResponse { - let count = self.count; - RequestResponse::Reply(RepeaterOutMessage::Count(count)) +impl Handler for Repeater { + async fn handle(&mut self, _msg: StopTimer, _ctx: &Context) { + if let Some(ct) = self.cancellation_token.clone() { + ct.cancel(); + } } +} - async fn handle_message( - &mut self, - message: Self::Message, - _handle: &ActorRef, - ) -> MessageResponse { - match message { - RepeaterCastMessage::Inc => { - self.count += 1; - } - RepeaterCastMessage::StopTimer => { - if let Some(ct) = self.cancellation_token.clone() { - ct.cancel() - }; - } - }; - MessageResponse::NoReply +impl Handler for Repeater { + async fn handle(&mut self, _msg: GetRepCount, _ctx: &Context) -> i32 { + self.count } } @@ -101,109 +68,60 @@ impl Actor for Repeater { pub fn test_send_interval_and_cancellation() { let runtime = rt::Runtime::new().unwrap(); runtime.block_on(async move { - // Start a Repeater - let mut repeater = Repeater::new(0).start(); + let repeater = Repeater::new(0).start(); - // Wait for 1 second rt::sleep(Duration::from_secs(1)).await; - // Check count - let count = Repeater::get_count(&mut repeater).await.unwrap(); + let count = repeater.request(GetRepCount).await.unwrap(); + assert_eq!(9, count); - // 9 messages in 1 second (after first 100 milliseconds sleep) - assert_eq!(RepeaterOutMessage::Count(9), count); + repeater.send(StopTimer).unwrap(); - // Pause timer - Repeater::stop_timer(&mut repeater).await.unwrap(); - - // Wait another second rt::sleep(Duration::from_secs(1)).await; - // Check count again - let count2 = Repeater::get_count(&mut repeater).await.unwrap(); - - // As timer was paused, count should remain at 9 - assert_eq!(RepeaterOutMessage::Count(9), count2); + let count2 = repeater.request(GetRepCount).await.unwrap(); + assert_eq!(9, count2); }); } -type DelayedHandle = ActorRef; - -#[derive(Clone)] -enum DelayedCastMessage { - Inc, -} +// --- Delayed (send_after test) --- -#[derive(Clone)] -enum DelayedCallMessage { - GetCount, - Stop, -} +#[derive(Debug)] +struct GetDelCount; +impl Message for GetDelCount { type Result = i32; } -#[derive(PartialEq, Debug)] -enum DelayedOutMessage { - Count(i32), -} +#[derive(Debug)] +struct StopDelayed; +impl Message for StopDelayed { type Result = i32; } struct Delayed { - pub(crate) count: i32, + count: i32, } impl Delayed { pub fn new(initial_count: i32) -> Self { - Delayed { - count: initial_count, - } + Delayed { count: initial_count } } } -impl Delayed { - pub async fn get_count(server: &mut DelayedHandle) -> Result { - server - .request(DelayedCallMessage::GetCount) - .await - .map_err(|_| ()) - } +impl Actor for Delayed {} - pub async fn stop(server: &mut DelayedHandle) -> Result { - server - .request(DelayedCallMessage::Stop) - .await - .map_err(|_| ()) +impl Handler for Delayed { + async fn handle(&mut self, _msg: Inc, _ctx: &Context) { + self.count += 1; } } -impl Actor for Delayed { - type Request = DelayedCallMessage; - type Message = DelayedCastMessage; - type Reply = DelayedOutMessage; - type Error = (); - - async fn handle_request( - &mut self, - message: Self::Request, - _handle: &DelayedHandle, - ) -> RequestResponse { - match message { - DelayedCallMessage::GetCount => { - let count = self.count; - RequestResponse::Reply(DelayedOutMessage::Count(count)) - } - DelayedCallMessage::Stop => RequestResponse::Stop(DelayedOutMessage::Count(self.count)), - } +impl Handler for Delayed { + async fn handle(&mut self, _msg: GetDelCount, _ctx: &Context) -> i32 { + self.count } +} - async fn handle_message( - &mut self, - message: Self::Message, - _handle: &DelayedHandle, - ) -> MessageResponse { - match message { - DelayedCastMessage::Inc => { - self.count += 1; - } - }; - MessageResponse::NoReply +impl Handler for Delayed { + async fn handle(&mut self, _msg: StopDelayed, ctx: &Context) -> i32 { + ctx.stop(); + self.count } } @@ -211,43 +129,33 @@ impl Actor for Delayed { pub fn test_send_after_and_cancellation() { let runtime = rt::Runtime::new().unwrap(); runtime.block_on(async move { - // Start a Delayed - let mut repeater = Delayed::new(0).start(); + let repeater = Delayed::new(0).start(); - // Set a just once timed message + let ctx = Context::from_ref(&repeater); let _ = send_after( Duration::from_millis(100), - repeater.clone(), - DelayedCastMessage::Inc, + ctx, + Inc, ); - // Wait for 200 milliseconds rt::sleep(Duration::from_millis(200)).await; - // Check count - let count = Delayed::get_count(&mut repeater).await.unwrap(); + let count = repeater.request(GetDelCount).await.unwrap(); + assert_eq!(1, count); - // Only one message (no repetition) - assert_eq!(DelayedOutMessage::Count(1), count); - - // New timer + let ctx = Context::from_ref(&repeater); let timer = send_after( Duration::from_millis(100), - repeater.clone(), - DelayedCastMessage::Inc, + ctx, + Inc, ); - // Cancel the new timer before timeout timer.cancellation_token.cancel(); - // Wait another 200 milliseconds rt::sleep(Duration::from_millis(200)).await; - // Check count again - let count2 = Delayed::get_count(&mut repeater).await.unwrap(); - - // As timer was cancelled, count should remain at 1 - assert_eq!(DelayedOutMessage::Count(1), count2); + let count2 = repeater.request(GetDelCount).await.unwrap(); + assert_eq!(1, count2); }); } @@ -255,39 +163,31 @@ pub fn test_send_after_and_cancellation() { pub fn test_send_after_gen_server_teardown() { let runtime = rt::Runtime::new().unwrap(); runtime.block_on(async move { - // Start a Delayed - let mut repeater = Delayed::new(0).start(); + let repeater = Delayed::new(0).start(); - // Set a just once timed message + let ctx = Context::from_ref(&repeater); let _ = send_after( Duration::from_millis(100), - repeater.clone(), - DelayedCastMessage::Inc, + ctx, + Inc, ); - // Wait for 200 milliseconds rt::sleep(Duration::from_millis(200)).await; - // Check count - let count = Delayed::get_count(&mut repeater).await.unwrap(); - - // Only one message (no repetition) - assert_eq!(DelayedOutMessage::Count(1), count); + let count = repeater.request(GetDelCount).await.unwrap(); + assert_eq!(1, count); - // New timer + let ctx = Context::from_ref(&repeater); let _ = send_after( Duration::from_millis(100), - repeater.clone(), - DelayedCastMessage::Inc, + ctx, + Inc, ); - // Stop the Actor before timeout - let count2 = Delayed::stop(&mut repeater).await.unwrap(); + let count2 = repeater.request(StopDelayed).await.unwrap(); - // Wait another 200 milliseconds rt::sleep(Duration::from_millis(200)).await; - // As timer was cancelled, count should remain at 1 - assert_eq!(DelayedOutMessage::Count(1), count2); + assert_eq!(1, count2); }); } diff --git a/concurrency/src/threads/actor.rs b/concurrency/src/threads/actor.rs index 04796b7..f7e7a53 100644 --- a/concurrency/src/threads/actor.rs +++ b/concurrency/src/threads/actor.rs @@ -1,5 +1,3 @@ -//! Actor trait and structs to create an abstraction similar to Erlang gen_server. -//! See examples/name_server for a usage example. use spawned_rt::threads::{ self as rt, mpsc, oneshot, oneshot::RecvTimeoutError, CancellationToken, }; @@ -11,11 +9,197 @@ use std::{ }; use crate::error::ActorError; +use crate::message::Message; const DEFAULT_REQUEST_TIMEOUT: Duration = Duration::from_secs(5); -/// Guard that signals completion when dropped. -/// Ensures waiters are notified even if the actor thread panics. +// --------------------------------------------------------------------------- +// Actor trait +// --------------------------------------------------------------------------- + +pub trait Actor: Send + Sized + 'static { + fn started(&mut self, _ctx: &Context) {} + fn stopped(&mut self, _ctx: &Context) {} +} + +// --------------------------------------------------------------------------- +// Handler trait (per-message, sync version) +// --------------------------------------------------------------------------- + +pub trait Handler: Actor { + fn handle(&mut self, msg: M, ctx: &Context) -> M::Result; +} + +// --------------------------------------------------------------------------- +// Envelope (type-erasure) +// --------------------------------------------------------------------------- + +trait Envelope: Send { + fn handle(self: Box, actor: &mut A, ctx: &Context); +} + +struct MessageEnvelope { + msg: M, + tx: Option>, +} + +impl Envelope for MessageEnvelope +where + A: Actor + Handler, + M: Message, +{ + fn handle(self: Box, actor: &mut A, ctx: &Context) { + let result = actor.handle(self.msg, ctx); + if let Some(tx) = self.tx { + let _ = tx.send(result); + } + } +} + +// --------------------------------------------------------------------------- +// Context +// --------------------------------------------------------------------------- + +pub struct Context { + sender: mpsc::Sender + Send>>, + cancellation_token: CancellationToken, +} + +impl Clone for Context { + fn clone(&self) -> Self { + Self { + sender: self.sender.clone(), + cancellation_token: self.cancellation_token.clone(), + } + } +} + +impl Debug for Context { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Context").finish_non_exhaustive() + } +} + +impl Context { + pub fn from_ref(actor_ref: &ActorRef) -> Self { + Self { + sender: actor_ref.sender.clone(), + cancellation_token: actor_ref.cancellation_token.clone(), + } + } + + pub fn stop(&self) { + self.cancellation_token.cancel(); + } + + pub fn send(&self, msg: M) -> Result<(), ActorError> + where + A: Handler, + M: Message, + { + let envelope = MessageEnvelope { msg, tx: None }; + self.sender + .send(Box::new(envelope)) + .map_err(|_| ActorError::ActorStopped) + } + + pub fn request_raw(&self, msg: M) -> Result, ActorError> + where + A: Handler, + M: Message, + { + let (tx, rx) = oneshot::channel(); + let envelope = MessageEnvelope { + msg, + tx: Some(tx), + }; + self.sender + .send(Box::new(envelope)) + .map_err(|_| ActorError::ActorStopped)?; + Ok(rx) + } + + pub fn request(&self, msg: M) -> Result + where + A: Handler, + M: Message, + { + self.request_with_timeout(msg, DEFAULT_REQUEST_TIMEOUT) + } + + pub fn request_with_timeout( + &self, + msg: M, + duration: Duration, + ) -> Result + where + A: Handler, + M: Message, + { + let rx = self.request_raw(msg)?; + match rx.recv_timeout(duration) { + Ok(result) => Ok(result), + Err(RecvTimeoutError::Timeout) => Err(ActorError::RequestTimeout), + Err(RecvTimeoutError::Disconnected) => Err(ActorError::ActorStopped), + } + } + + pub fn recipient(&self) -> Recipient + where + A: Handler, + M: Message, + { + Arc::new(self.clone()) + } + + pub(crate) fn cancellation_token(&self) -> CancellationToken { + self.cancellation_token.clone() + } +} + +// Bridge: Context implements Receiver for any M that A handles +impl Receiver for Context +where + A: Actor + Handler, + M: Message, +{ + fn send(&self, msg: M) -> Result<(), ActorError> { + Context::send(self, msg) + } + + fn request_raw(&self, msg: M) -> Result, ActorError> { + Context::request_raw(self, msg) + } +} + +// --------------------------------------------------------------------------- +// Receiver trait (object-safe) + Recipient alias +// --------------------------------------------------------------------------- + +pub trait Receiver: Send + Sync { + fn send(&self, msg: M) -> Result<(), ActorError>; + fn request_raw(&self, msg: M) -> Result, ActorError>; +} + +pub type Recipient = Arc>; + +pub fn request( + recipient: &dyn Receiver, + msg: M, + timeout: Duration, +) -> Result { + let rx = recipient.request_raw(msg)?; + match rx.recv_timeout(timeout) { + Ok(result) => Ok(result), + Err(RecvTimeoutError::Timeout) => Err(ActorError::RequestTimeout), + Err(RecvTimeoutError::Disconnected) => Err(ActorError::ActorStopped), + } +} + +// --------------------------------------------------------------------------- +// ActorRef +// --------------------------------------------------------------------------- + struct CompletionGuard(Arc<(Mutex, Condvar)>); impl Drop for CompletionGuard { @@ -27,94 +211,93 @@ impl Drop for CompletionGuard { } } -pub struct ActorRef { - pub tx: mpsc::Sender>, +pub struct ActorRef { + sender: mpsc::Sender + Send>>, cancellation_token: CancellationToken, - /// Completion signal: (is_completed, condvar for waiters) completion: Arc<(Mutex, Condvar)>, } +impl Debug for ActorRef { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ActorRef").finish_non_exhaustive() + } +} + impl Clone for ActorRef { fn clone(&self) -> Self { Self { - tx: self.tx.clone(), + sender: self.sender.clone(), cancellation_token: self.cancellation_token.clone(), completion: self.completion.clone(), } } } -impl std::fmt::Debug for ActorRef { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("ActorRef") - .field("tx", &self.tx) - .field("cancellation_token", &self.cancellation_token) - .finish_non_exhaustive() - } -} - impl ActorRef { - pub(crate) fn new(actor: A) -> Self { - let (tx, mut rx) = mpsc::channel::>(); - let cancellation_token = CancellationToken::new(); - let completion = Arc::new((Mutex::new(false), Condvar::new())); - let handle = ActorRef { - tx, - cancellation_token, - completion: completion.clone(), - }; - let handle_clone = handle.clone(); - let _thread_handle = rt::spawn(move || { - // Guard ensures completion is signaled even if actor panics - let _guard = CompletionGuard(completion); - if actor.run(&handle, &mut rx).is_err() { - tracing::trace!("Actor crashed") - }; - }); - handle_clone + pub fn send(&self, msg: M) -> Result<(), ActorError> + where + A: Handler, + M: Message, + { + let envelope = MessageEnvelope { msg, tx: None }; + self.sender + .send(Box::new(envelope)) + .map_err(|_| ActorError::ActorStopped) } - pub fn sender(&self) -> mpsc::Sender> { - self.tx.clone() + pub fn request_raw(&self, msg: M) -> Result, ActorError> + where + A: Handler, + M: Message, + { + let (tx, rx) = oneshot::channel(); + let envelope = MessageEnvelope { + msg, + tx: Some(tx), + }; + self.sender + .send(Box::new(envelope)) + .map_err(|_| ActorError::ActorStopped)?; + Ok(rx) } - pub fn request(&mut self, message: A::Request) -> Result { - self.request_with_timeout(message, DEFAULT_REQUEST_TIMEOUT) + pub fn request(&self, msg: M) -> Result + where + A: Handler, + M: Message, + { + self.request_with_timeout(msg, DEFAULT_REQUEST_TIMEOUT) } - pub fn request_with_timeout( - &mut self, - message: A::Request, + pub fn request_with_timeout( + &self, + msg: M, duration: Duration, - ) -> Result { - let (oneshot_tx, oneshot_rx) = oneshot::channel::>(); - self.tx.send(ActorInMsg::Request { - sender: oneshot_tx, - message, - })?; - match oneshot_rx.recv_timeout(duration) { - Ok(result) => result, + ) -> Result + where + A: Handler, + M: Message, + { + let rx = self.request_raw(msg)?; + match rx.recv_timeout(duration) { + Ok(result) => Ok(result), Err(RecvTimeoutError::Timeout) => Err(ActorError::RequestTimeout), - Err(RecvTimeoutError::Disconnected) => Err(ActorError::Server), + Err(RecvTimeoutError::Disconnected) => Err(ActorError::ActorStopped), } } - pub fn send(&mut self, message: A::Message) -> Result<(), ActorError> { - self.tx - .send(ActorInMsg::Message { message }) - .map_err(|_error| ActorError::Server) + pub fn recipient(&self) -> Recipient + where + A: Handler, + M: Message, + { + Arc::new(self.clone()) } - pub(crate) fn cancellation_token(&self) -> CancellationToken { - self.cancellation_token.clone() + pub fn context(&self) -> Context { + Context::from_ref(self) } - /// Blocks until the actor has stopped. - /// - /// This method blocks the current thread until the actor has finished - /// processing and exited its main loop. Can be called multiple times from - /// different clones of the ActorRef - all callers will be notified when - /// the actor stops. pub fn join(&self) { let (lock, cvar) = &*self.completion; let mut completed = lock.lock().unwrap_or_else(|p| p.into_inner()); @@ -124,191 +307,114 @@ impl ActorRef { } } -pub enum ActorInMsg { - Request { - sender: oneshot::Sender>, - message: A::Request, - }, - Message { - message: A::Message, - }, -} - -pub enum RequestResponse { - Reply(A::Reply), - Unused, - Stop(A::Reply), -} - -pub enum MessageResponse { - NoReply, - Unused, - Stop, -} +// Bridge: ActorRef implements Receiver for any M that A handles +impl Receiver for ActorRef +where + A: Actor + Handler, + M: Message, +{ + fn send(&self, msg: M) -> Result<(), ActorError> { + ActorRef::send(self, msg) + } -pub enum InitResult { - Success(A), - NoSuccess(A), + fn request_raw(&self, msg: M) -> Result, ActorError> { + ActorRef::request_raw(self, msg) + } } -pub trait Actor: Send + Sized { - type Request: Clone + Send + Sized + Sync; - type Message: Clone + Send + Sized + Sync; - type Reply: Send + Sized; - type Error: Debug + Send; +// --------------------------------------------------------------------------- +// Actor startup + main loop +// --------------------------------------------------------------------------- - fn start(self) -> ActorRef { - ActorRef::new(self) - } +impl ActorRef { + fn spawn(actor: A) -> Self { + let (tx, rx) = mpsc::channel:: + Send>>(); + let cancellation_token = CancellationToken::new(); + let completion = Arc::new((Mutex::new(false), Condvar::new())); - fn run( - self, - handle: &ActorRef, - rx: &mut mpsc::Receiver>, - ) -> Result<(), ActorError> { - let cancellation_token = handle.cancellation_token.clone(); - - let res = match self.init(handle) { - Ok(InitResult::Success(new_state)) => { - let final_state = new_state.main_loop(handle, rx)?; - Ok(final_state) - } - Ok(InitResult::NoSuccess(intermediate_state)) => { - // Initialization failed but error was handled in callback. - // Skip main_loop and return state for teardown. - Ok(intermediate_state) - } - Err(err) => { - tracing::error!("Initialization failed with unhandled error: {err:?}"); - Err(ActorError::Initialization) - } + let actor_ref = ActorRef { + sender: tx.clone(), + cancellation_token: cancellation_token.clone(), + completion: completion.clone(), }; - cancellation_token.cancel(); + let ctx = Context { + sender: tx, + cancellation_token: cancellation_token.clone(), + }; - if let Ok(final_state) = res { - if let Err(err) = final_state.teardown(handle) { - tracing::error!("Error during teardown: {err:?}"); - } - } + let _thread_handle = rt::spawn(move || { + let _guard = CompletionGuard(completion); + run_actor(actor, ctx, rx, cancellation_token); + }); - Ok(()) + actor_ref } +} - /// Initialization function. It's called before main loop. It - /// can be overrided on implementations in case initial steps are - /// required. - fn init(self, _handle: &ActorRef) -> Result, Self::Error> { - Ok(InitResult::Success(self)) - } +fn run_actor( + mut actor: A, + ctx: Context, + rx: mpsc::Receiver + Send>>, + cancellation_token: CancellationToken, +) { + actor.started(&ctx); - fn main_loop( - mut self, - handle: &ActorRef, - rx: &mut mpsc::Receiver>, - ) -> Result { - loop { - if !self.receive(handle, rx)? { - break; - } - } - tracing::trace!("Stopping Actor"); - Ok(self) + if cancellation_token.is_cancelled() { + actor.stopped(&ctx); + return; } - fn receive( - &mut self, - handle: &ActorRef, - rx: &mut mpsc::Receiver>, - ) -> Result { - let message = rx.recv().ok(); - - let keep_running = match message { - Some(ActorInMsg::Request { sender, message }) => { - let (keep_running, response) = match catch_unwind(AssertUnwindSafe(|| { - self.handle_request(message, handle) - })) { - Ok(response) => match response { - RequestResponse::Reply(response) => (true, Ok(response)), - RequestResponse::Stop(response) => (false, Ok(response)), - RequestResponse::Unused => { - tracing::error!("Actor received unexpected Request"); - (false, Err(ActorError::RequestUnused)) - } - }, - Err(error) => { - tracing::error!("Error in callback: '{error:?}'"); - (true, Err(ActorError::Callback)) - } - }; - // Send response back - if sender.send(response).is_err() { - tracing::trace!("Actor failed to send response back, client must have died") - }; - keep_running - } - Some(ActorInMsg::Message { message }) => { - match catch_unwind(AssertUnwindSafe(|| self.handle_message(message, handle))) { - Ok(response) => match response { - MessageResponse::NoReply => true, - MessageResponse::Stop => false, - MessageResponse::Unused => { - tracing::error!("Actor received unexpected Message"); - false - } - }, - Err(error) => { - tracing::error!("Error in callback: '{error:?}'"); - true - } + loop { + let msg = rx.recv().ok(); + match msg { + Some(envelope) => { + let result = catch_unwind(AssertUnwindSafe(|| { + envelope.handle(&mut actor, &ctx); + })); + if let Err(panic) = result { + tracing::error!("Panic in message handler: {panic:?}"); + break; + } + if cancellation_token.is_cancelled() { + break; } } - None => { - // Channel has been closed; won't receive further messages. Stop the server. - false - } - }; - Ok(keep_running) + None => break, + } } - fn handle_request( - &mut self, - _message: Self::Request, - _handle: &ActorRef, - ) -> RequestResponse { - RequestResponse::Unused - } + cancellation_token.cancel(); + actor.stopped(&ctx); +} - fn handle_message( - &mut self, - _message: Self::Message, - _handle: &ActorRef, - ) -> MessageResponse { - MessageResponse::Unused - } +// --------------------------------------------------------------------------- +// Actor::start +// --------------------------------------------------------------------------- - /// Teardown function. It's called after the stop message is received. - /// It can be overrided on implementations in case final steps are required, - /// like closing streams, stopping timers, etc. - fn teardown(self, _handle: &ActorRef) -> Result<(), Self::Error> { - Ok(()) +pub trait ActorStart: Actor { + fn start(self) -> ActorRef { + ActorRef::spawn(self) } } -/// Spawns a thread that runs a blocking operation and sends a message to an Actor -/// on completion. This is the sync equivalent of tasks::send_message_on. -/// This function returns a handle to the spawned thread. -pub fn send_message_on(handle: ActorRef, f: F, message: T::Message) -> rt::JoinHandle<()> +impl ActorStart for A {} + +// --------------------------------------------------------------------------- +// send_message_on (utility) +// --------------------------------------------------------------------------- + +pub fn send_message_on(ctx: Context, f: F, msg: M) -> rt::JoinHandle<()> where - T: Actor, + A: Actor + Handler, + M: Message, F: FnOnce() + Send + 'static, { - let cancellation_token = handle.cancellation_token(); - let mut handle_clone = handle.clone(); + let cancellation_token = ctx.cancellation_token(); rt::spawn(move || { f(); if !cancellation_token.is_cancelled() { - if let Err(e) = handle_clone.send(message) { + if let Err(e) = ctx.send(msg) { tracing::error!("Failed to send message: {e:?}") } } diff --git a/concurrency/src/threads/mod.rs b/concurrency/src/threads/mod.rs index 9643a13..0e6758c 100644 --- a/concurrency/src/threads/mod.rs +++ b/concurrency/src/threads/mod.rs @@ -1,7 +1,4 @@ -//! spawned concurrency -//! IO threads-based traits and structs to implement concurrent code à-la-Erlang. - -mod actor; +pub(crate) mod actor; mod process; mod stream; mod time; @@ -10,9 +7,11 @@ mod time; mod timer_tests; pub use actor::{ - send_message_on, Actor, ActorInMsg, ActorRef, InitResult, InitResult::NoSuccess, - InitResult::Success, MessageResponse, RequestResponse, + send_message_on, Actor, ActorRef, ActorStart, Context, Handler, Receiver, Recipient, + request, }; pub use process::{send, Process, ProcessInfo}; pub use stream::spawn_listener; -pub use time::{send_after, send_interval}; +pub use time::{send_after, send_interval, TimerHandle}; + +pub use crate::registry; diff --git a/concurrency/src/threads/stream.rs b/concurrency/src/threads/stream.rs index 696c3cf..9249246 100644 --- a/concurrency/src/threads/stream.rs +++ b/concurrency/src/threads/stream.rs @@ -1,21 +1,21 @@ use std::thread::JoinHandle; -use crate::threads::{Actor, ActorRef}; +use crate::message::Message; -/// Spawns a listener that listens to a stream and sends messages to an Actor. -/// -/// Items sent through the stream are required to be wrapped in a Result type. -pub fn spawn_listener(mut handle: ActorRef, stream: I) -> JoinHandle<()> +use super::actor::{Actor, Context, Handler}; + +pub fn spawn_listener(ctx: Context, stream: I) -> JoinHandle<()> where - T: Actor, - I: IntoIterator, + A: Actor + Handler, + M: Message, + I: IntoIterator, ::IntoIter: std::marker::Send + 'static, { let mut iter = stream.into_iter(); - let cancellation_token = handle.cancellation_token(); + let cancellation_token = ctx.cancellation_token(); let join_handle = spawned_rt::threads::spawn(move || loop { match iter.next() { - Some(msg) => match handle.send(msg) { + Some(msg) => match ctx.send(msg) { Ok(_) => tracing::trace!("Message sent successfully"), Err(e) => { tracing::error!("Failed to send message: {e:?}"); diff --git a/concurrency/src/threads/time.rs b/concurrency/src/threads/time.rs index 5b4ebb8..78fb1cd 100644 --- a/concurrency/src/threads/time.rs +++ b/concurrency/src/threads/time.rs @@ -3,36 +3,30 @@ use std::time::Duration; use spawned_rt::threads::{self as rt, CancellationToken, JoinHandle}; -use super::{Actor, ActorRef}; +use super::actor::{Actor, Context, Handler}; +use crate::message::Message; pub struct TimerHandle { pub join_handle: JoinHandle<()>, pub cancellation_token: CancellationToken, } -/// Sends a message after a given period to the specified Actor. -/// -/// The timer respects both its own cancellation token and the Actor's -/// cancellation token. If either is cancelled, the timer wakes up immediately -/// and exits without sending the message. -pub fn send_after(period: Duration, mut handle: ActorRef, message: T::Message) -> TimerHandle +pub fn send_after(period: Duration, ctx: Context, msg: M) -> TimerHandle where - T: Actor + 'static, + A: Actor + Handler, + M: Message, { let cancellation_token = CancellationToken::new(); let timer_token = cancellation_token.clone(); - let actor_token = handle.cancellation_token(); + let actor_token = ctx.cancellation_token(); - // Channel to wake the timer thread on cancellation let (wake_tx, wake_rx) = mpsc::channel::<()>(); - // Register wake-up on timer cancellation let wake_tx1 = wake_tx.clone(); timer_token.on_cancel(Box::new(move || { let _ = wake_tx1.send(()); })); - // Register wake-up on actor cancellation actor_token.on_cancel(Box::new(move || { let _ = wake_tx.send(()); })); @@ -40,14 +34,11 @@ where let join_handle = rt::spawn(move || { match wake_rx.recv_timeout(period) { Err(RecvTimeoutError::Timeout) => { - // Timer expired - send if still valid if !timer_token.is_cancelled() && !actor_token.is_cancelled() { - let _ = handle.send(message); + let _ = ctx.send(msg); } } - Ok(()) | Err(RecvTimeoutError::Disconnected) => { - // Woken early by cancellation - exit without sending - } + Ok(()) | Err(RecvTimeoutError::Disconnected) => {} } }); @@ -57,46 +48,33 @@ where } } -/// Sends a message to the specified Actor repeatedly at the given interval. -/// -/// The timer respects both its own cancellation token and the Actor's -/// cancellation token. If either is cancelled, the timer wakes up immediately -/// and exits. -pub fn send_interval( - period: Duration, - mut handle: ActorRef, - message: T::Message, -) -> TimerHandle +pub fn send_interval(period: Duration, ctx: Context, msg: M) -> TimerHandle where - T: Actor + 'static, + A: Actor + Handler, + M: Message + Clone, { let cancellation_token = CancellationToken::new(); let timer_token = cancellation_token.clone(); - let actor_token = handle.cancellation_token(); + let actor_token = ctx.cancellation_token(); - // Channel to wake the timer thread on cancellation let (wake_tx, wake_rx) = mpsc::channel::<()>(); - // Register wake-up on timer cancellation let wake_tx1 = wake_tx.clone(); timer_token.on_cancel(Box::new(move || { let _ = wake_tx1.send(()); })); - // Register wake-up on actor cancellation actor_token.on_cancel(Box::new(move || { let _ = wake_tx.send(()); })); let join_handle = rt::spawn(move || { while let Err(RecvTimeoutError::Timeout) = wake_rx.recv_timeout(period) { - // Timer expired - send if still valid if timer_token.is_cancelled() || actor_token.is_cancelled() { break; } - let _ = handle.send(message.clone()); + let _ = ctx.send(msg.clone()); } - // If we exit the loop via Ok(()) or Disconnected, cancellation occurred }); TimerHandle { diff --git a/concurrency/src/threads/timer_tests.rs b/concurrency/src/threads/timer_tests.rs index e023a78..0339d02 100644 --- a/concurrency/src/threads/timer_tests.rs +++ b/concurrency/src/threads/timer_tests.rs @@ -1,33 +1,27 @@ use crate::threads::{ - send_interval, Actor, ActorRef, InitResult, MessageResponse, RequestResponse, + send_after, send_interval, Actor, ActorStart, Context, Handler, }; +use crate::message::Message; use spawned_rt::threads::{self as rt, CancellationToken}; use std::time::Duration; -use super::send_after; +// --- Repeater (interval timer test) --- -type RepeaterHandle = ActorRef; +#[derive(Clone, Debug)] +struct Inc; +impl Message for Inc { type Result = (); } -#[derive(Clone)] -enum RepeaterCastMessage { - Inc, - StopTimer, -} - -#[derive(Clone)] -enum RepeaterCallMessage { - GetCount, -} +#[derive(Clone, Debug)] +struct StopTimer; +impl Message for StopTimer { type Result = (); } -#[derive(PartialEq, Debug)] -enum RepeaterOutMessage { - Count(i32), -} +#[derive(Debug)] +struct GetRepCount; +impl Message for GetRepCount { type Result = i32; } -#[derive(Clone)] struct Repeater { - pub(crate) count: i32, - pub(crate) cancellation_token: Option, + count: i32, + cancellation_token: Option, } impl Repeater { @@ -39,240 +33,136 @@ impl Repeater { } } -impl Repeater { - pub fn stop_timer(server: &mut RepeaterHandle) -> Result<(), ()> { - server.send(RepeaterCastMessage::StopTimer).map_err(|_| ()) - } - - pub fn get_count(server: &mut RepeaterHandle) -> Result { - server - .request(RepeaterCallMessage::GetCount) - .map_err(|_| ()) - } -} - impl Actor for Repeater { - type Request = RepeaterCallMessage; - type Message = RepeaterCastMessage; - type Reply = RepeaterOutMessage; - type Error = (); - - fn init(mut self, handle: &RepeaterHandle) -> Result, Self::Error> { + fn started(&mut self, ctx: &Context) { let timer = send_interval( Duration::from_millis(100), - handle.clone(), - RepeaterCastMessage::Inc, + ctx.clone(), + Inc, ); self.cancellation_token = Some(timer.cancellation_token); - Ok(InitResult::Success(self)) } +} + +impl Handler for Repeater { + fn handle(&mut self, _msg: Inc, _ctx: &Context) { + self.count += 1; + } +} - fn handle_request( - &mut self, - _message: Self::Request, - _handle: &RepeaterHandle, - ) -> RequestResponse { - let count = self.count; - RequestResponse::Reply(RepeaterOutMessage::Count(count)) +impl Handler for Repeater { + fn handle(&mut self, _msg: StopTimer, _ctx: &Context) { + if let Some(ct) = self.cancellation_token.clone() { + ct.cancel(); + } } +} - fn handle_message( - &mut self, - message: Self::Message, - _handle: &ActorRef, - ) -> MessageResponse { - match message { - RepeaterCastMessage::Inc => { - self.count += 1; - } - RepeaterCastMessage::StopTimer => { - if let Some(ct) = self.cancellation_token.clone() { - ct.cancel() - }; - } - }; - MessageResponse::NoReply +impl Handler for Repeater { + fn handle(&mut self, _msg: GetRepCount, _ctx: &Context) -> i32 { + self.count } } #[test] pub fn test_send_interval_and_cancellation() { - // Start a Repeater - let mut repeater = Repeater::new(0).start(); + let repeater = Repeater::new(0).start(); - // Wait for 1 second rt::sleep(Duration::from_secs(1)); - // Check count - let count = Repeater::get_count(&mut repeater).unwrap(); - - // 9 messages in 1 second (after first 100 milliseconds sleep) - assert_eq!(RepeaterOutMessage::Count(9), count); + let count = repeater.request(GetRepCount).unwrap(); + assert_eq!(9, count); - // Pause timer - Repeater::stop_timer(&mut repeater).unwrap(); + repeater.send(StopTimer).unwrap(); - // Wait another second rt::sleep(Duration::from_secs(1)); - // Check count again - let count2 = Repeater::get_count(&mut repeater).unwrap(); - - // As timer was paused, count should remain at 9 - assert_eq!(RepeaterOutMessage::Count(9), count2); + let count2 = repeater.request(GetRepCount).unwrap(); + assert_eq!(9, count2); } -type DelayedHandle = ActorRef; - -#[derive(Clone)] -enum DelayedCastMessage { - Inc, -} +// --- Delayed (send_after test) --- -#[derive(Clone)] -enum DelayedCallMessage { - GetCount, - Stop, -} +#[derive(Debug)] +struct GetDelCount; +impl Message for GetDelCount { type Result = i32; } -#[derive(PartialEq, Debug)] -enum DelayedOutMessage { - Count(i32), -} +#[derive(Debug)] +struct StopDelayed; +impl Message for StopDelayed { type Result = i32; } -#[derive(Clone)] struct Delayed { - pub(crate) count: i32, + count: i32, } impl Delayed { pub fn new(initial_count: i32) -> Self { - Delayed { - count: initial_count, - } + Delayed { count: initial_count } } } -impl Delayed { - pub fn get_count(server: &mut DelayedHandle) -> Result { - server.request(DelayedCallMessage::GetCount).map_err(|_| ()) - } +impl Actor for Delayed {} - pub fn stop(server: &mut DelayedHandle) -> Result { - server.request(DelayedCallMessage::Stop).map_err(|_| ()) +impl Handler for Delayed { + fn handle(&mut self, _msg: Inc, _ctx: &Context) { + self.count += 1; } } -impl Actor for Delayed { - type Request = DelayedCallMessage; - type Message = DelayedCastMessage; - type Reply = DelayedOutMessage; - type Error = (); - - fn handle_request( - &mut self, - message: Self::Request, - _handle: &DelayedHandle, - ) -> RequestResponse { - match message { - DelayedCallMessage::GetCount => { - RequestResponse::Reply(DelayedOutMessage::Count(self.count)) - } - DelayedCallMessage::Stop => { - RequestResponse::Stop(DelayedOutMessage::Count(self.count)) - } - } +impl Handler for Delayed { + fn handle(&mut self, _msg: GetDelCount, _ctx: &Context) -> i32 { + self.count } +} - fn handle_message( - &mut self, - message: Self::Message, - _handle: &DelayedHandle, - ) -> MessageResponse { - match message { - DelayedCastMessage::Inc => { - self.count += 1; - } - }; - MessageResponse::NoReply +impl Handler for Delayed { + fn handle(&mut self, _msg: StopDelayed, ctx: &Context) -> i32 { + ctx.stop(); + self.count } } #[test] pub fn test_send_after_and_cancellation() { - // Start a Delayed - let mut repeater = Delayed::new(0).start(); + let actor = Delayed::new(0).start(); - // Set a just once timed message - let _ = send_after( - Duration::from_millis(100), - repeater.clone(), - DelayedCastMessage::Inc, - ); + let ctx = Context::from_ref(&actor); + let _ = send_after(Duration::from_millis(100), ctx, Inc); - // Wait for 200 milliseconds rt::sleep(Duration::from_millis(200)); - // Check count - let count = Delayed::get_count(&mut repeater).unwrap(); - - // Only one message (no repetition) - assert_eq!(DelayedOutMessage::Count(1), count); + let count = actor.request(GetDelCount).unwrap(); + assert_eq!(1, count); - // New timer - let timer = send_after( - Duration::from_millis(100), - repeater.clone(), - DelayedCastMessage::Inc, - ); + let ctx = Context::from_ref(&actor); + let timer = send_after(Duration::from_millis(100), ctx, Inc); - // Cancel the new timer before timeout timer.cancellation_token.cancel(); - // Wait another 200 milliseconds rt::sleep(Duration::from_millis(200)); - // Check count again - let count2 = Delayed::get_count(&mut repeater).unwrap(); - - // As timer was cancelled, count should remain at 1 - assert_eq!(DelayedOutMessage::Count(1), count2); + let count2 = actor.request(GetDelCount).unwrap(); + assert_eq!(1, count2); } #[test] pub fn test_send_after_actor_shutdown() { - // Start a Delayed - let mut actor = Delayed::new(0).start(); + let actor = Delayed::new(0).start(); - // Set a just once timed message - let _ = send_after( - Duration::from_millis(100), - actor.clone(), - DelayedCastMessage::Inc, - ); + let ctx = Context::from_ref(&actor); + let _ = send_after(Duration::from_millis(100), ctx, Inc); - // Wait for 200 milliseconds rt::sleep(Duration::from_millis(200)); - // Check count - let count = Delayed::get_count(&mut actor).unwrap(); - - // Only one message (no repetition) - assert_eq!(DelayedOutMessage::Count(1), count); + let count = actor.request(GetDelCount).unwrap(); + assert_eq!(1, count); - // New timer with long delay - let _ = send_after( - Duration::from_millis(100), - actor.clone(), - DelayedCastMessage::Inc, - ); + let ctx = Context::from_ref(&actor); + let _ = send_after(Duration::from_millis(100), ctx, Inc); - // Stop the Actor before timeout - this should wake up the timer immediately - let count2 = Delayed::stop(&mut actor).unwrap(); + let count2 = actor.request(StopDelayed).unwrap(); - // Wait another 200 milliseconds rt::sleep(Duration::from_millis(200)); - // As actor was stopped, count should remain at 1 (timer didn't fire) - assert_eq!(DelayedOutMessage::Count(1), count2); + assert_eq!(1, count2); } diff --git a/examples/bank/src/main.rs b/examples/bank/src/main.rs index d3321af..a71d158 100644 --- a/examples/bank/src/main.rs +++ b/examples/bank/src/main.rs @@ -24,16 +24,16 @@ mod server; use messages::{BankError, BankOutMessage}; use server::Bank; -use spawned_concurrency::tasks::Actor as _; +use spawned_concurrency::tasks::ActorStart; use spawned_rt::tasks as rt; fn main() { rt::run(async { // Starting the bank - let mut name_server = Bank::new().start(); + let name_server = Bank::new().start(); // Testing initial balance for "main" account - let result = Bank::withdraw(&mut name_server, "main".to_string(), 15).await; + let result = Bank::withdraw(&name_server, "main".to_string(), 15).await; tracing::info!("Withdraw result {result:?}"); assert_eq!( result, @@ -46,17 +46,17 @@ fn main() { let joe = "Joe".to_string(); // Error on deposit for an unexistent account - let result = Bank::deposit(&mut name_server, joe.clone(), 10).await; + let result = Bank::deposit(&name_server, joe.clone(), 10).await; tracing::info!("Deposit result {result:?}"); assert_eq!(result, Err(BankError::NotACustomer { who: joe.clone() })); // Account creation - let result = Bank::new_account(&mut name_server, "Joe".to_string()).await; + let result = Bank::new_account(&name_server, "Joe".to_string()).await; tracing::info!("New account result {result:?}"); assert_eq!(result, Ok(BankOutMessage::Welcome { who: joe.clone() })); // Deposit - let result = Bank::deposit(&mut name_server, "Joe".to_string(), 10).await; + let result = Bank::deposit(&name_server, "Joe".to_string(), 10).await; tracing::info!("Deposit result {result:?}"); assert_eq!( result, @@ -67,7 +67,7 @@ fn main() { ); // Deposit - let result = Bank::deposit(&mut name_server, "Joe".to_string(), 30).await; + let result = Bank::deposit(&name_server, "Joe".to_string(), 30).await; tracing::info!("Deposit result {result:?}"); assert_eq!( result, @@ -78,7 +78,7 @@ fn main() { ); // Withdrawal - let result = Bank::withdraw(&mut name_server, "Joe".to_string(), 15).await; + let result = Bank::withdraw(&name_server, "Joe".to_string(), 15).await; tracing::info!("Withdraw result {result:?}"); assert_eq!( result, @@ -89,7 +89,7 @@ fn main() { ); // Withdrawal with not enough balance - let result = Bank::withdraw(&mut name_server, "Joe".to_string(), 45).await; + let result = Bank::withdraw(&name_server, "Joe".to_string(), 45).await; tracing::info!("Withdraw result {result:?}"); assert_eq!( result, @@ -100,7 +100,7 @@ fn main() { ); // Full withdrawal - let result = Bank::withdraw(&mut name_server, "Joe".to_string(), 25).await; + let result = Bank::withdraw(&name_server, "Joe".to_string(), 25).await; tracing::info!("Withdraw result {result:?}"); assert_eq!( result, @@ -111,7 +111,7 @@ fn main() { ); // Stopping the bank - let result = Bank::stop(&mut name_server).await; + let result = Bank::stop(&name_server).await; tracing::info!("Stop result {result:?}"); assert_eq!(result, Ok(BankOutMessage::Stopped)); }) diff --git a/examples/bank/src/messages.rs b/examples/bank/src/messages.rs index d58ae9d..bbee592 100644 --- a/examples/bank/src/messages.rs +++ b/examples/bank/src/messages.rs @@ -1,12 +1,5 @@ -#[derive(Debug, Clone)] -pub enum BankInMessage { - New { who: String }, - Add { who: String, amount: i32 }, - Remove { who: String, amount: i32 }, - Stop, -} +use spawned_concurrency::message::Message; -#[allow(dead_code)] #[derive(Debug, Clone, PartialEq)] pub enum BankOutMessage { Welcome { who: String }, @@ -15,7 +8,6 @@ pub enum BankOutMessage { Stopped, } -#[allow(dead_code)] #[derive(Debug, Clone, PartialEq)] pub enum BankError { AlreadyACustomer { who: String }, @@ -23,3 +15,37 @@ pub enum BankError { InsufficientBalance { who: String, amount: i32 }, ServerError, } + +type MsgResult = Result; + +#[derive(Debug)] +pub struct NewAccount { + pub who: String, +} +impl Message for NewAccount { + type Result = MsgResult; +} + +#[derive(Debug)] +pub struct Deposit { + pub who: String, + pub amount: i32, +} +impl Message for Deposit { + type Result = MsgResult; +} + +#[derive(Debug)] +pub struct Withdraw { + pub who: String, + pub amount: i32, +} +impl Message for Withdraw { + type Result = MsgResult; +} + +#[derive(Debug)] +pub struct Stop; +impl Message for Stop { + type Result = MsgResult; +} diff --git a/examples/bank/src/server.rs b/examples/bank/src/server.rs index bd2bfed..e56cc7b 100644 --- a/examples/bank/src/server.rs +++ b/examples/bank/src/server.rs @@ -1,18 +1,10 @@ use std::collections::HashMap; -use spawned_concurrency::{ - messages::Unused, - tasks::{ - Actor, ActorRef, - InitResult::{self, Success}, - RequestResponse, - }, -}; +use spawned_concurrency::tasks::{Actor, ActorRef, Context, Handler}; -use crate::messages::{BankError, BankInMessage as InMessage, BankOutMessage as OutMessage}; +use crate::messages::{BankError, BankOutMessage as OutMessage, Deposit, NewAccount, Stop, Withdraw}; type MsgResult = Result; -type BankHandle = ActorRef; pub struct Bank { accounts: HashMap, @@ -27,89 +19,95 @@ impl Bank { } impl Bank { - pub async fn stop(server: &mut BankHandle) -> MsgResult { + pub async fn stop(server: &ActorRef) -> MsgResult { server - .request(InMessage::Stop) + .request(Stop) .await .unwrap_or(Err(BankError::ServerError)) } - pub async fn new_account(server: &mut BankHandle, who: String) -> MsgResult { + pub async fn new_account(server: &ActorRef, who: String) -> MsgResult { server - .request(InMessage::New { who }) + .request(NewAccount { who }) .await .unwrap_or(Err(BankError::ServerError)) } - pub async fn deposit(server: &mut BankHandle, who: String, amount: i32) -> MsgResult { + pub async fn deposit(server: &ActorRef, who: String, amount: i32) -> MsgResult { server - .request(InMessage::Add { who, amount }) + .request(Deposit { who, amount }) .await .unwrap_or(Err(BankError::ServerError)) } - pub async fn withdraw(server: &mut BankHandle, who: String, amount: i32) -> MsgResult { + pub async fn withdraw(server: &ActorRef, who: String, amount: i32) -> MsgResult { server - .request(InMessage::Remove { who, amount }) + .request(Withdraw { who, amount }) .await .unwrap_or(Err(BankError::ServerError)) } } impl Actor for Bank { - type Request = InMessage; - type Message = Unused; - type Reply = MsgResult; - type Error = BankError; - - // Initializing "main" account with 1000 in balance to test init() callback. - async fn init(mut self, _handle: &ActorRef) -> Result, Self::Error> { + async fn started(&mut self, _ctx: &Context) { self.accounts.insert("main".to_string(), 1000); - Ok(Success(self)) } +} - async fn handle_request( - &mut self, - message: Self::Request, - _handle: &BankHandle, - ) -> RequestResponse { - match message.clone() { - Self::Request::New { who } => match self.accounts.get(&who) { - Some(_amount) => RequestResponse::Reply(Err(BankError::AlreadyACustomer { who })), - None => { - self.accounts.insert(who.clone(), 0); - RequestResponse::Reply(Ok(OutMessage::Welcome { who })) - } - }, - Self::Request::Add { who, amount } => match self.accounts.get(&who) { - Some(current) => { - let new_amount = current + amount; - self.accounts.insert(who.clone(), new_amount); - RequestResponse::Reply(Ok(OutMessage::Balance { - who, +impl Handler for Bank { + async fn handle(&mut self, msg: NewAccount, _ctx: &Context) -> MsgResult { + match self.accounts.get(&msg.who) { + Some(_) => Err(BankError::AlreadyACustomer { who: msg.who }), + None => { + self.accounts.insert(msg.who.clone(), 0); + Ok(OutMessage::Welcome { who: msg.who }) + } + } + } +} + +impl Handler for Bank { + async fn handle(&mut self, msg: Deposit, _ctx: &Context) -> MsgResult { + match self.accounts.get(&msg.who) { + Some(current) => { + let new_amount = current + msg.amount; + self.accounts.insert(msg.who.clone(), new_amount); + Ok(OutMessage::Balance { + who: msg.who, + amount: new_amount, + }) + } + None => Err(BankError::NotACustomer { who: msg.who }), + } + } +} + +impl Handler for Bank { + async fn handle(&mut self, msg: Withdraw, _ctx: &Context) -> MsgResult { + match self.accounts.get(&msg.who) { + Some(¤t) => { + if current < msg.amount { + Err(BankError::InsufficientBalance { + who: msg.who, + amount: current, + }) + } else { + let new_amount = current - msg.amount; + self.accounts.insert(msg.who.clone(), new_amount); + Ok(OutMessage::WidrawOk { + who: msg.who, amount: new_amount, - })) + }) } - None => RequestResponse::Reply(Err(BankError::NotACustomer { who })), - }, - Self::Request::Remove { who, amount } => match self.accounts.get(&who) { - Some(¤t) => match current < amount { - true => RequestResponse::Reply(Err(BankError::InsufficientBalance { - who, - amount: current, - })), - false => { - let new_amount = current - amount; - self.accounts.insert(who.clone(), new_amount); - RequestResponse::Reply(Ok(OutMessage::WidrawOk { - who, - amount: new_amount, - })) - } - }, - None => RequestResponse::Reply(Err(BankError::NotACustomer { who })), - }, - Self::Request::Stop => RequestResponse::Stop(Ok(OutMessage::Stopped)), + } + None => Err(BankError::NotACustomer { who: msg.who }), } } } + +impl Handler for Bank { + async fn handle(&mut self, _msg: Stop, ctx: &Context) -> MsgResult { + ctx.stop(); + Ok(OutMessage::Stopped) + } +} diff --git a/examples/bank_threads/src/main.rs b/examples/bank_threads/src/main.rs index 9b89c54..aa67b4b 100644 --- a/examples/bank_threads/src/main.rs +++ b/examples/bank_threads/src/main.rs @@ -24,16 +24,16 @@ mod server; use messages::{BankError, BankOutMessage}; use server::Bank; -use spawned_concurrency::threads::Actor as _; +use spawned_concurrency::threads::ActorStart; use spawned_rt::threads as rt; fn main() { rt::run(|| { // Starting the bank - let mut name_server = Bank::new().start(); + let name_server = Bank::new().start(); // Testing initial balance for "main" account - let result = Bank::withdraw(&mut name_server, "main".to_string(), 15); + let result = Bank::withdraw(&name_server, "main".to_string(), 15); tracing::info!("Withdraw result {result:?}"); assert_eq!( result, @@ -46,17 +46,17 @@ fn main() { let joe = "Joe".to_string(); // Error on deposit for an unexistent account - let result = Bank::deposit(&mut name_server, joe.clone(), 10); + let result = Bank::deposit(&name_server, joe.clone(), 10); tracing::info!("Deposit result {result:?}"); assert_eq!(result, Err(BankError::NotACustomer { who: joe.clone() })); // Account creation - let result = Bank::new_account(&mut name_server, "Joe".to_string()); + let result = Bank::new_account(&name_server, "Joe".to_string()); tracing::info!("New account result {result:?}"); assert_eq!(result, Ok(BankOutMessage::Welcome { who: joe.clone() })); // Deposit - let result = Bank::deposit(&mut name_server, "Joe".to_string(), 10); + let result = Bank::deposit(&name_server, "Joe".to_string(), 10); tracing::info!("Deposit result {result:?}"); assert_eq!( result, @@ -67,7 +67,7 @@ fn main() { ); // Deposit - let result = Bank::deposit(&mut name_server, "Joe".to_string(), 30); + let result = Bank::deposit(&name_server, "Joe".to_string(), 30); tracing::info!("Deposit result {result:?}"); assert_eq!( result, @@ -78,7 +78,7 @@ fn main() { ); // Withdrawal - let result = Bank::withdraw(&mut name_server, "Joe".to_string(), 15); + let result = Bank::withdraw(&name_server, "Joe".to_string(), 15); tracing::info!("Withdraw result {result:?}"); assert_eq!( result, @@ -89,7 +89,7 @@ fn main() { ); // Withdrawal with not enough balance - let result = Bank::withdraw(&mut name_server, "Joe".to_string(), 45); + let result = Bank::withdraw(&name_server, "Joe".to_string(), 45); tracing::info!("Withdraw result {result:?}"); assert_eq!( result, @@ -100,7 +100,7 @@ fn main() { ); // Full withdrawal - let result = Bank::withdraw(&mut name_server, "Joe".to_string(), 25); + let result = Bank::withdraw(&name_server, "Joe".to_string(), 25); tracing::info!("Withdraw result {result:?}"); assert_eq!( result, @@ -111,7 +111,7 @@ fn main() { ); // Stopping the bank - let result = Bank::stop(&mut name_server); + let result = Bank::stop(&name_server); tracing::info!("Stop result {result:?}"); assert_eq!(result, Ok(BankOutMessage::Stopped)); }) diff --git a/examples/bank_threads/src/messages.rs b/examples/bank_threads/src/messages.rs index d58ae9d..bbee592 100644 --- a/examples/bank_threads/src/messages.rs +++ b/examples/bank_threads/src/messages.rs @@ -1,12 +1,5 @@ -#[derive(Debug, Clone)] -pub enum BankInMessage { - New { who: String }, - Add { who: String, amount: i32 }, - Remove { who: String, amount: i32 }, - Stop, -} +use spawned_concurrency::message::Message; -#[allow(dead_code)] #[derive(Debug, Clone, PartialEq)] pub enum BankOutMessage { Welcome { who: String }, @@ -15,7 +8,6 @@ pub enum BankOutMessage { Stopped, } -#[allow(dead_code)] #[derive(Debug, Clone, PartialEq)] pub enum BankError { AlreadyACustomer { who: String }, @@ -23,3 +15,37 @@ pub enum BankError { InsufficientBalance { who: String, amount: i32 }, ServerError, } + +type MsgResult = Result; + +#[derive(Debug)] +pub struct NewAccount { + pub who: String, +} +impl Message for NewAccount { + type Result = MsgResult; +} + +#[derive(Debug)] +pub struct Deposit { + pub who: String, + pub amount: i32, +} +impl Message for Deposit { + type Result = MsgResult; +} + +#[derive(Debug)] +pub struct Withdraw { + pub who: String, + pub amount: i32, +} +impl Message for Withdraw { + type Result = MsgResult; +} + +#[derive(Debug)] +pub struct Stop; +impl Message for Stop { + type Result = MsgResult; +} diff --git a/examples/bank_threads/src/server.rs b/examples/bank_threads/src/server.rs index 5edf5f7..c36f884 100644 --- a/examples/bank_threads/src/server.rs +++ b/examples/bank_threads/src/server.rs @@ -1,16 +1,11 @@ use std::collections::HashMap; -use spawned_concurrency::{ - messages::Unused, - threads::{Actor, ActorRef, InitResult, RequestResponse}, -}; +use spawned_concurrency::threads::{Actor, ActorRef, Context, Handler}; -use crate::messages::{BankError, BankInMessage as InMessage, BankOutMessage as OutMessage}; +use crate::messages::{BankError, BankOutMessage as OutMessage, Deposit, NewAccount, Stop, Withdraw}; type MsgResult = Result; -type BankHandle = ActorRef; -#[derive(Clone)] pub struct Bank { accounts: HashMap, } @@ -24,85 +19,91 @@ impl Bank { } impl Bank { - pub fn stop(server: &mut BankHandle) -> MsgResult { + pub fn stop(server: &ActorRef) -> MsgResult { server - .request(InMessage::Stop) + .request(Stop) .unwrap_or(Err(BankError::ServerError)) } - pub fn new_account(server: &mut BankHandle, who: String) -> MsgResult { + pub fn new_account(server: &ActorRef, who: String) -> MsgResult { server - .request(InMessage::New { who }) + .request(NewAccount { who }) .unwrap_or(Err(BankError::ServerError)) } - pub fn deposit(server: &mut BankHandle, who: String, amount: i32) -> MsgResult { + pub fn deposit(server: &ActorRef, who: String, amount: i32) -> MsgResult { server - .request(InMessage::Add { who, amount }) + .request(Deposit { who, amount }) .unwrap_or(Err(BankError::ServerError)) } - pub fn withdraw(server: &mut BankHandle, who: String, amount: i32) -> MsgResult { + pub fn withdraw(server: &ActorRef, who: String, amount: i32) -> MsgResult { server - .request(InMessage::Remove { who, amount }) + .request(Withdraw { who, amount }) .unwrap_or(Err(BankError::ServerError)) } } impl Actor for Bank { - type Request = InMessage; - type Message = Unused; - type Reply = MsgResult; - type Error = BankError; - - // Initializing "main" account with 1000 in balance to test init() callback. - fn init(mut self, _handle: &ActorRef) -> Result, Self::Error> { + fn started(&mut self, _ctx: &Context) { self.accounts.insert("main".to_string(), 1000); - Ok(InitResult::Success(self)) } +} - fn handle_request( - &mut self, - message: Self::Request, - _handle: &BankHandle, - ) -> RequestResponse { - match message.clone() { - Self::Request::New { who } => match self.accounts.get(&who) { - Some(_amount) => RequestResponse::Reply(Err(BankError::AlreadyACustomer { who })), - None => { - self.accounts.insert(who.clone(), 0); - RequestResponse::Reply(Ok(OutMessage::Welcome { who })) - } - }, - Self::Request::Add { who, amount } => match self.accounts.get(&who) { - Some(current) => { - let new_amount = current + amount; - self.accounts.insert(who.clone(), new_amount); - RequestResponse::Reply(Ok(OutMessage::Balance { - who, +impl Handler for Bank { + fn handle(&mut self, msg: NewAccount, _ctx: &Context) -> MsgResult { + match self.accounts.get(&msg.who) { + Some(_) => Err(BankError::AlreadyACustomer { who: msg.who }), + None => { + self.accounts.insert(msg.who.clone(), 0); + Ok(OutMessage::Welcome { who: msg.who }) + } + } + } +} + +impl Handler for Bank { + fn handle(&mut self, msg: Deposit, _ctx: &Context) -> MsgResult { + match self.accounts.get(&msg.who) { + Some(current) => { + let new_amount = current + msg.amount; + self.accounts.insert(msg.who.clone(), new_amount); + Ok(OutMessage::Balance { + who: msg.who, + amount: new_amount, + }) + } + None => Err(BankError::NotACustomer { who: msg.who }), + } + } +} + +impl Handler for Bank { + fn handle(&mut self, msg: Withdraw, _ctx: &Context) -> MsgResult { + match self.accounts.get(&msg.who) { + Some(¤t) => { + if current < msg.amount { + Err(BankError::InsufficientBalance { + who: msg.who, + amount: current, + }) + } else { + let new_amount = current - msg.amount; + self.accounts.insert(msg.who.clone(), new_amount); + Ok(OutMessage::WidrawOk { + who: msg.who, amount: new_amount, - })) + }) } - None => RequestResponse::Reply(Err(BankError::NotACustomer { who })), - }, - Self::Request::Remove { who, amount } => match self.accounts.get(&who) { - Some(¤t) => match current < amount { - true => RequestResponse::Reply(Err(BankError::InsufficientBalance { - who, - amount: current, - })), - false => { - let new_amount = current - amount; - self.accounts.insert(who.clone(), new_amount); - RequestResponse::Reply(Ok(OutMessage::WidrawOk { - who, - amount: new_amount, - })) - } - }, - None => RequestResponse::Reply(Err(BankError::NotACustomer { who })), - }, - Self::Request::Stop => RequestResponse::Stop(Ok(OutMessage::Stopped)), + } + None => Err(BankError::NotACustomer { who: msg.who }), } } } + +impl Handler for Bank { + fn handle(&mut self, _msg: Stop, ctx: &Context) -> MsgResult { + ctx.stop(); + Ok(OutMessage::Stopped) + } +} diff --git a/examples/blocking_genserver/main.rs b/examples/blocking_genserver/main.rs index f1ec820..1535d24 100644 --- a/examples/blocking_genserver/main.rs +++ b/examples/blocking_genserver/main.rs @@ -2,11 +2,16 @@ use spawned_rt::tasks as rt; use std::time::Duration; use std::{process::exit, thread}; -use spawned_concurrency::tasks::{ - Actor, ActorRef, Backend, MessageResponse, RequestResponse, send_after, -}; +use spawned_concurrency::messages; +use spawned_concurrency::tasks::{Actor, ActorStart, Backend, Context, Handler, send_after}; + +messages! { + GetCount -> u64; + StopActor -> u64; + BadWork -> (); + GoodWork -> () +} -// We test a scenario with a badly behaved task struct BadlyBehavedTask; impl BadlyBehavedTask { @@ -15,32 +20,17 @@ impl BadlyBehavedTask { } } -#[derive(Clone)] -pub enum InMessage { - GetCount, - Stop, -} - -#[derive(Clone)] -pub enum OutMsg { - Count(u64), -} +impl Actor for BadlyBehavedTask {} -impl Actor for BadlyBehavedTask { - type Request = InMessage; - type Message = (); - type Reply = (); - type Error = (); - - async fn handle_request( - &mut self, - _: Self::Request, - _: &ActorRef, - ) -> RequestResponse { - RequestResponse::Stop(()) +impl Handler for BadlyBehavedTask { + async fn handle(&mut self, _: StopActor, ctx: &Context) -> u64 { + ctx.stop(); + 0 } +} - async fn handle_message(&mut self, _: Self::Message, _: &ActorRef) -> MessageResponse { +impl Handler for BadlyBehavedTask { + async fn handle(&mut self, _: BadWork, _ctx: &Context) { rt::sleep(Duration::from_millis(20)).await; loop { println!("{:?}: bad still alive", thread::current().id()); @@ -61,35 +51,26 @@ impl WellBehavedTask { } } -impl Actor for WellBehavedTask { - type Request = InMessage; - type Message = (); - type Reply = OutMsg; - type Error = (); - - async fn handle_request( - &mut self, - message: Self::Request, - _: &ActorRef, - ) -> RequestResponse { - match message { - InMessage::GetCount => { - let count = self.count; - RequestResponse::Reply(OutMsg::Count(count)) - } - InMessage::Stop => RequestResponse::Stop(OutMsg::Count(self.count)), - } +impl Actor for WellBehavedTask {} + +impl Handler for WellBehavedTask { + async fn handle(&mut self, _: GetCount, _ctx: &Context) -> u64 { + self.count + } +} + +impl Handler for WellBehavedTask { + async fn handle(&mut self, _: StopActor, ctx: &Context) -> u64 { + ctx.stop(); + self.count } +} - async fn handle_message( - &mut self, - _: Self::Message, - handle: &ActorRef, - ) -> MessageResponse { +impl Handler for WellBehavedTask { + async fn handle(&mut self, _: GoodWork, ctx: &Context) { self.count += 1; println!("{:?}: good still alive", thread::current().id()); - send_after(Duration::from_millis(100), handle.to_owned(), ()); - MessageResponse::NoReply + send_after(Duration::from_millis(100), ctx.clone(), GoodWork); } } @@ -99,20 +80,16 @@ impl Actor for WellBehavedTask { pub fn main() { rt::run(async move { // If we change BadlyBehavedTask to Backend::Async instead, it can stop the entire program - let mut badboy = BadlyBehavedTask::new().start_with_backend(Backend::Thread); - let _ = badboy.send(()).await; - let mut goodboy = WellBehavedTask::new(0).start(); - let _ = goodboy.send(()).await; + let badboy = BadlyBehavedTask::new().start_with_backend(Backend::Thread); + let _ = badboy.send(BadWork); + let goodboy = WellBehavedTask::new(0).start(); + let _ = goodboy.send(GoodWork); rt::sleep(Duration::from_secs(1)).await; - let count = goodboy.request(InMessage::GetCount).await.unwrap(); + let count = goodboy.request(GetCount).await.unwrap(); - match count { - OutMsg::Count(num) => { - assert!(num == 10); - } - } + assert!(count == 10); - goodboy.request(InMessage::Stop).await.unwrap(); + goodboy.request(StopActor).await.unwrap(); exit(0); }) } diff --git a/examples/busy_genserver_warning/main.rs b/examples/busy_genserver_warning/main.rs index cf83573..7eddd19 100644 --- a/examples/busy_genserver_warning/main.rs +++ b/examples/busy_genserver_warning/main.rs @@ -3,9 +3,14 @@ use std::time::Duration; use std::{process::exit, thread}; use tracing::info; -use spawned_concurrency::tasks::{Actor, ActorRef, MessageResponse, RequestResponse}; +use spawned_concurrency::messages; +use spawned_concurrency::tasks::{Actor, ActorStart, Context, Handler}; + +messages! { + StopBusy -> (); + BusyWork -> () +} -// We test a scenario with a badly behaved task struct BusyWorker; impl BusyWorker { @@ -14,43 +19,22 @@ impl BusyWorker { } } -#[derive(Clone)] -pub enum InMessage { - GetCount, - Stop, -} +impl Actor for BusyWorker {} -#[derive(Clone)] -pub enum OutMsg { - Count(u64), -} - -impl Actor for BusyWorker { - type Request = InMessage; - type Message = (); - type Reply = (); - type Error = (); - - async fn handle_request( - &mut self, - _: Self::Request, - _: &ActorRef, - ) -> RequestResponse { - RequestResponse::Stop(()) +impl Handler for BusyWorker { + async fn handle(&mut self, _: StopBusy, ctx: &Context) { + ctx.stop(); } +} - async fn handle_message( - &mut self, - _: Self::Message, - handle: &ActorRef, - ) -> MessageResponse { +impl Handler for BusyWorker { + async fn handle(&mut self, _: BusyWork, ctx: &Context) { info!(taskid = ?rt::task_id(), "sleeping"); thread::sleep(Duration::from_millis(542)); - handle.clone().send(()).await.unwrap(); + ctx.send(BusyWork).unwrap(); // This sleep is needed to yield control to the runtime. // If not, the future never returns and the warning isn't emitted. rt::sleep(Duration::from_millis(0)).await; - MessageResponse::NoReply } } @@ -64,8 +48,8 @@ impl Actor for BusyWorker { pub fn main() { rt::run(async move { // If we change BusyWorker to Backend::Blocking instead, it won't print the warning - let mut badboy = BusyWorker::new().start(); - let _ = badboy.send(()).await; + let badboy = BusyWorker::new().start(); + let _ = badboy.send(BusyWork); rt::sleep(Duration::from_secs(5)).await; exit(0); diff --git a/examples/chat_room/Cargo.toml b/examples/chat_room/Cargo.toml new file mode 100644 index 0000000..e36a671 --- /dev/null +++ b/examples/chat_room/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "chat_room" +version.workspace = true +edition.workspace = true +license.workspace = true + +[dependencies] +spawned-rt = { workspace = true } +spawned-concurrency = { workspace = true } +spawned-macros = { workspace = true } +tracing = { workspace = true } diff --git a/examples/chat_room/src/main.rs b/examples/chat_room/src/main.rs new file mode 100644 index 0000000..45511d8 --- /dev/null +++ b/examples/chat_room/src/main.rs @@ -0,0 +1,38 @@ +mod room; +mod user; + +use std::time::Duration; + +use room::{ChatRoom, ChatRoomApi}; +use spawned_concurrency::tasks::ActorStart; +use spawned_rt::tasks as rt; +use user::{User, UserApi}; + +fn main() { + rt::run(async { + let room = ChatRoom::new().start(); + + let alice = User::new("Alice".into()).start(); + let bob = User::new("Bob".into()).start(); + + // Register users in the room (send — fire-and-forget) + alice.join_room(room.clone()).unwrap(); + bob.join_room(room.clone()).unwrap(); + + // Let join messages propagate (user → room) + rt::sleep(Duration::from_millis(10)).await; + + // Query members (request — awaits a response) + let members = room.members().await.unwrap(); + tracing::info!("Members in room: {:?}", members); + + // Chat (send — fire-and-forget) + alice.say("Hello everyone!".into()).unwrap(); + bob.say("Hi Alice!".into()).unwrap(); + + // Give time for messages to propagate + rt::sleep(Duration::from_millis(100)).await; + + tracing::info!("Chat room demo complete"); + }); +} diff --git a/examples/chat_room/src/room.rs b/examples/chat_room/src/room.rs new file mode 100644 index 0000000..d0651aa --- /dev/null +++ b/examples/chat_room/src/room.rs @@ -0,0 +1,67 @@ +use spawned_concurrency::actor_api; +use spawned_concurrency::send_messages; +use spawned_concurrency::request_messages; +use spawned_concurrency::tasks::{Actor, ActorRef, Context, Handler, Recipient}; +use spawned_macros::actor; + +// -- Messages -- + +send_messages! { + Say { from: String, text: String }; + Deliver { from: String, text: String }; + Join { name: String, inbox: Recipient } +} + +request_messages! { + Members -> Vec +} + +// -- API -- + +actor_api! { + pub ChatRoomApi for ActorRef { + send fn say(from: String, text: String) => Say; + send fn add_member(name: String, inbox: Recipient) => Join; + request async fn members() -> Vec => Members; + } +} + +// -- Actor -- + +pub struct ChatRoom { + members: Vec<(String, Recipient)>, +} + +impl Actor for ChatRoom {} + +#[actor] +impl ChatRoom { + pub fn new() -> Self { + Self { + members: Vec::new(), + } + } + + #[send_handler] + async fn handle_say(&mut self, msg: Say, _ctx: &Context) { + for (name, inbox) in &self.members { + if *name != msg.from { + let _ = inbox.send(Deliver { + from: msg.from.clone(), + text: msg.text.clone(), + }); + } + } + } + + #[send_handler] + async fn handle_join(&mut self, msg: Join, _ctx: &Context) { + tracing::info!("[room] {} joined", msg.name); + self.members.push((msg.name, msg.inbox)); + } + + #[request_handler] + async fn handle_members(&mut self, _msg: Members, _ctx: &Context) -> Vec { + self.members.iter().map(|(name, _)| name.clone()).collect() + } +} diff --git a/examples/chat_room/src/user.rs b/examples/chat_room/src/user.rs new file mode 100644 index 0000000..bcfc5b2 --- /dev/null +++ b/examples/chat_room/src/user.rs @@ -0,0 +1,56 @@ +use spawned_concurrency::actor_api; +use spawned_concurrency::send_messages; +use spawned_concurrency::tasks::{Actor, ActorRef, Context, Handler}; +use spawned_macros::actor; + +use crate::room::{ChatRoom, ChatRoomApi, Deliver}; + +// -- Messages -- + +send_messages! { + SayToRoom { text: String }; + JoinRoom { room: ActorRef } +} + +// -- API -- + +actor_api! { + pub UserApi for ActorRef { + send fn say(text: String) => SayToRoom; + send fn join_room(room: ActorRef) => JoinRoom; + } +} + +// -- Actor -- + +pub struct User { + pub name: String, + room: Option>, +} + +impl Actor for User {} + +#[actor] +impl User { + pub fn new(name: String) -> Self { + Self { name, room: None } + } + + #[send_handler] + async fn handle_say_to_room(&mut self, msg: SayToRoom, _ctx: &Context) { + if let Some(ref room) = self.room { + let _ = room.say(self.name.clone(), msg.text); + } + } + + #[send_handler] + async fn handle_join_room(&mut self, msg: JoinRoom, ctx: &Context) { + let _ = msg.room.add_member(self.name.clone(), ctx.recipient::()); + self.room = Some(msg.room); + } + + #[send_handler] + async fn handle_deliver(&mut self, msg: Deliver, _ctx: &Context) { + tracing::info!("[{}] got: {} says '{}'", self.name, msg.from, msg.text); + } +} diff --git a/examples/chat_room_threads/Cargo.toml b/examples/chat_room_threads/Cargo.toml new file mode 100644 index 0000000..b04a3b5 --- /dev/null +++ b/examples/chat_room_threads/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "chat_room_threads" +version.workspace = true +edition.workspace = true +license.workspace = true + +[dependencies] +spawned-rt = { workspace = true } +spawned-concurrency = { workspace = true } +spawned-macros = { workspace = true } +tracing = { workspace = true } diff --git a/examples/chat_room_threads/src/main.rs b/examples/chat_room_threads/src/main.rs new file mode 100644 index 0000000..db491bd --- /dev/null +++ b/examples/chat_room_threads/src/main.rs @@ -0,0 +1,39 @@ +mod room; +mod user; + +use std::thread; +use std::time::Duration; + +use room::{ChatRoom, ChatRoomApi}; +use spawned_concurrency::threads::ActorStart; +use spawned_rt::threads as rt; +use user::{User, UserApi}; + +fn main() { + rt::run(|| { + let room = ChatRoom::new().start(); + + let alice = User::new("Alice".into()).start(); + let bob = User::new("Bob".into()).start(); + + // Register users in the room (send — fire-and-forget) + alice.join_room(room.clone()).unwrap(); + bob.join_room(room.clone()).unwrap(); + + // Let join messages propagate (user → room) + thread::sleep(Duration::from_millis(10)); + + // Query members (request — blocking) + let members = room.members().unwrap(); + tracing::info!("Members in room: {:?}", members); + + // Chat (send — fire-and-forget) + alice.say("Hello everyone!".into()).unwrap(); + bob.say("Hi Alice!".into()).unwrap(); + + // Give time for messages to propagate + thread::sleep(Duration::from_millis(100)); + + tracing::info!("Chat room demo complete"); + }); +} diff --git a/examples/chat_room_threads/src/room.rs b/examples/chat_room_threads/src/room.rs new file mode 100644 index 0000000..7108abf --- /dev/null +++ b/examples/chat_room_threads/src/room.rs @@ -0,0 +1,67 @@ +use spawned_concurrency::actor_api; +use spawned_concurrency::request_messages; +use spawned_concurrency::send_messages; +use spawned_concurrency::threads::{Actor, ActorRef, Context, Handler, Recipient}; +use spawned_macros::actor; + +// -- Messages -- + +send_messages! { + Say { from: String, text: String }; + Deliver { from: String, text: String }; + Join { name: String, inbox: Recipient } +} + +request_messages! { + Members -> Vec +} + +// -- API -- + +actor_api! { + pub ChatRoomApi for ActorRef { + send fn say(from: String, text: String) => Say; + send fn add_member(name: String, inbox: Recipient) => Join; + request fn members() -> Vec => Members; + } +} + +// -- Actor -- + +pub struct ChatRoom { + members: Vec<(String, Recipient)>, +} + +impl Actor for ChatRoom {} + +#[actor] +impl ChatRoom { + pub fn new() -> Self { + Self { + members: Vec::new(), + } + } + + #[send_handler] + fn handle_say(&mut self, msg: Say, _ctx: &Context) { + for (name, inbox) in &self.members { + if *name != msg.from { + let _ = inbox.send(Deliver { + from: msg.from.clone(), + text: msg.text.clone(), + }); + } + } + } + + #[send_handler] + fn handle_join(&mut self, msg: Join, _ctx: &Context) { + tracing::info!("[room] {} joined", msg.name); + self.members.push((msg.name, msg.inbox)); + } + + #[request_handler] + fn handle_members(&mut self, _msg: Members, _ctx: &Context) -> Vec { + self.members.iter().map(|(name, _)| name.clone()).collect() + } +} diff --git a/examples/chat_room_threads/src/user.rs b/examples/chat_room_threads/src/user.rs new file mode 100644 index 0000000..4dfe6a5 --- /dev/null +++ b/examples/chat_room_threads/src/user.rs @@ -0,0 +1,56 @@ +use spawned_concurrency::actor_api; +use spawned_concurrency::send_messages; +use spawned_concurrency::threads::{Actor, ActorRef, Context, Handler}; +use spawned_macros::actor; + +use crate::room::{ChatRoom, ChatRoomApi, Deliver}; + +// -- Messages -- + +send_messages! { + SayToRoom { text: String }; + JoinRoom { room: ActorRef } +} + +// -- API -- + +actor_api! { + pub UserApi for ActorRef { + send fn say(text: String) => SayToRoom; + send fn join_room(room: ActorRef) => JoinRoom; + } +} + +// -- Actor -- + +pub struct User { + pub name: String, + room: Option>, +} + +impl Actor for User {} + +#[actor] +impl User { + pub fn new(name: String) -> Self { + Self { name, room: None } + } + + #[send_handler] + fn handle_say_to_room(&mut self, msg: SayToRoom, _ctx: &Context) { + if let Some(ref room) = self.room { + let _ = room.say(self.name.clone(), msg.text); + } + } + + #[send_handler] + fn handle_join_room(&mut self, msg: JoinRoom, ctx: &Context) { + let _ = msg.room.add_member(self.name.clone(), ctx.recipient::()); + self.room = Some(msg.room); + } + + #[send_handler] + fn handle_deliver(&mut self, msg: Deliver, _ctx: &Context) { + tracing::info!("[{}] got: {} says '{}'", self.name, msg.from, msg.text); + } +} diff --git a/examples/name_server/src/main.rs b/examples/name_server/src/main.rs index 85fab9e..810d540 100644 --- a/examples/name_server/src/main.rs +++ b/examples/name_server/src/main.rs @@ -16,19 +16,19 @@ mod server; use messages::NameServerOutMessage; use server::NameServer; -use spawned_concurrency::tasks::Actor as _; +use spawned_concurrency::tasks::ActorStart; use spawned_rt::tasks as rt; fn main() { rt::run(async { - let mut name_server = NameServer::new().start(); + let name_server = NameServer::new().start(); let result = - NameServer::add(&mut name_server, "Joe".to_string(), "At Home".to_string()).await; + NameServer::add(&name_server, "Joe".to_string(), "At Home".to_string()).await; tracing::info!("Storing value result: {result:?}"); assert_eq!(result, NameServerOutMessage::Ok); - let result = NameServer::find(&mut name_server, "Joe".to_string()).await; + let result = NameServer::find(&name_server, "Joe".to_string()).await; tracing::info!("Retrieving value result: {result:?}"); assert_eq!( result, @@ -37,7 +37,7 @@ fn main() { } ); - let result = NameServer::find(&mut name_server, "Bob".to_string()).await; + let result = NameServer::find(&name_server, "Bob".to_string()).await; tracing::info!("Retrieving value result: {result:?}"); assert_eq!(result, NameServerOutMessage::NotFound); }) diff --git a/examples/name_server/src/messages.rs b/examples/name_server/src/messages.rs index b011cb2..6324c8f 100644 --- a/examples/name_server/src/messages.rs +++ b/examples/name_server/src/messages.rs @@ -1,10 +1,5 @@ -#[derive(Debug, Clone)] -pub enum NameServerInMessage { - Add { key: String, value: String }, - Find { key: String }, -} +use spawned_concurrency::message::Message; -#[allow(dead_code)] #[derive(Debug, Clone, PartialEq)] pub enum NameServerOutMessage { Ok, @@ -12,3 +7,20 @@ pub enum NameServerOutMessage { NotFound, Error, } + +#[derive(Debug)] +pub struct Add { + pub key: String, + pub value: String, +} +impl Message for Add { + type Result = NameServerOutMessage; +} + +#[derive(Debug)] +pub struct Find { + pub key: String, +} +impl Message for Find { + type Result = NameServerOutMessage; +} diff --git a/examples/name_server/src/server.rs b/examples/name_server/src/server.rs index 59a5c96..ae75f03 100644 --- a/examples/name_server/src/server.rs +++ b/examples/name_server/src/server.rs @@ -1,13 +1,8 @@ use std::collections::HashMap; -use spawned_concurrency::{ - messages::Unused, - tasks::{Actor, ActorRef, RequestResponse}, -}; +use spawned_concurrency::tasks::{Actor, ActorRef, Context, Handler}; -use crate::messages::{NameServerInMessage as InMessage, NameServerOutMessage as OutMessage}; - -type NameServerHandle = ActorRef; +use crate::messages::{Add, Find, NameServerOutMessage as OutMessage}; pub struct NameServer { inner: HashMap, @@ -22,44 +17,37 @@ impl NameServer { } impl NameServer { - pub async fn add(server: &mut NameServerHandle, key: String, value: String) -> OutMessage { - match server.request(InMessage::Add { key, value }).await { + pub async fn add(server: &ActorRef, key: String, value: String) -> OutMessage { + match server.request(Add { key, value }).await { Ok(_) => OutMessage::Ok, Err(_) => OutMessage::Error, } } - pub async fn find(server: &mut NameServerHandle, key: String) -> OutMessage { + pub async fn find(server: &ActorRef, key: String) -> OutMessage { server - .request(InMessage::Find { key }) + .request(Find { key }) .await .unwrap_or(OutMessage::Error) } } -impl Actor for NameServer { - type Request = InMessage; - type Message = Unused; - type Reply = OutMessage; - type Error = std::fmt::Error; +impl Actor for NameServer {} + +impl Handler for NameServer { + async fn handle(&mut self, msg: Add, _ctx: &Context) -> OutMessage { + self.inner.insert(msg.key, msg.value); + OutMessage::Ok + } +} - async fn handle_request( - &mut self, - message: Self::Request, - _handle: &NameServerHandle, - ) -> RequestResponse { - match message.clone() { - Self::Request::Add { key, value } => { - self.inner.insert(key, value); - RequestResponse::Reply(Self::Reply::Ok) - } - Self::Request::Find { key } => match self.inner.get(&key) { - Some(result) => { - let value = result.to_string(); - RequestResponse::Reply(Self::Reply::Found { value }) - } - None => RequestResponse::Reply(Self::Reply::NotFound), +impl Handler for NameServer { + async fn handle(&mut self, msg: Find, _ctx: &Context) -> OutMessage { + match self.inner.get(&msg.key) { + Some(result) => OutMessage::Found { + value: result.to_string(), }, + None => OutMessage::NotFound, } } } diff --git a/examples/service_discovery/Cargo.toml b/examples/service_discovery/Cargo.toml new file mode 100644 index 0000000..39860fe --- /dev/null +++ b/examples/service_discovery/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "service_discovery" +version.workspace = true +edition.workspace = true +license.workspace = true + +[dependencies] +spawned-rt = { workspace = true } +spawned-concurrency = { workspace = true } +spawned-macros = { workspace = true } +tracing = { workspace = true } diff --git a/examples/service_discovery/src/main.rs b/examples/service_discovery/src/main.rs new file mode 100644 index 0000000..7b282d4 --- /dev/null +++ b/examples/service_discovery/src/main.rs @@ -0,0 +1,102 @@ +use std::collections::HashMap; +use std::time::Duration; + +use spawned_concurrency::messages; +use spawned_concurrency::registry; +use spawned_concurrency::tasks::{Actor, ActorStart, Context, Handler, Recipient, request}; +use spawned_macros::actor; +use spawned_rt::tasks as rt; + +// --- Messages --- + +messages! { + Register { name: String, address: String } -> (); + Lookup { name: String } -> Option; + ListAll -> Vec<(String, String)> +} + +// --- ServiceRegistry actor --- + +struct ServiceRegistry { + services: HashMap, +} + +impl ServiceRegistry { + fn new() -> Self { + Self { + services: HashMap::new(), + } + } +} + +impl Actor for ServiceRegistry {} + +#[actor] +impl ServiceRegistry { + #[handler] + async fn handle_register(&mut self, msg: Register, _ctx: &Context) { + tracing::info!("Registered service '{}' at {}", msg.name, msg.address); + self.services.insert(msg.name, msg.address); + } + + #[handler] + async fn handle_lookup(&mut self, msg: Lookup, _ctx: &Context) -> Option { + self.services.get(&msg.name).cloned() + } + + #[handler] + async fn handle_list_all(&mut self, _msg: ListAll, _ctx: &Context) -> Vec<(String, String)> { + self.services.iter().map(|(k, v)| (k.clone(), v.clone())).collect() + } +} + +fn main() { + rt::run(async { + // Start the service registry actor + let svc = ServiceRegistry::new().start(); + + // Register it by name — other actors can discover it + registry::register("service_registry", svc.recipient::()).unwrap(); + + // Register some services + svc.send(Register { + name: "web".into(), + address: "http://localhost:8080".into(), + }) + .unwrap(); + + svc.send(Register { + name: "db".into(), + address: "postgres://localhost:5432".into(), + }) + .unwrap(); + + // A consumer discovers the registry by name (doesn't need to know ServiceRegistry type) + let lookup_recipient: Recipient = registry::whereis("service_registry").unwrap(); + + // Look up a service + let addr = request( + &*lookup_recipient, + Lookup { + name: "web".into(), + }, + Duration::from_secs(5), + ) + .await + .unwrap(); + tracing::info!("Looked up 'web': {:?}", addr); + + // List all registered names in the registry + let names = registry::registered(); + tracing::info!("Registry contains: {:?}", names); + + // Direct request for all services + let all = svc.request(ListAll).await.unwrap(); + tracing::info!("All services: {:?}", all); + + // Clean up + registry::unregister("service_registry"); + + tracing::info!("Service discovery demo complete"); + }); +} diff --git a/examples/signal_test/src/main.rs b/examples/signal_test/src/main.rs index 90e6d6b..5ecb404 100644 --- a/examples/signal_test/src/main.rs +++ b/examples/signal_test/src/main.rs @@ -1,27 +1,24 @@ //! Test to verify signal handling across different Actor backends (tasks version). //! //! This example demonstrates using `send_message_on` to handle Ctrl+C signals. -//! The signal handler is set up in the Actor's `init()` function. +//! The signal handler is set up in the Actor's `started()` function. //! //! Run with: cargo run --bin signal_test -- [async|blocking|thread] //! //! Then press Ctrl+C and observe: //! - Does the actor stop gracefully? -//! - Does teardown run? +//! - Does stopped run? -use spawned_concurrency::{ - messages::Unused, - tasks::{ - send_interval, send_message_on, Actor, ActorRef, Backend, InitResult, MessageResponse, - }, +use spawned_concurrency::tasks::{ + send_interval, send_message_on, Actor, ActorStart, Backend, Context, Handler, TimerHandle, }; -use spawned_rt::tasks::{self as rt, CancellationToken}; +use spawned_rt::tasks as rt; use std::{env, time::Duration}; struct TickingActor { name: String, count: u64, - timer_token: Option, + timer: Option, } impl TickingActor { @@ -29,61 +26,51 @@ impl TickingActor { Self { name: name.to_string(), count: 0, - timer_token: None, + timer: None, } } } -#[derive(Clone)] -enum Msg { - Tick, - Shutdown, +use spawned_concurrency::messages; + +messages! { + #[derive(Clone)] + Tick -> (); + Shutdown -> () } impl Actor for TickingActor { - type Request = Unused; - type Message = Msg; - type Reply = Unused; - type Error = (); - - async fn init(mut self, handle: &ActorRef) -> Result, Self::Error> { + async fn started(&mut self, ctx: &Context) { tracing::info!("[{}] Actor initialized", self.name); // Set up periodic ticking - let timer = send_interval(Duration::from_secs(1), handle.clone(), Msg::Tick); - self.timer_token = Some(timer.cancellation_token); + let timer = send_interval(Duration::from_secs(1), ctx.clone(), Tick); + self.timer = Some(timer); // Set up Ctrl+C handler using send_message_on - send_message_on(handle.clone(), rt::ctrl_c(), Msg::Shutdown); - - Ok(InitResult::Success(self)) - } - - async fn handle_message( - &mut self, - message: Self::Message, - _handle: &ActorRef, - ) -> MessageResponse { - match message { - Msg::Tick => { - self.count += 1; - tracing::info!("[{}] Tick #{}", self.name, self.count); - MessageResponse::NoReply - } - Msg::Shutdown => { - tracing::info!("[{}] Received shutdown signal", self.name); - MessageResponse::Stop - } - } + send_message_on(ctx.clone(), rt::ctrl_c(), Shutdown); } - async fn teardown(self, _handle: &ActorRef) -> Result<(), Self::Error> { + async fn stopped(&mut self, _ctx: &Context) { tracing::info!( - "[{}] Teardown called! Final count: {}", + "[{}] Stopped called! Final count: {}", self.name, self.count ); - Ok(()) + } +} + +impl Handler for TickingActor { + async fn handle(&mut self, _msg: Tick, _ctx: &Context) { + self.count += 1; + tracing::info!("[{}] Tick #{}", self.name, self.count); + } +} + +impl Handler for TickingActor { + async fn handle(&mut self, _msg: Shutdown, ctx: &Context) { + tracing::info!("[{}] Received shutdown signal", self.name); + ctx.stop(); } } diff --git a/examples/signal_test_threads/src/main.rs b/examples/signal_test_threads/src/main.rs index a0da2a0..47e84a9 100644 --- a/examples/signal_test_threads/src/main.rs +++ b/examples/signal_test_threads/src/main.rs @@ -1,25 +1,24 @@ //! Test to verify signal handling for threads Actor. //! //! This example demonstrates using `send_message_on` to handle Ctrl+C signals. -//! The signal handler is set up in the Actor's `init()` function. +//! The signal handler is set up in the Actor's `started()` function. //! //! Run with: cargo run --bin signal_test_threads //! //! Then press Ctrl+C and observe: //! - Does the actor stop gracefully? -//! - Does teardown run? +//! - Does stopped run? -use spawned_concurrency::{ - messages::Unused, - threads::{send_interval, send_message_on, Actor, ActorRef, InitResult, MessageResponse}, +use spawned_concurrency::threads::{ + send_interval, send_message_on, Actor, ActorStart, Context, Handler, TimerHandle, }; -use spawned_rt::threads::{self as rt, CancellationToken}; +use spawned_rt::threads as rt; use std::time::Duration; struct TickingActor { name: String, count: u64, - timer_token: Option, + timer: Option, } impl TickingActor { @@ -27,61 +26,51 @@ impl TickingActor { Self { name: name.to_string(), count: 0, - timer_token: None, + timer: None, } } } -#[derive(Clone)] -enum Msg { - Tick, - Shutdown, +use spawned_concurrency::messages; + +messages! { + #[derive(Clone)] + Tick -> (); + Shutdown -> () } impl Actor for TickingActor { - type Request = Unused; - type Message = Msg; - type Reply = Unused; - type Error = (); - - fn init(mut self, handle: &ActorRef) -> Result, Self::Error> { + fn started(&mut self, ctx: &Context) { tracing::info!("[{}] Actor initialized", self.name); // Set up periodic ticking - let timer = send_interval(Duration::from_secs(1), handle.clone(), Msg::Tick); - self.timer_token = Some(timer.cancellation_token); + let timer = send_interval(Duration::from_secs(1), ctx.clone(), Tick); + self.timer = Some(timer); // Set up Ctrl+C handler using send_message_on - send_message_on(handle.clone(), rt::ctrl_c(), Msg::Shutdown); - - Ok(InitResult::Success(self)) - } - - fn handle_message( - &mut self, - message: Self::Message, - _handle: &ActorRef, - ) -> MessageResponse { - match message { - Msg::Tick => { - self.count += 1; - tracing::info!("[{}] Tick #{}", self.name, self.count); - MessageResponse::NoReply - } - Msg::Shutdown => { - tracing::info!("[{}] Received shutdown signal", self.name); - MessageResponse::Stop - } - } + send_message_on(ctx.clone(), rt::ctrl_c(), Shutdown); } - fn teardown(self, _handle: &ActorRef) -> Result<(), Self::Error> { + fn stopped(&mut self, _ctx: &Context) { tracing::info!( - "[{}] Teardown called! Final count: {}", + "[{}] Stopped called! Final count: {}", self.name, self.count ); - Ok(()) + } +} + +impl Handler for TickingActor { + fn handle(&mut self, _msg: Tick, _ctx: &Context) { + self.count += 1; + tracing::info!("[{}] Tick #{}", self.name, self.count); + } +} + +impl Handler for TickingActor { + fn handle(&mut self, _msg: Shutdown, ctx: &Context) { + tracing::info!("[{}] Received shutdown signal", self.name); + ctx.stop(); } } diff --git a/examples/updater/src/main.rs b/examples/updater/src/main.rs index 0a6aaf0..d046b33 100644 --- a/examples/updater/src/main.rs +++ b/examples/updater/src/main.rs @@ -9,7 +9,7 @@ mod server; use std::{thread, time::Duration}; use server::UpdaterServer; -use spawned_concurrency::tasks::Actor as _; +use spawned_concurrency::tasks::ActorStart; use spawned_rt::tasks as rt; fn main() { diff --git a/examples/updater/src/messages.rs b/examples/updater/src/messages.rs index daa0589..2450b76 100644 --- a/examples/updater/src/messages.rs +++ b/examples/updater/src/messages.rs @@ -1,11 +1,6 @@ -#[derive(Debug, Clone)] -pub enum UpdaterInMessage { - Check, -} +use spawned_concurrency::messages; -#[allow(dead_code)] -#[derive(Debug, Clone, PartialEq)] -pub enum UpdaterOutMessage { - Ok, - Error, +messages! { + #[derive(Clone)] + Check -> () } diff --git a/examples/updater/src/server.rs b/examples/updater/src/server.rs index 2c1b02e..ce4021b 100644 --- a/examples/updater/src/server.rs +++ b/examples/updater/src/server.rs @@ -1,23 +1,13 @@ use std::time::Duration; -use spawned_concurrency::{ - messages::Unused, - tasks::{ - send_interval, Actor, ActorRef, - InitResult::{self, Success}, - MessageResponse, - }, -}; -use spawned_rt::tasks::CancellationToken; +use spawned_concurrency::tasks::{send_interval, Actor, Context, Handler, TimerHandle}; -use crate::messages::{UpdaterInMessage as InMessage, UpdaterOutMessage as OutMessage}; - -type UpdateServerHandle = ActorRef; +use crate::messages::Check; pub struct UpdaterServer { pub url: String, pub periodicity: Duration, - pub timer_token: Option, + pub timer: Option, } impl UpdaterServer { @@ -25,38 +15,24 @@ impl UpdaterServer { UpdaterServer { url, periodicity, - timer_token: None, + timer: None, } } } impl Actor for UpdaterServer { - type Request = Unused; - type Message = InMessage; - type Reply = OutMessage; - type Error = std::fmt::Error; - - // Initializing Actor to start periodic checks. - async fn init(mut self, handle: &ActorRef) -> Result, Self::Error> { - let timer = send_interval(self.periodicity, handle.clone(), InMessage::Check); - self.timer_token = Some(timer.cancellation_token); - Ok(Success(self)) + async fn started(&mut self, ctx: &Context) { + let timer = send_interval(self.periodicity, ctx.clone(), Check); + self.timer = Some(timer); } +} - async fn handle_message( - &mut self, - message: Self::Message, - _handle: &UpdateServerHandle, - ) -> MessageResponse { - match message { - Self::Message::Check => { - let url = self.url.clone(); - tracing::info!("Fetching: {url}"); - let resp = req(url).await; - tracing::info!("Response: {resp:?}"); - MessageResponse::NoReply - } - } +impl Handler for UpdaterServer { + async fn handle(&mut self, _msg: Check, _ctx: &Context) { + let url = self.url.clone(); + tracing::info!("Fetching: {url}"); + let resp = req(url).await; + tracing::info!("Response: {resp:?}"); } } diff --git a/examples/updater_threads/src/main.rs b/examples/updater_threads/src/main.rs index 5b7ceb3..50255fe 100644 --- a/examples/updater_threads/src/main.rs +++ b/examples/updater_threads/src/main.rs @@ -9,7 +9,7 @@ mod server; use std::{thread, time::Duration}; use server::UpdaterServer; -use spawned_concurrency::threads::Actor as _; +use spawned_concurrency::threads::ActorStart; use spawned_rt::threads as rt; fn main() { diff --git a/examples/updater_threads/src/messages.rs b/examples/updater_threads/src/messages.rs index daa0589..2450b76 100644 --- a/examples/updater_threads/src/messages.rs +++ b/examples/updater_threads/src/messages.rs @@ -1,11 +1,6 @@ -#[derive(Debug, Clone)] -pub enum UpdaterInMessage { - Check, -} +use spawned_concurrency::messages; -#[allow(dead_code)] -#[derive(Debug, Clone, PartialEq)] -pub enum UpdaterOutMessage { - Ok, - Error, +messages! { + #[derive(Clone)] + Check -> () } diff --git a/examples/updater_threads/src/server.rs b/examples/updater_threads/src/server.rs index 2a931ff..b427dfa 100644 --- a/examples/updater_threads/src/server.rs +++ b/examples/updater_threads/src/server.rs @@ -1,50 +1,28 @@ use std::time::Duration; -use spawned_concurrency::{ - messages::Unused, - threads::{send_after, Actor, ActorRef, InitResult, MessageResponse}, -}; +use spawned_concurrency::threads::{send_after, Actor, Context, Handler}; use spawned_rt::threads::block_on; -use crate::messages::{UpdaterInMessage as InMessage, UpdaterOutMessage as OutMessage}; +use crate::messages::Check; -type UpdateServerHandle = ActorRef; - -#[derive(Clone)] pub struct UpdaterServer { pub url: String, pub periodicity: Duration, } impl Actor for UpdaterServer { - type Request = Unused; - type Message = InMessage; - type Reply = OutMessage; - type Error = std::fmt::Error; - - // Initializing Actor to start periodic checks. - fn init(self, handle: &ActorRef) -> Result, Self::Error> { - send_after(self.periodicity, handle.clone(), InMessage::Check); - Ok(InitResult::Success(self)) + fn started(&mut self, ctx: &Context) { + send_after(self.periodicity, ctx.clone(), Check); } +} - fn handle_message( - &mut self, - message: Self::Message, - handle: &UpdateServerHandle, - ) -> MessageResponse { - match message { - Self::Message::Check => { - send_after(self.periodicity, handle.clone(), InMessage::Check); - let url = self.url.clone(); - tracing::info!("Fetching: {url}"); - let resp = block_on(req(url)); - - tracing::info!("Response: {resp:?}"); - - MessageResponse::NoReply - } - } +impl Handler for UpdaterServer { + fn handle(&mut self, _msg: Check, ctx: &Context) { + send_after(self.periodicity, ctx.clone(), Check); + let url = self.url.clone(); + tracing::info!("Fetching: {url}"); + let resp = block_on(req(url)); + tracing::info!("Response: {resp:?}"); } } diff --git a/macros/Cargo.toml b/macros/Cargo.toml new file mode 100644 index 0000000..c1e3e15 --- /dev/null +++ b/macros/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "spawned-macros" +description = "Proc macros for the Spawned actor framework" +version.workspace = true +edition.workspace = true +license.workspace = true + +[lib] +proc-macro = true + +[dependencies] +syn = { version = "2", features = ["full"] } +quote = "1" +proc-macro2 = "1" diff --git a/macros/src/lib.rs b/macros/src/lib.rs new file mode 100644 index 0000000..f8dda17 --- /dev/null +++ b/macros/src/lib.rs @@ -0,0 +1,112 @@ +use proc_macro::TokenStream; +use quote::quote; +use syn::{parse_macro_input, FnArg, ImplItem, ItemImpl, Pat, ReturnType, Type}; + +/// Attribute macro for actor impl blocks. +/// +/// Place `#[actor]` on an `impl MyActor` block containing methods annotated +/// with `#[send_handler]` or `#[request_handler]`. For each annotated method, +/// the macro generates a corresponding `impl Handler for MyActor` block. +/// +/// Use `#[send_handler]` for fire-and-forget messages (no return value): +/// +/// ```ignore +/// #[send_handler] +/// async fn on_deposit(&mut self, msg: Deposit, ctx: &Context) { ... } +/// ``` +/// +/// Use `#[request_handler]` for request-response messages (returns a value): +/// +/// ```ignore +/// #[request_handler] +/// async fn on_balance(&mut self, msg: GetBalance, ctx: &Context) -> u64 { ... } +/// ``` +/// +/// Sync handlers (for the `threads` module) omit `async`: +/// +/// ```ignore +/// #[send_handler] +/// fn on_deposit(&mut self, msg: Deposit, ctx: &Context) { ... } +/// ``` +/// +/// The generic `#[handler]` attribute is also supported for backwards +/// compatibility and works for both send and request handlers. +#[proc_macro_attribute] +pub fn actor(_attr: TokenStream, item: TokenStream) -> TokenStream { + let mut impl_block = parse_macro_input!(item as ItemImpl); + + let self_ty = &impl_block.self_ty; + let (impl_generics, _, where_clause) = impl_block.generics.split_for_impl(); + + let mut handler_impls = Vec::new(); + + for item in &mut impl_block.items { + if let ImplItem::Fn(method) = item { + let handler_idx = method.attrs.iter().position(|attr| { + attr.path().is_ident("handler") + || attr.path().is_ident("send_handler") + || attr.path().is_ident("request_handler") + }); + + if let Some(idx) = handler_idx { + method.attrs.remove(idx); + + let method_name = &method.sig.ident; + let is_async = method.sig.asyncness.is_some(); + + // Extract message type from 2nd parameter (index 1, after &mut self) + let msg_ty = match method.sig.inputs.iter().nth(1) { + Some(FnArg::Typed(pat_type)) => { + if let Pat::Ident(pat_ident) = &*pat_type.pat { + if pat_ident.ident == "_" || pat_ident.ident.to_string().starts_with('_') { + // Still use the type + } + } + &*pat_type.ty + } + _ => { + return syn::Error::new_spanned( + &method.sig, + "handler method must have signature: fn(&mut self, msg: M, ctx: &Context) -> R", + ) + .to_compile_error() + .into(); + } + }; + + // Extract return type (default to () if omitted) + let ret_ty: Box = match &method.sig.output { + ReturnType::Default => syn::parse_quote! { () }, + ReturnType::Type(_, ty) => ty.clone(), + }; + + let handler_impl = if is_async { + quote! { + impl #impl_generics Handler<#msg_ty> for #self_ty #where_clause { + async fn handle(&mut self, msg: #msg_ty, ctx: &Context) -> #ret_ty { + self.#method_name(msg, ctx).await + } + } + } + } else { + quote! { + impl #impl_generics Handler<#msg_ty> for #self_ty #where_clause { + fn handle(&mut self, msg: #msg_ty, ctx: &Context) -> #ret_ty { + self.#method_name(msg, ctx) + } + } + } + }; + + handler_impls.push(handler_impl); + } + } + } + + let output = quote! { + #impl_block + #(#handler_impls)* + }; + + output.into() +}