2525from .prompts import DynamicPromptFunction , Prompt , PromptUtil
2626from .run_context import RunContextWrapper , TContext
2727from .tool import FunctionTool , FunctionToolResult , Tool , function_tool
28+ from .tool_context import ToolContext
2829from .util import _transforms
2930from .util ._types import MaybeAwaitable
3031
3132if TYPE_CHECKING :
33+ from openai .types .responses .response_function_tool_call import ResponseFunctionToolCall
34+
3235 from .lifecycle import AgentHooks , RunHooks
3336 from .mcp import MCPServer
3437 from .memory .session import Session
35- from .result import RunResult
38+ from .result import RunResult , RunResultStreaming
3639 from .run import RunConfig
40+ from .stream_events import StreamEvent
3741
3842
3943@dataclass
@@ -58,6 +62,19 @@ class ToolsToFinalOutputResult:
5862"""
5963
6064
65+ class AgentToolStreamEvent (TypedDict ):
66+ """Streaming event emitted when an agent is invoked as a tool."""
67+
68+ event : StreamEvent
69+ """The streaming event from the nested agent run."""
70+
71+ agent : Agent [Any ]
72+ """The nested agent emitting the event."""
73+
74+ tool_call : ResponseFunctionToolCall | None
75+ """The originating tool call, if available."""
76+
77+
6178class StopAtTools (TypedDict ):
6279 stop_at_tool_names : list [str ]
6380 """A list of tool names, any of which will stop the agent from running further."""
@@ -382,9 +399,12 @@ def as_tool(
382399 self ,
383400 tool_name : str | None ,
384401 tool_description : str | None ,
385- custom_output_extractor : Callable [[RunResult ], Awaitable [str ]] | None = None ,
402+ custom_output_extractor : (
403+ Callable [[RunResult | RunResultStreaming ], Awaitable [str ]] | None
404+ ) = None ,
386405 is_enabled : bool
387406 | Callable [[RunContextWrapper [Any ], AgentBase [Any ]], MaybeAwaitable [bool ]] = True ,
407+ on_stream : Callable [[AgentToolStreamEvent ], MaybeAwaitable [None ]] | None = None ,
388408 run_config : RunConfig | None = None ,
389409 max_turns : int | None = None ,
390410 hooks : RunHooks [TContext ] | None = None ,
@@ -409,33 +429,100 @@ def as_tool(
409429 is_enabled: Whether the tool is enabled. Can be a bool or a callable that takes the run
410430 context and agent and returns whether the tool is enabled. Disabled tools are hidden
411431 from the LLM at runtime.
432+ on_stream: Optional callback (sync or async) to receive streaming events from the nested
433+ agent run. The callback receives an `AgentToolStreamEvent` containing the nested
434+ agent, the originating tool call (when available), and each stream event. When
435+ provided, the nested agent is executed in streaming mode.
412436 """
413437
414438 @function_tool (
415439 name_override = tool_name or _transforms .transform_string_function_style (self .name ),
416440 description_override = tool_description or "" ,
417441 is_enabled = is_enabled ,
418442 )
419- async def run_agent (context : RunContextWrapper , input : str ) -> Any :
443+ async def run_agent (context : ToolContext , input : str ) -> Any :
420444 from .run import DEFAULT_MAX_TURNS , Runner
421445
422446 resolved_max_turns = max_turns if max_turns is not None else DEFAULT_MAX_TURNS
423-
424- output = await Runner .run (
425- starting_agent = self ,
426- input = input ,
427- context = context .context ,
428- run_config = run_config ,
429- max_turns = resolved_max_turns ,
430- hooks = hooks ,
431- previous_response_id = previous_response_id ,
432- conversation_id = conversation_id ,
433- session = session ,
434- )
447+ run_result : RunResult | RunResultStreaming
448+
449+ if on_stream is not None :
450+ run_result = Runner .run_streamed (
451+ starting_agent = self ,
452+ input = input ,
453+ context = context .context ,
454+ run_config = run_config ,
455+ max_turns = resolved_max_turns ,
456+ hooks = hooks ,
457+ previous_response_id = previous_response_id ,
458+ conversation_id = conversation_id ,
459+ session = session ,
460+ )
461+ # Dispatch callbacks in the background so slow handlers do not block
462+ # event consumption.
463+ event_queue : asyncio .Queue [AgentToolStreamEvent | None ] = asyncio .Queue ()
464+
465+ async def _run_handler (payload : AgentToolStreamEvent ) -> None :
466+ """Execute the user callback while capturing exceptions."""
467+ try :
468+ maybe_result = on_stream (payload )
469+ if inspect .isawaitable (maybe_result ):
470+ await maybe_result
471+ except Exception :
472+ logger .exception (
473+ "Error while handling on_stream event for agent tool %s." ,
474+ self .name ,
475+ )
476+
477+ async def dispatch_stream_events () -> None :
478+ while True :
479+ payload = await event_queue .get ()
480+ is_sentinel = payload is None # None marks the end of the stream.
481+ try :
482+ if payload is not None :
483+ await _run_handler (payload )
484+ finally :
485+ event_queue .task_done ()
486+
487+ if is_sentinel :
488+ break
489+
490+ dispatch_task = asyncio .create_task (dispatch_stream_events ())
491+
492+ try :
493+ from .stream_events import AgentUpdatedStreamEvent
494+
495+ current_agent = run_result .current_agent
496+ async for event in run_result .stream_events ():
497+ if isinstance (event , AgentUpdatedStreamEvent ):
498+ current_agent = event .new_agent
499+
500+ payload : AgentToolStreamEvent = {
501+ "event" : event ,
502+ "agent" : current_agent ,
503+ "tool_call" : context .tool_call ,
504+ }
505+ await event_queue .put (payload )
506+ finally :
507+ await event_queue .put (None )
508+ await event_queue .join ()
509+ await dispatch_task
510+ else :
511+ run_result = await Runner .run (
512+ starting_agent = self ,
513+ input = input ,
514+ context = context .context ,
515+ run_config = run_config ,
516+ max_turns = resolved_max_turns ,
517+ hooks = hooks ,
518+ previous_response_id = previous_response_id ,
519+ conversation_id = conversation_id ,
520+ session = session ,
521+ )
435522 if custom_output_extractor :
436- return await custom_output_extractor (output )
523+ return await custom_output_extractor (run_result )
437524
438- return output .final_output
525+ return run_result .final_output
439526
440527 return run_agent
441528
0 commit comments