Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
190 changes: 117 additions & 73 deletions tests/ModelContextProtocol.Tests/Server/McpServerTaskTests.cs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
using ModelContextProtocol.Protocol;
using ModelContextProtocol.Server;
using Microsoft.Extensions.DependencyInjection;
using System.Collections.Concurrent;
using System.Runtime.InteropServices;
using System.Text.Json;
using System.Text.Json.Nodes;
Expand Down Expand Up @@ -179,17 +178,25 @@ public async Task CallToolAsync_AsyncTool_FailedTask_ThrowsMcpException()
await using var client = await CreateMcpClientForServer();
var ct = TestContext.Current.CancellationToken;

_ = Task.Run(async () =>
var failedTask = new TaskCompletionSource<bool>();

// Run failure task once the task from the tool call is created
_taskStore.OnTaskCreated += taskId =>
{
await Task.Delay(100, ct);
var taskId = _taskStore.GetAllTaskIds().Single();
_taskStore.FailTask(taskId, JsonElement.Parse("""{"code":-32000,"message":"something went wrong"}"""));
}, ct);
_ = Task.Run(async () =>
{
await Task.Delay(100, ct);
_taskStore.FailTask(taskId, JsonElement.Parse("""{"code":-32000,"message":"something went wrong"}"""));
failedTask.SetResult(true);
}, ct);
};

await Assert.ThrowsAsync<McpException>(async () =>
await client.CallToolAsync(
new CallToolRequestParams { Name = "async-tool" },
ct));

Assert.True(await failedTask.Task);
}

[Fact]
Expand All @@ -198,17 +205,25 @@ public async Task CallToolAsync_AsyncTool_CancelledTask_ThrowsOperationCancelled
await using var client = await CreateMcpClientForServer();
var ct = TestContext.Current.CancellationToken;

_ = Task.Run(async () =>
var cancelledTask = new TaskCompletionSource<bool>();

// Run cancellation task once the task from the tool call is created
_taskStore.OnTaskCreated += taskId =>
{
await Task.Delay(100, ct);
var taskId = _taskStore.GetAllTaskIds().Single();
_taskStore.CancelTask(taskId);
}, ct);
Task.Run(async () =>
{
await Task.Delay(100, ct);
_taskStore.CancelTask(taskId);
cancelledTask.SetResult(true);
}, ct);
};

await Assert.ThrowsAsync<OperationCanceledException>(async () =>
await client.CallToolAsync(
new CallToolRequestParams { Name = "async-tool" },
ct));

Assert.True(await cancelledTask.Task);
}

[Fact]
Expand Down Expand Up @@ -538,106 +553,135 @@ public async Task CallToolHandler_CanBeSetToNull_ThenOtherCanBeSet()
/// </summary>
private sealed class InMemoryTaskStore
{
private readonly ConcurrentDictionary<string, TaskEntry> _tasks = new();
private readonly Dictionary<string, TaskEntry> _tasks = new();

internal Action<string>? OnTaskCreated;

public string CreateTask(McpTaskStatus initialStatus = McpTaskStatus.Working)
{
var taskId = Guid.NewGuid().ToString("N");
_tasks[taskId] = new TaskEntry
lock (_tasks)
{
Status = initialStatus,
CreatedAt = DateTimeOffset.UtcNow,
LastUpdatedAt = DateTimeOffset.UtcNow,
};
_tasks[taskId] = new TaskEntry
{
Status = initialStatus,
CreatedAt = DateTimeOffset.UtcNow,
LastUpdatedAt = DateTimeOffset.UtcNow,
};
}

OnTaskCreated?.Invoke(taskId);

return taskId;
}

public IEnumerable<string> GetAllTaskIds() => _tasks.Keys;

public GetTaskResult GetTask(string taskId)
public IEnumerable<string> GetAllTaskIds()
{
if (!_tasks.TryGetValue(taskId, out var entry))
lock (_tasks)
{
throw new McpException($"Unknown task: '{taskId}'");
return _tasks.Keys.ToArray();
}
}

return entry.Status switch
public GetTaskResult GetTask(string taskId)
{
lock (_tasks)
{
McpTaskStatus.Working => new WorkingTaskResult
{
TaskId = taskId,
CreatedAt = entry.CreatedAt,
LastUpdatedAt = entry.LastUpdatedAt,
PollIntervalMs = 50,
},
McpTaskStatus.Completed => new CompletedTaskResult
if (!_tasks.TryGetValue(taskId, out var entry))
{
TaskId = taskId,
CreatedAt = entry.CreatedAt,
LastUpdatedAt = entry.LastUpdatedAt,
Result = JsonSerializer.SerializeToElement(entry.Result, McpJsonUtilities.DefaultOptions),
},
McpTaskStatus.Failed => new FailedTaskResult
{
TaskId = taskId,
CreatedAt = entry.CreatedAt,
LastUpdatedAt = entry.LastUpdatedAt,
Error = entry.Error!.Value,
},
McpTaskStatus.Cancelled => new CancelledTaskResult
{
TaskId = taskId,
CreatedAt = entry.CreatedAt,
LastUpdatedAt = entry.LastUpdatedAt,
},
McpTaskStatus.InputRequired => new InputRequiredTaskResult
throw new McpException($"Unknown task: '{taskId}'");
}

return entry.Status switch
{
TaskId = taskId,
CreatedAt = entry.CreatedAt,
LastUpdatedAt = entry.LastUpdatedAt,
InputRequests = entry.InputRequests ?? new Dictionary<string, InputRequest>(),
},
_ => throw new InvalidOperationException($"Unexpected status: {entry.Status}")
};
McpTaskStatus.Working => new WorkingTaskResult
{
TaskId = taskId,
CreatedAt = entry.CreatedAt,
LastUpdatedAt = entry.LastUpdatedAt,
PollIntervalMs = 50,
},
McpTaskStatus.Completed => new CompletedTaskResult
{
TaskId = taskId,
CreatedAt = entry.CreatedAt,
LastUpdatedAt = entry.LastUpdatedAt,
Result = JsonSerializer.SerializeToElement(entry.Result, McpJsonUtilities.DefaultOptions),
},
McpTaskStatus.Failed => new FailedTaskResult
{
TaskId = taskId,
CreatedAt = entry.CreatedAt,
LastUpdatedAt = entry.LastUpdatedAt,
Error = entry.Error!.Value,
},
McpTaskStatus.Cancelled => new CancelledTaskResult
{
TaskId = taskId,
CreatedAt = entry.CreatedAt,
LastUpdatedAt = entry.LastUpdatedAt,
},
McpTaskStatus.InputRequired => new InputRequiredTaskResult
{
TaskId = taskId,
CreatedAt = entry.CreatedAt,
LastUpdatedAt = entry.LastUpdatedAt,
InputRequests = entry.InputRequests ?? new Dictionary<string, InputRequest>(),
},
_ => throw new InvalidOperationException($"Unexpected status: {entry.Status}")
};
}
}

public void CompleteTask(string taskId, CallToolResult result)
{
if (_tasks.TryGetValue(taskId, out var entry))
lock (_tasks)
{
entry.Status = McpTaskStatus.Completed;
entry.Result = result;
entry.LastUpdatedAt = DateTimeOffset.UtcNow;
if (_tasks.TryGetValue(taskId, out var entry))
{
entry.Result = result;
entry.LastUpdatedAt = DateTimeOffset.UtcNow;
entry.Status = McpTaskStatus.Completed;
}
}
}

public void FailTask(string taskId, JsonElement error)
{
if (_tasks.TryGetValue(taskId, out var entry))
lock (_tasks)
{
entry.Status = McpTaskStatus.Failed;
entry.Error = error;
entry.LastUpdatedAt = DateTimeOffset.UtcNow;
if (_tasks.TryGetValue(taskId, out var entry))
{
entry.Error = error;
entry.LastUpdatedAt = DateTimeOffset.UtcNow;
entry.Status = McpTaskStatus.Failed;
}
}
}

public void CancelTask(string taskId)
{
if (_tasks.TryGetValue(taskId, out var entry))
lock (_tasks)
{
entry.Status = McpTaskStatus.Cancelled;
entry.LastUpdatedAt = DateTimeOffset.UtcNow;
if (_tasks.TryGetValue(taskId, out var entry))
{
entry.LastUpdatedAt = DateTimeOffset.UtcNow;
entry.Status = McpTaskStatus.Cancelled;
}
}
}

public void ProvideInput(string taskId, IDictionary<string, InputResponse> inputResponses)
{
if (_tasks.TryGetValue(taskId, out var entry))
lock (_tasks)
{
entry.InputResponses = inputResponses;
// Transition back to working after receiving input
entry.Status = McpTaskStatus.Working;
entry.LastUpdatedAt = DateTimeOffset.UtcNow;
if (_tasks.TryGetValue(taskId, out var entry))
{
entry.InputResponses = inputResponses;
entry.LastUpdatedAt = DateTimeOffset.UtcNow;
// Transition back to working after receiving input
entry.Status = McpTaskStatus.Working;
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ namespace ModelContextProtocol.Tests.Server;
/// </summary>
public class TaskPollStuckDetectorTests : ClientServerTestBase
{
private int _pollCount = 0;

public TaskPollStuckDetectorTests(ITestOutputHelper testOutputHelper) : base(testOutputHelper)
{
#if !NET
Expand Down Expand Up @@ -48,6 +50,8 @@ protected override void ConfigureServices(ServiceCollection services, IMcpServer
// misbehaving-server condition the stuck-detector exists to break out of.
options.Handlers.GetTaskHandler = (context, cancellationToken) =>
{
Interlocked.Increment(ref _pollCount);

return new ValueTask<GetTaskResult>(new InputRequiredTaskResult
{
TaskId = context.Params!.TaskId,
Expand Down Expand Up @@ -77,19 +81,13 @@ public async Task CallToolAsync_TaskStuckInInputRequired_WithoutNewRequests_Thro
await using var client = await CreateMcpClientForServer();
var ct = TestContext.Current.CancellationToken;

var sw = System.Diagnostics.Stopwatch.StartNew();

var ex = await Assert.ThrowsAsync<McpException>(async () =>
await client.CallToolAsync(new CallToolRequestParams { Name = "any-tool" }, ct));

sw.Stop();

Assert.Contains(McpTaskStatus.InputRequired.ToString(), ex.Message);
Assert.Contains("consecutive polls", ex.Message);

// 60 polls × 5ms ≈ 300ms; allow generous slack for CI.
Assert.True(sw.Elapsed < TimeSpan.FromSeconds(10),
$"Stuck-detector should give up promptly but took {sw.Elapsed}.");
Assert.Equal(60, _pollCount);
}

[Fact]
Expand All @@ -111,6 +109,7 @@ public async Task CallToolAsync_StuckDetector_HonorsConfiguredThreshold()
// The message embeds the configured threshold, which is the strongest signal that the
// option value (not the 60-default constant) is what governed the loop.
Assert.Contains($"{CustomThreshold} consecutive polls", ex.Message);
Assert.Equal(CustomThreshold, _pollCount);
}

[Theory]
Expand Down