Skip to content

Commit a502979

Browse files
committed
First implementation of tasks
1 parent 13488f7 commit a502979

File tree

7 files changed

+1132
-3
lines changed

7 files changed

+1132
-3
lines changed

mcp/client.go

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,6 @@ func NewClient(impl *Implementation, options *ClientOptions) *Client {
5050
if options != nil {
5151
opts = *options
5252
}
53-
options = nil // prevent reuse
5453

5554
if opts.Logger == nil { // ensure we have a logger
5655
opts.Logger = ensureLogger(nil)
@@ -129,6 +128,7 @@ type ClientOptions struct {
129128
PromptListChangedHandler func(context.Context, *PromptListChangedRequest)
130129
ResourceListChangedHandler func(context.Context, *ResourceListChangedRequest)
131130
ResourceUpdatedHandler func(context.Context, *ResourceUpdatedNotificationRequest)
131+
TaskStatusHandler func(context.Context, *TaskStatusNotificationRequest)
132132
LoggingMessageHandler func(context.Context, *LoggingMessageRequest)
133133
ProgressNotificationHandler func(context.Context, *ProgressNotificationClientRequest)
134134
// If non-zero, defines an interval for regular "ping" requests.
@@ -807,6 +807,7 @@ var clientMethodInfos = map[string]methodInfo{
807807
methodCreateMessage: newClientMethodInfo(clientMethod((*Client).createMessage), 0),
808808
methodElicit: newClientMethodInfo(clientMethod((*Client).elicit), missingParamsOK),
809809
notificationCancelled: newClientMethodInfo(clientSessionMethod((*ClientSession).cancel), notification|missingParamsOK),
810+
notificationTaskStatus: newClientMethodInfo(clientMethod((*Client).callTaskStatusHandler), notification),
810811
notificationToolListChanged: newClientMethodInfo(clientMethod((*Client).callToolChangedHandler), notification|missingParamsOK),
811812
notificationPromptListChanged: newClientMethodInfo(clientMethod((*Client).callPromptChangedHandler), notification|missingParamsOK),
812813
notificationResourceListChanged: newClientMethodInfo(clientMethod((*Client).callResourceChangedHandler), notification|missingParamsOK),
@@ -888,6 +889,9 @@ func (cs *ClientSession) ListTools(ctx context.Context, params *ListToolsParams)
888889
//
889890
// The params.Arguments can be any value that marshals into a JSON object.
890891
func (cs *ClientSession) CallTool(ctx context.Context, params *CallToolParams) (*CallToolResult, error) {
892+
if params != nil && params.Task != nil {
893+
return nil, fmt.Errorf("task augmentation requested: use CallToolTask")
894+
}
891895
if params == nil {
892896
params = new(CallToolParams)
893897
}
@@ -898,6 +902,43 @@ func (cs *ClientSession) CallTool(ctx context.Context, params *CallToolParams) (
898902
return handleSend[*CallToolResult](ctx, methodCallTool, newClientRequest(cs, orZero[Params](params)))
899903
}
900904

905+
// CallToolTask calls a tool using task-based execution (tools/call with params.task).
906+
//
907+
// The response is a CreateTaskResult. Use GetTask to poll for task state and
908+
// TaskResult to retrieve the final tool result.
909+
func (cs *ClientSession) CallToolTask(ctx context.Context, params *CallToolParams) (*CreateTaskResult, error) {
910+
if params == nil || params.Task == nil {
911+
return nil, fmt.Errorf("CallToolTask requires params.Task")
912+
}
913+
if params.Arguments == nil {
914+
// Avoid sending nil over the wire.
915+
params.Arguments = map[string]any{}
916+
}
917+
return handleSend[*CreateTaskResult](ctx, methodCallTool, newClientRequest(cs, params))
918+
}
919+
920+
// GetTask polls task status via tasks/get.
921+
func (cs *ClientSession) GetTask(ctx context.Context, params *GetTaskParams) (*GetTaskResult, error) {
922+
return handleSend[*GetTaskResult](ctx, methodGetTask, newClientRequest(cs, orZero[Params](params)))
923+
}
924+
925+
// ListTasks lists tasks via tasks/list.
926+
func (cs *ClientSession) ListTasks(ctx context.Context, params *ListTasksParams) (*ListTasksResult, error) {
927+
return handleSend[*ListTasksResult](ctx, methodListTasks, newClientRequest(cs, orZero[Params](params)))
928+
}
929+
930+
// CancelTask requests cancellation via tasks/cancel.
931+
func (cs *ClientSession) CancelTask(ctx context.Context, params *CancelTaskParams) (*CancelTaskResult, error) {
932+
return handleSend[*CancelTaskResult](ctx, methodCancelTask, newClientRequest(cs, orZero[Params](params)))
933+
}
934+
935+
// TaskResult retrieves the final result of a task via tasks/result.
936+
//
937+
// Currently, this SDK supports tasks/result only for tasks created from tools/call.
938+
func (cs *ClientSession) TaskResult(ctx context.Context, params *TaskResultParams) (*CallToolResult, error) {
939+
return handleSend[*CallToolResult](ctx, methodTaskResult, newClientRequest(cs, orZero[Params](params)))
940+
}
941+
901942
func (cs *ClientSession) SetLoggingLevel(ctx context.Context, params *SetLoggingLevelParams) error {
902943
_, err := handleSend[*emptyResult](ctx, methodSetLevel, newClientRequest(cs, orZero[Params](params)))
903944
return err
@@ -971,6 +1012,13 @@ func (c *Client) callLoggingHandler(ctx context.Context, req *LoggingMessageRequ
9711012
return nil, nil
9721013
}
9731014

1015+
func (c *Client) callTaskStatusHandler(ctx context.Context, req *TaskStatusNotificationRequest) (Result, error) {
1016+
if h := c.opts.TaskStatusHandler; h != nil {
1017+
h(ctx, req)
1018+
}
1019+
return nil, nil
1020+
}
1021+
9741022
func (cs *ClientSession) callProgressNotificationHandler(ctx context.Context, params *ProgressNotificationParams) (Result, error) {
9751023
if h := cs.client.opts.ProgressNotificationHandler; h != nil {
9761024
h(ctx, clientRequestFor(cs, params))

mcp/protocol.go

Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,11 @@ type CallToolParams struct {
4848
// Arguments holds the tool arguments. It can hold any value that can be
4949
// marshaled to JSON.
5050
Arguments any `json:"arguments,omitempty"`
51+
// Task optionally requests task-based execution of this tool call.
52+
//
53+
// Note: when Task is present, the wire response is a CreateTaskResult rather
54+
// than a CallToolResult.
55+
Task *TaskParams `json:"task,omitempty"`
5156
}
5257

5358
// CallToolParamsRaw is passed to tool handlers on the server. Its arguments
@@ -63,6 +68,8 @@ type CallToolParamsRaw struct {
6368
// is the responsibility of the tool handler to unmarshal and validate the
6469
// Arguments (see [AddTool]).
6570
Arguments json.RawMessage `json:"arguments,omitempty"`
71+
// Task optionally requests task-based execution of this tool call.
72+
Task *TaskParams `json:"task,omitempty"`
6673
}
6774

6875
// A CallToolResult is the server's response to a tool call.
@@ -210,13 +217,16 @@ type ClientCapabilities struct {
210217
Sampling *SamplingCapabilities `json:"sampling,omitempty"`
211218
// Elicitation is present if the client supports elicitation from the server.
212219
Elicitation *ElicitationCapabilities `json:"elicitation,omitempty"`
220+
// Tasks describes support for task-based execution.
221+
Tasks *TasksCapabilities `json:"tasks,omitempty"`
213222
}
214223

215224
// clone returns a deep copy of the ClientCapabilities.
216225
func (c *ClientCapabilities) clone() *ClientCapabilities {
217226
cp := *c
218227
cp.RootsV2 = shallowClone(c.RootsV2)
219228
cp.Sampling = shallowClone(c.Sampling)
229+
cp.Tasks = shallowClone(c.Tasks)
220230
if c.Elicitation != nil {
221231
x := *c.Elicitation
222232
x.Form = shallowClone(c.Elicitation.Form)
@@ -1092,8 +1102,28 @@ type Tool struct {
10921102
Title string `json:"title,omitempty"`
10931103
// Icons for the tool, if any.
10941104
Icons []Icon `json:"icons,omitempty"`
1105+
// Execution contains optional execution-related settings.
1106+
Execution *ToolExecution `json:"execution,omitempty"`
1107+
}
1108+
1109+
// ToolExecution configures execution behavior for a tool.
1110+
type ToolExecution struct {
1111+
// TaskSupport declares task support for this tool.
1112+
//
1113+
// Valid values are: "required", "optional", or "forbidden".
1114+
// See ToolTaskSupportRequired, ToolTaskSupportOptional, and ToolTaskSupportForbidden.
1115+
TaskSupport string `json:"taskSupport,omitempty"`
10951116
}
10961117

1118+
const (
1119+
// ToolTaskSupportRequired indicates the tool MUST be invoked with task augmentation.
1120+
ToolTaskSupportRequired = "required"
1121+
// ToolTaskSupportOptional indicates the tool MAY be invoked with task augmentation.
1122+
ToolTaskSupportOptional = "optional"
1123+
// ToolTaskSupportForbidden indicates the tool MUST NOT be invoked with task augmentation.
1124+
ToolTaskSupportForbidden = "forbidden"
1125+
)
1126+
10971127
// Additional properties describing a Tool to clients.
10981128
//
10991129
// NOTE: all properties in ToolAnnotations are hints. They are not
@@ -1314,6 +1344,8 @@ type ServerCapabilities struct {
13141344
Resources *ResourceCapabilities `json:"resources,omitempty"`
13151345
// Tools is present if the supports tools.
13161346
Tools *ToolCapabilities `json:"tools,omitempty"`
1347+
// Tasks describes support for task-based execution.
1348+
Tasks *TasksCapabilities `json:"tasks,omitempty"`
13171349
}
13181350

13191351
// clone returns a deep copy of the ServerCapabilities.
@@ -1324,12 +1356,148 @@ func (c *ServerCapabilities) clone() *ServerCapabilities {
13241356
cp.Prompts = shallowClone(c.Prompts)
13251357
cp.Resources = shallowClone(c.Resources)
13261358
cp.Tools = shallowClone(c.Tools)
1359+
cp.Tasks = shallowClone(c.Tasks)
13271360
return &cp
13281361
}
13291362

1363+
// TasksCapabilities describes support for task-based execution.
1364+
type TasksCapabilities struct {
1365+
List *TasksListCapabilities `json:"list,omitempty"`
1366+
Cancel *TasksCancelCapabilities `json:"cancel,omitempty"`
1367+
Requests *TasksRequestsCapabilities `json:"requests,omitempty"`
1368+
}
1369+
1370+
type TasksListCapabilities struct{}
1371+
type TasksCancelCapabilities struct{}
1372+
1373+
type TasksRequestsCapabilities struct {
1374+
Tools *TasksToolsRequestCapabilities `json:"tools,omitempty"`
1375+
Sampling *TasksSamplingRequestCapabilities `json:"sampling,omitempty"`
1376+
Elicitation *TasksElicitationRequestCapabilities `json:"elicitation,omitempty"`
1377+
}
1378+
1379+
type TasksToolsRequestCapabilities struct {
1380+
Call *TasksToolsCallCapabilities `json:"call,omitempty"`
1381+
}
1382+
1383+
type TasksToolsCallCapabilities struct{}
1384+
1385+
type TasksSamplingRequestCapabilities struct {
1386+
CreateMessage *TasksSamplingCreateMessageCapabilities `json:"createMessage,omitempty"`
1387+
}
1388+
1389+
type TasksSamplingCreateMessageCapabilities struct{}
1390+
1391+
type TasksElicitationRequestCapabilities struct {
1392+
Create *TasksElicitationCreateCapabilities `json:"create,omitempty"`
1393+
}
1394+
1395+
type TasksElicitationCreateCapabilities struct{}
1396+
1397+
// TaskParams is included in request parameters to request task-based execution.
1398+
type TaskParams struct {
1399+
TTL *int64 `json:"ttl,omitempty"`
1400+
}
1401+
1402+
type TaskStatus string
1403+
1404+
const (
1405+
TaskStatusWorking TaskStatus = "working"
1406+
TaskStatusInputRequired TaskStatus = "input_required"
1407+
TaskStatusCompleted TaskStatus = "completed"
1408+
TaskStatusFailed TaskStatus = "failed"
1409+
TaskStatusCancelled TaskStatus = "cancelled"
1410+
)
1411+
1412+
// Task describes the state of a task.
1413+
type Task struct {
1414+
Meta `json:"_meta,omitempty"`
1415+
TaskID string `json:"taskId"`
1416+
Status TaskStatus `json:"status"`
1417+
StatusMessage string `json:"statusMessage,omitempty"`
1418+
CreatedAt string `json:"createdAt"`
1419+
LastUpdatedAt string `json:"lastUpdatedAt"`
1420+
TTL *int64 `json:"ttl"`
1421+
PollInterval *int64 `json:"pollInterval,omitempty"`
1422+
}
1423+
1424+
// CreateTaskResult is returned for task-augmented requests.
1425+
type CreateTaskResult struct {
1426+
Meta `json:"_meta,omitempty"`
1427+
Task *Task `json:"task"`
1428+
}
1429+
1430+
func (*CreateTaskResult) isResult() {}
1431+
1432+
type GetTaskParams struct {
1433+
Meta `json:"_meta,omitempty"`
1434+
TaskID string `json:"taskId"`
1435+
}
1436+
1437+
func (*GetTaskParams) isParams() {}
1438+
func (x *GetTaskParams) GetProgressToken() any { return getProgressToken(x) }
1439+
func (x *GetTaskParams) SetProgressToken(t any) { setProgressToken(x, t) }
1440+
1441+
type GetTaskResult Task
1442+
1443+
func (*GetTaskResult) isResult() {}
1444+
1445+
type ListTasksParams struct {
1446+
Meta `json:"_meta,omitempty"`
1447+
Cursor string `json:"cursor,omitempty"`
1448+
}
1449+
1450+
func (x *ListTasksParams) isParams() {}
1451+
func (x *ListTasksParams) GetProgressToken() any { return getProgressToken(x) }
1452+
func (x *ListTasksParams) SetProgressToken(t any) { setProgressToken(x, t) }
1453+
func (x *ListTasksParams) cursorPtr() *string { return &x.Cursor }
1454+
1455+
type ListTasksResult struct {
1456+
Meta `json:"_meta,omitempty"`
1457+
Tasks []*Task `json:"tasks"`
1458+
NextCursor string `json:"nextCursor,omitempty"`
1459+
}
1460+
1461+
func (*ListTasksResult) isResult() {}
1462+
func (x *ListTasksResult) nextCursorPtr() *string { return &x.NextCursor }
1463+
1464+
type CancelTaskParams struct {
1465+
Meta `json:"_meta,omitempty"`
1466+
TaskID string `json:"taskId"`
1467+
}
1468+
1469+
func (*CancelTaskParams) isParams() {}
1470+
func (x *CancelTaskParams) GetProgressToken() any { return getProgressToken(x) }
1471+
func (x *CancelTaskParams) SetProgressToken(t any) { setProgressToken(x, t) }
1472+
1473+
type CancelTaskResult Task
1474+
1475+
func (*CancelTaskResult) isResult() {}
1476+
1477+
type TaskResultParams struct {
1478+
Meta `json:"_meta,omitempty"`
1479+
TaskID string `json:"taskId"`
1480+
}
1481+
1482+
func (*TaskResultParams) isParams() {}
1483+
func (x *TaskResultParams) GetProgressToken() any { return getProgressToken(x) }
1484+
func (x *TaskResultParams) SetProgressToken(t any) { setProgressToken(x, t) }
1485+
1486+
// TaskStatusNotificationParams is sent as notifications/tasks/status.
1487+
type TaskStatusNotificationParams Task
1488+
1489+
func (*TaskStatusNotificationParams) isParams() {}
1490+
func (x *TaskStatusNotificationParams) GetProgressToken() any { return getProgressToken(x) }
1491+
func (x *TaskStatusNotificationParams) SetProgressToken(t any) { setProgressToken(x, t) }
1492+
13301493
const (
13311494
methodCallTool = "tools/call"
1495+
methodGetTask = "tasks/get"
1496+
methodListTasks = "tasks/list"
1497+
methodCancelTask = "tasks/cancel"
1498+
methodTaskResult = "tasks/result"
13321499
notificationCancelled = "notifications/cancelled"
1500+
notificationTaskStatus = "notifications/tasks/status"
13331501
methodComplete = "completion/complete"
13341502
methodCreateMessage = "sampling/createMessage"
13351503
methodElicit = "elicitation/create"

mcp/requests.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,12 @@ package mcp
88

99
type (
1010
CallToolRequest = ServerRequest[*CallToolParamsRaw]
11+
CancelTaskRequest = ServerRequest[*CancelTaskParams]
1112
CompleteRequest = ServerRequest[*CompleteParams]
13+
GetTaskRequest = ServerRequest[*GetTaskParams]
1214
GetPromptRequest = ServerRequest[*GetPromptParams]
1315
InitializedRequest = ServerRequest[*InitializedParams]
16+
ListTasksRequest = ServerRequest[*ListTasksParams]
1417
ListPromptsRequest = ServerRequest[*ListPromptsParams]
1518
ListResourcesRequest = ServerRequest[*ListResourcesParams]
1619
ListResourceTemplatesRequest = ServerRequest[*ListResourceTemplatesParams]
@@ -19,6 +22,8 @@ type (
1922
ReadResourceRequest = ServerRequest[*ReadResourceParams]
2023
RootsListChangedRequest = ServerRequest[*RootsListChangedParams]
2124
SubscribeRequest = ServerRequest[*SubscribeParams]
25+
TaskStatusNotificationServerRequest = ServerRequest[*TaskStatusNotificationParams]
26+
TaskResultRequest = ServerRequest[*TaskResultParams]
2227
UnsubscribeRequest = ServerRequest[*UnsubscribeParams]
2328
)
2429

@@ -33,6 +38,7 @@ type (
3338
PromptListChangedRequest = ClientRequest[*PromptListChangedParams]
3439
ResourceListChangedRequest = ClientRequest[*ResourceListChangedParams]
3540
ResourceUpdatedNotificationRequest = ClientRequest[*ResourceUpdatedNotificationParams]
41+
TaskStatusNotificationRequest = ClientRequest[*TaskStatusNotificationParams]
3642
ToolListChangedRequest = ClientRequest[*ToolListChangedParams]
3743
ElicitationCompleteNotificationRequest = ClientRequest[*ElicitationCompleteParams]
3844
)

0 commit comments

Comments
 (0)