-
Notifications
You must be signed in to change notification settings - Fork 28
Expand file tree
/
Copy pathagent.py
More file actions
262 lines (228 loc) · 8.57 KB
/
agent.py
File metadata and controls
262 lines (228 loc) · 8.57 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
import json
import copy
import inspect
from openai import OpenAI
from pydantic import BaseModel
from typing_extensions import Literal
from typing import Union, Callable, List, Optional
def pretty_print_messages(messages) -> None:
for message in messages:
if message["role"] != "assistant":
continue
# print agent name in blue
print(f"\033[94m{message['sender']}\033[0m:", end=" ")
# print response, if any
if message["content"]:
print(message["content"])
# print tool calls in purple, if any
tool_calls = message.get("tool_calls") or []
if len(tool_calls) > 1:
print()
for tool_call in tool_calls:
f = tool_call["function"]
name, args = f["name"], f["arguments"]
arg_str = json.dumps(json.loads(args)).replace(":", "=")
print(f"\033[95m{name}\033[0m({arg_str[1:-1]})")
def function_to_json(func) -> dict:
"""
Sample Input:
def add_two_numbers(a: int, b: int) -> int:
# Adds two numbers together
return a + b
Sample Output:
{
'type': 'function',
'function': {
'name': 'add_two_numbers',
'description': 'Adds two numbers together',
'parameters': {
'type': 'object',
'properties': {
'a': {'type': 'integer'},
'b': {'type': 'integer'}
},
'required': ['a', 'b']
}
}
}
"""
type_map = {
str: "string",
int: "integer",
float: "number",
bool: "boolean",
list: "array",
dict: "object",
type(None): "null",
}
try:
signature = inspect.signature(func)
except ValueError as e:
raise ValueError(
f"Failed to get signature for function {func.__name__}: {str(e)}"
)
parameters = {}
for param in signature.parameters.values():
try:
param_type = type_map.get(param.annotation, "string")
except KeyError as e:
raise KeyError(
f"Unknown type annotation {param.annotation} for parameter {param.name}: {str(e)}"
)
parameters[param.name] = {"type": param_type}
required = [
param.name
for param in signature.parameters.values()
if param.default == inspect._empty
]
return {
"type": "function",
"function": {
"name": func.__name__,
"description": func.__doc__ or "",
"parameters": {
"type": "object",
"properties": parameters,
"required": required,
},
},
}
AgentFunction = Callable[[], Union[str, "Agent", dict]]
class Agent(BaseModel):
# Just a simple class. Doesn't contain any methods out of the box
name: str = "Agent"
model: str = "gpt-4o"
instructions: Union[str, Callable[[], str]] = "You are a helpful agent."
functions: List[AgentFunction] = []
tool_choice: str = None
parallel_tool_calls: bool = True
class Response(BaseModel):
# Response is used to encapsulate the entire conversation output
messages: List = []
agent: Optional[Agent] = None
class Function(BaseModel):
arguments: str
name: str
class ChatCompletionMessageToolCall(BaseModel):
id: str # The ID of the tool call
function: Function # The function that the model called
type: Literal["function"] # The type of the tool. Currently, only `function` is supported
class Result(BaseModel):
# Result is used to encapsulate the return value of a single function/tool call
value: str = "" # The result value as a string.
agent: Optional[Agent] = None # The agent instance, if applicable.
class Swarm:
# Implements the core logic of orchestrating a single/multi-agent system
def __init__(
self,
client=None,
):
if not client:
client = OpenAI()
self.client = client
def get_chat_completion(
self,
agent: Agent,
history: List,
model_override: str
):
messages = [{"role": "system", "content": agent.instructions}] + history
tools = [function_to_json(f) for f in agent.functions]
create_params = {
"model": model_override or agent.model,
"messages": messages,
"tools": tools or None,
"tool_choice": agent.tool_choice,
}
if tools:
create_params["parallel_tool_calls"] = agent.parallel_tool_calls
return self.client.chat.completions.create(**create_params)
def handle_function_result(self, result) -> Result:
match result:
case Result() as result:
return result
case Agent() as agent:
return Result(
value=json.dumps({"assistant": agent.name}),
agent=agent
)
case _:
try:
return Result(value=str(result))
except Exception as e:
raise TypeError(e)
def handle_tool_calls(
self,
tool_calls: List[ChatCompletionMessageToolCall],
functions: List[AgentFunction]
) -> Response:
function_map = {f.__name__: f for f in functions}
partial_response = Response(messages=[], agent=None)
for tool_call in tool_calls:
name = tool_call.function.name
# handle missing tool case, skip to next tool
if name not in function_map:
partial_response.messages.append(
{
"role": "tool",
"tool_call_id": tool_call.id,
"tool_name": name,
"content": f"Error: Tool {name} not found.",
}
)
continue
args = json.loads(tool_call.function.arguments)
raw_result = function_map[name](**args)
print(f'Called function {name} with args: {args} and obtained result: {raw_result}')
print('#############################################')
result: Result = self.handle_function_result(raw_result)
partial_response.messages.append(
{
"role": "tool",
"tool_call_id": tool_call.id,
"tool_name": name,
"content": result.value,
}
)
if result.agent:
partial_response.agent = result.agent
return partial_response
def run(
self,
agent: Agent,
messages: List,
model_override: str = None,
max_turns: int = float("inf"),
execute_tools: bool = True,
) -> Response:
active_agent = agent
history = copy.deepcopy(messages)
init_len = len(messages)
print('#############################################')
print(f'history: {history}')
print('#############################################')
while len(history) - init_len < max_turns and active_agent:
completion = self.get_chat_completion(
agent=active_agent,
history=history,
model_override=model_override
)
message = completion.choices[0].message
message.sender = active_agent.name
print(f'Active agent: {active_agent.name}')
print(f"message: {message}")
print('#############################################')
history.append(json.loads(message.model_dump_json()))
if not message.tool_calls or not execute_tools:
print('No tool calls hence breaking')
print('#############################################')
break
partial_response = self.handle_tool_calls(message.tool_calls, active_agent.functions)
history.extend(partial_response.messages)
if partial_response.agent:
active_agent = partial_response.agent
message.sender = active_agent.name
return Response(
messages=history[init_len:],
agent=active_agent,
)