Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 85 additions & 0 deletions crates/agent-gui/src-tauri/src/commands/workspace/fs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2356,6 +2356,61 @@ pub async fn fs_read_editable_text(
.await
}

#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct PathStatusResponse {
pub path: String,
pub exists: bool,
pub kind: Option<String>,
pub size_bytes: Option<u64>,
pub mtime_ms: Option<u64>,
}

pub(crate) fn fs_path_status_sync(
workdir: String,
path: String,
) -> Result<PathStatusResponse, String> {
let wd = canonicalize_workdir(&workdir).map_err(|e| e.to_string())?;
let rel = sanitize_rel_path(&path).map_err(|e| e.to_string())?;
let logical_path = logical_rel_path(&rel);
let target = wd.join(&rel);

match fs::symlink_metadata(&target) {
Ok(meta) => {
let file_type = meta.file_type();
let kind = if file_type.is_symlink() {
"symlink"
} else if meta.is_file() {
"file"
} else if meta.is_dir() {
"dir"
} else {
"other"
};
Ok(PathStatusResponse {
path: logical_path,
exists: true,
kind: Some(kind.to_string()),
size_bytes: Some(meta.len()),
mtime_ms: Some(metadata_mtime_ms(&meta)),
})
}
Err(err) if err.kind() == io::ErrorKind::NotFound => Ok(PathStatusResponse {
path: logical_path,
exists: false,
kind: None,
size_bytes: None,
mtime_ms: None,
}),
Err(err) => Err(FsError::Io(err).to_string()),
}
}

#[tauri::command(rename_all = "snake_case")]
pub async fn fs_path_status(workdir: String, path: String) -> Result<PathStatusResponse, String> {
run_blocking("fs_path_status", move || fs_path_status_sync(workdir, path)).await
}

#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct WriteTextResponse {
Expand Down Expand Up @@ -4229,6 +4284,36 @@ mod tests {
let _ = fs::remove_dir_all(workdir);
}

#[test]
fn path_status_reports_existing_file_directory_and_missing_path() {
let workdir = unique_test_workdir("path-status");
fs::create_dir_all(workdir.join("src")).expect("create workdir");
fs::write(workdir.join("src/main.rs"), "fn main() {}\n").expect("write file");

let file = fs_path_status_sync(workdir.display().to_string(), "src/main.rs".to_string())
.expect("file status should succeed");
assert_eq!(file.path, "src/main.rs");
assert!(file.exists);
assert_eq!(file.kind.as_deref(), Some("file"));
assert_eq!(file.size_bytes, Some("fn main() {}\n".len() as u64));
assert!(file.mtime_ms.unwrap_or_default() > 0);

let dir = fs_path_status_sync(workdir.display().to_string(), "src".to_string())
.expect("dir status should succeed");
assert_eq!(dir.kind.as_deref(), Some("dir"));
assert!(dir.exists);

let missing = fs_path_status_sync(workdir.display().to_string(), "new.html".to_string())
.expect("missing status should succeed");
assert_eq!(missing.path, "new.html");
assert!(!missing.exists);
assert!(missing.kind.is_none());
assert!(missing.size_bytes.is_none());
assert!(missing.mtime_ms.is_none());

let _ = fs::remove_dir_all(workdir);
}

#[test]
fn read_editable_text_rejects_invalid_targets_and_non_utf8() {
let workdir = unique_test_workdir("read-editable-invalid");
Expand Down
1 change: 1 addition & 0 deletions crates/agent-gui/src-tauri/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ macro_rules! app_invoke_handler {
// File system
commands::fs::fs_read_text,
commands::fs::fs_read_editable_text,
commands::fs::fs_path_status,
commands::fs::fs_read_image_source,
commands::fs::fs_read_workspace_image,
commands::fs::fs_write_text,
Expand Down
149 changes: 145 additions & 4 deletions crates/agent-gui/src/lib/chat/runner/agentRunner.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import { Agent, type AgentTool } from "@earendil-works/pi-agent-core";
import type {
AssistantMessage,
AssistantMessageEvent,
Context,
Message,
ToolCall,
Expand Down Expand Up @@ -640,6 +641,10 @@ export async function runAssistantWithTools(params: {
signal?: AbortSignal,
context?: ToolExecutionEventContext,
) => Promise<Message>;
preflightToolCall?: (
toolCall: ToolCall,
signal?: AbortSignal,
) => Promise<{ toolCall?: ToolCall; toolResult: ToolResultMessage } | null>;
onTurnStart?: (round: number) => void;
onTextDelta: (delta: string, round: number) => void;
onThinkingDelta?: (delta: string, round: number) => void;
Expand Down Expand Up @@ -712,6 +717,7 @@ export async function runAssistantWithTools(params: {

const toolResultErrorFlags = new Map<string, boolean>();
const toolCallsById = new Map<string, ToolCall>();
const streamPreflightToolResults = new Map<string, ToolResultMessage>();
const parallelBatchKeyByToolCallId = new Map<string, string>();
const parallelToolBatches = new Map<string, ParallelToolBatch>();
const llmTools = params.tools ?? [];
Expand Down Expand Up @@ -970,7 +976,9 @@ export async function runAssistantWithTools(params: {
}

function replaceAgentStateMessage(target: Message, replacement: Message) {
const stateMessages = getAgentMessages(agent);
const currentAgent = agent;
if (!currentAgent) return false;
const stateMessages = getAgentMessages(currentAgent);
let targetIndex = stateMessages.lastIndexOf(target);
if (targetIndex < 0) {
for (let index = stateMessages.length - 1; index >= 0; index -= 1) {
Expand All @@ -989,7 +997,7 @@ export async function runAssistantWithTools(params: {
}
}
if (targetIndex < 0) return false;
agent!.state.messages = [
currentAgent.state.messages = [
...stateMessages.slice(0, targetIndex),
replacement,
...stateMessages.slice(targetIndex + 1),
Expand Down Expand Up @@ -1079,6 +1087,16 @@ export async function runAssistantWithTools(params: {
});
toolCallsById.set(toolCall.id, toolCall);

const preflightToolResult = streamPreflightToolResults.get(toolCall.id);
if (preflightToolResult) {
streamPreflightToolResults.delete(toolCall.id);
toolResultErrorFlags.set(toolCall.id, Boolean(preflightToolResult.isError));
return {
content: preflightToolResult.content,
details: preflightToolResult.details ?? {},
};
}

if (tool.name === "Bash" || tool.name === "Agent") {
const batchKey = parallelBatchKeyByToolCallId.get(toolCallId);
if (batchKey) {
Expand Down Expand Up @@ -1126,6 +1144,118 @@ export async function runAssistantWithTools(params: {
agentTools = [...visibleAgentTools, ...hiddenProviderNativeWebSearchAgentTools];

let streamRound = 0;
function getToolCallFromStreamEvent(event: AssistantMessageEvent) {
if (
event.type !== "toolcall_start" &&
event.type !== "toolcall_delta" &&
event.type !== "toolcall_end"
) {
return null;
}

const toolCall =
event.type === "toolcall_end" ? event.toolCall : event.partial.content[event.contentIndex];
return toolCall?.type === "toolCall"
? {
contentIndex: event.contentIndex,
toolCall,
partial: event.partial,
}
: null;
}

function buildPreflightToolUseAssistant(
partial: AssistantMessage,
contentIndex: number,
toolCall: ToolCall,
): AssistantMessage {
const content = partial.content.slice();
content[contentIndex] = toolCall;
return {
...partial,
content,
stopReason: "toolUse",
errorMessage: undefined,
};
}

function wrapStreamWithToolPreflight(
source: ReturnType<typeof streamSimpleByApi>,
signal: AbortSignal | undefined,
abortEarly: () => void,
cleanup: () => void,
): ReturnType<typeof streamSimpleByApi> {
let preflightFinalMessage: AssistantMessage | null = null;

return {
async *[Symbol.asyncIterator]() {
const iterator = source[Symbol.asyncIterator]();
try {
while (true) {
const next = await iterator.next();
if (next.done) return;

const event = next.value;
const candidate = getToolCallFromStreamEvent(event);
const effectiveToolCall = candidate
? normalizeToolCallNameForExecution(candidate.toolCall)
: null;
if (!candidate || !effectiveToolCall || !params.preflightToolCall) {
yield event;
continue;
}

const preflight = await params.preflightToolCall(effectiveToolCall, signal);

if (!preflight) {
yield event;
continue;
}

const completedToolCall = normalizeToolCallNameForExecution(
preflight.toolCall ?? effectiveToolCall,
);
toolCallsById.set(completedToolCall.id, completedToolCall);
streamPreflightToolResults.set(completedToolCall.id, {
...preflight.toolResult,
toolCallId: completedToolCall.id,
toolName: completedToolCall.name,
});
preflightFinalMessage = buildPreflightToolUseAssistant(
candidate.partial,
candidate.contentIndex,
completedToolCall,
);

abortEarly();
await iterator.return?.();

yield event;
if (event.type !== "toolcall_end") {
yield {
type: "toolcall_end",
contentIndex: candidate.contentIndex,
toolCall: completedToolCall,
partial: preflightFinalMessage,
};
}
yield {
type: "done",
reason: "toolUse",
message: preflightFinalMessage,
};
return;
}
} finally {
cleanup();
}
},
result() {
return preflightFinalMessage ? Promise.resolve(preflightFinalMessage) : source.result();
},
} as unknown as ReturnType<typeof streamSimpleByApi>;
}

const streamFn = (streamModel: typeof model, streamContext: Context, options?: any) => {
const round = ++streamRound;
const streamTools =
Expand All @@ -1149,6 +1279,11 @@ export async function runAssistantWithTools(params: {
const hostedSearchProbeId = shouldProbeHostedSearch
? createHostedSearchProbeId(params.providerId)
: undefined;
const earlyPreflightAbortController = new AbortController();
const streamAbortSignal = createLinkedAbortSignal([
options?.signal,
earlyPreflightAbortController.signal,
]);
let streamOptions: StreamOptionsEx = {
...(options ?? {}),
apiKey: options?.apiKey ?? params.runtime.apiKey,
Expand All @@ -1159,7 +1294,7 @@ export async function runAssistantWithTools(params: {
},
hostedSearchProbeId,
),
signal: options?.signal,
signal: streamAbortSignal.signal,
sessionId: options?.sessionId ?? params.sessionId,
cacheRetention:
options?.cacheRetention ??
Expand Down Expand Up @@ -1220,7 +1355,13 @@ export async function runAssistantWithTools(params: {
}),
);

return streamSimpleByApi(streamModel, effectiveContext, streamOptions);
const sourceStream = streamSimpleByApi(streamModel, effectiveContext, streamOptions);
return wrapStreamWithToolPreflight(
sourceStream,
options?.signal,
() => earlyPreflightAbortController.abort(),
streamAbortSignal.cleanup,
);
};

agent = new Agent({
Expand Down
18 changes: 18 additions & 0 deletions crates/agent-gui/src/lib/tools/builtinRegistry.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import type {
BuiltinToolBundle,
BuiltinToolExecutionContext,
BuiltinToolMetadata,
BuiltinToolPreflightResult,
} from "./builtinTypes";
import { createCronTools } from "./cronTools";
import { createCustomSystemTools } from "./customSystemTools";
Expand All @@ -46,6 +47,10 @@ export type BuiltinToolRegistry = {
signal?: AbortSignal,
context?: BuiltinToolExecutionContext,
) => Promise<ToolResultMessage>;
preflightToolCall: (
toolCall: ToolCall,
signal?: AbortSignal,
) => Promise<BuiltinToolPreflightResult | null>;
metadataByName: Map<string, BuiltinToolMetadata>;
hasTool: (toolName: string) => boolean;
};
Expand Down Expand Up @@ -105,6 +110,7 @@ function createBuiltinToolRegistry(bundles: BuiltinToolBundle[]): BuiltinToolReg
const tools: BuiltinToolBundle["tools"] = [];
const metadataByName = new Map<string, BuiltinToolMetadata>();
const executorsByName = new Map<string, BuiltinToolBundle["executeToolCall"]>();
const preflightsByName = new Map<string, NonNullable<BuiltinToolBundle["preflightToolCall"]>>();
const canonicalToolNameByLookupKey = new Map<string, string | null>();

const registerCanonicalToolName = (toolName: string) => {
Expand All @@ -131,6 +137,9 @@ function createBuiltinToolRegistry(bundles: BuiltinToolBundle[]): BuiltinToolReg
}
tools.push(tool);
executorsByName.set(tool.name, bundle.executeToolCall);
if (bundle.preflightToolCall) {
preflightsByName.set(tool.name, bundle.preflightToolCall);
}
registerCanonicalToolName(tool.name);
const metadata = bundle.metadataByName.get(tool.name);
if (metadata) {
Expand Down Expand Up @@ -172,6 +181,15 @@ function createBuiltinToolRegistry(bundles: BuiltinToolBundle[]): BuiltinToolReg
resolvedToolName === toolCall.name ? toolCall : { ...toolCall, name: resolvedToolName };
return execute(effectiveToolCall, signal, context);
},
async preflightToolCall(toolCall, signal) {
const resolvedToolName = resolveToolName(toolCall.name);
if (!resolvedToolName) return null;
const preflight = preflightsByName.get(resolvedToolName);
if (!preflight) return null;
const effectiveToolCall =
resolvedToolName === toolCall.name ? toolCall : { ...toolCall, name: resolvedToolName };
return preflight(effectiveToolCall, signal);
},
};
}

Expand Down
Loading
Loading