|
| 1 | +# Copyright 2025 Google LLC |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +"""Test script for MCP Toolset OAuth Authentication Flow. |
| 16 | +
|
| 17 | +This script demonstrates the two-phase tool discovery flow: |
| 18 | +1. First invocation: Agent tries to get tools, auth is required, returns auth |
| 19 | + request event (adk_request_credential) |
| 20 | +2. User provides OAuth credentials (simulated) |
| 21 | +3. Second invocation: Agent has credentials, can list and call tools |
| 22 | +
|
| 23 | +Usage: |
| 24 | + # Start the MCP server first (in another terminal): |
| 25 | + PYTHONPATH=src python contributing/samples/mcp_toolset_auth/oauth_mcp_server.py |
| 26 | +
|
| 27 | + # Run the demo: |
| 28 | + PYTHONPATH=src python contributing/samples/mcp_toolset_auth/main.py |
| 29 | +""" |
| 30 | + |
| 31 | +from __future__ import annotations |
| 32 | + |
| 33 | +import asyncio |
| 34 | + |
| 35 | +from agent import auth_credential |
| 36 | +from agent import auth_scheme |
| 37 | +from agent import mcp_toolset |
| 38 | +from agent import root_agent |
| 39 | +from google.adk.auth.auth_credential import AuthCredential |
| 40 | +from google.adk.auth.auth_credential import AuthCredentialTypes |
| 41 | +from google.adk.auth.auth_credential import OAuth2Auth |
| 42 | +from google.adk.auth.auth_tool import AuthConfig |
| 43 | +from google.adk.runners import Runner |
| 44 | +from google.adk.sessions.in_memory_session_service import InMemorySessionService |
| 45 | +from google.genai import types |
| 46 | + |
| 47 | + |
| 48 | +async def run_demo(): |
| 49 | + """Run demo with real MCP server.""" |
| 50 | + print('=' * 60) |
| 51 | + print('MCP Toolset OAuth Authentication Demo') |
| 52 | + print('=' * 60) |
| 53 | + print('\nNote: Make sure the MCP server is running:') |
| 54 | + print(' python oauth_mcp_server.py\n') |
| 55 | + |
| 56 | + # Create session service and runner |
| 57 | + session_service = InMemorySessionService() |
| 58 | + runner = Runner( |
| 59 | + agent=root_agent, |
| 60 | + app_name='toolset_auth_demo', |
| 61 | + session_service=session_service, |
| 62 | + ) |
| 63 | + |
| 64 | + # Create a session |
| 65 | + session = await session_service.create_session( |
| 66 | + app_name='toolset_auth_demo', |
| 67 | + user_id='test_user', |
| 68 | + ) |
| 69 | + |
| 70 | + print(f'Session created: {session.id}') |
| 71 | + print('\n--- Phase 1: Initial request (no credentials) ---\n') |
| 72 | + |
| 73 | + # First invocation - should trigger auth request |
| 74 | + user_message = 'List all users' |
| 75 | + print(f'User: {user_message}') |
| 76 | + |
| 77 | + events = [] |
| 78 | + auth_function_call_id = None |
| 79 | + max_events = 10 |
| 80 | + |
| 81 | + try: |
| 82 | + async for event in runner.run_async( |
| 83 | + session_id=session.id, |
| 84 | + user_id='test_user', |
| 85 | + new_message=types.Content( |
| 86 | + role='user', |
| 87 | + parts=[types.Part(text=user_message)], |
| 88 | + ), |
| 89 | + ): |
| 90 | + events.append(event) |
| 91 | + print(f'\nEvent from {event.author}:') |
| 92 | + if event.content and event.content.parts: |
| 93 | + for part in event.content.parts: |
| 94 | + if part.text: |
| 95 | + print(f' Text: {part.text}') |
| 96 | + if part.function_call: |
| 97 | + print(f' Function call: {part.function_call.name}') |
| 98 | + if part.function_call.name == 'adk_request_credential': |
| 99 | + auth_function_call_id = part.function_call.id |
| 100 | + |
| 101 | + if len(events) >= max_events: |
| 102 | + print(f'\n** SAFETY LIMIT ({max_events} events) **') |
| 103 | + break |
| 104 | + |
| 105 | + except Exception as e: |
| 106 | + print(f'\nError: {e}') |
| 107 | + print('Make sure the MCP server is running!') |
| 108 | + await mcp_toolset.close() |
| 109 | + return |
| 110 | + |
| 111 | + if auth_function_call_id: |
| 112 | + print('\n** Auth request detected! **') |
| 113 | + print('\n--- Phase 2: Provide OAuth credentials ---\n') |
| 114 | + |
| 115 | + # Simulate user providing OAuth credentials after completing OAuth flow |
| 116 | + auth_response = AuthConfig( |
| 117 | + auth_scheme=auth_scheme, |
| 118 | + raw_auth_credential=auth_credential, |
| 119 | + exchanged_auth_credential=AuthCredential( |
| 120 | + auth_type=AuthCredentialTypes.OAUTH2, |
| 121 | + oauth2=OAuth2Auth( |
| 122 | + access_token='test_access_token_12345', |
| 123 | + ), |
| 124 | + ), |
| 125 | + ) |
| 126 | + |
| 127 | + print('Providing access token: test_access_token_12345') |
| 128 | + |
| 129 | + auth_response_message = types.Content( |
| 130 | + role='user', |
| 131 | + parts=[ |
| 132 | + types.Part( |
| 133 | + function_response=types.FunctionResponse( |
| 134 | + name='adk_request_credential', |
| 135 | + id=auth_function_call_id, |
| 136 | + response=auth_response.model_dump(exclude_none=True), |
| 137 | + ) |
| 138 | + ) |
| 139 | + ], |
| 140 | + ) |
| 141 | + |
| 142 | + async for event in runner.run_async( |
| 143 | + session_id=session.id, |
| 144 | + user_id='test_user', |
| 145 | + new_message=auth_response_message, |
| 146 | + ): |
| 147 | + print(f'\nEvent from {event.author}:') |
| 148 | + if event.content and event.content.parts: |
| 149 | + for part in event.content.parts: |
| 150 | + if part.text: |
| 151 | + text = ( |
| 152 | + part.text[:200] + '...' if len(part.text) > 200 else part.text |
| 153 | + ) |
| 154 | + print(f' Text: {text}') |
| 155 | + if part.function_call: |
| 156 | + print(f' Function call: {part.function_call.name}') |
| 157 | + else: |
| 158 | + print('\n** No auth request - credentials may already be available **') |
| 159 | + |
| 160 | + print('\n' + '=' * 60) |
| 161 | + print('Demo completed') |
| 162 | + print('=' * 60) |
| 163 | + |
| 164 | + await mcp_toolset.close() |
| 165 | + |
| 166 | + |
| 167 | +if __name__ == '__main__': |
| 168 | + asyncio.run(run_demo()) |
0 commit comments