Skip to content

Commit 03cf4fe

Browse files
authored
Dev Updates (#30)
## Changes - All providers removed from init to (hopefully) eventually expand provider list without requiring all provider APIs to be installed - A few misc code cleanup tasks - Reworking observable base class - Fixing unit tests accordingly
1 parent ef2a7d1 commit 03cf4fe

11 files changed

Lines changed: 278 additions & 425 deletions

File tree

agents/__init__.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
ProcessorDF,
1212
ProcessorIterable,
1313
)
14-
from .providers import OpenAIProvider, AzureOpenAIProvider, AzureOpenAIBatchProvider
1514
from .stopping_conditions import (
1615
StoppingCondition,
1716
StopOnStep,
@@ -30,9 +29,6 @@
3029
"BatchProcessorIterable",
3130
"ProcessorDF",
3231
"BatchProcessorDF",
33-
"OpenAIProvider",
34-
"AzureOpenAIProvider",
35-
"AzureOpenAIBatchProvider",
3632
"StoppingCondition",
3733
"StopOnStep",
3834
"StopOnDataModel",

agents/abstract.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from asyncio import Task, create_task, to_thread
1111
from dataclasses import dataclass, field
1212
from typing import (
13+
TYPE_CHECKING,
1314
Any,
1415
Awaitable,
1516
Callable,
@@ -23,13 +24,16 @@
2324
Union,
2425
)
2526

26-
from openai.types.chat import ChatCompletionMessageParam, ChatCompletionMessage
27-
from openai.types.chat.chat_completion import ChatCompletion, Choice
27+
if TYPE_CHECKING:
28+
from openai.types.chat import ChatCompletionMessageParam, ChatCompletionMessage
29+
from openai.types.chat.chat_completion import ChatCompletion, Choice
30+
2831
from pydantic import BaseModel, ValidationError
32+
from .observability import Observable
2933

3034
logger = logging.getLogger(__name__)
3135

32-
Message = Union[dict[str, str], ChatCompletionMessageParam]
36+
Message = Union[dict[str, str], "ChatCompletionMessageParam"]
3337

3438
P = TypeVar("P", bound="_Provider")
3539
A = TypeVar("A", bound="_Agent")
@@ -186,13 +190,13 @@ async def handler(self) -> Dict[str, Union[str, BaseModel]]:
186190
except ValidationError as e:
187191
# Case: Handle pydantic validation errors by passing them back to the
188192
# model to correct
189-
logger.warning("Failed Pydantic Validation.")
193+
logger.debug("Failed Pydantic Validation.")
190194
res = str(e)
191195

192196
return self._construct_return_message(self.id, res)
193197

194198

195-
class _Provider(Generic[A], metaclass=abc.ABCMeta):
199+
class _Provider(Observable, Generic[A], metaclass=abc.ABCMeta):
196200
"""
197201
A LLM Provider which should provide the standard methods for prompting and agent
198202
authenticating, etc.
@@ -201,11 +205,12 @@ class _Provider(Generic[A], metaclass=abc.ABCMeta):
201205
"The tool_call class specific to this provider that will be used to evaluate any tool calls from the model"
202206
tool_call_wrapper: Type[_ToolCall]
203207
"The method that will be used to call the OpenAI API, e.g. openai.chat.completions.create"
204-
endpoint_fn: Callable[..., Awaitable[ChatCompletion]]
208+
endpoint_fn: Callable[..., Awaitable["ChatCompletion"]]
205209

206210
mode: Literal["chat", "batch"]
207211

208212
def __init__(self, model_name: str, **kwargs):
213+
super().__init__()
209214
pass
210215

211216
@abc.abstractmethod
@@ -235,11 +240,11 @@ class _StoppingCondition(metaclass=abc.ABCMeta):
235240
"""
236241

237242
@abc.abstractmethod
238-
def __call__(self, cls: "_Agent", response: Choice) -> Optional[Any]:
243+
def __call__(self, cls: "_Agent", response: "Choice") -> Optional[Any]:
239244
raise NotImplementedError()
240245

241246

242-
class _Agent(metaclass=abc.ABCMeta):
247+
class _Agent(Observable, metaclass=abc.ABCMeta):
243248
terminated: bool = False
244249
truncated: bool = False
245250
curr_step: int = 1
@@ -254,6 +259,7 @@ class _Agent(metaclass=abc.ABCMeta):
254259
callback_output: list
255260
tool_res_payload: List[Message]
256261
provider: _Provider
262+
placeholder: Optional[Any]
257263

258264
def __init__(
259265
self,
@@ -265,6 +271,7 @@ def __init__(
265271
oai_kwargs: Optional[dict[str, Any]] = None,
266272
**fmt_kwargs,
267273
):
274+
super().__init__()
268275
pass
269276

270277
@abc.abstractmethod
@@ -277,7 +284,7 @@ async def step(self):
277284
raise NotImplementedError()
278285

279286
@abc.abstractmethod
280-
def _check_stop_condition(self, response: ChatCompletionMessage) -> None:
287+
def _check_stop_condition(self, response: "ChatCompletionMessage") -> None:
281288
"""
282289
Called from within :func:`step()`.
283290
Checks whether our stop condition has been met and handles assignment of answer, if so.
@@ -298,7 +305,7 @@ def get_next_messages(self) -> List[Message]:
298305
raise NotImplementedError()
299306

300307
@abc.abstractmethod
301-
def _handle_tool_calls(self, response: Choice) -> None:
308+
def _handle_tool_calls(self, response: "Choice") -> None:
302309
raise NotImplementedError()
303310

304311
@property

agents/agent/base.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,16 @@ def __init__(
6363
:param dict[str, any] oai_kwargs: Dict of additional OpenAI arguments to pass thru to chat call
6464
:param fmt_kwargs: Additional named arguments which will be inserted into the :func:`BASE_PROMPT` via fstring
6565
"""
66+
super().__init__(
67+
stopping_condition,
68+
model_name=None,
69+
provider=None,
70+
tools=None,
71+
callbacks=None,
72+
oai_kwargs=None,
73+
**fmt_kwargs,
74+
)
75+
6676
self.fmt_kwargs = fmt_kwargs
6777
self.stopping_condition = stopping_condition
6878
# We default to Azure OpenAI here, but
@@ -129,7 +139,7 @@ def _check_stop_condition(self, response):
129139
if (answer := self.stopping_condition(self, response)) is not None:
130140
self.answer = answer
131141
self.terminated = True
132-
logger.info("Stopping condition signaled, terminating.")
142+
logger.debug("Stopping condition signaled, terminating.")
133143

134144
async def step(self):
135145
"""
@@ -248,7 +258,7 @@ async def _handle_tool_calls(self, response):
248258
for payload, result in zip(tool_calls, tool_call_results):
249259
# Log it
250260
toolcall_str = f"{payload.func_name}({str(payload.kwargs)[:100] + '...(trunc)' if len(str(payload.kwargs)) > 100 else str(payload.kwargs)})"
251-
logger.info(f"Got tool call: {toolcall_str}")
261+
logger.debug(f"Got tool call: {toolcall_str}")
252262
self.scratchpad += f"\t=> {toolcall_str}\n"
253263
self.scratchpad += "\t\t"
254264

agents/agent/prediction.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
"""
55

66
import logging
7-
from typing import Callable, List, Literal, Optional, Any, Type
7+
from typing import Callable, List, Literal, Optional, Any
88

99
import pydantic
1010

@@ -30,7 +30,7 @@ def __init__(
3030
expected_len: Optional[int] = None,
3131
stopping_condition: Optional[_StoppingCondition] = None,
3232
model_name: Optional[str] = None,
33-
provider: Optional[Type[_Provider]] = None,
33+
provider: Optional[_Provider] = None,
3434
tools: Optional[List[dict]] = None,
3535
callbacks: Optional[List[Callable]] = None,
3636
oai_kwargs: Optional[dict[str, Any]] = None,
@@ -44,7 +44,7 @@ def __init__(
4444
:param int expected_len: Optional length constraint on the response_model (OpenAI API doesn't allow maxItems parameter in schema so this is checked post-hoc in the Pydantic BaseModel)
4545
:param _StoppingCondition stopping_condition: A handler that signals when an Agent has completed the task (optional)
4646
:param str model_name: Name of model to use (or deployment name for AzureOpenAI) (optional if provider is passed)
47-
:param Type[_Provider] provider: Instantiated OpenAI instance to use (optional)
47+
:param _Provider provider: Instantiated OpenAI instance to use (optional)
4848
:param List[dict] tools: List of tools the agent can call via response (optional)
4949
:param List[Callable] callbacks: List of callbacks to evaluate at end of run (optional)
5050
:param dict[str, any] oai_kwargs: Dict of additional OpenAI arguments to pass thru to chat call

0 commit comments

Comments
 (0)