-
Notifications
You must be signed in to change notification settings - Fork 99
Expand file tree
/
Copy pathagent.py
More file actions
139 lines (124 loc) · 5.36 KB
/
agent.py
File metadata and controls
139 lines (124 loc) · 5.36 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Iterator, List, Optional, Tuple, Union
from llama_stack_client import LlamaStackClient
from llama_stack_client.types import ToolResponseMessage, UserMessage
from llama_stack_client.types.agent_create_params import AgentConfig
from llama_stack_client.types.agents.turn import Turn
from llama_stack_client.types.agents.turn_create_params import Document, Toolgroup
from llama_stack_client.types.agents.turn_create_response import AgentTurnResponseStreamChunk
from .client_tool import ClientTool
from .output_parser import OutputParser
DEFAULT_MAX_ITER = 10
class Agent:
def __init__(
self,
client: LlamaStackClient,
agent_config: AgentConfig,
client_tools: Tuple[ClientTool] = (),
memory_bank_id: Optional[str] = None,
output_parser: Optional[OutputParser] = None,
):
self.client = client
self.agent_config = agent_config
self.agent_id = self._create_agent(agent_config)
self.client_tools = {t.get_name(): t for t in client_tools}
self.sessions = []
self.memory_bank_id = memory_bank_id
self.output_parser = output_parser
def _create_agent(self, agent_config: AgentConfig) -> int:
agentic_system_create_response = self.client.agents.create(
agent_config=agent_config,
)
self.agent_id = agentic_system_create_response.agent_id
return self.agent_id
def create_session(self, session_name: str) -> int:
agentic_system_create_session_response = self.client.agents.session.create(
agent_id=self.agent_id,
session_name=session_name,
)
self.session_id = agentic_system_create_session_response.session_id
self.sessions.append(self.session_id)
return self.session_id
def _has_tool_call(self, chunk: AgentTurnResponseStreamChunk) -> bool:
if chunk.event.payload.event_type != "turn_complete":
return False
message = chunk.event.payload.turn.output_message
if message.stop_reason == "out_of_tokens":
return False
if self.output_parser:
parsed_message = self.output_parser.parse(message)
message = parsed_message
return len(message.tool_calls) > 0
def _run_tool(self, chunk: AgentTurnResponseStreamChunk) -> ToolResponseMessage:
message = chunk.event.payload.turn.output_message
tool_call = message.tool_calls[0]
if tool_call.tool_name not in self.client_tools:
return ToolResponseMessage(
call_id=tool_call.call_id,
tool_name=tool_call.tool_name,
content=f"Unknown tool `{tool_call.tool_name}` was called.",
role="tool",
)
tool = self.client_tools[tool_call.tool_name]
result_messages = tool.run([message])
next_message = result_messages[0]
return next_message
def create_turn(
self,
messages: List[Union[UserMessage, ToolResponseMessage]],
session_id: Optional[str] = None,
toolgroups: Optional[List[Toolgroup]] = None,
documents: Optional[List[Document]] = None,
stream: bool = True,
) -> Iterator[AgentTurnResponseStreamChunk] | Turn:
if stream:
return self._create_turn_streaming(messages, session_id, toolgroups, documents, stream)
else:
chunk = None
for chunk in self._create_turn_streaming(messages, session_id, toolgroups, documents, stream):
pass
if not chunk:
raise Exception("No chunk returned")
if chunk.event.payload.event_type != "turn_complete":
raise Exception("Turn did not complete")
return chunk.event.payload.turn
def _create_turn_streaming(
self,
messages: List[Union[UserMessage, ToolResponseMessage]],
session_id: Optional[str] = None,
toolgroups: Optional[List[Toolgroup]] = None,
documents: Optional[List[Document]] = None,
stream: bool = True,
) -> Iterator[AgentTurnResponseStreamChunk]:
stop = False
n_iter = 0
max_iter = self.agent_config.get("max_infer_iters", DEFAULT_MAX_ITER)
while not stop and n_iter < max_iter:
response = self.client.agents.turn.create(
agent_id=self.agent_id,
# use specified session_id or last session created
session_id=session_id or self.session_id[-1],
messages=messages,
stream=True,
documents=documents,
toolgroups=toolgroups,
)
# by default, we stop after the first turn
stop = True
for chunk in response:
if hasattr(chunk, "error"):
yield chunk
return
elif not self._has_tool_call(chunk):
yield chunk
else:
next_message = self._run_tool(chunk)
yield next_message
# continue the turn when there's a tool call
stop = False
messages = [next_message]
n_iter += 1