-
Notifications
You must be signed in to change notification settings - Fork 0
fix: make MCP JSON output UTF-8 safe #666
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
2a2c583
37ffa3e
0751a5f
4181e38
2feaded
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,65 @@ | ||
| """Regression tests for UTF-8-safe MCP JSON serialization.""" | ||
|
|
||
| import asyncio | ||
| from types import SimpleNamespace | ||
|
|
||
|
|
||
| class _AsyncLock: | ||
| async def __aenter__(self): | ||
| return self | ||
|
|
||
| async def __aexit__(self, exc_type, exc, tb): | ||
| return False | ||
|
|
||
|
|
||
| def test_mcp_json_dumps_preserves_cjk_but_escapes_lone_surrogates(): | ||
| from tools.mcp_tool import _mcp_json_dumps | ||
|
|
||
| payload = _mcp_json_dumps({"result": "中文\ud800"}) | ||
|
|
||
| assert "中文" in payload | ||
| assert "\ud800" not in payload | ||
| assert "\\ud800" in payload | ||
| payload.encode("utf-8") | ||
|
|
||
|
|
||
| def test_mcp_tool_handler_returns_utf8_encodable_result_for_surrogate_text(monkeypatch): | ||
| import tools.mcp_tool as mcp_tool | ||
|
|
||
| class FakeSession: | ||
| async def call_tool(self, tool_name, arguments): | ||
| return SimpleNamespace( | ||
| isError=False, | ||
| content=[SimpleNamespace(text="tool output \ud800")], | ||
| ) | ||
|
|
||
| fake_server = SimpleNamespace(session=FakeSession(), _rpc_lock=_AsyncLock()) | ||
| monkeypatch.setitem(mcp_tool._servers, "surrogate-server", fake_server) | ||
| monkeypatch.setattr( | ||
| mcp_tool, | ||
| "_run_on_mcp_loop", | ||
| lambda coro, timeout=None: asyncio.run(coro), | ||
| ) | ||
|
|
||
| handler = mcp_tool._make_tool_handler("surrogate-server", "demo", 1.0) | ||
| result = handler({}) | ||
|
|
||
| assert "\ud800" not in result | ||
| assert "\\ud800" in result | ||
| result.encode("utf-8") | ||
|
|
||
|
|
||
| def test_sampling_tool_arguments_are_utf8_encodable(): | ||
| from tools.mcp_tool import SamplingHandler | ||
|
|
||
| handler = SamplingHandler("surrogate-server", {}) | ||
| tool_use = SimpleNamespace(name="demo", input={"value": "中文\ud800"}, id="call_1") | ||
| message = SimpleNamespace(role="assistant", content=[tool_use], content_as_list=[tool_use]) | ||
|
|
||
| converted = handler._convert_messages(SimpleNamespace(messages=[message])) | ||
| arguments = converted[0]["tool_calls"][0]["function"]["arguments"] | ||
|
|
||
| assert "中文" in arguments | ||
| assert "\ud800" not in arguments | ||
| assert "\\ud800" in arguments | ||
| arguments.encode("utf-8") |
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -92,6 +92,31 @@ | |||||||||||||||||||||||||||
| logger = logging.getLogger(__name__) | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| def _utf8_safe_text(text: str) -> str: | ||||||||||||||||||||||||||||
| """Return text that can always be encoded as UTF-8.""" | ||||||||||||||||||||||||||||
| return text.encode("utf-8", errors="backslashreplace").decode("utf-8") | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| def _sanitize_json_value(value: Any) -> Any: | ||||||||||||||||||||||||||||
| """Preserve valid Unicode while escaping lone surrogates in JSON data.""" | ||||||||||||||||||||||||||||
| if isinstance(value, str): | ||||||||||||||||||||||||||||
| return _utf8_safe_text(value) | ||||||||||||||||||||||||||||
| if isinstance(value, dict): | ||||||||||||||||||||||||||||
| return { | ||||||||||||||||||||||||||||
| _sanitize_json_value(key): _sanitize_json_value(item) | ||||||||||||||||||||||||||||
| for key, item in value.items() | ||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||
| if isinstance(value, (list, tuple)): | ||||||||||||||||||||||||||||
| return [_sanitize_json_value(item) for item in value] | ||||||||||||||||||||||||||||
| return value | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| def _mcp_json_dumps(value: Any, **kwargs: Any) -> str: | ||||||||||||||||||||||||||||
| """Serialize MCP-controlled data without emitting invalid UTF-8 text.""" | ||||||||||||||||||||||||||||
| kwargs.setdefault("ensure_ascii", False) | ||||||||||||||||||||||||||||
| return json.dumps(_sanitize_json_value(value), **kwargs) | ||||||||||||||||||||||||||||
|
Comment on lines
+114
to
+117
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Recursively traversing and copying the entire JSON object structure in
Suggested change
|
||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| # --------------------------------------------------------------------------- | ||||||||||||||||||||||||||||
| # Stdio subprocess stderr redirection | ||||||||||||||||||||||||||||
| # --------------------------------------------------------------------------- | ||||||||||||||||||||||||||||
|
|
@@ -669,7 +694,7 @@ def _convert_messages(self, params) -> List[dict]: | |||||||||||||||||||||||||||
| "type": "function", | ||||||||||||||||||||||||||||
| "function": { | ||||||||||||||||||||||||||||
| "name": tu.name, | ||||||||||||||||||||||||||||
| "arguments": json.dumps(tu.input, ensure_ascii=False) if isinstance(tu.input, dict) else str(tu.input), | ||||||||||||||||||||||||||||
| "arguments": _mcp_json_dumps(tu.input) if isinstance(tu.input, dict) else _utf8_safe_text(str(tu.input)), | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| }, | ||||||||||||||||||||||||||||
| }) | ||||||||||||||||||||||||||||
| msg_dict: dict = {"role": msg.role, "tool_calls": tc_list} | ||||||||||||||||||||||||||||
|
|
@@ -1820,7 +1845,7 @@ async def _recover(): | |||||||||||||||||||||||||||
| # needs_reauth error. Bumps the circuit breaker so the model stops | ||||||||||||||||||||||||||||
| # retrying the tool. | ||||||||||||||||||||||||||||
| _bump_server_error(server_name) | ||||||||||||||||||||||||||||
| return json.dumps({ | ||||||||||||||||||||||||||||
| return _mcp_json_dumps({ | ||||||||||||||||||||||||||||
| "error": ( | ||||||||||||||||||||||||||||
| f"MCP server '{server_name}' requires re-authentication. " | ||||||||||||||||||||||||||||
| f"Run `hermes mcp login {server_name}` (or delete the tokens " | ||||||||||||||||||||||||||||
|
|
@@ -2080,7 +2105,7 @@ def _run_on_mcp_loop(coro, timeout: float = 30): | |||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| def _interrupted_call_result() -> str: | ||||||||||||||||||||||||||||
| """Standardized JSON error for a user-interrupted MCP tool call.""" | ||||||||||||||||||||||||||||
| return json.dumps({ | ||||||||||||||||||||||||||||
| return _mcp_json_dumps({ | ||||||||||||||||||||||||||||
| "error": "MCP call interrupted: user sent a new message" | ||||||||||||||||||||||||||||
| }, ensure_ascii=False) | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
|
|
@@ -2178,7 +2203,7 @@ def _handler(args: dict, **kwargs) -> str: | |||||||||||||||||||||||||||
| age = time.monotonic() - opened_at | ||||||||||||||||||||||||||||
| if age < _CIRCUIT_BREAKER_COOLDOWN_SEC: | ||||||||||||||||||||||||||||
| remaining = max(1, int(_CIRCUIT_BREAKER_COOLDOWN_SEC - age)) | ||||||||||||||||||||||||||||
| return json.dumps({ | ||||||||||||||||||||||||||||
| return _mcp_json_dumps({ | ||||||||||||||||||||||||||||
| "error": ( | ||||||||||||||||||||||||||||
| f"MCP server '{server_name}' is unreachable after " | ||||||||||||||||||||||||||||
| f"{_server_error_counts[server_name]} consecutive " | ||||||||||||||||||||||||||||
|
|
@@ -2193,7 +2218,7 @@ def _handler(args: dict, **kwargs) -> str: | |||||||||||||||||||||||||||
| server = _servers.get(server_name) | ||||||||||||||||||||||||||||
| if not server or not server.session: | ||||||||||||||||||||||||||||
| _bump_server_error(server_name) | ||||||||||||||||||||||||||||
| return json.dumps({ | ||||||||||||||||||||||||||||
| return _mcp_json_dumps({ | ||||||||||||||||||||||||||||
| "error": f"MCP server '{server_name}' is not connected" | ||||||||||||||||||||||||||||
| }, ensure_ascii=False) | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
|
|
@@ -2206,7 +2231,7 @@ async def _call(): | |||||||||||||||||||||||||||
| for block in (result.content or []): | ||||||||||||||||||||||||||||
| if hasattr(block, "text"): | ||||||||||||||||||||||||||||
| error_text += block.text | ||||||||||||||||||||||||||||
| return json.dumps({ | ||||||||||||||||||||||||||||
| return _mcp_json_dumps({ | ||||||||||||||||||||||||||||
| "error": _sanitize_error( | ||||||||||||||||||||||||||||
| error_text or "MCP tool returned an error" | ||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||
|
|
@@ -2240,12 +2265,12 @@ async def _call(): | |||||||||||||||||||||||||||
| structured = getattr(result, "structuredContent", None) | ||||||||||||||||||||||||||||
| if structured is not None: | ||||||||||||||||||||||||||||
| if text_result: | ||||||||||||||||||||||||||||
| return json.dumps({ | ||||||||||||||||||||||||||||
| return _mcp_json_dumps({ | ||||||||||||||||||||||||||||
| "result": text_result, | ||||||||||||||||||||||||||||
| "structuredContent": structured, | ||||||||||||||||||||||||||||
| }, ensure_ascii=False) | ||||||||||||||||||||||||||||
| return json.dumps({"result": structured}, ensure_ascii=False) | ||||||||||||||||||||||||||||
| return json.dumps({"result": text_result}, ensure_ascii=False) | ||||||||||||||||||||||||||||
| return _mcp_json_dumps({"result": structured}, ensure_ascii=False) | ||||||||||||||||||||||||||||
| return _mcp_json_dumps({"result": text_result}, ensure_ascii=False) | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| def _call_once(): | ||||||||||||||||||||||||||||
| return _run_on_mcp_loop(_call(), timeout=tool_timeout) | ||||||||||||||||||||||||||||
|
|
@@ -2290,7 +2315,7 @@ def _call_once(): | |||||||||||||||||||||||||||
| "MCP tool %s/%s call failed: %s", | ||||||||||||||||||||||||||||
| server_name, tool_name, exc, | ||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||
| return json.dumps({ | ||||||||||||||||||||||||||||
| return _mcp_json_dumps({ | ||||||||||||||||||||||||||||
| "error": _sanitize_error( | ||||||||||||||||||||||||||||
| f"MCP call failed: {type(exc).__name__}: {_exc_str(exc)}" | ||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||
|
|
@@ -2306,7 +2331,7 @@ def _handler(args: dict, **kwargs) -> str: | |||||||||||||||||||||||||||
| with _lock: | ||||||||||||||||||||||||||||
| server = _servers.get(server_name) | ||||||||||||||||||||||||||||
| if not server or not server.session: | ||||||||||||||||||||||||||||
| return json.dumps({ | ||||||||||||||||||||||||||||
| return _mcp_json_dumps({ | ||||||||||||||||||||||||||||
| "error": f"MCP server '{server_name}' is not connected" | ||||||||||||||||||||||||||||
| }, ensure_ascii=False) | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
|
|
@@ -2325,7 +2350,7 @@ async def _call(): | |||||||||||||||||||||||||||
| if hasattr(r, "mimeType") and r.mimeType: | ||||||||||||||||||||||||||||
| entry["mimeType"] = r.mimeType | ||||||||||||||||||||||||||||
| resources.append(entry) | ||||||||||||||||||||||||||||
| return json.dumps({"resources": resources}, ensure_ascii=False) | ||||||||||||||||||||||||||||
| return _mcp_json_dumps({"resources": resources}, ensure_ascii=False) | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| def _call_once(): | ||||||||||||||||||||||||||||
| return _run_on_mcp_loop(_call(), timeout=tool_timeout) | ||||||||||||||||||||||||||||
|
|
@@ -2348,7 +2373,7 @@ def _call_once(): | |||||||||||||||||||||||||||
| logger.error( | ||||||||||||||||||||||||||||
| "MCP %s/list_resources failed: %s", server_name, exc, | ||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||
| return json.dumps({ | ||||||||||||||||||||||||||||
| return _mcp_json_dumps({ | ||||||||||||||||||||||||||||
| "error": _sanitize_error( | ||||||||||||||||||||||||||||
| f"MCP call failed: {type(exc).__name__}: {_exc_str(exc)}" | ||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||
|
|
@@ -2366,7 +2391,7 @@ def _handler(args: dict, **kwargs) -> str: | |||||||||||||||||||||||||||
| with _lock: | ||||||||||||||||||||||||||||
| server = _servers.get(server_name) | ||||||||||||||||||||||||||||
| if not server or not server.session: | ||||||||||||||||||||||||||||
| return json.dumps({ | ||||||||||||||||||||||||||||
| return _mcp_json_dumps({ | ||||||||||||||||||||||||||||
| "error": f"MCP server '{server_name}' is not connected" | ||||||||||||||||||||||||||||
| }, ensure_ascii=False) | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
|
|
@@ -2385,7 +2410,7 @@ async def _call(): | |||||||||||||||||||||||||||
| parts.append(block.text) | ||||||||||||||||||||||||||||
| elif hasattr(block, "blob"): | ||||||||||||||||||||||||||||
| parts.append(f"[binary data, {len(block.blob)} bytes]") | ||||||||||||||||||||||||||||
| return json.dumps({"result": "\n".join(parts) if parts else ""}, ensure_ascii=False) | ||||||||||||||||||||||||||||
| return _mcp_json_dumps({"result": "\n".join(parts) if parts else ""}, ensure_ascii=False) | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| def _call_once(): | ||||||||||||||||||||||||||||
| return _run_on_mcp_loop(_call(), timeout=tool_timeout) | ||||||||||||||||||||||||||||
|
|
@@ -2408,7 +2433,7 @@ def _call_once(): | |||||||||||||||||||||||||||
| logger.error( | ||||||||||||||||||||||||||||
| "MCP %s/read_resource failed: %s", server_name, exc, | ||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||
| return json.dumps({ | ||||||||||||||||||||||||||||
| return _mcp_json_dumps({ | ||||||||||||||||||||||||||||
| "error": _sanitize_error( | ||||||||||||||||||||||||||||
| f"MCP call failed: {type(exc).__name__}: {_exc_str(exc)}" | ||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||
|
|
@@ -2424,7 +2449,7 @@ def _handler(args: dict, **kwargs) -> str: | |||||||||||||||||||||||||||
| with _lock: | ||||||||||||||||||||||||||||
| server = _servers.get(server_name) | ||||||||||||||||||||||||||||
| if not server or not server.session: | ||||||||||||||||||||||||||||
| return json.dumps({ | ||||||||||||||||||||||||||||
| return _mcp_json_dumps({ | ||||||||||||||||||||||||||||
| "error": f"MCP server '{server_name}' is not connected" | ||||||||||||||||||||||||||||
| }, ensure_ascii=False) | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
|
|
@@ -2448,7 +2473,7 @@ async def _call(): | |||||||||||||||||||||||||||
| for a in p.arguments | ||||||||||||||||||||||||||||
| ] | ||||||||||||||||||||||||||||
| prompts.append(entry) | ||||||||||||||||||||||||||||
| return json.dumps({"prompts": prompts}, ensure_ascii=False) | ||||||||||||||||||||||||||||
| return _mcp_json_dumps({"prompts": prompts}, ensure_ascii=False) | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| def _call_once(): | ||||||||||||||||||||||||||||
| return _run_on_mcp_loop(_call(), timeout=tool_timeout) | ||||||||||||||||||||||||||||
|
|
@@ -2471,7 +2496,7 @@ def _call_once(): | |||||||||||||||||||||||||||
| logger.error( | ||||||||||||||||||||||||||||
| "MCP %s/list_prompts failed: %s", server_name, exc, | ||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||
| return json.dumps({ | ||||||||||||||||||||||||||||
| return _mcp_json_dumps({ | ||||||||||||||||||||||||||||
| "error": _sanitize_error( | ||||||||||||||||||||||||||||
| f"MCP call failed: {type(exc).__name__}: {_exc_str(exc)}" | ||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||
|
|
@@ -2489,7 +2514,7 @@ def _handler(args: dict, **kwargs) -> str: | |||||||||||||||||||||||||||
| with _lock: | ||||||||||||||||||||||||||||
| server = _servers.get(server_name) | ||||||||||||||||||||||||||||
| if not server or not server.session: | ||||||||||||||||||||||||||||
| return json.dumps({ | ||||||||||||||||||||||||||||
| return _mcp_json_dumps({ | ||||||||||||||||||||||||||||
| "error": f"MCP server '{server_name}' is not connected" | ||||||||||||||||||||||||||||
| }, ensure_ascii=False) | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
|
|
@@ -2519,7 +2544,7 @@ async def _call(): | |||||||||||||||||||||||||||
| resp = {"messages": messages} | ||||||||||||||||||||||||||||
| if hasattr(result, "description") and result.description: | ||||||||||||||||||||||||||||
| resp["description"] = result.description | ||||||||||||||||||||||||||||
| return json.dumps(resp, ensure_ascii=False) | ||||||||||||||||||||||||||||
| return _mcp_json_dumps(resp, ensure_ascii=False) | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| def _call_once(): | ||||||||||||||||||||||||||||
| return _run_on_mcp_loop(_call(), timeout=tool_timeout) | ||||||||||||||||||||||||||||
|
|
@@ -2542,7 +2567,7 @@ def _call_once(): | |||||||||||||||||||||||||||
| logger.error( | ||||||||||||||||||||||||||||
| "MCP %s/get_prompt failed: %s", server_name, exc, | ||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||
| return json.dumps({ | ||||||||||||||||||||||||||||
| return _mcp_json_dumps({ | ||||||||||||||||||||||||||||
| "error": _sanitize_error( | ||||||||||||||||||||||||||||
| f"MCP call failed: {type(exc).__name__}: {_exc_str(exc)}" | ||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To avoid unnecessary string allocations and encoding/decoding overhead for valid UTF-8 strings (which represent the vast majority of cases), we can perform a fast-path check using a
try...except UnicodeEncodeErrorblock. If the string is already valid UTF-8, we can return it directly.