diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerTaskTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerTaskTests.cs index 7166e1f25..9708722f3 100644 --- a/tests/ModelContextProtocol.Tests/Server/McpServerTaskTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/McpServerTaskTests.cs @@ -1,7 +1,6 @@ using ModelContextProtocol.Protocol; using ModelContextProtocol.Server; using Microsoft.Extensions.DependencyInjection; -using System.Collections.Concurrent; using System.Runtime.InteropServices; using System.Text.Json; using System.Text.Json.Nodes; @@ -179,17 +178,25 @@ public async Task CallToolAsync_AsyncTool_FailedTask_ThrowsMcpException() await using var client = await CreateMcpClientForServer(); var ct = TestContext.Current.CancellationToken; - _ = Task.Run(async () => + var failedTask = new TaskCompletionSource(); + + // Run failure task once the task from the tool call is created + _taskStore.OnTaskCreated += taskId => { - await Task.Delay(100, ct); - var taskId = _taskStore.GetAllTaskIds().Single(); - _taskStore.FailTask(taskId, JsonElement.Parse("""{"code":-32000,"message":"something went wrong"}""")); - }, ct); + _ = Task.Run(async () => + { + await Task.Delay(100, ct); + _taskStore.FailTask(taskId, JsonElement.Parse("""{"code":-32000,"message":"something went wrong"}""")); + failedTask.SetResult(true); + }, ct); + }; await Assert.ThrowsAsync(async () => await client.CallToolAsync( new CallToolRequestParams { Name = "async-tool" }, ct)); + + Assert.True(await failedTask.Task); } [Fact] @@ -198,17 +205,25 @@ public async Task CallToolAsync_AsyncTool_CancelledTask_ThrowsOperationCancelled await using var client = await CreateMcpClientForServer(); var ct = TestContext.Current.CancellationToken; - _ = Task.Run(async () => + var cancelledTask = new TaskCompletionSource(); + + // Run cancellation task once the task from the tool call is created + _taskStore.OnTaskCreated += taskId => { - await Task.Delay(100, ct); - var taskId = _taskStore.GetAllTaskIds().Single(); - _taskStore.CancelTask(taskId); - }, ct); + Task.Run(async () => + { + await Task.Delay(100, ct); + _taskStore.CancelTask(taskId); + cancelledTask.SetResult(true); + }, ct); + }; await Assert.ThrowsAsync(async () => await client.CallToolAsync( new CallToolRequestParams { Name = "async-tool" }, ct)); + + Assert.True(await cancelledTask.Task); } [Fact] @@ -538,106 +553,135 @@ public async Task CallToolHandler_CanBeSetToNull_ThenOtherCanBeSet() /// private sealed class InMemoryTaskStore { - private readonly ConcurrentDictionary _tasks = new(); + private readonly Dictionary _tasks = new(); + + internal Action? OnTaskCreated; public string CreateTask(McpTaskStatus initialStatus = McpTaskStatus.Working) { var taskId = Guid.NewGuid().ToString("N"); - _tasks[taskId] = new TaskEntry + lock (_tasks) { - Status = initialStatus, - CreatedAt = DateTimeOffset.UtcNow, - LastUpdatedAt = DateTimeOffset.UtcNow, - }; + _tasks[taskId] = new TaskEntry + { + Status = initialStatus, + CreatedAt = DateTimeOffset.UtcNow, + LastUpdatedAt = DateTimeOffset.UtcNow, + }; + } + + OnTaskCreated?.Invoke(taskId); + return taskId; } - public IEnumerable GetAllTaskIds() => _tasks.Keys; - - public GetTaskResult GetTask(string taskId) + public IEnumerable GetAllTaskIds() { - if (!_tasks.TryGetValue(taskId, out var entry)) + lock (_tasks) { - throw new McpException($"Unknown task: '{taskId}'"); + return _tasks.Keys.ToArray(); } + } - return entry.Status switch + public GetTaskResult GetTask(string taskId) + { + lock (_tasks) { - McpTaskStatus.Working => new WorkingTaskResult - { - TaskId = taskId, - CreatedAt = entry.CreatedAt, - LastUpdatedAt = entry.LastUpdatedAt, - PollIntervalMs = 50, - }, - McpTaskStatus.Completed => new CompletedTaskResult + if (!_tasks.TryGetValue(taskId, out var entry)) { - TaskId = taskId, - CreatedAt = entry.CreatedAt, - LastUpdatedAt = entry.LastUpdatedAt, - Result = JsonSerializer.SerializeToElement(entry.Result, McpJsonUtilities.DefaultOptions), - }, - McpTaskStatus.Failed => new FailedTaskResult - { - TaskId = taskId, - CreatedAt = entry.CreatedAt, - LastUpdatedAt = entry.LastUpdatedAt, - Error = entry.Error!.Value, - }, - McpTaskStatus.Cancelled => new CancelledTaskResult - { - TaskId = taskId, - CreatedAt = entry.CreatedAt, - LastUpdatedAt = entry.LastUpdatedAt, - }, - McpTaskStatus.InputRequired => new InputRequiredTaskResult + throw new McpException($"Unknown task: '{taskId}'"); + } + + return entry.Status switch { - TaskId = taskId, - CreatedAt = entry.CreatedAt, - LastUpdatedAt = entry.LastUpdatedAt, - InputRequests = entry.InputRequests ?? new Dictionary(), - }, - _ => throw new InvalidOperationException($"Unexpected status: {entry.Status}") - }; + McpTaskStatus.Working => new WorkingTaskResult + { + TaskId = taskId, + CreatedAt = entry.CreatedAt, + LastUpdatedAt = entry.LastUpdatedAt, + PollIntervalMs = 50, + }, + McpTaskStatus.Completed => new CompletedTaskResult + { + TaskId = taskId, + CreatedAt = entry.CreatedAt, + LastUpdatedAt = entry.LastUpdatedAt, + Result = JsonSerializer.SerializeToElement(entry.Result, McpJsonUtilities.DefaultOptions), + }, + McpTaskStatus.Failed => new FailedTaskResult + { + TaskId = taskId, + CreatedAt = entry.CreatedAt, + LastUpdatedAt = entry.LastUpdatedAt, + Error = entry.Error!.Value, + }, + McpTaskStatus.Cancelled => new CancelledTaskResult + { + TaskId = taskId, + CreatedAt = entry.CreatedAt, + LastUpdatedAt = entry.LastUpdatedAt, + }, + McpTaskStatus.InputRequired => new InputRequiredTaskResult + { + TaskId = taskId, + CreatedAt = entry.CreatedAt, + LastUpdatedAt = entry.LastUpdatedAt, + InputRequests = entry.InputRequests ?? new Dictionary(), + }, + _ => throw new InvalidOperationException($"Unexpected status: {entry.Status}") + }; + } } public void CompleteTask(string taskId, CallToolResult result) { - if (_tasks.TryGetValue(taskId, out var entry)) + lock (_tasks) { - entry.Status = McpTaskStatus.Completed; - entry.Result = result; - entry.LastUpdatedAt = DateTimeOffset.UtcNow; + if (_tasks.TryGetValue(taskId, out var entry)) + { + entry.Result = result; + entry.LastUpdatedAt = DateTimeOffset.UtcNow; + entry.Status = McpTaskStatus.Completed; + } } } public void FailTask(string taskId, JsonElement error) { - if (_tasks.TryGetValue(taskId, out var entry)) + lock (_tasks) { - entry.Status = McpTaskStatus.Failed; - entry.Error = error; - entry.LastUpdatedAt = DateTimeOffset.UtcNow; + if (_tasks.TryGetValue(taskId, out var entry)) + { + entry.Error = error; + entry.LastUpdatedAt = DateTimeOffset.UtcNow; + entry.Status = McpTaskStatus.Failed; + } } } public void CancelTask(string taskId) { - if (_tasks.TryGetValue(taskId, out var entry)) + lock (_tasks) { - entry.Status = McpTaskStatus.Cancelled; - entry.LastUpdatedAt = DateTimeOffset.UtcNow; + if (_tasks.TryGetValue(taskId, out var entry)) + { + entry.LastUpdatedAt = DateTimeOffset.UtcNow; + entry.Status = McpTaskStatus.Cancelled; + } } } public void ProvideInput(string taskId, IDictionary inputResponses) { - if (_tasks.TryGetValue(taskId, out var entry)) + lock (_tasks) { - entry.InputResponses = inputResponses; - // Transition back to working after receiving input - entry.Status = McpTaskStatus.Working; - entry.LastUpdatedAt = DateTimeOffset.UtcNow; + if (_tasks.TryGetValue(taskId, out var entry)) + { + entry.InputResponses = inputResponses; + entry.LastUpdatedAt = DateTimeOffset.UtcNow; + // Transition back to working after receiving input + entry.Status = McpTaskStatus.Working; + } } } diff --git a/tests/ModelContextProtocol.Tests/Server/TaskPollStuckDetectorTests.cs b/tests/ModelContextProtocol.Tests/Server/TaskPollStuckDetectorTests.cs index 2f00d925a..1d17be223 100644 --- a/tests/ModelContextProtocol.Tests/Server/TaskPollStuckDetectorTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/TaskPollStuckDetectorTests.cs @@ -15,6 +15,8 @@ namespace ModelContextProtocol.Tests.Server; /// public class TaskPollStuckDetectorTests : ClientServerTestBase { + private int _pollCount = 0; + public TaskPollStuckDetectorTests(ITestOutputHelper testOutputHelper) : base(testOutputHelper) { #if !NET @@ -48,6 +50,8 @@ protected override void ConfigureServices(ServiceCollection services, IMcpServer // misbehaving-server condition the stuck-detector exists to break out of. options.Handlers.GetTaskHandler = (context, cancellationToken) => { + Interlocked.Increment(ref _pollCount); + return new ValueTask(new InputRequiredTaskResult { TaskId = context.Params!.TaskId, @@ -77,19 +81,13 @@ public async Task CallToolAsync_TaskStuckInInputRequired_WithoutNewRequests_Thro await using var client = await CreateMcpClientForServer(); var ct = TestContext.Current.CancellationToken; - var sw = System.Diagnostics.Stopwatch.StartNew(); - var ex = await Assert.ThrowsAsync(async () => await client.CallToolAsync(new CallToolRequestParams { Name = "any-tool" }, ct)); - sw.Stop(); - Assert.Contains(McpTaskStatus.InputRequired.ToString(), ex.Message); Assert.Contains("consecutive polls", ex.Message); - // 60 polls × 5ms ≈ 300ms; allow generous slack for CI. - Assert.True(sw.Elapsed < TimeSpan.FromSeconds(10), - $"Stuck-detector should give up promptly but took {sw.Elapsed}."); + Assert.Equal(60, _pollCount); } [Fact] @@ -111,6 +109,7 @@ public async Task CallToolAsync_StuckDetector_HonorsConfiguredThreshold() // The message embeds the configured threshold, which is the strongest signal that the // option value (not the 60-default constant) is what governed the loop. Assert.Contains($"{CustomThreshold} consecutive polls", ex.Message); + Assert.Equal(CustomThreshold, _pollCount); } [Theory]