diff --git a/Makefile b/Makefile
index a58a8f7..48114d7 100644
--- a/Makefile
+++ b/Makefile
@@ -115,8 +115,15 @@ test-docs: ## Verify docs site builds cleanly
cd site && npx vitepress build docs
.PHONY: test-coverage
-test-coverage: ## Run all tests with coverage reports
- cd $(BACKEND) && uv run pytest tests/ --tb=short -v
+test-coverage: test-coverage-backend test-coverage-frontend ## Run all tests with coverage reports
+
+.PHONY: test-coverage-backend
+test-coverage-backend: ## Backend tests with coverage
+ cd $(BACKEND) && uv run pytest tests/ --cov --cov-report=term-missing --tb=short -v
+
+.PHONY: test-coverage-frontend
+test-coverage-frontend: ## Frontend tests with coverage
+ cd $(FRONTEND) && pnpm test --coverage
cd $(FRONTEND) && pnpm test
# ─── Docker ───────────────────────────────────────────────
diff --git a/README.md b/README.md
index 39f29bd..4d97df9 100644
--- a/README.md
+++ b/README.md
@@ -68,7 +68,7 @@ See `.env.example` for all options.
## Testing
```bash
-make test # Run all tests (149 backend + 29 frontend + docs build)
+make test # Run all tests (591 backend + 182 frontend + docs build)
make test-backend # Backend tests only
make test-frontend # Frontend tests only
make test-docs # Docs build check
diff --git a/backend/pyproject.toml b/backend/pyproject.toml
index d7c2876..7b7824f 100644
--- a/backend/pyproject.toml
+++ b/backend/pyproject.toml
@@ -63,6 +63,18 @@ openmlr = "openmlr.main:main"
dev-dependencies = [
"pytest>=8.0.0",
"pytest-asyncio>=0.24.0",
+ "pytest-cov>=6.0.0",
"aiosqlite>=0.20.0",
"httpx>=0.28.0",
]
+
+[tool.coverage.run]
+source = ["openmlr"]
+omit = ["openmlr/db/migrations/*"]
+
+[tool.coverage.report]
+exclude_lines = [
+ "pragma: no cover",
+ "if __name__ == .__main__.:",
+ "raise NotImplementedError",
+]
diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py
index e7eb1c4..0731f80 100644
--- a/backend/tests/conftest.py
+++ b/backend/tests/conftest.py
@@ -105,11 +105,15 @@ async def client() -> AsyncGenerator[httpx.AsyncClient, None]:
from openmlr.app import app
from openmlr.db.engine import get_db as engine_get_db
from openmlr.dependencies import get_db as dep_get_db
+ from openmlr.config import AgentConfig
# Override both the canonical get_db *and* the re-export in dependencies
app.dependency_overrides[engine_get_db] = _override_get_db
app.dependency_overrides[dep_get_db] = _override_get_db
+ # Set minimal app state since lifespan is skipped
+ app.state.config = AgentConfig()
+
transport = httpx.ASGITransport(app=app)
async with httpx.AsyncClient(
transport=transport,
diff --git a/backend/tests/test_agent_loop.py b/backend/tests/test_agent_loop.py
new file mode 100644
index 0000000..77a6e77
--- /dev/null
+++ b/backend/tests/test_agent_loop.py
@@ -0,0 +1,374 @@
+"""Tests for agent loop — tool execution, approval, undo, compact, submissions."""
+
+import pytest
+from unittest.mock import AsyncMock, MagicMock, patch
+
+from openmlr.agent.types import AgentEvent, Message, ToolCall, ToolSpec, Submission, OpType, LLMResult
+from openmlr.agent.session import Session
+from openmlr.agent.context import ContextManager
+from openmlr.agent.loop import (
+ _execute_tool, _handle_approval, _undo, _compact,
+ _run_agent, run_agent_turn, submission_loop,
+ _stream_llm_call, _non_stream_llm_call, _compact_llm_call,
+)
+from openmlr.config import AgentConfig
+from openmlr.tools.registry import ToolRouter
+
+pytestmark = pytest.mark.asyncio
+
+# ── Test fixtures ──────────────────────────────────────────
+
+@pytest.fixture
+def config():
+ return AgentConfig(model_name="test/model", max_iterations=10, stream=False)
+
+@pytest.fixture
+def mock_session(config):
+ session = MagicMock(spec=Session)
+ session.config = config
+ session.submission_queue = MagicMock()
+ session.context_manager = MagicMock(spec=ContextManager)
+ session.emit = AsyncMock()
+ session.cancel = MagicMock()
+ session.clear_cancel = MagicMock()
+ session.is_cancelled = MagicMock(return_value=False)
+ session.pending_approval = None
+ session.pending_answers = None
+ session.turn_count = 0
+ session.on_event = MagicMock()
+ session.update_model = MagicMock()
+ return session
+
+@pytest.fixture
+def mock_router():
+ router = MagicMock(spec=ToolRouter)
+ router.call_tool = AsyncMock(return_value=("tool output", True))
+ router.set_mode = MagicMock()
+ router.get_tool_specs_for_llm = MagicMock(return_value=[])
+ router.get_tool = MagicMock(return_value=None)
+ router.get_mode = MagicMock(return_value="execute")
+ return router
+
+
+# ── Tool Execution ─────────────────────────────────────────
+
+class TestExecuteTool:
+ async def test_executes_tool_and_returns_output(self, mock_session, mock_router):
+ tc = ToolCall(id="tc1", name="bash", arguments={"cmd": "ls"})
+ mock_router.call_tool.return_value = ("file list output", True)
+
+ output, success = await _execute_tool(mock_session, mock_router, tc)
+
+ assert output == "file list output"
+ assert success is True
+ mock_router.call_tool.assert_called_once_with("bash", {"cmd": "ls"}, session=mock_session)
+ assert mock_session.emit.called
+
+ async def test_handles_tool_error(self, mock_session, mock_router):
+ tc = ToolCall(id="tc1", name="bad_tool", arguments={})
+ mock_router.call_tool.side_effect = RuntimeError("execution failed")
+
+ output, success = await _execute_tool(mock_session, mock_router, tc)
+
+ assert success is False
+ assert "Tool execution error" in output
+
+ async def test_emits_both_state_changes(self, mock_session, mock_router):
+ tc = ToolCall(id="tc1", name="test", arguments={})
+
+ await _execute_tool(mock_session, mock_router, tc)
+
+ emitted_event_types = []
+ for call in mock_session.emit.call_args_list:
+ args = call[0]
+ if args and hasattr(args[0], 'event_type'):
+ emitted_event_types.append(args[0].event_type)
+ assert len(emitted_event_types) >= 2
+ assert "tool_state_change" in emitted_event_types
+
+
+# ── Approval Handling ──────────────────────────────────────
+
+class TestHandleApproval:
+ async def test_approves_tool_calls(self, mock_session, mock_router):
+ tcs = [ToolCall(id="tc1", name="bash", arguments={"cmd": "ls"})]
+ mock_session.pending_approval = {"tool_calls": tcs, "tool_router": mock_router}
+ mock_session.context_manager = MagicMock()
+ mock_session.context_manager.add_message = MagicMock()
+ mock_session.context_manager.get_messages = MagicMock(return_value=[])
+ mock_session.config.stream = False
+
+ await _handle_approval(mock_session, mock_router, {"tc1": True})
+
+ mock_router.call_tool.assert_called_once_with("bash", {"cmd": "ls"}, session=mock_session)
+
+ async def test_rejects_tool_calls(self, mock_session, mock_router):
+ tcs = [ToolCall(id="tc1", name="dangerous", arguments={})]
+ mock_session.pending_approval = {"tool_calls": tcs, "tool_router": mock_router}
+ mock_session.context_manager = MagicMock()
+ mock_session.context_manager.add_message = MagicMock()
+ mock_session.context_manager.get_messages = MagicMock(return_value=[])
+ mock_session.config.stream = False
+
+ await _handle_approval(mock_session, mock_router, {"tc1": False})
+
+ assert mock_router.call_tool.called == False
+
+ async def test_no_pending_approval_returns(self, mock_session, mock_router):
+ mock_session.pending_approval = None
+ await _handle_approval(mock_session, mock_router, {})
+ # No exception, nothing happens
+
+
+# ── Undo ───────────────────────────────────────────────────
+
+class TestUndo:
+ async def test_undo_calls_context_manager(self, mock_session):
+ mock_session.context_manager.undo_last_turn.return_value = 3
+
+ await _undo(mock_session)
+
+ mock_session.context_manager.undo_last_turn.assert_called_once()
+ mock_session.emit.assert_called()
+
+
+# ── Compaction ─────────────────────────────────────────────
+
+class TestCompact:
+ async def test_compact_calls_context_manager(self, mock_session):
+ mock_session.context_manager.compact = AsyncMock(return_value="Summary of conversation")
+
+ await _compact(mock_session)
+
+ mock_session.context_manager.compact.assert_called_once()
+ mock_session.emit.assert_called()
+
+
+# ── Run Agent ──────────────────────────────────────────────
+
+class TestRunAgent:
+ async def test_runs_with_no_tool_calls(self, mock_session, mock_router):
+ """Agent processes a message, LLM returns content with no tool calls."""
+ mock_session.context_manager.get_messages.return_value = []
+ mock_session.context_manager.needs_compaction.return_value = False
+ mock_session.context_manager.get_token_usage.return_value = {"used": 100, "max": 200000, "ratio": 0.0}
+ mock_session.config.stream = False
+
+ with patch("openmlr.agent.loop.LLMProvider.generate") as mock_gen:
+ mock_gen.return_value = LLMResult(
+ content="I can help with that!",
+ tool_calls=[],
+ finish_reason="stop",
+ usage={"total_tokens": 50},
+ )
+
+ await _run_agent(mock_session, mock_router, "help me")
+
+ assert mock_session.context_manager.add_message.called
+ assert mock_session.emit.called
+
+ async def test_handles_error_gracefully(self, mock_session, mock_router):
+ mock_session.context_manager.get_messages.side_effect = RuntimeError("Something broke")
+
+ await _run_agent(mock_session, mock_router, "test")
+
+ # Should not raise, emits error event
+ assert mock_session.emit.called
+
+ async def test_cancelled_stops_early(self, mock_session, mock_router):
+ mock_session.is_cancelled.return_value = True
+ mock_session.context_manager.get_messages.return_value = []
+ mock_session.context_manager.needs_compaction.return_value = False
+ mock_session.context_manager.get_token_usage.return_value = {"used": 0, "max": 200000, "ratio": 0.0}
+
+ await _run_agent(mock_session, mock_router, "test")
+
+ mock_session.emit.assert_any_call(AgentEvent(event_type="interrupted"))
+
+
+class TestRunAgentTurn:
+ async def test_delegates_to_run_agent(self, mock_session, mock_router):
+ mock_session.context_manager.get_messages.return_value = []
+ mock_session.context_manager.needs_compaction.return_value = False
+ mock_session.context_manager.get_token_usage.return_value = {"ratio": 0.0}
+ mock_session.config.stream = False
+
+ with patch("openmlr.agent.loop.LLMProvider.generate") as mock_gen:
+ mock_gen.return_value = LLMResult(
+ content="Hello!", tool_calls=[], finish_reason="stop",
+ )
+
+ await run_agent_turn(mock_session, mock_router, "Hi", mode="plan")
+
+ mock_router.set_mode.assert_called_with("plan")
+
+ async def test_default_mode_is_execute(self, mock_session, mock_router):
+ mock_session.context_manager.get_messages.return_value = []
+ mock_session.context_manager.needs_compaction.return_value = False
+ mock_session.context_manager.get_token_usage.return_value = {"ratio": 0.0}
+ mock_session.config.stream = False
+
+ with patch("openmlr.agent.loop.LLMProvider.generate") as mock_gen:
+ mock_gen.return_value = LLMResult(
+ content="Ok", tool_calls=[], finish_reason="stop",
+ )
+ await run_agent_turn(mock_session, mock_router, "test", mode="unknown")
+
+ mock_router.set_mode.assert_called_with("execute")
+
+
+# ── Submissions ────────────────────────────────────────────
+
+class TestSubmissionLoop:
+ async def test_processes_user_input(self, mock_session, mock_router):
+ mock_session.submission_queue.get = AsyncMock(side_effect=[
+ Submission(op=OpType.USER_INPUT, data="hello"),
+ Submission(op=OpType.SHUTDOWN),
+ ])
+ mock_session.context_manager.get_messages.return_value = []
+ mock_session.context_manager.needs_compaction.return_value = False
+ mock_session.context_manager.get_token_usage.return_value = {"ratio": 0.0}
+ mock_session.config.stream = False
+
+ with patch("openmlr.agent.loop.LLMProvider.generate") as mock_gen:
+ mock_gen.return_value = LLMResult(
+ content="Hi!", tool_calls=[], finish_reason="stop",
+ )
+ await submission_loop(mock_session, mock_router)
+
+ assert mock_session.emit.called
+
+ async def test_processes_compact(self, mock_session, mock_router):
+ mock_session.submission_queue.get = AsyncMock(side_effect=[
+ Submission(op=OpType.COMPACT),
+ Submission(op=OpType.SHUTDOWN),
+ ])
+ mock_session.context_manager.compact = AsyncMock(return_value="Summary")
+
+ await submission_loop(mock_session, mock_router)
+
+ mock_session.context_manager.compact.assert_called_once()
+
+ async def test_processes_undo(self, mock_session, mock_router):
+ mock_session.submission_queue.get = AsyncMock(side_effect=[
+ Submission(op=OpType.UNDO),
+ Submission(op=OpType.SHUTDOWN),
+ ])
+ mock_session.context_manager.undo_last_turn.return_value = 3
+
+ await submission_loop(mock_session, mock_router)
+
+ mock_session.context_manager.undo_last_turn.assert_called_once()
+
+ async def test_processes_interrupt(self, mock_session, mock_router):
+ mock_session.submission_queue.get = AsyncMock(side_effect=[
+ Submission(op=OpType.INTERRUPT),
+ Submission(op=OpType.SHUTDOWN),
+ ])
+
+ await submission_loop(mock_session, mock_router)
+
+ mock_session.cancel.assert_called_once()
+
+ async def test_shutdown_exits(self, mock_session, mock_router):
+ mock_session.submission_queue.get = AsyncMock(return_value=Submission(op=OpType.SHUTDOWN))
+
+ await submission_loop(mock_session, mock_router)
+
+ mock_session.emit.assert_any_call(AgentEvent(event_type="shutdown"))
+
+
+# ── LLM Call Helpers ───────────────────────────────────────
+
+class TestNonStreamLLMCall:
+ async def test_returns_llm_result(self, mock_session):
+ mock_session.is_cancelled.return_value = False
+ messages = [{"role": "user", "content": "help"}]
+ tools = []
+
+ with patch("openmlr.agent.loop.LLMProvider.generate") as mock_gen:
+ mock_gen.return_value = LLMResult(
+ content="Response", tool_calls=[], finish_reason="stop",
+ )
+ result = await _non_stream_llm_call(mock_session, messages, tools)
+
+ assert result is not None
+ assert result.content == "Response"
+
+ async def test_emits_chunk_and_end(self, mock_session):
+ mock_session.is_cancelled.return_value = False
+
+ with patch("openmlr.agent.loop.LLMProvider.generate") as mock_gen:
+ mock_gen.return_value = LLMResult(
+ content="Output", tool_calls=[], finish_reason="stop",
+ )
+ await _non_stream_llm_call(mock_session, [], [])
+
+ assert mock_session.emit.called
+
+
+class TestStreamLLMCall:
+ async def test_returns_llm_result_from_chunks(self, mock_session):
+ mock_session.is_cancelled.return_value = False
+ mock_session.config = AgentConfig(model_name="test", stream=True)
+
+ async def mock_stream(messages, config, tools):
+ yield "Hello"
+ yield " world"
+
+ with patch("openmlr.agent.loop.LLMProvider.generate_stream") as mock_str:
+ mock_str.return_value = mock_stream(None, None, None)
+ result = await _stream_llm_call(mock_session, [], [])
+
+ assert result is not None
+ assert result.content == "Hello world"
+ assert result.finish_reason == "stop"
+
+ async def test_cancelled_returns_none(self, mock_session):
+ mock_session.is_cancelled.return_value = True
+
+ async def mock_stream(messages, config, tools):
+ yield "Hello"
+ if False:
+ yield
+
+ with patch("openmlr.agent.loop.LLMProvider.generate_stream") as mock_str:
+ mock_str.return_value = mock_stream(None, None, None)
+ result = await _stream_llm_call(mock_session, [], [])
+
+ assert result is None
+
+ async def test_handles_stream_with_tool_calls(self, mock_session):
+ mock_session.is_cancelled.return_value = False
+ mock_session.config = AgentConfig(model_name="test", stream=True)
+ tc = ToolCall(id="call_1", name="search", arguments={"query": "test"})
+
+ async def mock_stream(messages, config, tools):
+ yield "Finding..."
+ yield tc
+
+ with patch("openmlr.agent.loop.LLMProvider.generate_stream") as mock_str:
+ mock_str.return_value = mock_stream(None, None, None)
+ result = await _stream_llm_call(mock_session, [], [])
+
+ assert result is not None
+ assert result.content == "Finding..."
+ assert len(result.tool_calls) == 1
+ assert result.tool_calls[0].name == "search"
+ assert result.finish_reason == "tool_calls"
+
+
+# ── Compact LLM Call ───────────────────────────────────────
+
+class TestCompactLLMCall:
+ async def test_returns_content(self):
+ messages = [{"role": "user", "content": "summarize"}]
+ config = AgentConfig(model_name="test/title", stream=False)
+
+ with patch("openmlr.agent.loop.LLMProvider.generate") as mock_gen:
+ mock_gen.return_value = LLMResult(
+ content="A summary.", tool_calls=[], finish_reason="stop",
+ )
+ result = await _compact_llm_call(messages, config)
+
+ assert result == "A summary."
diff --git a/backend/tests/test_app.py b/backend/tests/test_app.py
new file mode 100644
index 0000000..222eeb2
--- /dev/null
+++ b/backend/tests/test_app.py
@@ -0,0 +1,43 @@
+"""Tests for app entrypoint and main module."""
+
+import pytest
+from openmlr.app import app
+
+pytestmark = pytest.mark.asyncio
+
+
+class TestAppCreation:
+ async def test_app_title(self):
+ assert app.title == "OpenMLR"
+
+ async def test_app_version(self):
+ assert app.version == "2.0.0"
+
+ async def test_app_routers_registered(self):
+ route_paths = [r.path for r in app.routes]
+ assert "/api/auth/register" in route_paths
+ assert "/api/auth/login" in route_paths
+ assert "/api/health" in route_paths
+ assert "/api/message" in route_paths
+
+ async def test_cors_middleware_configured(self):
+ from fastapi.middleware.cors import CORSMiddleware
+ middlewares = [m.cls for m in app.user_middleware]
+ assert CORSMiddleware in middlewares
+
+ async def test_global_exception_handler_configured(self):
+ from starlette.responses import JSONResponse
+ handlers = app.exception_handlers
+ assert Exception in handlers
+
+
+class TestMainModule:
+ async def test_main_is_callable(self):
+ from openmlr.main import main
+ assert callable(main)
+
+ async def test_main_contains_uvicorn_import(self):
+ from openmlr.main import main
+ import inspect
+ source = inspect.getsource(main)
+ assert "uvicorn" in source
diff --git a/backend/tests/test_celery_app.py b/backend/tests/test_celery_app.py
new file mode 100644
index 0000000..21c2cda
--- /dev/null
+++ b/backend/tests/test_celery_app.py
@@ -0,0 +1,42 @@
+"""Tests for Celery app configuration."""
+
+import pytest
+from celery import Celery
+
+
+class TestCeleryApp:
+ def test_is_celery_instance(self):
+ from openmlr.celery_app import celery_app
+ assert isinstance(celery_app, Celery)
+
+ def test_has_correct_name(self):
+ from openmlr.celery_app import celery_app
+ assert celery_app.main == "openmlr"
+
+ def test_config_has_serializer(self):
+ from openmlr.celery_app import celery_app
+ assert celery_app.conf.task_serializer == "json"
+ assert "json" in celery_app.conf.accept_content
+
+ def test_config_has_timezone(self):
+ from openmlr.celery_app import celery_app
+ assert celery_app.conf.timezone == "UTC"
+ assert celery_app.conf.enable_utc is True
+
+ def test_config_worker_settings(self):
+ from openmlr.celery_app import celery_app
+ assert celery_app.conf.worker_prefetch_multiplier == 1
+ assert celery_app.conf.task_acks_late is True
+
+ def test_config_result_expiry(self):
+ from openmlr.celery_app import celery_app
+ assert celery_app.conf.result_expires == 3600
+
+ def test_task_routing_configured(self):
+ from openmlr.celery_app import celery_app
+ routes = celery_app.conf.task_routes
+ assert routes is not None
+
+ def test_get_celery_app(self):
+ from openmlr.celery_app import get_celery_app, celery_app
+ assert get_celery_app() is celery_app
diff --git a/backend/tests/test_db_engine.py b/backend/tests/test_db_engine.py
new file mode 100644
index 0000000..aa47ab3
--- /dev/null
+++ b/backend/tests/test_db_engine.py
@@ -0,0 +1,43 @@
+"""Tests for database engine factory and configuration."""
+
+import pytest
+
+
+class TestEngineConfig:
+ def test_database_url_exists(self):
+ from openmlr.db.engine import DATABASE_URL
+ assert DATABASE_URL is not None
+ assert len(DATABASE_URL) > 0
+
+ def test_engine_created(self):
+ from openmlr.db.engine import engine
+ assert engine is not None
+
+ def test_async_session_created(self):
+ from openmlr.db.engine import async_session
+ assert async_session is not None
+
+ def test_worker_engine_context_var(self):
+ from openmlr.db.engine import _worker_engine
+ from contextvars import ContextVar
+ assert isinstance(_worker_engine, ContextVar)
+
+
+@pytest.mark.asyncio
+class TestGetWorkerSession:
+ async def test_returns_sessionmaker(self):
+ from openmlr.db.engine import get_worker_session
+ result = get_worker_session()
+ from sqlalchemy.ext.asyncio import async_sessionmaker
+ assert isinstance(result, async_sessionmaker)
+
+
+@pytest.mark.asyncio
+class TestGetDB:
+ async def test_yields_session(self):
+ from openmlr.db.engine import get_db
+ sessions = []
+ async for s in get_db():
+ sessions.append(s)
+ break
+ assert len(sessions) == 1
diff --git a/backend/tests/test_db_operations.py b/backend/tests/test_db_operations.py
new file mode 100644
index 0000000..ba7263d
--- /dev/null
+++ b/backend/tests/test_db_operations.py
@@ -0,0 +1,393 @@
+"""Tests for database CRUD operations."""
+
+import pytest
+import pytest_asyncio
+from sqlalchemy.ext.asyncio import AsyncSession
+
+pytestmark = pytest.mark.asyncio
+
+from openmlr.db import operations as ops
+from openmlr.db.models import UserSetting
+
+
+class TestConversationOperations:
+ async def test_create_conversation(self, db_session: AsyncSession, test_user):
+ conv = await ops.create_conversation(db_session, test_user.id, title="Test Conv")
+ assert conv.id is not None
+ assert conv.title == "Test Conv"
+ assert conv.user_id == test_user.id
+ assert conv.mode == "general"
+
+ async def test_create_conversation_with_model(self, db_session: AsyncSession, test_user):
+ conv = await ops.create_conversation(db_session, test_user.id, model="gpt-4o", mode="coding")
+ assert conv.model == "gpt-4o"
+ assert conv.mode == "coding"
+
+ async def test_get_conversations_empty(self, db_session: AsyncSession, test_user):
+ convs = await ops.get_conversations(db_session, test_user.id)
+ assert convs == []
+
+ async def test_get_conversations(self, db_session: AsyncSession, test_user):
+ await ops.create_conversation(db_session, test_user.id, title="Conv 1")
+ await ops.create_conversation(db_session, test_user.id, title="Conv 2")
+ convs = await ops.get_conversations(db_session, test_user.id)
+ assert len(convs) == 2
+ assert convs[0].title == "Conv 2" # most recent first
+
+ async def test_get_conversation_by_id(self, db_session: AsyncSession, test_user):
+ conv = await ops.create_conversation(db_session, test_user.id, title="Find Me")
+ found = await ops.get_conversation_by_id(db_session, conv.id)
+ assert found is not None
+ assert found.title == "Find Me"
+
+ async def test_get_conversation_by_id_not_found(self, db_session: AsyncSession):
+ found = await ops.get_conversation_by_id(db_session, 9999)
+ assert found is None
+
+ async def test_get_conversation_by_uuid(self, db_session: AsyncSession, test_user):
+ conv = await ops.create_conversation(db_session, test_user.id)
+ found = await ops.get_conversation_by_uuid(db_session, conv.uuid)
+ assert found is not None
+ assert found.id == conv.id
+
+ async def test_get_conversation_by_uuid_not_found(self, db_session: AsyncSession):
+ found = await ops.get_conversation_by_uuid(db_session, "nonexistent")
+ assert found is None
+
+ async def test_delete_conversation(self, db_session: AsyncSession, test_user):
+ conv = await ops.create_conversation(db_session, test_user.id)
+ deleted = await ops.delete_conversation(db_session, conv.id)
+ assert deleted is True
+ found = await ops.get_conversation_by_id(db_session, conv.id)
+ assert found is None
+
+ async def test_delete_nonexistent(self, db_session: AsyncSession):
+ deleted = await ops.delete_conversation(db_session, 9999)
+ assert deleted is False
+
+ async def test_update_title(self, db_session: AsyncSession, test_user):
+ conv = await ops.create_conversation(db_session, test_user.id, title="Old")
+ await ops.update_conversation_title(db_session, conv.id, "New Title")
+ found = await ops.get_conversation_by_id(db_session, conv.id)
+ assert found.title == "New Title"
+
+ async def test_update_model(self, db_session: AsyncSession, test_user):
+ conv = await ops.create_conversation(db_session, test_user.id)
+ await ops.update_conversation_model(db_session, conv.id, "claude-sonnet-4")
+ found = await ops.get_conversation_by_id(db_session, conv.id)
+ assert found.model == "claude-sonnet-4"
+
+ async def test_increment_user_message_count(self, db_session: AsyncSession, test_user):
+ conv = await ops.create_conversation(db_session, test_user.id)
+ assert conv.user_message_count == 0
+ await ops.increment_user_message_count(db_session, conv.id)
+ found = await ops.get_conversation_by_id(db_session, conv.id)
+ assert found.user_message_count == 1
+
+ async def test_conversations_isolated_by_user(self, db_session: AsyncSession, test_user):
+ # Create another user
+ from openmlr.db.models import User
+ from openmlr.auth.security import hash_password
+ user2 = User(username="user2", password_hash=hash_password("pwd"), is_active=True)
+ db_session.add(user2)
+ await db_session.flush()
+
+ await ops.create_conversation(db_session, test_user.id, title="User 1 Conv")
+ await ops.create_conversation(db_session, user2.id, title="User 2 Conv")
+
+ convs_u1 = await ops.get_conversations(db_session, test_user.id)
+ convs_u2 = await ops.get_conversations(db_session, user2.id)
+ assert len(convs_u1) == 1
+ assert len(convs_u2) == 1
+ assert convs_u1[0].title == "User 1 Conv"
+ assert convs_u2[0].title == "User 2 Conv"
+
+
+class TestMessageOperations:
+ @pytest_asyncio.fixture(autouse=True)
+ async def _conv(self, db_session: AsyncSession, test_user):
+ self.conv = await ops.create_conversation(db_session, test_user.id, title="Msg Test")
+
+ async def test_add_message(self, db_session: AsyncSession):
+ msg = await ops.add_message(db_session, self.conv.id, "user", "Hello!")
+ assert msg.id is not None
+ assert msg.role == "user"
+ assert msg.content == "Hello!"
+ assert msg.conversation_id == self.conv.id
+
+ async def test_add_message_with_metadata(self, db_session: AsyncSession):
+ msg = await ops.add_message(
+ db_session, self.conv.id, "assistant", "Done",
+ metadata={"tool": "search", "duration": 1.5},
+ )
+ assert msg.meta == {"tool": "search", "duration": 1.5}
+
+ async def test_get_messages_empty(self, db_session: AsyncSession):
+ msgs = await ops.get_messages(db_session, self.conv.id)
+ assert msgs == []
+
+ async def test_get_messages(self, db_session: AsyncSession):
+ await ops.add_message(db_session, self.conv.id, "user", "First")
+ await ops.add_message(db_session, self.conv.id, "assistant", "Second")
+ msgs = await ops.get_messages(db_session, self.conv.id)
+ assert len(msgs) == 2
+ assert msgs[0].content == "First"
+ assert msgs[1].content == "Second"
+
+ async def test_clear_messages(self, db_session: AsyncSession):
+ await ops.add_message(db_session, self.conv.id, "user", "A")
+ await ops.add_message(db_session, self.conv.id, "assistant", "B")
+ await ops.clear_messages(db_session, self.conv.id)
+ msgs = await ops.get_messages(db_session, self.conv.id)
+ assert msgs == []
+
+
+class TestSettingsOperations:
+ async def test_set_and_get_user_setting(self, db_session: AsyncSession, test_user):
+ await ops.set_user_setting(db_session, test_user.id, "agent", "default_model", "gpt-4o")
+ val = await ops.get_user_setting(db_session, test_user.id, "agent", "default_model")
+ assert val == "gpt-4o"
+
+ async def test_set_user_setting_update(self, db_session: AsyncSession, test_user):
+ await ops.set_user_setting(db_session, test_user.id, "agent", "yolo_mode", True)
+ await ops.set_user_setting(db_session, test_user.id, "agent", "yolo_mode", False)
+ val = await ops.get_user_setting(db_session, test_user.id, "agent", "yolo_mode")
+ assert val is False
+
+ async def test_get_nonexistent_setting(self, db_session: AsyncSession, test_user):
+ val = await ops.get_user_setting(db_session, test_user.id, "agent", "nonexistent")
+ assert val is None
+
+ async def test_get_all_settings(self, db_session: AsyncSession, test_user):
+ await ops.set_user_setting(db_session, test_user.id, "agent", "model", "gpt-4o")
+ await ops.set_user_setting(db_session, test_user.id, "providers", "openai_key", "sk-123")
+ settings = await ops.get_all_settings(db_session, test_user.id)
+ assert "agent" in settings
+ assert "providers" in settings
+ assert settings["agent"]["model"] == "gpt-4o"
+ assert settings["providers"]["openai_key"] == "sk-123"
+
+ async def test_get_all_settings_by_category(self, db_session: AsyncSession, test_user):
+ await ops.set_user_setting(db_session, test_user.id, "agent", "model", "gpt-4o")
+ await ops.set_user_setting(db_session, test_user.id, "providers", "key", "val")
+ agent_settings = await ops.get_all_settings(db_session, test_user.id, category="agent")
+ assert "agent" in agent_settings
+ assert "providers" not in agent_settings
+
+ async def test_delete_user_setting(self, db_session: AsyncSession, test_user):
+ await ops.set_user_setting(db_session, test_user.id, "agent", "model", "gpt-4o")
+ await ops.delete_user_setting(db_session, test_user.id, "agent", "model")
+ val = await ops.get_user_setting(db_session, test_user.id, "agent", "model")
+ assert val is None
+
+ async def test_set_user_setting_int_value(self, db_session: AsyncSession, test_user):
+ await ops.set_user_setting(db_session, test_user.id, "agent", "max_iterations", 100)
+ val = await ops.get_user_setting(db_session, test_user.id, "agent", "max_iterations")
+ assert val == 100
+
+ async def test_set_user_setting_float_value(self, db_session: AsyncSession, test_user):
+ await ops.set_user_setting(db_session, test_user.id, "agent", "threshold", 0.85)
+ val = await ops.get_user_setting(db_session, test_user.id, "agent", "threshold")
+ assert val == 0.85
+
+ async def test_get_user_agent_settings(self, db_session: AsyncSession, test_user):
+ await ops.set_user_setting(db_session, test_user.id, "agent", "default_model", "claude")
+ await ops.set_user_setting(db_session, test_user.id, "agent", "yolo_mode", True)
+ settings = await ops.get_user_agent_settings(db_session, test_user.id)
+ assert settings["default_model"] == "claude"
+ assert settings["yolo_mode"] is True
+
+
+class TestTaskOperations:
+ @pytest_asyncio.fixture(autouse=True)
+ async def _conv(self, db_session: AsyncSession, test_user):
+ self.conv = await ops.create_conversation(db_session, test_user.id)
+
+ async def test_upsert_tasks_create(self, db_session: AsyncSession):
+ tasks = [
+ {"title": "Task 1", "status": "pending"},
+ {"title": "Task 2", "status": "in_progress"},
+ ]
+ result = await ops.upsert_conversation_tasks(db_session, self.conv.id, tasks)
+ assert len(result) == 2
+ assert result[0].title == "Task 1"
+ assert result[0].order_index == 0
+ assert result[1].order_index == 1
+
+ async def test_upsert_tasks_replace(self, db_session: AsyncSession):
+ await ops.upsert_conversation_tasks(db_session, self.conv.id, [
+ {"title": "Old Task"},
+ ])
+ result = await ops.upsert_conversation_tasks(db_session, self.conv.id, [
+ {"title": "New Task"},
+ ])
+ assert len(result) == 1
+ assert result[0].title == "New Task"
+
+ async def test_get_tasks_empty(self, db_session: AsyncSession):
+ tasks = await ops.get_conversation_tasks(db_session, self.conv.id)
+ assert tasks == []
+
+ async def test_get_tasks(self, db_session: AsyncSession):
+ await ops.upsert_conversation_tasks(db_session, self.conv.id, [
+ {"title": "T1", "status": "pending", "priority": "high"},
+ {"title": "T2", "status": "completed"},
+ ])
+ tasks = await ops.get_conversation_tasks(db_session, self.conv.id)
+ assert len(tasks) == 2
+ assert tasks[0].title == "T1"
+ assert tasks[0].priority == "high"
+
+ async def test_update_task_status(self, db_session: AsyncSession):
+ await ops.upsert_conversation_tasks(db_session, self.conv.id, [
+ {"title": "Do this", "status": "pending"},
+ ])
+ ok = await ops.update_task_status(db_session, self.conv.id, 0, "completed")
+ assert ok is True
+ tasks = await ops.get_conversation_tasks(db_session, self.conv.id)
+ assert tasks[0].status == "completed"
+
+ async def test_update_task_status_out_of_range(self, db_session: AsyncSession):
+ await ops.upsert_conversation_tasks(db_session, self.conv.id, [
+ {"title": "Only one"},
+ ])
+ ok = await ops.update_task_status(db_session, self.conv.id, 5, "completed")
+ assert ok is False
+
+
+class TestResourceOperations:
+ @pytest_asyncio.fixture(autouse=True)
+ async def _conv(self, db_session: AsyncSession, test_user):
+ self.conv = await ops.create_conversation(db_session, test_user.id)
+
+ async def test_add_resource(self, db_session: AsyncSession):
+ res = await ops.add_conversation_resource(
+ db_session, self.conv.id,
+ title="Paper 1", resource_type="paper", url="https://example.com",
+ )
+ assert res.id is not None
+ assert res.title == "Paper 1"
+ assert res.type == "paper"
+ assert res.url == "https://example.com"
+
+ async def test_get_resources(self, db_session: AsyncSession):
+ await ops.add_conversation_resource(db_session, self.conv.id, title="R1", resource_type="doc")
+ await ops.add_conversation_resource(db_session, self.conv.id, title="R2", resource_type="code")
+ resources = await ops.get_conversation_resources(db_session, self.conv.id)
+ assert len(resources) == 2
+
+ async def test_get_resource_by_id(self, db_session: AsyncSession):
+ res = await ops.add_conversation_resource(
+ db_session, self.conv.id, title="Test", resource_type="doc",
+ )
+ found = await ops.get_resource_by_id(db_session, res.resource_id)
+ assert found is not None
+ assert found.title == "Test"
+
+ async def test_upsert_resources_replace(self, db_session: AsyncSession):
+ await ops.add_conversation_resource(db_session, self.conv.id, title="Old", resource_type="doc")
+ result = await ops.upsert_conversation_resources(db_session, self.conv.id, [
+ {"title": "New", "type": "doc"},
+ ])
+ assert len(result) == 1
+ assert result[0].title == "New"
+
+ async def test_upsert_plan_resource_new(self, db_session: AsyncSession):
+ res = await ops.upsert_plan_resource(db_session, self.conv.id, "# Plan content")
+ assert res.title == "PLAN.md"
+ assert res.type == "plan"
+ assert res.content == "# Plan content"
+
+ async def test_upsert_plan_resource_update(self, db_session: AsyncSession):
+ await ops.upsert_plan_resource(db_session, self.conv.id, "First version")
+ res = await ops.upsert_plan_resource(db_session, self.conv.id, "Updated version")
+ assert res.content == "Updated version"
+
+ async def test_upsert_paper_resource(self, db_session: AsyncSession):
+ res = await ops.upsert_paper_resource(
+ db_session, self.conv.id, "My Paper", "## Abstract\nContent",
+ )
+ assert res.title == "My Paper"
+ assert res.type == "paper"
+ assert "Abstract" in res.content
+
+ async def test_upsert_resource_create(self, db_session: AsyncSession):
+ res = await ops.upsert_resource(
+ db_session, self.conv.id,
+ resource_id="custom-id", title="Custom", resource_type="report",
+ content="Report content",
+ )
+ assert res.resource_id == "custom-id"
+ assert res.content == "Report content"
+
+ async def test_upsert_resource_update(self, db_session: AsyncSession):
+ await ops.upsert_resource(
+ db_session, self.conv.id,
+ resource_id="rid", title="Old Title", resource_type="doc", content="Old",
+ )
+ res = await ops.upsert_resource(
+ db_session, self.conv.id,
+ resource_id="rid", title="New Title", resource_type="doc", content="New",
+ )
+ assert res.title == "New Title"
+ assert res.content == "New"
+
+
+class TestAgentJobOperations:
+ @pytest_asyncio.fixture(autouse=True)
+ async def _conv(self, db_session: AsyncSession, test_user):
+ self.conv = await ops.create_conversation(db_session, test_user.id)
+
+ async def test_create_agent_job(self, db_session: AsyncSession):
+ job = await ops.create_agent_job(
+ db_session, self.conv.id, self.conv.user_id, "Process this",
+ )
+ assert job.job_id is not None
+ assert job.status == "queued"
+ assert job.message == "Process this"
+
+ async def test_get_agent_job(self, db_session: AsyncSession):
+ job = await ops.create_agent_job(
+ db_session, self.conv.id, self.conv.user_id, "Test",
+ )
+ found = await ops.get_agent_job(db_session, job.job_id)
+ assert found is not None
+ assert found.status == "queued"
+
+ async def test_get_agent_job_not_found(self, db_session: AsyncSession):
+ found = await ops.get_agent_job(db_session, "nonexistent")
+ assert found is None
+
+ async def test_get_active_jobs(self, db_session: AsyncSession):
+ job1 = await ops.create_agent_job(db_session, self.conv.id, self.conv.user_id, "Job 1")
+ job2 = await ops.create_agent_job(db_session, self.conv.id, self.conv.user_id, "Job 2")
+ # Complete one
+ await ops.update_job_status(db_session, job1.job_id, "completed")
+ active = await ops.get_active_jobs_for_conversation(db_session, self.conv.id)
+ assert len(active) == 1
+ assert active[0].job_id == job2.job_id
+
+ async def test_update_job_status_to_running(self, db_session: AsyncSession):
+ job = await ops.create_agent_job(db_session, self.conv.id, self.conv.user_id, "Test")
+ ok = await ops.update_job_status(db_session, job.job_id, "running")
+ assert ok is True
+ found = await ops.get_agent_job(db_session, job.job_id)
+ assert found.status == "running"
+ assert found.started_at is not None
+
+ async def test_update_job_status_to_completed(self, db_session: AsyncSession):
+ job = await ops.create_agent_job(db_session, self.conv.id, self.conv.user_id, "Test")
+ await ops.update_job_status(db_session, job.job_id, "completed")
+ found = await ops.get_agent_job(db_session, job.job_id)
+ assert found.status == "completed"
+ assert found.completed_at is not None
+
+ async def test_update_job_status_not_found(self, db_session: AsyncSession):
+ ok = await ops.update_job_status(db_session, "nonexistent", "completed")
+ assert ok is False
+
+ async def test_update_job_status_with_error(self, db_session: AsyncSession):
+ job = await ops.create_agent_job(db_session, self.conv.id, self.conv.user_id, "Test")
+ await ops.update_job_status(db_session, job.job_id, "failed", error="Something broke")
+ found = await ops.get_agent_job(db_session, job.job_id)
+ assert found.error == "Something broke"
diff --git a/backend/tests/test_dependencies.py b/backend/tests/test_dependencies.py
new file mode 100644
index 0000000..c5fc776
--- /dev/null
+++ b/backend/tests/test_dependencies.py
@@ -0,0 +1,48 @@
+"""Tests for FastAPI dependencies — get_config, get_current_user."""
+
+import pytest
+from httpx import AsyncClient
+
+pytestmark = pytest.mark.asyncio
+
+
+class TestGetConfig:
+ async def test_get_config_returns_agent_config(self):
+ from openmlr.dependencies import get_config
+ config = get_config()
+ assert config is not None
+ assert hasattr(config, "model_name")
+ assert hasattr(config, "max_iterations")
+
+ async def test_get_config_is_cached(self):
+ from openmlr.dependencies import get_config
+ config1 = get_config()
+ config2 = get_config()
+ assert config1 is config2
+
+
+class TestGetCurrentUser:
+ async def test_valid_token_returns_user(self, auth_client: AsyncClient):
+ resp = await auth_client.get("/api/auth/me")
+ assert resp.status_code == 200
+ data = resp.json()
+ assert data["username"] == "testuser"
+
+ async def test_no_token_returns_401(self, client: AsyncClient):
+ resp = await client.get("/api/auth/me")
+ assert resp.status_code == 401
+
+ async def test_invalid_token_returns_401(self, client: AsyncClient):
+ client.headers["Authorization"] = "Bearer invalid.token.here"
+ resp = await client.get("/api/auth/me")
+ assert resp.status_code == 401
+
+
+class TestGetDB:
+ async def test_db_session_yielded(self, client: AsyncClient):
+ from openmlr.dependencies import get_db
+ sessions = []
+ async for s in get_db():
+ sessions.append(s)
+ break
+ assert len(sessions) == 1
diff --git a/backend/tests/test_job_manager.py b/backend/tests/test_job_manager.py
new file mode 100644
index 0000000..8305cb8
--- /dev/null
+++ b/backend/tests/test_job_manager.py
@@ -0,0 +1,84 @@
+"""Tests for JobManager — background job creation and status tracking."""
+
+import pytest
+import pytest_asyncio
+from sqlalchemy.ext.asyncio import AsyncSession
+
+pytestmark = pytest.mark.asyncio
+
+from openmlr.services.job_manager import JobManager, get_job_manager
+from openmlr.db import operations as ops
+
+
+@pytest_asyncio.fixture
+async def conversation(db_session: AsyncSession, test_user):
+ return await ops.create_conversation(db_session, test_user.id, title="Job Test")
+
+
+class TestJobManager:
+ async def test_get_job_manager_singleton(self):
+ jm1 = get_job_manager()
+ jm2 = get_job_manager()
+ assert jm1 is jm2
+
+ async def test_create_job_disabled_by_default(self, db_session: AsyncSession, conversation, test_user):
+ jm = JobManager()
+ # USE_BACKGROUND_JOBS is controlled by env — test without making assumptions
+ from openmlr.services.job_manager import USE_BACKGROUND_JOBS
+ job = await jm.create_job(
+ db=db_session,
+ conversation_id=conversation.id,
+ user_id=test_user.id,
+ message="Test message",
+ )
+ if not USE_BACKGROUND_JOBS:
+ assert job is None
+ else:
+ assert job is not None
+ assert job.status == "queued"
+
+ async def test_get_job_status_nonexistent(self, db_session: AsyncSession):
+ jm = JobManager()
+ status = await jm.get_job_status(db_session, "nonexistent")
+ assert status is None
+
+ async def test_get_job_status_from_db(self, db_session: AsyncSession, conversation, test_user):
+ job = await ops.create_agent_job(
+ db_session, conversation.id, test_user.id, "Test",
+ )
+ jm = JobManager()
+ status = await jm.get_job_status(db_session, job.job_id)
+ assert status is not None
+ assert status["job_id"] == job.job_id
+ assert status["status"] == "queued"
+ assert "created_at" in status
+
+ async def test_get_active_jobs(self, db_session: AsyncSession, conversation, test_user):
+ job1 = await ops.create_agent_job(db_session, conversation.id, test_user.id, "J1")
+ job2 = await ops.create_agent_job(db_session, conversation.id, test_user.id, "J2")
+ await ops.update_job_status(db_session, job2.job_id, "completed")
+
+ jm = JobManager()
+ active = await jm.get_active_jobs(db_session, conversation.id)
+ assert len(active) == 1
+ assert active[0]["job_id"] == job1.job_id
+
+ async def test_cancel_queued_job(self, db_session: AsyncSession, conversation, test_user):
+ job = await ops.create_agent_job(db_session, conversation.id, test_user.id, "Test")
+ jm = JobManager()
+ cancelled = await jm.cancel_job(db_session, job.job_id)
+ assert cancelled is True
+ found = await ops.get_agent_job(db_session, job.job_id)
+ assert found.status == "cancelled"
+
+ async def test_cancel_nonexistent_job(self, db_session: AsyncSession):
+ jm = JobManager()
+ cancelled = await jm.cancel_job(db_session, "nonexistent")
+ assert cancelled is False
+
+ async def test_cancel_already_running_job(self, db_session: AsyncSession, conversation, test_user):
+ job = await ops.create_agent_job(db_session, conversation.id, test_user.id, "Test")
+ await ops.update_job_status(db_session, job.job_id, "running")
+ jm = JobManager()
+ cancelled = await jm.cancel_job(db_session, job.job_id)
+ assert cancelled is False
diff --git a/backend/tests/test_llm.py b/backend/tests/test_llm.py
new file mode 100644
index 0000000..e636ab1
--- /dev/null
+++ b/backend/tests/test_llm.py
@@ -0,0 +1,246 @@
+"""Unit tests for LLMProvider — API key resolution, model normalization, tool param conversion, retry logic."""
+
+import pytest
+
+from openmlr.agent.llm import LLMProvider
+
+
+class TestGetApiKey:
+ @pytest.mark.parametrize("model_name,env_var", [
+ ("openai/gpt-4o", "OPENAI_API_KEY"),
+ ("anthropic/claude-sonnet-4", "ANTHROPIC_API_KEY"),
+ ("openrouter/openai/gpt-4o", "OPENROUTER_API_KEY"),
+ ("opencode-go/glm-5.1", "OPENCODE_GO_API_KEY"),
+ ])
+ def test_model_prefix_maps_to_env_var(self, monkeypatch, model_name, env_var):
+ monkeypatch.setenv(env_var, f"test-key-{env_var}")
+ key = LLMProvider._get_api_key(model_name)
+ assert key == f"test-key-{env_var}"
+
+ def test_local_model_uses_local_api_key(self, monkeypatch):
+ monkeypatch.setenv("LOCAL_API_KEY", "local")
+ key = LLMProvider._get_api_key("local/default")
+ assert key == "local"
+
+ def test_ollama_defaults_to_not_needed(self, monkeypatch):
+ monkeypatch.delenv("LOCAL_API_KEY", raising=False)
+ monkeypatch.delenv("OPENAI_API_KEY", raising=False)
+ monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False)
+ monkeypatch.delenv("OPENROUTER_API_KEY", raising=False)
+ key = LLMProvider._get_api_key("ollama/llama3.1")
+ assert key == "not-needed"
+
+ def test_lmstudio_defaults_to_not_needed(self, monkeypatch):
+ monkeypatch.delenv("LOCAL_API_KEY", raising=False)
+ monkeypatch.delenv("OPENAI_API_KEY", raising=False)
+ monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False)
+ monkeypatch.delenv("OPENROUTER_API_KEY", raising=False)
+ key = LLMProvider._get_api_key("lmstudio/default")
+ assert key == "not-needed"
+
+ def test_fallback_to_any_available_key(self, monkeypatch):
+ monkeypatch.delenv("OPENAI_API_KEY", raising=False)
+ monkeypatch.setenv("ANTHROPIC_API_KEY", "sk-anthro")
+ monkeypatch.delenv("OPENROUTER_API_KEY", raising=False)
+ key = LLMProvider._get_api_key("unknown/model")
+ assert key == "sk-anthro"
+
+
+class TestNormalizeModel:
+ @pytest.mark.parametrize("full_name,normalized", [
+ ("openai/gpt-4o", "gpt-4o"),
+ ("anthropic/claude-sonnet-4", "claude-sonnet-4"),
+ ("openrouter/anthropic/claude-3-sonnet", "anthropic/claude-3-sonnet"),
+ ("ollama/llama3.1", "llama3.1"),
+ ("lmstudio/default", "default"),
+ ("local/custom-model", "custom-model"),
+ ("opencode-go/glm-5.1", "glm-5.1"),
+ ])
+ def test_normalize_strips_prefix(self, full_name, normalized):
+ result = LLMProvider._normalize_model(full_name)
+ assert result == normalized
+
+ def test_no_prefix_passes_through(self):
+ result = LLMProvider._normalize_model("gpt-4o")
+ assert result == "gpt-4o"
+
+
+class TestGetBaseUrl:
+ def test_local_base_url(self):
+ assert LLMProvider._get_base_url("local/default") == "http://localhost:8000/v1"
+
+ def test_ollama_base_url(self):
+ assert LLMProvider._get_base_url("ollama/llama3") == "http://localhost:11434/v1"
+
+ def test_lmstudio_base_url(self):
+ assert LLMProvider._get_base_url("lmstudio/default") == "http://localhost:1234/v1"
+
+ def test_openrouter_base_url(self):
+ assert LLMProvider._get_base_url("openrouter/gpt-4o") == "https://openrouter.ai/api/v1"
+
+ def test_opencode_go_base_url(self):
+ assert LLMProvider._get_base_url("opencode-go/glm-5.1") == "https://opencode.ai/zen/go/v1"
+
+ def test_standard_provider_no_base_url(self):
+ assert LLMProvider._get_base_url("openai/gpt-4o") is None
+ assert LLMProvider._get_base_url("anthropic/claude-4") is None
+
+ def test_local_custom_base_url(self, monkeypatch):
+ monkeypatch.setenv("LOCAL_API_BASE", "http://custom:9999/v1")
+ assert LLMProvider._get_base_url("local/model") == "http://custom:9999/v1"
+
+ def test_ollama_custom_base_url(self, monkeypatch):
+ monkeypatch.setenv("OLLAMA_API_BASE", "http://remote:11435/v1")
+ assert LLMProvider._get_base_url("ollama/model") == "http://remote:11435/v1"
+
+ def test_lmstudio_custom_base_url(self, monkeypatch):
+ monkeypatch.setenv("LMSTUDIO_API_BASE", "http://other:1235/v1")
+ assert LLMProvider._get_base_url("lmstudio/model") == "http://other:1235/v1"
+
+
+class TestAnthropicFormatDetection:
+ def test_native_anthropic(self):
+ assert LLMProvider._is_anthropic_model("anthropic/claude-sonnet-4") is True
+
+ def test_openrouter_is_not_anthropic(self):
+ assert LLMProvider._is_anthropic_model("openrouter/anthropic/claude-4") is False
+
+ def test_opencode_go_deepseek_is_anthropic_format(self):
+ assert LLMProvider._is_opencode_go_anthropic_format("opencode-go/deepseek-v4-pro") is True
+
+ def test_opencode_go_deepseek_flash_is_anthropic_format(self):
+ assert LLMProvider._is_opencode_go_anthropic_format("opencode-go/deepseek-v4-flash") is True
+
+ def test_opencode_go_minimax_is_anthropic_format(self):
+ assert LLMProvider._is_opencode_go_anthropic_format("opencode-go/minimax-m2.7") is True
+
+ def test_opencode_go_glm_is_not_anthropic_format(self):
+ assert LLMProvider._is_opencode_go_anthropic_format("opencode-go/glm-5.1") is False
+
+ def test_uses_anthropic_format_native(self):
+ assert LLMProvider._uses_anthropic_format("anthropic/claude-4") is True
+
+ def test_uses_anthropic_format_opencode_go_deepseek(self):
+ assert LLMProvider._uses_anthropic_format("opencode-go/deepseek-v4-pro") is True
+
+ def test_uses_anthropic_format_openai(self):
+ assert LLMProvider._uses_anthropic_format("openai/gpt-4o") is False
+
+
+class TestToolParamConversion:
+ def test_openai_tool_param_none(self):
+ assert LLMProvider._openai_tool_param(None) is None
+
+ def test_openai_tool_param_empty(self):
+ assert LLMProvider._openai_tool_param([]) is None
+
+ def test_openai_tool_param_raw_format(self):
+ tools = [
+ {"name": "search", "description": "Search web", "parameters": {"type": "object", "properties": {}}},
+ ]
+ result = LLMProvider._openai_tool_param(tools)
+ assert len(result) == 1
+ assert result[0]["type"] == "function"
+ assert result[0]["function"]["name"] == "search"
+
+ def test_openai_tool_param_already_formatted(self):
+ tools = [
+ {"type": "function", "function": {"name": "bash", "description": "Run cmd", "parameters": {}}},
+ ]
+ result = LLMProvider._openai_tool_param(tools)
+ assert len(result) == 1
+ assert result[0]["type"] == "function"
+
+ def test_anthropic_tool_param_none(self):
+ assert LLMProvider._anthropic_tool_param(None) is None
+
+ def test_anthropic_tool_param_conversion(self):
+ tools = [
+ {
+ "type": "function",
+ "function": {
+ "name": "read",
+ "description": "Read file",
+ "parameters": {
+ "type": "object",
+ "properties": {"path": {"type": "string"}},
+ "required": ["path"],
+ },
+ },
+ },
+ ]
+ result = LLMProvider._anthropic_tool_param(tools)
+ assert len(result) == 1
+ assert result[0]["name"] == "read"
+ assert result[0]["input_schema"]["type"] == "object"
+ assert "path" in result[0]["input_schema"]["properties"]
+ assert result[0]["input_schema"]["required"] == ["path"]
+
+ def test_anthropic_tool_param_unwrapped(self):
+ tools = [
+ {"name": "bash", "description": "Run cmd", "parameters": {"type": "object", "properties": {}}},
+ ]
+ result = LLMProvider._anthropic_tool_param(tools)
+ assert len(result) == 1
+ assert result[0]["name"] == "bash"
+
+
+class TestToAnthropicMessages:
+ def test_separates_system_prompt(self):
+ messages = [
+ {"role": "system", "content": "You are an assistant."},
+ {"role": "user", "content": "Hello"},
+ ]
+ system, chat = LLMProvider._to_anthropic_messages(messages)
+ assert "You are an assistant" in system
+ assert len(chat) == 1
+ assert chat[0]["role"] == "user"
+ assert chat[0]["content"] == "Hello"
+
+ def test_converts_assistant_with_tool_calls(self):
+ messages = [
+ {"role": "user", "content": "Read file"},
+ {
+ "role": "assistant",
+ "content": "Let me read that.",
+ "tool_calls": [
+ {"id": "tc1", "name": "read_file", "arguments": {"path": "/tmp/test"}},
+ ],
+ },
+ ]
+ _system, chat = LLMProvider._to_anthropic_messages(messages)
+ assert len(chat) == 2
+
+ def test_converts_tool_result(self):
+ messages = [
+ {"role": "tool", "content": "file contents here", "tool_call_id": "tc1"},
+ ]
+ _system, chat = LLMProvider._to_anthropic_messages(messages)
+ assert len(chat) == 1
+ assert chat[0]["role"] == "user"
+
+
+class TestRetryLogic:
+ def test_is_retryable_rate_limit(self):
+ assert LLMProvider._is_retryable(Exception("429 Rate limit exceeded")) is True
+
+ def test_is_retryable_timeout(self):
+ assert LLMProvider._is_retryable(Exception("Connection timeout")) is True
+
+ def test_is_retryable_server_error(self):
+ assert LLMProvider._is_retryable(Exception("server_error 503")) is True
+
+ def test_is_not_retryable_auth_error(self):
+ assert LLMProvider._is_retryable(Exception("Invalid API key")) is False
+
+ def test_is_not_retryable_validation_error(self):
+ assert LLMProvider._is_retryable(Exception("Model not found")) is False
+
+ def test_is_retryable_overloaded(self):
+ assert LLMProvider._is_retryable(Exception("Server is overloaded")) is True
+
+ def test_is_retryable_capacity(self):
+ assert LLMProvider._is_retryable(Exception("Exceeded capacity")) is True
+
+ def test_is_retryable_502(self):
+ assert LLMProvider._is_retryable(Exception("502 Bad Gateway")) is True
diff --git a/backend/tests/test_models.py b/backend/tests/test_models.py
new file mode 100644
index 0000000..c1031ee
--- /dev/null
+++ b/backend/tests/test_models.py
@@ -0,0 +1,186 @@
+"""Tests for Pydantic API models — validation, defaults, serialization."""
+
+import pytest
+from pydantic import ValidationError
+
+from openmlr.models import (
+ UserRegister, UserLogin, TokenResponse, UserInfo,
+ ConversationCreate, ConversationResponse, MessageResponse, ConversationDetail,
+ MessageSend, ApprovalRequest,
+ SettingUpdate, ProviderConfig, ModelSwitch, AgentEvent,
+)
+
+
+class TestUserRegister:
+ def test_valid(self):
+ u = UserRegister(username="testuser", password="testpassword123")
+ assert u.username == "testuser"
+ assert u.password == "testpassword123"
+ assert u.display_name is None
+
+ def test_with_display_name(self):
+ u = UserRegister(username="tester", password="pass123", display_name="Test User")
+ assert u.display_name == "Test User"
+
+ def test_username_too_short(self):
+ with pytest.raises(ValidationError):
+ UserRegister(username="ab", password="password123")
+
+ def test_username_too_long(self):
+ with pytest.raises(ValidationError):
+ UserRegister(username="a" * 51, password="password123")
+
+ def test_password_too_short(self):
+ with pytest.raises(ValidationError):
+ UserRegister(username="testuser", password="12345")
+
+ def test_username_exactly_min(self):
+ u = UserRegister(username="abc", password="123456")
+ assert u.username == "abc"
+
+
+class TestUserLogin:
+ def test_valid(self):
+ u = UserLogin(username="testuser", password="secret")
+ assert u.username == "testuser"
+ assert u.password == "secret"
+
+
+class TestTokenResponse:
+ def test_defaults(self):
+ t = TokenResponse(access_token="abc123", user={"id": 1, "username": "test"})
+ assert t.access_token == "abc123"
+ assert t.token_type == "bearer"
+ assert t.user == {"id": 1, "username": "test"}
+
+
+class TestConversationCreate:
+ def test_defaults(self):
+ c = ConversationCreate()
+ assert c.title == "New conversation"
+ assert c.model is None
+ assert c.mode == "general"
+
+ def test_custom(self):
+ c = ConversationCreate(title="Research Q1", model="gpt-4o", mode="research")
+ assert c.title == "Research Q1"
+ assert c.model == "gpt-4o"
+ assert c.mode == "research"
+
+
+class TestConversationResponse:
+ def test_creation(self):
+ from datetime import datetime, timezone
+ now = datetime.now(timezone.utc)
+ c = ConversationResponse(
+ id=1, uuid="abc-def", title="Test Conv", model="gpt-4o",
+ mode="general", user_message_count=5,
+ created_at=now, updated_at=now,
+ )
+ assert c.id == 1
+ assert c.uuid == "abc-def"
+ assert c.user_message_count == 5
+
+
+class TestMessageResponse:
+ def test_creation(self):
+ from datetime import datetime, timezone
+ now = datetime.now(timezone.utc)
+ m = MessageResponse(id=1, role="user", content="Hello", metadata=None, created_at=now)
+ assert m.id == 1
+ assert m.role == "user"
+ assert m.content == "Hello"
+
+ def test_with_metadata(self):
+ from datetime import datetime, timezone
+ now = datetime.now(timezone.utc)
+ m = MessageResponse(id=2, role="assistant", content="Hi", metadata={"tool": "search"}, created_at=now)
+ assert m.metadata == {"tool": "search"}
+
+
+class TestConversationDetail:
+ def test_creation(self):
+ from datetime import datetime, timezone
+ now = datetime.now(timezone.utc)
+ conv = ConversationResponse(
+ id=1, uuid="x", title="C", model=None, mode="general",
+ user_message_count=0, created_at=now, updated_at=now,
+ )
+ msgs = [MessageResponse(id=1, role="user", content="Hi", metadata=None, created_at=now)]
+ cd = ConversationDetail(conversation=conv, messages=msgs)
+ assert len(cd.messages) == 1
+ assert cd.conversation.id == 1
+
+
+class TestMessageSend:
+ def test_basic(self):
+ m = MessageSend(message="Hello world")
+ assert m.message == "Hello world"
+ assert m.mode is None
+
+ def test_with_mode(self):
+ m = MessageSend(message="Research this", mode="research")
+ assert m.mode == "research"
+
+
+class TestApprovalRequest:
+ def test_valid(self):
+ a = ApprovalRequest(approvals={"call_1": True, "call_2": False})
+ assert a.approvals == {"call_1": True, "call_2": False}
+
+ def test_empty(self):
+ a = ApprovalRequest(approvals={})
+ assert a.approvals == {}
+
+
+class TestSettingUpdate:
+ def test_str_value(self):
+ s = SettingUpdate(value="hello")
+ assert s.value == "hello"
+
+ def test_int_value(self):
+ s = SettingUpdate(value=42)
+ assert s.value == 42
+
+ def test_bool_value(self):
+ s = SettingUpdate(value=True)
+ assert s.value is True
+
+ def test_dict_value(self):
+ s = SettingUpdate(value={"key": "val"})
+ assert s.value == {"key": "val"}
+
+
+class TestProviderConfig:
+ def test_empty(self):
+ p = ProviderConfig()
+ assert p.openai_api_key is None
+ assert p.anthropic_api_key is None
+
+ def test_with_keys(self):
+ p = ProviderConfig(openai_api_key="sk-123", brave_api_key="bsk-456")
+ assert p.openai_api_key == "sk-123"
+ assert p.brave_api_key == "bsk-456"
+
+ def test_modal_config(self):
+ p = ProviderConfig(modal_token_id="tid", modal_token_secret="tsec")
+ assert p.modal_token_id == "tid"
+ assert p.modal_token_secret == "tsec"
+
+
+class TestModelSwitch:
+ def test_valid(self):
+ m = ModelSwitch(model="gpt-4o")
+ assert m.model == "gpt-4o"
+
+
+class TestAgentEventPydantic:
+ def test_valid(self):
+ e = AgentEvent(event_type="status", data={"key": "val"})
+ assert e.event_type == "status"
+ assert e.data == {"key": "val"}
+
+ def test_no_data(self):
+ e = AgentEvent(event_type="ping")
+ assert e.event_type == "ping"
+ assert e.data is None
diff --git a/backend/tests/test_models_orm.py b/backend/tests/test_models_orm.py
new file mode 100644
index 0000000..bd2e5ce
--- /dev/null
+++ b/backend/tests/test_models_orm.py
@@ -0,0 +1,299 @@
+"""Tests for SQLAlchemy ORM models — creation, relationships, constraints."""
+
+import pytest
+import pytest_asyncio
+from sqlalchemy.ext.asyncio import AsyncSession
+
+pytestmark = pytest.mark.asyncio
+from sqlalchemy import select
+
+from openmlr.db.models import (
+ User, Conversation, Message, ResearchCorpus,
+ WritingProject, SandboxConfig, ConversationTask,
+ ConversationResource, AgentJob, UserSetting,
+)
+from openmlr.auth.security import hash_password
+
+
+class TestUserModel:
+ async def test_create_user(self, db_session: AsyncSession):
+ user = User(
+ username="newuser",
+ password_hash=hash_password("password123"),
+ display_name="New User",
+ is_active=True,
+ )
+ db_session.add(user)
+ await db_session.commit()
+ assert user.id is not None
+ assert user.username == "newuser"
+ assert user.is_active is True
+ assert user.created_at is not None
+
+ async def test_user_unique_username(self, db_session: AsyncSession, test_user):
+ dup = User(
+ username="testuser",
+ password_hash=hash_password("another"),
+ is_active=True,
+ )
+ db_session.add(dup)
+ with pytest.raises(Exception):
+ await db_session.commit()
+
+ async def test_user_relationships_exist(self, db_session: AsyncSession, test_user):
+ # Async SQLAlchemy requires relationship attributes to exist on the class
+ assert hasattr(User, "conversations")
+ assert hasattr(User, "settings")
+
+
+class TestConversationModel:
+ async def test_create_conversation(self, db_session: AsyncSession, test_user):
+ conv = Conversation(
+ user_id=test_user.id,
+ title="Test Conversation",
+ model="gpt-4o",
+ mode="research",
+ )
+ db_session.add(conv)
+ await db_session.commit()
+ assert conv.id is not None
+ assert conv.uuid is not None
+ assert len(conv.uuid) > 0
+ assert conv.user_message_count == 0
+ assert conv.created_at is not None
+ assert conv.updated_at is not None
+
+ async def test_conversation_user_relationship(self, db_session: AsyncSession, test_user):
+ conv = Conversation(user_id=test_user.id, title="Rel Test")
+ db_session.add(conv)
+ await db_session.commit()
+ await db_session.refresh(conv)
+ assert conv.user_id == test_user.id
+
+
+class TestMessageModel:
+ @pytest_asyncio.fixture(autouse=True)
+ async def _conv(self, db_session: AsyncSession, test_user):
+ conv = Conversation(user_id=test_user.id, title="Msg Model Test")
+ db_session.add(conv)
+ await db_session.commit()
+ self.conv = conv
+
+ async def test_create_message(self, db_session: AsyncSession):
+ msg = Message(
+ conversation_id=self.conv.id,
+ role="user",
+ content="Hello world",
+ )
+ db_session.add(msg)
+ await db_session.commit()
+ assert msg.id is not None
+ assert msg.role == "user"
+ assert msg.content == "Hello world"
+
+ async def test_message_with_metadata(self, db_session: AsyncSession):
+ msg = Message(
+ conversation_id=self.conv.id,
+ role="assistant",
+ content="Done",
+ meta={"tool": "search", "usage": {"tokens": 100}},
+ )
+ db_session.add(msg)
+ await db_session.commit()
+ assert msg.meta == {"tool": "search", "usage": {"tokens": 100}}
+
+ async def test_message_has_created_at(self, db_session: AsyncSession):
+ msg = Message(conversation_id=self.conv.id, role="system", content="Init")
+ db_session.add(msg)
+ await db_session.commit()
+ assert msg.created_at is not None
+
+
+class TestConversationTask:
+ @pytest_asyncio.fixture(autouse=True)
+ async def _conv(self, db_session: AsyncSession, test_user):
+ conv = Conversation(user_id=test_user.id, title="Task Model Test")
+ db_session.add(conv)
+ await db_session.commit()
+ self.conv = conv
+
+ async def test_create_task(self, db_session: AsyncSession):
+ task = ConversationTask(
+ conversation_id=self.conv.id,
+ title="Test task",
+ status="pending",
+ order_index=0,
+ )
+ db_session.add(task)
+ await db_session.commit()
+ assert task.id is not None
+ assert task.title == "Test task"
+ assert task.status == "pending"
+ assert task.order_index == 0
+
+ async def test_task_with_priority(self, db_session: AsyncSession):
+ task = ConversationTask(
+ conversation_id=self.conv.id,
+ title="Urgent",
+ status="pending",
+ priority="high",
+ order_index=0,
+ )
+ db_session.add(task)
+ await db_session.commit()
+ assert task.priority == "high"
+
+
+class TestConversationResource:
+ @pytest_asyncio.fixture(autouse=True)
+ async def _conv(self, db_session: AsyncSession, test_user):
+ conv = Conversation(user_id=test_user.id, title="Resource Model Test")
+ db_session.add(conv)
+ await db_session.commit()
+ self.conv = conv
+
+ async def test_create_resource(self, db_session: AsyncSession):
+ res = ConversationResource(
+ conversation_id=self.conv.id,
+ resource_id="res-001",
+ title="Paper 1",
+ type="paper",
+ url="https://arxiv.org/abs/1234.5678",
+ )
+ db_session.add(res)
+ await db_session.commit()
+ assert res.id is not None
+ assert res.type == "paper"
+ assert res.url == "https://arxiv.org/abs/1234.5678"
+
+ async def test_resource_with_content(self, db_session: AsyncSession):
+ res = ConversationResource(
+ conversation_id=self.conv.id,
+ resource_id="rpt-001",
+ title="Report",
+ type="report",
+ content="## Findings\nKey results...",
+ )
+ db_session.add(res)
+ await db_session.commit()
+ assert res.content is not None
+
+
+class TestAgentJob:
+ @pytest_asyncio.fixture(autouse=True)
+ async def _conv(self, db_session: AsyncSession, test_user):
+ conv = Conversation(user_id=test_user.id, title="Job Model Test")
+ db_session.add(conv)
+ await db_session.commit()
+ self.conv = conv
+
+ async def test_create_job(self, db_session: AsyncSession):
+ job = AgentJob(
+ job_id="job-123",
+ conversation_id=self.conv.id,
+ user_id=self.conv.user_id,
+ message="Process this",
+ status="queued",
+ )
+ db_session.add(job)
+ await db_session.commit()
+ assert job.id is not None
+ assert job.job_id == "job-123"
+ assert job.status == "queued"
+ assert job.created_at is not None
+
+ async def test_job_optional_fields(self, db_session: AsyncSession):
+ job = AgentJob(
+ job_id="job-456",
+ conversation_id=self.conv.id,
+ user_id=self.conv.user_id,
+ message="Test",
+ status="queued",
+ mode="research",
+ )
+ db_session.add(job)
+ await db_session.commit()
+ assert job.mode == "research"
+
+
+class TestUserSetting:
+ async def test_create_setting(self, db_session: AsyncSession, test_user):
+ setting = UserSetting(
+ user_id=test_user.id,
+ category="agent",
+ key="default_model",
+ value="gpt-4o",
+ )
+ db_session.add(setting)
+ await db_session.commit()
+ assert setting.id is not None
+ assert setting.category == "agent"
+ assert setting.key == "default_model"
+ assert setting.value == "gpt-4o"
+
+ async def test_setting_uniqueness(self, db_session: AsyncSession, test_user):
+ s1 = UserSetting(user_id=test_user.id, category="agent", key="model", value="gpt-4o")
+ db_session.add(s1)
+ await db_session.commit()
+
+ # Verify we can query the setting back
+ from sqlalchemy import select
+ result = await db_session.execute(
+ select(UserSetting).where(
+ UserSetting.user_id == test_user.id,
+ UserSetting.category == "agent",
+ UserSetting.key == "model",
+ )
+ )
+ settings = result.scalars().all()
+ assert len(settings) == 1
+ assert settings[0].value == "gpt-4o"
+
+
+class TestResearchCorpus:
+ async def test_create_corpus(self, db_session: AsyncSession, test_user):
+ corpus = ResearchCorpus(
+ user_id=test_user.id,
+ paper_id="arxiv-123",
+ title="Test Paper",
+ abstract="This is a test abstract",
+ source="arxiv",
+ tags=["ml", "nlp"],
+ )
+ db_session.add(corpus)
+ await db_session.commit()
+ assert corpus.id is not None
+ assert corpus.paper_id == "arxiv-123"
+ assert corpus.source == "arxiv"
+
+
+class TestWritingProject:
+ async def test_create_project(self, db_session: AsyncSession, test_user):
+ project = WritingProject(
+ user_id=test_user.id,
+ title="My Paper",
+ status="drafting",
+ sections={"abstract": ""},
+ )
+ db_session.add(project)
+ await db_session.commit()
+ assert project.id is not None
+ assert project.title == "My Paper"
+ assert project.status == "drafting"
+ assert "abstract" in project.sections
+
+
+class TestSandboxConfig:
+ async def test_create_config(self, db_session: AsyncSession, test_user):
+ config = SandboxConfig(
+ user_id=test_user.id,
+ name="My Modal Sandbox",
+ type="modal",
+ config={"gpu": "a100"},
+ is_default=True,
+ )
+ db_session.add(config)
+ await db_session.commit()
+ assert config.id is not None
+ assert config.type == "modal"
+ assert config.is_default is True
diff --git a/backend/tests/test_prompts.py b/backend/tests/test_prompts.py
new file mode 100644
index 0000000..6ccd6f1
--- /dev/null
+++ b/backend/tests/test_prompts.py
@@ -0,0 +1,81 @@
+"""Tests for system prompt builder."""
+
+import pytest
+
+from openmlr.agent.prompts import build_system_prompt, COMPACT_PROMPT
+from openmlr.agent.types import ToolSpec
+
+
+class TestBuildSystemPrompt:
+ def test_renders_with_tools(self):
+ tools = [
+ ToolSpec(name="read_file", description="Read a file", parameters={"type": "object"}),
+ ToolSpec(name="write_file", description="Write a file", parameters={"type": "object"}),
+ ]
+ prompt = build_system_prompt(tool_specs=tools, mode="general", username="tester")
+ assert isinstance(prompt, str)
+ assert len(prompt) > 0
+ assert "read_file" in prompt or "read_file" in prompt
+
+ def test_renders_with_username(self):
+ tools = [ToolSpec(name="test_tool", description="Test", parameters={"type": "object"})]
+ prompt = build_system_prompt(tool_specs=tools, username="alice")
+ assert isinstance(prompt, str)
+ assert len(prompt) > 0
+
+ def test_renders_with_sandbox_info(self):
+ tools = [ToolSpec(name="bash", description="Run commands", parameters={"type": "object"})]
+ prompt = build_system_prompt(tool_specs=tools, sandbox_info="local")
+ assert isinstance(prompt, str)
+
+ def test_renders_with_mode_plan(self):
+ tools = [ToolSpec(name="ask_user", description="Ask questions", parameters={"type": "object"})]
+ prompt = build_system_prompt(tool_specs=tools, mode="plan")
+ assert isinstance(prompt, str)
+
+ def test_renders_with_mode_research(self):
+ tools = [ToolSpec(name="papers", description="Search papers", parameters={"type": "object"})]
+ prompt = build_system_prompt(tool_specs=tools, mode="research")
+ assert isinstance(prompt, str)
+
+ def test_renders_with_config(self):
+ from openmlr.config import AgentConfig
+ config = AgentConfig(model_name="test/model", max_iterations=10)
+ tools = [ToolSpec(name="test", description="Test tool", parameters={"type": "object"})]
+ prompt = build_system_prompt(tool_specs=tools, config=config)
+ assert isinstance(prompt, str)
+
+ def test_multiple_tools_appear_in_prompt(self):
+ tools = [
+ ToolSpec(name="tool_a", description="First tool", parameters={"type": "object"}),
+ ToolSpec(name="tool_b", description="Second tool", parameters={"type": "object"}),
+ ToolSpec(name="tool_c", description="Third tool", parameters={"type": "object"}),
+ ]
+ prompt = build_system_prompt(tool_specs=tools)
+ assert isinstance(prompt, str)
+
+ def test_empty_tools(self):
+ prompt = build_system_prompt(tool_specs=[])
+ assert isinstance(prompt, str)
+ assert len(prompt) > 0
+
+ def test_contains_date_and_time(self):
+ tools = [ToolSpec(name="t", description="d", parameters={"type": "object"})]
+ prompt = build_system_prompt(tool_specs=tools)
+ # Contains date in YYYY-MM-DD format
+ import re
+ assert re.search(r"\d{4}-\d{2}-\d{2}", prompt)
+
+ def test_contains_sandbox_info_in_prompt(self):
+ tools = [ToolSpec(name="bash", description="Run", parameters={"type": "object"})]
+ prompt = build_system_prompt(tool_specs=tools, sandbox_info="SSH remote (4 GPUs)")
+ assert isinstance(prompt, str)
+
+
+class TestCompactPrompt:
+ def test_is_string(self):
+ assert isinstance(COMPACT_PROMPT, str)
+ assert len(COMPACT_PROMPT) > 0
+
+ def test_mentions_summary(self):
+ assert "summary" in COMPACT_PROMPT.lower()
diff --git a/backend/tests/test_redis_pubsub.py b/backend/tests/test_redis_pubsub.py
new file mode 100644
index 0000000..80fd6d9
--- /dev/null
+++ b/backend/tests/test_redis_pubsub.py
@@ -0,0 +1,184 @@
+"""Tests for Redis pub/sub — event publishing, answer relay, interrupt signaling."""
+
+import pytest
+from unittest.mock import AsyncMock, MagicMock, patch
+
+
+@pytest.mark.asyncio
+class TestPublishEvent:
+ async def test_publishes_json_to_redis(self):
+ from openmlr.agent.types import AgentEvent
+ from openmlr.services.redis_pubsub import publish_event
+
+ mock_redis = AsyncMock()
+ with patch("openmlr.services.redis_pubsub.get_redis", return_value=mock_redis):
+ event = AgentEvent(event_type="status", data={"status": "ready"})
+ await publish_event(event)
+
+ mock_redis.publish.assert_called_once()
+ call_args = mock_redis.publish.call_args[0]
+ assert call_args[0] == "openmlr:events"
+ assert "status" in call_args[1]
+
+ async def test_handles_redis_error(self):
+ from openmlr.agent.types import AgentEvent
+ from openmlr.services.redis_pubsub import publish_event
+
+ mock_redis = AsyncMock()
+ mock_redis.publish.side_effect = Exception("Redis down")
+ with patch("openmlr.services.redis_pubsub.get_redis", return_value=mock_redis):
+ event = AgentEvent(event_type="status")
+ await publish_event(event) # should not raise
+
+
+@pytest.mark.asyncio
+class TestPublishAnswers:
+ async def test_sets_answers_key_in_redis(self):
+ from openmlr.services.redis_pubsub import publish_answers
+
+ mock_redis = AsyncMock()
+ with patch("openmlr.services.redis_pubsub.get_redis", return_value=mock_redis):
+ await publish_answers(conversation_id=42, answers={"q1": "Option A"})
+
+ assert mock_redis.set.called
+ key = mock_redis.set.call_args[0][0]
+ assert "openmlr:answers:" in key
+ assert "42" in key
+
+ async def test_handles_redis_error(self):
+ from openmlr.services.redis_pubsub import publish_answers
+
+ mock_redis = AsyncMock()
+ mock_redis.set.side_effect = Exception("Redis down")
+ with patch("openmlr.services.redis_pubsub.get_redis", return_value=mock_redis):
+ await publish_answers(conversation_id=1, answers={"q": "a"}) # should not raise
+
+
+@pytest.mark.asyncio
+class TestPublishInterrupt:
+ async def test_sets_interrupt_key(self):
+ from openmlr.services.redis_pubsub import publish_interrupt
+
+ mock_redis = AsyncMock()
+ with patch("openmlr.services.redis_pubsub.get_redis", return_value=mock_redis):
+ await publish_interrupt(conversation_id=99)
+
+ assert mock_redis.set.called
+ key = mock_redis.set.call_args[0][0]
+ assert "openmlr:interrupt:" in key
+ assert "99" in key
+
+ async def test_uses_60s_expiry(self):
+ from openmlr.services.redis_pubsub import publish_interrupt
+
+ mock_redis = AsyncMock()
+ with patch("openmlr.services.redis_pubsub.get_redis", return_value=mock_redis):
+ await publish_interrupt(conversation_id=1)
+
+ call_kwargs = mock_redis.set.call_args[1]
+ assert call_kwargs.get("ex") == 60
+
+
+@pytest.mark.asyncio
+class TestCheckInterrupt:
+ async def test_returns_true_when_key_exists(self):
+ from openmlr.services.redis_pubsub import check_interrupt
+
+ mock_redis = AsyncMock()
+ mock_redis.exists.return_value = 1
+ with patch("openmlr.services.redis_pubsub.get_redis", return_value=mock_redis):
+ result = await check_interrupt(conversation_id=5)
+
+ assert result is True
+
+ async def test_returns_false_when_not_found(self):
+ from openmlr.services.redis_pubsub import check_interrupt
+
+ mock_redis = AsyncMock()
+ mock_redis.exists.return_value = 0
+ with patch("openmlr.services.redis_pubsub.get_redis", return_value=mock_redis):
+ result = await check_interrupt(conversation_id=5)
+
+ assert result is False
+
+ async def test_returns_false_on_redis_error(self):
+ from openmlr.services.redis_pubsub import check_interrupt
+
+ mock_redis = AsyncMock()
+ mock_redis.exists.side_effect = Exception("Redis down")
+ with patch("openmlr.services.redis_pubsub.get_redis", return_value=mock_redis):
+ result = await check_interrupt(conversation_id=5)
+
+ assert result is False
+
+
+@pytest.mark.asyncio
+class TestClearInterrupt:
+ async def test_deletes_key(self):
+ from openmlr.services.redis_pubsub import clear_interrupt
+
+ mock_redis = AsyncMock()
+ with patch("openmlr.services.redis_pubsub.get_redis", return_value=mock_redis):
+ await clear_interrupt(conversation_id=5)
+
+ assert mock_redis.delete.called
+
+
+@pytest.mark.asyncio
+class TestWaitForAnswers:
+ async def test_returns_answers_when_set(self):
+ from openmlr.services.redis_pubsub import wait_for_answers
+
+ mock_redis = AsyncMock()
+ mock_redis.get.return_value = '{"q1": "Option A"}'
+ with patch("openmlr.services.redis_pubsub.get_redis", return_value=mock_redis):
+ result = await wait_for_answers(conversation_id=1, timeout=0.5)
+
+ assert result == {"q1": "Option A"}
+ mock_redis.delete.assert_called_once()
+
+ async def test_returns_none_on_timeout(self):
+ import time
+ from openmlr.services.redis_pubsub import wait_for_answers
+
+ mock_redis = AsyncMock()
+ mock_redis.get.return_value = None # never gets set
+ with patch("openmlr.services.redis_pubsub.get_redis", return_value=mock_redis):
+ with patch("openmlr.services.redis_pubsub.asyncio.sleep", return_value=None):
+ result = await wait_for_answers(conversation_id=1, timeout=0.1)
+
+ assert result is None
+
+ async def test_returns_none_on_redis_error(self):
+ from openmlr.services.redis_pubsub import wait_for_answers
+
+ mock_redis = AsyncMock()
+ mock_redis.get.side_effect = Exception("Redis down")
+ with patch("openmlr.services.redis_pubsub.get_redis", return_value=mock_redis):
+ result = await wait_for_answers(conversation_id=1, timeout=0.1)
+
+ assert result is None
+
+
+class TestModuleConstants:
+ def test_channel_name(self):
+ from openmlr.services.redis_pubsub import CHANNEL_NAME
+ assert CHANNEL_NAME == "openmlr:events"
+
+ def test_answers_key_prefix(self):
+ from openmlr.services.redis_pubsub import ANSWERS_KEY_PREFIX
+ assert ANSWERS_KEY_PREFIX == "openmlr:answers:"
+
+ def test_interrupt_key_prefix(self):
+ from openmlr.services.redis_pubsub import INTERRUPT_KEY_PREFIX
+ assert INTERRUPT_KEY_PREFIX == "openmlr:interrupt:"
+
+ def test_redis_url_from_env(self, monkeypatch):
+ monkeypatch.setenv("REDIS_URL", "redis://custom:6379/1")
+ from importlib import reload
+ import openmlr.services.redis_pubsub
+ reload(openmlr.services.redis_pubsub)
+ assert openmlr.services.redis_pubsub.REDIS_URL == "redis://custom:6379/1"
+ # Restore
+ monkeypatch.delenv("REDIS_URL", raising=False)
+ reload(openmlr.services.redis_pubsub)
diff --git a/backend/tests/test_routes_health.py b/backend/tests/test_routes_health.py
new file mode 100644
index 0000000..6af06e6
--- /dev/null
+++ b/backend/tests/test_routes_health.py
@@ -0,0 +1,42 @@
+"""Tests for the health check endpoints."""
+
+import pytest
+from httpx import AsyncClient
+
+pytestmark = pytest.mark.asyncio
+
+
+class TestHealthEndpoints:
+ async def test_api_health_returns_ok(self, client: AsyncClient):
+ resp = await client.get("/api/health")
+ assert resp.status_code == 200
+ data = resp.json()
+ assert data["status"] == "ok"
+ assert "version" in data
+ assert "timestamp" in data
+ assert isinstance(data["version"], str)
+
+ async def test_health_returns_ok(self, client: AsyncClient):
+ resp = await client.get("/health")
+ assert resp.status_code == 200
+ data = resp.json()
+ assert data["status"] == "ok"
+ assert "version" in data
+ assert "timestamp" in data
+
+ async def test_both_endpoints_return_same_structure(self, client: AsyncClient):
+ resp1 = await client.get("/api/health")
+ resp2 = await client.get("/health")
+ assert resp1.json()["status"] == resp2.json()["status"]
+ assert resp1.json()["version"] == resp2.json()["version"]
+
+ async def test_health_not_require_auth(self, client: AsyncClient):
+ resp = await client.get("/api/health")
+ assert resp.status_code == 200
+
+ async def test_health_timestamp_is_iso_format(self, client: AsyncClient):
+ resp = await client.get("/health")
+ data = resp.json()
+ import re
+ iso_pattern = r"\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}"
+ assert re.search(iso_pattern, data["timestamp"])
diff --git a/backend/tests/test_routes_settings.py b/backend/tests/test_routes_settings.py
new file mode 100644
index 0000000..91db400
--- /dev/null
+++ b/backend/tests/test_routes_settings.py
@@ -0,0 +1,146 @@
+"""Tests for settings API routes."""
+
+import os
+import pytest
+from httpx import AsyncClient
+
+pytestmark = pytest.mark.asyncio
+from sqlalchemy.ext.asyncio import AsyncSession
+
+from openmlr.db import operations as ops
+
+
+class TestAgentSettings:
+ async def test_get_all_settings_empty(self, auth_client: AsyncClient):
+ resp = await auth_client.get("/api/settings")
+ assert resp.status_code == 200
+ data = resp.json()
+ assert "settings" in data
+
+ async def test_get_all_settings_after_set(self, auth_client: AsyncClient, db_session: AsyncSession, test_user):
+ await ops.set_user_setting(db_session, test_user.id, "agent", "default_model", "claude")
+ resp = await auth_client.get("/api/settings")
+ assert resp.status_code == 200
+ data = resp.json()
+ assert "agent" in data["settings"]
+ assert data["settings"]["agent"]["default_model"] == "claude"
+
+ async def test_get_settings_category(self, auth_client: AsyncClient, db_session: AsyncSession, test_user):
+ await ops.set_user_setting(db_session, test_user.id, "agent", "yolo_mode", True)
+ resp = await auth_client.get("/api/settings/agent")
+ assert resp.status_code == 200
+ data = resp.json()
+ assert "settings" in data
+ assert data["settings"]["yolo_mode"] is True
+
+ async def test_update_setting(self, auth_client: AsyncClient, db_session: AsyncSession, test_user):
+ resp = await auth_client.put(
+ "/api/settings/agent/test_key",
+ json={"value": "test_value"},
+ )
+ assert resp.status_code == 200
+ assert resp.json() == {"ok": True}
+
+ async def test_update_setting_missing_value(self, auth_client: AsyncClient):
+ resp = await auth_client.put(
+ "/api/settings/agent/test_key",
+ json={},
+ )
+ assert resp.status_code == 400
+
+ async def test_update_setting_with_dict(self, auth_client: AsyncClient):
+ resp = await auth_client.put(
+ "/api/settings/mcp/config",
+ json={"value": {"server": "test", "port": 1234}},
+ )
+ assert resp.status_code == 200
+
+ async def test_delete_setting(self, auth_client: AsyncClient, db_session: AsyncSession, test_user):
+ await ops.set_user_setting(db_session, test_user.id, "agent", "remove_me", "yes")
+ resp = await auth_client.delete("/api/settings/agent/remove_me")
+ assert resp.status_code == 200
+ assert resp.json() == {"ok": True}
+
+ async def test_update_provider_key_sets_env(self, auth_client: AsyncClient, db_session: AsyncSession, test_user):
+ resp = await auth_client.put(
+ "/api/settings/providers/openai_api_key",
+ json={"value": "sk-test-key"},
+ )
+ assert resp.status_code == 200
+ assert os.environ["OPENAI_API_KEY"] == "sk-test-key"
+
+ async def test_settings_require_auth(self, client: AsyncClient):
+ resp = await client.get("/api/settings")
+ assert resp.status_code == 401
+
+
+class TestProviders:
+ async def test_list_providers(self, auth_client: AsyncClient):
+ resp = await auth_client.get("/api/providers")
+ assert resp.status_code == 200
+ data = resp.json()
+ assert "providers" in data
+ providers = data["providers"]
+ assert isinstance(providers, list)
+ assert len(providers) > 0
+ provider_ids = [p["id"] for p in providers]
+ assert "openai" in provider_ids
+ assert "anthropic" in provider_ids
+ assert "openrouter" in provider_ids
+
+ async def test_provider_has_required_fields(self, auth_client: AsyncClient):
+ resp = await auth_client.get("/api/providers")
+ data = resp.json()
+ for p in data["providers"]:
+ assert "id" in p
+ assert "name" in p
+ assert "key_env" in p
+ assert "configured" in p
+
+
+class TestAppStatus:
+ async def test_get_status(self, auth_client: AsyncClient):
+ resp = await auth_client.get("/api/status")
+ assert resp.status_code == 200
+ data = resp.json()
+ assert "model" in data
+ assert "research_model" in data
+ assert "yolo_mode" in data
+ assert "needs_onboarding" in data
+
+
+class TestModels:
+ async def test_list_models(self, auth_client: AsyncClient):
+ resp = await auth_client.get("/api/models")
+ assert resp.status_code == 200
+ data = resp.json()
+ assert "models" in data
+ models = data["models"]
+ assert isinstance(models, list)
+ assert len(models) > 0
+
+ async def test_models_have_required_fields(self, auth_client: AsyncClient):
+ resp = await auth_client.get("/api/models")
+ models = resp.json()["models"]
+ for m in models:
+ assert "id" in m
+ assert "name" in m
+ assert "provider" in m
+
+
+class TestConfigEndpoint:
+ async def test_save_config(self, auth_client: AsyncClient):
+ resp = await auth_client.post(
+ "/api/config",
+ json={"OPENAI_API_KEY": "sk-from-config"},
+ )
+ assert resp.status_code == 200
+ assert resp.json() == {"ok": True}
+
+ async def test_save_config_ignores_non_whitelisted(self, auth_client: AsyncClient):
+ resp = await auth_client.post(
+ "/api/config",
+ json={"RANDOM_KEY": "should-be-ignored"},
+ )
+ assert resp.status_code == 200
+ assert "RANDOM_KEY" not in os.environ
diff --git a/backend/tests/test_sandbox_manager.py b/backend/tests/test_sandbox_manager.py
new file mode 100644
index 0000000..ca0a751
--- /dev/null
+++ b/backend/tests/test_sandbox_manager.py
@@ -0,0 +1,54 @@
+"""Tests for SandboxManager — lifecycle management and provider selection."""
+
+import pytest
+from openmlr.sandbox.manager import SandboxManager
+
+pytestmark = pytest.mark.asyncio
+
+
+@pytest.fixture
+def manager():
+ return SandboxManager()
+
+
+class TestSandboxManager:
+ async def test_initial_state(self, manager):
+ assert manager.get_active() is None
+ assert manager.active_type == "none"
+
+ async def test_create_local(self, manager):
+ sandbox = await manager.create("local")
+ assert sandbox is not None
+ assert manager.active_type == "local"
+ assert manager.get_active() is sandbox
+
+ async def test_create_then_destroy(self, manager):
+ sandbox = await manager.create("local")
+ await manager.destroy()
+ assert manager.get_active() is None
+ assert manager.active_type == "none"
+
+ async def test_create_replaces_existing(self, manager):
+ local1 = await manager.create("local")
+ local2 = await manager.create("local")
+ assert manager.get_active() is local2
+
+ async def test_ensure_local_creates_if_none(self, manager):
+ sandbox = await manager.ensure_local()
+ assert sandbox is not None
+ assert manager.active_type == "local"
+
+ async def test_ensure_local_returns_existing(self, manager):
+ sandbox1 = await manager.create("local")
+ sandbox2 = await manager.ensure_local()
+ assert sandbox1 is sandbox2
+
+ async def test_create_unknown_provider_raises(self, manager):
+ with pytest.raises(ValueError, match="Unknown sandbox provider"):
+ await manager.create("invalid_provider")
+
+ async def test_destroy_when_none_different_provider(self, manager):
+ await manager.create("local")
+ # Simulate a provider mismatch — destroy should still work
+ await manager.destroy()
+ assert manager.active_type == "none"
diff --git a/backend/tests/test_sandbox_types.py b/backend/tests/test_sandbox_types.py
new file mode 100644
index 0000000..d4cbd5f
--- /dev/null
+++ b/backend/tests/test_sandbox_types.py
@@ -0,0 +1,139 @@
+"""Tests for sandbox interface types and LocalSandbox."""
+
+import tempfile
+from pathlib import Path
+import pytest
+
+from openmlr.sandbox.interface import EnvironmentInfo, ExecutionResult, SandboxInterface
+from openmlr.sandbox.local import LocalSandbox
+from openmlr.sandbox.manager import SandboxManager
+
+
+class TestEnvironmentInfo:
+ def test_defaults(self):
+ info = EnvironmentInfo()
+ assert info.os == "unknown"
+ assert info.python_version == "unknown"
+ assert info.gpu_available is False
+ assert info.gpu_info is None
+ assert info.installed_packages == []
+ assert info.available_disk_gb == 0.0
+ assert info.available_ram_gb == 0.0
+
+ def test_custom_values(self):
+ info = EnvironmentInfo(
+ os="Linux",
+ python_version="3.12.0",
+ gpu_available=True,
+ gpu_info="NVIDIA A100",
+ installed_packages=["torch", "numpy"],
+ available_disk_gb=50.0,
+ available_ram_gb=32.0,
+ )
+ assert info.os == "Linux"
+ assert info.gpu_available is True
+ assert "torch" in info.installed_packages
+
+
+class TestExecutionResult:
+ def test_defaults(self):
+ r = ExecutionResult(output="done", success=True)
+ assert r.output == "done"
+ assert r.success is True
+ assert r.exit_code == 0
+ assert r.duration_seconds == 0.0
+
+ def test_failure(self):
+ r = ExecutionResult(output="error", success=False, exit_code=1, duration_seconds=2.5)
+ assert r.success is False
+ assert r.exit_code == 1
+ assert r.duration_seconds == 2.5
+
+ def test_truncation_handled_by_caller(self):
+ # Truncation is done by the tools, not the dataclass
+ r = ExecutionResult(output="x" * 100000, success=True)
+ assert len(r.output) == 100000
+
+
+class TestSandboxInterface:
+ def test_is_abstract(self):
+ with pytest.raises(TypeError):
+ SandboxInterface()
+
+
+@pytest.mark.asyncio
+class TestLocalSandbox:
+ async def test_create_default(self, monkeypatch, tmp_path):
+ monkeypatch.chdir(tmp_path)
+ sb = LocalSandbox()
+ await sb.create({})
+ assert sb.workdir == str(tmp_path)
+
+ async def test_create_with_workdir(self, monkeypatch, tmp_path):
+ monkeypatch.chdir(tmp_path)
+ sb = LocalSandbox()
+ await sb.create({"workdir": str(tmp_path)})
+ assert sb.workdir == str(tmp_path)
+
+ async def test_write_and_read_file(self, monkeypatch, tmp_path):
+ monkeypatch.chdir(tmp_path)
+ sb = LocalSandbox(str(tmp_path))
+ await sb.create({})
+ f = tmp_path / "test.txt"
+ ok = await sb.write_file(str(f), "hello")
+ assert ok is True
+ content = await sb.read_file("test.txt")
+ assert content == "hello"
+
+ async def test_file_exists(self, monkeypatch, tmp_path):
+ monkeypatch.chdir(tmp_path)
+ sb = LocalSandbox(str(tmp_path))
+ await sb.create({})
+ f = tmp_path / "exists.txt"
+ f.write_text("data")
+ assert await sb.file_exists("exists.txt") is True
+ assert await sb.file_exists("nope.txt") is False
+
+ async def test_edit_file(self, monkeypatch, tmp_path):
+ monkeypatch.chdir(tmp_path)
+ sb = LocalSandbox(str(tmp_path))
+ await sb.create({})
+ f = tmp_path / "edit.txt"
+ f.write_text("old text here")
+ ok = await sb.edit_file("edit.txt", "old", "new")
+ assert ok is True
+ assert f.read_text() == "new text here"
+
+ async def test_edit_nonexistent(self, monkeypatch, tmp_path):
+ monkeypatch.chdir(tmp_path)
+ sb = LocalSandbox(str(tmp_path))
+ await sb.create({})
+ # edit_file tries to read the file first, which raises FileNotFoundError
+ with pytest.raises(FileNotFoundError):
+ await sb.edit_file("nope.txt", "a", "b")
+
+ async def test_list_files(self, monkeypatch, tmp_path):
+ monkeypatch.chdir(tmp_path)
+ sb = LocalSandbox(str(tmp_path))
+ await sb.create({})
+ (tmp_path / "a.txt").write_text("a")
+ (tmp_path / "b.txt").write_text("b")
+ (tmp_path / "subdir").mkdir()
+ files = await sb.list_files(".")
+ assert "a.txt" in files
+ assert "b.txt" in files
+ assert "subdir/" in files
+
+ async def test_destroy_is_noop(self, monkeypatch, tmp_path):
+ monkeypatch.chdir(tmp_path)
+ sb = LocalSandbox()
+ await sb.create({})
+ await sb.destroy()
+
+ async def test_execute_simple_command(self, monkeypatch, tmp_path):
+ monkeypatch.chdir(tmp_path)
+ sb = LocalSandbox(str(tmp_path))
+ await sb.create({})
+ result = await sb.execute("echo hello", timeout=10)
+ assert result.success is True
+ assert "hello" in result.output
diff --git a/backend/tests/test_session_manager.py b/backend/tests/test_session_manager.py
new file mode 100644
index 0000000..6990d70
--- /dev/null
+++ b/backend/tests/test_session_manager.py
@@ -0,0 +1,99 @@
+"""Tests for SessionManager — multi-session lifecycle and message queuing."""
+
+import pytest
+from openmlr.services.session_manager import SessionManager, ActiveSession
+
+pytestmark = pytest.mark.asyncio
+from openmlr.services.event_bus import EventBus
+from openmlr.config import AgentConfig
+
+
+@pytest.fixture
+def config():
+ return AgentConfig(
+ model_name="test/model",
+ max_iterations=10,
+ stream=False,
+ yolo_mode=False,
+ )
+
+
+@pytest.fixture
+def event_bus():
+ return EventBus()
+
+
+@pytest.fixture
+def session_manager(event_bus, config):
+ return SessionManager(event_bus=event_bus, default_config=config)
+
+
+class TestSessionManager:
+ async def test_initial_state(self, session_manager):
+ assert session_manager.current_conversation_id is None
+ assert session_manager.is_processing is False
+ assert session_manager.get_current_session() is None
+
+ async def test_get_session_nonexistent(self, session_manager):
+ assert session_manager.get_session(999) is None
+
+ async def test_get_or_create_session(self, session_manager):
+ active = await session_manager.get_or_create_session(
+ conversation_id=1,
+ uuid="test-uuid-1",
+ model="test/model",
+ mode="general",
+ username="tester",
+ )
+ assert active is not None
+ assert active.conversation_id == 1
+ assert active.uuid == "test-uuid-1"
+ assert active.session is not None
+ assert active.tool_router is not None
+
+ async def test_get_or_create_returns_existing(self, session_manager):
+ s1 = await session_manager.get_or_create_session(1, "u1")
+ s2 = await session_manager.get_or_create_session(1, "u1")
+ assert s1 is s2
+
+ async def test_get_session_after_create(self, session_manager):
+ await session_manager.get_or_create_session(1, "u1")
+ s = session_manager.get_session(1)
+ assert s is not None
+ assert s.conversation_id == 1
+
+ async def test_remove_session(self, session_manager):
+ await session_manager.get_or_create_session(1, "u1")
+ await session_manager.remove_session(1)
+ assert session_manager.get_session(1) is None
+
+ async def test_remove_nonexistent_session(self, session_manager):
+ await session_manager.remove_session(999)
+
+ async def test_multiple_sessions(self, session_manager):
+ s1 = await session_manager.get_or_create_session(1, "u1")
+ s2 = await session_manager.get_or_create_session(2, "u2")
+ assert s1.conversation_id == 1
+ assert s2.conversation_id == 2
+ assert s1 is not s2
+
+ async def test_session_loads_existing_messages(self, session_manager):
+ messages = [
+ {"role": "user", "content": "Hello"},
+ {"role": "assistant", "content": "Hi there!"},
+ ]
+ active = await session_manager.get_or_create_session(
+ 1, "u1", existing_messages=messages,
+ )
+ msgs = active.session.context_manager.get_messages()
+ assert len(msgs) >= 2 # includes system prompt + existing messages
+
+ async def test_generate_title_no_session(self, session_manager):
+ title = await session_manager.generate_title(999, [])
+ assert title is None
+
+ async def test_remove_session_clears_current(self, session_manager):
+ await session_manager.get_or_create_session(42, "u42")
+ session_manager.current_conversation_id = 42
+ await session_manager.remove_session(42)
+ assert session_manager.current_conversation_id is None
diff --git a/backend/tests/test_tool_registry.py b/backend/tests/test_tool_registry.py
new file mode 100644
index 0000000..2f835ef
--- /dev/null
+++ b/backend/tests/test_tool_registry.py
@@ -0,0 +1,233 @@
+"""Tests for ToolRouter — registration, dispatch, mode filtering."""
+
+import pytest
+from openmlr.tools.registry import ToolRouter, MODE_TOOL_RESTRICTIONS
+
+pytestmark = pytest.mark.asyncio
+from openmlr.agent.types import ToolSpec, ToolCall
+
+
+@pytest.fixture
+def router():
+ return ToolRouter()
+
+
+@pytest.fixture
+def dummy_tool():
+ async def handler(arg1: str = "default") -> tuple[str, bool]:
+ return f"handled: {arg1}", True
+
+ return ToolSpec(
+ name="dummy",
+ description="A dummy tool",
+ parameters={
+ "type": "object",
+ "properties": {
+ "arg1": {"type": "string", "description": "Test arg"},
+ },
+ "required": ["arg1"],
+ },
+ handler=handler,
+ )
+
+
+@pytest.fixture
+def read_tool():
+ async def handler(path: str) -> tuple[str, bool]:
+ return f"read: {path}", True
+
+ return ToolSpec(
+ name="read_file",
+ description="Read a file",
+ parameters={
+ "type": "object",
+ "properties": {
+ "path": {"type": "string"},
+ },
+ "required": ["path"],
+ },
+ handler=handler,
+ )
+
+
+@pytest.fixture
+def bash_tool():
+ async def handler(cmd: str) -> tuple[str, bool]:
+ return f"executed: {cmd}", True
+
+ return ToolSpec(
+ name="bash",
+ description="Execute a command",
+ parameters={
+ "type": "object",
+ "properties": {
+ "cmd": {"type": "string"},
+ },
+ "required": ["cmd"],
+ },
+ handler=handler,
+ )
+
+
+class TestToolRegistration:
+ async def test_register_single(self, router, dummy_tool):
+ router.register(dummy_tool)
+ assert "dummy" in router.tools
+
+ async def test_register_many(self, router, dummy_tool, read_tool):
+ router.register_many([dummy_tool, read_tool])
+ assert len(router.tools) == 2
+ assert "dummy" in router.tools
+ assert "read_file" in router.tools
+
+ async def test_get_tool(self, router, dummy_tool):
+ router.register(dummy_tool)
+ tool = router.get_tool("dummy")
+ assert tool is not None
+ assert tool.name == "dummy"
+
+ async def test_get_tool_not_found(self, router):
+ assert router.get_tool("nonexistent") is None
+
+ async def test_register_blocked_tool(self, router, dummy_tool):
+ router._blocklist = {"dummy"}
+ router.register(dummy_tool)
+ assert "dummy" not in router.tools
+
+
+class TestToolDispatch:
+ async def test_call_tool_simple(self, router, dummy_tool):
+ router.register(dummy_tool)
+ result, success = await router.call_tool("dummy", {"arg1": "hello"})
+ assert success is True
+ assert "handled: hello" in result
+
+ async def test_call_tool_unknown(self, router):
+ result, success = await router.call_tool("unknown", {})
+ assert success is False
+ assert "Unknown tool" in result
+
+ async def test_call_tool_default_args(self, router, dummy_tool):
+ router.register(dummy_tool)
+ result, success = await router.call_tool("dummy", {})
+ assert success is True
+ assert "handled: default" in result
+
+ async def test_call_tool_type_error(self, router):
+ async def handler(required_arg: str) -> tuple[str, bool]:
+ return "ok", True
+
+ tool = ToolSpec(
+ name="strict", description="Needs arg", parameters={"type": "object"},
+ handler=handler,
+ )
+ router.register(tool)
+ result, success = await router.call_tool("strict", {"wrong_param": "val"})
+ assert success is False
+ assert "Tool argument error" in result
+
+
+class TestToolSpecsForLLM:
+ async def test_get_specs_in_openai_format(self, router, dummy_tool, read_tool):
+ router.register(dummy_tool)
+ router.register(read_tool)
+ specs = router.get_tool_specs_for_llm()
+ assert len(specs) == 2
+ for spec in specs:
+ assert spec["type"] == "function"
+ assert "function" in spec
+ assert "name" in spec["function"]
+ assert "description" in spec["function"]
+ assert "parameters" in spec["function"]
+
+ async def test_get_raw_specs(self, router, dummy_tool, read_tool):
+ router.register(dummy_tool)
+ router.register(read_tool)
+ raw = router.get_raw_specs()
+ assert len(raw) == 2
+ assert all(isinstance(s, ToolSpec) for s in raw)
+
+
+class TestModeFiltering:
+ async def test_default_mode_allows_all(self, router, dummy_tool, bash_tool):
+ router.register(dummy_tool)
+ router.register(bash_tool)
+ allowed, msg = router.is_tool_allowed("bash")
+ assert allowed is True
+ assert msg == ""
+
+ async def test_plan_mode_blocks_bash(self, router, bash_tool):
+ router.register(bash_tool)
+ router.set_mode("plan")
+ allowed, msg = router.is_tool_allowed("bash")
+ assert allowed is False
+ assert "PLAN mode" in msg
+
+ async def test_plan_mode_allows_ask_user(self, router):
+ router.set_mode("plan")
+ allowed, msg = router.is_tool_allowed("ask_user")
+ assert allowed is True
+
+ async def test_plan_mode_allows_read_file(self, router):
+ router.set_mode("plan")
+ allowed, msg = router.is_tool_allowed("read_file")
+ assert allowed is True
+
+ async def test_execute_mode_blocks_ask_user(self, router):
+ router.set_mode("execute")
+ allowed, msg = router.is_tool_allowed("ask_user")
+ assert allowed is False
+ assert "EXECUTE mode" in msg
+
+ async def test_execute_mode_allows_bash(self, router, bash_tool):
+ router.register(bash_tool)
+ router.set_mode("execute")
+ allowed, msg = router.is_tool_allowed("bash")
+ assert allowed is True
+
+ async def test_mode_filtering_in_get_specs(self, router, dummy_tool, bash_tool, read_tool):
+ router.register(dummy_tool)
+ router.register(bash_tool)
+ router.register(read_tool)
+ router.set_mode("plan")
+ specs = router.get_tool_specs_for_llm(filter_by_mode=True)
+ spec_names = [s["function"]["name"] for s in specs]
+ assert "bash" not in spec_names
+
+ async def test_no_filter_in_get_specs(self, router, dummy_tool, bash_tool):
+ router.register(dummy_tool)
+ router.register(bash_tool)
+ router.set_mode("plan")
+ specs = router.get_tool_specs_for_llm(filter_by_mode=False)
+ spec_names = [s["function"]["name"] for s in specs]
+ assert "bash" in spec_names
+ assert "dummy" in spec_names
+
+ async def test_call_tool_mode_enforcement(self, router, bash_tool):
+ router.register(bash_tool)
+ router.set_mode("plan")
+ result, success = await router.call_tool("bash", {"cmd": "ls"}, enforce_mode=True)
+ assert success is False
+ assert "MODE VIOLATION" in result
+
+ async def test_get_mode(self, router):
+ assert router.get_mode() == "general"
+ router.set_mode("research")
+ assert router.get_mode() == "research"
+
+
+class TestModeRestrictionsConfig:
+ async def test_plan_has_allowed_list(self):
+ assert "allowed" in MODE_TOOL_RESTRICTIONS["plan"]
+
+ async def test_execute_has_blocked_list(self):
+ assert "blocked" in MODE_TOOL_RESTRICTIONS["execute"]
+
+ async def test_plan_allowed_includes_ask_user(self):
+ assert "ask_user" in MODE_TOOL_RESTRICTIONS["plan"]["allowed"]
+
+ async def test_plan_allowed_includes_read_file(self):
+ assert "read_file" in MODE_TOOL_RESTRICTIONS["plan"]["allowed"]
+
+ async def test_execute_blocked_includes_ask_user(self):
+ assert "ask_user" in MODE_TOOL_RESTRICTIONS["execute"]["blocked"]
diff --git a/backend/tests/test_tools_github.py b/backend/tests/test_tools_github.py
new file mode 100644
index 0000000..3b6a5e7
--- /dev/null
+++ b/backend/tests/test_tools_github.py
@@ -0,0 +1,55 @@
+"""Tests for GitHub tools — tool specs and helper functions."""
+
+import pytest
+from openmlr.tools.github import create_github_tools, _headers
+
+pytestmark = pytest.mark.asyncio
+
+
+class TestCreateGithubTools:
+ async def test_creates_all_tools(self):
+ tools = create_github_tools()
+ names = [t.name for t in tools]
+ assert "github_read_file" in names
+ assert "github_list_repos" in names
+ assert "github_find_examples" in names
+ assert len(tools) == 3
+
+ async def test_read_file_required_params(self):
+ tools = create_github_tools()
+ read = [t for t in tools if t.name == "github_read_file"][0]
+ required = read.parameters["required"]
+ assert "owner" in required
+ assert "repo" in required
+ assert "path" in required
+
+ async def test_list_repos_required_params(self):
+ tools = create_github_tools()
+ list_tool = [t for t in tools if t.name == "github_list_repos"][0]
+ assert "owner" in list_tool.parameters["required"]
+
+ async def test_find_examples_required_params(self):
+ tools = create_github_tools()
+ find = [t for t in tools if t.name == "github_find_examples"][0]
+ assert "query" in find.parameters["required"]
+
+ async def test_find_examples_language_filter(self):
+ tools = create_github_tools()
+ find = [t for t in tools if t.name == "github_find_examples"][0]
+ assert "language" in find.parameters["properties"]
+
+
+class TestHeaders:
+ async def test_headers_accept(self):
+ h = _headers()
+ assert h["Accept"] == "application/vnd.github+json"
+
+ async def test_headers_without_token(self, monkeypatch):
+ monkeypatch.delenv("GITHUB_TOKEN", raising=False)
+ h = _headers()
+ assert "Authorization" not in h
+
+ async def test_headers_with_token(self, monkeypatch):
+ monkeypatch.setenv("GITHUB_TOKEN", "ghp_test123")
+ h = _headers()
+ assert h["Authorization"] == "Bearer ghp_test123"
diff --git a/backend/tests/test_tools_local.py b/backend/tests/test_tools_local.py
new file mode 100644
index 0000000..575b7b6
--- /dev/null
+++ b/backend/tests/test_tools_local.py
@@ -0,0 +1,219 @@
+"""Tests for local tools — bash, read, write, edit, and path validation."""
+
+import os
+import tempfile
+from pathlib import Path
+import pytest
+
+from openmlr.tools.local import (
+ create_local_tools, _validate_path, _handle_read, _handle_write, _handle_edit,
+ DOCKER_IMAGE, CONTAINER_PREFIX, ALLOW_DIRECT_EXEC, WORKSPACE_ROOT,
+)
+
+
+class TestCreateLocalTools:
+ def test_creates_all_tools(self):
+ tools = create_local_tools()
+ names = [t.name for t in tools]
+ assert "bash" in names
+ assert "read" in names
+ assert "write" in names
+ assert "edit" in names
+ assert len(tools) == 4
+
+ def test_bash_tool_spec(self):
+ tools = create_local_tools()
+ bash = [t for t in tools if t.name == "bash"][0]
+ assert "command" in bash.parameters["properties"]
+ assert "timeout" in bash.parameters["properties"]
+ assert "workdir" in bash.parameters["properties"]
+ assert "command" in bash.parameters["required"]
+
+ def test_read_tool_spec(self):
+ tools = create_local_tools()
+ read_tool = [t for t in tools if t.name == "read"][0]
+ assert read_tool.handler is not None
+ assert "path" in read_tool.parameters["required"]
+
+ def test_write_tool_spec(self):
+ tools = create_local_tools()
+ write_tool = [t for t in tools if t.name == "write"][0]
+ assert "path" in write_tool.parameters["required"]
+ assert "content" in write_tool.parameters["required"]
+
+ def test_edit_tool_spec(self):
+ tools = create_local_tools()
+ edit_tool = [t for t in tools if t.name == "edit"][0]
+ required = edit_tool.parameters["required"]
+ assert "path" in required
+ assert "old_string" in required
+ assert "new_string" in required
+
+
+class TestValidatePath:
+ def test_resolves_relative_path(self):
+ cwd = os.getcwd()
+ path = Path(".", "test_file.txt")
+ resolved, error = _validate_path(path)
+ assert error is None
+ assert resolved.is_absolute()
+
+ def test_path_within_cwd_is_allowed(self):
+ path = Path.cwd() / "test" / "file.py"
+ resolved, error = _validate_path(path)
+ assert error is None
+
+ def test_blocked_system_path(self, monkeypatch):
+ monkeypatch.setattr("openmlr.tools.local.WORKSPACE_ROOT", "/home/user/projects")
+ path = Path("/etc", "passwd")
+ resolved, error = _validate_path(path)
+ assert error is not None
+ assert "outside workspace" in error
+
+ def test_blocked_root_path(self, monkeypatch):
+ monkeypatch.setattr("openmlr.tools.local.WORKSPACE_ROOT", "/home/user/projects")
+ path = Path("/root", "secret")
+ resolved, error = _validate_path(path)
+ assert error is not None
+
+ def test_blocked_var_path(self, monkeypatch):
+ monkeypatch.setattr("openmlr.tools.local.WORKSPACE_ROOT", "/home/user/projects")
+ path = Path("/var", "log")
+ resolved, error = _validate_path(path)
+ assert error is not None
+
+ def test_blocked_bin_path(self):
+ path = Path("/bin", "bash")
+ resolved, error = _validate_path(path)
+ assert error is not None
+
+ def test_with_workspace_root_set(self, monkeypatch):
+ workspace = str(Path.cwd() / "workspace")
+ monkeypatch.setattr("openmlr.tools.local.WORKSPACE_ROOT", workspace)
+ within = Path(workspace) / "file.txt"
+ resolved, error = _validate_path(within)
+ assert error is None
+
+
+@pytest.mark.asyncio
+class TestHandleRead:
+ async def test_reads_file_with_line_numbers(self, monkeypatch, tmp_path):
+ monkeypatch.chdir(tmp_path)
+ f = tmp_path / "test.txt"
+ f.write_text("line one\nline two\nline three\n")
+ result, success = await _handle_read("test.txt")
+ assert success is True
+ assert "1: line one" in result
+ assert "2: line two" in result
+ assert "3: line three" in result
+
+ async def test_read_with_offset_and_limit(self, monkeypatch, tmp_path):
+ monkeypatch.chdir(tmp_path)
+ f = tmp_path / "test.txt"
+ f.write_text("a\nb\nc\nd\ne\n")
+ result, success = await _handle_read("test.txt", offset=3, limit=2)
+ assert success is True
+ assert "3: c" in result
+ assert "4: d" in result
+ assert "5: e" not in result
+
+ async def test_read_nonexistent_file(self, monkeypatch, tmp_path):
+ monkeypatch.chdir(tmp_path)
+ result, success = await _handle_read("nonexistent_test_file.txt")
+ assert success is False
+ assert "File not found" in result
+
+ async def test_read_directory(self, monkeypatch, tmp_path):
+ monkeypatch.chdir(tmp_path)
+ d = tmp_path / "testdir"
+ d.mkdir()
+ (d / "file1.txt").write_text("x")
+ (d / "file2.txt").write_text("y")
+ result, success = await _handle_read("testdir")
+ assert success is True
+ assert "file1.txt" in result
+ assert "file2.txt" in result
+
+
+@pytest.mark.asyncio
+class TestHandleWrite:
+ async def test_writes_file(self, monkeypatch, tmp_path):
+ monkeypatch.chdir(tmp_path)
+ f = tmp_path / "output.txt"
+ result, success = await _handle_write("output.txt", content="hello world")
+ assert success is True
+ assert "Wrote" in result
+ assert f.read_text() == "hello world"
+
+ async def test_creates_parent_directories(self, monkeypatch, tmp_path):
+ monkeypatch.chdir(tmp_path)
+ f = tmp_path / "deeply" / "nested" / "dir" / "file.txt"
+ result, success = await _handle_write("deeply/nested/dir/file.txt", content="deep")
+ assert success is True
+ assert f.read_text() == "deep"
+
+ async def test_requires_path(self, monkeypatch, tmp_path):
+ monkeypatch.chdir(tmp_path)
+ result, success = await _handle_write(path="", content="test")
+ assert success is False
+
+ async def test_requires_content(self, monkeypatch, tmp_path):
+ monkeypatch.chdir(tmp_path)
+ result, success = await _handle_write(path="test.txt", content="")
+ assert success is False
+
+
+@pytest.mark.asyncio
+class TestHandleEdit:
+ async def test_replaces_string(self, monkeypatch, tmp_path):
+ monkeypatch.chdir(tmp_path)
+ f = tmp_path / "edit_test.txt"
+ f.write_text("Hello World")
+ result, success = await _handle_edit("edit_test.txt", "World", "Universe")
+ assert success is True
+ assert "Replaced" in result
+ assert f.read_text() == "Hello Universe"
+
+ async def test_old_string_not_found(self, monkeypatch, tmp_path):
+ monkeypatch.chdir(tmp_path)
+ f = tmp_path / "edit_test.txt"
+ f.write_text("Hello World")
+ result, success = await _handle_edit("edit_test.txt", "Mars", "Earth")
+ assert success is False
+ assert "old_string not found" in result
+
+ async def test_multiple_matches_without_replace_all(self, monkeypatch, tmp_path):
+ monkeypatch.chdir(tmp_path)
+ f = tmp_path / "edit_test.txt"
+ f.write_text("hello hello hello")
+ result, success = await _handle_edit("edit_test.txt", "hello", "hi")
+ assert success is False
+ assert "Found 3 matches" in result
+
+ async def test_replace_all(self, monkeypatch, tmp_path):
+ monkeypatch.chdir(tmp_path)
+ f = tmp_path / "edit_test.txt"
+ f.write_text("hello hello hello")
+ result, success = await _handle_edit("edit_test.txt", "hello", "hi", replace_all=True)
+ assert success is True
+ assert f.read_text() == "hi hi hi"
+
+ async def test_nonexistent_file(self, monkeypatch, tmp_path):
+ monkeypatch.chdir(tmp_path)
+ result, success = await _handle_edit("nonexistent_edit_test.txt", "a", "b")
+ assert success is False
+ assert "File not found" in result
+
+
+class TestModuleConstants:
+ def test_docker_image_default(self):
+ assert DOCKER_IMAGE == "python:3.12-slim"
+
+ def test_container_prefix(self):
+ assert CONTAINER_PREFIX == "openmlr-sandbox"
+
+ def test_allow_direct_exec_default(self, monkeypatch):
+ monkeypatch.delenv("OPENMLR_ALLOW_DIRECT_EXEC", raising=False)
+ import openmlr.tools.local
+ allow = openmlr.tools.local.ALLOW_DIRECT_EXEC
+ assert allow is False
diff --git a/backend/tests/test_tools_mcp.py b/backend/tests/test_tools_mcp.py
new file mode 100644
index 0000000..0cddbee
--- /dev/null
+++ b/backend/tests/test_tools_mcp.py
@@ -0,0 +1,64 @@
+"""Tests for MCP tools — env var substitution and config processing."""
+
+import pytest
+from openmlr.tools.mcp import substitute_env_vars, process_mcp_config
+
+pytestmark = pytest.mark.asyncio
+
+
+class TestSubstituteEnvVars:
+ async def test_replaces_var(self, monkeypatch):
+ monkeypatch.setenv("TEST_KEY", "test-value")
+ result = substitute_env_vars("prefix-${TEST_KEY}-suffix")
+ assert result == "prefix-test-value-suffix"
+
+ async def test_multiple_vars(self, monkeypatch):
+ monkeypatch.setenv("A", "aaa")
+ monkeypatch.setenv("B", "bbb")
+ result = substitute_env_vars("${A}_${B}")
+ assert result == "aaa_bbb"
+
+ async def test_unknown_var_unchanged(self, monkeypatch):
+ monkeypatch.delenv("UNKNOWN_VAR", raising=False)
+ result = substitute_env_vars("${UNKNOWN_VAR}")
+ assert result == "${UNKNOWN_VAR}"
+
+ async def test_no_vars(self):
+ result = substitute_env_vars("plain text no vars")
+ assert result == "plain text no vars"
+
+ async def test_partial_match(self):
+ result = substitute_env_vars("$TEST no match")
+ assert result == "$TEST no match"
+
+
+class TestProcessMCPConfig:
+ async def test_simple_dict(self, monkeypatch):
+ monkeypatch.setenv("API_KEY", "secret123")
+ config = {"url": "https://api.example.com?key=${API_KEY}"}
+ result = process_mcp_config(config)
+ assert result["url"] == "https://api.example.com?key=secret123"
+
+ async def test_nested_dict(self, monkeypatch):
+ monkeypatch.setenv("HOST", "localhost")
+ config = {"server": {"host": "${HOST}", "port": 8080}}
+ result = process_mcp_config(config)
+ assert result["server"]["host"] == "localhost"
+ assert result["server"]["port"] == 8080
+
+ async def test_list_values(self, monkeypatch):
+ monkeypatch.setenv("ARG1", "value1")
+ monkeypatch.setenv("ARG2", "value2")
+ config = {"args": ["${ARG1}", "${ARG2}", "literal"]}
+ result = process_mcp_config(config)
+ assert result["args"] == ["value1", "value2", "literal"]
+
+ async def test_non_string_values(self):
+ config = {"timeout": 30, "enabled": True, "null_val": None}
+ result = process_mcp_config(config)
+ assert result["timeout"] == 30
+ assert result["enabled"] is True
+ assert result["null_val"] is None
+
+ async def test_empty_config(self):
+ assert process_mcp_config({}) == {}
diff --git a/backend/tests/test_tools_papers.py b/backend/tests/test_tools_papers.py
new file mode 100644
index 0000000..37b8803
--- /dev/null
+++ b/backend/tests/test_tools_papers.py
@@ -0,0 +1,98 @@
+"""Tests for papers tool — helper functions and tool spec."""
+
+import pytest
+from openmlr.tools.papers import (
+ create_papers_tool, _to_openalex_id, _extract_arxiv_id,
+ _extract_arxiv_from_ids, _reconstruct_abstract, _check_budget,
+ _increment_budget, _get_budget_info,
+)
+
+pytestmark = pytest.mark.asyncio
+
+
+class TestCreatePapersTool:
+ async def test_creates_tool(self):
+ tool = create_papers_tool()
+ assert tool.name == "papers"
+ assert tool.handler is not None
+ assert "operation" in tool.parameters["required"]
+ ops = tool.parameters["properties"]["operation"]["enum"]
+ assert "search" in ops
+ assert "trending" in ops
+ assert "details" in ops
+ assert "read_paper" in ops
+ assert "citations" in ops
+ assert "recommend" in ops
+ assert "find_code" in ops
+ assert "find_datasets" in ops
+
+
+class TestExtractArxivId:
+ async def test_standard_format(self):
+ assert _extract_arxiv_id("2301.12345") == "2301.12345"
+
+ async def test_with_version(self):
+ assert _extract_arxiv_id("2301.12345v2") == "2301.12345v2"
+
+ async def test_from_url(self):
+ assert _extract_arxiv_id("https://arxiv.org/abs/2301.12345") == "2301.12345"
+
+ async def test_from_pdf_url(self):
+ assert _extract_arxiv_id("https://arxiv.org/pdf/2301.12345.pdf") == "2301.12345"
+
+ async def test_no_arxiv_id(self):
+ assert _extract_arxiv_id("some text without id") is None
+
+ async def test_from_doi_with_arxiv(self):
+ assert _extract_arxiv_id("10.48550/arXiv.2301.12345") == "2301.12345"
+
+
+class TestToOpenAlexId:
+ async def test_openalex_id(self):
+ assert _to_openalex_id("W123456") == "W123456"
+
+ async def test_openalex_url(self):
+ assert _to_openalex_id("https://openalex.org/W123456") == "https://openalex.org/W123456"
+
+ async def test_doi(self):
+ assert _to_openalex_id("10.1234/foo.bar") == "https://doi.org/10.1234/foo.bar"
+
+ async def test_arxiv_id(self):
+ result = _to_openalex_id("2301.12345")
+ assert "doi.org" in result
+ assert "arXiv" in result
+
+
+class TestReconstructAbstract:
+ async def test_simple(self):
+ inv = {"hello": [0], "world": [1]}
+ assert _reconstruct_abstract(inv) == "hello world"
+
+ async def test_empty(self):
+ assert _reconstruct_abstract({}) is None
+ assert _reconstruct_abstract(None) is None
+
+ async def test_multiple_positions(self):
+ inv = {"the": [0, 3], "cat": [1], "sat": [2], "mat": [4]}
+ result = _reconstruct_abstract(inv)
+ assert result == "the cat sat the mat"
+
+
+class TestBudgetFunctions:
+ async def test_check_budget_allows_first_call(self):
+ ok, msg = _check_budget()
+ assert ok is True
+ assert msg == ""
+
+ async def test_get_budget_info(self):
+ info = _get_budget_info()
+ assert "used" in info
+ assert "max" in info
+
+ async def test_increment_and_check(self):
+ from openmlr.tools.papers import _search_counts
+ _search_counts.clear()
+ _increment_budget()
+ info = _get_budget_info()
+ assert info["used"] >= 1
+ _search_counts.clear()
diff --git a/backend/tests/test_tools_research.py b/backend/tests/test_tools_research.py
new file mode 100644
index 0000000..3577f82
--- /dev/null
+++ b/backend/tests/test_tools_research.py
@@ -0,0 +1,62 @@
+"""Tests for research tool — tool specs and the research sub-agent dispatching."""
+
+import pytest
+from openmlr.tools.research import (
+ create_research_tool, _get_research_tool_specs, _execute_research_tool,
+ RESEARCH_SYSTEM_PROMPT, MAX_RESEARCH_ITERATIONS,
+)
+from openmlr.agent.types import ToolCall
+
+
+class TestCreateResearchTool:
+ def test_creates_tool(self):
+ tool = create_research_tool()
+ assert tool.name == "research"
+ assert tool.handler is not None
+ assert "query" in tool.parameters["required"]
+
+ def test_focus_enum(self):
+ tool = create_research_tool()
+ focus = tool.parameters["properties"]["focus"]
+ assert focus["enum"] == ["papers", "code", "docs", "general"]
+
+
+class TestGetResearchToolSpecs:
+ def test_returns_formatted_tools(self):
+ specs = _get_research_tool_specs()
+ assert len(specs) > 0
+ for spec in specs:
+ assert spec["type"] == "function"
+ assert "function" in spec
+ assert "name" in spec["function"]
+
+ def test_includes_search(self):
+ specs = _get_research_tool_specs()
+ names = [s["function"]["name"] for s in specs]
+ assert "web_search" in names
+
+ def test_includes_papers(self):
+ specs = _get_research_tool_specs()
+ names = [s["function"]["name"] for s in specs]
+ assert "papers" in names
+
+ def test_github_read_included(self):
+ specs = _get_research_tool_specs()
+ names = [s["function"]["name"] for s in specs]
+ assert "github_read_file" in names or "github_find_examples" in names
+
+
+class TestExecuteResearchTool:
+ @pytest.mark.asyncio
+ async def test_unknown_tool(self):
+ tc = ToolCall(id="tc1", name="unknown_tool", arguments={})
+ result, success = await _execute_research_tool(tc)
+ assert success is False
+ assert "not available" in result
+
+ def test_system_prompt_not_empty(self):
+ assert len(RESEARCH_SYSTEM_PROMPT) > 0
+ assert "research sub-agent" in RESEARCH_SYSTEM_PROMPT.lower()
+
+ def test_max_iterations(self):
+ assert MAX_RESEARCH_ITERATIONS == 60
diff --git a/backend/tests/test_tools_sandbox.py b/backend/tests/test_tools_sandbox.py
new file mode 100644
index 0000000..e05a486
--- /dev/null
+++ b/backend/tests/test_tools_sandbox.py
@@ -0,0 +1,64 @@
+"""Tests for sandbox tools — probe, create, exec, read, write."""
+
+import pytest
+from openmlr.tools.sandbox_tools import create_sandbox_tools, _handle_probe
+
+pytestmark = pytest.mark.asyncio
+from openmlr.sandbox.manager import SandboxManager
+
+
+class TestCreateSandboxTools:
+ async def test_creates_all_tools(self):
+ mgr = SandboxManager()
+ tools = create_sandbox_tools(mgr)
+ names = [t.name for t in tools]
+ assert "sandbox_probe" in names
+ assert "sandbox_create" in names
+ assert "sandbox_exec" in names
+ assert "sandbox_read" in names
+ assert "sandbox_write" in names
+ assert len(tools) == 5
+
+ async def test_sandbox_create_needs_approval(self):
+ mgr = SandboxManager()
+ tools = create_sandbox_tools(mgr)
+ create = [t for t in tools if t.name == "sandbox_create"][0]
+ assert create.needs_approval is not None
+ assert create.needs_approval({"provider": "modal"}) is True
+
+ async def test_sandbox_exec_requires_command(self):
+ mgr = SandboxManager()
+ tools = create_sandbox_tools(mgr)
+ exec_tool = [t for t in tools if t.name == "sandbox_exec"][0]
+ assert "command" in exec_tool.parameters["required"]
+
+ async def test_sandbox_probe_no_params(self):
+ mgr = SandboxManager()
+ tools = create_sandbox_tools(mgr)
+ probe = [t for t in tools if t.name == "sandbox_probe"][0]
+ assert probe.parameters["properties"] == {}
+
+ async def test_sandbox_create_provider_enum(self):
+ mgr = SandboxManager()
+ tools = create_sandbox_tools(mgr)
+ create = [t for t in tools if t.name == "sandbox_create"][0]
+ providers = create.parameters["properties"]["provider"]["enum"]
+ assert "local" in providers
+ assert "ssh" in providers
+ assert "modal" in providers
+
+
+class TestHandleProbe:
+ async def test_probe_without_sandbox(self):
+ mgr = SandboxManager()
+ result, success = await _handle_probe(mgr)
+ assert success is True
+ assert "Using local environment" in result or "OS:" in result
+
+ async def test_probe_with_local_sandbox(self):
+ mgr = SandboxManager()
+ await mgr.create("local")
+ result, success = await _handle_probe(mgr)
+ assert success is True
+ assert "Sandbox Environment" in result
+ await mgr.destroy()
diff --git a/backend/tests/test_tools_search.py b/backend/tests/test_tools_search.py
new file mode 100644
index 0000000..36c9f11
--- /dev/null
+++ b/backend/tests/test_tools_search.py
@@ -0,0 +1,27 @@
+"""Tests for web search tool — tool spec validation."""
+
+import pytest
+from openmlr.tools.search import create_search_tools
+
+pytestmark = pytest.mark.asyncio
+
+
+class TestCreateSearchTools:
+ async def test_creates_tool(self):
+ tools = create_search_tools()
+ assert len(tools) == 1
+ assert tools[0].name == "web_search"
+
+ async def test_handler_configured(self):
+ tools = create_search_tools()
+ assert tools[0].handler is not None
+
+ async def test_query_required(self):
+ tools = create_search_tools()
+ assert "query" in tools[0].parameters["required"]
+
+ async def test_count_parameter(self):
+ tools = create_search_tools()
+ props = tools[0].parameters["properties"]
+ assert "count" in props
+ assert props["count"]["type"] == "integer"
diff --git a/backend/tests/test_tools_writing.py b/backend/tests/test_tools_writing.py
new file mode 100644
index 0000000..bb72e20
--- /dev/null
+++ b/backend/tests/test_tools_writing.py
@@ -0,0 +1,216 @@
+"""Tests for writing tool — project management and paper operations."""
+
+import pytest
+from openmlr.tools.writing import (
+ create_writing_tool, _create_project, _set_outline,
+ _write_section, _get_draft, _get_draft_from_proj,
+ _list_sections, _add_citation, _refine_section,
+ _count_sections,
+)
+
+pytestmark = pytest.mark.asyncio
+
+
+class TestCreateWritingTool:
+ async def test_creates_tool(self):
+ tool = create_writing_tool()
+ assert tool.name == "writing"
+ assert tool.handler is not None
+ assert "operation" in tool.parameters["required"]
+ ops = tool.parameters["properties"]["operation"]["enum"]
+ assert "create_project" in ops
+ assert "write_section" in ops
+ assert "get_draft" in ops
+
+
+class TestCreateProject:
+ async def test_creates_project(self):
+ from openmlr.tools.writing import _projects
+ _projects.clear()
+ result, ok = _create_project(conv_id=1, title="My Paper")
+ assert ok is True
+ assert "My Paper" in result
+ proj = _projects.get(1)
+ assert proj is not None
+ assert proj["title"] == "My Paper"
+ _projects.clear()
+
+ async def test_requires_title(self):
+ result, ok = _create_project(conv_id=1, title="")
+ assert ok is False
+ assert "title" in result.lower()
+
+
+class TestSetOutline:
+ async def test_no_project(self):
+ from openmlr.tools.writing import _projects
+ _projects.clear()
+ result, ok = _set_outline(conv_id=999, outline=[])
+ assert ok is False
+ assert "No paper project" in result
+
+ async def test_requires_outline(self):
+ from openmlr.tools.writing import _projects
+ _projects.clear()
+ _create_project(conv_id=1, title="Test")
+ result, ok = _set_outline(conv_id=1, outline=None)
+ assert ok is False
+ _projects.clear()
+
+ async def test_sets_outline(self):
+ from openmlr.tools.writing import _projects
+ _projects.clear()
+ _create_project(conv_id=1, title="Test")
+ outline = [
+ {"id": "sec1", "title": "Introduction"},
+ {"id": "sec2", "title": "Methods", "subsections": [
+ {"id": "sec2.1", "title": "Setup"},
+ ]},
+ ]
+ result, ok = _set_outline(conv_id=1, outline=outline)
+ assert ok is True
+ assert "Introduction" in result
+ assert "Methods" in result
+ assert "Setup" in result
+ _projects.clear()
+
+
+class TestWriteSection:
+ async def test_no_project(self):
+ from openmlr.tools.writing import _projects
+ _projects.clear()
+ result, ok = _write_section(conv_id=999, section_id="s1", content="text")
+ assert ok is False
+
+ async def test_writes_section(self):
+ from openmlr.tools.writing import _projects
+ _projects.clear()
+ _create_project(conv_id=1, title="Test")
+ _set_outline(conv_id=1, outline=[{"id": "intro", "title": "Introduction"}])
+ result, ok = _write_section(conv_id=1, section_id="intro", content="Hello world.")
+ assert ok is True
+ assert "intro" in result
+ assert "auto-saved" in result
+ _projects.clear()
+
+
+class TestGetDraft:
+ async def test_no_project(self):
+ from openmlr.tools.writing import _projects
+ _projects.clear()
+ result, ok = _get_draft(conv_id=999)
+ assert ok is False
+
+ async def test_generates_draft(self):
+ from openmlr.tools.writing import _projects
+ _projects.clear()
+ _create_project(conv_id=1, title="The Paper")
+ _set_outline(conv_id=1, outline=[{"id": "intro", "title": "Introduction"}])
+ _write_section(conv_id=1, section_id="intro", content="This is the intro.")
+ result, ok = _get_draft(conv_id=1)
+ assert ok is True
+ assert "# The Paper" in result
+ assert "Introduction" in result
+ assert "This is the intro." in result
+ _projects.clear()
+
+
+class TestGetDraftFromProj:
+ async def test_generates_full_draft(self):
+ proj = {
+ "title": "ML Research",
+ "outline": [
+ {"id": "abstract", "title": "Abstract"},
+ {"id": "method", "title": "Method", "subsections": [
+ {"id": "method.experimental", "title": "Experimental Setup"},
+ ]},
+ ],
+ "sections": {
+ "abstract": "This is the abstract.",
+ "method": "The method section.",
+ "method.experimental": "Experimental details.",
+ },
+ "bibliography": [
+ {"key": "doe2024", "author": "J. Doe", "title": "Great Paper", "year": "2024"},
+ ],
+ }
+ result, ok = _get_draft_from_proj(proj)
+ assert ok is True
+ assert "ML Research" in result
+ assert "Abstract" in result
+ assert "Method" in result
+ assert "Experimental Setup" in result
+ assert "References" in result
+ assert "J. Doe" in result
+
+
+class TestListSections:
+ async def test_no_project(self):
+ from openmlr.tools.writing import _projects
+ _projects.clear()
+ result, ok = _list_sections(conv_id=999)
+ assert ok is False
+ assert "No paper project" in result
+
+ async def test_lists_sections_with_status(self):
+ from openmlr.tools.writing import _projects
+ _projects.clear()
+ _create_project(conv_id=1, title="Test")
+ _set_outline(conv_id=1, outline=[
+ {"id": "s1", "title": "Section 1"},
+ {"id": "s2", "title": "Section 2"},
+ ])
+ _write_section(conv_id=1, section_id="s1", content="written")
+ result, ok = _list_sections(conv_id=1)
+ assert ok is True
+ assert "[done]" in result
+ assert "[pending]" in result
+ _projects.clear()
+
+
+class TestAddCitation:
+ async def test_adds_citation(self):
+ from openmlr.tools.writing import _projects
+ _projects.clear()
+ _create_project(conv_id=1, title="Test")
+ citation = {
+ "key": "smith2023",
+ "title": "Important Paper",
+ "author": "A. Smith",
+ "year": "2023",
+ }
+ result, ok = _add_citation(conv_id=1, citation=citation)
+ assert ok is True
+ assert "Added citation" in result
+ _projects.clear()
+
+
+class TestRefineSection:
+ async def test_returns_feedback_mode(self):
+ from openmlr.tools.writing import _projects
+ _projects.clear()
+ _create_project(conv_id=1, title="Test")
+ _set_outline(conv_id=1, outline=[{"id": "s1", "title": "Section"}])
+ _write_section(conv_id=1, section_id="s1", content="original content")
+ result, ok = _refine_section(
+ conv_id=1, section_id="s1",
+ content=None, feedback="make it better",
+ )
+ assert ok is True
+ assert "feedback" in result.lower()
+ _projects.clear()
+
+
+class TestCountSections:
+ async def test_counts_with_subsections(self):
+ outline = [
+ {"id": "a", "title": "A"},
+ {"id": "b", "title": "B", "subsections": [
+ {"id": "b1", "title": "B1"},
+ {"id": "b2", "title": "B2"},
+ ]},
+ ]
+ assert _count_sections(outline) == 4
+
+ async def test_empty_outline(self):
+ assert _count_sections([]) == 0
diff --git a/backend/tests/test_types.py b/backend/tests/test_types.py
new file mode 100644
index 0000000..4271696
--- /dev/null
+++ b/backend/tests/test_types.py
@@ -0,0 +1,172 @@
+"""Tests for agent core types — AgentEvent, OpType, Message, ToolCall, ToolSpec, LLMResult, Submission."""
+
+import json
+import pytest
+
+from openmlr.agent.types import (
+ AgentEvent, OpType, Message, ToolCall, ToolSpec, LLMResult, Submission,
+)
+
+
+class TestToolCall:
+ def test_creation(self):
+ tc = ToolCall(id="call_1", name="read_file", arguments={"path": "/tmp/test"})
+ assert tc.id == "call_1"
+ assert tc.name == "read_file"
+ assert tc.arguments == {"path": "/tmp/test"}
+
+ def test_empty_arguments(self):
+ tc = ToolCall(id="c2", name="noop", arguments={})
+ assert tc.arguments == {}
+
+
+class TestToolSpec:
+ def test_creation_without_optional(self):
+ ts = ToolSpec(name="test_tool", description="A test tool", parameters={"type": "object"})
+ assert ts.name == "test_tool"
+ assert ts.description == "A test tool"
+ assert ts.parameters == {"type": "object"}
+ assert ts.handler is None
+ assert ts.needs_approval is None
+
+ def test_creation_with_handler(self):
+ async def handler(arg1: str, arg2: int = 0) -> tuple[str, bool]:
+ return f"done: {arg1}", True
+
+ ts = ToolSpec(
+ name="with_handler",
+ description="Tool with handler",
+ parameters={"type": "object", "properties": {"arg1": {"type": "string"}}},
+ handler=handler,
+ )
+ assert ts.handler is not None
+ assert ts.name == "with_handler"
+
+ def test_needs_approval(self):
+ ts = ToolSpec(
+ name="dangerous",
+ description="Needs approval",
+ parameters={"type": "object"},
+ needs_approval=lambda **kwargs: True,
+ )
+ assert ts.needs_approval is not None
+
+
+class TestMessage:
+ def test_user_message(self):
+ msg = Message(role="user", content="Hello")
+ assert msg.role == "user"
+ assert msg.content == "Hello"
+ assert msg.tool_calls is None
+ assert msg.tool_call_id is None
+ assert msg.name is None
+
+ def test_assistant_with_tool_calls(self):
+ tc = ToolCall(id="tc1", name="bash", arguments={"cmd": "ls"})
+ msg = Message(role="assistant", content="", tool_calls=[tc])
+ assert msg.role == "assistant"
+ assert msg.content == ""
+ assert len(msg.tool_calls) == 1
+ assert msg.tool_calls[0].name == "bash"
+
+ def test_tool_result_message(self):
+ msg = Message(role="tool", content="output here", tool_call_id="tc1", name="bash")
+ assert msg.role == "tool"
+ assert msg.tool_call_id == "tc1"
+ assert msg.name == "bash"
+
+ def test_system_message(self):
+ msg = Message(role="system", content="You are an AI assistant.")
+ assert msg.role == "system"
+
+
+class TestAgentEvent:
+ def test_creation_without_data(self):
+ evt = AgentEvent(event_type="status")
+ assert evt.event_type == "status"
+ assert evt.data is None
+
+ def test_creation_with_data(self):
+ evt = AgentEvent(event_type="status", data={"status": "thinking..."})
+ assert evt.data == {"status": "thinking..."}
+
+ def test_creation_kwargs(self):
+ evt = AgentEvent(event_type="text_delta", data={"text": "hello"})
+ assert evt.event_type == "text_delta"
+ assert evt.data["text"] == "hello"
+
+ def test_to_sse_simple(self):
+ evt = AgentEvent(event_type="ping")
+ sse = evt.to_sse()
+ assert sse.startswith("data: ")
+ assert sse.endswith("\n\n")
+ parsed = json.loads(sse[6:-2])
+ assert parsed["event_type"] == "ping"
+ assert parsed["data"] is None
+
+ def test_to_sse_with_data(self):
+ evt = AgentEvent(event_type="status", data={"status": "ready"})
+ sse = evt.to_sse()
+ parsed = json.loads(sse[6:-2])
+ assert parsed["event_type"] == "status"
+ assert parsed["data"] == {"status": "ready"}
+
+ def test_to_sse_complex_data(self):
+ evt = AgentEvent(event_type="text_delta", data={"text": "hello\nworld", "index": 0})
+ sse = evt.to_sse()
+ parsed = json.loads(sse[6:-2])
+ assert parsed["data"]["text"] == "hello\nworld"
+ assert parsed["data"]["index"] == 0
+
+
+class TestOpType:
+ def test_enum_values(self):
+ assert OpType.USER_INPUT == "user_input"
+ assert OpType.EXEC_APPROVAL == "exec_approval"
+ assert OpType.COMPACT == "compact"
+ assert OpType.UNDO == "undo"
+ assert OpType.INTERRUPT == "interrupt"
+ assert OpType.SHUTDOWN == "shutdown"
+
+ def test_string_equality(self):
+ assert OpType.USER_INPUT == "user_input"
+ assert OpType.INTERRUPT != "user_input"
+
+
+class TestSubmission:
+ def test_creation(self):
+ sub = Submission(op=OpType.USER_INPUT, data="Send this message")
+ assert sub.op == OpType.USER_INPUT
+ assert sub.data == "Send this message"
+
+ def test_no_data_default(self):
+ sub = Submission(op=OpType.SHUTDOWN)
+ assert sub.op == OpType.SHUTDOWN
+ assert sub.data is None
+
+
+class TestLLMResult:
+ def test_basic_result(self):
+ result = LLMResult(content="Hello, world!", tool_calls=[], finish_reason="stop")
+ assert result.content == "Hello, world!"
+ assert result.tool_calls == []
+ assert result.finish_reason == "stop"
+ assert result.usage is None
+
+ def test_with_tool_calls(self):
+ tc = ToolCall(id="c1", name="search", arguments={"query": "test"})
+ result = LLMResult(content="", tool_calls=[tc], finish_reason="tool_calls")
+ assert len(result.tool_calls) == 1
+ assert result.tool_calls[0].name == "search"
+ assert result.finish_reason == "tool_calls"
+
+ def test_with_usage(self):
+ usage = {"prompt_tokens": 100, "completion_tokens": 50, "total_tokens": 150}
+ result = LLMResult(content="ok", tool_calls=[], finish_reason="stop", usage=usage)
+ assert result.usage == usage
+ assert result.usage["total_tokens"] == 150
+
+ def test_finish_reason_length(self):
+ """finish_reason can be 'length' for truncated responses."""
+ result = LLMResult(content="truncated", tool_calls=[], finish_reason="length")
+ assert result.finish_reason == "length"
diff --git a/frontend/src/__tests__/AgentSettings.test.tsx b/frontend/src/__tests__/AgentSettings.test.tsx
new file mode 100644
index 0000000..40ac845
--- /dev/null
+++ b/frontend/src/__tests__/AgentSettings.test.tsx
@@ -0,0 +1,64 @@
+import { describe, it, expect, vi } from 'vitest';
+import { render, screen, waitFor } from '@testing-library/react';
+import { AgentSettings } from '../components/settings/AgentSettings';
+import { api } from '../api';
+
+vi.mock('../api', () => ({
+ api: {
+ getSettings: vi.fn(),
+ updateSetting: vi.fn(),
+ },
+}));
+
+describe('AgentSettings', () => {
+ beforeEach(() => {
+ vi.clearAllMocks();
+ vi.mocked(api.getSettings).mockResolvedValue({ settings: {} });
+ });
+
+ it('renders model hint', async () => {
+ render();
+ await waitFor(() => {
+ expect(screen.getByText(/Model used for new conversations/)).toBeInTheDocument();
+ });
+ });
+
+ it('renders default model input', async () => {
+ render();
+ await waitFor(() => {
+ expect(screen.getByText('Default Model')).toBeInTheDocument();
+ expect(screen.getByPlaceholderText(/anthropic\/claude-sonnet-4/)).toBeInTheDocument();
+ });
+ });
+
+ it('renders research model input', async () => {
+ render();
+ await waitFor(() => {
+ expect(screen.getByText(/Research \/ Title Model/)).toBeInTheDocument();
+ });
+ });
+
+ it('renders yolo mode checkbox', async () => {
+ render();
+ await waitFor(() => {
+ expect(screen.getByText(/YOLO Mode/)).toBeInTheDocument();
+ });
+ });
+
+ it('renders save button', async () => {
+ render();
+ await waitFor(() => {
+ expect(screen.getByText('Save Agent Settings')).toBeInTheDocument();
+ });
+ });
+
+ it('loads settings on mount', async () => {
+ vi.mocked(api.getSettings).mockResolvedValue({
+ settings: { agent: { default_model: 'claude-4', yolo_mode: true } },
+ });
+ render();
+ await waitFor(() => {
+ expect(api.getSettings).toHaveBeenCalled();
+ });
+ });
+});
diff --git a/frontend/src/__tests__/ApprovalModal.test.tsx b/frontend/src/__tests__/ApprovalModal.test.tsx
new file mode 100644
index 0000000..2f2ff69
--- /dev/null
+++ b/frontend/src/__tests__/ApprovalModal.test.tsx
@@ -0,0 +1,127 @@
+import { describe, it, expect, vi } from 'vitest';
+import { render, screen, fireEvent } from '@testing-library/react';
+import { ApprovalModal } from '../components/ApprovalModal';
+import { api } from '../api';
+
+vi.mock('../api', () => ({
+ api: {
+ sendApproval: vi.fn(),
+ },
+}));
+
+describe('ApprovalModal', () => {
+ it('renders header', () => {
+ vi.mocked(api.sendApproval).mockResolvedValue({});
+ const event = {
+ data: {
+ tool_calls: [{ id: 'tc1', name: 'bash', arguments: { cmd: 'ls' } }],
+ },
+ };
+ render();
+ expect(screen.getByText('Approve Tool Calls')).toBeInTheDocument();
+ });
+
+ it('renders tool names', () => {
+ vi.mocked(api.sendApproval).mockResolvedValue({});
+ const event = {
+ data: {
+ tool_calls: [
+ { id: 'tc1', name: 'bash', arguments: { cmd: 'ls' } },
+ { id: 'tc2', name: 'write', arguments: { path: '/tmp/t', content: 'hi' } },
+ ],
+ },
+ };
+ render();
+ expect(screen.getByText('bash')).toBeInTheDocument();
+ expect(screen.getByText('write')).toBeInTheDocument();
+ });
+
+ it('renders approve and reject buttons', () => {
+ vi.mocked(api.sendApproval).mockResolvedValue({});
+ const event = {
+ data: {
+ tool_calls: [{ id: 'tc1', name: 'test', arguments: {} }],
+ },
+ };
+ render();
+ expect(screen.getByRole('button', { name: 'Approve' })).toBeInTheDocument();
+ expect(screen.getByRole('button', { name: 'Reject' })).toBeInTheDocument();
+ });
+
+ it('displays tool args', () => {
+ vi.mocked(api.sendApproval).mockResolvedValue({});
+ const event = {
+ data: {
+ tool_calls: [{ id: 'tc1', name: 'run', arguments: { query: 'test', count: 5 } }],
+ },
+ };
+ render();
+ expect(screen.getByText(/query/)).toBeInTheDocument();
+ });
+
+ it('handles string arguments', () => {
+ vi.mocked(api.sendApproval).mockResolvedValue({});
+ const event = {
+ data: {
+ tool_calls: [{ id: 'tc1', name: 'echo', arguments: 'hello world' }],
+ },
+ };
+ render();
+ expect(screen.getByText('hello world')).toBeInTheDocument();
+ });
+
+ it('handles tools field as fallback', () => {
+ vi.mocked(api.sendApproval).mockResolvedValue({});
+ const event = {
+ data: {
+ tools: [{ id: 'tc1', name: 'fallback_tool', arguments: {} }],
+ },
+ };
+ render();
+ expect(screen.getByText('fallback_tool')).toBeInTheDocument();
+ });
+
+ it('calls sendApproval with approved', async () => {
+ const mockSend = vi.mocked(api.sendApproval).mockResolvedValue({});
+ const onClose = vi.fn();
+ const event = {
+ data: {
+ tool_calls: [{ id: 'tc1', name: 'bash', arguments: {} }],
+ },
+ };
+ render();
+
+ fireEvent.click(screen.getByRole('button', { name: 'Approve' }));
+ expect(mockSend).toHaveBeenCalledWith({ tc1: true });
+ });
+
+ it('calls sendApproval with rejected', () => {
+ const mockSend = vi.mocked(api.sendApproval).mockResolvedValue({});
+ const onClose = vi.fn();
+ const event = {
+ data: {
+ tool_calls: [{ id: 'tc1', name: 'bash', arguments: {} }],
+ },
+ };
+ render();
+
+ fireEvent.click(screen.getByRole('button', { name: 'Reject' }));
+ expect(mockSend).toHaveBeenCalledWith({ tc1: false });
+ });
+
+ it('approves multiple tool calls', () => {
+ const mockSend = vi.mocked(api.sendApproval).mockResolvedValue({});
+ const event = {
+ data: {
+ tool_calls: [
+ { id: 'a', name: 'one', arguments: {} },
+ { id: 'b', name: 'two', arguments: {} },
+ { id: 'c', name: 'three', arguments: {} },
+ ],
+ },
+ };
+ render();
+ fireEvent.click(screen.getByText('Approve'));
+ expect(mockSend).toHaveBeenCalledWith({ a: true, b: true, c: true });
+ });
+});
diff --git a/frontend/src/__tests__/AuthGuard.test.tsx b/frontend/src/__tests__/AuthGuard.test.tsx
new file mode 100644
index 0000000..babc938
--- /dev/null
+++ b/frontend/src/__tests__/AuthGuard.test.tsx
@@ -0,0 +1,106 @@
+import { describe, it, expect, vi, beforeEach } from 'vitest';
+import { render, screen, waitFor } from '@testing-library/react';
+import { MemoryRouter } from 'react-router-dom';
+import { AuthGuard } from '../components/AuthGuard';
+import { api, setToken } from '../api';
+
+vi.mock('../api', () => ({
+ api: {
+ getMe: vi.fn(),
+ },
+ setToken: vi.fn(),
+}));
+
+const getItemMock = vi.fn();
+const setItemMock = vi.fn();
+const removeItemMock = vi.fn();
+
+Object.defineProperty(window, 'localStorage', {
+ value: {
+ getItem: getItemMock,
+ setItem: setItemMock,
+ removeItem: removeItemMock,
+ },
+ writable: true,
+});
+
+describe('AuthGuard', () => {
+ beforeEach(() => {
+ vi.resetAllMocks();
+ getItemMock.mockReturnValue(null);
+ });
+
+ it('shows loading when checking token', () => {
+ getItemMock.mockReturnValue('existing-token');
+ vi.mocked(api.getMe).mockImplementation(() => new Promise(() => {}));
+
+ render(
+
+
+
+ );
+
+ expect(screen.getByText('Loading...')).toBeInTheDocument();
+ });
+
+ it('redirects to /login when no user and no token', async () => {
+ render(
+
+
+
+ );
+
+ await waitFor(() => {
+ // Navigate should redirect to /login
+ expect(screen.queryByText('Loading...')).not.toBeInTheDocument();
+ });
+ });
+
+ it('authenticates with valid token', async () => {
+ getItemMock.mockReturnValue('valid-token');
+ const user = { id: 1, username: 'test', display_name: 'Test' };
+ vi.mocked(api.getMe).mockResolvedValue(user);
+ const onAuth = vi.fn();
+
+ render(
+
+
+
+ );
+
+ await waitFor(() => {
+ expect(onAuth).toHaveBeenCalledWith(user);
+ expect(screen.queryByText('Loading...')).not.toBeInTheDocument();
+ });
+ });
+
+ it('clears token on invalid token', async () => {
+ getItemMock.mockReturnValue('invalid-token');
+ vi.mocked(api.getMe).mockRejectedValue(new Error('Invalid'));
+ const onAuth = vi.fn();
+
+ render(
+
+
+
+ );
+
+ await waitFor(() => {
+ expect(setToken).toHaveBeenCalledWith(null);
+ expect(onAuth).not.toHaveBeenCalled();
+ });
+ });
+
+ it('skips check when user already passed in', () => {
+ const user = { id: 1, username: 'test', display_name: 'Test' };
+ vi.mocked(api.getMe).mockResolvedValue(user);
+
+ render(
+
+
+
+ );
+
+ expect(api.getMe).not.toHaveBeenCalled();
+ });
+});
diff --git a/frontend/src/__tests__/ConfirmDialog.test.tsx b/frontend/src/__tests__/ConfirmDialog.test.tsx
new file mode 100644
index 0000000..e5fd9df
--- /dev/null
+++ b/frontend/src/__tests__/ConfirmDialog.test.tsx
@@ -0,0 +1,159 @@
+import { describe, it, expect, vi } from 'vitest';
+import { render, screen, fireEvent } from '@testing-library/react';
+import { ConfirmDialog } from '../components/ConfirmDialog';
+
+describe('ConfirmDialog', () => {
+ it('renders title and message', () => {
+ render(
+
+ );
+ expect(screen.getByText('Delete Item')).toBeInTheDocument();
+ expect(screen.getByText('Are you sure you want to delete this item?')).toBeInTheDocument();
+ });
+
+ it('renders default button labels', () => {
+ render(
+
+ );
+ expect(screen.getByRole('button', { name: 'Cancel' })).toBeInTheDocument();
+ expect(screen.getByRole('button', { name: 'Confirm' })).toBeInTheDocument();
+ });
+
+ it('renders custom button labels', () => {
+ render(
+
+ );
+ expect(screen.getByRole('button', { name: 'Yes, Delete' })).toBeInTheDocument();
+ expect(screen.getByRole('button', { name: 'No, Keep' })).toBeInTheDocument();
+ });
+
+ it('calls onConfirm when confirm button clicked', () => {
+ const onConfirm = vi.fn();
+ render(
+
+ );
+ fireEvent.click(screen.getByRole('button', { name: 'Confirm' }));
+ expect(onConfirm).toHaveBeenCalledTimes(1);
+ });
+
+ it('calls onCancel when cancel button clicked', () => {
+ const onCancel = vi.fn();
+ render(
+
+ );
+ fireEvent.click(screen.getByRole('button', { name: 'Cancel' }));
+ expect(onCancel).toHaveBeenCalledTimes(1);
+ });
+
+ it('calls onCancel when overlay clicked', () => {
+ const onCancel = vi.fn();
+ render(
+
+ );
+ const overlay = document.querySelector('.modal-overlay');
+ fireEvent.click(overlay!);
+ expect(onCancel).toHaveBeenCalled();
+ });
+
+ it('calls onCancel on Escape key', () => {
+ const onCancel = vi.fn();
+ render(
+
+ );
+ fireEvent.keyDown(window, { key: 'Escape' });
+ expect(onCancel).toHaveBeenCalled();
+ });
+
+ it('does not call onCancel for non-Escape keys', () => {
+ const onCancel = vi.fn();
+ render(
+
+ );
+ fireEvent.keyDown(window, { key: 'Enter' });
+ expect(onCancel).not.toHaveBeenCalled();
+ });
+
+ it('applies danger class to confirm button', () => {
+ render(
+
+ );
+ const confirmBtn = screen.getByRole('button', { name: 'Confirm' });
+ expect(confirmBtn.className).toContain('btn-danger');
+ });
+
+ it('uses btn-confirm class for non-danger', () => {
+ render(
+
+ );
+ const confirmBtn = screen.getByRole('button', { name: 'Confirm' });
+ expect(confirmBtn.className).toContain('btn-confirm');
+ });
+
+ it('cleans up Escape listener on unmount', () => {
+ const onCancel = vi.fn();
+ const { unmount } = render(
+
+ );
+ unmount();
+ fireEvent.keyDown(window, { key: 'Escape' });
+ expect(onCancel).not.toHaveBeenCalled();
+ });
+});
diff --git a/frontend/src/__tests__/LoginPage.test.tsx b/frontend/src/__tests__/LoginPage.test.tsx
new file mode 100644
index 0000000..587b99c
--- /dev/null
+++ b/frontend/src/__tests__/LoginPage.test.tsx
@@ -0,0 +1,212 @@
+import { describe, it, expect, vi } from 'vitest';
+import { render, screen, fireEvent, waitFor } from '@testing-library/react';
+import userEvent from '@testing-library/user-event';
+import { MemoryRouter } from 'react-router-dom';
+import { LoginPage } from '../components/LoginPage';
+import { api, setToken } from '../api';
+
+vi.mock('../api', () => ({
+ api: {
+ checkSetup: vi.fn(),
+ register: vi.fn(),
+ login: vi.fn(),
+ },
+ setToken: vi.fn(),
+}));
+
+const mockNavigate = vi.fn();
+vi.mock('react-router-dom', async (importOriginal) => {
+ const actual = await importOriginal() as any;
+ return {
+ ...actual,
+ useNavigate: () => mockNavigate,
+ };
+});
+
+function renderLoginPage(onAuth = vi.fn()) {
+ return render(
+
+
+
+ );
+}
+
+describe('LoginPage', () => {
+ beforeEach(() => {
+ vi.resetAllMocks();
+ vi.mocked(api.checkSetup).mockResolvedValue({ has_users: true });
+ });
+
+ it('renders sign in form', async () => {
+ renderLoginPage();
+ await waitFor(() => {
+ expect(screen.getByText('OpenMLR')).toBeInTheDocument();
+ });
+ expect(screen.getByPlaceholderText('Username')).toBeInTheDocument();
+ expect(screen.getByPlaceholderText('Password')).toBeInTheDocument();
+ });
+
+ it('shows login and register tabs when users exist', async () => {
+ renderLoginPage();
+ await waitFor(() => {
+ const tabs = screen.getAllByRole('button');
+ const tabTexts = tabs.map(t => t.textContent);
+ expect(tabTexts).toContain('Sign In');
+ expect(tabTexts).toContain('Register');
+ });
+ });
+
+ it('switches to register mode', async () => {
+ renderLoginPage();
+ await waitFor(() => {
+ expect(screen.getByText('Register')).toBeInTheDocument();
+ });
+ fireEvent.click(screen.getByText('Register'));
+ const createBtn = screen.getByText('Create Account');
+ expect(createBtn).toBeInTheDocument();
+ expect(screen.getByPlaceholderText('Display name (optional)')).toBeInTheDocument();
+ });
+
+ it('shows first-user notice when no users exist', async () => {
+ vi.mocked(api.checkSetup).mockResolvedValue({ has_users: false });
+ renderLoginPage();
+ await waitFor(() => {
+ expect(screen.getByText(/Create your account/)).toBeInTheDocument();
+ });
+ });
+
+ it('handles login submission', async () => {
+ const onAuth = vi.fn();
+ vi.mocked(api.login).mockResolvedValue({
+ access_token: 'test-token',
+ user: { id: 1, username: 'testuser', display_name: 'Test User' },
+ });
+
+ renderLoginPage(onAuth);
+
+ await waitFor(() => {
+ expect(screen.getByPlaceholderText('Username')).toBeInTheDocument();
+ });
+
+ const user = userEvent.setup();
+ await user.type(screen.getByPlaceholderText('Username'), 'testuser');
+ await user.type(screen.getByPlaceholderText('Password'), 'password123');
+
+ const submitButton = screen.getByText('Sign In', { selector: '.login-submit' });
+ await user.click(submitButton);
+
+ await waitFor(() => {
+ expect(api.login).toHaveBeenCalledWith('testuser', 'password123');
+ expect(setToken).toHaveBeenCalledWith('test-token');
+ expect(onAuth).toHaveBeenCalledWith({ id: 1, username: 'testuser', display_name: 'Test User' });
+ expect(mockNavigate).toHaveBeenCalledWith('/', { replace: true });
+ });
+ });
+
+ it('handles register submission', async () => {
+ const onAuth = vi.fn();
+ vi.mocked(api.register).mockResolvedValue({
+ access_token: 'reg-token',
+ user: { id: 2, username: 'newuser', display_name: 'New' },
+ });
+
+ renderLoginPage(onAuth);
+
+ await waitFor(() => {
+ fireEvent.click(screen.getByText('Register'));
+ });
+
+ const user = userEvent.setup();
+ await user.type(screen.getByPlaceholderText('Username'), 'newuser');
+ await user.type(screen.getByPlaceholderText('Display name (optional)'), 'New');
+ await user.type(screen.getByPlaceholderText('Password'), 'password123');
+ await user.click(screen.getByText('Create Account', { selector: '.login-submit' }));
+
+ await waitFor(() => {
+ expect(api.register).toHaveBeenCalledWith('newuser', 'password123', 'New');
+ expect(setToken).toHaveBeenCalledWith('reg-token');
+ });
+ });
+
+ it('displays error on auth failure', async () => {
+ vi.mocked(api.login).mockRejectedValue(new Error('Invalid credentials'));
+
+ renderLoginPage();
+
+ await waitFor(() => {
+ expect(screen.getByPlaceholderText('Username')).toBeInTheDocument();
+ });
+
+ const user = userEvent.setup();
+ await user.type(screen.getByPlaceholderText('Username'), 'wrong');
+ await user.type(screen.getByPlaceholderText('Password'), 'wrong');
+ await user.click(screen.getByText('Sign In', { selector: '.login-submit' }));
+
+ await waitFor(() => {
+ expect(screen.getByText('Invalid credentials')).toBeInTheDocument();
+ });
+ });
+
+ it('displays generic error for unknown failures', async () => {
+ vi.mocked(api.login).mockRejectedValue(new Error());
+
+ renderLoginPage();
+
+ await waitFor(() => {
+ expect(screen.getByPlaceholderText('Username')).toBeInTheDocument();
+ });
+
+ const user = userEvent.setup();
+ await user.type(screen.getByPlaceholderText('Username'), 'test');
+ await user.type(screen.getByPlaceholderText('Password'), 'pass');
+ await user.click(screen.getByText('Sign In', { selector: '.login-submit' }));
+
+ await waitFor(() => {
+ expect(screen.getByText('Authentication failed')).toBeInTheDocument();
+ });
+ });
+
+ it('clears error when switching modes', async () => {
+ vi.mocked(api.register).mockRejectedValue(new Error('Username taken'));
+
+ renderLoginPage();
+
+ await waitFor(() => {
+ fireEvent.click(screen.getByText('Register'));
+ });
+
+ // Fill required fields and submit register form
+ const user = userEvent.setup();
+ await user.type(screen.getByPlaceholderText('Username'), 'newuser');
+ await user.type(screen.getByPlaceholderText('Password'), 'password');
+ await user.click(screen.getByText('Create Account', { selector: '.login-submit' }));
+
+ await waitFor(() => {
+ expect(screen.getByText('Username taken')).toBeInTheDocument();
+ });
+
+ fireEvent.click(screen.getByRole('button', { name: 'Sign In' }));
+ await waitFor(() => {
+ expect(screen.queryByText('Username taken')).not.toBeInTheDocument();
+ });
+ });
+
+ it('disables submit button while loading', async () => {
+ vi.mocked(api.login).mockImplementation(() => new Promise(() => {}));
+
+ renderLoginPage();
+
+ await waitFor(() => {
+ expect(screen.getByPlaceholderText('Username')).toBeInTheDocument();
+ });
+
+ const user = userEvent.setup();
+ await user.type(screen.getByPlaceholderText('Username'), 'test');
+ await user.type(screen.getByPlaceholderText('Password'), 'pass');
+ await user.click(screen.getByText('Sign In', { selector: '.login-submit' }));
+
+ await waitFor(() => {
+ expect(screen.getByText('Please wait...')).toBeInTheDocument();
+ });
+ });
+});
diff --git a/frontend/src/__tests__/ModelModal.test.tsx b/frontend/src/__tests__/ModelModal.test.tsx
new file mode 100644
index 0000000..983f5b0
--- /dev/null
+++ b/frontend/src/__tests__/ModelModal.test.tsx
@@ -0,0 +1,176 @@
+import { describe, it, expect, vi, beforeEach } from 'vitest';
+import { render, screen, fireEvent, waitFor } from '@testing-library/react';
+import { ModelModal } from '../components/ModelModal';
+import { api } from '../api';
+
+vi.mock('../api', () => ({
+ api: {
+ getProviders: vi.fn(),
+ getModels: vi.fn(),
+ saveConfig: vi.fn(),
+ setModel: vi.fn(),
+ },
+}));
+
+const defaultProviders = [
+ { id: 'openai', name: 'OpenAI', key_env: 'OPENAI_API_KEY', configured: true },
+ { id: 'anthropic', name: 'Anthropic', key_env: 'ANTHROPIC_API_KEY', configured: false },
+];
+
+const defaultModels = [
+ { id: 'openai/gpt-4o', name: 'GPT-4o', provider: 'openai' },
+ { id: 'openai/gpt-4o-mini', name: 'GPT-4o Mini', provider: 'openai' },
+ { id: 'anthropic/claude-4', name: 'Claude 4', provider: 'anthropic' },
+];
+
+describe('ModelModal', () => {
+ beforeEach(() => {
+ vi.resetAllMocks();
+ vi.mocked(api.getProviders).mockResolvedValue({ providers: defaultProviders });
+ vi.mocked(api.getModels).mockResolvedValue({ models: defaultModels });
+ });
+
+ it('renders current model button', () => {
+ render();
+ expect(screen.getByText('openai/gpt-4o')).toBeInTheDocument();
+ });
+
+ it('opens modal on button click', () => {
+ render();
+ fireEvent.click(screen.getByText('openai/gpt-4o'));
+ expect(screen.getByText('Models')).toBeInTheDocument();
+ expect(screen.getByText('Providers')).toBeInTheDocument();
+ });
+
+ it('shows model list when opened', async () => {
+ render();
+ fireEvent.click(screen.getByText('openai/gpt-4o'));
+
+ await waitFor(() => {
+ expect(screen.getByText('GPT-4o')).toBeInTheDocument();
+ expect(screen.getByText('GPT-4o Mini')).toBeInTheDocument();
+ expect(screen.getByText('Claude 4')).toBeInTheDocument();
+ });
+ });
+
+ it('highlights current model', async () => {
+ render();
+ fireEvent.click(screen.getByText('openai/gpt-4o-mini'));
+
+ await waitFor(() => {
+ const options = document.querySelectorAll('.model-picker-option');
+ const mini = Array.from(options).find(o => o.textContent?.includes('GPT-4o Mini'));
+ expect(mini?.classList.contains('active')).toBe(true);
+ });
+ });
+
+ it('switches to providers tab', async () => {
+ render();
+ fireEvent.click(screen.getByText('openai/gpt-4o'));
+ fireEvent.click(screen.getByText('Providers'));
+
+ await waitFor(() => {
+ expect(screen.getByText('Configured')).toBeInTheDocument();
+ expect(screen.getByText('API key missing')).toBeInTheDocument();
+ });
+ });
+
+ it('filters models by provider', async () => {
+ render();
+ fireEvent.click(screen.getByText('openai/gpt-4o'));
+
+ await waitFor(() => {
+ expect(screen.getByText('Claude 4')).toBeInTheDocument();
+ });
+
+ const select = document.querySelector('select')!;
+ fireEvent.change(select, { target: { value: 'openai' } });
+
+ await waitFor(() => {
+ expect(screen.getByText('GPT-4o')).toBeInTheDocument();
+ expect(screen.queryByText('Claude 4')).not.toBeInTheDocument();
+ });
+ });
+
+ it('selects model and calls onModelChange', async () => {
+ vi.mocked(api.setModel).mockResolvedValue({ ok: true });
+ const onChange = vi.fn();
+
+ render();
+ fireEvent.click(screen.getByText('openai/gpt-4o'));
+
+ await waitFor(() => {
+ expect(screen.getByText('Claude 4')).toBeInTheDocument();
+ });
+
+ fireEvent.click(screen.getByText('Claude 4'));
+
+ await waitFor(() => {
+ expect(api.setModel).toHaveBeenCalledWith('anthropic/claude-4');
+ expect(onChange).toHaveBeenCalledWith('anthropic/claude-4');
+ });
+ });
+
+ it('closed modal on close button', async () => {
+ vi.mocked(api.setModel).mockResolvedValue({ ok: true });
+
+ render();
+ fireEvent.click(screen.getByText('openai/gpt-4o'));
+
+ await waitFor(() => {
+ expect(screen.getByText('GPT-4o')).toBeInTheDocument();
+ });
+
+ fireEvent.click(screen.getByText('Close'));
+
+ await waitFor(() => {
+ expect(screen.queryByText('GPT-4o Mini')).not.toBeInTheDocument();
+ });
+ });
+
+ it('shows custom model input', async () => {
+ render();
+ fireEvent.click(screen.getByText('openai/gpt-4o'));
+
+ await waitFor(() => {
+ expect(screen.getByPlaceholderText('Custom model ID')).toBeInTheDocument();
+ });
+ });
+
+ it('uses custom model on Enter', async () => {
+ vi.mocked(api.setModel).mockResolvedValue({ ok: true });
+ const onModelChange = vi.fn();
+
+ render();
+ fireEvent.click(screen.getByText('openai/gpt-4o'));
+
+ await waitFor(() => {
+ expect(screen.getByPlaceholderText('Custom model ID')).toBeInTheDocument();
+ });
+
+ const input = screen.getByPlaceholderText('Custom model ID');
+ fireEvent.change(input, { target: { value: 'my-custom-model' } });
+ fireEvent.keyDown(input, { key: 'Enter' });
+
+ await waitFor(() => {
+ expect(api.setModel).toHaveBeenCalledWith('my-custom-model');
+ expect(onModelChange).toHaveBeenCalledWith('my-custom-model');
+ });
+ });
+
+ it('closes modal when overlay clicked', async () => {
+ render();
+ fireEvent.click(screen.getByText('openai/gpt-4o'));
+
+ await waitFor(() => {
+ expect(screen.getByText('GPT-4o')).toBeInTheDocument();
+ });
+
+ const overlay = document.querySelector('.modal-overlay');
+ fireEvent.click(overlay!);
+
+ await waitFor(() => {
+ expect(screen.queryByText('GPT-4o Mini')).not.toBeInTheDocument();
+ });
+ });
+});
diff --git a/frontend/src/__tests__/OnboardingModal.test.tsx b/frontend/src/__tests__/OnboardingModal.test.tsx
new file mode 100644
index 0000000..8ac92bc
--- /dev/null
+++ b/frontend/src/__tests__/OnboardingModal.test.tsx
@@ -0,0 +1,185 @@
+import { describe, it, expect, vi, beforeEach } from 'vitest';
+import { render, screen, fireEvent, waitFor } from '@testing-library/react';
+import { OnboardingModal } from '../components/OnboardingModal';
+import { api } from '../api';
+
+vi.mock('../api', () => ({
+ api: {
+ getProviders: vi.fn(),
+ getModels: vi.fn(),
+ saveConfig: vi.fn(),
+ setModel: vi.fn(),
+ },
+}));
+
+const defaultProviders = [
+ { id: 'openai', name: 'OpenAI', key_env: 'OPENAI_API_KEY', configured: false },
+ { id: 'anthropic', name: 'Anthropic', key_env: 'ANTHROPIC_API_KEY', configured: false },
+ { id: 'openrouter', name: 'OpenRouter', key_env: 'OPENROUTER_API_KEY', configured: false },
+ { id: 'opencode-go', name: 'OpenCode Go', key_env: 'OPENCODE_GO_API_KEY', configured: false },
+ { id: 'ollama', name: 'Ollama', key_env: 'OLLAMA_API_BASE', configured: false },
+ { id: 'brave', name: 'Brave Search', key_env: 'BRAVE_API_KEY', configured: false },
+];
+
+const defaultModels = [
+ { id: 'openai/gpt-4o', name: 'GPT-4o', provider: 'openai' },
+ { id: 'anthropic/claude-4', name: 'Claude 4', provider: 'anthropic' },
+];
+
+describe('OnboardingModal', () => {
+ beforeEach(() => {
+ vi.resetAllMocks();
+ vi.mocked(api.getProviders).mockResolvedValue({ providers: defaultProviders });
+ vi.mocked(api.getModels).mockResolvedValue({ models: defaultModels });
+ });
+
+ it('renders welcome heading', async () => {
+ render();
+ await waitFor(() => {
+ expect(screen.getByText('Welcome to OpenMLR')).toBeInTheDocument();
+ });
+ });
+
+ it('shows loading state initially', () => {
+ vi.mocked(api.getProviders).mockImplementation(() => new Promise(() => {}));
+ render();
+ expect(screen.getByText('Loading...')).toBeInTheDocument();
+ });
+
+ it('shows providers step by default', async () => {
+ render();
+ await waitFor(() => {
+ expect(screen.getByText(/Configure at least one LLM provider/)).toBeInTheDocument();
+ });
+ });
+
+ it('renders LLM providers', async () => {
+ render();
+ await waitFor(() => {
+ expect(screen.getByText('OpenAI')).toBeInTheDocument();
+ expect(screen.getByText('Anthropic')).toBeInTheDocument();
+ });
+ });
+
+ it('shows key inputs for unconfigured providers', async () => {
+ render();
+ await waitFor(() => {
+ expect(screen.getByPlaceholderText(/OPENAI_API_KEY/)).toBeInTheDocument();
+ });
+ });
+
+ it('shows "configured" for configured providers', async () => {
+ vi.mocked(api.getProviders).mockResolvedValue({
+ providers: [
+ { id: 'openai', name: 'OpenAI', key_env: 'OPENAI_API_KEY', configured: true },
+ { id: 'anthropic', name: 'Anthropic', key_env: 'ANTHROPIC_API_KEY', configured: false },
+ { id: 'openrouter', name: 'OpenRouter', key_env: 'OPENROUTER_API_KEY', configured: false },
+ { id: 'opencode-go', name: 'OpenCode Go', key_env: 'OPENCODE_GO_API_KEY', configured: false },
+ { id: 'ollama', name: 'Ollama', key_env: 'OLLAMA_API_BASE', configured: false },
+ ],
+ });
+ render();
+ // When at least one provider is configured, step auto-skips to model selection
+ await waitFor(() => {
+ expect(screen.getByText(/Pick a model/)).toBeInTheDocument();
+ });
+ });
+
+ it('skips to model selection when providers are configured', async () => {
+ vi.mocked(api.getProviders).mockResolvedValue({
+ providers: [
+ { id: 'openai', name: 'OpenAI', key_env: 'OPENAI_API_KEY', configured: true },
+ ],
+ });
+ render();
+ await waitFor(() => {
+ expect(screen.getByText(/Pick a model/)).toBeInTheDocument();
+ });
+ });
+
+ it('disables save button when no keys entered', async () => {
+ render();
+ await waitFor(() => {
+ const btn = screen.getByText('Save & Continue');
+ expect(btn).toBeDisabled();
+ });
+ });
+
+ it('saves keys and moves to model step', async () => {
+ vi.mocked(api.saveConfig).mockResolvedValue({ ok: true });
+ vi.mocked(api.getProviders)
+ .mockResolvedValueOnce({ providers: defaultProviders })
+ .mockResolvedValueOnce({
+ providers: [{ ...defaultProviders[0], configured: true }, ...defaultProviders.slice(1)],
+ });
+
+ render();
+ await waitFor(() => {
+ expect(screen.getByPlaceholderText(/OPENAI_API_KEY/)).toBeInTheDocument();
+ });
+
+ // Type a key
+ const input = screen.getByPlaceholderText(/OPENAI_API_KEY/);
+ fireEvent.change(input, { target: { value: 'sk-test-key' } });
+
+ fireEvent.click(screen.getByText('Save & Continue'));
+
+ await waitFor(() => {
+ expect(api.saveConfig).toHaveBeenCalled();
+ expect(screen.getByText(/Pick a model/)).toBeInTheDocument();
+ });
+ });
+
+ it('shows models in model step', async () => {
+ vi.mocked(api.getProviders).mockResolvedValue({
+ providers: [{ ...defaultProviders[0], configured: true }],
+ });
+ render();
+ await waitFor(() => {
+ expect(screen.getByText('GPT-4o')).toBeInTheDocument();
+ expect(screen.getByText('Claude 4')).toBeInTheDocument();
+ });
+ });
+
+ it('filters models by provider', async () => {
+ vi.mocked(api.getProviders).mockResolvedValue({
+ providers: [
+ { id: 'openai', name: 'OpenAI', key_env: 'OPENAI_API_KEY', configured: true },
+ { id: 'anthropic', name: 'Anthropic', key_env: 'ANTHROPIC_API_KEY', configured: true },
+ ],
+ });
+ render();
+ await waitFor(() => {
+ expect(screen.getByText(/Pick a model/)).toBeInTheDocument();
+ });
+
+ // Select OpenAI from provider filter
+ const select = document.querySelector('select')!;
+ fireEvent.change(select, { target: { value: 'openai' } });
+
+ await waitFor(() => {
+ expect(screen.getByText('GPT-4o')).toBeInTheDocument();
+ expect(screen.queryByText('Claude 4')).not.toBeInTheDocument();
+ });
+ });
+
+ it('selects model and completes', async () => {
+ const onComplete = vi.fn();
+ vi.mocked(api.setModel).mockResolvedValue({ ok: true });
+ vi.mocked(api.getProviders).mockResolvedValue({
+ providers: [{ ...defaultProviders[0], configured: true }],
+ });
+
+ render();
+ await waitFor(() => {
+ expect(screen.getByText('GPT-4o')).toBeInTheDocument();
+ });
+
+ fireEvent.click(screen.getByText('GPT-4o'));
+
+ await waitFor(() => {
+ expect(api.setModel).toHaveBeenCalledWith('openai/gpt-4o');
+ expect(onComplete).toHaveBeenCalledWith('openai/gpt-4o');
+ });
+ });
+});
diff --git a/frontend/src/__tests__/ProvidersSettings.test.tsx b/frontend/src/__tests__/ProvidersSettings.test.tsx
new file mode 100644
index 0000000..6d3b98a
--- /dev/null
+++ b/frontend/src/__tests__/ProvidersSettings.test.tsx
@@ -0,0 +1,86 @@
+import { describe, it, expect, vi } from 'vitest';
+import { render, screen, fireEvent, waitFor } from '@testing-library/react';
+import { ProvidersSettings } from '../components/settings/ProvidersSettings';
+import { api } from '../api';
+
+vi.mock('../api', () => ({
+ api: {
+ getProviders: vi.fn(),
+ updateSetting: vi.fn(),
+ },
+}));
+
+describe('ProvidersSettings', () => {
+ const mockProviders = [
+ { id: 'openai', name: 'OpenAI', key_env: 'OPENAI_API_KEY', configured: true },
+ { id: 'anthropic', name: 'Anthropic', key_env: 'ANTHROPIC_API_KEY', configured: false },
+ ];
+
+ beforeEach(() => {
+ vi.clearAllMocks();
+ vi.mocked(api.getProviders).mockResolvedValue({ providers: mockProviders });
+ });
+
+ it('renders hint', async () => {
+ render();
+ await waitFor(() => {
+ expect(screen.getByText(/API keys are stored/)).toBeInTheDocument();
+ });
+ });
+
+ it('renders provider names', async () => {
+ render();
+ await waitFor(() => {
+ expect(screen.getByText('OpenAI')).toBeInTheDocument();
+ expect(screen.getByText('Anthropic')).toBeInTheDocument();
+ });
+ });
+
+ it('shows configured/not set status', async () => {
+ render();
+ await waitFor(() => {
+ expect(screen.getByText('Configured')).toBeInTheDocument();
+ expect(screen.getByText('Not set')).toBeInTheDocument();
+ });
+ });
+
+ it('renders key input placeholders', async () => {
+ render();
+ await waitFor(() => {
+ expect(screen.getByPlaceholderText('OPENAI_API_KEY')).toBeInTheDocument();
+ expect(screen.getByPlaceholderText('ANTHROPIC_API_KEY')).toBeInTheDocument();
+ });
+ });
+
+ it('renders save button', async () => {
+ render();
+ await waitFor(() => {
+ expect(screen.getByText('Save Keys')).toBeInTheDocument();
+ });
+ });
+
+ it('disables save when no keys entered', async () => {
+ render();
+ await waitFor(() => {
+ const btn = screen.getByText('Save Keys');
+ expect(btn).toBeDisabled();
+ });
+ });
+
+ it('enables save when key entered', async () => {
+ render();
+ await waitFor(() => {
+ const input = screen.getByPlaceholderText('ANTHROPIC_API_KEY');
+ fireEvent.change(input, { target: { value: 'sk-ant-test' } });
+ });
+ const btn = screen.getByText('Save Keys');
+ expect(btn).not.toBeDisabled();
+ });
+
+ it('loads providers on mount', async () => {
+ render();
+ await waitFor(() => {
+ expect(api.getProviders).toHaveBeenCalled();
+ });
+ });
+});
diff --git a/frontend/src/__tests__/QuestionDrawer.test.tsx b/frontend/src/__tests__/QuestionDrawer.test.tsx
new file mode 100644
index 0000000..c61e459
--- /dev/null
+++ b/frontend/src/__tests__/QuestionDrawer.test.tsx
@@ -0,0 +1,136 @@
+import { describe, it, expect, vi } from 'vitest';
+import { render, screen, fireEvent, waitFor } from '@testing-library/react';
+import { QuestionDrawer } from '../components/QuestionDrawer';
+import { api } from '../api';
+import type { QuestionsPayload } from '../types';
+
+vi.mock('../api', () => ({
+ api: {
+ submitAnswers: vi.fn(),
+ },
+}));
+
+const defaultPayload: QuestionsPayload = {
+ questions: [
+ {
+ id: 'q1',
+ question: 'What is your preference?',
+ options: [
+ { label: 'Option A', description: 'First option' },
+ { label: 'Option B' },
+ { label: 'Option C', description: 'Third option' },
+ ],
+ },
+ {
+ id: 'q2',
+ question: 'Choose a framework.',
+ options: [
+ { label: 'PyTorch' },
+ { label: 'JAX', description: 'Functional approach' },
+ ],
+ },
+ ],
+ context: 'Project setup questions',
+};
+
+describe('QuestionDrawer', () => {
+ it('renders context title', () => {
+ vi.mocked(api.submitAnswers).mockResolvedValue({});
+ render(
+
+ );
+ expect(screen.getByText('Project setup questions')).toBeInTheDocument();
+ });
+
+ it('renders question tabs', () => {
+ vi.mocked(api.submitAnswers).mockResolvedValue({});
+ render(
+
+ );
+ expect(screen.getByText('1')).toBeInTheDocument();
+ expect(screen.getByText('2')).toBeInTheDocument();
+ });
+
+ it('renders current question text', () => {
+ vi.mocked(api.submitAnswers).mockResolvedValue({});
+ render(
+
+ );
+ expect(screen.getByText('What is your preference?')).toBeInTheDocument();
+ });
+
+ it('renders options with descriptions', () => {
+ vi.mocked(api.submitAnswers).mockResolvedValue({});
+ render(
+
+ );
+ expect(screen.getByText('Option A')).toBeInTheDocument();
+ expect(screen.getByText('First option')).toBeInTheDocument();
+ expect(screen.getByText('Option C')).toBeInTheDocument();
+ expect(screen.getByText('Third option')).toBeInTheDocument();
+ });
+
+ it('shows text input by default', () => {
+ vi.mocked(api.submitAnswers).mockResolvedValue({});
+ render(
+
+ );
+ expect(screen.getByPlaceholderText('Or type your own answer...')).toBeInTheDocument();
+ });
+
+ it('navigates to next question', () => {
+ vi.mocked(api.submitAnswers).mockResolvedValue({});
+ render(
+
+ );
+ fireEvent.click(screen.getByText('Next'));
+ expect(screen.getByText('Choose a framework.')).toBeInTheDocument();
+ });
+
+ it('shows progress', () => {
+ vi.mocked(api.submitAnswers).mockResolvedValue({});
+ render(
+
+ );
+ expect(screen.getByText('0 / 2')).toBeInTheDocument();
+ });
+
+ it('selects option and advances', async () => {
+ vi.mocked(api.submitAnswers).mockResolvedValue({});
+ render(
+
+ );
+ fireEvent.click(screen.getByText('Option A'));
+ await waitFor(() => {
+ expect(screen.getByText('Choose a framework.')).toBeInTheDocument();
+ });
+ });
+
+ it('calls submit with answers', async () => {
+ const mockSubmit = vi.mocked(api.submitAnswers).mockResolvedValue({});
+ const onDone = vi.fn();
+ render(
+
+ );
+ // Answer first question
+ fireEvent.click(screen.getByText('Option A'));
+ await waitFor(() => {
+ expect(screen.getByText('Choose a framework.')).toBeInTheDocument();
+ });
+ // Answer second question
+ fireEvent.click(screen.getByText('PyTorch'));
+
+ fireEvent.click(screen.getByText('Submit'));
+ expect(mockSubmit).toHaveBeenCalledWith({ q1: 'Option A', q2: 'PyTorch' });
+ });
+
+ it('calls onClose when X clicked', () => {
+ vi.mocked(api.submitAnswers).mockResolvedValue({});
+ const onClose = vi.fn();
+ render(
+
+ );
+ fireEvent.click(screen.getByText('×'));
+ expect(onClose).toHaveBeenCalled();
+ });
+});
diff --git a/frontend/src/__tests__/ReportDrawer.test.tsx b/frontend/src/__tests__/ReportDrawer.test.tsx
new file mode 100644
index 0000000..c09bcff
--- /dev/null
+++ b/frontend/src/__tests__/ReportDrawer.test.tsx
@@ -0,0 +1,82 @@
+import { describe, it, expect, vi } from 'vitest';
+import { render, screen } from '@testing-library/react';
+import ReactMarkdown from 'react-markdown';
+import { ReportDrawer } from '../components/ReportDrawer';
+
+vi.mock('../api', () => ({
+ api: {
+ getReport: vi.fn(),
+ },
+}));
+
+describe('ReportDrawer', () => {
+ it('renders title and close button', () => {
+ render(
+
+ );
+ expect(screen.getByText('My Test Report')).toBeInTheDocument();
+ expect(screen.getByText('×')).toBeInTheDocument();
+ });
+
+ it('renders cached content without loading', () => {
+ render(
+
+ );
+ expect(screen.queryByText('Loading...')).not.toBeInTheDocument();
+ });
+
+ it('shows loading state without cached content', async () => {
+ const { api } = await import('../api');
+ vi.mocked(api.getReport).mockImplementation(() => new Promise(() => {})); // never resolves
+ render(
+
+ );
+ expect(screen.getByText('Loading...')).toBeInTheDocument();
+ });
+
+ it('calls onClose when close button clicked', () => {
+ const onClose = vi.fn();
+ render(
+
+ );
+ screen.getByText('×').click();
+ expect(onClose).toHaveBeenCalled();
+ });
+
+ it('calls onClose when overlay clicked', () => {
+ const onClose = vi.fn();
+ render(
+
+ );
+ const overlay = document.querySelector('.report-overlay');
+ if (overlay) {
+ (overlay as HTMLElement).click();
+ expect(onClose).toHaveBeenCalled();
+ }
+ });
+});
diff --git a/frontend/src/__tests__/RightPanel.test.tsx b/frontend/src/__tests__/RightPanel.test.tsx
new file mode 100644
index 0000000..4dc575d
--- /dev/null
+++ b/frontend/src/__tests__/RightPanel.test.tsx
@@ -0,0 +1,185 @@
+import { describe, it, expect, vi } from 'vitest';
+import { render, screen } from '@testing-library/react';
+import { RightPanel } from '../components/RightPanel';
+import type { PlanTask, Resource, ContextUsage, SearchBudget } from '../types';
+
+vi.mock('../api', () => ({
+ api: {
+ getReport: vi.fn(),
+ },
+}));
+
+describe('RightPanel', () => {
+ const mockTasks: PlanTask[] = [
+ { title: 'Read papers', status: 'completed' },
+ { title: 'Implement model', status: 'in_progress' },
+ { title: 'Write report', status: 'pending' },
+ ];
+
+ const mockResources: Resource[] = [
+ { title: 'Plan', type: 'plan', id: 'plan-1' },
+ { title: 'ArXiv Paper', type: 'paper', id: 'paper-1' },
+ { title: 'Dataset X', type: 'dataset', url: 'https://example.com' },
+ ];
+
+ const mockContext: ContextUsage = { used: 50000, max: 200000, ratio: 0.25 };
+ const mockSearchBudget: SearchBudget = { used: 5, max: 25 };
+
+ it('renders toggle button when not visible', () => {
+ render(
+
+ );
+ expect(screen.getByTitle('Tasks & resources')).toBeInTheDocument();
+ });
+
+ it('renders tasks when visible', () => {
+ render(
+
+ );
+ expect(screen.getByText('Read papers')).toBeInTheDocument();
+ expect(screen.getByText('Implement model')).toBeInTheDocument();
+ expect(screen.getByText('Write report')).toBeInTheDocument();
+ });
+
+ it('shows task count', () => {
+ render(
+
+ );
+ expect(screen.getByText('Tasks (1/3)')).toBeInTheDocument();
+ });
+
+ it('shows "No tasks yet" when empty', () => {
+ render(
+
+ );
+ expect(screen.getByText('No tasks yet')).toBeInTheDocument();
+ expect(screen.getByText('No resources yet')).toBeInTheDocument();
+ });
+
+ it('renders context gauge', () => {
+ render(
+
+ );
+ expect(screen.getByText(/50k/)).toBeInTheDocument();
+ expect(screen.getByText(/200k/)).toBeInTheDocument();
+ });
+
+ it('renders search budget gauge', () => {
+ render(
+
+ );
+ expect(screen.getByText(/Searches:/)).toBeInTheDocument();
+ });
+
+ it('renders resources', () => {
+ render(
+
+ );
+ expect(screen.getByText('Dataset X')).toBeInTheDocument();
+ });
+
+ it('renders paper resource with export buttons', () => {
+ render(
+
+ );
+ expect(screen.getByText('.md')).toBeInTheDocument();
+ expect(screen.getByText('.tex')).toBeInTheDocument();
+ });
+
+ it('hides toggle badge when no tasks and visible', () => {
+ render(
+
+ );
+ expect(screen.queryByTitle('Tasks & resources')).not.toBeInTheDocument();
+ });
+
+ it('shows toggle badge with task count when collapsed with tasks', () => {
+ render(
+
+ );
+ const badge = document.querySelector('.toggle-badge');
+ expect(badge?.textContent).toBe('3');
+ });
+});
diff --git a/frontend/src/__tests__/SandboxSettings.test.tsx b/frontend/src/__tests__/SandboxSettings.test.tsx
new file mode 100644
index 0000000..3916af2
--- /dev/null
+++ b/frontend/src/__tests__/SandboxSettings.test.tsx
@@ -0,0 +1,60 @@
+import { describe, it, expect, vi } from 'vitest';
+import { render, screen, fireEvent, waitFor } from '@testing-library/react';
+import { SandboxSettings } from '../components/settings/SandboxSettings';
+import { api } from '../api';
+
+vi.mock('../api', () => ({
+ api: {
+ getSettings: vi.fn(),
+ updateSetting: vi.fn(),
+ },
+}));
+
+describe('SandboxSettings', () => {
+ beforeEach(() => {
+ vi.clearAllMocks();
+ vi.mocked(api.getSettings).mockResolvedValue({ settings: {} });
+ });
+
+ it('renders heading and hint', async () => {
+ render();
+ await waitFor(() => {
+ expect(screen.getByText(/Execution environment/)).toBeInTheDocument();
+ });
+ });
+
+ it('renders default sandbox select', async () => {
+ render();
+ await waitFor(() => {
+ expect(screen.getByText('Default Sandbox')).toBeInTheDocument();
+ expect(screen.getByText('Local')).toBeInTheDocument();
+ expect(screen.getByText('SSH Remote')).toBeInTheDocument();
+ expect(screen.getByText('Modal Cloud')).toBeInTheDocument();
+ });
+ });
+
+ it('renders modal token fields', async () => {
+ render();
+ await waitFor(() => {
+ expect(screen.getByPlaceholderText('MODAL_TOKEN_ID')).toBeInTheDocument();
+ expect(screen.getByPlaceholderText('MODAL_TOKEN_SECRET')).toBeInTheDocument();
+ });
+ });
+
+ it('renders save button', async () => {
+ render();
+ await waitFor(() => {
+ expect(screen.getByText('Save Sandbox Settings')).toBeInTheDocument();
+ });
+ });
+
+ it('loads settings on mount', async () => {
+ vi.mocked(api.getSettings).mockResolvedValue({
+ settings: { sandbox: { default_sandbox: 'ssh' } },
+ });
+ render();
+ await waitFor(() => {
+ expect(api.getSettings).toHaveBeenCalled();
+ });
+ });
+});
diff --git a/frontend/src/__tests__/SettingsPage.test.tsx b/frontend/src/__tests__/SettingsPage.test.tsx
new file mode 100644
index 0000000..e736773
--- /dev/null
+++ b/frontend/src/__tests__/SettingsPage.test.tsx
@@ -0,0 +1,44 @@
+import { describe, it, expect, vi, beforeEach } from 'vitest';
+import { render, screen, fireEvent, waitFor } from '@testing-library/react';
+import { MemoryRouter } from 'react-router-dom';
+import { SettingsPage } from '../components/SettingsPage';
+
+describe('SettingsPage', () => {
+ function renderSettings(path = '/settings/agent') {
+ return render(
+
+
+
+ );
+ }
+
+ it('renders back link', () => {
+ renderSettings();
+ expect(screen.getByText(/Back to chat/)).toBeInTheDocument();
+ });
+
+ it('renders Settings title', () => {
+ renderSettings();
+ expect(screen.getByText('Settings')).toBeInTheDocument();
+ });
+
+ it('renders all nav links', () => {
+ renderSettings();
+ expect(screen.getByText('Providers')).toBeInTheDocument();
+ expect(screen.getByText('Agent')).toBeInTheDocument();
+ expect(screen.getByText('Sandbox')).toBeInTheDocument();
+ expect(screen.getByText('Writing')).toBeInTheDocument();
+ });
+
+ it('highlights active nav link', () => {
+ renderSettings('/settings/agent');
+ const agentLink = screen.getByText('Agent');
+ expect(agentLink.className).toContain('active');
+ });
+
+ it('does not highlight inactive links', () => {
+ renderSettings('/settings/agent');
+ const providers = screen.getByText('Providers');
+ expect(providers.className).not.toContain('active');
+ });
+});
diff --git a/frontend/src/__tests__/SettingsPanel.test.tsx b/frontend/src/__tests__/SettingsPanel.test.tsx
new file mode 100644
index 0000000..633e7a6
--- /dev/null
+++ b/frontend/src/__tests__/SettingsPanel.test.tsx
@@ -0,0 +1,104 @@
+import { describe, it, expect, vi } from 'vitest';
+import { render, screen, fireEvent, waitFor } from '@testing-library/react';
+import { SettingsPanel } from '../components/SettingsPanel';
+import { api } from '../api';
+
+vi.mock('../api', () => ({
+ api: {
+ getProviders: vi.fn(),
+ getSettings: vi.fn(),
+ updateSetting: vi.fn(),
+ },
+}));
+
+describe('SettingsPanel', () => {
+ beforeEach(() => {
+ vi.resetAllMocks();
+ vi.mocked(api.getProviders).mockResolvedValue({
+ providers: [
+ { id: 'openai', name: 'OpenAI', key_env: 'OPENAI_API_KEY', configured: true },
+ { id: 'anthropic', name: 'Anthropic', key_env: 'ANTHROPIC_API_KEY', configured: false },
+ ],
+ });
+ vi.mocked(api.getSettings).mockResolvedValue({ settings: {} });
+ });
+
+ it('renders settings heading', async () => {
+ render();
+ await waitFor(() => {
+ expect(screen.getByText('Settings')).toBeInTheDocument();
+ });
+ });
+
+ it('renders tab buttons', async () => {
+ render();
+ await waitFor(() => {
+ expect(screen.getByText('Providers')).toBeInTheDocument();
+ expect(screen.getByText('Agent')).toBeInTheDocument();
+ expect(screen.getByText('Sandbox')).toBeInTheDocument();
+ expect(screen.getByText('Writing')).toBeInTheDocument();
+ });
+ });
+
+ it('shows providers by default', async () => {
+ render();
+ await waitFor(() => {
+ expect(screen.getByText('OpenAI')).toBeInTheDocument();
+ expect(screen.getByText('Anthropic')).toBeInTheDocument();
+ });
+ });
+
+ it('shows configured status', async () => {
+ render();
+ await waitFor(() => {
+ expect(screen.getByText('Configured')).toBeInTheDocument();
+ expect(screen.getByText('Not set')).toBeInTheDocument();
+ });
+ });
+
+ it('switches to agent tab', async () => {
+ render();
+ await waitFor(() => {
+ fireEvent.click(screen.getByText('Agent'));
+ });
+ expect(screen.getByText('Default Model')).toBeInTheDocument();
+ expect(screen.getByText(/YOLO Mode/)).toBeInTheDocument();
+ });
+
+ it('switches to sandbox tab', async () => {
+ render();
+ await waitFor(() => {
+ fireEvent.click(screen.getByText('Sandbox'));
+ });
+ expect(screen.getByText('Default Sandbox')).toBeInTheDocument();
+ });
+
+ it('switches to writing tab', async () => {
+ render();
+ await waitFor(() => {
+ fireEvent.click(screen.getByText('Writing'));
+ });
+ expect(screen.getByText('Citation Style')).toBeInTheDocument();
+ expect(screen.getByText('Export Format')).toBeInTheDocument();
+ });
+
+ it('calls onClose when close clicked', async () => {
+ const onClose = vi.fn();
+ render();
+ await waitFor(() => {
+ const closeBtn = screen.getByText('×');
+ fireEvent.click(closeBtn);
+ });
+ expect(onClose).toHaveBeenCalled();
+ });
+
+ it('calls onClose when overlay clicked', async () => {
+ const onClose = vi.fn();
+ render();
+ await waitFor(() => {
+ const overlay = document.querySelector('.modal-overlay');
+ fireEvent.click(overlay!);
+ });
+ expect(onClose).toHaveBeenCalled();
+ });
+});
diff --git a/frontend/src/__tests__/Sidebar.test.tsx b/frontend/src/__tests__/Sidebar.test.tsx
new file mode 100644
index 0000000..1553492
--- /dev/null
+++ b/frontend/src/__tests__/Sidebar.test.tsx
@@ -0,0 +1,237 @@
+import { describe, it, expect, vi } from 'vitest';
+import { render, screen, fireEvent } from '@testing-library/react';
+import { MemoryRouter } from 'react-router-dom';
+import { Sidebar } from '../components/Sidebar';
+import type { Conversation, User } from '../types';
+
+vi.mock('../api', () => ({
+ setToken: vi.fn(),
+}));
+
+const mockNavigate = vi.fn();
+vi.mock('react-router-dom', async (importOriginal) => {
+ const actual = await importOriginal() as any;
+ return {
+ ...actual,
+ useNavigate: () => mockNavigate,
+ };
+});
+
+const mockUser: User = { id: 1, username: 'tester', display_name: 'Test User' };
+
+const mockConversations: Conversation[] = [
+ {
+ id: 1,
+ uuid: 'conv-1',
+ title: 'First conversation',
+ model: 'gpt-4o',
+ mode: 'general',
+ user_message_count: 3,
+ created_at: new Date().toISOString(),
+ updated_at: new Date().toISOString(),
+ },
+ {
+ id: 2,
+ uuid: 'conv-2',
+ title: 'Research project',
+ model: 'claude-4',
+ mode: 'research',
+ user_message_count: 10,
+ created_at: new Date(Date.now() - 86400000 * 3).toISOString(),
+ updated_at: new Date(Date.now() - 86400000 * 3).toISOString(),
+ },
+];
+
+describe('Sidebar', () => {
+ it('renders new chat button', () => {
+ render(
+
+
+
+ );
+ expect(screen.getByText('+ New Chat')).toBeInTheDocument();
+ });
+
+ it('renders conversation titles', () => {
+ render(
+
+
+
+ );
+ expect(screen.getByText('First conversation')).toBeInTheDocument();
+ expect(screen.getByText('Research project')).toBeInTheDocument();
+ });
+
+ it('highlights current conversation', () => {
+ render(
+
+
+
+ );
+ const items = document.querySelectorAll('.conversation-item');
+ expect(items[0].classList.contains('active')).toBe(true);
+ });
+
+ it('calls onSwitch when conversation clicked', () => {
+ const onSwitch = vi.fn();
+ render(
+
+
+
+ );
+ fireEvent.click(screen.getByText('First conversation'));
+ expect(onSwitch).toHaveBeenCalledWith('conv-1');
+ });
+
+ it('shows empty state when no conversations', () => {
+ render(
+
+
+
+ );
+ expect(screen.getByText('No conversations yet')).toBeInTheDocument();
+ });
+
+ it('shows user display name', () => {
+ render(
+
+
+
+ );
+ expect(screen.getByText('Test User')).toBeInTheDocument();
+ });
+
+ it('renders action buttons', () => {
+ render(
+
+
+
+ );
+ expect(screen.getByText(/Undo/)).toBeInTheDocument();
+ expect(screen.getByText(/Compact/)).toBeInTheDocument();
+ });
+
+ it('calls onAction for undo', () => {
+ const onAction = vi.fn();
+ render(
+
+
+
+ );
+ fireEvent.click(screen.getByText(/Undo/));
+ expect(onAction).toHaveBeenCalledWith('undo');
+ });
+
+ it('calls onAction for compact', () => {
+ const onAction = vi.fn();
+ render(
+
+
+
+ );
+ fireEvent.click(screen.getByText(/Compact/));
+ expect(onAction).toHaveBeenCalledWith('compact');
+ });
+
+ it('filters conversations by search', () => {
+ render(
+
+
+
+ );
+ const searchInput = screen.getByPlaceholderText('Search...');
+ fireEvent.change(searchInput, { target: { value: 'Research' } });
+ expect(screen.getByText('Research project')).toBeInTheDocument();
+ expect(screen.queryByText('First conversation')).not.toBeInTheDocument();
+ });
+});
diff --git a/frontend/src/__tests__/WritingSettings.test.tsx b/frontend/src/__tests__/WritingSettings.test.tsx
new file mode 100644
index 0000000..c83a2c6
--- /dev/null
+++ b/frontend/src/__tests__/WritingSettings.test.tsx
@@ -0,0 +1,60 @@
+import { describe, it, expect, vi } from 'vitest';
+import { render, screen, waitFor } from '@testing-library/react';
+import { WritingSettings } from '../components/settings/WritingSettings';
+import { api } from '../api';
+
+vi.mock('../api', () => ({
+ api: {
+ getSettings: vi.fn(),
+ updateSetting: vi.fn(),
+ },
+}));
+
+describe('WritingSettings', () => {
+ beforeEach(() => {
+ vi.clearAllMocks();
+ vi.mocked(api.getSettings).mockResolvedValue({ settings: {} });
+ });
+
+ it('renders hint', async () => {
+ render();
+ await waitFor(() => {
+ expect(screen.getByText('Paper writing preferences.')).toBeInTheDocument();
+ });
+ });
+
+ it('renders citation style select', async () => {
+ render();
+ await waitFor(() => {
+ expect(screen.getByText('Citation Style')).toBeInTheDocument();
+ expect(screen.getByText('APA')).toBeInTheDocument();
+ expect(screen.getByText('IEEE')).toBeInTheDocument();
+ });
+ });
+
+ it('renders export format select', async () => {
+ render();
+ await waitFor(() => {
+ expect(screen.getByText('Export Format')).toBeInTheDocument();
+ expect(screen.getByText('Markdown')).toBeInTheDocument();
+ expect(screen.getByText('LaTeX')).toBeInTheDocument();
+ });
+ });
+
+ it('renders save button', async () => {
+ render();
+ await waitFor(() => {
+ expect(screen.getByText('Save Writing Settings')).toBeInTheDocument();
+ });
+ });
+
+ it('loads settings on mount', async () => {
+ vi.mocked(api.getSettings).mockResolvedValue({
+ settings: { writing: { citation_style: 'ieee' } },
+ });
+ render();
+ await waitFor(() => {
+ expect(api.getSettings).toHaveBeenCalled();
+ });
+ });
+});
diff --git a/frontend/src/__tests__/useJobStatus.test.ts b/frontend/src/__tests__/useJobStatus.test.ts
new file mode 100644
index 0000000..c648928
--- /dev/null
+++ b/frontend/src/__tests__/useJobStatus.test.ts
@@ -0,0 +1,213 @@
+import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
+import { renderHook, waitFor, act } from '@testing-library/react';
+import { useJobStatus } from '../hooks/useJobStatus';
+import { api } from '../api';
+
+vi.mock('../api', () => ({
+ api: {
+ getConversationJobs: vi.fn(),
+ },
+}));
+
+describe('useJobStatus', () => {
+ beforeEach(() => {
+ vi.resetAllMocks();
+ vi.useFakeTimers({ shouldAdvanceTime: true });
+ });
+
+ afterEach(() => {
+ vi.useRealTimers();
+ });
+
+ async function flushPromises() {
+ await act(async () => {
+ await Promise.resolve();
+ });
+ }
+
+ it('returns empty jobs when disabled', async () => {
+ const { result } = renderHook(() =>
+ useJobStatus({ conversationUuid: 'test-uuid', enabled: false })
+ );
+ expect(result.current.activeJobs).toEqual([]);
+ expect(result.current.isProcessing).toBe(false);
+ });
+
+ it('returns empty jobs when conversationUuid is null', async () => {
+ const { result } = renderHook(() =>
+ useJobStatus({ conversationUuid: null })
+ );
+ expect(result.current.activeJobs).toEqual([]);
+ });
+
+ it('fetches jobs on mount', async () => {
+ const mockJobs = [
+ { job_id: 'job-1', status: 'queued', created_at: '2024-01-01T00:00:00Z' },
+ { job_id: 'job-2', status: 'running', created_at: '2024-01-01T00:01:00Z' },
+ ];
+ vi.mocked(api.getConversationJobs).mockResolvedValue({ jobs: mockJobs });
+
+ renderHook(() =>
+ useJobStatus({ conversationUuid: 'test-uuid', pollInterval: 10000 })
+ );
+
+ await flushPromises();
+
+ expect(api.getConversationJobs).toHaveBeenCalledTimes(1);
+ });
+
+ it('detects processing state for queued jobs', async () => {
+ const mockJobs = [
+ { job_id: 'job-1', status: 'queued', created_at: '2024-01-01T00:00:00Z' },
+ ];
+ vi.mocked(api.getConversationJobs).mockResolvedValue({ jobs: mockJobs });
+
+ const { result } = renderHook(() =>
+ useJobStatus({ conversationUuid: 'test-uuid', pollInterval: 10000 })
+ );
+
+ await flushPromises();
+
+ expect(result.current.isProcessing).toBe(true);
+ });
+
+ it('detects processing state for running jobs', async () => {
+ const mockJobs = [
+ { job_id: 'job-1', status: 'running', created_at: '2024-01-01T00:00:00Z' },
+ ];
+ vi.mocked(api.getConversationJobs).mockResolvedValue({ jobs: mockJobs });
+
+ const { result } = renderHook(() =>
+ useJobStatus({ conversationUuid: 'test-uuid', pollInterval: 10000 })
+ );
+
+ await flushPromises();
+
+ expect(result.current.isProcessing).toBe(true);
+ });
+
+ it('not processing when all jobs completed', async () => {
+ const mockJobs = [
+ { job_id: 'job-1', status: 'completed', created_at: '2024-01-01T00:00:00Z' },
+ ];
+ vi.mocked(api.getConversationJobs).mockResolvedValue({ jobs: mockJobs });
+
+ const { result } = renderHook(() =>
+ useJobStatus({ conversationUuid: 'test-uuid', pollInterval: 10000 })
+ );
+
+ await flushPromises();
+
+ expect(result.current.isProcessing).toBe(false);
+ });
+
+ it('handles empty jobs array', async () => {
+ vi.mocked(api.getConversationJobs).mockResolvedValue({ jobs: [] });
+
+ const { result } = renderHook(() =>
+ useJobStatus({ conversationUuid: 'test-uuid', pollInterval: 10000 })
+ );
+
+ await flushPromises();
+
+ expect(result.current.activeJobs).toEqual([]);
+ });
+
+ it('handles missing jobs key', async () => {
+ vi.mocked(api.getConversationJobs).mockResolvedValue({});
+
+ const { result } = renderHook(() =>
+ useJobStatus({ conversationUuid: 'test-uuid', pollInterval: 10000 })
+ );
+
+ await flushPromises();
+
+ expect(result.current.activeJobs).toEqual([]);
+ });
+
+ it('calls onJobComplete when jobs finish', async () => {
+ const onComplete = vi.fn();
+ // First call: processing, Second call: completed
+ vi.mocked(api.getConversationJobs)
+ .mockResolvedValueOnce({ jobs: [{ job_id: 'j1', status: 'running' }] })
+ .mockResolvedValueOnce({ jobs: [{ job_id: 'j1', status: 'completed' }] });
+
+ const { result } = renderHook(() =>
+ useJobStatus({ conversationUuid: 'test-uuid', pollInterval: 5000, onJobComplete: onComplete })
+ );
+
+ await flushPromises();
+ expect(result.current.isProcessing).toBe(true);
+
+ // Advance time to trigger next poll
+ await act(async () => {
+ vi.advanceTimersByTime(5000);
+ });
+ await flushPromises();
+
+ expect(onComplete).toHaveBeenCalledWith('test-uuid');
+ });
+
+ it('tracks last job id', async () => {
+ vi.mocked(api.getConversationJobs).mockResolvedValue({
+ jobs: [{ job_id: 'job-last', status: 'running' }],
+ });
+
+ const { result } = renderHook(() =>
+ useJobStatus({ conversationUuid: 'test-uuid', pollInterval: 10000 })
+ );
+
+ await flushPromises();
+
+ expect(result.current.lastJobId).toBe('job-last');
+ });
+
+ it('handles fetch errors gracefully', async () => {
+ vi.mocked(api.getConversationJobs).mockRejectedValue(new Error('Network error'));
+ const consoleDebug = vi.spyOn(console, 'debug').mockImplementation(() => {});
+
+ renderHook(() =>
+ useJobStatus({ conversationUuid: 'test-uuid', pollInterval: 10000 })
+ );
+
+ await flushPromises();
+
+ expect(consoleDebug).toHaveBeenCalled();
+ consoleDebug.mockRestore();
+ });
+
+ it('polls at configured interval', async () => {
+ vi.mocked(api.getConversationJobs).mockResolvedValue({ jobs: [] });
+
+ renderHook(() =>
+ useJobStatus({ conversationUuid: 'test-uuid', pollInterval: 5000 })
+ );
+
+ await flushPromises();
+ expect(api.getConversationJobs).toHaveBeenCalledTimes(1);
+
+ await act(async () => {
+ vi.advanceTimersByTime(5000);
+ });
+ await flushPromises();
+
+ expect(api.getConversationJobs).toHaveBeenCalledTimes(2);
+ });
+
+ it('refresh triggers immediate fetch', async () => {
+ vi.mocked(api.getConversationJobs).mockResolvedValue({ jobs: [] });
+
+ const { result } = renderHook(() =>
+ useJobStatus({ conversationUuid: 'test-uuid', pollInterval: 10000 })
+ );
+
+ await flushPromises();
+ expect(api.getConversationJobs).toHaveBeenCalledTimes(1);
+
+ await act(async () => {
+ await result.current.refresh();
+ });
+
+ expect(api.getConversationJobs).toHaveBeenCalledTimes(2);
+ });
+});
diff --git a/frontend/src/__tests__/useSSE.test.ts b/frontend/src/__tests__/useSSE.test.ts
new file mode 100644
index 0000000..c6621dd
--- /dev/null
+++ b/frontend/src/__tests__/useSSE.test.ts
@@ -0,0 +1,134 @@
+import { describe, it, expect, vi } from 'vitest';
+import { renderHook, act } from '@testing-library/react';
+import { useSSE } from '../hooks/useSSE';
+import type { AgentEvent } from '../types';
+
+const instances: any[] = [];
+
+vi.stubGlobal('EventSource', vi.fn(function (this: any, url: string) {
+ this.url = url;
+ this.readyState = 0;
+ this.onopen = null;
+ this.onmessage = null;
+ this.onerror = null;
+ this.close = function () {
+ this.readyState = 2;
+ };
+ instances.push(this);
+ return this;
+}));
+
+(globalThis as any).EventSource.CONNECTING = 0;
+(globalThis as any).EventSource.OPEN = 1;
+(globalThis as any).EventSource.CLOSED = 2;
+
+function newest() {
+ return instances[instances.length - 1];
+}
+
+describe('useSSE', () => {
+ beforeEach(() => {
+ instances.length = 0;
+ });
+
+ it('creates EventSource when enabled is true', () => {
+ const onEvent = vi.fn();
+ renderHook(() => useSSE(onEvent, true));
+
+ expect((globalThis as any).EventSource).toHaveBeenCalled();
+ expect(instances.length).toBe(1);
+ expect(instances[0].url).toContain('/api/events');
+ });
+
+ it('does not create EventSource when enabled is false', () => {
+ const onEvent = vi.fn();
+ renderHook(() => useSSE(onEvent, false));
+
+ expect(instances.length).toBe(0);
+ });
+
+ it('appends token as query parameter', () => {
+ const onEvent = vi.fn();
+ renderHook(() => useSSE(onEvent, true, 'test-token'));
+
+ expect(instances.length).toBe(1);
+ expect(instances[0].url).toContain('token=test-token');
+ });
+
+ it('sets connected to true on open', () => {
+ const onEvent = vi.fn();
+ const { result } = renderHook(() => useSSE(onEvent, true));
+
+ act(() => {
+ newest().onopen(new Event('open'));
+ });
+
+ expect(result.current.connected).toBe(true);
+ });
+
+ it('calls onEvent when message received', () => {
+ const onEvent = vi.fn();
+ renderHook(() => useSSE(onEvent, true));
+
+ const event: AgentEvent = { event_type: 'text_delta', data: { text: 'hello' } };
+ act(() => {
+ newest().onmessage(new MessageEvent('message', { data: JSON.stringify(event) }));
+ });
+
+ expect(onEvent).toHaveBeenCalledWith(event);
+ });
+
+ it('does not call onEvent for messages without data', () => {
+ const onEvent = vi.fn();
+ renderHook(() => useSSE(onEvent, true));
+
+ act(() => {
+ newest().onmessage(new MessageEvent('message', { data: '' }));
+ });
+
+ expect(onEvent).not.toHaveBeenCalled();
+ });
+
+ it('handles malformed JSON gracefully', () => {
+ const onEvent = vi.fn();
+ renderHook(() => useSSE(onEvent, true));
+
+ act(() => {
+ newest().onmessage(new MessageEvent('message', { data: 'not-json' }));
+ });
+
+ expect(onEvent).not.toHaveBeenCalled();
+ });
+
+ it('sets connected to false on error and closes connection', () => {
+ vi.useFakeTimers();
+ const onEvent = vi.fn();
+ const { result } = renderHook(() => useSSE(onEvent, true));
+
+ act(() => {
+ newest().onopen(new Event('open'));
+ });
+ expect(result.current.connected).toBe(true);
+
+ act(() => {
+ newest().onerror(new Event('error'));
+ });
+ expect(result.current.connected).toBe(false);
+
+ vi.useRealTimers();
+ });
+
+ it('calls onReconnect when reconnecting after disconnect', () => {
+ const onEvent = vi.fn();
+ const onReconnect = vi.fn();
+ renderHook(() => useSSE(onEvent, true, null, onReconnect));
+
+ act(() => {
+ newest().onopen(new Event('open'));
+ });
+ act(() => {
+ newest().onopen(new Event('open'));
+ });
+ expect(onReconnect).toHaveBeenCalled();
+ });
+});
diff --git a/frontend/vite.config.ts b/frontend/vite.config.ts
index 9077474..4b0c245 100644
--- a/frontend/vite.config.ts
+++ b/frontend/vite.config.ts
@@ -18,5 +18,10 @@ export default defineConfig({
globals: true,
setupFiles: ['./src/__tests__/setup.ts'],
include: ['src/**/*.test.{ts,tsx}'],
+ coverage: {
+ provider: 'v8',
+ include: ['src/**/*.{ts,tsx}'],
+ exclude: ['src/__tests__/**', 'src/types.ts', 'src/main.tsx'],
+ },
},
})
diff --git a/qodana.yaml b/qodana.yaml
index 843fce2..9a4826a 100644
--- a/qodana.yaml
+++ b/qodana.yaml
@@ -3,8 +3,8 @@
####################################################################################################################
version: "1.0"
-linter: jetbrains/qodana-python-community:2026.1
+linter: jetbrains/qodana-python-community:2025.3
profile:
name: qodana.recommended
include:
- - name: CheckDependencyLicenses
\ No newline at end of file
+ - name: CheckDependencyLicenses
diff --git a/site/docs/setup.md b/site/docs/setup.md
index 8797c31..878a7c9 100644
--- a/site/docs/setup.md
+++ b/site/docs/setup.md
@@ -149,8 +149,8 @@ Run `make help` for the full list:
| `make db-upgrade` | Run migrations |
| **Testing** | |
| `make test` | Run all tests (backend + frontend + docs build) |
-| `make test-backend` | Backend tests only (149 tests) |
-| `make test-frontend` | Frontend tests only (29 tests) |
+| `make test-backend` | Backend tests only (591 tests) |
+| `make test-frontend` | Frontend tests only (182 tests) |
| `make test-docs` | Docs build check |
| **Other** | |
| `make check` | Type-check backend + frontend |