From ad7be99888103c13eac31eb8180141a95236ef05 Mon Sep 17 00:00:00 2001 From: Amin Aramoon Date: Wed, 4 Mar 2026 15:04:34 -0800 Subject: [PATCH] create blocking state change apis --- include/cucascade/data/data_batch.hpp | 53 ++- src/data/data_batch.cpp | 129 ++++++ test/data/test_data_batch.cpp | 542 ++++++++++++++++++++++++++ 3 files changed, 723 insertions(+), 1 deletion(-) diff --git a/include/cucascade/data/data_batch.hpp b/include/cucascade/data/data_batch.hpp index 14fbf43..b0b873f 100644 --- a/include/cucascade/data/data_batch.hpp +++ b/include/cucascade/data/data_batch.hpp @@ -271,6 +271,56 @@ class data_batch : public std::enable_shared_from_this { * The CV is notified outside of the batch mutex. */ void set_state_change_cv(std::condition_variable* cv); + + /** + * @brief Blocking call: wait until the batch can accept a new task, then create it. + * + * Blocks until the batch leaves in_transit state, then performs the same transition + * as try_to_create_task(). Always succeeds when it returns. + */ + void wait_to_create_task(); + + /** + * @brief Blocking call: wait until a created task can be cancelled, then cancel it. + * + * Blocks until the batch is in task_created or processing state, then performs + * the same transition as try_to_cancel_task(). Always succeeds when it returns. + */ + void wait_to_cancel_task(); + + /** + * @brief Blocking call: wait until the batch can be locked for processing, then lock it. + * + * Blocks until the batch is in task_created or processing state with a pending + * task_created_count. Non-waitable failures (missing_data, memory_space_mismatch) + * are returned immediately with the appropriate status. + * + * @param requested_memory_space The memory space the caller expects to process from. + * @return lock_for_processing_result success=true with handle on success; success=false + * with status describing non-waitable failure. + */ + lock_for_processing_result wait_to_lock_for_processing( + memory::memory_space_id requested_memory_space); + + /** + * @brief Blocking call: wait until the batch can be locked for in-transit, then lock it. + * + * Blocks until processing_count == 0 and the batch is in idle or task_created state, + * then performs the same transition as try_to_lock_for_in_transit(). Always succeeds + * when it returns. + */ + void wait_to_lock_for_in_transit(); + + /** + * @brief Blocking call: wait until the batch is in in_transit state, then release it. + * + * Blocks until the batch is in in_transit state, then performs the same transition + * as try_to_release_in_transit(). Always succeeds when it returns. + * + * @param target_state Optional state to transition to when releasing in_transit. If not set, + * the batch returns to idle. + */ + void wait_to_release_in_transit(std::optional target_state = std::nullopt); /** * @brief Replace the underlying data representation. * Requires no active processing. @@ -417,7 +467,8 @@ class data_batch : public std::enable_shared_from_this { void decrement_processing_count(); mutable std::mutex _mutex; ///< Mutex for thread-safe access to state and processing count - uint64_t _batch_id; ///< Unique identifier for this data batch + std::condition_variable _internal_cv; ///< CV used by blocking wait_to_* calls + uint64_t _batch_id; ///< Unique identifier for this data batch std::unique_ptr _data; ///< Pointer to the actual data representation size_t _processing_count = 0; ///< Count of active processing handles size_t _task_created_count = 0; ///< Count of pending task_created requests diff --git a/src/data/data_batch.cpp b/src/data/data_batch.cpp index dc7ad33..f1284eb 100644 --- a/src/data/data_batch.cpp +++ b/src/data/data_batch.cpp @@ -130,10 +130,32 @@ bool data_batch::try_to_create_task() success = true; } } + if (should_notify) { _internal_cv.notify_all(); } if (should_notify && cv_to_notify) { cv_to_notify->notify_all(); } return success; } +void data_batch::wait_to_create_task() +{ + std::condition_variable* cv_to_notify = nullptr; + bool should_notify = false; + { + std::unique_lock lock(_mutex); + _internal_cv.wait(lock, [&] { return _state != batch_state::in_transit; }); + if (_state == batch_state::idle) { + _state = batch_state::task_created; + ++_task_created_count; + should_notify = true; + cv_to_notify = _state_change_cv; + } else { + // task_created or processing: just increment counter + ++_task_created_count; + } + } + if (should_notify) { _internal_cv.notify_all(); } + if (should_notify && cv_to_notify) { cv_to_notify->notify_all(); } +} + size_t data_batch::get_task_created_count() const { std::lock_guard lock(_mutex); @@ -162,9 +184,35 @@ bool data_batch::try_to_cancel_task() success = true; } } + if (should_notify) { _internal_cv.notify_all(); } if (should_notify && cv_to_notify) { cv_to_notify->notify_all(); } return success; } + +void data_batch::wait_to_cancel_task() +{ + std::condition_variable* cv_to_notify = nullptr; + bool should_notify = false; + { + std::unique_lock lock(_mutex); + _internal_cv.wait(lock, [&] { + return _state == batch_state::task_created || _state == batch_state::processing; + }); + if (_task_created_count == 0) { + throw std::runtime_error( + "Cannot cancel task: task_created_count is zero. " + "try_to_create_task() must be called before wait_to_cancel_task()"); + } + --_task_created_count; + if (_task_created_count == 0 && _processing_count == 0) { + _state = batch_state::idle; + should_notify = true; + cv_to_notify = _state_change_cv; + } + } + if (should_notify) { _internal_cv.notify_all(); } + if (should_notify && cv_to_notify) { cv_to_notify->notify_all(); } +} lock_for_processing_result data_batch::try_to_lock_for_processing( memory::memory_space_id requested_memory_space) { @@ -202,10 +250,58 @@ lock_for_processing_result data_batch::try_to_lock_for_processing( result = { true, data_batch_processing_handle{shared_from_this()}, lock_for_processing_status::success}; } + if (should_notify) { _internal_cv.notify_all(); } if (should_notify && cv_to_notify) { cv_to_notify->notify_all(); } return result; } +lock_for_processing_result data_batch::wait_to_lock_for_processing( + memory::memory_space_id requested_memory_space) +{ + std::condition_variable* cv_to_notify = nullptr; + lock_for_processing_result result{ + false, data_batch_processing_handle{}, lock_for_processing_status::not_attempted}; + { + std::unique_lock lock(_mutex); + + // Return immediately for failures that cannot be resolved by waiting + if (_data == nullptr) { + result.status = lock_for_processing_status::missing_data; + return result; + } + if (_data->get_memory_space().get_id() != requested_memory_space) { + result.status = lock_for_processing_status::memory_space_mismatch; + return result; + } + + // Wait until the state allows locking for processing + _internal_cv.wait(lock, [&] { + return (_state == batch_state::task_created || _state == batch_state::processing) && + _task_created_count > 0; + }); + + // Re-check non-waitable conditions after waking (data may have changed) + if (_data == nullptr) { + result.status = lock_for_processing_status::missing_data; + return result; + } + if (_data->get_memory_space().get_id() != requested_memory_space) { + result.status = lock_for_processing_status::memory_space_mismatch; + return result; + } + + --_task_created_count; + ++_processing_count; + _state = batch_state::processing; + cv_to_notify = _state_change_cv; + result = { + true, data_batch_processing_handle{shared_from_this()}, lock_for_processing_status::success}; + } + _internal_cv.notify_all(); + if (cv_to_notify) { cv_to_notify->notify_all(); } + return result; +} + bool data_batch::try_to_lock_for_in_transit() { std::condition_variable* cv_to_notify = nullptr; @@ -222,10 +318,28 @@ bool data_batch::try_to_lock_for_in_transit() success = true; } } + if (should_notify) { _internal_cv.notify_all(); } if (should_notify && cv_to_notify) { cv_to_notify->notify_all(); } return success; } +void data_batch::wait_to_lock_for_in_transit() +{ + std::condition_variable* cv_to_notify = nullptr; + { + std::unique_lock lock(_mutex); + _internal_cv.wait(lock, [&] { + return _processing_count == 0 && + ((_state == batch_state::idle) || + (_state == batch_state::task_created && _task_created_count > 0)); + }); + _state = batch_state::in_transit; + cv_to_notify = _state_change_cv; + } + _internal_cv.notify_all(); + if (cv_to_notify) { cv_to_notify->notify_all(); } +} + bool data_batch::try_to_release_in_transit(std::optional target_state) { std::condition_variable* cv_to_notify = nullptr; @@ -245,10 +359,24 @@ bool data_batch::try_to_release_in_transit(std::optional target_sta success = true; } } + if (should_notify) { _internal_cv.notify_all(); } if (should_notify && cv_to_notify) { cv_to_notify->notify_all(); } return success; } +void data_batch::wait_to_release_in_transit(std::optional target_state) +{ + std::condition_variable* cv_to_notify = nullptr; + { + std::unique_lock lock(_mutex); + _internal_cv.wait(lock, [&] { return _state == batch_state::in_transit; }); + _state = target_state.has_value() ? *target_state : batch_state::idle; + cv_to_notify = _state_change_cv; + } + _internal_cv.notify_all(); + if (cv_to_notify) { cv_to_notify->notify_all(); } +} + void data_batch::decrement_processing_count() { std::condition_variable* cv_to_notify = nullptr; @@ -272,6 +400,7 @@ void data_batch::decrement_processing_count() cv_to_notify = _state_change_cv; } } + if (should_notify) { _internal_cv.notify_all(); } if (should_notify && cv_to_notify) { cv_to_notify->notify_all(); } } diff --git a/test/data/test_data_batch.cpp b/test/data/test_data_batch.cpp index e874360..a00b051 100644 --- a/test/data/test_data_batch.cpp +++ b/test/data/test_data_batch.cpp @@ -25,6 +25,8 @@ #include +#include +#include #include #include #include @@ -1103,3 +1105,543 @@ TEST_CASE("data_batch clone state transitions correctly", "[data_batch][gpu]") REQUIRE(cloned->get_state() == batch_state::idle); } } + +// ============================================================================= +// Blocking wait_to_* method tests +// ============================================================================= + +// Helper: wait for a flag with a short timeout, used to confirm blocking +static void wait_for_flag(const std::atomic& flag) +{ + for (int i = 0; i < 200; ++i) { + if (flag.load()) return; + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + } +} + +// --- wait_to_create_task --- + +TEST_CASE("data_batch wait_to_create_task from idle succeeds immediately", "[data_batch][blocking]") +{ + auto data = std::make_unique(memory::Tier::GPU, 1024); + auto batch = std::make_shared(1, std::move(data)); + + REQUIRE(batch->get_state() == batch_state::idle); + batch->wait_to_create_task(); + REQUIRE(batch->get_state() == batch_state::task_created); + REQUIRE(batch->get_task_created_count() == 1); +} + +TEST_CASE("data_batch wait_to_create_task from task_created increments counter", + "[data_batch][blocking]") +{ + auto data = std::make_unique(memory::Tier::GPU, 1024); + auto batch = std::make_shared(1, std::move(data)); + + REQUIRE(batch->try_to_create_task() == true); + REQUIRE(batch->get_task_created_count() == 1); + + batch->wait_to_create_task(); + REQUIRE(batch->get_state() == batch_state::task_created); + REQUIRE(batch->get_task_created_count() == 2); +} + +TEST_CASE("data_batch wait_to_create_task from processing increments counter", + "[data_batch][blocking]") +{ + auto data = std::make_unique(memory::Tier::GPU, 1024); + auto batch = std::make_shared(1, std::move(data)); + auto space_id = batch->get_memory_space()->get_id(); + + REQUIRE(batch->try_to_create_task() == true); + auto r = batch->try_to_lock_for_processing(space_id); + REQUIRE(r.success == true); + auto handle = std::move(r.handle); + REQUIRE(batch->get_state() == batch_state::processing); + + batch->wait_to_create_task(); + REQUIRE(batch->get_state() == batch_state::processing); + REQUIRE(batch->get_task_created_count() == 1); +} + +TEST_CASE("data_batch wait_to_create_task blocks while in_transit then succeeds", + "[data_batch][blocking]") +{ + auto data = std::make_unique(memory::Tier::GPU, 1024); + auto batch = std::make_shared(1, std::move(data)); + + REQUIRE(batch->try_to_lock_for_in_transit() == true); + REQUIRE(batch->get_state() == batch_state::in_transit); + + std::atomic wait_started{false}; + std::atomic wait_done{false}; + + std::thread waiter([batch, &wait_started, &wait_done]() { + wait_started.store(true); + batch->wait_to_create_task(); + wait_done.store(true); + }); + + wait_for_flag(wait_started); + std::this_thread::sleep_for(std::chrono::milliseconds(20)); + REQUIRE(wait_done.load() == false); // still blocked + + REQUIRE(batch->try_to_release_in_transit() == true); + waiter.join(); + + REQUIRE(wait_done.load() == true); + REQUIRE(batch->get_state() == batch_state::task_created); + REQUIRE(batch->get_task_created_count() == 1); +} + +// --- wait_to_cancel_task --- + +TEST_CASE("data_batch wait_to_cancel_task from task_created returns to idle", + "[data_batch][blocking]") +{ + auto data = std::make_unique(memory::Tier::GPU, 1024); + auto batch = std::make_shared(1, std::move(data)); + + REQUIRE(batch->try_to_create_task() == true); + REQUIRE(batch->get_state() == batch_state::task_created); + + batch->wait_to_cancel_task(); + REQUIRE(batch->get_state() == batch_state::idle); + REQUIRE(batch->get_task_created_count() == 0); +} + +TEST_CASE("data_batch wait_to_cancel_task with multiple tasks only decrements by one", + "[data_batch][blocking]") +{ + auto data = std::make_unique(memory::Tier::GPU, 1024); + auto batch = std::make_shared(1, std::move(data)); + + REQUIRE(batch->try_to_create_task() == true); + REQUIRE(batch->try_to_create_task() == true); + REQUIRE(batch->get_task_created_count() == 2); + + batch->wait_to_cancel_task(); + REQUIRE(batch->get_state() == batch_state::task_created); + REQUIRE(batch->get_task_created_count() == 1); +} + +TEST_CASE("data_batch wait_to_cancel_task from processing decrements counter", + "[data_batch][blocking]") +{ + auto data = std::make_unique(memory::Tier::GPU, 1024); + auto batch = std::make_shared(1, std::move(data)); + auto space_id = batch->get_memory_space()->get_id(); + + REQUIRE(batch->try_to_create_task() == true); + REQUIRE(batch->try_to_create_task() == true); + auto r = batch->try_to_lock_for_processing(space_id); + REQUIRE(r.success == true); + auto handle = std::move(r.handle); + REQUIRE(batch->get_state() == batch_state::processing); + REQUIRE(batch->get_task_created_count() == 1); + + batch->wait_to_cancel_task(); + REQUIRE(batch->get_state() == batch_state::processing); + REQUIRE(batch->get_task_created_count() == 0); +} + +TEST_CASE("data_batch wait_to_cancel_task blocks in idle until task is created", + "[data_batch][blocking]") +{ + auto data = std::make_unique(memory::Tier::GPU, 1024); + auto batch = std::make_shared(1, std::move(data)); + + // Pre-create a task so cancel has something to cancel (count > 0 after unblocking) + REQUIRE(batch->try_to_create_task() == true); + // Force it back to in_transit so wait_to_cancel_task has to wait + REQUIRE(batch->try_to_lock_for_in_transit() == true); + REQUIRE(batch->get_state() == batch_state::in_transit); + + std::atomic wait_started{false}; + std::atomic wait_done{false}; + + std::thread waiter([batch, &wait_started, &wait_done]() { + wait_started.store(true); + batch->wait_to_cancel_task(); + wait_done.store(true); + }); + + wait_for_flag(wait_started); + std::this_thread::sleep_for(std::chrono::milliseconds(20)); + REQUIRE(wait_done.load() == false); // still blocked + + // Release in_transit back to task_created so the cancel can proceed + REQUIRE(batch->try_to_release_in_transit(batch_state::task_created) == true); + waiter.join(); + + REQUIRE(wait_done.load() == true); + REQUIRE(batch->get_state() == batch_state::idle); + REQUIRE(batch->get_task_created_count() == 0); +} + +// --- wait_to_lock_for_processing --- + +TEST_CASE("data_batch wait_to_lock_for_processing from task_created succeeds immediately", + "[data_batch][blocking]") +{ + auto data = std::make_unique(memory::Tier::GPU, 1024); + auto batch = std::make_shared(1, std::move(data)); + auto space_id = batch->get_memory_space()->get_id(); + + REQUIRE(batch->try_to_create_task() == true); + + auto result = batch->wait_to_lock_for_processing(space_id); + REQUIRE(result.success == true); + REQUIRE(result.status == lock_for_processing_status::success); + REQUIRE(result.handle.valid() == true); + REQUIRE(batch->get_state() == batch_state::processing); + REQUIRE(batch->get_processing_count() == 1); + REQUIRE(batch->get_task_created_count() == 0); +} + +TEST_CASE("data_batch wait_to_lock_for_processing returns immediately for missing_data", + "[data_batch][blocking]") +{ + auto data = std::make_unique(memory::Tier::GPU, 1024); + auto batch = std::make_shared(1, std::move(data)); + batch->set_data(nullptr); + + // Fabricate a space_id; doesn't matter since it should fail before checking state + auto dummy_space = make_mock_memory_space(memory::Tier::GPU, 0); + auto dummy_id = dummy_space->get_id(); + + auto result = batch->wait_to_lock_for_processing(dummy_id); + REQUIRE(result.success == false); + REQUIRE(result.status == lock_for_processing_status::missing_data); +} + +TEST_CASE("data_batch wait_to_lock_for_processing returns immediately for memory_space_mismatch", + "[data_batch][blocking]") +{ + auto data = std::make_unique(memory::Tier::GPU, 1024); + auto batch = std::make_shared(1, std::move(data)); + auto wrong_space = make_mock_memory_space(memory::Tier::HOST, 0); + + auto result = batch->wait_to_lock_for_processing(wrong_space->get_id()); + REQUIRE(result.success == false); + REQUIRE(result.status == lock_for_processing_status::memory_space_mismatch); +} + +TEST_CASE("data_batch wait_to_lock_for_processing blocks in idle until task is created", + "[data_batch][blocking]") +{ + auto data = std::make_unique(memory::Tier::GPU, 1024); + auto batch = std::make_shared(1, std::move(data)); + auto space_id = batch->get_memory_space()->get_id(); + + REQUIRE(batch->get_state() == batch_state::idle); + + std::atomic wait_started{false}; + std::atomic wait_done{false}; + lock_for_processing_result result; + + std::thread waiter([batch, space_id, &wait_started, &wait_done, &result]() { + wait_started.store(true); + result = batch->wait_to_lock_for_processing(space_id); + wait_done.store(true); + }); + + wait_for_flag(wait_started); + std::this_thread::sleep_for(std::chrono::milliseconds(20)); + REQUIRE(wait_done.load() == false); // still blocked + + REQUIRE(batch->try_to_create_task() == true); + waiter.join(); + + REQUIRE(wait_done.load() == true); + REQUIRE(result.success == true); + REQUIRE(result.status == lock_for_processing_status::success); + REQUIRE(batch->get_state() == batch_state::processing); + REQUIRE(batch->get_processing_count() == 1); +} + +TEST_CASE("data_batch wait_to_lock_for_processing blocks in in_transit until released", + "[data_batch][blocking]") +{ + auto data = std::make_unique(memory::Tier::GPU, 1024); + auto batch = std::make_shared(1, std::move(data)); + auto space_id = batch->get_memory_space()->get_id(); + + REQUIRE(batch->try_to_create_task() == true); + REQUIRE(batch->try_to_lock_for_in_transit() == true); + REQUIRE(batch->get_state() == batch_state::in_transit); + + std::atomic wait_started{false}; + std::atomic wait_done{false}; + lock_for_processing_result result; + + std::thread waiter([batch, space_id, &wait_started, &wait_done, &result]() { + wait_started.store(true); + result = batch->wait_to_lock_for_processing(space_id); + wait_done.store(true); + }); + + wait_for_flag(wait_started); + std::this_thread::sleep_for(std::chrono::milliseconds(20)); + REQUIRE(wait_done.load() == false); // still blocked + + // Release transit back to task_created (count is still 1) + REQUIRE(batch->try_to_release_in_transit(batch_state::task_created) == true); + waiter.join(); + + REQUIRE(wait_done.load() == true); + REQUIRE(result.success == true); + REQUIRE(batch->get_state() == batch_state::processing); +} + +// --- wait_to_lock_for_in_transit --- + +TEST_CASE("data_batch wait_to_lock_for_in_transit from idle succeeds immediately", + "[data_batch][blocking]") +{ + auto data = std::make_unique(memory::Tier::GPU, 1024); + auto batch = std::make_shared(1, std::move(data)); + + REQUIRE(batch->get_state() == batch_state::idle); + batch->wait_to_lock_for_in_transit(); + REQUIRE(batch->get_state() == batch_state::in_transit); +} + +TEST_CASE("data_batch wait_to_lock_for_in_transit from task_created succeeds immediately", + "[data_batch][blocking]") +{ + auto data = std::make_unique(memory::Tier::GPU, 1024); + auto batch = std::make_shared(1, std::move(data)); + + REQUIRE(batch->try_to_create_task() == true); + REQUIRE(batch->get_state() == batch_state::task_created); + REQUIRE(batch->get_task_created_count() == 1); + + batch->wait_to_lock_for_in_transit(); + REQUIRE(batch->get_state() == batch_state::in_transit); +} + +TEST_CASE("data_batch wait_to_lock_for_in_transit blocks while processing then succeeds", + "[data_batch][blocking]") +{ + auto data = std::make_unique(memory::Tier::GPU, 1024); + auto batch = std::make_shared(1, std::move(data)); + auto space_id = batch->get_memory_space()->get_id(); + + REQUIRE(batch->try_to_create_task() == true); + auto r = batch->try_to_lock_for_processing(space_id); + REQUIRE(r.success == true); + auto handle = std::move(r.handle); + REQUIRE(batch->get_state() == batch_state::processing); + REQUIRE(batch->get_processing_count() == 1); + + std::atomic wait_started{false}; + std::atomic wait_done{false}; + + std::thread waiter([batch, &wait_started, &wait_done]() { + wait_started.store(true); + batch->wait_to_lock_for_in_transit(); + wait_done.store(true); + }); + + wait_for_flag(wait_started); + std::this_thread::sleep_for(std::chrono::milliseconds(20)); + REQUIRE(wait_done.load() == false); // still blocked + + handle.release(); // decrement processing count → idle + waiter.join(); + + REQUIRE(wait_done.load() == true); + REQUIRE(batch->get_state() == batch_state::in_transit); + REQUIRE(batch->get_processing_count() == 0); +} + +TEST_CASE("data_batch wait_to_lock_for_in_transit blocks with multiple processing handles", + "[data_batch][blocking]") +{ + auto data = std::make_unique(memory::Tier::GPU, 1024); + auto batch = std::make_shared(1, std::move(data)); + auto space_id = batch->get_memory_space()->get_id(); + + // Acquire two processing handles + REQUIRE(batch->try_to_create_task() == true); + REQUIRE(batch->try_to_create_task() == true); + auto r1 = batch->try_to_lock_for_processing(space_id); + auto r2 = batch->try_to_lock_for_processing(space_id); + REQUIRE(r1.success == true); + REQUIRE(r2.success == true); + auto h1 = std::move(r1.handle); + auto h2 = std::move(r2.handle); + REQUIRE(batch->get_processing_count() == 2); + + std::atomic wait_started{false}; + std::atomic wait_done{false}; + + std::thread waiter([batch, &wait_started, &wait_done]() { + wait_started.store(true); + batch->wait_to_lock_for_in_transit(); + wait_done.store(true); + }); + + wait_for_flag(wait_started); + std::this_thread::sleep_for(std::chrono::milliseconds(20)); + REQUIRE(wait_done.load() == false); + + h1.release(); // count → 1, still blocked + std::this_thread::sleep_for(std::chrono::milliseconds(20)); + REQUIRE(wait_done.load() == false); + + h2.release(); // count → 0, unblocks + waiter.join(); + + REQUIRE(wait_done.load() == true); + REQUIRE(batch->get_state() == batch_state::in_transit); +} + +// --- wait_to_release_in_transit --- + +TEST_CASE("data_batch wait_to_release_in_transit from in_transit returns to idle", + "[data_batch][blocking]") +{ + auto data = std::make_unique(memory::Tier::GPU, 1024); + auto batch = std::make_shared(1, std::move(data)); + + REQUIRE(batch->try_to_lock_for_in_transit() == true); + REQUIRE(batch->get_state() == batch_state::in_transit); + + batch->wait_to_release_in_transit(); + REQUIRE(batch->get_state() == batch_state::idle); +} + +TEST_CASE("data_batch wait_to_release_in_transit with explicit target_state", + "[data_batch][blocking]") +{ + auto data = std::make_unique(memory::Tier::GPU, 1024); + auto batch = std::make_shared(1, std::move(data)); + + REQUIRE(batch->try_to_create_task() == true); + REQUIRE(batch->try_to_lock_for_in_transit() == true); + REQUIRE(batch->get_state() == batch_state::in_transit); + + batch->wait_to_release_in_transit(batch_state::task_created); + REQUIRE(batch->get_state() == batch_state::task_created); +} + +TEST_CASE("data_batch wait_to_release_in_transit blocks until in_transit", "[data_batch][blocking]") +{ + auto data = std::make_unique(memory::Tier::GPU, 1024); + auto batch = std::make_shared(1, std::move(data)); + + REQUIRE(batch->get_state() == batch_state::idle); + + std::atomic wait_started{false}; + std::atomic wait_done{false}; + + std::thread waiter([batch, &wait_started, &wait_done]() { + wait_started.store(true); + batch->wait_to_release_in_transit(); + wait_done.store(true); + }); + + wait_for_flag(wait_started); + std::this_thread::sleep_for(std::chrono::milliseconds(20)); + REQUIRE(wait_done.load() == false); // still blocked + + REQUIRE(batch->try_to_lock_for_in_transit() == true); + waiter.join(); + + REQUIRE(wait_done.load() == true); + REQUIRE(batch->get_state() == batch_state::idle); +} + +TEST_CASE("data_batch wait_to_release_in_transit blocks until in_transit with target_state", + "[data_batch][blocking]") +{ + auto data = std::make_unique(memory::Tier::GPU, 1024); + auto batch = std::make_shared(1, std::move(data)); + + std::atomic wait_done{false}; + + std::thread waiter([batch, &wait_done]() { + batch->wait_to_release_in_transit(batch_state::task_created); + wait_done.store(true); + }); + + std::this_thread::sleep_for(std::chrono::milliseconds(20)); + REQUIRE(wait_done.load() == false); + + REQUIRE(batch->try_to_lock_for_in_transit() == true); + waiter.join(); + + REQUIRE(wait_done.load() == true); + REQUIRE(batch->get_state() == batch_state::task_created); +} + +// --- Full state-machine round trips using blocking calls --- + +TEST_CASE("data_batch blocking calls full round trip idle->task_created->processing->idle", + "[data_batch][blocking]") +{ + auto data = std::make_unique(memory::Tier::GPU, 1024); + auto batch = std::make_shared(1, std::move(data)); + auto space_id = batch->get_memory_space()->get_id(); + + REQUIRE(batch->get_state() == batch_state::idle); + + batch->wait_to_create_task(); + REQUIRE(batch->get_state() == batch_state::task_created); + + auto result = batch->wait_to_lock_for_processing(space_id); + REQUIRE(result.success == true); + REQUIRE(batch->get_state() == batch_state::processing); + + result.handle.release(); + REQUIRE(batch->get_state() == batch_state::idle); +} + +TEST_CASE("data_batch blocking calls full round trip idle->in_transit->idle", + "[data_batch][blocking]") +{ + auto data = std::make_unique(memory::Tier::GPU, 1024); + auto batch = std::make_shared(1, std::move(data)); + + REQUIRE(batch->get_state() == batch_state::idle); + + batch->wait_to_lock_for_in_transit(); + REQUIRE(batch->get_state() == batch_state::in_transit); + + batch->wait_to_release_in_transit(); + REQUIRE(batch->get_state() == batch_state::idle); +} + +TEST_CASE("data_batch concurrent wait_to_create_task and wait_to_lock_for_processing", + "[data_batch][blocking]") +{ + auto data = std::make_unique(memory::Tier::GPU, 1024); + auto batch = std::make_shared(1, std::move(data)); + auto space_id = batch->get_memory_space()->get_id(); + + constexpr int num_rounds = 50; + std::vector handles; + handles.reserve(num_rounds); + + for (int i = 0; i < num_rounds; ++i) { + std::atomic creator_done{false}; + + std::thread creator([batch, &creator_done]() { + batch->wait_to_create_task(); + creator_done.store(true); + }); + + auto result = batch->wait_to_lock_for_processing(space_id); + creator.join(); + + REQUIRE(result.success == true); + handles.push_back(std::move(result.handle)); + } + + REQUIRE(batch->get_processing_count() == num_rounds); + handles.clear(); + REQUIRE(batch->get_processing_count() == 0); + REQUIRE(batch->get_state() == batch_state::idle); +}