diff --git a/mcp/client.go b/mcp/client.go index 57c30fb3..51d0fc14 100644 --- a/mcp/client.go +++ b/mcp/client.go @@ -50,7 +50,6 @@ func NewClient(impl *Implementation, options *ClientOptions) *Client { if options != nil { opts = *options } - options = nil // prevent reuse if opts.Logger == nil { // ensure we have a logger opts.Logger = ensureLogger(nil) @@ -129,6 +128,7 @@ type ClientOptions struct { PromptListChangedHandler func(context.Context, *PromptListChangedRequest) ResourceListChangedHandler func(context.Context, *ResourceListChangedRequest) ResourceUpdatedHandler func(context.Context, *ResourceUpdatedNotificationRequest) + TaskStatusHandler func(context.Context, *TaskStatusNotificationRequest) LoggingMessageHandler func(context.Context, *LoggingMessageRequest) ProgressNotificationHandler func(context.Context, *ProgressNotificationClientRequest) // If non-zero, defines an interval for regular "ping" requests. @@ -807,6 +807,7 @@ var clientMethodInfos = map[string]methodInfo{ methodCreateMessage: newClientMethodInfo(clientMethod((*Client).createMessage), 0), methodElicit: newClientMethodInfo(clientMethod((*Client).elicit), missingParamsOK), notificationCancelled: newClientMethodInfo(clientSessionMethod((*ClientSession).cancel), notification|missingParamsOK), + notificationTaskStatus: newClientMethodInfo(clientMethod((*Client).callTaskStatusHandler), notification), notificationToolListChanged: newClientMethodInfo(clientMethod((*Client).callToolChangedHandler), notification|missingParamsOK), notificationPromptListChanged: newClientMethodInfo(clientMethod((*Client).callPromptChangedHandler), notification|missingParamsOK), notificationResourceListChanged: newClientMethodInfo(clientMethod((*Client).callResourceChangedHandler), notification|missingParamsOK), @@ -888,6 +889,9 @@ func (cs *ClientSession) ListTools(ctx context.Context, params *ListToolsParams) // // The params.Arguments can be any value that marshals into a JSON object. func (cs *ClientSession) CallTool(ctx context.Context, params *CallToolParams) (*CallToolResult, error) { + if params != nil && params.Task != nil { + return nil, fmt.Errorf("task augmentation requested: use CallToolTask") + } if params == nil { params = new(CallToolParams) } @@ -898,6 +902,43 @@ func (cs *ClientSession) CallTool(ctx context.Context, params *CallToolParams) ( return handleSend[*CallToolResult](ctx, methodCallTool, newClientRequest(cs, orZero[Params](params))) } +// CallToolTask calls a tool using task-based execution (tools/call with params.task). +// +// The response is a CreateTaskResult. Use GetTask to poll for task state and +// TaskResult to retrieve the final tool result. +func (cs *ClientSession) CallToolTask(ctx context.Context, params *CallToolParams) (*CreateTaskResult, error) { + if params == nil || params.Task == nil { + return nil, fmt.Errorf("CallToolTask requires params.Task") + } + if params.Arguments == nil { + // Avoid sending nil over the wire. + params.Arguments = map[string]any{} + } + return handleSend[*CreateTaskResult](ctx, methodCallTool, newClientRequest(cs, params)) +} + +// GetTask polls task status via tasks/get. +func (cs *ClientSession) GetTask(ctx context.Context, params *GetTaskParams) (*GetTaskResult, error) { + return handleSend[*GetTaskResult](ctx, methodGetTask, newClientRequest(cs, orZero[Params](params))) +} + +// ListTasks lists tasks via tasks/list. +func (cs *ClientSession) ListTasks(ctx context.Context, params *ListTasksParams) (*ListTasksResult, error) { + return handleSend[*ListTasksResult](ctx, methodListTasks, newClientRequest(cs, orZero[Params](params))) +} + +// CancelTask requests cancellation via tasks/cancel. +func (cs *ClientSession) CancelTask(ctx context.Context, params *CancelTaskParams) (*CancelTaskResult, error) { + return handleSend[*CancelTaskResult](ctx, methodCancelTask, newClientRequest(cs, orZero[Params](params))) +} + +// TaskResult retrieves the final result of a task via tasks/result. +// +// Currently, this SDK supports tasks/result only for tasks created from tools/call. +func (cs *ClientSession) TaskResult(ctx context.Context, params *TaskResultParams) (*CallToolResult, error) { + return handleSend[*CallToolResult](ctx, methodTaskResult, newClientRequest(cs, orZero[Params](params))) +} + func (cs *ClientSession) SetLoggingLevel(ctx context.Context, params *SetLoggingLevelParams) error { _, err := handleSend[*emptyResult](ctx, methodSetLevel, newClientRequest(cs, orZero[Params](params))) return err @@ -971,6 +1012,13 @@ func (c *Client) callLoggingHandler(ctx context.Context, req *LoggingMessageRequ return nil, nil } +func (c *Client) callTaskStatusHandler(ctx context.Context, req *TaskStatusNotificationRequest) (Result, error) { + if h := c.opts.TaskStatusHandler; h != nil { + h(ctx, req) + } + return nil, nil +} + func (cs *ClientSession) callProgressNotificationHandler(ctx context.Context, params *ProgressNotificationParams) (Result, error) { if h := cs.client.opts.ProgressNotificationHandler; h != nil { h(ctx, clientRequestFor(cs, params)) diff --git a/mcp/protocol.go b/mcp/protocol.go index bea776f9..a69d6f67 100644 --- a/mcp/protocol.go +++ b/mcp/protocol.go @@ -48,6 +48,11 @@ type CallToolParams struct { // Arguments holds the tool arguments. It can hold any value that can be // marshaled to JSON. Arguments any `json:"arguments,omitempty"` + // Task optionally requests task-based execution of this tool call. + // + // Note: when Task is present, the wire response is a CreateTaskResult rather + // than a CallToolResult. + Task *TaskParams `json:"task,omitempty"` } // CallToolParamsRaw is passed to tool handlers on the server. Its arguments @@ -63,6 +68,8 @@ type CallToolParamsRaw struct { // is the responsibility of the tool handler to unmarshal and validate the // Arguments (see [AddTool]). Arguments json.RawMessage `json:"arguments,omitempty"` + // Task optionally requests task-based execution of this tool call. + Task *TaskParams `json:"task,omitempty"` } // A CallToolResult is the server's response to a tool call. @@ -210,6 +217,8 @@ type ClientCapabilities struct { Sampling *SamplingCapabilities `json:"sampling,omitempty"` // Elicitation is present if the client supports elicitation from the server. Elicitation *ElicitationCapabilities `json:"elicitation,omitempty"` + // Tasks describes support for task-based execution. + Tasks *TasksCapabilities `json:"tasks,omitempty"` } // clone returns a deep copy of the ClientCapabilities. @@ -217,6 +226,7 @@ func (c *ClientCapabilities) clone() *ClientCapabilities { cp := *c cp.RootsV2 = shallowClone(c.RootsV2) cp.Sampling = shallowClone(c.Sampling) + cp.Tasks = shallowClone(c.Tasks) if c.Elicitation != nil { x := *c.Elicitation x.Form = shallowClone(c.Elicitation.Form) @@ -1092,8 +1102,28 @@ type Tool struct { Title string `json:"title,omitempty"` // Icons for the tool, if any. Icons []Icon `json:"icons,omitempty"` + // Execution contains optional execution-related settings. + Execution *ToolExecution `json:"execution,omitempty"` +} + +// ToolExecution configures execution behavior for a tool. +type ToolExecution struct { + // TaskSupport declares task support for this tool. + // + // Valid values are: "required", "optional", or "forbidden". + // See ToolTaskSupportRequired, ToolTaskSupportOptional, and ToolTaskSupportForbidden. + TaskSupport string `json:"taskSupport,omitempty"` } +const ( + // ToolTaskSupportRequired indicates the tool MUST be invoked with task augmentation. + ToolTaskSupportRequired = "required" + // ToolTaskSupportOptional indicates the tool MAY be invoked with task augmentation. + ToolTaskSupportOptional = "optional" + // ToolTaskSupportForbidden indicates the tool MUST NOT be invoked with task augmentation. + ToolTaskSupportForbidden = "forbidden" +) + // Additional properties describing a Tool to clients. // // NOTE: all properties in ToolAnnotations are hints. They are not @@ -1314,6 +1344,8 @@ type ServerCapabilities struct { Resources *ResourceCapabilities `json:"resources,omitempty"` // Tools is present if the supports tools. Tools *ToolCapabilities `json:"tools,omitempty"` + // Tasks describes support for task-based execution. + Tasks *TasksCapabilities `json:"tasks,omitempty"` } // clone returns a deep copy of the ServerCapabilities. @@ -1324,12 +1356,148 @@ func (c *ServerCapabilities) clone() *ServerCapabilities { cp.Prompts = shallowClone(c.Prompts) cp.Resources = shallowClone(c.Resources) cp.Tools = shallowClone(c.Tools) + cp.Tasks = shallowClone(c.Tasks) return &cp } +// TasksCapabilities describes support for task-based execution. +type TasksCapabilities struct { + List *TasksListCapabilities `json:"list,omitempty"` + Cancel *TasksCancelCapabilities `json:"cancel,omitempty"` + Requests *TasksRequestsCapabilities `json:"requests,omitempty"` +} + +type TasksListCapabilities struct{} +type TasksCancelCapabilities struct{} + +type TasksRequestsCapabilities struct { + Tools *TasksToolsRequestCapabilities `json:"tools,omitempty"` + Sampling *TasksSamplingRequestCapabilities `json:"sampling,omitempty"` + Elicitation *TasksElicitationRequestCapabilities `json:"elicitation,omitempty"` +} + +type TasksToolsRequestCapabilities struct { + Call *TasksToolsCallCapabilities `json:"call,omitempty"` +} + +type TasksToolsCallCapabilities struct{} + +type TasksSamplingRequestCapabilities struct { + CreateMessage *TasksSamplingCreateMessageCapabilities `json:"createMessage,omitempty"` +} + +type TasksSamplingCreateMessageCapabilities struct{} + +type TasksElicitationRequestCapabilities struct { + Create *TasksElicitationCreateCapabilities `json:"create,omitempty"` +} + +type TasksElicitationCreateCapabilities struct{} + +// TaskParams is included in request parameters to request task-based execution. +type TaskParams struct { + TTL *int64 `json:"ttl,omitempty"` +} + +type TaskStatus string + +const ( + TaskStatusWorking TaskStatus = "working" + TaskStatusInputRequired TaskStatus = "input_required" + TaskStatusCompleted TaskStatus = "completed" + TaskStatusFailed TaskStatus = "failed" + TaskStatusCancelled TaskStatus = "cancelled" +) + +// Task describes the state of a task. +type Task struct { + Meta `json:"_meta,omitempty"` + TaskID string `json:"taskId"` + Status TaskStatus `json:"status"` + StatusMessage string `json:"statusMessage,omitempty"` + CreatedAt string `json:"createdAt"` + LastUpdatedAt string `json:"lastUpdatedAt"` + TTL *int64 `json:"ttl"` + PollInterval *int64 `json:"pollInterval,omitempty"` +} + +// CreateTaskResult is returned for task-augmented requests. +type CreateTaskResult struct { + Meta `json:"_meta,omitempty"` + Task *Task `json:"task"` +} + +func (*CreateTaskResult) isResult() {} + +type GetTaskParams struct { + Meta `json:"_meta,omitempty"` + TaskID string `json:"taskId"` +} + +func (*GetTaskParams) isParams() {} +func (x *GetTaskParams) GetProgressToken() any { return getProgressToken(x) } +func (x *GetTaskParams) SetProgressToken(t any) { setProgressToken(x, t) } + +type GetTaskResult Task + +func (*GetTaskResult) isResult() {} + +type ListTasksParams struct { + Meta `json:"_meta,omitempty"` + Cursor string `json:"cursor,omitempty"` +} + +func (x *ListTasksParams) isParams() {} +func (x *ListTasksParams) GetProgressToken() any { return getProgressToken(x) } +func (x *ListTasksParams) SetProgressToken(t any) { setProgressToken(x, t) } +func (x *ListTasksParams) cursorPtr() *string { return &x.Cursor } + +type ListTasksResult struct { + Meta `json:"_meta,omitempty"` + Tasks []*Task `json:"tasks"` + NextCursor string `json:"nextCursor,omitempty"` +} + +func (*ListTasksResult) isResult() {} +func (x *ListTasksResult) nextCursorPtr() *string { return &x.NextCursor } + +type CancelTaskParams struct { + Meta `json:"_meta,omitempty"` + TaskID string `json:"taskId"` +} + +func (*CancelTaskParams) isParams() {} +func (x *CancelTaskParams) GetProgressToken() any { return getProgressToken(x) } +func (x *CancelTaskParams) SetProgressToken(t any) { setProgressToken(x, t) } + +type CancelTaskResult Task + +func (*CancelTaskResult) isResult() {} + +type TaskResultParams struct { + Meta `json:"_meta,omitempty"` + TaskID string `json:"taskId"` +} + +func (*TaskResultParams) isParams() {} +func (x *TaskResultParams) GetProgressToken() any { return getProgressToken(x) } +func (x *TaskResultParams) SetProgressToken(t any) { setProgressToken(x, t) } + +// TaskStatusNotificationParams is sent as notifications/tasks/status. +type TaskStatusNotificationParams Task + +func (*TaskStatusNotificationParams) isParams() {} +func (x *TaskStatusNotificationParams) GetProgressToken() any { return getProgressToken(x) } +func (x *TaskStatusNotificationParams) SetProgressToken(t any) { setProgressToken(x, t) } + const ( methodCallTool = "tools/call" + methodGetTask = "tasks/get" + methodListTasks = "tasks/list" + methodCancelTask = "tasks/cancel" + methodTaskResult = "tasks/result" notificationCancelled = "notifications/cancelled" + notificationTaskStatus = "notifications/tasks/status" methodComplete = "completion/complete" methodCreateMessage = "sampling/createMessage" methodElicit = "elicitation/create" diff --git a/mcp/requests.go b/mcp/requests.go index f64d6fb6..9a75962f 100644 --- a/mcp/requests.go +++ b/mcp/requests.go @@ -8,9 +8,12 @@ package mcp type ( CallToolRequest = ServerRequest[*CallToolParamsRaw] + CancelTaskRequest = ServerRequest[*CancelTaskParams] CompleteRequest = ServerRequest[*CompleteParams] + GetTaskRequest = ServerRequest[*GetTaskParams] GetPromptRequest = ServerRequest[*GetPromptParams] InitializedRequest = ServerRequest[*InitializedParams] + ListTasksRequest = ServerRequest[*ListTasksParams] ListPromptsRequest = ServerRequest[*ListPromptsParams] ListResourcesRequest = ServerRequest[*ListResourcesParams] ListResourceTemplatesRequest = ServerRequest[*ListResourceTemplatesParams] @@ -19,6 +22,8 @@ type ( ReadResourceRequest = ServerRequest[*ReadResourceParams] RootsListChangedRequest = ServerRequest[*RootsListChangedParams] SubscribeRequest = ServerRequest[*SubscribeParams] + TaskStatusNotificationServerRequest = ServerRequest[*TaskStatusNotificationParams] + TaskResultRequest = ServerRequest[*TaskResultParams] UnsubscribeRequest = ServerRequest[*UnsubscribeParams] ) @@ -33,6 +38,7 @@ type ( PromptListChangedRequest = ClientRequest[*PromptListChangedParams] ResourceListChangedRequest = ClientRequest[*ResourceListChangedParams] ResourceUpdatedNotificationRequest = ClientRequest[*ResourceUpdatedNotificationParams] + TaskStatusNotificationRequest = ClientRequest[*TaskStatusNotificationParams] ToolListChangedRequest = ClientRequest[*ToolListChangedParams] ElicitationCompleteNotificationRequest = ClientRequest[*ElicitationCompleteParams] ) diff --git a/mcp/server.go b/mcp/server.go index 1485b889..af00674b 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -52,6 +52,7 @@ type Server struct { receivingMethodHandler_ MethodHandler resourceSubscriptions map[string]map[*ServerSession]bool // uri -> session -> bool pendingNotifications map[string]*time.Timer // notification name -> timer for pending notification send + tasks *serverTasks } // ServerOptions is used to configure behavior of the server. @@ -71,6 +72,8 @@ type ServerOptions struct { RootsListChangedHandler func(context.Context, *RootsListChangedRequest) // If non-nil, called when "notifications/progress" is received. ProgressNotificationHandler func(context.Context, *ProgressNotificationServerRequest) + // If non-nil, called when "notifications/tasks/status" is received. + TaskStatusHandler func(context.Context, *TaskStatusNotificationServerRequest) // If non-nil, called when "completion/complete" is received. CompletionHandler func(context.Context, *CompleteRequest) (*CompleteResult, error) // If non-zero, defines an interval for regular "ping" requests. @@ -155,7 +158,6 @@ func NewServer(impl *Implementation, options *ServerOptions) *Server { if options != nil { opts = *options } - options = nil // prevent reuse if opts.PageSize < 0 { panic(fmt.Errorf("invalid page size %d", opts.PageSize)) } @@ -188,6 +190,7 @@ func NewServer(impl *Implementation, options *ServerOptions) *Server { receivingMethodHandler_: defaultReceivingMethodHandler[*ServerSession], resourceSubscriptions: make(map[string]map[*ServerSession]bool), pendingNotifications: make(map[string]*time.Timer), + tasks: newServerTasks(), } } @@ -1011,6 +1014,13 @@ func (s *Server) callRootsListChangedHandler(ctx context.Context, req *RootsList return nil, nil } +func (s *Server) callTaskStatusHandler(ctx context.Context, req *TaskStatusNotificationServerRequest) (Result, error) { + if h := s.opts.TaskStatusHandler; h != nil { + h(ctx, req) + } + return nil, nil +} + func (ss *ServerSession) callProgressNotificationHandler(ctx context.Context, p *ProgressNotificationParams) (Result, error) { if h := ss.server.opts.ProgressNotificationHandler; h != nil { h(ctx, serverRequestFor(ss, p)) @@ -1254,7 +1264,11 @@ var serverMethodInfos = map[string]methodInfo{ methodListPrompts: newServerMethodInfo(serverMethod((*Server).listPrompts), missingParamsOK), methodGetPrompt: newServerMethodInfo(serverMethod((*Server).getPrompt), 0), methodListTools: newServerMethodInfo(serverMethod((*Server).listTools), missingParamsOK), - methodCallTool: newServerMethodInfo(serverMethod((*Server).callTool), 0), + methodCallTool: callToolMethodInfo(), + methodGetTask: newServerMethodInfo(serverMethod((*Server).getTask), 0), + methodListTasks: newServerMethodInfo(serverMethod((*Server).listTasks), missingParamsOK), + methodCancelTask: newServerMethodInfo(serverMethod((*Server).cancelTask), 0), + methodTaskResult: newServerMethodInfo(serverMethod((*Server).taskResult), 0), methodListResources: newServerMethodInfo(serverMethod((*Server).listResources), missingParamsOK), methodListResourceTemplates: newServerMethodInfo(serverMethod((*Server).listResourceTemplates), missingParamsOK), methodReadResource: newServerMethodInfo(serverMethod((*Server).readResource), 0), @@ -1262,6 +1276,7 @@ var serverMethodInfos = map[string]methodInfo{ methodSubscribe: newServerMethodInfo(serverMethod((*Server).subscribe), 0), methodUnsubscribe: newServerMethodInfo(serverMethod((*Server).unsubscribe), 0), notificationCancelled: newServerMethodInfo(serverSessionMethod((*ServerSession).cancel), notification|missingParamsOK), + notificationTaskStatus: newServerMethodInfo(serverMethod((*Server).callTaskStatusHandler), notification), notificationInitialized: newServerMethodInfo(serverSessionMethod((*ServerSession).initialized), notification|missingParamsOK), notificationRootsListChanged: newServerMethodInfo(serverMethod((*Server).callRootsListChangedHandler), notification|missingParamsOK), notificationProgress: newServerMethodInfo(serverSessionMethod((*ServerSession).callProgressNotificationHandler), notification), @@ -1286,6 +1301,22 @@ func initializeMethodInfo() methodInfo { return info } +func callToolMethodInfo() methodInfo { + // Start with the standard tools/call method info so that we preserve the + // wire format (CallToolParamsRaw) and the normal result type. + info := newServerMethodInfo(serverMethod((*Server).callTool), 0) + + // Override receive-side behavior to be task-aware. + info.handleMethod = func(ctx context.Context, _ string, req Request) (Result, error) { + r, ok := req.(*CallToolRequest) + if !ok { + return nil, fmt.Errorf("internal error: unexpected request type %T for tools/call", req) + } + return r.Session.server.callToolAny(ctx, r) + } + return info +} + func (ss *ServerSession) sendingMethodInfos() map[string]methodInfo { return clientMethodInfos } func (ss *ServerSession) receivingMethodInfos() map[string]methodInfo { return serverMethodInfos } diff --git a/mcp/shared.go b/mcp/shared.go index d83eae7d..ebe44e02 100644 --- a/mcp/shared.go +++ b/mcp/shared.go @@ -109,6 +109,15 @@ func defaultSendingMethodHandler(ctx context.Context, method string, req Request // Create the result to unmarshal into. // The concrete type of the result is the return type of the receiving function. res := info.newResult() + // Task-augmented requests may change the result schema. + // + // Currently, the only task-augmentable request supported by this SDK is + // tools/call. + if method == methodCallTool { + if p, ok := params.(*CallToolParams); ok && p.Task != nil { + res = &CreateTaskResult{} + } + } if err := call(ctx, req.GetSession().getConn(), method, params, res); err != nil { return nil, err } diff --git a/mcp/tasks_server.go b/mcp/tasks_server.go new file mode 100644 index 00000000..b5588781 --- /dev/null +++ b/mcp/tasks_server.go @@ -0,0 +1,455 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package mcp + +import ( + "context" + "crypto/rand" + "encoding/hex" + "fmt" + "sort" + "strconv" + "sync" + "time" + + "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" + "github.com/modelcontextprotocol/go-sdk/jsonrpc" +) + +const relatedTaskMetaKey = "io.modelcontextprotocol/related-task" + +type serverTasks struct { + mu sync.Mutex + next uint64 + tasks map[string]*serverTaskEntry +} + +type serverTaskEntry struct { + seq uint64 + session *ServerSession + meta Meta + args []byte + + task Task + expiresAt *time.Time + + cancel context.CancelFunc + done chan struct{} + + result *CallToolResult + err error +} + +func newServerTasks() *serverTasks { + return &serverTasks{tasks: make(map[string]*serverTaskEntry)} +} + +func (s *Server) tasksEnabledForToolsCall() bool { + caps := s.capabilities() + return caps.Tasks != nil && + caps.Tasks.Requests != nil && + caps.Tasks.Requests.Tools != nil && + caps.Tasks.Requests.Tools.Call != nil +} + +func (s *Server) tasksEnabled() bool { + return s.capabilities().Tasks != nil +} + +func (s *Server) tasksListEnabled() bool { + caps := s.capabilities() + return caps.Tasks != nil && caps.Tasks.List != nil +} + +func (s *Server) tasksCancelEnabled() bool { + caps := s.capabilities() + return caps.Tasks != nil && caps.Tasks.Cancel != nil +} + +func (s *Server) callToolAny(ctx context.Context, req *CallToolRequest) (Result, error) { + s.mu.Lock() + st, ok := s.tools.get(req.Params.Name) + s.mu.Unlock() + if !ok { + return nil, &jsonrpc.Error{ + Code: jsonrpc.CodeInvalidParams, + Message: fmt.Sprintf("unknown tool %q", req.Params.Name), + } + } + + // If the server hasn't advertised task augmentation for tools/call, ignore any + // task request and process normally. + if !s.tasksEnabledForToolsCall() { + return s.callToolNow(ctx, req, st) + } + + taskSupport := "forbidden" + if st.tool.Execution != nil && st.tool.Execution.TaskSupport != "" { + taskSupport = st.tool.Execution.TaskSupport + } + + if req.Params.Task == nil { + if taskSupport == "required" { + return nil, fmt.Errorf("%w: task augmentation required for tools/call", jsonrpc2.ErrMethodNotFound) + } + return s.callToolNow(ctx, req, st) + } + + // Task requested. + if taskSupport == "forbidden" || taskSupport == "" { + return nil, fmt.Errorf("%w: tool does not support task execution", jsonrpc2.ErrMethodNotFound) + } + if taskSupport != "optional" && taskSupport != "required" { + return nil, fmt.Errorf("%w: invalid tool execution.taskSupport %q", jsonrpc2.ErrInvalidParams, taskSupport) + } + + entry, err := s.tasks.createToolTask(req.Session, req.Params.Meta, req.Params.Arguments, req.Params.Task) + if err != nil { + return nil, err + } + + // Run the tool asynchronously. + go func() { + defer func() { + // Ensure we never leak a task wait. + select { + case <-entry.done: + default: + close(entry.done) + } + }() + + res, runErr := s.runToolTask(entry, st) + + s.tasks.finishToolTask(entry, res, runErr) + }() + + t := entry.task // copy + return &CreateTaskResult{Task: &t}, nil +} + +func (s *Server) runToolTask(entry *serverTaskEntry, st *serverTool) (*CallToolResult, error) { + // Tasks are durable relative to the initiating request lifetime. + taskCtx, cancel := context.WithCancel(context.Background()) + s.tasks.setCancel(entry, cancel) + defer cancel() + + paramsCopy := CallToolParamsRaw{ + Meta: entry.meta, + Name: st.tool.Name, + Arguments: append([]byte(nil), entry.args...), + Task: nil, + } + + // The tool handler expects a CallToolRequest. + toolReq := &CallToolRequest{Session: entry.session, Params: ¶msCopy} + res, err := st.handler(taskCtx, toolReq) + if err == nil && res != nil && res.Content == nil { + res2 := *res + res2.Content = []Content{} + res = &res2 + } + if err == nil && res == nil { + res = &CallToolResult{Content: []Content{}} + } + return res, err +} + +func (s *serverTasks) createToolTask(session *ServerSession, meta Meta, rawArgs []byte, tp *TaskParams) (*serverTaskEntry, error) { + if session == nil { + return nil, fmt.Errorf("%w: missing session", jsonrpc2.ErrInvalidRequest) + } + if meta != nil { + cp := make(Meta, len(meta)) + for k, v := range meta { + cp[k] = v + } + meta = cp + } + + now := time.Now().UTC() + createdAt := now.Format(time.RFC3339) + + var ttl *int64 + var expiresAt *time.Time + if tp != nil && tp.TTL != nil { + v := *tp.TTL + ttl = &v + exp := now.Add(time.Duration(v) * time.Millisecond) + expiresAt = &exp + } else { + // Explicitly include null TTL in tasks/get responses. + ttl = nil + } + + taskID, err := newTaskID() + if err != nil { + return nil, fmt.Errorf("%w: generating task id: %v", jsonrpc2.ErrInternal, err) + } + + e := &serverTaskEntry{ + session: session, + meta: meta, + args: append([]byte(nil), rawArgs...), + task: Task{ + Meta: nil, + TaskID: taskID, + Status: TaskStatusWorking, + StatusMessage: "The operation is now in progress.", + CreatedAt: createdAt, + LastUpdatedAt: createdAt, + TTL: ttl, + }, + expiresAt: expiresAt, + done: make(chan struct{}), + } + + s.mu.Lock() + s.next++ + e.seq = s.next + s.tasks[taskID] = e + s.mu.Unlock() + + return e, nil +} + +func (s *serverTasks) setCancel(entry *serverTaskEntry, cancel context.CancelFunc) { + s.mu.Lock() + defer s.mu.Unlock() + if cur, ok := s.tasks[entry.task.TaskID]; ok { + cur.cancel = cancel + } +} + +func (s *serverTasks) finishToolTask(entry *serverTaskEntry, res *CallToolResult, err error) { + s.mu.Lock() + cur := s.tasks[entry.task.TaskID] + if cur == nil { + s.mu.Unlock() + return + } + cur.result = res + cur.err = err + + // Respect terminal cancellation: do not transition away from cancelled. + if cur.task.Status != TaskStatusCancelled { + now := time.Now().UTC().Format(time.RFC3339) + cur.task.LastUpdatedAt = now + switch { + case err != nil: + cur.task.Status = TaskStatusFailed + cur.task.StatusMessage = err.Error() + case res != nil && res.IsError: + cur.task.Status = TaskStatusFailed + cur.task.StatusMessage = "tool execution failed" + default: + cur.task.Status = TaskStatusCompleted + cur.task.StatusMessage = "" + } + } + t := cur.task + s.mu.Unlock() + + // Best-effort status notification. + _ = handleNotify(context.Background(), notificationTaskStatus, newServerRequest(entry.session, (*TaskStatusNotificationParams)(&t))) +} + +func (s *serverTasks) get(session *ServerSession, taskID string) (*serverTaskEntry, error) { + s.mu.Lock() + defer s.mu.Unlock() + e := s.tasks[taskID] + if e == nil || e.session != session { + return nil, &jsonrpc.Error{Code: jsonrpc.CodeInvalidParams, Message: "Failed to retrieve task: Task not found"} + } + if e.expiresAt != nil && time.Now().After(*e.expiresAt) { + delete(s.tasks, taskID) + return nil, &jsonrpc.Error{Code: jsonrpc.CodeInvalidParams, Message: "Failed to retrieve task: Task has expired"} + } + return e, nil +} + +func (s *Server) getTask(_ context.Context, req *GetTaskRequest) (*GetTaskResult, error) { + if !s.tasksEnabled() { + return nil, jsonrpc2.ErrMethodNotFound + } + e, err := s.tasks.get(req.Session, req.Params.TaskID) + if err != nil { + return nil, err + } + t := GetTaskResult(e.task) + return &t, nil +} + +func (s *Server) listTasks(_ context.Context, req *ListTasksRequest) (*ListTasksResult, error) { + if !s.tasksListEnabled() { + return nil, jsonrpc2.ErrMethodNotFound + } + if req.Params == nil { + req.Params = &ListTasksParams{} + } + cursor, err := decodeTaskCursor(req.Params.Cursor) + if err != nil { + return nil, &jsonrpc.Error{Code: jsonrpc.CodeInvalidParams, Message: "Invalid cursor"} + } + + entries := s.tasks.listForSession(req.Session) + sort.Slice(entries, func(i, j int) bool { return entries[i].seq < entries[j].seq }) + + start := 0 + if cursor != 0 { + for i, e := range entries { + if e.seq == cursor { + start = i + 1 + break + } + } + if start == 0 { + return nil, &jsonrpc.Error{Code: jsonrpc.CodeInvalidParams, Message: "Invalid cursor"} + } + } + + pageSize := s.opts.PageSize + end := start + pageSize + if end > len(entries) { + end = len(entries) + } + + res := &ListTasksResult{Tasks: []*Task{}} + for _, e := range entries[start:end] { + t := e.task + res.Tasks = append(res.Tasks, &t) + } + if end < len(entries) { + res.NextCursor = encodeTaskCursor(entries[end-1].seq) + } + return res, nil +} + +func (s *serverTasks) listForSession(session *ServerSession) []*serverTaskEntry { + s.mu.Lock() + defer s.mu.Unlock() + var out []*serverTaskEntry + now := time.Now() + for id, e := range s.tasks { + if e.session != session { + continue + } + if e.expiresAt != nil && now.After(*e.expiresAt) { + delete(s.tasks, id) + continue + } + out = append(out, e) + } + return out +} + +func (s *Server) cancelTask(_ context.Context, req *CancelTaskRequest) (*CancelTaskResult, error) { + if !s.tasksCancelEnabled() { + return nil, jsonrpc2.ErrMethodNotFound + } + e, err := s.tasks.get(req.Session, req.Params.TaskID) + if err != nil { + return nil, err + } + + // Terminal tasks cannot be cancelled. + s.tasks.mu.Lock() + cur := s.tasks.tasks[e.task.TaskID] + if cur == nil { + s.tasks.mu.Unlock() + return nil, &jsonrpc.Error{Code: jsonrpc.CodeInvalidParams, Message: "Failed to cancel task: Task not found"} + } + switch cur.task.Status { + case TaskStatusCompleted, TaskStatusFailed, TaskStatusCancelled: + s.tasks.mu.Unlock() + return nil, &jsonrpc.Error{Code: jsonrpc.CodeInvalidParams, Message: fmt.Sprintf("Cannot cancel task: already in terminal status %q", cur.task.Status)} + default: + } + now := time.Now().UTC().Format(time.RFC3339) + cur.task.Status = TaskStatusCancelled + cur.task.StatusMessage = "The task was cancelled by request." + cur.task.LastUpdatedAt = now + cancel := cur.cancel + t := cur.task + s.tasks.mu.Unlock() + + if cancel != nil { + cancel() + } + // Best-effort status notification. + _ = handleNotify(context.Background(), notificationTaskStatus, newServerRequest(req.Session, (*TaskStatusNotificationParams)(&t))) + + res := CancelTaskResult(t) + return &res, nil +} + +func (s *Server) taskResult(ctx context.Context, req *TaskResultRequest) (*CallToolResult, error) { + if !s.tasksEnabled() { + return nil, jsonrpc2.ErrMethodNotFound + } + e, err := s.tasks.get(req.Session, req.Params.TaskID) + if err != nil { + return nil, err + } + + <-e.done + + s.tasks.mu.Lock() + cur := s.tasks.tasks[e.task.TaskID] + res := cur.result + err = cur.err + s.tasks.mu.Unlock() + + if err != nil { + return nil, err + } + if res == nil { + res = &CallToolResult{Content: []Content{}} + } + + m := res.GetMeta() + if m == nil { + m = map[string]any{} + } + m[relatedTaskMetaKey] = map[string]any{"taskId": req.Params.TaskID} + res.SetMeta(m) + return res, nil +} + +func (s *Server) callToolNow(ctx context.Context, req *CallToolRequest, st *serverTool) (*CallToolResult, error) { + // Ensure tasks are not propagated into the underlying call. + paramsCopy := *req.Params + paramsCopy.Task = nil + localReq := *req + localReq.Params = ¶msCopy + + res, err := st.handler(ctx, &localReq) + if err == nil && res != nil && res.Content == nil { + res2 := *res + res2.Content = []Content{} // avoid "null" + res = &res2 + } + return res, err +} + +func newTaskID() (string, error) { + var b [16]byte + if _, err := rand.Read(b[:]); err != nil { + return "", err + } + // Hex is fine; spec only requires a unique string. + return hex.EncodeToString(b[:]), nil +} + +func encodeTaskCursor(seq uint64) string { + return strconv.FormatUint(seq, 10) +} + +func decodeTaskCursor(cursor string) (uint64, error) { + if cursor == "" { + return 0, nil + } + return strconv.ParseUint(cursor, 10, 64) +} diff --git a/mcp/tasks_test.go b/mcp/tasks_test.go new file mode 100644 index 00000000..88981e05 --- /dev/null +++ b/mcp/tasks_test.go @@ -0,0 +1,412 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package mcp_test + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/google/jsonschema-go/jsonschema" + "github.com/modelcontextprotocol/go-sdk/mcp" + "github.com/modelcontextprotocol/go-sdk/jsonrpc" +) + +func TestToolTasksBasicLifecycle(t *testing.T) { + ctx := context.Background() + + // Enable task support for tools/call. + server := mcp.NewServer(&mcp.Implementation{Name: "testServer", Version: "v1.0.0"}, &mcp.ServerOptions{ + Capabilities: &mcp.ServerCapabilities{ + Tasks: &mcp.TasksCapabilities{ + List: &mcp.TasksListCapabilities{}, + Cancel: &mcp.TasksCancelCapabilities{}, + Requests: &mcp.TasksRequestsCapabilities{ + Tools: &mcp.TasksToolsRequestCapabilities{Call: &mcp.TasksToolsCallCapabilities{}}, + }, + }, + }, + }) + + start := make(chan struct{}) + server.AddTool(&mcp.Tool{ + Name: "slow", + InputSchema: &jsonschema.Schema{Type: "object"}, + Execution: &mcp.ToolExecution{TaskSupport: mcp.ToolTaskSupportOptional}, + }, func(ctx context.Context, req *mcp.CallToolRequest) (*mcp.CallToolResult, error) { + select { + case <-start: + return &mcp.CallToolResult{Content: []mcp.Content{&mcp.TextContent{Text: "ok"}}}, nil + case <-ctx.Done(): + return nil, ctx.Err() + } + }) + + cTransport, sTransport := mcp.NewInMemoryTransports() + ss, err := server.Connect(ctx, sTransport, nil) + if err != nil { + t.Fatal(err) + } + defer ss.Close() + + client := mcp.NewClient(&mcp.Implementation{Name: "testClient", Version: "v1.0.0"}, nil) + cs, err := client.Connect(ctx, cTransport, nil) + if err != nil { + t.Fatal(err) + } + defer cs.Close() + + ttl := int64(60_000) + createRes, err := cs.CallToolTask(ctx, &mcp.CallToolParams{ + Name: "slow", + Arguments: map[string]any{}, + Task: &mcp.TaskParams{TTL: &ttl}, + }) + if err != nil { + t.Fatalf("CallToolTask failed: %v", err) + } + if createRes.Task == nil || createRes.Task.TaskID == "" { + t.Fatalf("CreateTaskResult missing task/taskId: %#v", createRes) + } + if got, want := createRes.Task.Status, mcp.TaskStatusWorking; got != want { + t.Fatalf("initial status: got %q want %q", got, want) + } + + close(start) + + // TaskResult should block until completion and then return the tool result. + resultCtx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + toolRes, err := cs.TaskResult(resultCtx, &mcp.TaskResultParams{TaskID: createRes.Task.TaskID}) + if err != nil { + t.Fatalf("TaskResult failed: %v", err) + } + if toolRes == nil || len(toolRes.Content) != 1 { + t.Fatalf("unexpected tool result: %#v", toolRes) + } + + getRes, err := cs.GetTask(ctx, &mcp.GetTaskParams{TaskID: createRes.Task.TaskID}) + if err != nil { + t.Fatalf("GetTask failed: %v", err) + } + if got, want := getRes.Status, mcp.TaskStatusCompleted; got != want { + t.Fatalf("final status: got %q want %q", got, want) + } +} + +func TestToolTasksCancel(t *testing.T) { + ctx := context.Background() + + server := mcp.NewServer(&mcp.Implementation{Name: "testServer", Version: "v1.0.0"}, &mcp.ServerOptions{ + Capabilities: &mcp.ServerCapabilities{ + Tasks: &mcp.TasksCapabilities{ + List: &mcp.TasksListCapabilities{}, + Cancel: &mcp.TasksCancelCapabilities{}, + Requests: &mcp.TasksRequestsCapabilities{ + Tools: &mcp.TasksToolsRequestCapabilities{Call: &mcp.TasksToolsCallCapabilities{}}, + }, + }, + }, + }) + + block := make(chan struct{}) + server.AddTool(&mcp.Tool{ + Name: "block", + InputSchema: &jsonschema.Schema{Type: "object"}, + Execution: &mcp.ToolExecution{TaskSupport: mcp.ToolTaskSupportOptional}, + }, func(ctx context.Context, req *mcp.CallToolRequest) (*mcp.CallToolResult, error) { + select { + case <-block: + return &mcp.CallToolResult{Content: []mcp.Content{&mcp.TextContent{Text: "done"}}}, nil + case <-ctx.Done(): + return nil, ctx.Err() + } + }) + + cTransport, sTransport := mcp.NewInMemoryTransports() + ss, err := server.Connect(ctx, sTransport, nil) + if err != nil { + t.Fatal(err) + } + defer ss.Close() + + client := mcp.NewClient(&mcp.Implementation{Name: "testClient", Version: "v1.0.0"}, nil) + cs, err := client.Connect(ctx, cTransport, nil) + if err != nil { + t.Fatal(err) + } + defer cs.Close() + + ttl := int64(60_000) + createRes, err := cs.CallToolTask(ctx, &mcp.CallToolParams{ + Name: "block", + Arguments: map[string]any{}, + Task: &mcp.TaskParams{TTL: &ttl}, + }) + if err != nil { + t.Fatalf("CallToolTask failed: %v", err) + } + + cancelRes, err := cs.CancelTask(ctx, &mcp.CancelTaskParams{TaskID: createRes.Task.TaskID}) + if err != nil { + t.Fatalf("CancelTask failed: %v", err) + } + if got, want := cancelRes.Status, mcp.TaskStatusCancelled; got != want { + t.Fatalf("cancel status: got %q want %q", got, want) + } + + resultCtx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + _, err = cs.TaskResult(resultCtx, &mcp.TaskResultParams{TaskID: createRes.Task.TaskID}) + if err == nil { + t.Fatalf("TaskResult unexpectedly succeeded after cancel") + } +} + +func TestTasksListPaginationAndCursor(t *testing.T) { + ctx := context.Background() + + server := mcp.NewServer(&mcp.Implementation{Name: "testServer", Version: "v1.0.0"}, &mcp.ServerOptions{ + PageSize: 1, + Capabilities: &mcp.ServerCapabilities{ + Tasks: &mcp.TasksCapabilities{ + List: &mcp.TasksListCapabilities{}, + Cancel: &mcp.TasksCancelCapabilities{}, + Requests: &mcp.TasksRequestsCapabilities{ + Tools: &mcp.TasksToolsRequestCapabilities{Call: &mcp.TasksToolsCallCapabilities{}}, + }, + }, + }, + }) + + block := make(chan struct{}) + server.AddTool(&mcp.Tool{ + Name: "block", + InputSchema: &jsonschema.Schema{Type: "object"}, + Execution: &mcp.ToolExecution{TaskSupport: mcp.ToolTaskSupportOptional}, + }, func(ctx context.Context, req *mcp.CallToolRequest) (*mcp.CallToolResult, error) { + select { + case <-block: + return &mcp.CallToolResult{Content: []mcp.Content{}}, nil + case <-ctx.Done(): + return nil, ctx.Err() + } + }) + + cTransport, sTransport := mcp.NewInMemoryTransports() + ss, err := server.Connect(ctx, sTransport, nil) + if err != nil { + t.Fatal(err) + } + defer ss.Close() + + client := mcp.NewClient(&mcp.Implementation{Name: "testClient", Version: "v1.0.0"}, nil) + cs, err := client.Connect(ctx, cTransport, nil) + if err != nil { + t.Fatal(err) + } + defer cs.Close() + + ttl := int64(60_000) + create1, err := cs.CallToolTask(ctx, &mcp.CallToolParams{Name: "block", Arguments: map[string]any{}, Task: &mcp.TaskParams{TTL: &ttl}}) + if err != nil { + t.Fatalf("CallToolTask #1 failed: %v", err) + } + create2, err := cs.CallToolTask(ctx, &mcp.CallToolParams{Name: "block", Arguments: map[string]any{}, Task: &mcp.TaskParams{TTL: &ttl}}) + if err != nil { + t.Fatalf("CallToolTask #2 failed: %v", err) + } + if create1.Task.TaskID == create2.Task.TaskID { + t.Fatalf("expected distinct task IDs") + } + + page1, err := cs.ListTasks(ctx, &mcp.ListTasksParams{}) + if err != nil { + t.Fatalf("ListTasks page1 failed: %v", err) + } + if got := len(page1.Tasks); got != 1 { + t.Fatalf("ListTasks page1: got %d tasks, want 1", got) + } + if page1.NextCursor == "" { + t.Fatalf("ListTasks page1: expected nextCursor") + } + + page2, err := cs.ListTasks(ctx, &mcp.ListTasksParams{Cursor: page1.NextCursor}) + if err != nil { + t.Fatalf("ListTasks page2 failed: %v", err) + } + if got := len(page2.Tasks); got != 1 { + t.Fatalf("ListTasks page2: got %d tasks, want 1", got) + } + if page2.NextCursor != "" { + t.Fatalf("ListTasks page2: expected empty nextCursor, got %q", page2.NextCursor) + } + + _, err = cs.ListTasks(ctx, &mcp.ListTasksParams{Cursor: "999999999"}) + if err == nil { + t.Fatalf("ListTasks with invalid cursor unexpectedly succeeded") + } + var rpcErr *jsonrpc.Error + if !errors.As(err, &rpcErr) || rpcErr.Code != jsonrpc.CodeInvalidParams { + t.Fatalf("ListTasks invalid cursor: got %T/%v, want jsonrpc invalid params", err, err) + } + + close(block) + _, _ = cs.TaskResult(ctx, &mcp.TaskResultParams{TaskID: create1.Task.TaskID}) + _, _ = cs.TaskResult(ctx, &mcp.TaskResultParams{TaskID: create2.Task.TaskID}) +} + +func TestTasksGetNotFound(t *testing.T) { + ctx := context.Background() + + server := mcp.NewServer(&mcp.Implementation{Name: "testServer", Version: "v1.0.0"}, &mcp.ServerOptions{ + Capabilities: &mcp.ServerCapabilities{Tasks: &mcp.TasksCapabilities{}}, + }) + + cTransport, sTransport := mcp.NewInMemoryTransports() + ss, err := server.Connect(ctx, sTransport, nil) + if err != nil { + t.Fatal(err) + } + defer ss.Close() + + client := mcp.NewClient(&mcp.Implementation{Name: "testClient", Version: "v1.0.0"}, nil) + cs, err := client.Connect(ctx, cTransport, nil) + if err != nil { + t.Fatal(err) + } + defer cs.Close() + + _, err = cs.GetTask(ctx, &mcp.GetTaskParams{TaskID: "does-not-exist"}) + if err == nil { + t.Fatalf("GetTask unexpectedly succeeded") + } + var rpcErr *jsonrpc.Error + if !errors.As(err, &rpcErr) || rpcErr.Code != jsonrpc.CodeInvalidParams { + t.Fatalf("GetTask not found: got %T/%v, want jsonrpc invalid params", err, err) + } +} + +func TestTasksCancelTerminalRejected(t *testing.T) { + ctx := context.Background() + + server := mcp.NewServer(&mcp.Implementation{Name: "testServer", Version: "v1.0.0"}, &mcp.ServerOptions{ + Capabilities: &mcp.ServerCapabilities{ + Tasks: &mcp.TasksCapabilities{ + Cancel: &mcp.TasksCancelCapabilities{}, + Requests: &mcp.TasksRequestsCapabilities{Tools: &mcp.TasksToolsRequestCapabilities{Call: &mcp.TasksToolsCallCapabilities{}}}, + }, + }, + }) + + start := make(chan struct{}) + server.AddTool(&mcp.Tool{ + Name: "finish", + InputSchema: &jsonschema.Schema{Type: "object"}, + Execution: &mcp.ToolExecution{TaskSupport: mcp.ToolTaskSupportOptional}, + }, func(ctx context.Context, req *mcp.CallToolRequest) (*mcp.CallToolResult, error) { + <-start + return &mcp.CallToolResult{Content: []mcp.Content{}}, nil + }) + + cTransport, sTransport := mcp.NewInMemoryTransports() + ss, err := server.Connect(ctx, sTransport, nil) + if err != nil { + t.Fatal(err) + } + defer ss.Close() + + client := mcp.NewClient(&mcp.Implementation{Name: "testClient", Version: "v1.0.0"}, nil) + cs, err := client.Connect(ctx, cTransport, nil) + if err != nil { + t.Fatal(err) + } + defer cs.Close() + + ttl := int64(60_000) + createRes, err := cs.CallToolTask(ctx, &mcp.CallToolParams{Name: "finish", Arguments: map[string]any{}, Task: &mcp.TaskParams{TTL: &ttl}}) + if err != nil { + t.Fatalf("CallToolTask failed: %v", err) + } + + close(start) + resultCtx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + _, err = cs.TaskResult(resultCtx, &mcp.TaskResultParams{TaskID: createRes.Task.TaskID}) + if err != nil { + t.Fatalf("TaskResult failed: %v", err) + } + + _, err = cs.CancelTask(ctx, &mcp.CancelTaskParams{TaskID: createRes.Task.TaskID}) + if err == nil { + t.Fatalf("CancelTask unexpectedly succeeded on terminal task") + } + var rpcErr *jsonrpc.Error + if !errors.As(err, &rpcErr) || rpcErr.Code != jsonrpc.CodeInvalidParams { + t.Fatalf("CancelTask terminal: got %T/%v, want jsonrpc invalid params", err, err) + } +} + +func TestTasksResultIncludesRelatedTaskMeta(t *testing.T) { + ctx := context.Background() + + server := mcp.NewServer(&mcp.Implementation{Name: "testServer", Version: "v1.0.0"}, &mcp.ServerOptions{ + Capabilities: &mcp.ServerCapabilities{ + Tasks: &mcp.TasksCapabilities{ + Requests: &mcp.TasksRequestsCapabilities{Tools: &mcp.TasksToolsRequestCapabilities{Call: &mcp.TasksToolsCallCapabilities{}}}, + }, + }, + }) + + start := make(chan struct{}) + server.AddTool(&mcp.Tool{ + Name: "slow", + InputSchema: &jsonschema.Schema{Type: "object"}, + Execution: &mcp.ToolExecution{TaskSupport: mcp.ToolTaskSupportOptional}, + }, func(ctx context.Context, req *mcp.CallToolRequest) (*mcp.CallToolResult, error) { + <-start + return &mcp.CallToolResult{Content: []mcp.Content{&mcp.TextContent{Text: "ok"}}}, nil + }) + + cTransport, sTransport := mcp.NewInMemoryTransports() + ss, err := server.Connect(ctx, sTransport, nil) + if err != nil { + t.Fatal(err) + } + defer ss.Close() + + client := mcp.NewClient(&mcp.Implementation{Name: "testClient", Version: "v1.0.0"}, nil) + cs, err := client.Connect(ctx, cTransport, nil) + if err != nil { + t.Fatal(err) + } + defer cs.Close() + + ttl := int64(60_000) + createRes, err := cs.CallToolTask(ctx, &mcp.CallToolParams{Name: "slow", Arguments: map[string]any{}, Task: &mcp.TaskParams{TTL: &ttl}}) + if err != nil { + t.Fatalf("CallToolTask failed: %v", err) + } + + close(start) + resultCtx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + toolRes, err := cs.TaskResult(resultCtx, &mcp.TaskResultParams{TaskID: createRes.Task.TaskID}) + if err != nil { + t.Fatalf("TaskResult failed: %v", err) + } + meta := toolRes.GetMeta() + if meta == nil { + t.Fatalf("TaskResult missing _meta") + } + related, ok := meta["io.modelcontextprotocol/related-task"].(map[string]any) + if !ok { + t.Fatalf("TaskResult missing related-task metadata: %#v", meta) + } + if got, _ := related["taskId"].(string); got != createRes.Task.TaskID { + t.Fatalf("related-task.taskId: got %q want %q", got, createRes.Task.TaskID) + } +}