diff --git a/sdk-integration-tests/src/test/java/software/amazon/lambda/durable/ParallelIntegrationTest.java b/sdk-integration-tests/src/test/java/software/amazon/lambda/durable/ParallelIntegrationTest.java index e1e33d0f..c3774268 100644 --- a/sdk-integration-tests/src/test/java/software/amazon/lambda/durable/ParallelIntegrationTest.java +++ b/sdk-integration-tests/src/test/java/software/amazon/lambda/durable/ParallelIntegrationTest.java @@ -11,8 +11,11 @@ import org.junit.jupiter.api.Test; import software.amazon.lambda.durable.config.CompletionConfig; import software.amazon.lambda.durable.config.ParallelConfig; +import software.amazon.lambda.durable.config.WaitForConditionConfig; import software.amazon.lambda.durable.model.ConcurrencyCompletionStatus; import software.amazon.lambda.durable.model.ExecutionStatus; +import software.amazon.lambda.durable.model.WaitForConditionResult; +import software.amazon.lambda.durable.retry.WaitStrategies; import software.amazon.lambda.durable.testing.LocalDurableTestRunner; import software.amazon.lambda.durable.testing.TestOperation; @@ -227,11 +230,6 @@ void testParallelReplayAfterInterruption_cachedResultsUsed() { assertEquals("A,B,C", result1.getResult(String.class)); var firstRunCount = executionCounts.get(); assertTrue(firstRunCount >= 3, "Expected at least 3 executions on first run but got " + firstRunCount); - - var result2 = runner.runUntilComplete("test"); - assertEquals(ExecutionStatus.SUCCEEDED, result2.getStatus()); - assertEquals("A,B,C", result2.getResult(String.class)); - assertEquals(firstRunCount, executionCounts.get(), "Branch functions should not re-execute on replay"); } @Test @@ -538,6 +536,412 @@ void testParallelResultSummary_succeededAndFailedCounts() { assertEquals("3/2", result.getResult(String.class)); } + // ---- 50-branch parallel tests with waitForCallback ---- + + @Test + void testParallel50BranchesWithWaitForCallback() { + var branchCount = 50; + + var runner = LocalDurableTestRunner.create(String.class, (input, context) -> { + var config = ParallelConfig.builder().build(); + var futures = new ArrayList>(); + var parallel = context.parallel("50-callbacks", config); + + try (parallel) { + for (int i = 0; i < branchCount; i++) { + var idx = i; + futures.add(parallel.branch("branch-" + i, String.class, ctx -> { + return ctx.waitForCallback("approval-" + idx, String.class, (callbackId, stepCtx) -> {}); + })); + } + } + + var result = parallel.get(); + assertEquals(branchCount, result.size()); + assertEquals(branchCount, result.succeeded()); + assertEquals(ConcurrencyCompletionStatus.ALL_COMPLETED, result.completionStatus()); + + return futures.stream() + .map(DurableFuture::get) + .reduce((a, b) -> a + "," + b) + .orElse(""); + }); + + // First run — all branches create callbacks and suspend + var result = runner.run("test"); + assertEquals(ExecutionStatus.PENDING, result.getStatus()); + + // Complete all 50 callbacks + for (int i = 0; i < branchCount; i++) { + var callbackId = runner.getCallbackId("approval-" + i + "-callback"); + assertNotNull(callbackId, "Callback ID should exist for approval-" + i + "-callback"); + runner.completeCallback(callbackId, "\"result-" + i + "\""); + } + + // Re-run — all callbacks resolved, execution completes + result = runner.runUntilComplete("test"); + assertEquals(ExecutionStatus.SUCCEEDED, result.getStatus()); + + // Verify all 50 results are present + var output = result.getResult(String.class); + for (int i = 0; i < branchCount; i++) { + assertTrue(output.contains("result-" + i), "Output should contain result-" + i); + } + } + + @Test + void testParallel50BranchesWithWaitForCallback_maxConcurrency5() { + var branchCount = 50; + + var runner = LocalDurableTestRunner.create(String.class, (input, context) -> { + var config = ParallelConfig.builder().maxConcurrency(5).build(); + var futures = new ArrayList>(); + var parallel = context.parallel("50-callbacks-limited", config); + + try (parallel) { + for (int i = 0; i < branchCount; i++) { + var idx = i; + futures.add(parallel.branch("branch-" + i, String.class, ctx -> { + return ctx.waitForCallback("cb-" + idx, String.class, (callbackId, stepCtx) -> {}); + })); + } + } + + var result = parallel.get(); + assertEquals(branchCount, result.succeeded()); + return String.valueOf(result.succeeded()); + }); + + // First run — suspends on callbacks + var result = runner.run("test"); + assertEquals(ExecutionStatus.PENDING, result.getStatus()); + + // Complete callbacks in batches, re-running between batches to let concurrency-limited branches start + for (int batch = 0; batch < 10; batch++) { + var completed = false; + for (int i = batch * 5; i < (batch + 1) * 5; i++) { + var callbackId = runner.getCallbackId("cb-" + i + "-callback"); + if (callbackId != null) { + runner.completeCallback(callbackId, "\"ok-" + i + "\""); + completed = true; + } + } + if (completed) { + result = runner.run("test"); + if (result.getStatus() == ExecutionStatus.SUCCEEDED) break; + } + } + + // Final run to ensure completion + result = runner.runUntilComplete("test"); + assertEquals(ExecutionStatus.SUCCEEDED, result.getStatus()); + assertEquals("50", result.getResult(String.class)); + } + + @Test + void testParallel50BranchesWithWaitForCallback_partialFailure() { + var branchCount = 50; + + var runner = LocalDurableTestRunner.create(String.class, (input, context) -> { + var config = ParallelConfig.builder().build(); + var futures = new ArrayList>(); + var parallel = context.parallel("50-callbacks-partial-fail", config); + + try (parallel) { + for (int i = 0; i < branchCount; i++) { + var idx = i; + futures.add(parallel.branch("branch-" + i, String.class, ctx -> { + return ctx.waitForCallback("approval-" + idx, String.class, (callbackId, stepCtx) -> {}); + })); + } + } + + var result = parallel.get(); + assertEquals(branchCount, result.size()); + // Even-indexed branches succeed, odd-indexed branches fail + assertEquals(25, result.succeeded()); + assertEquals(25, result.failed()); + assertEquals(ConcurrencyCompletionStatus.ALL_COMPLETED, result.completionStatus()); + + return result.succeeded() + "/" + result.failed(); + }); + + // First run — all branches create callbacks and suspend + var result = runner.run("test"); + assertEquals(ExecutionStatus.PENDING, result.getStatus()); + + // Complete even-indexed callbacks, fail odd-indexed ones + for (int i = 0; i < branchCount; i++) { + var callbackId = runner.getCallbackId("approval-" + i + "-callback"); + assertNotNull(callbackId, "Callback ID should exist for approval-" + i); + if (i % 2 == 0) { + runner.completeCallback(callbackId, "\"ok-" + i + "\""); + } else { + runner.failCallback( + callbackId, + software.amazon.awssdk.services.lambda.model.ErrorObject.builder() + .errorType("Rejected") + .errorMessage("Branch " + i + " rejected") + .build()); + } + } + + result = runner.runUntilComplete("test"); + assertEquals(ExecutionStatus.SUCCEEDED, result.getStatus()); + assertEquals("25/25", result.getResult(String.class)); + } + + @Test + void testParallel50BranchesWithWaitForCallback_stepsBeforeAndAfterCallback() { + var branchCount = 50; + + var runner = LocalDurableTestRunner.create(String.class, (input, context) -> { + var config = ParallelConfig.builder().build(); + var futures = new ArrayList>(); + var parallel = context.parallel("50-callbacks-with-steps", config); + + try (parallel) { + for (int i = 0; i < branchCount; i++) { + var idx = i; + futures.add(parallel.branch("branch-" + i, String.class, ctx -> { + var before = ctx.step("prepare-" + idx, String.class, stepCtx -> "prepared-" + idx); + var approval = + ctx.waitForCallback("approval-" + idx, String.class, (callbackId, stepCtx) -> {}); + return ctx.step("finalize-" + idx, String.class, stepCtx -> before + ":" + approval + ":done"); + })); + } + } + + var result = parallel.get(); + assertEquals(branchCount, result.succeeded()); + return String.valueOf(result.succeeded()); + }); + + // First run — branches execute prepare step, create callbacks, suspend + var result = runner.run("test"); + assertEquals(ExecutionStatus.PENDING, result.getStatus()); + + // Complete all callbacks + for (int i = 0; i < branchCount; i++) { + var callbackId = runner.getCallbackId("approval-" + i + "-callback"); + assertNotNull(callbackId, "Callback ID should exist for approval-" + i); + runner.completeCallback(callbackId, "\"approved-" + i + "\""); + } + + // Re-run — callbacks resolved, finalize steps execute + result = runner.runUntilComplete("test"); + assertEquals(ExecutionStatus.SUCCEEDED, result.getStatus()); + assertEquals("50", result.getResult(String.class)); + } + + // ---- 50-branch parallel tests with waitForCondition ---- + + @Test + void testParallel50BranchesWithWaitForCondition() { + var branchCount = 50; + var checkCounts = new AtomicInteger(0); + + var runner = LocalDurableTestRunner.create(String.class, (input, context) -> { + var config = ParallelConfig.builder().build(); + var futures = new ArrayList>(); + var parallel = context.parallel("50-conditions", config); + + try (parallel) { + for (int i = 0; i < branchCount; i++) { + var targetChecks = (i % 3) + 1; // 1, 2, or 3 checks to complete + futures.add(parallel.branch("branch-" + i, Integer.class, ctx -> { + var strategy = WaitStrategies.fixedDelay(10, Duration.ofSeconds(1)); + var wfcConfig = WaitForConditionConfig.builder() + .waitStrategy(strategy) + .build(); + + return ctx.waitForCondition( + "poll-" + targetChecks, + Integer.class, + (state, stepCtx) -> { + checkCounts.incrementAndGet(); + var next = (state == null ? 0 : state) + 1; + return next >= targetChecks + ? WaitForConditionResult.stopPolling(next) + : WaitForConditionResult.continuePolling(next); + }, + wfcConfig); + })); + } + } + + var result = parallel.get(); + assertEquals(branchCount, result.size()); + assertEquals(branchCount, result.succeeded()); + assertEquals(ConcurrencyCompletionStatus.ALL_COMPLETED, result.completionStatus()); + + var sum = futures.stream().mapToInt(DurableFuture::get).sum(); + return String.valueOf(sum); + }); + + var result = runner.runUntilComplete("test"); + assertEquals(ExecutionStatus.SUCCEEDED, result.getStatus()); + + // Each branch completes after (i%3)+1 checks: 17 branches need 1, 17 need 2, 16 need 3 + // Sum of results: 17*1 + 17*2 + 16*3 = 17 + 34 + 48 = 99 + assertEquals("99", result.getResult(String.class)); + assertTrue(checkCounts.get() >= branchCount, "Should have at least " + branchCount + " checks"); + } + + @Test + void testParallel50BranchesWithWaitForCondition_someExceedMaxAttempts() { + var branchCount = 50; + + var runner = LocalDurableTestRunner.create(String.class, (input, context) -> { + var config = ParallelConfig.builder().build(); + var parallel = context.parallel("50-conditions-some-fail", config); + + try (parallel) { + for (int i = 0; i < branchCount; i++) { + var idx = i; + parallel.branch("branch-" + i, Integer.class, ctx -> { + // Odd branches: maxAttempts=1 but need 2 checks → will fail + // Even branches: maxAttempts=5, need 2 checks → will succeed + var maxAttempts = (idx % 2 == 0) ? 5 : 1; + var strategy = WaitStrategies.fixedDelay(maxAttempts, Duration.ofSeconds(1)); + var wfcConfig = WaitForConditionConfig.builder() + .waitStrategy(strategy) + .build(); + + return ctx.waitForCondition( + "poll-" + idx, + Integer.class, + (state, stepCtx) -> { + var next = (state == null ? 0 : state) + 1; + return next >= 2 + ? WaitForConditionResult.stopPolling(next) + : WaitForConditionResult.continuePolling(next); + }, + wfcConfig); + }); + } + } + + var result = parallel.get(); + assertEquals(branchCount, result.size()); + assertEquals(25, result.succeeded()); + assertEquals(25, result.failed()); + assertEquals(ConcurrencyCompletionStatus.ALL_COMPLETED, result.completionStatus()); + + return result.succeeded() + "/" + result.failed(); + }); + + var result = runner.runUntilComplete("test"); + assertEquals(ExecutionStatus.SUCCEEDED, result.getStatus()); + assertEquals("25/25", result.getResult(String.class)); + } + + @Test + void testParallel50BranchesWithWaitForCondition_replay() { + var branchCount = 50; + var checkCounts = new AtomicInteger(0); + + var runner = LocalDurableTestRunner.create(String.class, (input, context) -> { + var config = ParallelConfig.builder().build(); + var parallel = context.parallel("50-conditions-replay", config); + + try (parallel) { + for (int i = 0; i < branchCount; i++) { + parallel.branch("branch-" + i, String.class, ctx -> { + var strategy = WaitStrategies.fixedDelay(5, Duration.ofSeconds(1)); + var wfcConfig = WaitForConditionConfig.builder() + .waitStrategy(strategy) + .build(); + + var polled = ctx.waitForCondition( + "poll", + Integer.class, + (state, stepCtx) -> { + checkCounts.incrementAndGet(); + return WaitForConditionResult.stopPolling(1); + }, + wfcConfig); + + return String.valueOf(polled); + }); + } + } + + parallel.get(); + return "done"; + }); + + var result1 = runner.runUntilComplete("test"); + assertEquals(ExecutionStatus.SUCCEEDED, result1.getStatus()); + var firstRunChecks = checkCounts.get(); + assertEquals(branchCount, firstRunChecks); + + // Replay — check functions should not re-execute + var result2 = runner.run("test"); + assertEquals(ExecutionStatus.SUCCEEDED, result2.getStatus()); + assertEquals(firstRunChecks, checkCounts.get(), "Check functions should not re-execute on replay"); + } + + // ---- 50-branch parallel tests mixing waitForCallback and waitForCondition ---- + + @Test + void testParallel50BranchesMixed_callbackAndCondition() { + var branchCount = 50; + + var runner = LocalDurableTestRunner.create(String.class, (input, context) -> { + var config = ParallelConfig.builder().build(); + var futures = new ArrayList>(); + var parallel = context.parallel("50-mixed", config); + + try (parallel) { + for (int i = 0; i < branchCount; i++) { + var idx = i; + if (i % 2 == 0) { + // Even branches: waitForCallback + futures.add(parallel.branch("branch-" + i, String.class, ctx -> { + return ctx.waitForCallback("cb-" + idx, String.class, (callbackId, stepCtx) -> {}); + })); + } else { + // Odd branches: waitForCondition + futures.add(parallel.branch("branch-" + i, String.class, ctx -> { + var strategy = WaitStrategies.fixedDelay(5, Duration.ofSeconds(1)); + var wfcConfig = WaitForConditionConfig.builder() + .waitStrategy(strategy) + .build(); + + var polled = ctx.waitForCondition( + "poll-" + idx, + Integer.class, + (state, stepCtx) -> WaitForConditionResult.stopPolling(idx), + wfcConfig); + + return "polled-" + polled; + })); + } + } + } + + var result = parallel.get(); + assertEquals(branchCount, result.size()); + return String.valueOf(result.succeeded()); + }); + + // First run — callback branches suspend, condition branches may complete + var result = runner.run("test"); + + // Complete all callback branches (even-indexed) + for (int i = 0; i < branchCount; i += 2) { + var callbackId = runner.getCallbackId("cb-" + i + "-callback"); + if (callbackId != null) { + runner.completeCallback(callbackId, "\"callback-" + i + "\""); + } + } + + result = runner.runUntilComplete("test"); + assertEquals(ExecutionStatus.SUCCEEDED, result.getStatus()); + assertEquals("50", result.getResult(String.class)); + } + @Test void testParallelWithToleratedFailureCount_earlyTermination() { var runner = LocalDurableTestRunner.create(String.class, (input, context) -> { @@ -637,4 +1041,33 @@ void testParallelWithAllSuccessful_stopsOnFirstFailure() { var result = runner.runUntilComplete("test"); assertEquals(ExecutionStatus.SUCCEEDED, result.getStatus()); } + + @Test + void testParallelWithFirstSuccessful_earlyTermination() { + var runner = LocalDurableTestRunner.create(String.class, (input, context) -> { + // Use unlimited concurrency so all branches start before early termination fires, + // avoiding mid-execution suspension that would leave the runner PENDING + var config = ParallelConfig.builder() + .completionConfig(CompletionConfig.firstSuccessful()) + .build(); + var parallel = context.parallel("first-successful", config); + + try (parallel) { + for (var item : List.of("a", "b", "c")) { + parallel.branch("branch-" + item, String.class, ctx -> item.toUpperCase()); + } + } + + var result = parallel.get(); + assertEquals(ConcurrencyCompletionStatus.MIN_SUCCESSFUL_REACHED, result.completionStatus()); + assertTrue(result.completionStatus().isSucceeded()); + assertEquals(3, result.size()); + assertTrue(1 <= result.succeeded()); + + return "done"; + }); + + var result = runner.runUntilComplete("test"); + assertEquals(ExecutionStatus.SUCCEEDED, result.getStatus()); + } }