Skip to content

Commit ede925b

Browse files
seanzhougooglecopybara-github
authored andcommitted
chore: Lazy register all streaming tools
Co-authored-by: Xiang (Sean) Zhou <[email protected]> PiperOrigin-RevId: 869480442
1 parent 5269a6b commit ede925b

File tree

4 files changed

+74
-173
lines changed

4 files changed

+74
-173
lines changed

src/google/adk/flows/llm_flows/functions.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -825,11 +825,21 @@ async def run_tool_and_update_queue(tool, function_args, tool_context):
825825
run_tool_and_update_queue(tool, function_args, tool_context)
826826
)
827827

828-
# The tool is already registered in active_streaming_tools by
829-
# runners.py at startup (all async-generator tools are registered
830-
# there). Just attach the background task.
831828
async with streaming_lock:
832-
invocation_context.active_streaming_tools[tool.name].task = task
829+
830+
if invocation_context.active_streaming_tools is None:
831+
invocation_context.active_streaming_tools = {}
832+
if tool.name in invocation_context.active_streaming_tools:
833+
invocation_context.active_streaming_tools[tool.name].task = task
834+
else:
835+
# Register the streaming tool lazily when the model calls it.
836+
# For input-streaming tools (those with `input_stream:
837+
# LiveRequestQueue`), _call_live will set .stream to a new
838+
# LiveRequestQueue so _send_to_model starts duplicating data.
839+
invocation_context.active_streaming_tools[tool.name] = (
840+
ActiveStreamingTool(task=task)
841+
)
842+
logger.debug('Lazily registered streaming tool: %s', tool.name)
833843

834844
# Immediately return a pending response.
835845
# This is required by current live model.

src/google/adk/runners.py

Lines changed: 0 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
from google.adk.artifacts import artifact_util
3333
from google.genai import types
3434

35-
from .agents.active_streaming_tool import ActiveStreamingTool
3635
from .agents.base_agent import BaseAgent
3736
from .agents.base_agent import BaseAgentState
3837
from .agents.context_cache_config import ContextCacheConfig
@@ -1012,42 +1011,6 @@ async def run_live(
10121011
root_agent = self.agent
10131012
invocation_context.agent = self._find_agent_to_run(session, root_agent)
10141013

1015-
# Pre-processing for live streaming tools
1016-
# Inspect the tool's parameters to find if it uses LiveRequestQueue
1017-
invocation_context.active_streaming_tools = {}
1018-
# For shell agents, there is no canonical_tools method so we should skip.
1019-
if hasattr(invocation_context.agent, 'canonical_tools'):
1020-
import inspect
1021-
1022-
# Use canonical_tools to get properly wrapped BaseTool instances
1023-
canonical_tools = await invocation_context.agent.canonical_tools(
1024-
invocation_context
1025-
)
1026-
# Register all async-generator tools as streaming tools.
1027-
# A streaming tool is any tool whose underlying function is an
1028-
# async generator (i.e. uses `yield`). There are two sub-types:
1029-
# 1. Input-streaming tools: accept a `input_stream:
1030-
# LiveRequestQueue` parameter to consume the live audio/video
1031-
# stream. The stream is created lazily in `_call_live` when
1032-
# the model actually calls the tool.
1033-
# 2. Output-streaming tools: async generators that yield results
1034-
# over time but don't consume the live stream. They are run
1035-
# as background tasks when called.
1036-
# Both types are registered here with `stream=None`. The
1037-
# distinction between them is made at call time.
1038-
for tool in canonical_tools:
1039-
callable_to_inspect = tool.func if hasattr(tool, 'func') else tool
1040-
if not callable(callable_to_inspect):
1041-
continue
1042-
if inspect.isasyncgenfunction(callable_to_inspect):
1043-
if not invocation_context.active_streaming_tools:
1044-
invocation_context.active_streaming_tools = {}
1045-
logger.debug('Register streaming tool: %s', tool.name)
1046-
active_streaming_tool = ActiveStreamingTool()
1047-
invocation_context.active_streaming_tools[tool.name] = (
1048-
active_streaming_tool
1049-
)
1050-
10511014
async def execute(ctx: InvocationContext) -> AsyncGenerator[Event]:
10521015
async with Aclosing(ctx.agent.run_live(ctx)) as agen:
10531016
async for event in agen:

tests/unittests/streaming/test_streaming.py

Lines changed: 60 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1322,10 +1322,10 @@ async def consume(session: testing_utils.Session):
13221322
return collected
13231323

13241324

1325-
def test_input_streaming_tool_stream_is_none_before_model_calls():
1326-
"""Test that input-streaming tools have stream=None until the model calls them."""
1327-
# Add a text response before the function call so we can observe stream
1328-
# state between registration and tool invocation.
1325+
def test_input_streaming_tool_registered_lazily_with_stream():
1326+
"""Test that input-streaming tools are registered lazily when called and receive a stream."""
1327+
# A text response before the function call lets us observe that the
1328+
# tool is NOT registered before the model calls it.
13291329
text_response = LlmResponse(
13301330
content=types.Content(
13311331
role='model',
@@ -1362,7 +1362,7 @@ async def monitor_video_stream(input_stream: LiveRequestQueue):
13621362

13631363
runner = _LiveTestRunner(root_agent=root_agent)
13641364

1365-
# Capture the invocation context to inspect stream state.
1365+
# Capture the invocation context to inspect registration state.
13661366
captured_context = None
13671367
original_method = runner.runner._new_invocation_context_for_live
13681368

@@ -1379,39 +1379,39 @@ def capturing_method(*args, **kwargs):
13791379
blob=types.Blob(data=b'test_data', mime_type='audio/pcm')
13801380
)
13811381

1382-
# Collect events and capture stream state before the tool is called.
1382+
# Collect events and check that the tool is NOT registered before
1383+
# the model calls it.
13831384
collected = []
1384-
stream_states_before_call = []
1385+
not_registered_before_call = None
13851386

13861387
async def consume(session: testing_utils.Session):
1388+
nonlocal not_registered_before_call
13871389
async for response in runner.runner.run_live(
13881390
session=session,
13891391
live_request_queue=live_request_queue,
13901392
):
13911393
collected.append(response)
1392-
# On a non-function-call event, the tool is registered but not
1393-
# yet invoked — capture the stream value at that point.
1394+
# On the first non-function-call event, verify the tool is not
1395+
# yet registered (lazy registration).
13941396
active = (
1395-
captured_context.active_streaming_tools if captured_context else {}
1397+
captured_context.active_streaming_tools if captured_context else None
13961398
)
13971399
if (
1398-
not stream_states_before_call
1400+
not_registered_before_call is None
13991401
and not response.get_function_calls()
1400-
and 'monitor_video_stream' in active
14011402
):
1402-
stream_states_before_call.append(active['monitor_video_stream'].stream)
1403+
not_registered_before_call = (
1404+
active is None or 'monitor_video_stream' not in active
1405+
)
14031406
if len(collected) >= 4:
14041407
return
14051408

14061409
runner._run_with_loop(asyncio.wait_for(consume(runner.session), timeout=5.0))
14071410

1408-
# Before the model calls the tool, stream should be None.
1409-
assert (
1410-
stream_states_before_call
1411-
), 'Stream state was never observed before the tool call'
1411+
# Tool should not be registered before the model calls it.
14121412
assert (
1413-
stream_states_before_call[0] is None
1414-
), 'Expected stream to be None before the model calls the tool'
1413+
not_registered_before_call is True
1414+
), 'Expected tool to NOT be registered before the model calls it'
14151415
# When the model calls the tool, input_stream should be provided.
14161416
assert (
14171417
stream_state_during_call is True
@@ -1458,17 +1458,20 @@ def stop_streaming(function_name: str):
14581458

14591459
runner = _LiveTestRunner(root_agent=root_agent)
14601460

1461-
# Capture invocation context to verify stream is reset.
1462-
captured_context = None
1463-
original_method = runner.runner._new_invocation_context_for_live
1464-
1465-
def capturing_method(*args, **kwargs):
1466-
nonlocal captured_context
1467-
ctx = original_method(*args, **kwargs)
1468-
captured_context = ctx
1461+
# Capture the child invocation context (created by _create_invocation_context
1462+
# inside base_agent.run_live) to inspect active_streaming_tools.
1463+
# We cannot use the parent context from _new_invocation_context_for_live
1464+
# because model_copy creates a separate child object.
1465+
captured_child_context = None
1466+
original_create = root_agent._create_invocation_context
1467+
1468+
def capturing_create(*args, **kwargs):
1469+
nonlocal captured_child_context
1470+
ctx = original_create(*args, **kwargs)
1471+
captured_child_context = ctx
14691472
return ctx
14701473

1471-
runner.runner._new_invocation_context_for_live = capturing_method
1474+
root_agent._create_invocation_context = capturing_create
14721475

14731476
live_request_queue = LiveRequestQueue()
14741477
live_request_queue.send_realtime(
@@ -1488,9 +1491,9 @@ def capturing_method(*args, **kwargs):
14881491

14891492
# Verify that stop_streaming reset the stream to None.
14901493
assert (
1491-
captured_context is not None
1492-
), 'Expected invocation context to be captured'
1493-
active_tools = captured_context.active_streaming_tools or {}
1494+
captured_child_context is not None
1495+
), 'Expected child invocation context to be captured'
1496+
active_tools = captured_child_context.active_streaming_tools or {}
14941497
assert (
14951498
'monitor_stock_price' in active_tools
14961499
), 'Expected monitor_stock_price in active_streaming_tools'
@@ -1499,11 +1502,18 @@ def capturing_method(*args, **kwargs):
14991502
), 'Expected stream to be reset to None after stop_streaming'
15001503

15011504

1502-
def test_output_streaming_tool_registered_at_startup():
1503-
"""Test that output-streaming tools (async generators without LiveRequestQueue) are registered at startup."""
1504-
response1 = LlmResponse(turn_complete=True)
1505+
def test_output_streaming_tool_registered_lazily_without_stream():
1506+
"""Test that output-streaming tools are registered lazily when called, with stream=None."""
1507+
function_call = types.Part.from_function_call(
1508+
name='monitor_stock_price', args={'stock_symbol': 'GOOG'}
1509+
)
1510+
response1 = LlmResponse(
1511+
content=types.Content(role='model', parts=[function_call]),
1512+
turn_complete=False,
1513+
)
1514+
response2 = LlmResponse(turn_complete=True)
15051515

1506-
mock_model = testing_utils.MockModel.create([response1])
1516+
mock_model = testing_utils.MockModel.create([response1, response2])
15071517

15081518
async def monitor_stock_price(stock_symbol: str):
15091519
"""Yield periodic price updates."""
@@ -1517,31 +1527,33 @@ async def monitor_stock_price(stock_symbol: str):
15171527

15181528
runner = _LiveTestRunner(root_agent=root_agent)
15191529

1520-
# Capture invocation context to verify registration.
1521-
captured_context = None
1522-
original_method = runner.runner._new_invocation_context_for_live
1530+
# Capture the child invocation context (created by _create_invocation_context
1531+
# inside base_agent.run_live) to inspect active_streaming_tools.
1532+
captured_child_context = None
1533+
original_create = root_agent._create_invocation_context
15231534

1524-
def capturing_method(*args, **kwargs):
1525-
nonlocal captured_context
1526-
ctx = original_method(*args, **kwargs)
1527-
captured_context = ctx
1535+
def capturing_create(*args, **kwargs):
1536+
nonlocal captured_child_context
1537+
ctx = original_create(*args, **kwargs)
1538+
captured_child_context = ctx
15281539
return ctx
15291540

1530-
runner.runner._new_invocation_context_for_live = capturing_method
1541+
root_agent._create_invocation_context = capturing_create
15311542

15321543
live_request_queue = LiveRequestQueue()
15331544
live_request_queue.send_realtime(
15341545
blob=types.Blob(data=b'test', mime_type='audio/pcm')
15351546
)
15361547

1537-
runner.run_live(live_request_queue, max_responses=1)
1548+
runner.run_live(live_request_queue, max_responses=3)
15381549

1539-
# Output-streaming tool should be registered with stream=None.
1540-
assert captured_context is not None
1541-
active_tools = captured_context.active_streaming_tools or {}
1550+
# After the model calls the tool, it should be registered with
1551+
# stream=None (output-streaming tools don't consume the live stream).
1552+
assert captured_child_context is not None
1553+
active_tools = captured_child_context.active_streaming_tools or {}
15421554
assert (
15431555
'monitor_stock_price' in active_tools
1544-
), 'Expected output-streaming tool to be registered at startup'
1556+
), 'Expected output-streaming tool to be registered when called'
15451557
assert (
15461558
active_tools['monitor_stock_price'].stream is None
15471559
), 'Expected stream to be None for output-streaming tool'

tests/unittests/test_runners.py

Lines changed: 0 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
from google.adk.agents.base_agent import BaseAgent
2424
from google.adk.agents.context_cache_config import ContextCacheConfig
2525
from google.adk.agents.invocation_context import InvocationContext
26-
from google.adk.agents.live_request_queue import LiveRequestQueue
2726
from google.adk.agents.llm_agent import LlmAgent
2827
from google.adk.agents.run_config import RunConfig
2928
from google.adk.apps.app import App
@@ -35,7 +34,6 @@
3534
from google.adk.runners import Runner
3635
from google.adk.sessions.in_memory_session_service import InMemorySessionService
3736
from google.adk.sessions.session import Session
38-
from google.adk.tools.function_tool import FunctionTool
3937
from google.genai import types
4038
import pytest
4139

@@ -360,88 +358,6 @@ async def test_run_live_auto_create_session():
360358
assert session is not None
361359

362360

363-
@pytest.mark.asyncio
364-
async def test_run_live_detects_streaming_tools_with_canonical_tools():
365-
"""run_live should detect streaming tools using canonical_tools and tool.name."""
366-
367-
# Define streaming tools - one as raw function, one wrapped in FunctionTool
368-
async def raw_streaming_tool(
369-
input_stream: LiveRequestQueue,
370-
) -> AsyncGenerator[str, None]:
371-
"""A raw streaming tool function."""
372-
yield "test"
373-
374-
async def wrapped_streaming_tool(
375-
input_stream: LiveRequestQueue,
376-
) -> AsyncGenerator[str, None]:
377-
"""A streaming tool wrapped in FunctionTool."""
378-
yield "test"
379-
380-
def non_streaming_tool(param: str) -> str:
381-
"""A regular non-streaming tool."""
382-
return param
383-
384-
# Create a mock LlmAgent that yields an event and captures invocation context
385-
captured_context = {}
386-
387-
class StreamingToolsAgent(LlmAgent):
388-
389-
async def _run_live_impl(
390-
self, invocation_context: InvocationContext
391-
) -> AsyncGenerator[Event, None]:
392-
# Capture the active_streaming_tools for verification
393-
captured_context["active_streaming_tools"] = (
394-
invocation_context.active_streaming_tools
395-
)
396-
yield Event(
397-
invocation_id=invocation_context.invocation_id,
398-
author=self.name,
399-
content=types.Content(
400-
role="model", parts=[types.Part(text="streaming test")]
401-
),
402-
)
403-
404-
agent = StreamingToolsAgent(
405-
name="streaming_agent",
406-
model="gemini-2.0-flash",
407-
tools=[
408-
raw_streaming_tool, # Raw function
409-
FunctionTool(wrapped_streaming_tool), # Wrapped in FunctionTool
410-
non_streaming_tool, # Non-streaming tool (should not be detected)
411-
],
412-
)
413-
414-
session_service = InMemorySessionService()
415-
artifact_service = InMemoryArtifactService()
416-
runner = Runner(
417-
app_name="streaming_test_app",
418-
agent=agent,
419-
session_service=session_service,
420-
artifact_service=artifact_service,
421-
auto_create_session=True,
422-
)
423-
424-
live_queue = LiveRequestQueue()
425-
426-
agen = runner.run_live(
427-
user_id="user",
428-
session_id="test_session",
429-
live_request_queue=live_queue,
430-
)
431-
432-
event = await agen.__anext__()
433-
await agen.aclose()
434-
435-
assert event.author == "streaming_agent"
436-
437-
# Verify streaming tools were detected correctly
438-
active_tools = captured_context.get("active_streaming_tools", {})
439-
assert "raw_streaming_tool" in active_tools
440-
assert "wrapped_streaming_tool" in active_tools
441-
# Non-streaming tool should not be detected
442-
assert "non_streaming_tool" not in active_tools
443-
444-
445361
@pytest.mark.asyncio
446362
async def test_runner_allows_nested_agent_directories(tmp_path, monkeypatch):
447363
project_root = tmp_path / "workspace"

0 commit comments

Comments
 (0)