diff --git a/crates/tui/src/mcp.rs b/crates/tui/src/mcp.rs index b4d287978..e10a051c8 100644 --- a/crates/tui/src/mcp.rs +++ b/crates/tui/src/mcp.rs @@ -14,12 +14,15 @@ use std::time::Duration; use anyhow::{Context, Result}; use reqwest::StatusCode; -use reqwest::header::{ACCEPT, CONTENT_TYPE}; +use reqwest::header::CONTENT_TYPE; use serde::{Deserialize, Serialize}; use tokio::io::{AsyncBufReadExt, AsyncWriteExt}; use tokio::process::{Child, ChildStdin, ChildStdout}; use tokio::sync::Mutex as TokioMutex; +mod headers; + +use self::headers::{apply_safe_custom_headers, with_default_mcp_http_headers}; use crate::child_env; use crate::network_policy::{Decision, NetworkPolicyDecider, host_from_url}; use crate::utils::write_atomic; @@ -28,19 +31,6 @@ use crate::utils::write_atomic; /// Bytes of a non-2xx response body to surface in connection errors. const ERROR_BODY_PREVIEW_BYTES: usize = 200; -const MCP_HTTP_ACCEPT: &str = "application/json, text/event-stream"; - -fn with_default_mcp_http_headers( - request: reqwest::RequestBuilder, - json_body: bool, -) -> reqwest::RequestBuilder { - let request = request.header(ACCEPT, MCP_HTTP_ACCEPT); - if json_body { - request.header(CONTENT_TYPE, "application/json") - } else { - request - } -} fn validate_mcp_config_path(path: &Path) -> Result<()> { if path.as_os_str().is_empty() { @@ -55,54 +45,6 @@ fn validate_mcp_config_path(path: &Path) -> Result<()> { Ok(()) } -/// Predicate for [`StreamableHttpTransport::send`]'s custom-header pass. -/// -/// We accept whatever reqwest's `HeaderName::try_from` / -/// `HeaderValue::try_from` would accept, but with three extra rules: -/// -/// 1. Reject empty / whitespace-only keys — these would surface as a -/// request-builder error mid-send and abort the whole connection. -/// 2. Reject keys that duplicate the framing we already emit -/// (`Accept`, `Content-Type`). The MCP Streamable HTTP transport -/// relies on those exact values for protocol negotiation; a stray -/// user override could silently break tool discovery. -/// 3. Reject values containing ASCII CR or LF. reqwest already -/// rejects those, but the explicit check makes the failure path -/// visible (a `tracing::warn!` instead of an obscure -/// builder error) and documents the response-splitting -/// defense. -/// -/// Returning `false` means "skip this header"; the rest of the -/// request still goes out. -fn is_safe_custom_header(key: &str, value: &str) -> bool { - let trimmed = key.trim(); - if trimmed.is_empty() { - return false; - } - if trimmed.eq_ignore_ascii_case("accept") || trimmed.eq_ignore_ascii_case("content-type") { - return false; - } - !value.contains('\r') && !value.contains('\n') -} - -fn apply_safe_custom_headers( - mut request: reqwest::RequestBuilder, - headers: &HashMap, -) -> reqwest::RequestBuilder { - for (key, value) in headers { - if !is_safe_custom_header(key, value) { - tracing::warn!( - target: "mcp", - "skipping unsafe MCP header {:?} (empty/control-char/reserved)", - key - ); - continue; - } - request = request.header(key.as_str(), value.as_str()); - } - request -} - /// Mask a URL so any embedded credentials in the userinfo portion (e.g. /// `https://user:secret@host`) are replaced with `***`. Failures fall back to /// the original string so we don't lose context — we never want masking to @@ -3129,7 +3071,9 @@ pub fn format_tool_result(result: &serde_json::Value) -> String { #[cfg(test)] mod tests { + use super::headers::{MCP_HTTP_ACCEPT, is_safe_custom_header}; use super::*; + use reqwest::header::{ACCEPT, CONTENT_TYPE}; use std::collections::VecDeque; use std::sync::atomic::{AtomicBool, Ordering as AtomicOrdering}; use std::sync::{Arc, Mutex, OnceLock}; diff --git a/crates/tui/src/mcp/headers.rs b/crates/tui/src/mcp/headers.rs new file mode 100644 index 000000000..ad0957111 --- /dev/null +++ b/crates/tui/src/mcp/headers.rs @@ -0,0 +1,65 @@ +use std::collections::HashMap; + +use reqwest::header::{ACCEPT, CONTENT_TYPE}; + +pub(super) const MCP_HTTP_ACCEPT: &str = "application/json, text/event-stream"; + +pub(super) fn with_default_mcp_http_headers( + request: reqwest::RequestBuilder, + json_body: bool, +) -> reqwest::RequestBuilder { + let request = request.header(ACCEPT, MCP_HTTP_ACCEPT); + if json_body { + request.header(CONTENT_TYPE, "application/json") + } else { + request + } +} + +/// Predicate for the custom-header pass used by MCP HTTP transports. +/// +/// We accept whatever reqwest's `HeaderName::try_from` / +/// `HeaderValue::try_from` would accept, but with three extra rules: +/// +/// 1. Reject empty / whitespace-only keys - these would surface as a +/// request-builder error mid-send and abort the whole connection. +/// 2. Reject keys that duplicate the framing we already emit +/// (`Accept`, `Content-Type`). The MCP Streamable HTTP transport +/// relies on those exact values for protocol negotiation; a stray +/// user override could silently break tool discovery. +/// 3. Reject values containing ASCII CR or LF. reqwest already +/// rejects those, but the explicit check makes the failure path +/// visible (a `tracing::warn!` instead of an obscure +/// builder error) and documents the response-splitting +/// defense. +/// +/// Returning `false` means "skip this header"; the rest of the +/// request still goes out. +pub(super) fn is_safe_custom_header(key: &str, value: &str) -> bool { + let trimmed = key.trim(); + if trimmed.is_empty() { + return false; + } + if trimmed.eq_ignore_ascii_case("accept") || trimmed.eq_ignore_ascii_case("content-type") { + return false; + } + !value.contains('\r') && !value.contains('\n') +} + +pub(super) fn apply_safe_custom_headers( + mut request: reqwest::RequestBuilder, + headers: &HashMap, +) -> reqwest::RequestBuilder { + for (key, value) in headers { + if !is_safe_custom_header(key, value) { + tracing::warn!( + target: "mcp", + "skipping unsafe MCP header {:?} (empty/control-char/reserved)", + key + ); + continue; + } + request = request.header(key.as_str(), value.as_str()); + } + request +}