Skip to content

Commit 0ac51e2

Browse files
authored
feat: Add on_stream to agents as tools (#2169)
1 parent f7c8c27 commit 0ac51e2

File tree

7 files changed

+753
-22
lines changed

7 files changed

+753
-22
lines changed

examples/agent_patterns/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ The mental model for handoffs is that the new agent "takes over". It sees the pr
2828
For example, you could model the translation task above as tool calls instead: rather than handing over to the language-specific agent, you could call the agent as a tool, and then use the result in the next step. This enables things like translating multiple languages at once.
2929

3030
See the [`agents_as_tools.py`](./agents_as_tools.py) file for an example of this.
31+
See the [`agents_as_tools_streaming.py`](./agents_as_tools_streaming.py) file for a streaming variant that taps into nested agent events via `on_stream`.
3132

3233
## LLM-as-a-judge
3334

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import asyncio
2+
3+
from agents import Agent, AgentToolStreamEvent, ModelSettings, Runner, function_tool, trace
4+
5+
6+
@function_tool(
7+
name_override="billing_status_checker",
8+
description_override="Answer questions about customer billing status.",
9+
)
10+
def billing_status_checker(customer_id: str | None = None, question: str = "") -> str:
11+
"""Return a canned billing answer or a fallback when the question is unrelated."""
12+
normalized = question.lower()
13+
if "bill" in normalized or "billing" in normalized:
14+
return f"This customer (ID: {customer_id})'s bill is $100"
15+
return "I can only answer questions about billing."
16+
17+
18+
def handle_stream(event: AgentToolStreamEvent) -> None:
19+
"""Print streaming events emitted by the nested billing agent."""
20+
stream = event["event"]
21+
tool_call = event.get("tool_call")
22+
tool_call_info = tool_call.call_id if tool_call is not None else "unknown"
23+
print(f"[stream] agent={event['agent'].name} call={tool_call_info} type={stream.type} {stream}")
24+
25+
26+
async def main() -> None:
27+
with trace("Agents as tools streaming example"):
28+
billing_agent = Agent(
29+
name="Billing Agent",
30+
instructions="You are a billing agent that answers billing questions.",
31+
model_settings=ModelSettings(tool_choice="required"),
32+
tools=[billing_status_checker],
33+
)
34+
35+
billing_agent_tool = billing_agent.as_tool(
36+
tool_name="billing_agent",
37+
tool_description="You are a billing agent that answers billing questions.",
38+
on_stream=handle_stream,
39+
)
40+
41+
main_agent = Agent(
42+
name="Customer Support Agent",
43+
instructions=(
44+
"You are a customer support agent. Always call the billing agent to answer billing "
45+
"questions and return the billing agent response to the user."
46+
),
47+
tools=[billing_agent_tool],
48+
)
49+
50+
result = await Runner.run(
51+
main_agent,
52+
"Hello, my customer ID is ABC123. How much is my bill for this month?",
53+
)
54+
55+
print(f"\nFinal response:\n{result.final_output}")
56+
57+
58+
if __name__ == "__main__":
59+
asyncio.run(main())

examples/financial_research_agent/manager.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from rich.console import Console
88

9-
from agents import Runner, RunResult, custom_span, gen_trace_id, trace
9+
from agents import Runner, RunResult, RunResultStreaming, custom_span, gen_trace_id, trace
1010

1111
from .agents.financials_agent import financials_agent
1212
from .agents.planner_agent import FinancialSearchItem, FinancialSearchPlan, planner_agent
@@ -17,7 +17,7 @@
1717
from .printer import Printer
1818

1919

20-
async def _summary_extractor(run_result: RunResult) -> str:
20+
async def _summary_extractor(run_result: RunResult | RunResultStreaming) -> str:
2121
"""Custom output extractor for sub‑agents that return an AnalysisSummary."""
2222
# The financial/risk analyst agents emit an AnalysisSummary with a `summary` field.
2323
# We want the tool call to return just that summary text so the writer can drop it inline.

src/agents/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from .agent import (
99
Agent,
1010
AgentBase,
11+
AgentToolStreamEvent,
1112
StopAtTools,
1213
ToolsToFinalOutputFunction,
1314
ToolsToFinalOutputResult,
@@ -214,6 +215,7 @@ def enable_verbose_stdout_logging():
214215
__all__ = [
215216
"Agent",
216217
"AgentBase",
218+
"AgentToolStreamEvent",
217219
"StopAtTools",
218220
"ToolsToFinalOutputFunction",
219221
"ToolsToFinalOutputResult",

src/agents/agent.py

Lines changed: 104 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -25,15 +25,19 @@
2525
from .prompts import DynamicPromptFunction, Prompt, PromptUtil
2626
from .run_context import RunContextWrapper, TContext
2727
from .tool import FunctionTool, FunctionToolResult, Tool, function_tool
28+
from .tool_context import ToolContext
2829
from .util import _transforms
2930
from .util._types import MaybeAwaitable
3031

3132
if TYPE_CHECKING:
33+
from openai.types.responses.response_function_tool_call import ResponseFunctionToolCall
34+
3235
from .lifecycle import AgentHooks, RunHooks
3336
from .mcp import MCPServer
3437
from .memory.session import Session
35-
from .result import RunResult
38+
from .result import RunResult, RunResultStreaming
3639
from .run import RunConfig
40+
from .stream_events import StreamEvent
3741

3842

3943
@dataclass
@@ -58,6 +62,19 @@ class ToolsToFinalOutputResult:
5862
"""
5963

6064

65+
class AgentToolStreamEvent(TypedDict):
66+
"""Streaming event emitted when an agent is invoked as a tool."""
67+
68+
event: StreamEvent
69+
"""The streaming event from the nested agent run."""
70+
71+
agent: Agent[Any]
72+
"""The nested agent emitting the event."""
73+
74+
tool_call: ResponseFunctionToolCall | None
75+
"""The originating tool call, if available."""
76+
77+
6178
class StopAtTools(TypedDict):
6279
stop_at_tool_names: list[str]
6380
"""A list of tool names, any of which will stop the agent from running further."""
@@ -382,9 +399,12 @@ def as_tool(
382399
self,
383400
tool_name: str | None,
384401
tool_description: str | None,
385-
custom_output_extractor: Callable[[RunResult], Awaitable[str]] | None = None,
402+
custom_output_extractor: (
403+
Callable[[RunResult | RunResultStreaming], Awaitable[str]] | None
404+
) = None,
386405
is_enabled: bool
387406
| Callable[[RunContextWrapper[Any], AgentBase[Any]], MaybeAwaitable[bool]] = True,
407+
on_stream: Callable[[AgentToolStreamEvent], MaybeAwaitable[None]] | None = None,
388408
run_config: RunConfig | None = None,
389409
max_turns: int | None = None,
390410
hooks: RunHooks[TContext] | None = None,
@@ -409,33 +429,100 @@ def as_tool(
409429
is_enabled: Whether the tool is enabled. Can be a bool or a callable that takes the run
410430
context and agent and returns whether the tool is enabled. Disabled tools are hidden
411431
from the LLM at runtime.
432+
on_stream: Optional callback (sync or async) to receive streaming events from the nested
433+
agent run. The callback receives an `AgentToolStreamEvent` containing the nested
434+
agent, the originating tool call (when available), and each stream event. When
435+
provided, the nested agent is executed in streaming mode.
412436
"""
413437

414438
@function_tool(
415439
name_override=tool_name or _transforms.transform_string_function_style(self.name),
416440
description_override=tool_description or "",
417441
is_enabled=is_enabled,
418442
)
419-
async def run_agent(context: RunContextWrapper, input: str) -> Any:
443+
async def run_agent(context: ToolContext, input: str) -> Any:
420444
from .run import DEFAULT_MAX_TURNS, Runner
421445

422446
resolved_max_turns = max_turns if max_turns is not None else DEFAULT_MAX_TURNS
423-
424-
output = await Runner.run(
425-
starting_agent=self,
426-
input=input,
427-
context=context.context,
428-
run_config=run_config,
429-
max_turns=resolved_max_turns,
430-
hooks=hooks,
431-
previous_response_id=previous_response_id,
432-
conversation_id=conversation_id,
433-
session=session,
434-
)
447+
run_result: RunResult | RunResultStreaming
448+
449+
if on_stream is not None:
450+
run_result = Runner.run_streamed(
451+
starting_agent=self,
452+
input=input,
453+
context=context.context,
454+
run_config=run_config,
455+
max_turns=resolved_max_turns,
456+
hooks=hooks,
457+
previous_response_id=previous_response_id,
458+
conversation_id=conversation_id,
459+
session=session,
460+
)
461+
# Dispatch callbacks in the background so slow handlers do not block
462+
# event consumption.
463+
event_queue: asyncio.Queue[AgentToolStreamEvent | None] = asyncio.Queue()
464+
465+
async def _run_handler(payload: AgentToolStreamEvent) -> None:
466+
"""Execute the user callback while capturing exceptions."""
467+
try:
468+
maybe_result = on_stream(payload)
469+
if inspect.isawaitable(maybe_result):
470+
await maybe_result
471+
except Exception:
472+
logger.exception(
473+
"Error while handling on_stream event for agent tool %s.",
474+
self.name,
475+
)
476+
477+
async def dispatch_stream_events() -> None:
478+
while True:
479+
payload = await event_queue.get()
480+
is_sentinel = payload is None # None marks the end of the stream.
481+
try:
482+
if payload is not None:
483+
await _run_handler(payload)
484+
finally:
485+
event_queue.task_done()
486+
487+
if is_sentinel:
488+
break
489+
490+
dispatch_task = asyncio.create_task(dispatch_stream_events())
491+
492+
try:
493+
from .stream_events import AgentUpdatedStreamEvent
494+
495+
current_agent = run_result.current_agent
496+
async for event in run_result.stream_events():
497+
if isinstance(event, AgentUpdatedStreamEvent):
498+
current_agent = event.new_agent
499+
500+
payload: AgentToolStreamEvent = {
501+
"event": event,
502+
"agent": current_agent,
503+
"tool_call": context.tool_call,
504+
}
505+
await event_queue.put(payload)
506+
finally:
507+
await event_queue.put(None)
508+
await event_queue.join()
509+
await dispatch_task
510+
else:
511+
run_result = await Runner.run(
512+
starting_agent=self,
513+
input=input,
514+
context=context.context,
515+
run_config=run_config,
516+
max_turns=resolved_max_turns,
517+
hooks=hooks,
518+
previous_response_id=previous_response_id,
519+
conversation_id=conversation_id,
520+
session=session,
521+
)
435522
if custom_output_extractor:
436-
return await custom_output_extractor(output)
523+
return await custom_output_extractor(run_result)
437524

438-
return output.final_output
525+
return run_result.final_output
439526

440527
return run_agent
441528

src/agents/tool_context.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@ class ToolContext(RunContextWrapper[TContext]):
3131
tool_arguments: str = field(default_factory=_assert_must_pass_tool_arguments)
3232
"""The raw arguments string of the tool call."""
3333

34+
tool_call: Optional[ResponseFunctionToolCall] = None
35+
"""The tool call object associated with this invocation."""
36+
3437
@classmethod
3538
def from_agent_context(
3639
cls,
@@ -50,6 +53,11 @@ def from_agent_context(
5053
tool_call.arguments if tool_call is not None else _assert_must_pass_tool_arguments()
5154
)
5255

53-
return cls(
54-
tool_name=tool_name, tool_call_id=tool_call_id, tool_arguments=tool_args, **base_values
56+
tool_context = cls(
57+
tool_name=tool_name,
58+
tool_call_id=tool_call_id,
59+
tool_arguments=tool_args,
60+
tool_call=tool_call,
61+
**base_values,
5562
)
63+
return tool_context

0 commit comments

Comments
 (0)