-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathagent.py
More file actions
166 lines (122 loc) · 4.53 KB
/
agent.py
File metadata and controls
166 lines (122 loc) · 4.53 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
from dataclasses import dataclass
from typing import Any, TypedDict
from uuid import UUID, uuid4
from langchain.agents import create_agent, AgentState
from dotenv import load_dotenv
from langchain_core.messages import AIMessageChunk
from langgraph.checkpoint.memory import InMemorySaver
from langgraph.graph.state import CompiledStateGraph
from langchain.agents.middleware.types import (
AgentState,
_InputAgentState,
_OutputAgentState,
)
from langchain.tools import tool, ToolRuntime
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
from fastapi import Depends, FastAPI
from contextlib import asynccontextmanager
from psycopg_pool import AsyncConnectionPool
import uvicorn
from starlette.responses import StreamingResponse
load_dotenv()
@dataclass
class Context:
city: str
class AgentConfig(TypedDict):
model: str
tools: list[str]
system_prompt: str
context: Context | None
@tool
def weather_tool(runtime: ToolRuntime[Context], city: str | None) -> str:
"""Get the weather for a city. If no city is provided, use the city from context."""
city = city or runtime.context.city
return f"it's sunny and 70 degrees in {city}"
agentConfig = AgentConfig(
model="openai:openai/gpt-oss-20b",
tools=["weather_tool"],
system_prompt="You are a helpful research assistant.",
context=Context(city="New York"),
)
async def create_configurable_agent(
config: AgentConfig,
postgresCheckpointer: AsyncPostgresSaver,
) -> CompiledStateGraph[AgentState, Any, _InputAgentState, _OutputAgentState]:
tools = []
for tool_name in config["tools"]:
match tool_name:
case "weather_tool":
tools.append(weather_tool)
return create_agent(
model=config["model"],
tools=tools,
system_prompt=config["system_prompt"],
# checkpointer=InMemorySaver(),
checkpointer=postgresCheckpointer,
context_schema=Context,
)
async def run_agent_streaming_response(
agent: CompiledStateGraph[AgentState, Any, _InputAgentState, _OutputAgentState],
session_id: UUID,
input_text: str,
context: Context | None = None,
):
response = agent.astream(
input={
"messages": [
{
"role": "user",
"content": input_text,
}
]
},
stream_mode=["messages", "updates"],
config={"configurable": {"thread_id": str(session_id)}},
context=context,
)
async for event_type, event in response:
if event_type == "messages":
chunk, state = event
if isinstance(chunk, AIMessageChunk):
# print(chunk.content, end="", flush=True)
yield chunk.content
async def initialize_resources():
# e.g., connect to DB, preload cache
print("Initializing resources...")
global postgresCheckpointer
postgresCheckpointer = AsyncPostgresSaver.from_conn_string(
"postgresql://postgres:postgres@localhost/postgres"
)
async def cleanup_resources():
# e.g., close DB connections
print("Cleaning up resources...")
@asynccontextmanager
async def lifespan(app: FastAPI):
# Startup logic
print("🚀 Server is starting up...")
await initialize_resources()
yield # App runs here
# Shutdown logic
print("🛑 Server is shutting down...")
await cleanup_resources()
app = FastAPI(lifespan=lifespan)
# Create a connection pool
pool = AsyncConnectionPool(
conninfo="postgresql://postgres:postgres@localhost/postgres")
# Create the saver context manager
def get_saver():
return AsyncPostgresSaver.from_conn_string(
"postgresql://postgres:postgres@localhost/postgres"
)
@app.get("/")
async def root(input: str = "Hello World", session_id: UUID | None = None):
async with postgresCheckpointer as postgresCheckpointerInstance:
if session_id is None:
session_id = uuid4()
agent = await create_configurable_agent(agentConfig, postgresCheckpointerInstance)
async def event_generator():
async for chunk in run_agent_streaming_response(agent, session_id, input, context=agentConfig["context"]):
yield f"data: {chunk}\n\n"
return StreamingResponse(event_generator(), media_type="text/event-stream")
if __name__ == "__main__":
uvicorn.run(app, host="localhost", port=8000)