Skip to content

Commit 4197c8f

Browse files
Copilotstephentoub
andcommitted
Add tool name to structured logging for tool calls
Co-authored-by: stephentoub <[email protected]>
1 parent 359bbc9 commit 4197c8f

File tree

3 files changed

+92
-14
lines changed

3 files changed

+92
-14
lines changed

src/ModelContextProtocol.Core/McpSessionHandler.cs

Lines changed: 36 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -436,7 +436,8 @@ public async Task<JsonRpcResponse> SendRequestAsync(JsonRpcRequest request, Canc
436436
// Now that the request has been sent, register for cancellation. If we registered before,
437437
// a cancellation request could arrive before the server knew about that request ID, in which
438438
// case the server could ignore it.
439-
LogRequestSentAwaitingResponse(EndpointName, request.Method, request.Id);
439+
string? target = GetRequestTarget(request);
440+
LogRequestSentAwaitingResponse(EndpointName, request.Method, request.Id, toolName: target);
440441
JsonRpcMessage? response;
441442
using (var registration = RegisterCancellation(cancellationToken, request))
442443
{
@@ -445,7 +446,7 @@ public async Task<JsonRpcResponse> SendRequestAsync(JsonRpcRequest request, Canc
445446

446447
if (response is JsonRpcError error)
447448
{
448-
LogSendingRequestFailed(EndpointName, request.Method, error.Error.Message, error.Error.Code);
449+
LogSendingRequestFailed(EndpointName, request.Method, error.Error.Message, error.Error.Code, toolName: target);
449450
throw new McpException($"Request failed (remote): {error.Error.Message}", (McpErrorCode)error.Error.Code);
450451
}
451452

@@ -458,11 +459,11 @@ public async Task<JsonRpcResponse> SendRequestAsync(JsonRpcRequest request, Canc
458459

459460
if (_logger.IsEnabled(LogLevel.Trace))
460461
{
461-
LogRequestResponseReceivedSensitive(EndpointName, request.Method, success.Result?.ToJsonString() ?? "null");
462+
LogRequestResponseReceivedSensitive(EndpointName, request.Method, success.Result?.ToJsonString() ?? "null", toolName: target);
462463
}
463464
else
464465
{
465-
LogRequestResponseReceived(EndpointName, request.Method);
466+
LogRequestResponseReceived(EndpointName, request.Method, toolName: target);
466467
}
467468

468469
return success;
@@ -763,6 +764,29 @@ private static TimeSpan GetElapsed(long startingTimestamp) =>
763764
return null;
764765
}
765766

767+
/// <summary>
768+
/// Extracts the target identifier (tool name, prompt name, or resource URI) from a request.
769+
/// </summary>
770+
/// <param name="request">The JSON-RPC request.</param>
771+
/// <returns>The target identifier if available; otherwise, null.</returns>
772+
private static string? GetRequestTarget(JsonRpcRequest request)
773+
{
774+
if (request.Params is not JsonObject paramsObj)
775+
{
776+
return null;
777+
}
778+
779+
return request.Method switch
780+
{
781+
RequestMethods.ToolsCall => GetStringProperty(paramsObj, "name"),
782+
RequestMethods.PromptsGet => GetStringProperty(paramsObj, "name"),
783+
RequestMethods.ResourcesRead => GetStringProperty(paramsObj, "uri"),
784+
RequestMethods.ResourcesSubscribe => GetStringProperty(paramsObj, "uri"),
785+
RequestMethods.ResourcesUnsubscribe => GetStringProperty(paramsObj, "uri"),
786+
_ => null
787+
};
788+
}
789+
766790
[LoggerMessage(Level = LogLevel.Information, Message = "{EndpointName} message processing canceled.")]
767791
private partial void LogEndpointMessageProcessingCanceled(string endpointName);
768792

@@ -778,8 +802,8 @@ private static TimeSpan GetElapsed(long startingTimestamp) =>
778802
[LoggerMessage(Level = LogLevel.Information, Message = "{EndpointName} received request for unknown request ID '{RequestId}'.")]
779803
private partial void LogNoRequestFoundForMessageWithId(string endpointName, RequestId requestId);
780804

781-
[LoggerMessage(Level = LogLevel.Warning, Message = "{EndpointName} request failed for method '{Method}': {ErrorMessage} ({ErrorCode}).")]
782-
private partial void LogSendingRequestFailed(string endpointName, string method, string errorMessage, int errorCode);
805+
[LoggerMessage(Level = LogLevel.Warning, Message = "{EndpointName} request failed for method '{Method}' (tool: '{ToolName}'): {ErrorMessage} ({ErrorCode}).")]
806+
private partial void LogSendingRequestFailed(string endpointName, string method, string errorMessage, int errorCode, string? toolName = null);
783807

784808
[LoggerMessage(Level = LogLevel.Warning, Message = "{EndpointName} received invalid response for method '{Method}'.")]
785809
private partial void LogSendingRequestInvalidResponseType(string endpointName, string method);
@@ -793,11 +817,11 @@ private static TimeSpan GetElapsed(long startingTimestamp) =>
793817
[LoggerMessage(Level = LogLevel.Information, Message = "{EndpointName} canceled request '{RequestId}' per client notification. Reason: '{Reason}'.")]
794818
private partial void LogRequestCanceled(string endpointName, RequestId requestId, string? reason);
795819

796-
[LoggerMessage(Level = LogLevel.Debug, Message = "{EndpointName} Request response received for method {method}")]
797-
private partial void LogRequestResponseReceived(string endpointName, string method);
820+
[LoggerMessage(Level = LogLevel.Debug, Message = "{EndpointName} Request response received for method {method} (tool: '{ToolName}')")]
821+
private partial void LogRequestResponseReceived(string endpointName, string method, string? toolName = null);
798822

799-
[LoggerMessage(Level = LogLevel.Trace, Message = "{EndpointName} Request response received for method {method}. Response: '{Response}'.")]
800-
private partial void LogRequestResponseReceivedSensitive(string endpointName, string method, string response);
823+
[LoggerMessage(Level = LogLevel.Trace, Message = "{EndpointName} Request response received for method {method} (tool: '{ToolName}'). Response: '{Response}'.")]
824+
private partial void LogRequestResponseReceivedSensitive(string endpointName, string method, string response, string? toolName = null);
801825

802826
[LoggerMessage(Level = LogLevel.Debug, Message = "{EndpointName} read {MessageType} message from channel.")]
803827
private partial void LogMessageRead(string endpointName, string messageType);
@@ -814,8 +838,8 @@ private static TimeSpan GetElapsed(long startingTimestamp) =>
814838
[LoggerMessage(Level = LogLevel.Warning, Message = "{EndpointName} received request for method '{Method}', but no handler is available.")]
815839
private partial void LogNoHandlerFoundForRequest(string endpointName, string method);
816840

817-
[LoggerMessage(Level = LogLevel.Debug, Message = "{EndpointName} waiting for response to request '{RequestId}' for method '{Method}'.")]
818-
private partial void LogRequestSentAwaitingResponse(string endpointName, string method, RequestId requestId);
841+
[LoggerMessage(Level = LogLevel.Debug, Message = "{EndpointName} waiting for response to request '{RequestId}' for method '{Method}' (tool: '{ToolName}').")]
842+
private partial void LogRequestSentAwaitingResponse(string endpointName, string method, RequestId requestId, string? toolName = null);
819843

820844
[LoggerMessage(Level = LogLevel.Debug, Message = "{EndpointName} sending message.")]
821845
private partial void LogSendingMessage(string endpointName);

tests/Common/Utils/MockLoggerProvider.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ namespace ModelContextProtocol.Tests.Utils;
66
public class MockLoggerProvider() : ILoggerProvider
77
{
88
public ConcurrentQueue<(string Category, LogLevel LogLevel, EventId EventId, string Message, Exception? Exception)> LogMessages { get; } = [];
9+
public ConcurrentQueue<(string Category, LogLevel LogLevel, EventId EventId, string Message, Exception? Exception, object? State)> LogMessagesWithState { get; } = [];
910

1011
public ILogger CreateLogger(string categoryName)
1112
{
@@ -22,6 +23,7 @@ public void Log<TState>(
2223
LogLevel logLevel, EventId eventId, TState state, Exception? exception, Func<TState, Exception?, string> formatter)
2324
{
2425
mockProvider.LogMessages.Enqueue((category, logLevel, eventId, formatter(state, exception), exception));
26+
mockProvider.LogMessagesWithState.Enqueue((category, logLevel, eventId, formatter(state, exception), exception, state));
2527
}
2628

2729
public bool IsEnabled(LogLevel logLevel) => true;

tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs

Lines changed: 54 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,20 @@ namespace ModelContextProtocol.Tests.Configuration;
2020

2121
public partial class McpServerBuilderExtensionsToolsTests : ClientServerTestBase
2222
{
23+
private MockLoggerProvider _mockLoggerProvider = new();
24+
2325
public McpServerBuilderExtensionsToolsTests(ITestOutputHelper testOutputHelper)
2426
: base(testOutputHelper)
2527
{
28+
// Configure LoggerFactory to use Debug level and add MockLoggerProvider
29+
LoggerFactory = Microsoft.Extensions.Logging.LoggerFactory.Create(builder =>
30+
{
31+
builder.AddProvider(XunitLoggerProvider);
32+
builder.AddProvider(_mockLoggerProvider);
33+
builder.SetMinimumLevel(LogLevel.Debug);
34+
});
2635
}
2736

28-
private MockLoggerProvider _mockLoggerProvider = new();
29-
3037
protected override void ConfigureServices(ServiceCollection services, IMcpServerBuilder mcpServerBuilder)
3138
{
3239
mcpServerBuilder
@@ -733,6 +740,51 @@ await client.SendNotificationAsync(
733740
await Assert.ThrowsAnyAsync<OperationCanceledException>(async () => await invokeTask);
734741
}
735742

743+
[Fact]
744+
public async Task ToolName_Captured_In_Structured_Logging()
745+
{
746+
await using McpClient client = await CreateMcpClientForServer();
747+
748+
// Call a tool that will succeed
749+
var result = await client.CallToolAsync(
750+
"echo",
751+
new Dictionary<string, object?> { ["message"] = "test" },
752+
cancellationToken: TestContext.Current.CancellationToken);
753+
754+
Assert.NotNull(result);
755+
756+
// Verify that the tool name is captured in structured logging
757+
// The LogMessagesWithState should contain log entries with tool name in the state
758+
var allLogs = _mockLoggerProvider.LogMessagesWithState.ToList();
759+
TestOutputHelper.WriteLine($"Total logs captured: {allLogs.Count}");
760+
foreach (var log in allLogs)
761+
{
762+
TestOutputHelper.WriteLine($"Log: Category={log.Category}, Level={log.LogLevel}, Message={log.Message}");
763+
}
764+
765+
var relevantLogs = allLogs
766+
.Where(m => m.Category == "ModelContextProtocol.Client.McpClient" &&
767+
m.Message.Contains("tools/call"))
768+
.ToList();
769+
770+
TestOutputHelper.WriteLine($"Relevant logs: {relevantLogs.Count}");
771+
Assert.NotEmpty(relevantLogs);
772+
773+
// Check that at least one log entry has the tool name in its structured state
774+
bool foundToolName = relevantLogs.Any(log =>
775+
{
776+
if (log.State is IReadOnlyList<KeyValuePair<string, object?>> stateList)
777+
{
778+
return stateList.Any(kvp =>
779+
kvp.Key == "ToolName" &&
780+
kvp.Value?.ToString() == "echo");
781+
}
782+
return false;
783+
});
784+
785+
Assert.True(foundToolName, "Tool name 'echo' was not found in structured logging state");
786+
}
787+
736788
[McpServerToolType]
737789
public sealed class EchoTool(ObjectWithId objectFromDI)
738790
{

0 commit comments

Comments
 (0)