From d2d57d9b3ecf90731e77166038ac9d4d79796240 Mon Sep 17 00:00:00 2001 From: Pranav Senthilnathan Date: Thu, 11 Jun 2026 20:56:19 -0700 Subject: [PATCH 1/2] Fix race in test InMemoryTaskStore (#1648) The test helper InMemoryTaskStore mutates TaskEntry fields without synchronization. FailTask/CompleteTask/CancelTask set Status before their payload field (Error/Result), so a concurrent reader in GetTask could observe Status == Failed with Error == null and throw 'Nullable object must have a value', surfacing to the client as a McpProtocolException instead of the expected McpException. Serialize all reads and writes under a lock on the backing dictionary, and reorder mutations so Status is assigned last. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../Server/McpServerTaskTests.cs | 149 ++++++++++-------- 1 file changed, 86 insertions(+), 63 deletions(-) diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerTaskTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerTaskTests.cs index 7166e1f25..66304ed5f 100644 --- a/tests/ModelContextProtocol.Tests/Server/McpServerTaskTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/McpServerTaskTests.cs @@ -1,7 +1,6 @@ using ModelContextProtocol.Protocol; using ModelContextProtocol.Server; using Microsoft.Extensions.DependencyInjection; -using System.Collections.Concurrent; using System.Runtime.InteropServices; using System.Text.Json; using System.Text.Json.Nodes; @@ -538,106 +537,130 @@ public async Task CallToolHandler_CanBeSetToNull_ThenOtherCanBeSet() /// private sealed class InMemoryTaskStore { - private readonly ConcurrentDictionary _tasks = new(); + private readonly Dictionary _tasks = new(); public string CreateTask(McpTaskStatus initialStatus = McpTaskStatus.Working) { var taskId = Guid.NewGuid().ToString("N"); - _tasks[taskId] = new TaskEntry + lock (_tasks) { - Status = initialStatus, - CreatedAt = DateTimeOffset.UtcNow, - LastUpdatedAt = DateTimeOffset.UtcNow, - }; + _tasks[taskId] = new TaskEntry + { + Status = initialStatus, + CreatedAt = DateTimeOffset.UtcNow, + LastUpdatedAt = DateTimeOffset.UtcNow, + }; + } return taskId; } - public IEnumerable GetAllTaskIds() => _tasks.Keys; - - public GetTaskResult GetTask(string taskId) + public IEnumerable GetAllTaskIds() { - if (!_tasks.TryGetValue(taskId, out var entry)) + lock (_tasks) { - throw new McpException($"Unknown task: '{taskId}'"); + return _tasks.Keys.ToArray(); } + } - return entry.Status switch + public GetTaskResult GetTask(string taskId) + { + lock (_tasks) { - McpTaskStatus.Working => new WorkingTaskResult - { - TaskId = taskId, - CreatedAt = entry.CreatedAt, - LastUpdatedAt = entry.LastUpdatedAt, - PollIntervalMs = 50, - }, - McpTaskStatus.Completed => new CompletedTaskResult + if (!_tasks.TryGetValue(taskId, out var entry)) { - TaskId = taskId, - CreatedAt = entry.CreatedAt, - LastUpdatedAt = entry.LastUpdatedAt, - Result = JsonSerializer.SerializeToElement(entry.Result, McpJsonUtilities.DefaultOptions), - }, - McpTaskStatus.Failed => new FailedTaskResult - { - TaskId = taskId, - CreatedAt = entry.CreatedAt, - LastUpdatedAt = entry.LastUpdatedAt, - Error = entry.Error!.Value, - }, - McpTaskStatus.Cancelled => new CancelledTaskResult - { - TaskId = taskId, - CreatedAt = entry.CreatedAt, - LastUpdatedAt = entry.LastUpdatedAt, - }, - McpTaskStatus.InputRequired => new InputRequiredTaskResult + throw new McpException($"Unknown task: '{taskId}'"); + } + + return entry.Status switch { - TaskId = taskId, - CreatedAt = entry.CreatedAt, - LastUpdatedAt = entry.LastUpdatedAt, - InputRequests = entry.InputRequests ?? new Dictionary(), - }, - _ => throw new InvalidOperationException($"Unexpected status: {entry.Status}") - }; + McpTaskStatus.Working => new WorkingTaskResult + { + TaskId = taskId, + CreatedAt = entry.CreatedAt, + LastUpdatedAt = entry.LastUpdatedAt, + PollIntervalMs = 50, + }, + McpTaskStatus.Completed => new CompletedTaskResult + { + TaskId = taskId, + CreatedAt = entry.CreatedAt, + LastUpdatedAt = entry.LastUpdatedAt, + Result = JsonSerializer.SerializeToElement(entry.Result, McpJsonUtilities.DefaultOptions), + }, + McpTaskStatus.Failed => new FailedTaskResult + { + TaskId = taskId, + CreatedAt = entry.CreatedAt, + LastUpdatedAt = entry.LastUpdatedAt, + Error = entry.Error!.Value, + }, + McpTaskStatus.Cancelled => new CancelledTaskResult + { + TaskId = taskId, + CreatedAt = entry.CreatedAt, + LastUpdatedAt = entry.LastUpdatedAt, + }, + McpTaskStatus.InputRequired => new InputRequiredTaskResult + { + TaskId = taskId, + CreatedAt = entry.CreatedAt, + LastUpdatedAt = entry.LastUpdatedAt, + InputRequests = entry.InputRequests ?? new Dictionary(), + }, + _ => throw new InvalidOperationException($"Unexpected status: {entry.Status}") + }; + } } public void CompleteTask(string taskId, CallToolResult result) { - if (_tasks.TryGetValue(taskId, out var entry)) + lock (_tasks) { - entry.Status = McpTaskStatus.Completed; - entry.Result = result; - entry.LastUpdatedAt = DateTimeOffset.UtcNow; + if (_tasks.TryGetValue(taskId, out var entry)) + { + entry.Result = result; + entry.LastUpdatedAt = DateTimeOffset.UtcNow; + entry.Status = McpTaskStatus.Completed; + } } } public void FailTask(string taskId, JsonElement error) { - if (_tasks.TryGetValue(taskId, out var entry)) + lock (_tasks) { - entry.Status = McpTaskStatus.Failed; - entry.Error = error; - entry.LastUpdatedAt = DateTimeOffset.UtcNow; + if (_tasks.TryGetValue(taskId, out var entry)) + { + entry.Error = error; + entry.LastUpdatedAt = DateTimeOffset.UtcNow; + entry.Status = McpTaskStatus.Failed; + } } } public void CancelTask(string taskId) { - if (_tasks.TryGetValue(taskId, out var entry)) + lock (_tasks) { - entry.Status = McpTaskStatus.Cancelled; - entry.LastUpdatedAt = DateTimeOffset.UtcNow; + if (_tasks.TryGetValue(taskId, out var entry)) + { + entry.LastUpdatedAt = DateTimeOffset.UtcNow; + entry.Status = McpTaskStatus.Cancelled; + } } } public void ProvideInput(string taskId, IDictionary inputResponses) { - if (_tasks.TryGetValue(taskId, out var entry)) + lock (_tasks) { - entry.InputResponses = inputResponses; - // Transition back to working after receiving input - entry.Status = McpTaskStatus.Working; - entry.LastUpdatedAt = DateTimeOffset.UtcNow; + if (_tasks.TryGetValue(taskId, out var entry)) + { + entry.InputResponses = inputResponses; + entry.LastUpdatedAt = DateTimeOffset.UtcNow; + // Transition back to working after receiving input + entry.Status = McpTaskStatus.Working; + } } } From 5d633a2654baf0bcb88c2ad2e1be2e9018b585b3 Mon Sep 17 00:00:00 2001 From: Pranav Senthilnathan Date: Thu, 18 Jun 2026 15:53:46 -0700 Subject: [PATCH 2/2] Reduce absolute time dependencies in tests --- .../Server/McpServerTaskTests.cs | 41 ++++++++++++++----- .../Server/TaskPollStuckDetectorTests.cs | 13 +++--- 2 files changed, 37 insertions(+), 17 deletions(-) diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerTaskTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerTaskTests.cs index 66304ed5f..9708722f3 100644 --- a/tests/ModelContextProtocol.Tests/Server/McpServerTaskTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/McpServerTaskTests.cs @@ -178,17 +178,25 @@ public async Task CallToolAsync_AsyncTool_FailedTask_ThrowsMcpException() await using var client = await CreateMcpClientForServer(); var ct = TestContext.Current.CancellationToken; - _ = Task.Run(async () => + var failedTask = new TaskCompletionSource(); + + // Run failure task once the task from the tool call is created + _taskStore.OnTaskCreated += taskId => { - await Task.Delay(100, ct); - var taskId = _taskStore.GetAllTaskIds().Single(); - _taskStore.FailTask(taskId, JsonElement.Parse("""{"code":-32000,"message":"something went wrong"}""")); - }, ct); + _ = Task.Run(async () => + { + await Task.Delay(100, ct); + _taskStore.FailTask(taskId, JsonElement.Parse("""{"code":-32000,"message":"something went wrong"}""")); + failedTask.SetResult(true); + }, ct); + }; await Assert.ThrowsAsync(async () => await client.CallToolAsync( new CallToolRequestParams { Name = "async-tool" }, ct)); + + Assert.True(await failedTask.Task); } [Fact] @@ -197,17 +205,25 @@ public async Task CallToolAsync_AsyncTool_CancelledTask_ThrowsOperationCancelled await using var client = await CreateMcpClientForServer(); var ct = TestContext.Current.CancellationToken; - _ = Task.Run(async () => + var cancelledTask = new TaskCompletionSource(); + + // Run cancellation task once the task from the tool call is created + _taskStore.OnTaskCreated += taskId => { - await Task.Delay(100, ct); - var taskId = _taskStore.GetAllTaskIds().Single(); - _taskStore.CancelTask(taskId); - }, ct); + Task.Run(async () => + { + await Task.Delay(100, ct); + _taskStore.CancelTask(taskId); + cancelledTask.SetResult(true); + }, ct); + }; await Assert.ThrowsAsync(async () => await client.CallToolAsync( new CallToolRequestParams { Name = "async-tool" }, ct)); + + Assert.True(await cancelledTask.Task); } [Fact] @@ -539,6 +555,8 @@ private sealed class InMemoryTaskStore { private readonly Dictionary _tasks = new(); + internal Action? OnTaskCreated; + public string CreateTask(McpTaskStatus initialStatus = McpTaskStatus.Working) { var taskId = Guid.NewGuid().ToString("N"); @@ -551,6 +569,9 @@ public string CreateTask(McpTaskStatus initialStatus = McpTaskStatus.Working) LastUpdatedAt = DateTimeOffset.UtcNow, }; } + + OnTaskCreated?.Invoke(taskId); + return taskId; } diff --git a/tests/ModelContextProtocol.Tests/Server/TaskPollStuckDetectorTests.cs b/tests/ModelContextProtocol.Tests/Server/TaskPollStuckDetectorTests.cs index 2f00d925a..1d17be223 100644 --- a/tests/ModelContextProtocol.Tests/Server/TaskPollStuckDetectorTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/TaskPollStuckDetectorTests.cs @@ -15,6 +15,8 @@ namespace ModelContextProtocol.Tests.Server; /// public class TaskPollStuckDetectorTests : ClientServerTestBase { + private int _pollCount = 0; + public TaskPollStuckDetectorTests(ITestOutputHelper testOutputHelper) : base(testOutputHelper) { #if !NET @@ -48,6 +50,8 @@ protected override void ConfigureServices(ServiceCollection services, IMcpServer // misbehaving-server condition the stuck-detector exists to break out of. options.Handlers.GetTaskHandler = (context, cancellationToken) => { + Interlocked.Increment(ref _pollCount); + return new ValueTask(new InputRequiredTaskResult { TaskId = context.Params!.TaskId, @@ -77,19 +81,13 @@ public async Task CallToolAsync_TaskStuckInInputRequired_WithoutNewRequests_Thro await using var client = await CreateMcpClientForServer(); var ct = TestContext.Current.CancellationToken; - var sw = System.Diagnostics.Stopwatch.StartNew(); - var ex = await Assert.ThrowsAsync(async () => await client.CallToolAsync(new CallToolRequestParams { Name = "any-tool" }, ct)); - sw.Stop(); - Assert.Contains(McpTaskStatus.InputRequired.ToString(), ex.Message); Assert.Contains("consecutive polls", ex.Message); - // 60 polls × 5ms ≈ 300ms; allow generous slack for CI. - Assert.True(sw.Elapsed < TimeSpan.FromSeconds(10), - $"Stuck-detector should give up promptly but took {sw.Elapsed}."); + Assert.Equal(60, _pollCount); } [Fact] @@ -111,6 +109,7 @@ public async Task CallToolAsync_StuckDetector_HonorsConfiguredThreshold() // The message embeds the configured threshold, which is the strongest signal that the // option value (not the 60-default constant) is what governed the loop. Assert.Contains($"{CustomThreshold} consecutive polls", ex.Message); + Assert.Equal(CustomThreshold, _pollCount); } [Theory]