diff --git a/packages/cli/src/gemini.tsx b/packages/cli/src/gemini.tsx index 16fea6311..a17e85139 100644 --- a/packages/cli/src/gemini.tsx +++ b/packages/cli/src/gemini.tsx @@ -339,6 +339,9 @@ export async function main() { process.cwd(), argv.extensions, ); + + // Register cleanup for MCP clients as early as possible + // This ensures MCP server subprocesses are properly terminated on exit registerCleanup(() => config.shutdown()); // FIXME: list extensions after the config initialize diff --git a/packages/core/src/config/config.ts b/packages/core/src/config/config.ts index af2d28555..79b9921dd 100644 --- a/packages/core/src/config/config.ts +++ b/packages/core/src/config/config.ts @@ -814,13 +814,6 @@ export class Config { return this.sessionId; } - /** - * Releases resources owned by the config instance. - */ - async shutdown(): Promise { - this.skillManager?.stopWatching(); - } - /** * Starts a new session and resets session-scoped services. */ @@ -1027,6 +1020,28 @@ export class Config { return this.toolRegistry; } + /** + * Shuts down the Config and releases all resources. + * This method is idempotent and safe to call multiple times. + * It handles the case where initialization was not completed. + */ + async shutdown(): Promise { + if (!this.initialized) { + // Nothing to clean up if not initialized + return; + } + try { + this.skillManager?.stopWatching(); + + if (this.toolRegistry) { + await this.toolRegistry.stop(); + } + } catch (error) { + // Log but don't throw - cleanup should be best-effort + console.error('Error during Config shutdown:', error); + } + } + getPromptRegistry(): PromptRegistry { return this.promptRegistry; } diff --git a/packages/core/src/tools/mcp-client-manager.test.ts b/packages/core/src/tools/mcp-client-manager.test.ts index ff2cb60fc..051c9d87a 100644 --- a/packages/core/src/tools/mcp-client-manager.test.ts +++ b/packages/core/src/tools/mcp-client-manager.test.ts @@ -9,15 +9,16 @@ import { McpClientManager } from './mcp-client-manager.js'; import { McpClient } from './mcp-client.js'; import type { ToolRegistry } from './tool-registry.js'; import type { Config } from '../config/config.js'; +import type { PromptRegistry } from '../prompts/prompt-registry.js'; +import type { WorkspaceContext } from '../utils/workspaceContext.js'; vi.mock('./mcp-client.js', async () => { const originalModule = await vi.importActual('./mcp-client.js'); return { ...originalModule, McpClient: vi.fn(), - populateMcpServerCommand: vi.fn(() => ({ - 'test-server': {}, - })), + // Return the input servers unchanged (identity function) + populateMcpServerCommand: vi.fn((servers) => servers), }; }); @@ -73,4 +74,178 @@ describe('McpClientManager', () => { expect(mockedMcpClient.connect).not.toHaveBeenCalled(); expect(mockedMcpClient.discover).not.toHaveBeenCalled(); }); + + it('should disconnect all clients when stop is called', async () => { + // Track disconnect calls across all instances + const disconnectCalls: string[] = []; + vi.mocked(McpClient).mockImplementation( + (name: string) => + ({ + connect: vi.fn(), + discover: vi.fn(), + disconnect: vi.fn().mockImplementation(() => { + disconnectCalls.push(name); + return Promise.resolve(); + }), + getStatus: vi.fn(), + }) as unknown as McpClient, + ); + const mockConfig = { + isTrustedFolder: () => true, + getMcpServers: () => ({ 'test-server': {}, 'another-server': {} }), + getMcpServerCommand: () => undefined, + getPromptRegistry: () => ({}) as PromptRegistry, + getWorkspaceContext: () => ({}) as WorkspaceContext, + getDebugMode: () => false, + } as unknown as Config; + const manager = new McpClientManager(mockConfig, {} as ToolRegistry); + // First connect to create the clients + await manager.discoverAllMcpTools({ + isTrustedFolder: () => true, + } as unknown as Config); + + // Clear the disconnect calls from initial stop() in discoverAllMcpTools + disconnectCalls.length = 0; + + // Then stop + await manager.stop(); + expect(disconnectCalls).toHaveLength(2); + expect(disconnectCalls).toContain('test-server'); + expect(disconnectCalls).toContain('another-server'); + }); + + it('should be idempotent - stop can be called multiple times safely', async () => { + const mockedMcpClient = { + connect: vi.fn(), + discover: vi.fn(), + disconnect: vi.fn().mockResolvedValue(undefined), + getStatus: vi.fn(), + }; + vi.mocked(McpClient).mockReturnValue( + mockedMcpClient as unknown as McpClient, + ); + const mockConfig = { + isTrustedFolder: () => true, + getMcpServers: () => ({ 'test-server': {} }), + getMcpServerCommand: () => undefined, + getPromptRegistry: () => ({}) as PromptRegistry, + getWorkspaceContext: () => ({}) as WorkspaceContext, + getDebugMode: () => false, + } as unknown as Config; + const manager = new McpClientManager(mockConfig, {} as ToolRegistry); + await manager.discoverAllMcpTools({ + isTrustedFolder: () => true, + } as unknown as Config); + + // Call stop multiple times - should not throw + await manager.stop(); + await manager.stop(); + await manager.stop(); + }); + + it('should discover tools for a single server and track the client for stop', async () => { + const mockedMcpClient = { + connect: vi.fn(), + discover: vi.fn(), + disconnect: vi.fn().mockResolvedValue(undefined), + getStatus: vi.fn(), + }; + vi.mocked(McpClient).mockReturnValue( + mockedMcpClient as unknown as McpClient, + ); + + const mockConfig = { + isTrustedFolder: () => true, + getMcpServers: () => ({ 'test-server': {} }), + getMcpServerCommand: () => undefined, + getPromptRegistry: () => ({}) as PromptRegistry, + getWorkspaceContext: () => ({}) as WorkspaceContext, + getDebugMode: () => false, + } as unknown as Config; + const manager = new McpClientManager(mockConfig, {} as ToolRegistry); + + await manager.discoverMcpToolsForServer( + 'test-server', + {} as unknown as Config, + ); + + expect(mockedMcpClient.connect).toHaveBeenCalledOnce(); + expect(mockedMcpClient.discover).toHaveBeenCalledOnce(); + + await manager.stop(); + expect(mockedMcpClient.disconnect).toHaveBeenCalledOnce(); + }); + + it('should replace an existing client when re-discovering a server', async () => { + const firstClient = { + connect: vi.fn(), + discover: vi.fn(), + disconnect: vi.fn().mockResolvedValue(undefined), + getStatus: vi.fn(), + }; + const secondClient = { + connect: vi.fn(), + discover: vi.fn(), + disconnect: vi.fn().mockResolvedValue(undefined), + getStatus: vi.fn(), + }; + + vi.mocked(McpClient) + .mockReturnValueOnce(firstClient as unknown as McpClient) + .mockReturnValueOnce(secondClient as unknown as McpClient); + + const mockConfig = { + isTrustedFolder: () => true, + getMcpServers: () => ({ 'test-server': {} }), + getMcpServerCommand: () => undefined, + getPromptRegistry: () => ({}) as PromptRegistry, + getWorkspaceContext: () => ({}) as WorkspaceContext, + getDebugMode: () => false, + } as unknown as Config; + const manager = new McpClientManager(mockConfig, {} as ToolRegistry); + + await manager.discoverMcpToolsForServer( + 'test-server', + {} as unknown as Config, + ); + await manager.discoverMcpToolsForServer( + 'test-server', + {} as unknown as Config, + ); + + expect(firstClient.disconnect).toHaveBeenCalledOnce(); + expect(secondClient.connect).toHaveBeenCalledOnce(); + expect(secondClient.discover).toHaveBeenCalledOnce(); + + await manager.stop(); + expect(secondClient.disconnect).toHaveBeenCalledOnce(); + }); + + it('should no-op when discovering an unknown server', async () => { + const mockedMcpClient = { + connect: vi.fn(), + discover: vi.fn(), + disconnect: vi.fn().mockResolvedValue(undefined), + getStatus: vi.fn(), + }; + vi.mocked(McpClient).mockReturnValue( + mockedMcpClient as unknown as McpClient, + ); + + const mockConfig = { + isTrustedFolder: () => true, + getMcpServers: () => ({}), + getMcpServerCommand: () => undefined, + getPromptRegistry: () => ({}) as PromptRegistry, + getWorkspaceContext: () => ({}) as WorkspaceContext, + getDebugMode: () => false, + } as unknown as Config; + const manager = new McpClientManager(mockConfig, {} as ToolRegistry); + + await manager.discoverMcpToolsForServer('unknown-server', { + isTrustedFolder: () => true, + } as unknown as Config); + + expect(vi.mocked(McpClient)).not.toHaveBeenCalled(); + }); }); diff --git a/packages/core/src/tools/mcp-client-manager.ts b/packages/core/src/tools/mcp-client-manager.ts index 354776c8d..d72c76ca5 100644 --- a/packages/core/src/tools/mcp-client-manager.ts +++ b/packages/core/src/tools/mcp-client-manager.ts @@ -100,6 +100,73 @@ export class McpClientManager { this.discoveryState = MCPDiscoveryState.COMPLETED; } + /** + * Connects to a single MCP server and discovers its tools/prompts. + * The connected client is tracked so it can be closed by {@link stop}. + * + * This is primarily used for on-demand re-discovery flows (e.g. after OAuth). + */ + async discoverMcpToolsForServer( + serverName: string, + cliConfig: Config, + ): Promise { + const servers = populateMcpServerCommand( + this.cliConfig.getMcpServers() || {}, + this.cliConfig.getMcpServerCommand(), + ); + const serverConfig = servers[serverName]; + if (!serverConfig) { + return; + } + + // Ensure we don't leak an existing connection for this server. + const existingClient = this.clients.get(serverName); + if (existingClient) { + try { + await existingClient.disconnect(); + } catch (error) { + console.error( + `Error stopping client '${serverName}': ${getErrorMessage(error)}`, + ); + } finally { + this.clients.delete(serverName); + this.eventEmitter?.emit('mcp-client-update', this.clients); + } + } + + // For SDK MCP servers, pass the sendSdkMcpMessage callback. + const sdkCallback = isSdkMcpServerConfig(serverConfig) + ? this.sendSdkMcpMessage + : undefined; + + const client = new McpClient( + serverName, + serverConfig, + this.toolRegistry, + this.cliConfig.getPromptRegistry(), + this.cliConfig.getWorkspaceContext(), + this.cliConfig.getDebugMode(), + sdkCallback, + ); + + this.clients.set(serverName, client); + this.eventEmitter?.emit('mcp-client-update', this.clients); + + try { + await client.connect(); + await client.discover(cliConfig); + } catch (error) { + // Log the error but don't throw: callers expect best-effort discovery. + console.error( + `Error during discovery for server '${serverName}': ${getErrorMessage( + error, + )}`, + ); + } finally { + this.eventEmitter?.emit('mcp-client-update', this.clients); + } + } + /** * Stops all running local MCP servers and closes all client connections. * This is the cleanup method to be called on application exit. diff --git a/packages/core/src/tools/tool-registry.ts b/packages/core/src/tools/tool-registry.ts index 4db7bd789..540851f50 100644 --- a/packages/core/src/tools/tool-registry.ts +++ b/packages/core/src/tools/tool-registry.ts @@ -15,7 +15,6 @@ import { Kind, BaseDeclarativeTool, BaseToolInvocation } from './tools.js'; import type { Config } from '../config/config.js'; import { spawn } from 'node:child_process'; import { StringDecoder } from 'node:string_decoder'; -import { connectAndDiscover } from './mcp-client.js'; import type { SendSdkMcpMessage } from './mcp-client.js'; import { McpClientManager } from './mcp-client-manager.js'; import { DiscoveredMCPTool } from './mcp-tool.js'; @@ -279,19 +278,10 @@ export class ToolRegistry { this.config.getPromptRegistry().removePromptsByServer(serverName); - const mcpServers = this.config.getMcpServers() ?? {}; - const serverConfig = mcpServers[serverName]; - if (serverConfig) { - await connectAndDiscover( - serverName, - serverConfig, - this, - this.config.getPromptRegistry(), - this.config.getDebugMode(), - this.config.getWorkspaceContext(), - this.config, - ); - } + await this.mcpClientManager.discoverMcpToolsForServer( + serverName, + this.config, + ); } private async discoverAndRegisterToolsFromCommand(): Promise { @@ -479,4 +469,17 @@ export class ToolRegistry { getTool(name: string): AnyDeclarativeTool | undefined { return this.tools.get(name); } + + /** + * Stops all MCP clients and cleans up resources. + * This method is idempotent and safe to call multiple times. + */ + async stop(): Promise { + try { + await this.mcpClientManager.stop(); + } catch (error) { + // Log but don't throw - cleanup should be best-effort + console.error('Error stopping MCP clients:', error); + } + } }