Skip to content

Commit 0cc3d6d

Browse files
mikeedjonesgemini-code-assist[bot]wukath
authored
Feat/expose mcps streamable http custom httpx factory parameter (google#2997)
* feat: Add support for custom HTTPX client factory in StreamableHTTPConnectionParams * Update src/google/adk/tools/mcp_tool/mcp_session_manager.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> * unit tested mock * provide default - httpx client factory can't be none * feat: Enhance StreamableHTTPConnectionParams with httpx_client_factory attribute * fmt * fmt * refactor: Rename test_init_with_streamable_http_none_httpx_factory to test_init_with_streamable_http_default_httpx_factory for clarity * isort * fmt --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Kathy Wu <[email protected]>
1 parent b5f5df9 commit 0cc3d6d

File tree

2 files changed

+69
-0
lines changed

2 files changed

+69
-0
lines changed

src/google/adk/tools/mcp_tool/mcp_session_manager.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,17 +25,22 @@
2525
from typing import Any
2626
from typing import Dict
2727
from typing import Optional
28+
from typing import Protocol
29+
from typing import runtime_checkable
2830
from typing import TextIO
2931
from typing import Union
3032

3133
import anyio
3234
from pydantic import BaseModel
35+
from pydantic import ConfigDict
3336

3437
try:
3538
from mcp import ClientSession
3639
from mcp import StdioServerParameters
3740
from mcp.client.sse import sse_client
3841
from mcp.client.stdio import stdio_client
42+
from mcp.client.streamable_http import create_mcp_http_client
43+
from mcp.client.streamable_http import McpHttpClientFactory
3944
from mcp.client.streamable_http import streamablehttp_client
4045
except ImportError as e:
4146

@@ -84,6 +89,11 @@ class SseConnectionParams(BaseModel):
8489
sse_read_timeout: float = 60 * 5.0
8590

8691

92+
@runtime_checkable
93+
class CheckableMcpHttpClientFactory(McpHttpClientFactory, Protocol):
94+
pass
95+
96+
8797
class StreamableHTTPConnectionParams(BaseModel):
8898
"""Parameters for the MCP Streamable HTTP connection.
8999
@@ -99,13 +109,18 @@ class StreamableHTTPConnectionParams(BaseModel):
99109
Streamable HTTP server.
100110
terminate_on_close: Whether to terminate the MCP Streamable HTTP server
101111
when the connection is closed.
112+
httpx_client_factory: Factory function to create a custom HTTPX client. If
113+
not provided, a default factory will be used.
102114
"""
103115

116+
model_config = ConfigDict(arbitrary_types_allowed=True)
117+
104118
url: str
105119
headers: dict[str, Any] | None = None
106120
timeout: float = 5.0
107121
sse_read_timeout: float = 60 * 5.0
108122
terminate_on_close: bool = True
123+
httpx_client_factory: CheckableMcpHttpClientFactory = create_mcp_http_client
109124

110125

111126
def retry_on_closed_resource(func):
@@ -286,6 +301,7 @@ def _create_client(self, merged_headers: Optional[Dict[str, str]] = None):
286301
seconds=self._connection_params.sse_read_timeout
287302
),
288303
terminate_on_close=self._connection_params.terminate_on_close,
304+
httpx_client_factory=self._connection_params.httpx_client_factory,
289305
)
290306
else:
291307
raise ValueError(

tests/unittests/tools/mcp_tool/test_mcp_session_manager.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,59 @@ def test_init_with_streamable_http_params(self):
146146

147147
assert manager._connection_params == http_params
148148

149+
@patch("google.adk.tools.mcp_tool.mcp_session_manager.streamablehttp_client")
150+
def test_init_with_streamable_http_custom_httpx_factory(
151+
self, mock_streamablehttp_client
152+
):
153+
"""Test that streamablehttp_client is called with custom httpx_client_factory."""
154+
from datetime import timedelta
155+
156+
custom_httpx_factory = Mock()
157+
158+
http_params = StreamableHTTPConnectionParams(
159+
url="https://example.com/mcp",
160+
timeout=15.0,
161+
httpx_client_factory=custom_httpx_factory,
162+
)
163+
manager = MCPSessionManager(http_params)
164+
165+
manager._create_client()
166+
167+
mock_streamablehttp_client.assert_called_once_with(
168+
url="https://example.com/mcp",
169+
headers=None,
170+
timeout=timedelta(seconds=15.0),
171+
sse_read_timeout=timedelta(seconds=300.0),
172+
terminate_on_close=True,
173+
httpx_client_factory=custom_httpx_factory,
174+
)
175+
176+
@pytest.mark.asyncio
177+
@patch("google.adk.tools.mcp_tool.mcp_session_manager.streamablehttp_client")
178+
async def test_init_with_streamable_http_default_httpx_factory(
179+
self, mock_streamablehttp_client
180+
):
181+
"""Test that streamablehttp_client is called with custom httpx_client_factory."""
182+
from datetime import timedelta
183+
184+
from mcp.client.streamable_http import create_mcp_http_client
185+
186+
http_params = StreamableHTTPConnectionParams(
187+
url="https://example.com/mcp", timeout=15.0
188+
)
189+
manager = MCPSessionManager(http_params)
190+
191+
manager._create_client()
192+
193+
mock_streamablehttp_client.assert_called_once_with(
194+
url="https://example.com/mcp",
195+
headers=None,
196+
timeout=timedelta(seconds=15.0),
197+
sse_read_timeout=timedelta(seconds=300.0),
198+
terminate_on_close=True,
199+
httpx_client_factory=create_mcp_http_client,
200+
)
201+
149202
def test_generate_session_key_stdio(self):
150203
"""Test session key generation for stdio connections."""
151204
manager = MCPSessionManager(self.mock_stdio_connection_params)

0 commit comments

Comments
 (0)