From 4c1a4f687d2ad2d5b7309b36ea1f00596f26f527 Mon Sep 17 00:00:00 2001 From: Dhruv Vats Date: Thu, 5 Mar 2026 12:38:33 +0530 Subject: [PATCH] Add data_batch_probe --- include/cucascade/data/data_batch.hpp | 78 ++++++++++++++++++++++----- src/data/data_batch.cpp | 49 +++++++++++++---- 2 files changed, 104 insertions(+), 23 deletions(-) diff --git a/include/cucascade/data/data_batch.hpp b/include/cucascade/data/data_batch.hpp index 14fbf43..07419da 100644 --- a/include/cucascade/data/data_batch.hpp +++ b/include/cucascade/data/data_batch.hpp @@ -58,6 +58,34 @@ enum class batch_state { in_transit ///< Batch is currently being moved to a different memory tier }; +/** + * @brief Interface for probing the data_batch class. + * + * Applications may implement this interface to hold additional application specific + * data_batch metadata while probing the data_batch by overriding the provided methods + * that expose the data_batch state when certain events occur, like state transitions. + */ +class idata_batch_probe { + public: + virtual ~idata_batch_probe() = default; + + /** + * @brief The function that is called everytime a data_batch transitions into a new batch_state. + * + * @note It is the implementer's responsibility that a call to this function returns quickly, as + * this function is called in a thread-safe manner, and will block other mutating changes + * while this executes. Primarily intended for bookkeeping purposes. + */ + virtual void state_transitioned_to([[maybe_unused]] const batch_state& new_state, + [[maybe_unused]] const uint64_t& batch_id, + [[maybe_unused]] const idata_representation& data, + [[maybe_unused]] const size_t& processing_count, + [[maybe_unused]] const size_t& task_created_count) + { + // The default impl is no-op. + } +}; + // Forward declarations class data_batch; class data_batch_processing_handle; @@ -192,8 +220,11 @@ class data_batch : public std::enable_shared_from_this { * * @param batch_id Unique identifier for this batch * @param data Ownership of the data representation is transferred to this batch + * @param probe @copydoc idata_batch_probe */ - data_batch(uint64_t batch_id, std::unique_ptr data); + data_batch(uint64_t batch_id, + std::unique_ptr data, + std::unique_ptr probe = std::make_unique()); /** * @brief Move constructor - transfers ownership of the batch and its data. @@ -303,14 +334,17 @@ class data_batch : public std::enable_shared_from_this { * @param new_batch_id The batch ID for the cloned batch * @param target_memory_space The memory space where the new representation will be allocated * @param stream CUDA stream for memory operations + * @param probe Optional @copydoc idata_batch_probe * @return std::shared_ptr A new data_batch with cloned data * @throws std::runtime_error if there is active processing on this batch */ template - std::shared_ptr clone_to(representation_converter_registry& registry, - uint64_t new_batch_id, - const cucascade::memory::memory_space* target_memory_space, - rmm::cuda_stream_view stream); + std::shared_ptr clone_to( + representation_converter_registry& registry, + uint64_t new_batch_id, + const cucascade::memory::memory_space* target_memory_space, + rmm::cuda_stream_view stream, + std::optional> probe = std::nullopt); /** * @brief Attempt to create a task for this batch. @@ -399,11 +433,15 @@ class data_batch : public std::enable_shared_from_this { * * @param new_batch_id The batch ID for the cloned batch * @param stream CUDA stream for memory operations + * @param probe Optional @copydoc idata_batch_probe * @return std::shared_ptr A new data_batch with copied data * @throws std::runtime_error if the batch is in in_transit state * @throws std::runtime_error if the underlying data is null */ - std::shared_ptr clone(uint64_t new_batch_id, rmm::cuda_stream_view stream); + std::shared_ptr clone( + uint64_t new_batch_id, + rmm::cuda_stream_view stream, + std::optional> probe = std::nullopt); private: friend class data_batch_processing_handle; @@ -416,12 +454,24 @@ class data_batch : public std::enable_shared_from_this { */ void decrement_processing_count(); + /** + * @brief Handle a state transition by updating _state and invoking the probe callback. + * + * Must be called while the caller holds _mutex. Passes all member fields except _mutex + * and _state_change_cv as const references to the probe callback. + * + * @param new_state The state to transition to + * @param lock Reference to the lock_guard proving the mutex is held + */ + void update_state_to(batch_state new_state, const std::lock_guard& lock); + mutable std::mutex _mutex; ///< Mutex for thread-safe access to state and processing count uint64_t _batch_id; ///< Unique identifier for this data batch - std::unique_ptr _data; ///< Pointer to the actual data representation - size_t _processing_count = 0; ///< Count of active processing handles - size_t _task_created_count = 0; ///< Count of pending task_created requests - batch_state _state = batch_state::idle; ///< Current state of the batch + std::unique_ptr _data; ///< Pointer to the actual data representation + size_t _processing_count = 0; ///< Count of active processing handles + size_t _task_created_count = 0; ///< Count of pending task_created requests + batch_state _state = batch_state::idle; ///< Current state of the batch + std::unique_ptr _probe; ///< Probe for observing state transitions std::condition_variable* _state_change_cv = nullptr; ///< Optional CV to notify on state change }; @@ -447,7 +497,8 @@ std::shared_ptr data_batch::clone_to( representation_converter_registry& registry, uint64_t new_batch_id, const cucascade::memory::memory_space* target_memory_space, - rmm::cuda_stream_view stream) + rmm::cuda_stream_view stream, + std::optional> probe) { std::lock_guard lock(_mutex); @@ -457,7 +508,10 @@ std::shared_ptr data_batch::clone_to( auto new_representation = registry.convert(*_data, target_memory_space, stream); - return std::make_shared(new_batch_id, std::move(new_representation)); + return std::make_shared( + new_batch_id, + std::move(new_representation), + probe ? std::move(*probe) : std::make_unique()); } } // namespace cucascade diff --git a/src/data/data_batch.cpp b/src/data/data_batch.cpp index dc7ad33..e3ec15d 100644 --- a/src/data/data_batch.cpp +++ b/src/data/data_batch.cpp @@ -18,6 +18,8 @@ #include #include +#include + namespace cucascade { // data_batch_processing_handle implementation @@ -34,9 +36,13 @@ void data_batch_processing_handle::release() // data_batch implementation -data_batch::data_batch(uint64_t batch_id, std::unique_ptr data) - : _batch_id(batch_id), _data(std::move(data)) +data_batch::data_batch(uint64_t batch_id, + std::unique_ptr data, + std::unique_ptr probe) + : _batch_id(batch_id), _data(std::move(data)), _probe(std::move(probe)) { + std::lock_guard lock(_mutex); + update_state_to(batch_state::idle, lock); } data_batch::data_batch(data_batch&& other) @@ -116,8 +122,8 @@ bool data_batch::try_to_create_task() { std::lock_guard lock(_mutex); if (_state == batch_state::idle) { - _state = batch_state::task_created; ++_task_created_count; + update_state_to(batch_state::task_created, lock); should_notify = true; cv_to_notify = _state_change_cv; success = true; @@ -155,7 +161,7 @@ bool data_batch::try_to_cancel_task() } --_task_created_count; if (_task_created_count == 0 && _processing_count == 0) { - _state = batch_state::idle; + update_state_to(batch_state::idle, lock); should_notify = true; cv_to_notify = _state_change_cv; } @@ -196,7 +202,7 @@ lock_for_processing_result data_batch::try_to_lock_for_processing( } --_task_created_count; ++_processing_count; - _state = batch_state::processing; + update_state_to(batch_state::processing, lock); should_notify = true; cv_to_notify = _state_change_cv; result = { @@ -216,7 +222,7 @@ bool data_batch::try_to_lock_for_in_transit() if (_processing_count == 0 && ((_state == batch_state::idle) || (_state == batch_state::task_created && _task_created_count > 0))) { - _state = batch_state::in_transit; + update_state_to(batch_state::in_transit, lock); should_notify = true; cv_to_notify = _state_change_cv; success = true; @@ -236,9 +242,16 @@ bool data_batch::try_to_release_in_transit(std::optional target_sta if (_state == batch_state::in_transit) { // Caller can explicitly choose the state to return to; default is idle. if (target_state.has_value()) { - _state = *target_state; + if (*target_state == batch_state::idle) { + update_state_to(batch_state::idle, lock); + } else { + // first transition to idle + update_state_to(batch_state::idle, lock); + // then to the next state, maintaining FSM invariants + update_state_to(*target_state, lock); + } } else { - _state = batch_state::idle; + update_state_to(batch_state::idle, lock); } should_notify = true; cv_to_notify = _state_change_cv; @@ -267,7 +280,8 @@ void data_batch::decrement_processing_count() _processing_count -= 1; if (_processing_count == 0) { // Preserve pending task_created intent if any remain - _state = (_task_created_count > 0) ? batch_state::task_created : batch_state::idle; + update_state_to((_task_created_count > 0) ? batch_state::task_created : batch_state::idle, + lock); should_notify = true; cv_to_notify = _state_change_cv; } @@ -275,7 +289,10 @@ void data_batch::decrement_processing_count() if (should_notify && cv_to_notify) { cv_to_notify->notify_all(); } } -std::shared_ptr data_batch::clone(uint64_t new_batch_id, rmm::cuda_stream_view stream) +std::shared_ptr data_batch::clone( + uint64_t new_batch_id, + rmm::cuda_stream_view stream, + std::optional> probe) { // Create a task and lock for processing to protect data during clone if (!try_to_create_task()) { @@ -295,7 +312,17 @@ std::shared_ptr data_batch::clone(uint64_t new_batch_id, rmm::cuda_s cloned_data = _data->clone(stream); } // Handle destructor will decrement processing count when result goes out of scope - return std::make_shared(new_batch_id, std::move(cloned_data)); + return std::make_shared( + new_batch_id, + std::move(cloned_data), + probe ? std::move(*probe) : std::make_unique()); +} + +void data_batch::update_state_to(batch_state new_state, const std::lock_guard& /*lock*/) +{ + _state = new_state; + _probe->state_transitioned_to( + _state, _batch_id, *(this->get_data()), _processing_count, _task_created_count); } } // namespace cucascade