Skip to content

Commit 2edbe8c

Browse files
olaservoclaudepcarleton
authored
Use correct schema for client sampling validation when tools are present (modelcontextprotocol#1347)
Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com> Co-authored-by: Paul Carleton <paulcarletonjr@gmail.com>
1 parent 8bf13e4 commit 2edbe8c

File tree

2 files changed

+131
-2
lines changed

2 files changed

+131
-2
lines changed

packages/client/src/client/client.ts

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ import {
4343
CompleteResultSchema,
4444
CreateMessageRequestSchema,
4545
CreateMessageResultSchema,
46+
CreateMessageResultWithToolsSchema,
4647
CreateTaskResultSchema,
4748
ElicitRequestSchema,
4849
ElicitResultSchema,
@@ -458,8 +459,10 @@ export class Client<
458459
return taskValidationResult.data;
459460
}
460461

461-
// For non-task requests, validate against CreateMessageResultSchema
462-
const validationResult = safeParse(CreateMessageResultSchema, result);
462+
// For non-task requests, validate against appropriate schema based on tools presence
463+
const hasTools = params.tools || params.toolChoice;
464+
const resultSchema = hasTools ? CreateMessageResultWithToolsSchema : CreateMessageResultSchema;
465+
const validationResult = safeParse(resultSchema, result);
463466
if (!validationResult.success) {
464467
const errorMessage =
465468
validationResult.error instanceof Error ? validationResult.error.message : String(validationResult.error);

test/integration/test/client/client.test.ts

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4132,3 +4132,129 @@ describe('getSupportedElicitationModes', () => {
41324132
expect(result.supportsUrlMode).toBe(false);
41334133
});
41344134
});
4135+
4136+
describe('Client sampling validation with tools', () => {
4137+
test('should validate array content with tool_use when request includes tools', async () => {
4138+
const server = new Server({ name: 'test server', version: '1.0' }, { capabilities: {} });
4139+
4140+
const client = new Client({ name: 'test client', version: '1.0' }, { capabilities: { sampling: { tools: {} } } });
4141+
4142+
// Handler returns array content with tool_use - should validate with CreateMessageResultWithToolsSchema
4143+
client.setRequestHandler(CreateMessageRequestSchema, async () => ({
4144+
model: 'test-model',
4145+
role: 'assistant',
4146+
stopReason: 'toolUse',
4147+
content: [{ type: 'tool_use', id: 'call_1', name: 'test_tool', input: { arg: 'value' } }]
4148+
}));
4149+
4150+
const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair();
4151+
await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]);
4152+
4153+
const result = await server.createMessage({
4154+
messages: [{ role: 'user', content: { type: 'text', text: 'hello' } }],
4155+
maxTokens: 100,
4156+
tools: [{ name: 'test_tool', inputSchema: { type: 'object' } }]
4157+
});
4158+
4159+
expect(result.stopReason).toBe('toolUse');
4160+
expect(Array.isArray(result.content)).toBe(true);
4161+
expect((result.content as Array<{ type: string }>)[0].type).toBe('tool_use');
4162+
});
4163+
4164+
test('should validate single content when request includes tools', async () => {
4165+
const server = new Server({ name: 'test server', version: '1.0' }, { capabilities: {} });
4166+
4167+
const client = new Client({ name: 'test client', version: '1.0' }, { capabilities: { sampling: { tools: {} } } });
4168+
4169+
// Handler returns single content (text) - should still validate with CreateMessageResultWithToolsSchema
4170+
client.setRequestHandler(CreateMessageRequestSchema, async () => ({
4171+
model: 'test-model',
4172+
role: 'assistant',
4173+
content: { type: 'text', text: 'No tool needed' }
4174+
}));
4175+
4176+
const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair();
4177+
await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]);
4178+
4179+
const result = await server.createMessage({
4180+
messages: [{ role: 'user', content: { type: 'text', text: 'hello' } }],
4181+
maxTokens: 100,
4182+
tools: [{ name: 'test_tool', inputSchema: { type: 'object' } }]
4183+
});
4184+
4185+
expect((result.content as { type: string }).type).toBe('text');
4186+
});
4187+
4188+
test('should validate single content when request has no tools', async () => {
4189+
const server = new Server({ name: 'test server', version: '1.0' }, { capabilities: {} });
4190+
4191+
const client = new Client({ name: 'test client', version: '1.0' }, { capabilities: { sampling: {} } });
4192+
4193+
// Handler returns single content - should validate with CreateMessageResultSchema
4194+
client.setRequestHandler(CreateMessageRequestSchema, async () => ({
4195+
model: 'test-model',
4196+
role: 'assistant',
4197+
content: { type: 'text', text: 'Response' }
4198+
}));
4199+
4200+
const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair();
4201+
await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]);
4202+
4203+
const result = await server.createMessage({
4204+
messages: [{ role: 'user', content: { type: 'text', text: 'hello' } }],
4205+
maxTokens: 100
4206+
});
4207+
4208+
expect((result.content as { type: string }).type).toBe('text');
4209+
});
4210+
4211+
test('should reject array content when request has no tools', async () => {
4212+
const server = new Server({ name: 'test server', version: '1.0' }, { capabilities: {} });
4213+
4214+
const client = new Client({ name: 'test client', version: '1.0' }, { capabilities: { sampling: {} } });
4215+
4216+
// Handler returns array content - should fail validation with CreateMessageResultSchema
4217+
client.setRequestHandler(CreateMessageRequestSchema, async () => ({
4218+
model: 'test-model',
4219+
role: 'assistant',
4220+
content: [{ type: 'text', text: 'Array response' }]
4221+
}));
4222+
4223+
const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair();
4224+
await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]);
4225+
4226+
await expect(
4227+
server.createMessage({
4228+
messages: [{ role: 'user', content: { type: 'text', text: 'hello' } }],
4229+
maxTokens: 100
4230+
})
4231+
).rejects.toThrow('Invalid sampling result');
4232+
});
4233+
4234+
test('should validate array content when request includes toolChoice', async () => {
4235+
const server = new Server({ name: 'test server', version: '1.0' }, { capabilities: {} });
4236+
4237+
const client = new Client({ name: 'test client', version: '1.0' }, { capabilities: { sampling: { tools: {} } } });
4238+
4239+
// Handler returns array content with tool_use
4240+
client.setRequestHandler(CreateMessageRequestSchema, async () => ({
4241+
model: 'test-model',
4242+
role: 'assistant',
4243+
stopReason: 'toolUse',
4244+
content: [{ type: 'tool_use', id: 'call_1', name: 'test_tool', input: {} }]
4245+
}));
4246+
4247+
const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair();
4248+
await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]);
4249+
4250+
const result = await server.createMessage({
4251+
messages: [{ role: 'user', content: { type: 'text', text: 'hello' } }],
4252+
maxTokens: 100,
4253+
tools: [{ name: 'test_tool', inputSchema: { type: 'object' } }],
4254+
toolChoice: { mode: 'auto' }
4255+
});
4256+
4257+
expect(result.stopReason).toBe('toolUse');
4258+
expect(Array.isArray(result.content)).toBe(true);
4259+
});
4260+
});

0 commit comments

Comments
 (0)