From 521582ecabb7a20707a41ff74904ef4212b9bee0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9E=97=E7=8E=AE=20=28Jade=20Lin=29?= Date: Wed, 11 Feb 2026 11:05:50 +0800 Subject: [PATCH] refactor: improve code structure and maintainability --- .gitignore | 1 + crates/coco-tui/src/components/chat.rs | 15 +- crates/coco-tui/src/components/chat/state.rs | 5 +- src/agent.rs | 90 +++--- src/agent/prompt.rs | 68 ++-- src/combo/runner.rs | 228 +++++-------- src/combo/starter.rs | 251 ++++++--------- src/config.rs | 3 +- src/config/env.rs | 9 +- src/config/presets.rs | 16 +- src/logging.rs | 27 +- src/mcp.rs | 111 ++++--- src/provider.rs | 318 ++++++++++--------- src/tools.rs | 37 ++- src/tools/list.rs | 9 +- src/tools/read.rs | 10 +- src/tools/str_replace.rs | 11 +- 17 files changed, 592 insertions(+), 617 deletions(-) diff --git a/.gitignore b/.gitignore index 1a4595e2..311ae1dd 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ /target *.log +.context/ diff --git a/crates/coco-tui/src/components/chat.rs b/crates/coco-tui/src/components/chat.rs index 37d7a2c9..8988ab06 100644 --- a/crates/coco-tui/src/components/chat.rs +++ b/crates/coco-tui/src/components/chat.rs @@ -197,14 +197,15 @@ impl Default for Inner { enum ChatState { #[default] Ready, - Procesing, + #[serde(alias = "Procesing")] + Processing, } impl std::fmt::Display for ChatState { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Self::Ready => f.write_str("Ready"), - Self::Procesing => f.write_str("Procesing"), + Self::Processing => f.write_str("Processing"), } } } @@ -428,7 +429,7 @@ impl Chat<'static> { if state.focus == Focus::ShortcutHints { state.focus = self.prev_focus.clone().unwrap_or_default(); } - if state.state == ChatState::Procesing { + if state.state == ChatState::Processing { // Persist Ready to avoid restoring a stale processing state. state.state = ChatState::Ready; } @@ -1230,8 +1231,8 @@ impl Chat<'static> { } fn set_processing(&mut self) { - if self.state.state != ChatState::Procesing { - self.state.write().state = ChatState::Procesing; + if self.state.state != ChatState::Processing { + self.state.write().state = ChatState::Processing; } } @@ -1645,7 +1646,7 @@ impl Chat<'static> { ChatState::Ready => { Line::from(Span::styled(format!(" {state} "), theme.ui.status_ready)) } - ChatState::Procesing => Line::from(vec![ + ChatState::Processing => Line::from(vec![ Span::raw(" "), Throbber::default() .throbber_set(BRAILLE_EIGHT_DOUBLE) @@ -2468,7 +2469,7 @@ impl Component for Chat<'static> { fn on_tick(&mut self) { self.cancellation_guard.on_trick(); - if self.state.state == ChatState::Procesing { + if self.state.state == ChatState::Processing { self.indicator.calc_next(); global::signal_dirty(); } diff --git a/crates/coco-tui/src/components/chat/state.rs b/crates/coco-tui/src/components/chat/state.rs index 14dc924e..04266945 100644 --- a/crates/coco-tui/src/components/chat/state.rs +++ b/crates/coco-tui/src/components/chat/state.rs @@ -18,14 +18,15 @@ const CTRL_C_WINDOW: Duration = Duration::from_secs(2); pub enum ChatState { #[default] Ready, - Procesing, + #[serde(alias = "Procesing")] + Processing, } impl std::fmt::Display for ChatState { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Self::Ready => f.write_str("Ready"), - Self::Procesing => f.write_str("Procesing"), + Self::Processing => f.write_str("Processing"), } } } diff --git a/src/agent.rs b/src/agent.rs index b3ba2cb7..e17eca5f 100644 --- a/src/agent.rs +++ b/src/agent.rs @@ -343,6 +343,31 @@ impl Agent { *self.messages.lock().await = messages.to_vec(); } + async fn build_chat_response_from_blocks( + &mut self, + blocks: Vec, + stop_reason: Option, + usage: Option, + request_options: &RequestOptions, + ) -> ChatResponse { + let message = if blocks.is_empty() { + Message::assistant(Content::Multiple(Vec::default())) + } else { + let mut msg = Message::assistant(Content::Multiple(blocks)); + if request_options.stringify_nested_tool_inputs { + parse_stringified_tool_inputs_in_message(&mut msg, &self.executor); + } + self.messages.lock().await.push(msg.clone()); + msg + }; + self.mark_thinking_cleanup_pending(stop_reason.as_ref()); + ChatResponse { + message, + stop_reason, + usage, + } + } + pub async fn chat(&mut self, message: Message) -> Result { let request_options = self.request_options_for_current_model(); let (_, client) = self.pick_provider()?; @@ -370,24 +395,15 @@ impl Agent { }) .whatever_context_display("failed to send messages")?; - let stop_reason = response.stop_reason.clone(); - let usage = response.usage.clone(); - let message = if response.content.is_empty() { - Message::assistant(Content::Multiple(Vec::default())) - } else { - let mut msg = Message::assistant(Content::Multiple(response.content)); - if request_options.stringify_nested_tool_inputs { - parse_stringified_tool_inputs_in_message(&mut msg, &self.executor); - } - self.messages.lock().await.push(msg.clone()); - msg - }; - self.mark_thinking_cleanup_pending(stop_reason.as_ref()); - Ok(ChatResponse { - message, + let crate::provider::MessagesResponse { + content, stop_reason, usage, - }) + .. + } = response; + Ok(self + .build_chat_response_from_blocks(content, stop_reason, usage, &request_options) + .await) } pub async fn chat_with_history(&mut self) -> Result { @@ -412,24 +428,15 @@ impl Agent { }) .whatever_context_display("failed to send messages")?; - let stop_reason = response.stop_reason.clone(); - let usage = response.usage.clone(); - let message = if response.content.is_empty() { - Message::assistant(Content::Multiple(Vec::default())) - } else { - let mut msg = Message::assistant(Content::Multiple(response.content)); - if request_options.stringify_nested_tool_inputs { - parse_stringified_tool_inputs_in_message(&mut msg, &self.executor); - } - self.messages.lock().await.push(msg.clone()); - msg - }; - self.mark_thinking_cleanup_pending(stop_reason.as_ref()); - Ok(ChatResponse { - message, + let crate::provider::MessagesResponse { + content, stop_reason, usage, - }) + .. + } = response; + Ok(self + .build_chat_response_from_blocks(content, stop_reason, usage, &request_options) + .await) } pub async fn chat_stream( @@ -574,25 +581,12 @@ impl Agent { } let (blocks, stop_reason, usage) = accumulator.finish(); - let message = if blocks.is_empty() { - Message::assistant(Content::Multiple(Vec::default())) - } else { - let mut msg = Message::assistant(Content::Multiple(blocks)); - if request_options.stringify_nested_tool_inputs { - parse_stringified_tool_inputs_in_message(&mut msg, &self.executor); - } - self.messages.lock().await.push(msg.clone()); - msg - }; - self.mark_thinking_cleanup_pending(stop_reason.as_ref()); if retried { notify_stream_retry_finished(request_options, true); } - return Ok(ChatResponse { - message, - stop_reason, - usage, - }); + return Ok(self + .build_chat_response_from_blocks(blocks, stop_reason, usage, request_options) + .await); } } diff --git a/src/agent/prompt.rs b/src/agent/prompt.rs index da1bdc30..7d1eae51 100644 --- a/src/agent/prompt.rs +++ b/src/agent/prompt.rs @@ -170,6 +170,34 @@ pub fn substitute_template(template: &str, args: &HashMap) -> St result } +fn append_prompt_content(current_prompt: &mut String, content: &str) { + if content.trim().is_empty() { + return; + } + if !current_prompt.trim().is_empty() { + current_prompt.push_str("\n\n"); + } + current_prompt.push_str(content); +} + +fn append_prompt_file_if_exists(current_prompt: &mut String, path: &Path) { + if !path.exists() { + return; + } + if let Ok(content) = std::fs::read_to_string(path) { + append_prompt_content(current_prompt, &content); + } +} + +async fn append_prompt_file_if_exists_async(current_prompt: &mut String, path: &Path) { + if !path.exists() { + return; + } + if let Ok(content) = tokio::fs::read_to_string(path).await { + append_prompt_content(current_prompt, &content); + } +} + /// Build system prompt from SystemPromptConfig. /// /// Combines agent.toml system_prompt with AGENTS.md files from multiple layers: @@ -201,27 +229,11 @@ pub fn build_system_prompt_from_config( // 2. Append: global AGENTS.md let global_agents_md = config_dir.join("AGENTS.md"); - if global_agents_md.exists() - && let Ok(content) = std::fs::read_to_string(&global_agents_md) - && !content.trim().is_empty() - { - if !current_prompt.trim().is_empty() { - current_prompt.push_str("\n\n"); - } - current_prompt.push_str(&content); - } + append_prompt_file_if_exists(&mut current_prompt, &global_agents_md); // 3. Append: workspace AGENTS.md let workspace_agents_md = workspace_dir.join("AGENTS.md"); - if workspace_agents_md.exists() - && let Ok(content) = std::fs::read_to_string(&workspace_agents_md) - && !content.trim().is_empty() - { - if !current_prompt.trim().is_empty() { - current_prompt.push_str("\n\n"); - } - current_prompt.push_str(&content); - } + append_prompt_file_if_exists(&mut current_prompt, &workspace_agents_md); current_prompt } @@ -259,27 +271,11 @@ pub async fn build_system_prompt_from_config_async( // 2. Append: global AGENTS.md let global_agents_md = config_dir.join("AGENTS.md"); - if global_agents_md.exists() - && let Ok(content) = tokio::fs::read_to_string(&global_agents_md).await - && !content.trim().is_empty() - { - if !current_prompt.trim().is_empty() { - current_prompt.push_str("\n\n"); - } - current_prompt.push_str(&content); - } + append_prompt_file_if_exists_async(&mut current_prompt, &global_agents_md).await; // 3. Append: workspace AGENTS.md let workspace_agents_md = workspace_dir.join("AGENTS.md"); - if workspace_agents_md.exists() - && let Ok(content) = tokio::fs::read_to_string(&workspace_agents_md).await - && !content.trim().is_empty() - { - if !current_prompt.trim().is_empty() { - current_prompt.push_str("\n\n"); - } - current_prompt.push_str(&content); - } + append_prompt_file_if_exists_async(&mut current_prompt, &workspace_agents_md).await; current_prompt } diff --git a/src/combo/runner.rs b/src/combo/runner.rs index dd46ceef..e7254f14 100644 --- a/src/combo/runner.rs +++ b/src/combo/runner.rs @@ -30,9 +30,9 @@ use crate::tools::{ pub type ExecuteResult = Result; use crate::{ - Agent, Block, ChatStreamUpdate, Config, Content, Message, OutputChunk, PromptResponseSender, - PromptSchema, SESSION_SOCKET_ENV, SessionEnv, Starter, StarterCommand, StarterError, - StarterEvent, ToolUse, bash_unsafe_ranges, bash_unsafe_reason, discover_starters, + Agent, Block, ChatResponse, ChatStreamUpdate, Config, Content, Message, OutputChunk, + PromptResponseSender, PromptSchema, SESSION_SOCKET_ENV, SessionEnv, Starter, StarterCommand, + StarterError, StarterEvent, ToolUse, bash_unsafe_ranges, bash_unsafe_reason, discover_starters, exec::StreamKind, parse_primary_command, workspace_dir, }; @@ -493,46 +493,24 @@ async fn execute_combo( thinking.clone(), cancel_token.clone(), move |update| { - match update { - ChatStreamUpdate::Reset => { - thinking_seen_stream.store( - false, - std::sync::atomic::Ordering::Relaxed, - ); - emit_combo_event( - &on_event_stream, - ComboEvent::PromptStreamReset { - name: stream_name.clone(), - }, - ); - } - ChatStreamUpdate::Plain { index, text } => { - emit_combo_event( - &on_event_stream, - ComboEvent::PromptStream { - name: stream_name.clone(), - index, - kind: ComboEventStreamKind::Plain, - text, - }, - ); - } - ChatStreamUpdate::Thinking { index, text } => { - thinking_seen_stream.store( - true, - std::sync::atomic::Ordering::Relaxed, - ); - emit_combo_event( - &on_event_stream, - ComboEvent::PromptStream { - name: stream_name.clone(), - index, - kind: ComboEventStreamKind::Thinking, - text, - }, - ); - } + if matches!(&update, ChatStreamUpdate::Thinking { .. }) + { + thinking_seen_stream.store( + true, + std::sync::atomic::Ordering::Relaxed, + ); } + if matches!(&update, ChatStreamUpdate::Reset) { + thinking_seen_stream.store( + false, + std::sync::atomic::Ordering::Relaxed, + ); + } + emit_prompt_stream_update( + &on_event_stream, + &stream_name, + update, + ); }, ) .await @@ -737,6 +715,69 @@ fn emit_combo_event(on_event: &ComboEventCallback, event: ComboEvent) { } } +fn emit_prompt_stream_update( + on_event: &ComboEventCallback, + combo_name: &str, + update: ChatStreamUpdate, +) { + match update { + ChatStreamUpdate::Reset => emit_combo_event( + on_event, + ComboEvent::PromptStreamReset { + name: combo_name.to_string(), + }, + ), + ChatStreamUpdate::Plain { index, text } => emit_combo_event( + on_event, + ComboEvent::PromptStream { + name: combo_name.to_string(), + index, + kind: ComboEventStreamKind::Plain, + text, + }, + ), + ChatStreamUpdate::Thinking { index, text } => emit_combo_event( + on_event, + ComboEvent::PromptStream { + name: combo_name.to_string(), + index, + kind: ComboEventStreamKind::Thinking, + text, + }, + ), + } +} + +async fn chat_with_history_for_combo( + agent: &mut Agent, + combo_name: &str, + cancel_token: CancellationToken, + on_event: &ComboEventCallback, +) -> Result<(ChatResponse, bool), ComboReplyError> { + if agent.disable_stream_for_current_model() { + let response = + agent + .chat_with_history() + .await + .map_err(|e| ComboReplyError::ChatFailed { + message: e.to_string(), + })?; + return Ok((response, false)); + } + + let on_event_stream = on_event.clone(); + let stream_name = combo_name.to_string(); + let response = agent + .chat_stream_with_history(cancel_token, move |update| { + emit_prompt_stream_update(&on_event_stream, &stream_name, update); + }) + .await + .map_err(|e| ComboReplyError::ChatFailed { + message: e.to_string(), + })?; + Ok((response, true)) +} + async fn emit_combo_transcript(on_event: &ComboEventCallback, name: &str, agent: &Agent) { let messages = agent.dump_messages().await; if messages.is_empty() { @@ -1200,62 +1241,15 @@ async fn handle_interactive_combo_reply( if cancel_token.is_cancelled() { return Err(ComboReplyError::Cancelled); } - let disable_stream = agent.disable_stream_for_current_model(); - let response = if disable_stream { - agent - .chat_with_history() - .await - .map_err(|e| ComboReplyError::ChatFailed { - message: e.to_string(), - })? - } else { - let on_event_stream = on_event.clone(); - let stream_name = combo_name.to_string(); - agent - .chat_stream_with_history(cancel_token.clone(), move |update| match update { - ChatStreamUpdate::Reset => { - emit_combo_event( - &on_event_stream, - ComboEvent::PromptStreamReset { - name: stream_name.clone(), - }, - ); - } - ChatStreamUpdate::Plain { index, text } => { - emit_combo_event( - &on_event_stream, - ComboEvent::PromptStream { - name: stream_name.clone(), - index, - kind: ComboEventStreamKind::Plain, - text, - }, - ); - } - ChatStreamUpdate::Thinking { index, text } => { - emit_combo_event( - &on_event_stream, - ComboEvent::PromptStream { - name: stream_name.clone(), - index, - kind: ComboEventStreamKind::Thinking, - text, - }, - ); - } - }) - .await - .map_err(|e| ComboReplyError::ChatFailed { - message: e.to_string(), - })? - }; + let (response, used_stream) = + chat_with_history_for_combo(agent, combo_name, cancel_token.clone(), on_event).await?; let tool_uses = extract_tool_uses(&response.message); if tool_uses.is_empty() { empty_tool_use_turns += 1; let reminder = build_interactive_offload_reply_reminder(schemas); let should_pause_for_feedback = empty_tool_use_turns > AUTO_IMPLICIT_NUDGE_LIMIT; - if disable_stream { + if !used_stream { let response_text = extract_text_response(&response.message); if should_pause_for_feedback && !response_text.trim().is_empty() { emit_combo_event( @@ -1683,54 +1677,8 @@ async fn handle_offload_combo_reply( .append_message(Message::user(Content::Text(directive.to_string()))) .await; - let chat_response = if agent.disable_stream_for_current_model() { - agent - .chat_with_history() - .await - .map_err(|e| ComboReplyError::ChatFailed { - message: e.to_string(), - })? - } else { - let on_event_stream = on_event.clone(); - let stream_name = combo_name.to_string(); - agent - .chat_stream_with_history(cancel_token.clone(), move |update| match update { - ChatStreamUpdate::Reset => { - emit_combo_event( - &on_event_stream, - ComboEvent::PromptStreamReset { - name: stream_name.clone(), - }, - ); - } - ChatStreamUpdate::Plain { index, text } => { - emit_combo_event( - &on_event_stream, - ComboEvent::PromptStream { - name: stream_name.clone(), - index, - kind: ComboEventStreamKind::Plain, - text, - }, - ); - } - ChatStreamUpdate::Thinking { index, text } => { - emit_combo_event( - &on_event_stream, - ComboEvent::PromptStream { - name: stream_name.clone(), - index, - kind: ComboEventStreamKind::Thinking, - text, - }, - ); - } - }) - .await - .map_err(|e| ComboReplyError::ChatFailed { - message: e.to_string(), - })? - }; + let (chat_response, _) = + chat_with_history_for_combo(agent, combo_name, cancel_token.clone(), on_event).await?; let blocks = match &chat_response.message.content { Content::Multiple(blocks) => blocks.as_slice(), diff --git a/src/combo/starter.rs b/src/combo/starter.rs index 664270f7..98386a9e 100644 --- a/src/combo/starter.rs +++ b/src/combo/starter.rs @@ -372,12 +372,9 @@ fn build_combo_from_session( session_state: Option, ) -> Result { if let Some(state) = session_state { - let metadata = state.metadata.ok_or_else(|| { - InvalidSnafu { - reason: "metadata not received from session".to_string(), - } - .build() - })?; + let metadata = state + .metadata + .ok_or_else(|| invalid_error("metadata not received from session"))?; return Ok(Combo { metadata: ComboMetadata { name: metadata.name, @@ -389,6 +386,38 @@ fn build_combo_from_session( Ok(parse_combo(command)) } +fn invalid_error(reason: impl Into) -> StarterError { + InvalidSnafu { + reason: reason.into(), + } + .build() +} + +fn reply_validation_error(error: impl Into) -> ReplyValidation { + ReplyValidation { + success: false, + error: Some(error.into()), + response: None, + } +} + +fn reply_validation_success(response: String) -> ReplyValidation { + ReplyValidation { + success: true, + error: None, + response: Some(response), + } +} + +async fn send_reply_validation_best_effort( + conn: &mut ServerConnection, + validation: ReplyValidation, +) { + let _ = conn + .send_server_message(&ServerMessage::ReplyValidation(validation)) + .await; +} + /// Helper macro to send a failed event and return early with the error. /// Used within the execute_command async block where `event_tx` and `command` are in scope. macro_rules! fail_early { @@ -437,10 +466,7 @@ fn execute_command( Err(error) => fail_early!(event_tx, command, error), }, None => { - let error = InvalidSnafu { - reason: "session env is required for starter execution".to_string(), - } - .build(); + let error = invalid_error("session env is required for starter execution"); fail_early!(event_tx, command, error); } }; @@ -460,10 +486,7 @@ fn execute_command( Err(err) => { let error = match err.kind() { ErrorKind::PermissionDenied => NotExcutableSnafu.build(), - _ => InvalidSnafu { - reason: format!("excuting error: {err}"), - } - .build(), + _ => invalid_error(format!("excuting error: {err}")), }; fail_early!(event_tx, command, error); } @@ -531,13 +554,10 @@ fn execute_command( if let Some(task) = session_server.take() { task.abort(); } - Err(InvalidSnafu { - reason: format!( - "starter exited with status {:?} during discovery", - status.code() - ), - } - .build()) + Err(invalid_error(format!( + "starter exited with status {:?} during discovery", + status.code() + ))) } else { if discovery { if let Some(task) = session_server.take() { @@ -555,10 +575,8 @@ fn execute_command( Ok(Ok(state)) => Some(state), Ok(Err(err)) => fail_early!(event_tx, command, err), Err(err) => { - let error = InvalidSnafu { - reason: format!("session server join error: {err}"), - } - .build(); + let error = + invalid_error(format!("session server join error: {err}")); fail_early!(event_tx, command, error); } } @@ -570,10 +588,7 @@ fn execute_command( } } Err(err) => { - let error = InvalidSnafu { - reason: format!("excuting error: {err}"), - } - .build(); + let error = invalid_error(format!("excuting error: {err}")); if let Some(task) = session_server.take() { task.abort(); } @@ -699,10 +714,9 @@ async fn handle_session_connection( { break; } - return Err(InvalidSnafu { - reason: format!("failed to read session message: {err}"), - } - .build()); + return Err(invalid_error(format!( + "failed to read session message: {err}" + ))); } }; @@ -710,41 +724,33 @@ async fn handle_session_connection( ClientMessage::Metadata(payload) => { if !first_message { let _ = conn.interrupt().await; - return Err(InvalidSnafu { - reason: "metadata must be the first and only metadata message".to_string(), - } - .build()); + return Err(invalid_error( + "metadata must be the first and only metadata message", + )); } { let mut guard = state.lock().await; if guard.metadata.is_some() { let _ = conn.interrupt().await; - return Err(InvalidSnafu { - reason: "metadata must be the first and only metadata message" - .to_string(), - } - .build()); + return Err(invalid_error( + "metadata must be the first and only metadata message", + )); } guard.metadata = Some(payload); } conn.send_server_message(&ServerMessage::Metadata(MetadataResponse { discovery })) .await .map_err(|err| { - InvalidSnafu { - reason: format!("failed to send metadata response: {err}"), - } - .build() + invalid_error(format!("failed to send metadata response: {err}")) })?; first_message = false; } ClientMessage::RecordStart(payload) => { if discovery { let _ = conn.interrupt().await; - return Err(InvalidSnafu { - reason: "record commands are not allowed in discovery or before metadata" - .to_string(), - } - .build()); + return Err(invalid_error( + "record commands are not allowed in discovery or before metadata", + )); } let (record_index, name) = { let mut guard = state.lock().await; @@ -755,12 +761,9 @@ async fn handle_session_connection( if metadata_name.is_none() { drop(guard); let _ = conn.interrupt().await; - return Err(InvalidSnafu { - reason: - "record commands are not allowed in discovery or before metadata" - .to_string(), - } - .build()); + return Err(invalid_error( + "record commands are not allowed in discovery or before metadata", + )); } let record_index = guard.next_event_index(); (record_index, metadata_name.unwrap()) @@ -793,10 +796,9 @@ async fn handle_session_connection( ClientMessage::RecordChunk(chunk) => { if discovery { let _ = conn.interrupt().await; - return Err(InvalidSnafu { - reason: "record chunk is not allowed during discovery".to_string(), - } - .build()); + return Err(invalid_error( + "record chunk is not allowed during discovery", + )); } let stream = chunk.stream; let lines = chunk.lines; @@ -829,10 +831,7 @@ async fn handle_session_connection( }) => { if discovery { let _ = conn.interrupt().await; - return Err(InvalidSnafu { - reason: "record end is not allowed during discovery".to_string(), - } - .build()); + return Err(invalid_error("record end is not allowed during discovery")); } let Some(mut record) = current_record.take() else { continue; @@ -862,39 +861,26 @@ async fn handle_session_connection( ClientMessage::Prompt(payload) => { let metadata = { state.lock().await.metadata.clone() }; if metadata.is_none() { - return Err(InvalidSnafu { - reason: "prompt is not allowed before metadata".to_string(), - } - .build()); + return Err(invalid_error("prompt is not allowed before metadata")); } if payload.interactive && !payload.reply { - return Err(InvalidSnafu { - reason: "interactive prompt requires reply mode".to_string(), - } - .build()); + return Err(invalid_error("interactive prompt requires reply mode")); } if payload.reply { if discovery { - return Err(InvalidSnafu { - reason: "prompt reply is not allowed during discovery".to_string(), - } - .build()); + return Err(invalid_error( + "prompt reply is not allowed during discovery", + )); } if payload.schemas.is_empty() { - return Err(InvalidSnafu { - reason: "prompt reply requires schemas".to_string(), - } - .build()); + return Err(invalid_error("prompt reply requires schemas")); } let (response_tx, response_rx) = oneshot::channel(); let responder = PromptResponseSender::new(response_tx); { let mut guard = state.lock().await; if guard.pending_reply_schemas.is_some() { - return Err(InvalidSnafu { - reason: "prompt reply already in progress".to_string(), - } - .build()); + return Err(invalid_error("prompt reply already in progress")); } guard.pending_reply_schemas = Some(payload.schemas.clone()); guard.pending_reply_responder = Some(responder.clone()); @@ -912,10 +898,7 @@ async fn handle_session_connection( .is_err() { clear_pending_reply(&state).await; - return Err(InvalidSnafu { - reason: "prompt responder is not available".to_string(), - } - .build()); + return Err(invalid_error("prompt responder is not available")); } { let mut guard = state.lock().await; @@ -925,30 +908,21 @@ async fn handle_session_connection( Ok(response) => response, Err(_) => { clear_pending_reply(&state).await; - return Err(InvalidSnafu { - reason: "prompt responder dropped response".to_string(), - } - .build()); + return Err(invalid_error("prompt responder dropped response")); } }; let response = match response { Ok(response) => response, Err(err) => { clear_pending_reply(&state).await; - return Err(InvalidSnafu { - reason: format!("prompt responder failed: {err}"), - } - .build()); + return Err(invalid_error(format!("prompt responder failed: {err}"))); } }; clear_pending_reply(&state).await; conn.send_server_message(&ServerMessage::PromptResponse(response)) .await .map_err(|err| { - InvalidSnafu { - reason: format!("failed to send prompt response: {err}"), - } - .build() + invalid_error(format!("failed to send prompt response: {err}")) })?; } else if !discovery { event_tx @@ -966,35 +940,26 @@ async fn handle_session_connection( } ClientMessage::Reply(payload) => { if discovery { - let _ = conn - .send_server_message(&ServerMessage::ReplyValidation(ReplyValidation { - success: false, - error: Some("reply is not allowed during discovery".to_string()), - response: None, - })) - .await; + send_reply_validation_best_effort( + &mut conn, + reply_validation_error("reply is not allowed during discovery"), + ) + .await; continue; } let schemas = { state.lock().await.pending_reply_schemas.clone() }; let Some(schemas) = schemas else { - let _ = conn - .send_server_message(&ServerMessage::ReplyValidation(ReplyValidation { - success: false, - error: Some("reply is not expected in combo session".to_string()), - response: None, - })) - .await; + send_reply_validation_best_effort( + &mut conn, + reply_validation_error("reply is not expected in combo session"), + ) + .await; continue; }; let parsed = match parse_reply_fields(&payload.fields, &schemas) { Ok(parsed) => parsed, Err(err) => { - let _ = conn - .send_server_message(&ServerMessage::ReplyValidation(ReplyValidation { - success: false, - error: Some(err), - response: None, - })) + send_reply_validation_best_effort(&mut conn, reply_validation_error(err)) .await; continue; } @@ -1002,13 +967,13 @@ async fn handle_session_connection( let response = match serde_json::to_string(&parsed) { Ok(value) => value, Err(err) => { - let _ = conn - .send_server_message(&ServerMessage::ReplyValidation(ReplyValidation { - success: false, - error: Some(format!("failed to serialize reply output: {err}")), - response: None, - })) - .await; + send_reply_validation_best_effort( + &mut conn, + reply_validation_error(format!( + "failed to serialize reply output: {err}" + )), + ) + .await; continue; } }; @@ -1019,31 +984,18 @@ async fn handle_session_connection( let _ = responder.send(Ok(response.clone())); } } - conn.send_server_message(&ServerMessage::ReplyValidation(ReplyValidation { - success: true, - error: None, - response: Some(response), - })) + conn.send_server_message(&ServerMessage::ReplyValidation( + reply_validation_success(response), + )) .await - .map_err(|err| { - InvalidSnafu { - reason: format!("failed to send reply validation: {err}"), - } - .build() - })?; + .map_err(|err| invalid_error(format!("failed to send reply validation: {err}")))?; first_message = false; } ClientMessage::ComboRun(_) => { - return Err(InvalidSnafu { - reason: "combo run is not allowed in combo session".to_string(), - } - .build()); + return Err(invalid_error("combo run is not allowed in combo session")); } ClientMessage::Mcp(_) => { - return Err(InvalidSnafu { - reason: "mcp request is not allowed in combo session".to_string(), - } - .build()); + return Err(invalid_error("mcp request is not allowed in combo session")); } } } @@ -1106,10 +1058,7 @@ async fn run_session_server( let state = state.lock().await.clone(); if state.metadata.is_none() { - return Err(InvalidSnafu { - reason: "metadata not received from session".to_string(), - } - .build()); + return Err(invalid_error("metadata not received from session")); } Ok(state) diff --git a/src/config.rs b/src/config.rs index d279920f..76703324 100644 --- a/src/config.rs +++ b/src/config.rs @@ -102,8 +102,7 @@ impl Config { pub fn request_options_for_model(&self, model: &str) -> RequestOptions { let mut options = RequestOptions::default(); - let builtin = presets::builtin_model_presets(); - apply_model_presets(&mut options, &builtin, model); + apply_model_presets(&mut options, presets::builtin_model_presets(), model); apply_model_presets(&mut options, &self.model_presets, model); options } diff --git a/src/config/env.rs b/src/config/env.rs index ccfccf37..b0299117 100644 --- a/src/config/env.rs +++ b/src/config/env.rs @@ -15,14 +15,15 @@ impl EnvString { pub fn get(&mut self) -> Result<&str> { match self { EnvString::EnvVar { name, value } => { - if let Some(value) = value { - return Ok(value.as_str()); - } else { + if value.is_none() { let value_from_env = std::env::var(&name) .whatever_context(format!("failed to get {name} from env"))?; *value = Some(value_from_env); } - Ok(value.as_ref().unwrap().as_str()) + match value.as_deref() { + Some(cached) => Ok(cached), + None => unreachable!("env value should be cached after loading"), + } } EnvString::String(value) => Ok(value), } diff --git a/src/config/presets.rs b/src/config/presets.rs index 8be3f626..a4ad0ea7 100644 --- a/src/config/presets.rs +++ b/src/config/presets.rs @@ -10,7 +10,7 @@ struct PresetFile { model_presets: Vec, } -pub(crate) fn builtin_model_presets() -> Vec { +pub(crate) fn builtin_model_presets() -> &'static [ModelPreset] { static PRESETS: OnceLock> = OnceLock::new(); PRESETS .get_or_init(|| { @@ -19,5 +19,17 @@ pub(crate) fn builtin_model_presets() -> Vec { toml::from_str(content).expect("failed to parse builtin presets"); parsed.model_presets }) - .clone() + .as_slice() +} + +#[cfg(test)] +mod tests { + use super::builtin_model_presets; + + #[test] + fn builtin_presets_are_cached_once() { + let first = builtin_model_presets(); + let second = builtin_model_presets(); + assert_eq!(first.as_ptr(), second.as_ptr()); + } } diff --git a/src/logging.rs b/src/logging.rs index e1bc8679..84f7e4ce 100644 --- a/src/logging.rs +++ b/src/logging.rs @@ -12,10 +12,10 @@ fn default_logs_dir() -> PathBuf { PathBuf::from(".coco").join("logs") } -fn sanitize_file_stem(stem: &str) -> String { +pub(crate) fn sanitize_log_stem(stem: &str) -> String { let trimmed = stem.trim(); if trimmed.is_empty() { - return "coco".to_string(); + return String::new(); } let mut out = String::with_capacity(trimmed.len()); @@ -42,7 +42,12 @@ pub fn init_file_logging(log_name: &str) -> Result { fs::create_dir_all(&logs_dir) .whatever_context(format!("failed to create logs dir {}", logs_dir.display()))?; - let file_stem = sanitize_file_stem(log_name); + let file_stem = sanitize_log_stem(log_name); + let file_stem = if file_stem.is_empty() { + "coco".to_string() + } else { + file_stem + }; let file_name = ensure_log_extension(&file_stem); let log_path = logs_dir.join(file_name); @@ -81,3 +86,19 @@ pub fn init_file_logging_best_effort(log_name: &str) -> Option { pub fn logs_dir() -> &'static Path { Path::new(".coco/logs") } + +#[cfg(test)] +mod tests { + use super::sanitize_log_stem; + + #[test] + fn sanitize_log_stem_replaces_unsupported_characters() { + assert_eq!(sanitize_log_stem(" mcp/server:dev "), "mcp_server_dev"); + } + + #[test] + fn sanitize_log_stem_trims_edge_separators() { + assert_eq!(sanitize_log_stem(".__-name-__."), "name"); + assert_eq!(sanitize_log_stem(" "), ""); + } +} diff --git a/src/mcp.rs b/src/mcp.rs index 341d0809..f41b1f8c 100644 --- a/src/mcp.rs +++ b/src/mcp.rs @@ -23,7 +23,7 @@ use crate::{ Config, SessionSocketServer, combo::{ClientMessage, ServerMessage}, config::McpConfig, - logging::logs_dir, + logging::{logs_dir, sanitize_log_stem}, }; pub const MCP_SOCKET_ENV: &str = "COCO_MCP_SOCK"; @@ -150,6 +150,40 @@ pub enum McpManagerError { type ManagerResult = std::result::Result; +#[allow(clippy::result_large_err)] +fn serialize_to_json_value(value: T, label: &str) -> ManagerResult { + serde_json::to_value(value).map_err(|err| McpManagerError::InvalidArguments { + reason: format!("failed to serialize {label}: {err}"), + }) +} + +#[allow(clippy::result_large_err)] +fn parse_call_tool_payload(payload: Option) -> ManagerResult { + match payload { + Some(value) => serde_json::from_value::(value).map_err(|err| { + McpManagerError::InvalidArguments { + reason: format!("invalid call_tool payload: {err}"), + } + }), + None => Err(McpManagerError::InvalidArguments { + reason: "missing call_tool payload".to_string(), + }), + } +} + +fn require_server_for_action<'a>( + request_id: &str, + action_name: &str, + server: Option<&'a str>, +) -> Result<&'a str, McpResponse> { + server.ok_or_else(|| { + McpResponse::err( + request_id.to_string(), + format!("server is required for {action_name}"), + ) + }) +} + struct McpClientEntry { service: RunningService, peer: Peer, @@ -218,10 +252,7 @@ impl McpManager { }), ) .await?; - let value = - serde_json::to_value(result).map_err(|err| McpManagerError::InvalidArguments { - reason: format!("failed to serialize tool result: {err}"), - })?; + let value = serialize_to_json_value(result, "tool result")?; Ok(value) } @@ -390,23 +421,6 @@ impl McpManager { } } -fn sanitize_log_stem(stem: &str) -> String { - let trimmed = stem.trim(); - if trimmed.is_empty() { - return String::new(); - } - - let mut out = String::with_capacity(trimmed.len()); - for ch in trimmed.chars() { - match ch { - 'a'..='z' | 'A'..='Z' | '0'..='9' | '-' | '_' | '.' => out.push(ch), - _ => out.push('_'), - } - } - - out.trim_matches(['.', '-', '_']).to_string() -} - fn tool_to_info(tool: Tool) -> McpToolInfo { let input_schema = Value::Object((*tool.input_schema).clone()); McpToolInfo { @@ -524,46 +538,41 @@ async fn handle_connection( } async fn handle_request(manager: std::sync::Arc, request: McpRequest) -> McpResponse { - let request_id = request.request_id.clone(); - let result = match request.action { + let McpRequest { + request_id, + server, + action, + payload, + timeout_ms, + } = request; + let result = match action { McpAction::ListServers => { let list = manager.list_servers(); - serde_json::to_value(list).map_err(|err| McpManagerError::InvalidArguments { - reason: format!("failed to serialize servers: {err}"), - }) + serialize_to_json_value(list, "servers") } McpAction::ListTools => { - let Some(server) = request.server.as_deref() else { - return McpResponse::err(request_id, "server is required for list_tools"); - }; + let server = + match require_server_for_action(&request_id, "list_tools", server.as_deref()) { + Ok(server) => server, + Err(response) => return response, + }; manager - .list_tools(server, request.timeout_ms) + .list_tools(server, timeout_ms) .await - .and_then(|tools| { - serde_json::to_value(tools).map_err(|err| McpManagerError::InvalidArguments { - reason: format!("failed to serialize tools: {err}"), - }) - }) + .and_then(|tools| serialize_to_json_value(tools, "tools")) } McpAction::CallTool => { - let Some(server) = request.server.as_deref() else { - return McpResponse::err(request_id, "server is required for call_tool"); - }; - let payload = match request.payload { - Some(value) => serde_json::from_value::(value).map_err(|err| { - McpManagerError::InvalidArguments { - reason: format!("invalid call_tool payload: {err}"), - } - }), - None => Err(McpManagerError::InvalidArguments { - reason: "missing call_tool payload".to_string(), - }), - }; + let server = + match require_server_for_action(&request_id, "call_tool", server.as_deref()) { + Ok(server) => server, + Err(response) => return response, + }; + let payload = parse_call_tool_payload(payload); let payload = match payload { Ok(payload) => payload, - Err(err) => return McpResponse::err(request_id, err.to_string()), + Err(err) => return McpResponse::err(request_id.clone(), err.to_string()), }; - manager.call_tool(server, payload, request.timeout_ms).await + manager.call_tool(server, payload, timeout_ms).await } }; diff --git a/src/provider.rs b/src/provider.rs index 6585e171..36909c0b 100644 --- a/src/provider.rs +++ b/src/provider.rs @@ -2,7 +2,7 @@ mod anthropic; mod openai; mod types; -use std::{pin::Pin, time::Duration}; +use std::{future::Future, pin::Pin, time::Duration}; use crate::StreamError; use futures_core::Stream; @@ -58,43 +58,43 @@ impl Client { thinking: Option, request_options: &RequestOptions, ) -> Result { - let result: Result = match self { - Client::Anthropic(client) => { - let temperature = request_options.temperature.map(f64::from); - let max_tokens = request_options.max_tokens; - let response = client - .messages() - .maybe_system_prompt(system_prompt) - .conversations( - conversations - .into_iter() - .map(anthropic_api::Message::from) - .collect(), - ) - .tools(tools.into_iter().map(anthropic_api::Tool::from).collect()) - .maybe_thinking(thinking.map(anthropic_api::Thinking::from)) - .maybe_temperature(temperature) - .maybe_max_tokens(max_tokens) - .retry_config(anthropic_retry_config(request_options)) - .call() - .await - .whatever_context_display("failed to send messages")?; - Ok(response.into()) + complete_with_retry_status(request_options, async { + match self { + Client::Anthropic(client) => { + let (temperature, max_tokens) = anthropic_request_params(request_options); + let response = client + .messages() + .maybe_system_prompt(system_prompt) + .conversations( + conversations + .into_iter() + .map(anthropic_api::Message::from) + .collect(), + ) + .tools(tools.into_iter().map(anthropic_api::Tool::from).collect()) + .maybe_thinking(thinking.map(anthropic_api::Thinking::from)) + .maybe_temperature(temperature) + .maybe_max_tokens(max_tokens) + .retry_config(anthropic_retry_config(request_options)) + .call() + .await + .whatever_context_display("failed to send messages")?; + Ok(response.into()) + } + Client::OpenAI(client) => openai::messages( + client, + system_prompt, + conversations, + tools, + None, + thinking, + request_options, + ) + .await + .whatever_context_display("failed to send messages"), } - Client::OpenAI(client) => openai::messages( - client, - system_prompt, - conversations, - tools, - None, - thinking, - request_options, - ) - .await - .whatever_context_display("failed to send messages"), - }; - notify_retry_finished(request_options, result.is_ok()); - result + }) + .await } pub async fn messages_stream( @@ -105,48 +105,46 @@ impl Client { thinking: Option, request_options: &RequestOptions, ) -> Result { - let result: Result = match self { - Client::Anthropic(client) => { - let temperature = request_options.temperature.map(f64::from); - let max_tokens = request_options.max_tokens; - let stream = client - .messages_stream() - .maybe_system_prompt(system_prompt) - .conversations( - conversations - .into_iter() - .map(anthropic_api::Message::from) - .collect(), + complete_with_retry_status(request_options, async { + match self { + Client::Anthropic(client) => { + let (temperature, max_tokens) = anthropic_request_params(request_options); + let stream = client + .messages_stream() + .maybe_system_prompt(system_prompt) + .conversations( + conversations + .into_iter() + .map(anthropic_api::Message::from) + .collect(), + ) + .tools(tools.into_iter().map(anthropic_api::Tool::from).collect()) + .maybe_thinking(thinking.map(anthropic_api::Thinking::from)) + .maybe_temperature(temperature) + .maybe_max_tokens(max_tokens) + .retry_config(anthropic_retry_config(request_options)) + .call() + .await + .whatever_context_display("failed to send messages stream")?; + Ok(map_anthropic_stream(stream)) + } + Client::OpenAI(client) => { + let stream = openai::messages_stream( + client, + system_prompt, + conversations, + tools, + None, + thinking, + request_options, ) - .tools(tools.into_iter().map(anthropic_api::Tool::from).collect()) - .maybe_thinking(thinking.map(anthropic_api::Thinking::from)) - .maybe_temperature(temperature) - .maybe_max_tokens(max_tokens) - .retry_config(anthropic_retry_config(request_options)) - .call() .await .whatever_context_display("failed to send messages stream")?; - let mapped = - stream.map(|event| event.map(Into::into).map_err(map_anthropic_stream_error)); - Ok(Box::pin(mapped)) + Ok(Box::pin(stream) as MessagesStream) + } } - Client::OpenAI(client) => { - let stream = openai::messages_stream( - client, - system_prompt, - conversations, - tools, - None, - thinking, - request_options, - ) - .await - .whatever_context_display("failed to send messages stream")?; - Ok(Box::pin(stream)) - } - }; - notify_retry_finished(request_options, result.is_ok()); - result + }) + .await } pub async fn messages_with_tool_choice( @@ -158,42 +156,42 @@ impl Client { thinking: Option, request_options: &RequestOptions, ) -> Result { - let result: Result = match self { - Client::Anthropic(client) => { - let temperature = request_options.temperature.map(f64::from); - let max_tokens = request_options.max_tokens; - let response = client - .messages_with_tool_choice( - system_prompt, - conversations - .into_iter() - .map(anthropic_api::Message::from) - .collect(), - tools.into_iter().map(anthropic_api::Tool::from).collect(), - tool_choice.into(), - thinking.map(anthropic_api::Thinking::from), - temperature, - max_tokens, - anthropic_retry_config(request_options), - ) - .await - .whatever_context_display("failed to request tool choice")?; - Ok(response.into()) + complete_with_retry_status(request_options, async { + match self { + Client::Anthropic(client) => { + let (temperature, max_tokens) = anthropic_request_params(request_options); + let response = client + .messages_with_tool_choice( + system_prompt, + conversations + .into_iter() + .map(anthropic_api::Message::from) + .collect(), + tools.into_iter().map(anthropic_api::Tool::from).collect(), + tool_choice.into(), + thinking.map(anthropic_api::Thinking::from), + temperature, + max_tokens, + anthropic_retry_config(request_options), + ) + .await + .whatever_context_display("failed to request tool choice")?; + Ok(response.into()) + } + Client::OpenAI(client) => openai::messages( + client, + system_prompt, + conversations, + tools, + Some(tool_choice), + thinking, + request_options, + ) + .await + .whatever_context_display("failed to request tool choice"), } - Client::OpenAI(client) => openai::messages( - client, - system_prompt, - conversations, - tools, - Some(tool_choice), - thinking, - request_options, - ) - .await - .whatever_context_display("failed to request tool choice"), - }; - notify_retry_finished(request_options, result.is_ok()); - result + }) + .await } pub async fn messages_stream_with_tool_choice( @@ -205,50 +203,57 @@ impl Client { thinking: Option, request_options: &RequestOptions, ) -> Result { - let result: Result = match self { - Client::Anthropic(client) => { - let temperature = request_options.temperature.map(f64::from); - let max_tokens = request_options.max_tokens; - let stream = client - .messages_stream_with_tool_choice( + complete_with_retry_status(request_options, async { + match self { + Client::Anthropic(client) => { + let (temperature, max_tokens) = anthropic_request_params(request_options); + let stream = client + .messages_stream_with_tool_choice( + system_prompt, + conversations + .into_iter() + .map(anthropic_api::Message::from) + .collect(), + tools.into_iter().map(anthropic_api::Tool::from).collect(), + tool_choice.into(), + thinking.map(anthropic_api::Thinking::from), + temperature, + max_tokens, + anthropic_retry_config(request_options), + ) + .await + .whatever_context_display("failed to request tool choice stream")?; + Ok(map_anthropic_stream(stream)) + } + Client::OpenAI(client) => { + let stream = openai::messages_stream( + client, system_prompt, - conversations - .into_iter() - .map(anthropic_api::Message::from) - .collect(), - tools.into_iter().map(anthropic_api::Tool::from).collect(), - tool_choice.into(), - thinking.map(anthropic_api::Thinking::from), - temperature, - max_tokens, - anthropic_retry_config(request_options), + conversations, + tools, + Some(tool_choice), + thinking, + request_options, ) .await .whatever_context_display("failed to request tool choice stream")?; - let mapped = - stream.map(|event| event.map(Into::into).map_err(map_anthropic_stream_error)); - Ok(Box::pin(mapped)) + Ok(Box::pin(stream) as MessagesStream) + } } - Client::OpenAI(client) => { - let stream = openai::messages_stream( - client, - system_prompt, - conversations, - tools, - Some(tool_choice), - thinking, - request_options, - ) - .await - .whatever_context_display("failed to request tool choice stream")?; - Ok(Box::pin(stream)) - } - }; - notify_retry_finished(request_options, result.is_ok()); - result + }) + .await } } +async fn complete_with_retry_status(request_options: &RequestOptions, op: F) -> Result +where + F: Future>, +{ + let result = op.await; + notify_retry_finished(request_options, result.is_ok()); + result +} + fn notify_retry_finished(request_options: &RequestOptions, success: bool) { if let Some(notifier) = &request_options.retry_notifier { notifier.notify(RetryUpdate::Finished { success }); @@ -275,6 +280,27 @@ fn anthropic_retry_config(request_options: &RequestOptions) -> anthropic_api::Re } } +fn anthropic_request_params(request_options: &RequestOptions) -> (Option, Option) { + ( + request_options.temperature.map(f64::from), + request_options.max_tokens, + ) +} + +fn map_anthropic_stream(stream: S) -> MessagesStream +where + S: Stream< + Item = std::result::Result< + anthropic_api::MessagesStreamEvent, + anthropic_api::StreamError, + >, + > + Send + + 'static, +{ + let mapped = stream.map(|event| event.map(Into::into).map_err(map_anthropic_stream_error)); + Box::pin(mapped) +} + fn map_anthropic_stream_error(err: anthropic_api::StreamError) -> StreamError { match err.kind { anthropic_api::StreamErrorKind::Transport => StreamError::transport(err.message), diff --git a/src/tools.rs b/src/tools.rs index a1ca286b..ce842a87 100644 --- a/src/tools.rs +++ b/src/tools.rs @@ -1,4 +1,4 @@ -use std::any::Any; +use std::{any::Any, path::PathBuf}; use async_trait::async_trait; use serde::{Deserialize, Serialize}; @@ -103,6 +103,14 @@ impl From for Output { /// Result for LLM pub type ExecuteResult = Result; +pub(crate) fn parse_relative_path(path: String) -> Result { + let path = PathBuf::from(path); + if path.is_absolute() { + return Err("Path must be relative to the working directory, not absolute".into()); + } + Ok(path) +} + impl TryFrom<&Final> for crate::provider::Content { type Error = serde_json::Error; @@ -141,3 +149,30 @@ impl Final { Err(self) } } + +#[cfg(test)] +mod tests { + use super::parse_relative_path; + + #[test] + fn parse_relative_path_accepts_relative_path() { + let path = parse_relative_path("src/main.rs".to_string()).expect("relative path accepted"); + assert_eq!(path.to_string_lossy(), "src/main.rs"); + } + + #[test] + fn parse_relative_path_rejects_absolute_path() { + let absolute = if cfg!(windows) { + r"C:\tmp\file.txt" + } else { + "/tmp/file.txt" + }; + let err = parse_relative_path(absolute.to_string()).expect_err("absolute path rejected"); + match err { + super::Final::Message(message) => { + assert!(message.contains("Path must be relative")); + } + super::Final::Json(value) => panic!("unexpected json error: {value}"), + } + } +} diff --git a/src/tools/list.rs b/src/tools/list.rs index dfdb7af4..744938ce 100644 --- a/src/tools/list.rs +++ b/src/tools/list.rs @@ -4,7 +4,7 @@ use serde::{Deserialize, Serialize}; use serde_json::json; use tokio::fs; -use super::{ExecuteResult, Final, Input, Tool}; +use super::{ExecuteResult, Final, Input, Tool, parse_relative_path}; #[derive(Default)] pub struct ListTool {} @@ -70,12 +70,7 @@ impl Tool for ListTool { return err_msg!("Exceeded maximum entry limit for listing"); } - let path = path - .parse::() - .map_err(|err| format!("Invalid path format: {err}"))?; - if path.is_absolute() { - return err_msg!("Path must be relative to the working directory, not absolute"); - } + let path = parse_relative_path(path)?; let metadata = fs::symlink_metadata(&path) .await diff --git a/src/tools/read.rs b/src/tools/read.rs index 215263f8..5b6f49b7 100644 --- a/src/tools/read.rs +++ b/src/tools/read.rs @@ -4,7 +4,7 @@ use serde::{Deserialize, Serialize}; use serde_json::json; use tokio::io::{AsyncBufReadExt, BufReader}; -use super::{ExecuteResult, Final, Input, Tool}; +use super::{ExecuteResult, Final, Input, Tool, parse_relative_path}; #[derive(Default)] pub struct ReadTool {} @@ -91,13 +91,7 @@ impl Tool for ReadTool { return err_msg!("Exceeded maximum line limit for reading file"); } - // Check if the path is absolute - let path = path - .parse::() - .map_err(|err| format!("Invalid path format: {err}"))?; - if path.is_absolute() { - return err_msg!("Path must be relative to the working directory, not absolute"); - } + let path = parse_relative_path(path)?; let fh = tokio::fs::File::open(path) .await diff --git a/src/tools/str_replace.rs b/src/tools/str_replace.rs index 75035903..1f947c2e 100644 --- a/src/tools/str_replace.rs +++ b/src/tools/str_replace.rs @@ -7,7 +7,7 @@ use tokio::{ io::{AsyncReadExt, AsyncWriteExt, BufReader}, }; -use super::{ExecuteResult, Final, Input, Output, TextEdit, Tool}; +use super::{ExecuteResult, Final, Input, Output, TextEdit, Tool, parse_relative_path}; use crate::AppliedTextEdit; #[derive(Default)] @@ -94,14 +94,7 @@ impl StrReplaceTool { expected_replacements, } = serde_json::from_value(input).map_err(|err| format!("Invalid input format: {err}"))?; - // Check if the path is absolute - let path = path - .parse::() - .map_err(|err| format!("Invalid path format: {err}"))?; - - if path.is_absolute() { - return Err("Path must be relative to the working directory, not absolute".into()); - } + let path = parse_relative_path(path)?; if old_str.is_empty() { Ok(TextEdit::new(path, String::new(), new_str.to_string()))