Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 0 additions & 40 deletions patch_pr.py

This file was deleted.

66 changes: 57 additions & 9 deletions tests/test_tools.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import collections
import unittest
from src.tools import (
load_tool_snapshot,
Expand All @@ -11,7 +12,7 @@
find_tools,
execute_tool,
render_tool_index,
PORTED_TOOLS
PORTED_TOOLS,
)
from src.models import PortingBacklog, PortingModule
from src.permissions import ToolPermissionContext
Expand All @@ -24,12 +25,12 @@ def test_load_tool_snapshot(self) -> None:
self.assertTrue(len(tools) > 0)
for tool in tools:
self.assertIsInstance(tool, PortingModule)
self.assertEqual(tool.status, 'mirrored')
self.assertEqual(tool.status, "mirrored")

def test_build_tool_backlog(self) -> None:
backlog = build_tool_backlog()
self.assertIsInstance(backlog, PortingBacklog)
self.assertEqual(backlog.title, 'Tool surface')
self.assertEqual(backlog.title, "Tool surface")
self.assertEqual(len(backlog.modules), len(PORTED_TOOLS))
self.assertEqual(backlog.modules, list(PORTED_TOOLS))

Expand All @@ -49,9 +50,49 @@ def test_get_tool(self) -> None:
# Case-insensitive match
self.assertEqual(get_tool(first_tool.name.lower()), first_tool)
self.assertEqual(get_tool(first_tool.name.upper()), first_tool)
# Mixed casing
mixed_case_name = "".join(
c.upper() if i % 2 == 0 else c.lower()
for i, c in enumerate(first_tool.name)
)
self.assertEqual(get_tool(mixed_case_name), first_tool)

# Edge cases
self.assertIsNone(get_tool(""))
self.assertIsNone(get_tool(" "))
self.assertIsNone(get_tool("\n"))

# Unknown tool
self.assertIsNone(get_tool("NonExistentToolNamexyz123"))

# First-match priority on duplicates
# Find the first duplicated name using Counter for O(n) detection
name_counts = collections.Counter(t.name for t in PORTED_TOOLS)
dupe_name = next(
(name for name, count in name_counts.items() if count > 1),
None,
)

if dupe_name is not None:
# Find the actual first module in PORTED_TOOLS with this name
expected_first_module = next(
t for t in PORTED_TOOLS if t.name.lower() == dupe_name.lower()
)

# get_tool should return this exact expected_first_module
self.assertEqual(get_tool(dupe_name), expected_first_module)
self.assertEqual(get_tool(dupe_name.upper()), expected_first_module)

# verify there are other modules with the same name but a different source_hint
other_modules = [
t
for t in PORTED_TOOLS
if t.name.lower() == dupe_name.lower() and t != expected_first_module
]
self.assertTrue(len(other_modules) > 0)
Comment on lines +90 to +92
self.assertNotEqual(other_modules[0].source_hint, expected_first_module.source_hint)
self.assertNotEqual(get_tool(dupe_name), other_modules[0])

def test_filter_tools_by_permission_context(self) -> None:
tools = PORTED_TOOLS[:5]
# No context
Expand All @@ -70,21 +111,27 @@ def test_get_tools(self) -> None:
self.assertEqual(len(all_tools), len(PORTED_TOOLS))

# simple_mode
simple_mode_names = {'BashTool', 'FileReadTool', 'FileEditTool'}
expected_simple_names = {t.name for t in PORTED_TOOLS if t.name in simple_mode_names}
simple_mode_names = {"BashTool", "FileReadTool", "FileEditTool"}
expected_simple_names = {
t.name for t in PORTED_TOOLS if t.name in simple_mode_names
}
simple_tools = get_tools(simple_mode=True)
simple_tool_names = {tool.name for tool in simple_tools}
self.assertEqual(simple_tool_names, expected_simple_names)

# include_mcp=False
# First, find if there are any MCP tools to test the filter
mcp_tools = [t for t in PORTED_TOOLS if 'mcp' in t.name.lower() or 'mcp' in t.source_hint.lower()]
mcp_tools = [
t
for t in PORTED_TOOLS
if "mcp" in t.name.lower() or "mcp" in t.source_hint.lower()
]
if mcp_tools:
no_mcp_tools = get_tools(include_mcp=False)
self.assertTrue(len(no_mcp_tools) < len(PORTED_TOOLS))
for tool in no_mcp_tools:
self.assertNotIn('mcp', tool.name.lower())
self.assertNotIn('mcp', tool.source_hint.lower())
self.assertNotIn("mcp", tool.name.lower())
self.assertNotIn("mcp", tool.source_hint.lower())

# With permission context
if len(PORTED_TOOLS) > 0:
Expand Down Expand Up @@ -146,5 +193,6 @@ def test_render_tool_index(self) -> None:
self.assertIn(f"Filtered by: {tool.name}", output)
self.assertIn(tool.name, output)

if __name__ == '__main__':

if __name__ == "__main__":
unittest.main()