Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions packages/cli/src/gemini.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,9 @@ export async function main() {
process.cwd(),
argv.extensions,
);

// Register cleanup for MCP clients as early as possible
// This ensures MCP server subprocesses are properly terminated on exit
registerCleanup(() => config.shutdown());

// FIXME: list extensions after the config initialize
Expand Down
29 changes: 22 additions & 7 deletions packages/core/src/config/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -814,13 +814,6 @@ export class Config {
return this.sessionId;
}

/**
* Releases resources owned by the config instance.
*/
async shutdown(): Promise<void> {
this.skillManager?.stopWatching();
}

/**
* Starts a new session and resets session-scoped services.
*/
Expand Down Expand Up @@ -1027,6 +1020,28 @@ export class Config {
return this.toolRegistry;
}

/**
* Shuts down the Config and releases all resources.
* This method is idempotent and safe to call multiple times.
* It handles the case where initialization was not completed.
*/
async shutdown(): Promise<void> {
if (!this.initialized) {
// Nothing to clean up if not initialized
return;
}
try {
this.skillManager?.stopWatching();

if (this.toolRegistry) {
await this.toolRegistry.stop();
}
} catch (error) {
// Log but don't throw - cleanup should be best-effort
console.error('Error during Config shutdown:', error);
}
}

getPromptRegistry(): PromptRegistry {
return this.promptRegistry;
}
Expand Down
181 changes: 178 additions & 3 deletions packages/core/src/tools/mcp-client-manager.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,16 @@ import { McpClientManager } from './mcp-client-manager.js';
import { McpClient } from './mcp-client.js';
import type { ToolRegistry } from './tool-registry.js';
import type { Config } from '../config/config.js';
import type { PromptRegistry } from '../prompts/prompt-registry.js';
import type { WorkspaceContext } from '../utils/workspaceContext.js';

vi.mock('./mcp-client.js', async () => {
const originalModule = await vi.importActual('./mcp-client.js');
return {
...originalModule,
McpClient: vi.fn(),
populateMcpServerCommand: vi.fn(() => ({
'test-server': {},
})),
// Return the input servers unchanged (identity function)
populateMcpServerCommand: vi.fn((servers) => servers),
};
});

Expand Down Expand Up @@ -73,4 +74,178 @@ describe('McpClientManager', () => {
expect(mockedMcpClient.connect).not.toHaveBeenCalled();
expect(mockedMcpClient.discover).not.toHaveBeenCalled();
});

it('should disconnect all clients when stop is called', async () => {
// Track disconnect calls across all instances
const disconnectCalls: string[] = [];
vi.mocked(McpClient).mockImplementation(
(name: string) =>
({
connect: vi.fn(),
discover: vi.fn(),
disconnect: vi.fn().mockImplementation(() => {
disconnectCalls.push(name);
return Promise.resolve();
}),
getStatus: vi.fn(),
}) as unknown as McpClient,
);
const mockConfig = {
isTrustedFolder: () => true,
getMcpServers: () => ({ 'test-server': {}, 'another-server': {} }),
getMcpServerCommand: () => undefined,
getPromptRegistry: () => ({}) as PromptRegistry,
getWorkspaceContext: () => ({}) as WorkspaceContext,
getDebugMode: () => false,
} as unknown as Config;
const manager = new McpClientManager(mockConfig, {} as ToolRegistry);
// First connect to create the clients
await manager.discoverAllMcpTools({
isTrustedFolder: () => true,
} as unknown as Config);

// Clear the disconnect calls from initial stop() in discoverAllMcpTools
disconnectCalls.length = 0;

// Then stop
await manager.stop();
expect(disconnectCalls).toHaveLength(2);
expect(disconnectCalls).toContain('test-server');
expect(disconnectCalls).toContain('another-server');
});

it('should be idempotent - stop can be called multiple times safely', async () => {
const mockedMcpClient = {
connect: vi.fn(),
discover: vi.fn(),
disconnect: vi.fn().mockResolvedValue(undefined),
getStatus: vi.fn(),
};
vi.mocked(McpClient).mockReturnValue(
mockedMcpClient as unknown as McpClient,
);
const mockConfig = {
isTrustedFolder: () => true,
getMcpServers: () => ({ 'test-server': {} }),
getMcpServerCommand: () => undefined,
getPromptRegistry: () => ({}) as PromptRegistry,
getWorkspaceContext: () => ({}) as WorkspaceContext,
getDebugMode: () => false,
} as unknown as Config;
const manager = new McpClientManager(mockConfig, {} as ToolRegistry);
await manager.discoverAllMcpTools({
isTrustedFolder: () => true,
} as unknown as Config);

// Call stop multiple times - should not throw
await manager.stop();
await manager.stop();
await manager.stop();
});

it('should discover tools for a single server and track the client for stop', async () => {
const mockedMcpClient = {
connect: vi.fn(),
discover: vi.fn(),
disconnect: vi.fn().mockResolvedValue(undefined),
getStatus: vi.fn(),
};
vi.mocked(McpClient).mockReturnValue(
mockedMcpClient as unknown as McpClient,
);

const mockConfig = {
isTrustedFolder: () => true,
getMcpServers: () => ({ 'test-server': {} }),
getMcpServerCommand: () => undefined,
getPromptRegistry: () => ({}) as PromptRegistry,
getWorkspaceContext: () => ({}) as WorkspaceContext,
getDebugMode: () => false,
} as unknown as Config;
const manager = new McpClientManager(mockConfig, {} as ToolRegistry);

await manager.discoverMcpToolsForServer(
'test-server',
{} as unknown as Config,
);

expect(mockedMcpClient.connect).toHaveBeenCalledOnce();
expect(mockedMcpClient.discover).toHaveBeenCalledOnce();

await manager.stop();
expect(mockedMcpClient.disconnect).toHaveBeenCalledOnce();
});

it('should replace an existing client when re-discovering a server', async () => {
const firstClient = {
connect: vi.fn(),
discover: vi.fn(),
disconnect: vi.fn().mockResolvedValue(undefined),
getStatus: vi.fn(),
};
const secondClient = {
connect: vi.fn(),
discover: vi.fn(),
disconnect: vi.fn().mockResolvedValue(undefined),
getStatus: vi.fn(),
};

vi.mocked(McpClient)
.mockReturnValueOnce(firstClient as unknown as McpClient)
.mockReturnValueOnce(secondClient as unknown as McpClient);

const mockConfig = {
isTrustedFolder: () => true,
getMcpServers: () => ({ 'test-server': {} }),
getMcpServerCommand: () => undefined,
getPromptRegistry: () => ({}) as PromptRegistry,
getWorkspaceContext: () => ({}) as WorkspaceContext,
getDebugMode: () => false,
} as unknown as Config;
const manager = new McpClientManager(mockConfig, {} as ToolRegistry);

await manager.discoverMcpToolsForServer(
'test-server',
{} as unknown as Config,
);
await manager.discoverMcpToolsForServer(
'test-server',
{} as unknown as Config,
);

expect(firstClient.disconnect).toHaveBeenCalledOnce();
expect(secondClient.connect).toHaveBeenCalledOnce();
expect(secondClient.discover).toHaveBeenCalledOnce();

await manager.stop();
expect(secondClient.disconnect).toHaveBeenCalledOnce();
});

it('should no-op when discovering an unknown server', async () => {
const mockedMcpClient = {
connect: vi.fn(),
discover: vi.fn(),
disconnect: vi.fn().mockResolvedValue(undefined),
getStatus: vi.fn(),
};
vi.mocked(McpClient).mockReturnValue(
mockedMcpClient as unknown as McpClient,
);

const mockConfig = {
isTrustedFolder: () => true,
getMcpServers: () => ({}),
getMcpServerCommand: () => undefined,
getPromptRegistry: () => ({}) as PromptRegistry,
getWorkspaceContext: () => ({}) as WorkspaceContext,
getDebugMode: () => false,
} as unknown as Config;
const manager = new McpClientManager(mockConfig, {} as ToolRegistry);

await manager.discoverMcpToolsForServer('unknown-server', {
isTrustedFolder: () => true,
} as unknown as Config);

expect(vi.mocked(McpClient)).not.toHaveBeenCalled();
});
});
67 changes: 67 additions & 0 deletions packages/core/src/tools/mcp-client-manager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,73 @@ export class McpClientManager {
this.discoveryState = MCPDiscoveryState.COMPLETED;
}

/**
* Connects to a single MCP server and discovers its tools/prompts.
* The connected client is tracked so it can be closed by {@link stop}.
*
* This is primarily used for on-demand re-discovery flows (e.g. after OAuth).
*/
async discoverMcpToolsForServer(
serverName: string,
cliConfig: Config,
): Promise<void> {
const servers = populateMcpServerCommand(
this.cliConfig.getMcpServers() || {},
this.cliConfig.getMcpServerCommand(),
);
const serverConfig = servers[serverName];
if (!serverConfig) {
return;
}

// Ensure we don't leak an existing connection for this server.
const existingClient = this.clients.get(serverName);
if (existingClient) {
try {
await existingClient.disconnect();
} catch (error) {
console.error(
`Error stopping client '${serverName}': ${getErrorMessage(error)}`,
);
} finally {
this.clients.delete(serverName);
this.eventEmitter?.emit('mcp-client-update', this.clients);
}
}

// For SDK MCP servers, pass the sendSdkMcpMessage callback.
const sdkCallback = isSdkMcpServerConfig(serverConfig)
? this.sendSdkMcpMessage
: undefined;

const client = new McpClient(
serverName,
serverConfig,
this.toolRegistry,
this.cliConfig.getPromptRegistry(),
this.cliConfig.getWorkspaceContext(),
this.cliConfig.getDebugMode(),
sdkCallback,
);

this.clients.set(serverName, client);
this.eventEmitter?.emit('mcp-client-update', this.clients);

try {
await client.connect();
await client.discover(cliConfig);
} catch (error) {
// Log the error but don't throw: callers expect best-effort discovery.
console.error(
`Error during discovery for server '${serverName}': ${getErrorMessage(
error,
)}`,
);
} finally {
this.eventEmitter?.emit('mcp-client-update', this.clients);
}
}

/**
* Stops all running local MCP servers and closes all client connections.
* This is the cleanup method to be called on application exit.
Expand Down
31 changes: 17 additions & 14 deletions packages/core/src/tools/tool-registry.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ import { Kind, BaseDeclarativeTool, BaseToolInvocation } from './tools.js';
import type { Config } from '../config/config.js';
import { spawn } from 'node:child_process';
import { StringDecoder } from 'node:string_decoder';
import { connectAndDiscover } from './mcp-client.js';
import type { SendSdkMcpMessage } from './mcp-client.js';
import { McpClientManager } from './mcp-client-manager.js';
import { DiscoveredMCPTool } from './mcp-tool.js';
Expand Down Expand Up @@ -279,19 +278,10 @@ export class ToolRegistry {

this.config.getPromptRegistry().removePromptsByServer(serverName);

const mcpServers = this.config.getMcpServers() ?? {};
const serverConfig = mcpServers[serverName];
if (serverConfig) {
await connectAndDiscover(
serverName,
serverConfig,
this,
this.config.getPromptRegistry(),
this.config.getDebugMode(),
this.config.getWorkspaceContext(),
this.config,
);
}
await this.mcpClientManager.discoverMcpToolsForServer(
serverName,
this.config,
);
}

private async discoverAndRegisterToolsFromCommand(): Promise<void> {
Expand Down Expand Up @@ -479,4 +469,17 @@ export class ToolRegistry {
getTool(name: string): AnyDeclarativeTool | undefined {
return this.tools.get(name);
}

/**
* Stops all MCP clients and cleans up resources.
* This method is idempotent and safe to call multiple times.
*/
async stop(): Promise<void> {
try {
await this.mcpClientManager.stop();
} catch (error) {
// Log but don't throw - cleanup should be best-effort
console.error('Error stopping MCP clients:', error);
}
}
}