-
Notifications
You must be signed in to change notification settings - Fork 9
Expand file tree
/
Copy pathloop.py
More file actions
300 lines (249 loc) · 9.67 KB
/
loop.py
File metadata and controls
300 lines (249 loc) · 9.67 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
"""
Gemini Computer Use sampling loop.
Based on Google's computer-use-preview reference implementation.
"""
from datetime import datetime
from typing import Any, Dict, List
from google import genai
from google.genai import types
from google.genai.types import (
Content,
FunctionResponse,
GenerateContentConfig,
Part,
)
from kernel import Kernel
from tools import ComputerTool, PREDEFINED_COMPUTER_USE_FUNCTIONS
# System prompt for browser-based computer use
def get_system_prompt() -> str:
"""Generate system prompt with current date."""
current_date = datetime.now().strftime("%A, %B %d, %Y")
return f"""You are a helpful assistant that can use a web browser.
You are operating a Chrome browser through computer use tools.
The browser is already open and ready for use.
When you need to navigate to a page, use the navigate action with a full URL.
When you need to interact with elements, use click_at, type_text_at, etc.
After each action, carefully evaluate the screenshot to determine your next step.
Current date: {current_date}."""
# Maximum number of recent turns to keep screenshots for (to manage context)
MAX_RECENT_TURN_WITH_SCREENSHOTS = 3
async def sampling_loop(
*,
model: str,
query: str,
api_key: str,
kernel: Kernel,
session_id: str,
max_iterations: int = 50,
system_prompt_suffix: str = "",
) -> Dict[str, Any]:
"""
Run the Gemini computer use sampling loop.
Args:
model: The Gemini model to use
query: The user's query/task
api_key: Google API key
kernel: Kernel client instance
session_id: Browser session ID
max_iterations: Maximum number of loop iterations
system_prompt_suffix: Additional system prompt text
Returns:
Dict with 'final_response', 'iterations', and 'error'
"""
# Initialize the Gemini client
client = genai.Client(api_key=api_key)
computer_tool = ComputerTool(kernel, session_id)
# Initialize conversation with user query
contents: List[Content] = [
Content(
role="user",
parts=[Part(text=query)],
)
]
base_prompt = get_system_prompt()
system_prompt = (
f"{base_prompt}\n\n{system_prompt_suffix}"
if system_prompt_suffix
else base_prompt
)
# Generate content config
generate_content_config = GenerateContentConfig(
temperature=1,
top_p=0.95,
top_k=40,
max_output_tokens=8192,
system_instruction=system_prompt,
tools=[
types.Tool(
computer_use=types.ComputerUse(
environment=types.Environment.ENVIRONMENT_BROWSER,
),
),
],
thinking_config=types.ThinkingConfig(include_thoughts=True),
)
iteration = 0
final_response = ""
error = None
while iteration < max_iterations:
iteration += 1
print(f"\n=== Iteration {iteration} ===")
try:
# Generate response from Gemini
response = client.models.generate_content(
model=model,
contents=contents,
config=generate_content_config,
)
if not response.candidates:
print("No candidates in response")
break
candidate = response.candidates[0]
if not candidate.content:
print("No content in candidate")
break
# Add assistant response to conversation
contents.append(candidate.content)
# Extract text and function calls
reasoning = _extract_text(candidate.content)
function_calls = _extract_function_calls(candidate.content)
# Log the response
print(f"Reasoning: {reasoning or '(none)'}")
print(f"Function calls: {len(function_calls)}")
for fc in function_calls:
print(f" - {fc.name}: {fc.args}")
# Check finish reason
finish_reason = candidate.finish_reason
if (
finish_reason == types.FinishReason.MALFORMED_FUNCTION_CALL
and not function_calls
):
print("Malformed function call, retrying...")
continue
# If no function calls, the model is done
if not function_calls:
print("Agent loop complete")
final_response = reasoning or ""
break
# Execute function calls and collect results
function_responses: List[Part] = []
for fc in function_calls:
args = dict(fc.args) if fc.args else {}
# Handle safety decisions if present
if (
"safety_decision" in args
and args["safety_decision"].get("decision") == "require_confirmation"
):
print(
f"Safety confirmation required: {args['safety_decision'].get('explanation')}"
)
print("Auto-acknowledging safety check")
# Execute the action
print(f"Executing action: {fc.name}")
result = await computer_tool.execute_action(fc.name, args)
if result.error:
print(f"Action error: {result.error}")
function_responses.append(
Part(
function_response=FunctionResponse(
name=fc.name,
response={"error": result.error, "url": result.url or "about:blank"},
)
)
)
else:
# Build response with screenshot - always include URL (required by Computer Use API)
response_data: Dict[str, Any] = {
"url": result.url or "about:blank",
}
# Include screenshot for predefined functions
parts = None
if result.base64_image and _is_predefined_function(fc.name):
parts = [
types.FunctionResponsePart(
inline_data=types.FunctionResponseBlob(
mime_type="image/png",
data=result.base64_image,
)
)
]
function_responses.append(
Part(
function_response=FunctionResponse(
name=fc.name,
response=response_data,
parts=parts,
)
)
)
# Add function responses to conversation
contents.append(
Content(
role="user",
parts=function_responses,
)
)
# Manage screenshot history to avoid context overflow
_prune_old_screenshots(contents)
except Exception as e:
error = str(e)
print(f"Error in sampling loop: {error}")
break
if iteration >= max_iterations:
print("Max iterations reached")
return {
"final_response": final_response,
"iterations": iteration,
"error": error,
}
def _extract_text(content: Content) -> str:
if not content.parts:
return ""
texts: List[str] = []
for part in content.parts:
if hasattr(part, "text") and part.text:
texts.append(part.text)
return " ".join(texts)
def _extract_function_calls(content: Content) -> List[types.FunctionCall]:
if not content.parts:
return []
calls: List[types.FunctionCall] = []
for part in content.parts:
if hasattr(part, "function_call") and part.function_call:
calls.append(part.function_call)
return calls
def _is_predefined_function(name: str) -> bool:
return name in [a.value for a in PREDEFINED_COMPUTER_USE_FUNCTIONS]
def _prune_old_screenshots(contents: List[Content]) -> None:
turns_with_screenshots = 0
# Iterate in reverse to find recent turns with screenshots
for content in reversed(contents):
if content.role != "user" or not content.parts:
continue
# Check if this turn has screenshots from predefined functions
has_screenshot = False
for part in content.parts:
if (
hasattr(part, "function_response")
and part.function_response
and _is_predefined_function(part.function_response.name or "")
):
# Check if it has parts (which contain screenshots)
if (
hasattr(part.function_response, "parts")
and part.function_response.parts
):
has_screenshot = True
break
if has_screenshot:
turns_with_screenshots += 1
# Remove screenshots from old turns
if turns_with_screenshots > MAX_RECENT_TURN_WITH_SCREENSHOTS:
for part in content.parts:
if (
hasattr(part, "function_response")
and part.function_response
and _is_predefined_function(part.function_response.name or "")
):
# Remove the parts array (which contains the screenshot)
part.function_response.parts = None