-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathDRIFTToolsExecutionLoop.py
More file actions
43 lines (38 loc) · 1.91 KB
/
DRIFTToolsExecutionLoop.py
File metadata and controls
43 lines (38 loc) · 1.91 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
from import_lib import *
class DRIFTToolsExecutionLoop(ToolsExecutionLoop):
"""Executes in loop a sequence of pipeline elements related to tool execution until the
LLM does not return any tool calls.
Args:
elements: a sequence of pipeline elements to be executed in loop. One of them should be
an LLM, and one of them should be a [ToolsExecutor][agentdojo.agent_pipeline.ToolsExecutor] (or
something that behaves similarly by executing function calls). You can find an example usage
of this class [here](../../concepts/agent_pipeline.md#combining-pipeline-components).
max_iters: maximum number of iterations to execute the pipeline elements in loop.
"""
def __init__(self, elements: Sequence[BasePipelineElement], max_iters: int = 15) -> None:
self.elements = elements
self.max_iters = max_iters
def query(
self,
query: str,
runtime: FunctionsRuntime,
env: Env = EmptyEnv(),
messages: Sequence[ChatMessage] = [],
extra_args: dict = {},
) -> tuple[str, FunctionsRuntime, Env, Sequence[ChatMessage], dict]:
if len(messages) == 0:
raise ValueError("Messages should not be empty when calling ToolsExecutionLoop")
logger = Logger().get()
for _ in range(self.max_iters):
last_message = messages[-1]
if "[CALL ERROR]" not in last_message["content"]:
if not last_message["role"] == "assistant":
break
if last_message["tool_calls"] is None:
break
if len(last_message["tool_calls"]) == 0:
break
for element in self.elements:
query, runtime, env, messages, extra_args = element.query(query, runtime, env, messages, extra_args)
logger.log(messages)
return query, runtime, env, messages, extra_args