From a6f3260a1cf3b76003ea396d06f0751f89a73855 Mon Sep 17 00:00:00 2001 From: xprilion Date: Sun, 26 Apr 2026 17:33:35 +0530 Subject: [PATCH] Add more tests --- Makefile | 11 +- README.md | 2 +- backend/pyproject.toml | 12 + backend/tests/conftest.py | 4 + backend/tests/test_agent_loop.py | 374 +++++++++++++++++ backend/tests/test_app.py | 43 ++ backend/tests/test_celery_app.py | 42 ++ backend/tests/test_db_engine.py | 43 ++ backend/tests/test_db_operations.py | 393 ++++++++++++++++++ backend/tests/test_dependencies.py | 48 +++ backend/tests/test_job_manager.py | 84 ++++ backend/tests/test_llm.py | 246 +++++++++++ backend/tests/test_models.py | 186 +++++++++ backend/tests/test_models_orm.py | 299 +++++++++++++ backend/tests/test_prompts.py | 81 ++++ backend/tests/test_redis_pubsub.py | 184 ++++++++ backend/tests/test_routes_health.py | 42 ++ backend/tests/test_routes_settings.py | 146 +++++++ backend/tests/test_sandbox_manager.py | 54 +++ backend/tests/test_sandbox_types.py | 139 +++++++ backend/tests/test_session_manager.py | 99 +++++ backend/tests/test_tool_registry.py | 233 +++++++++++ backend/tests/test_tools_github.py | 55 +++ backend/tests/test_tools_local.py | 219 ++++++++++ backend/tests/test_tools_mcp.py | 64 +++ backend/tests/test_tools_papers.py | 98 +++++ backend/tests/test_tools_research.py | 62 +++ backend/tests/test_tools_sandbox.py | 64 +++ backend/tests/test_tools_search.py | 27 ++ backend/tests/test_tools_writing.py | 216 ++++++++++ backend/tests/test_types.py | 172 ++++++++ frontend/src/__tests__/AgentSettings.test.tsx | 64 +++ frontend/src/__tests__/ApprovalModal.test.tsx | 127 ++++++ frontend/src/__tests__/AuthGuard.test.tsx | 106 +++++ frontend/src/__tests__/ConfirmDialog.test.tsx | 159 +++++++ frontend/src/__tests__/LoginPage.test.tsx | 212 ++++++++++ frontend/src/__tests__/ModelModal.test.tsx | 176 ++++++++ .../src/__tests__/OnboardingModal.test.tsx | 185 +++++++++ .../src/__tests__/ProvidersSettings.test.tsx | 86 ++++ .../src/__tests__/QuestionDrawer.test.tsx | 136 ++++++ frontend/src/__tests__/ReportDrawer.test.tsx | 82 ++++ frontend/src/__tests__/RightPanel.test.tsx | 185 +++++++++ .../src/__tests__/SandboxSettings.test.tsx | 60 +++ frontend/src/__tests__/SettingsPage.test.tsx | 44 ++ frontend/src/__tests__/SettingsPanel.test.tsx | 104 +++++ frontend/src/__tests__/Sidebar.test.tsx | 237 +++++++++++ .../src/__tests__/WritingSettings.test.tsx | 60 +++ frontend/src/__tests__/useJobStatus.test.ts | 213 ++++++++++ frontend/src/__tests__/useSSE.test.ts | 134 ++++++ frontend/vite.config.ts | 5 + qodana.yaml | 4 +- site/docs/setup.md | 4 +- 52 files changed, 6118 insertions(+), 7 deletions(-) create mode 100644 backend/tests/test_agent_loop.py create mode 100644 backend/tests/test_app.py create mode 100644 backend/tests/test_celery_app.py create mode 100644 backend/tests/test_db_engine.py create mode 100644 backend/tests/test_db_operations.py create mode 100644 backend/tests/test_dependencies.py create mode 100644 backend/tests/test_job_manager.py create mode 100644 backend/tests/test_llm.py create mode 100644 backend/tests/test_models.py create mode 100644 backend/tests/test_models_orm.py create mode 100644 backend/tests/test_prompts.py create mode 100644 backend/tests/test_redis_pubsub.py create mode 100644 backend/tests/test_routes_health.py create mode 100644 backend/tests/test_routes_settings.py create mode 100644 backend/tests/test_sandbox_manager.py create mode 100644 backend/tests/test_sandbox_types.py create mode 100644 backend/tests/test_session_manager.py create mode 100644 backend/tests/test_tool_registry.py create mode 100644 backend/tests/test_tools_github.py create mode 100644 backend/tests/test_tools_local.py create mode 100644 backend/tests/test_tools_mcp.py create mode 100644 backend/tests/test_tools_papers.py create mode 100644 backend/tests/test_tools_research.py create mode 100644 backend/tests/test_tools_sandbox.py create mode 100644 backend/tests/test_tools_search.py create mode 100644 backend/tests/test_tools_writing.py create mode 100644 backend/tests/test_types.py create mode 100644 frontend/src/__tests__/AgentSettings.test.tsx create mode 100644 frontend/src/__tests__/ApprovalModal.test.tsx create mode 100644 frontend/src/__tests__/AuthGuard.test.tsx create mode 100644 frontend/src/__tests__/ConfirmDialog.test.tsx create mode 100644 frontend/src/__tests__/LoginPage.test.tsx create mode 100644 frontend/src/__tests__/ModelModal.test.tsx create mode 100644 frontend/src/__tests__/OnboardingModal.test.tsx create mode 100644 frontend/src/__tests__/ProvidersSettings.test.tsx create mode 100644 frontend/src/__tests__/QuestionDrawer.test.tsx create mode 100644 frontend/src/__tests__/ReportDrawer.test.tsx create mode 100644 frontend/src/__tests__/RightPanel.test.tsx create mode 100644 frontend/src/__tests__/SandboxSettings.test.tsx create mode 100644 frontend/src/__tests__/SettingsPage.test.tsx create mode 100644 frontend/src/__tests__/SettingsPanel.test.tsx create mode 100644 frontend/src/__tests__/Sidebar.test.tsx create mode 100644 frontend/src/__tests__/WritingSettings.test.tsx create mode 100644 frontend/src/__tests__/useJobStatus.test.ts create mode 100644 frontend/src/__tests__/useSSE.test.ts diff --git a/Makefile b/Makefile index a58a8f7..48114d7 100644 --- a/Makefile +++ b/Makefile @@ -115,8 +115,15 @@ test-docs: ## Verify docs site builds cleanly cd site && npx vitepress build docs .PHONY: test-coverage -test-coverage: ## Run all tests with coverage reports - cd $(BACKEND) && uv run pytest tests/ --tb=short -v +test-coverage: test-coverage-backend test-coverage-frontend ## Run all tests with coverage reports + +.PHONY: test-coverage-backend +test-coverage-backend: ## Backend tests with coverage + cd $(BACKEND) && uv run pytest tests/ --cov --cov-report=term-missing --tb=short -v + +.PHONY: test-coverage-frontend +test-coverage-frontend: ## Frontend tests with coverage + cd $(FRONTEND) && pnpm test --coverage cd $(FRONTEND) && pnpm test # ─── Docker ─────────────────────────────────────────────── diff --git a/README.md b/README.md index 39f29bd..4d97df9 100644 --- a/README.md +++ b/README.md @@ -68,7 +68,7 @@ See `.env.example` for all options. ## Testing ```bash -make test # Run all tests (149 backend + 29 frontend + docs build) +make test # Run all tests (591 backend + 182 frontend + docs build) make test-backend # Backend tests only make test-frontend # Frontend tests only make test-docs # Docs build check diff --git a/backend/pyproject.toml b/backend/pyproject.toml index d7c2876..7b7824f 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -63,6 +63,18 @@ openmlr = "openmlr.main:main" dev-dependencies = [ "pytest>=8.0.0", "pytest-asyncio>=0.24.0", + "pytest-cov>=6.0.0", "aiosqlite>=0.20.0", "httpx>=0.28.0", ] + +[tool.coverage.run] +source = ["openmlr"] +omit = ["openmlr/db/migrations/*"] + +[tool.coverage.report] +exclude_lines = [ + "pragma: no cover", + "if __name__ == .__main__.:", + "raise NotImplementedError", +] diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py index e7eb1c4..0731f80 100644 --- a/backend/tests/conftest.py +++ b/backend/tests/conftest.py @@ -105,11 +105,15 @@ async def client() -> AsyncGenerator[httpx.AsyncClient, None]: from openmlr.app import app from openmlr.db.engine import get_db as engine_get_db from openmlr.dependencies import get_db as dep_get_db + from openmlr.config import AgentConfig # Override both the canonical get_db *and* the re-export in dependencies app.dependency_overrides[engine_get_db] = _override_get_db app.dependency_overrides[dep_get_db] = _override_get_db + # Set minimal app state since lifespan is skipped + app.state.config = AgentConfig() + transport = httpx.ASGITransport(app=app) async with httpx.AsyncClient( transport=transport, diff --git a/backend/tests/test_agent_loop.py b/backend/tests/test_agent_loop.py new file mode 100644 index 0000000..77a6e77 --- /dev/null +++ b/backend/tests/test_agent_loop.py @@ -0,0 +1,374 @@ +"""Tests for agent loop — tool execution, approval, undo, compact, submissions.""" + +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + +from openmlr.agent.types import AgentEvent, Message, ToolCall, ToolSpec, Submission, OpType, LLMResult +from openmlr.agent.session import Session +from openmlr.agent.context import ContextManager +from openmlr.agent.loop import ( + _execute_tool, _handle_approval, _undo, _compact, + _run_agent, run_agent_turn, submission_loop, + _stream_llm_call, _non_stream_llm_call, _compact_llm_call, +) +from openmlr.config import AgentConfig +from openmlr.tools.registry import ToolRouter + +pytestmark = pytest.mark.asyncio + +# ── Test fixtures ────────────────────────────────────────── + +@pytest.fixture +def config(): + return AgentConfig(model_name="test/model", max_iterations=10, stream=False) + +@pytest.fixture +def mock_session(config): + session = MagicMock(spec=Session) + session.config = config + session.submission_queue = MagicMock() + session.context_manager = MagicMock(spec=ContextManager) + session.emit = AsyncMock() + session.cancel = MagicMock() + session.clear_cancel = MagicMock() + session.is_cancelled = MagicMock(return_value=False) + session.pending_approval = None + session.pending_answers = None + session.turn_count = 0 + session.on_event = MagicMock() + session.update_model = MagicMock() + return session + +@pytest.fixture +def mock_router(): + router = MagicMock(spec=ToolRouter) + router.call_tool = AsyncMock(return_value=("tool output", True)) + router.set_mode = MagicMock() + router.get_tool_specs_for_llm = MagicMock(return_value=[]) + router.get_tool = MagicMock(return_value=None) + router.get_mode = MagicMock(return_value="execute") + return router + + +# ── Tool Execution ───────────────────────────────────────── + +class TestExecuteTool: + async def test_executes_tool_and_returns_output(self, mock_session, mock_router): + tc = ToolCall(id="tc1", name="bash", arguments={"cmd": "ls"}) + mock_router.call_tool.return_value = ("file list output", True) + + output, success = await _execute_tool(mock_session, mock_router, tc) + + assert output == "file list output" + assert success is True + mock_router.call_tool.assert_called_once_with("bash", {"cmd": "ls"}, session=mock_session) + assert mock_session.emit.called + + async def test_handles_tool_error(self, mock_session, mock_router): + tc = ToolCall(id="tc1", name="bad_tool", arguments={}) + mock_router.call_tool.side_effect = RuntimeError("execution failed") + + output, success = await _execute_tool(mock_session, mock_router, tc) + + assert success is False + assert "Tool execution error" in output + + async def test_emits_both_state_changes(self, mock_session, mock_router): + tc = ToolCall(id="tc1", name="test", arguments={}) + + await _execute_tool(mock_session, mock_router, tc) + + emitted_event_types = [] + for call in mock_session.emit.call_args_list: + args = call[0] + if args and hasattr(args[0], 'event_type'): + emitted_event_types.append(args[0].event_type) + assert len(emitted_event_types) >= 2 + assert "tool_state_change" in emitted_event_types + + +# ── Approval Handling ────────────────────────────────────── + +class TestHandleApproval: + async def test_approves_tool_calls(self, mock_session, mock_router): + tcs = [ToolCall(id="tc1", name="bash", arguments={"cmd": "ls"})] + mock_session.pending_approval = {"tool_calls": tcs, "tool_router": mock_router} + mock_session.context_manager = MagicMock() + mock_session.context_manager.add_message = MagicMock() + mock_session.context_manager.get_messages = MagicMock(return_value=[]) + mock_session.config.stream = False + + await _handle_approval(mock_session, mock_router, {"tc1": True}) + + mock_router.call_tool.assert_called_once_with("bash", {"cmd": "ls"}, session=mock_session) + + async def test_rejects_tool_calls(self, mock_session, mock_router): + tcs = [ToolCall(id="tc1", name="dangerous", arguments={})] + mock_session.pending_approval = {"tool_calls": tcs, "tool_router": mock_router} + mock_session.context_manager = MagicMock() + mock_session.context_manager.add_message = MagicMock() + mock_session.context_manager.get_messages = MagicMock(return_value=[]) + mock_session.config.stream = False + + await _handle_approval(mock_session, mock_router, {"tc1": False}) + + assert mock_router.call_tool.called == False + + async def test_no_pending_approval_returns(self, mock_session, mock_router): + mock_session.pending_approval = None + await _handle_approval(mock_session, mock_router, {}) + # No exception, nothing happens + + +# ── Undo ─────────────────────────────────────────────────── + +class TestUndo: + async def test_undo_calls_context_manager(self, mock_session): + mock_session.context_manager.undo_last_turn.return_value = 3 + + await _undo(mock_session) + + mock_session.context_manager.undo_last_turn.assert_called_once() + mock_session.emit.assert_called() + + +# ── Compaction ───────────────────────────────────────────── + +class TestCompact: + async def test_compact_calls_context_manager(self, mock_session): + mock_session.context_manager.compact = AsyncMock(return_value="Summary of conversation") + + await _compact(mock_session) + + mock_session.context_manager.compact.assert_called_once() + mock_session.emit.assert_called() + + +# ── Run Agent ────────────────────────────────────────────── + +class TestRunAgent: + async def test_runs_with_no_tool_calls(self, mock_session, mock_router): + """Agent processes a message, LLM returns content with no tool calls.""" + mock_session.context_manager.get_messages.return_value = [] + mock_session.context_manager.needs_compaction.return_value = False + mock_session.context_manager.get_token_usage.return_value = {"used": 100, "max": 200000, "ratio": 0.0} + mock_session.config.stream = False + + with patch("openmlr.agent.loop.LLMProvider.generate") as mock_gen: + mock_gen.return_value = LLMResult( + content="I can help with that!", + tool_calls=[], + finish_reason="stop", + usage={"total_tokens": 50}, + ) + + await _run_agent(mock_session, mock_router, "help me") + + assert mock_session.context_manager.add_message.called + assert mock_session.emit.called + + async def test_handles_error_gracefully(self, mock_session, mock_router): + mock_session.context_manager.get_messages.side_effect = RuntimeError("Something broke") + + await _run_agent(mock_session, mock_router, "test") + + # Should not raise, emits error event + assert mock_session.emit.called + + async def test_cancelled_stops_early(self, mock_session, mock_router): + mock_session.is_cancelled.return_value = True + mock_session.context_manager.get_messages.return_value = [] + mock_session.context_manager.needs_compaction.return_value = False + mock_session.context_manager.get_token_usage.return_value = {"used": 0, "max": 200000, "ratio": 0.0} + + await _run_agent(mock_session, mock_router, "test") + + mock_session.emit.assert_any_call(AgentEvent(event_type="interrupted")) + + +class TestRunAgentTurn: + async def test_delegates_to_run_agent(self, mock_session, mock_router): + mock_session.context_manager.get_messages.return_value = [] + mock_session.context_manager.needs_compaction.return_value = False + mock_session.context_manager.get_token_usage.return_value = {"ratio": 0.0} + mock_session.config.stream = False + + with patch("openmlr.agent.loop.LLMProvider.generate") as mock_gen: + mock_gen.return_value = LLMResult( + content="Hello!", tool_calls=[], finish_reason="stop", + ) + + await run_agent_turn(mock_session, mock_router, "Hi", mode="plan") + + mock_router.set_mode.assert_called_with("plan") + + async def test_default_mode_is_execute(self, mock_session, mock_router): + mock_session.context_manager.get_messages.return_value = [] + mock_session.context_manager.needs_compaction.return_value = False + mock_session.context_manager.get_token_usage.return_value = {"ratio": 0.0} + mock_session.config.stream = False + + with patch("openmlr.agent.loop.LLMProvider.generate") as mock_gen: + mock_gen.return_value = LLMResult( + content="Ok", tool_calls=[], finish_reason="stop", + ) + await run_agent_turn(mock_session, mock_router, "test", mode="unknown") + + mock_router.set_mode.assert_called_with("execute") + + +# ── Submissions ──────────────────────────────────────────── + +class TestSubmissionLoop: + async def test_processes_user_input(self, mock_session, mock_router): + mock_session.submission_queue.get = AsyncMock(side_effect=[ + Submission(op=OpType.USER_INPUT, data="hello"), + Submission(op=OpType.SHUTDOWN), + ]) + mock_session.context_manager.get_messages.return_value = [] + mock_session.context_manager.needs_compaction.return_value = False + mock_session.context_manager.get_token_usage.return_value = {"ratio": 0.0} + mock_session.config.stream = False + + with patch("openmlr.agent.loop.LLMProvider.generate") as mock_gen: + mock_gen.return_value = LLMResult( + content="Hi!", tool_calls=[], finish_reason="stop", + ) + await submission_loop(mock_session, mock_router) + + assert mock_session.emit.called + + async def test_processes_compact(self, mock_session, mock_router): + mock_session.submission_queue.get = AsyncMock(side_effect=[ + Submission(op=OpType.COMPACT), + Submission(op=OpType.SHUTDOWN), + ]) + mock_session.context_manager.compact = AsyncMock(return_value="Summary") + + await submission_loop(mock_session, mock_router) + + mock_session.context_manager.compact.assert_called_once() + + async def test_processes_undo(self, mock_session, mock_router): + mock_session.submission_queue.get = AsyncMock(side_effect=[ + Submission(op=OpType.UNDO), + Submission(op=OpType.SHUTDOWN), + ]) + mock_session.context_manager.undo_last_turn.return_value = 3 + + await submission_loop(mock_session, mock_router) + + mock_session.context_manager.undo_last_turn.assert_called_once() + + async def test_processes_interrupt(self, mock_session, mock_router): + mock_session.submission_queue.get = AsyncMock(side_effect=[ + Submission(op=OpType.INTERRUPT), + Submission(op=OpType.SHUTDOWN), + ]) + + await submission_loop(mock_session, mock_router) + + mock_session.cancel.assert_called_once() + + async def test_shutdown_exits(self, mock_session, mock_router): + mock_session.submission_queue.get = AsyncMock(return_value=Submission(op=OpType.SHUTDOWN)) + + await submission_loop(mock_session, mock_router) + + mock_session.emit.assert_any_call(AgentEvent(event_type="shutdown")) + + +# ── LLM Call Helpers ─────────────────────────────────────── + +class TestNonStreamLLMCall: + async def test_returns_llm_result(self, mock_session): + mock_session.is_cancelled.return_value = False + messages = [{"role": "user", "content": "help"}] + tools = [] + + with patch("openmlr.agent.loop.LLMProvider.generate") as mock_gen: + mock_gen.return_value = LLMResult( + content="Response", tool_calls=[], finish_reason="stop", + ) + result = await _non_stream_llm_call(mock_session, messages, tools) + + assert result is not None + assert result.content == "Response" + + async def test_emits_chunk_and_end(self, mock_session): + mock_session.is_cancelled.return_value = False + + with patch("openmlr.agent.loop.LLMProvider.generate") as mock_gen: + mock_gen.return_value = LLMResult( + content="Output", tool_calls=[], finish_reason="stop", + ) + await _non_stream_llm_call(mock_session, [], []) + + assert mock_session.emit.called + + +class TestStreamLLMCall: + async def test_returns_llm_result_from_chunks(self, mock_session): + mock_session.is_cancelled.return_value = False + mock_session.config = AgentConfig(model_name="test", stream=True) + + async def mock_stream(messages, config, tools): + yield "Hello" + yield " world" + + with patch("openmlr.agent.loop.LLMProvider.generate_stream") as mock_str: + mock_str.return_value = mock_stream(None, None, None) + result = await _stream_llm_call(mock_session, [], []) + + assert result is not None + assert result.content == "Hello world" + assert result.finish_reason == "stop" + + async def test_cancelled_returns_none(self, mock_session): + mock_session.is_cancelled.return_value = True + + async def mock_stream(messages, config, tools): + yield "Hello" + if False: + yield + + with patch("openmlr.agent.loop.LLMProvider.generate_stream") as mock_str: + mock_str.return_value = mock_stream(None, None, None) + result = await _stream_llm_call(mock_session, [], []) + + assert result is None + + async def test_handles_stream_with_tool_calls(self, mock_session): + mock_session.is_cancelled.return_value = False + mock_session.config = AgentConfig(model_name="test", stream=True) + tc = ToolCall(id="call_1", name="search", arguments={"query": "test"}) + + async def mock_stream(messages, config, tools): + yield "Finding..." + yield tc + + with patch("openmlr.agent.loop.LLMProvider.generate_stream") as mock_str: + mock_str.return_value = mock_stream(None, None, None) + result = await _stream_llm_call(mock_session, [], []) + + assert result is not None + assert result.content == "Finding..." + assert len(result.tool_calls) == 1 + assert result.tool_calls[0].name == "search" + assert result.finish_reason == "tool_calls" + + +# ── Compact LLM Call ─────────────────────────────────────── + +class TestCompactLLMCall: + async def test_returns_content(self): + messages = [{"role": "user", "content": "summarize"}] + config = AgentConfig(model_name="test/title", stream=False) + + with patch("openmlr.agent.loop.LLMProvider.generate") as mock_gen: + mock_gen.return_value = LLMResult( + content="A summary.", tool_calls=[], finish_reason="stop", + ) + result = await _compact_llm_call(messages, config) + + assert result == "A summary." diff --git a/backend/tests/test_app.py b/backend/tests/test_app.py new file mode 100644 index 0000000..222eeb2 --- /dev/null +++ b/backend/tests/test_app.py @@ -0,0 +1,43 @@ +"""Tests for app entrypoint and main module.""" + +import pytest +from openmlr.app import app + +pytestmark = pytest.mark.asyncio + + +class TestAppCreation: + async def test_app_title(self): + assert app.title == "OpenMLR" + + async def test_app_version(self): + assert app.version == "2.0.0" + + async def test_app_routers_registered(self): + route_paths = [r.path for r in app.routes] + assert "/api/auth/register" in route_paths + assert "/api/auth/login" in route_paths + assert "/api/health" in route_paths + assert "/api/message" in route_paths + + async def test_cors_middleware_configured(self): + from fastapi.middleware.cors import CORSMiddleware + middlewares = [m.cls for m in app.user_middleware] + assert CORSMiddleware in middlewares + + async def test_global_exception_handler_configured(self): + from starlette.responses import JSONResponse + handlers = app.exception_handlers + assert Exception in handlers + + +class TestMainModule: + async def test_main_is_callable(self): + from openmlr.main import main + assert callable(main) + + async def test_main_contains_uvicorn_import(self): + from openmlr.main import main + import inspect + source = inspect.getsource(main) + assert "uvicorn" in source diff --git a/backend/tests/test_celery_app.py b/backend/tests/test_celery_app.py new file mode 100644 index 0000000..21c2cda --- /dev/null +++ b/backend/tests/test_celery_app.py @@ -0,0 +1,42 @@ +"""Tests for Celery app configuration.""" + +import pytest +from celery import Celery + + +class TestCeleryApp: + def test_is_celery_instance(self): + from openmlr.celery_app import celery_app + assert isinstance(celery_app, Celery) + + def test_has_correct_name(self): + from openmlr.celery_app import celery_app + assert celery_app.main == "openmlr" + + def test_config_has_serializer(self): + from openmlr.celery_app import celery_app + assert celery_app.conf.task_serializer == "json" + assert "json" in celery_app.conf.accept_content + + def test_config_has_timezone(self): + from openmlr.celery_app import celery_app + assert celery_app.conf.timezone == "UTC" + assert celery_app.conf.enable_utc is True + + def test_config_worker_settings(self): + from openmlr.celery_app import celery_app + assert celery_app.conf.worker_prefetch_multiplier == 1 + assert celery_app.conf.task_acks_late is True + + def test_config_result_expiry(self): + from openmlr.celery_app import celery_app + assert celery_app.conf.result_expires == 3600 + + def test_task_routing_configured(self): + from openmlr.celery_app import celery_app + routes = celery_app.conf.task_routes + assert routes is not None + + def test_get_celery_app(self): + from openmlr.celery_app import get_celery_app, celery_app + assert get_celery_app() is celery_app diff --git a/backend/tests/test_db_engine.py b/backend/tests/test_db_engine.py new file mode 100644 index 0000000..aa47ab3 --- /dev/null +++ b/backend/tests/test_db_engine.py @@ -0,0 +1,43 @@ +"""Tests for database engine factory and configuration.""" + +import pytest + + +class TestEngineConfig: + def test_database_url_exists(self): + from openmlr.db.engine import DATABASE_URL + assert DATABASE_URL is not None + assert len(DATABASE_URL) > 0 + + def test_engine_created(self): + from openmlr.db.engine import engine + assert engine is not None + + def test_async_session_created(self): + from openmlr.db.engine import async_session + assert async_session is not None + + def test_worker_engine_context_var(self): + from openmlr.db.engine import _worker_engine + from contextvars import ContextVar + assert isinstance(_worker_engine, ContextVar) + + +@pytest.mark.asyncio +class TestGetWorkerSession: + async def test_returns_sessionmaker(self): + from openmlr.db.engine import get_worker_session + result = get_worker_session() + from sqlalchemy.ext.asyncio import async_sessionmaker + assert isinstance(result, async_sessionmaker) + + +@pytest.mark.asyncio +class TestGetDB: + async def test_yields_session(self): + from openmlr.db.engine import get_db + sessions = [] + async for s in get_db(): + sessions.append(s) + break + assert len(sessions) == 1 diff --git a/backend/tests/test_db_operations.py b/backend/tests/test_db_operations.py new file mode 100644 index 0000000..ba7263d --- /dev/null +++ b/backend/tests/test_db_operations.py @@ -0,0 +1,393 @@ +"""Tests for database CRUD operations.""" + +import pytest +import pytest_asyncio +from sqlalchemy.ext.asyncio import AsyncSession + +pytestmark = pytest.mark.asyncio + +from openmlr.db import operations as ops +from openmlr.db.models import UserSetting + + +class TestConversationOperations: + async def test_create_conversation(self, db_session: AsyncSession, test_user): + conv = await ops.create_conversation(db_session, test_user.id, title="Test Conv") + assert conv.id is not None + assert conv.title == "Test Conv" + assert conv.user_id == test_user.id + assert conv.mode == "general" + + async def test_create_conversation_with_model(self, db_session: AsyncSession, test_user): + conv = await ops.create_conversation(db_session, test_user.id, model="gpt-4o", mode="coding") + assert conv.model == "gpt-4o" + assert conv.mode == "coding" + + async def test_get_conversations_empty(self, db_session: AsyncSession, test_user): + convs = await ops.get_conversations(db_session, test_user.id) + assert convs == [] + + async def test_get_conversations(self, db_session: AsyncSession, test_user): + await ops.create_conversation(db_session, test_user.id, title="Conv 1") + await ops.create_conversation(db_session, test_user.id, title="Conv 2") + convs = await ops.get_conversations(db_session, test_user.id) + assert len(convs) == 2 + assert convs[0].title == "Conv 2" # most recent first + + async def test_get_conversation_by_id(self, db_session: AsyncSession, test_user): + conv = await ops.create_conversation(db_session, test_user.id, title="Find Me") + found = await ops.get_conversation_by_id(db_session, conv.id) + assert found is not None + assert found.title == "Find Me" + + async def test_get_conversation_by_id_not_found(self, db_session: AsyncSession): + found = await ops.get_conversation_by_id(db_session, 9999) + assert found is None + + async def test_get_conversation_by_uuid(self, db_session: AsyncSession, test_user): + conv = await ops.create_conversation(db_session, test_user.id) + found = await ops.get_conversation_by_uuid(db_session, conv.uuid) + assert found is not None + assert found.id == conv.id + + async def test_get_conversation_by_uuid_not_found(self, db_session: AsyncSession): + found = await ops.get_conversation_by_uuid(db_session, "nonexistent") + assert found is None + + async def test_delete_conversation(self, db_session: AsyncSession, test_user): + conv = await ops.create_conversation(db_session, test_user.id) + deleted = await ops.delete_conversation(db_session, conv.id) + assert deleted is True + found = await ops.get_conversation_by_id(db_session, conv.id) + assert found is None + + async def test_delete_nonexistent(self, db_session: AsyncSession): + deleted = await ops.delete_conversation(db_session, 9999) + assert deleted is False + + async def test_update_title(self, db_session: AsyncSession, test_user): + conv = await ops.create_conversation(db_session, test_user.id, title="Old") + await ops.update_conversation_title(db_session, conv.id, "New Title") + found = await ops.get_conversation_by_id(db_session, conv.id) + assert found.title == "New Title" + + async def test_update_model(self, db_session: AsyncSession, test_user): + conv = await ops.create_conversation(db_session, test_user.id) + await ops.update_conversation_model(db_session, conv.id, "claude-sonnet-4") + found = await ops.get_conversation_by_id(db_session, conv.id) + assert found.model == "claude-sonnet-4" + + async def test_increment_user_message_count(self, db_session: AsyncSession, test_user): + conv = await ops.create_conversation(db_session, test_user.id) + assert conv.user_message_count == 0 + await ops.increment_user_message_count(db_session, conv.id) + found = await ops.get_conversation_by_id(db_session, conv.id) + assert found.user_message_count == 1 + + async def test_conversations_isolated_by_user(self, db_session: AsyncSession, test_user): + # Create another user + from openmlr.db.models import User + from openmlr.auth.security import hash_password + user2 = User(username="user2", password_hash=hash_password("pwd"), is_active=True) + db_session.add(user2) + await db_session.flush() + + await ops.create_conversation(db_session, test_user.id, title="User 1 Conv") + await ops.create_conversation(db_session, user2.id, title="User 2 Conv") + + convs_u1 = await ops.get_conversations(db_session, test_user.id) + convs_u2 = await ops.get_conversations(db_session, user2.id) + assert len(convs_u1) == 1 + assert len(convs_u2) == 1 + assert convs_u1[0].title == "User 1 Conv" + assert convs_u2[0].title == "User 2 Conv" + + +class TestMessageOperations: + @pytest_asyncio.fixture(autouse=True) + async def _conv(self, db_session: AsyncSession, test_user): + self.conv = await ops.create_conversation(db_session, test_user.id, title="Msg Test") + + async def test_add_message(self, db_session: AsyncSession): + msg = await ops.add_message(db_session, self.conv.id, "user", "Hello!") + assert msg.id is not None + assert msg.role == "user" + assert msg.content == "Hello!" + assert msg.conversation_id == self.conv.id + + async def test_add_message_with_metadata(self, db_session: AsyncSession): + msg = await ops.add_message( + db_session, self.conv.id, "assistant", "Done", + metadata={"tool": "search", "duration": 1.5}, + ) + assert msg.meta == {"tool": "search", "duration": 1.5} + + async def test_get_messages_empty(self, db_session: AsyncSession): + msgs = await ops.get_messages(db_session, self.conv.id) + assert msgs == [] + + async def test_get_messages(self, db_session: AsyncSession): + await ops.add_message(db_session, self.conv.id, "user", "First") + await ops.add_message(db_session, self.conv.id, "assistant", "Second") + msgs = await ops.get_messages(db_session, self.conv.id) + assert len(msgs) == 2 + assert msgs[0].content == "First" + assert msgs[1].content == "Second" + + async def test_clear_messages(self, db_session: AsyncSession): + await ops.add_message(db_session, self.conv.id, "user", "A") + await ops.add_message(db_session, self.conv.id, "assistant", "B") + await ops.clear_messages(db_session, self.conv.id) + msgs = await ops.get_messages(db_session, self.conv.id) + assert msgs == [] + + +class TestSettingsOperations: + async def test_set_and_get_user_setting(self, db_session: AsyncSession, test_user): + await ops.set_user_setting(db_session, test_user.id, "agent", "default_model", "gpt-4o") + val = await ops.get_user_setting(db_session, test_user.id, "agent", "default_model") + assert val == "gpt-4o" + + async def test_set_user_setting_update(self, db_session: AsyncSession, test_user): + await ops.set_user_setting(db_session, test_user.id, "agent", "yolo_mode", True) + await ops.set_user_setting(db_session, test_user.id, "agent", "yolo_mode", False) + val = await ops.get_user_setting(db_session, test_user.id, "agent", "yolo_mode") + assert val is False + + async def test_get_nonexistent_setting(self, db_session: AsyncSession, test_user): + val = await ops.get_user_setting(db_session, test_user.id, "agent", "nonexistent") + assert val is None + + async def test_get_all_settings(self, db_session: AsyncSession, test_user): + await ops.set_user_setting(db_session, test_user.id, "agent", "model", "gpt-4o") + await ops.set_user_setting(db_session, test_user.id, "providers", "openai_key", "sk-123") + settings = await ops.get_all_settings(db_session, test_user.id) + assert "agent" in settings + assert "providers" in settings + assert settings["agent"]["model"] == "gpt-4o" + assert settings["providers"]["openai_key"] == "sk-123" + + async def test_get_all_settings_by_category(self, db_session: AsyncSession, test_user): + await ops.set_user_setting(db_session, test_user.id, "agent", "model", "gpt-4o") + await ops.set_user_setting(db_session, test_user.id, "providers", "key", "val") + agent_settings = await ops.get_all_settings(db_session, test_user.id, category="agent") + assert "agent" in agent_settings + assert "providers" not in agent_settings + + async def test_delete_user_setting(self, db_session: AsyncSession, test_user): + await ops.set_user_setting(db_session, test_user.id, "agent", "model", "gpt-4o") + await ops.delete_user_setting(db_session, test_user.id, "agent", "model") + val = await ops.get_user_setting(db_session, test_user.id, "agent", "model") + assert val is None + + async def test_set_user_setting_int_value(self, db_session: AsyncSession, test_user): + await ops.set_user_setting(db_session, test_user.id, "agent", "max_iterations", 100) + val = await ops.get_user_setting(db_session, test_user.id, "agent", "max_iterations") + assert val == 100 + + async def test_set_user_setting_float_value(self, db_session: AsyncSession, test_user): + await ops.set_user_setting(db_session, test_user.id, "agent", "threshold", 0.85) + val = await ops.get_user_setting(db_session, test_user.id, "agent", "threshold") + assert val == 0.85 + + async def test_get_user_agent_settings(self, db_session: AsyncSession, test_user): + await ops.set_user_setting(db_session, test_user.id, "agent", "default_model", "claude") + await ops.set_user_setting(db_session, test_user.id, "agent", "yolo_mode", True) + settings = await ops.get_user_agent_settings(db_session, test_user.id) + assert settings["default_model"] == "claude" + assert settings["yolo_mode"] is True + + +class TestTaskOperations: + @pytest_asyncio.fixture(autouse=True) + async def _conv(self, db_session: AsyncSession, test_user): + self.conv = await ops.create_conversation(db_session, test_user.id) + + async def test_upsert_tasks_create(self, db_session: AsyncSession): + tasks = [ + {"title": "Task 1", "status": "pending"}, + {"title": "Task 2", "status": "in_progress"}, + ] + result = await ops.upsert_conversation_tasks(db_session, self.conv.id, tasks) + assert len(result) == 2 + assert result[0].title == "Task 1" + assert result[0].order_index == 0 + assert result[1].order_index == 1 + + async def test_upsert_tasks_replace(self, db_session: AsyncSession): + await ops.upsert_conversation_tasks(db_session, self.conv.id, [ + {"title": "Old Task"}, + ]) + result = await ops.upsert_conversation_tasks(db_session, self.conv.id, [ + {"title": "New Task"}, + ]) + assert len(result) == 1 + assert result[0].title == "New Task" + + async def test_get_tasks_empty(self, db_session: AsyncSession): + tasks = await ops.get_conversation_tasks(db_session, self.conv.id) + assert tasks == [] + + async def test_get_tasks(self, db_session: AsyncSession): + await ops.upsert_conversation_tasks(db_session, self.conv.id, [ + {"title": "T1", "status": "pending", "priority": "high"}, + {"title": "T2", "status": "completed"}, + ]) + tasks = await ops.get_conversation_tasks(db_session, self.conv.id) + assert len(tasks) == 2 + assert tasks[0].title == "T1" + assert tasks[0].priority == "high" + + async def test_update_task_status(self, db_session: AsyncSession): + await ops.upsert_conversation_tasks(db_session, self.conv.id, [ + {"title": "Do this", "status": "pending"}, + ]) + ok = await ops.update_task_status(db_session, self.conv.id, 0, "completed") + assert ok is True + tasks = await ops.get_conversation_tasks(db_session, self.conv.id) + assert tasks[0].status == "completed" + + async def test_update_task_status_out_of_range(self, db_session: AsyncSession): + await ops.upsert_conversation_tasks(db_session, self.conv.id, [ + {"title": "Only one"}, + ]) + ok = await ops.update_task_status(db_session, self.conv.id, 5, "completed") + assert ok is False + + +class TestResourceOperations: + @pytest_asyncio.fixture(autouse=True) + async def _conv(self, db_session: AsyncSession, test_user): + self.conv = await ops.create_conversation(db_session, test_user.id) + + async def test_add_resource(self, db_session: AsyncSession): + res = await ops.add_conversation_resource( + db_session, self.conv.id, + title="Paper 1", resource_type="paper", url="https://example.com", + ) + assert res.id is not None + assert res.title == "Paper 1" + assert res.type == "paper" + assert res.url == "https://example.com" + + async def test_get_resources(self, db_session: AsyncSession): + await ops.add_conversation_resource(db_session, self.conv.id, title="R1", resource_type="doc") + await ops.add_conversation_resource(db_session, self.conv.id, title="R2", resource_type="code") + resources = await ops.get_conversation_resources(db_session, self.conv.id) + assert len(resources) == 2 + + async def test_get_resource_by_id(self, db_session: AsyncSession): + res = await ops.add_conversation_resource( + db_session, self.conv.id, title="Test", resource_type="doc", + ) + found = await ops.get_resource_by_id(db_session, res.resource_id) + assert found is not None + assert found.title == "Test" + + async def test_upsert_resources_replace(self, db_session: AsyncSession): + await ops.add_conversation_resource(db_session, self.conv.id, title="Old", resource_type="doc") + result = await ops.upsert_conversation_resources(db_session, self.conv.id, [ + {"title": "New", "type": "doc"}, + ]) + assert len(result) == 1 + assert result[0].title == "New" + + async def test_upsert_plan_resource_new(self, db_session: AsyncSession): + res = await ops.upsert_plan_resource(db_session, self.conv.id, "# Plan content") + assert res.title == "PLAN.md" + assert res.type == "plan" + assert res.content == "# Plan content" + + async def test_upsert_plan_resource_update(self, db_session: AsyncSession): + await ops.upsert_plan_resource(db_session, self.conv.id, "First version") + res = await ops.upsert_plan_resource(db_session, self.conv.id, "Updated version") + assert res.content == "Updated version" + + async def test_upsert_paper_resource(self, db_session: AsyncSession): + res = await ops.upsert_paper_resource( + db_session, self.conv.id, "My Paper", "## Abstract\nContent", + ) + assert res.title == "My Paper" + assert res.type == "paper" + assert "Abstract" in res.content + + async def test_upsert_resource_create(self, db_session: AsyncSession): + res = await ops.upsert_resource( + db_session, self.conv.id, + resource_id="custom-id", title="Custom", resource_type="report", + content="Report content", + ) + assert res.resource_id == "custom-id" + assert res.content == "Report content" + + async def test_upsert_resource_update(self, db_session: AsyncSession): + await ops.upsert_resource( + db_session, self.conv.id, + resource_id="rid", title="Old Title", resource_type="doc", content="Old", + ) + res = await ops.upsert_resource( + db_session, self.conv.id, + resource_id="rid", title="New Title", resource_type="doc", content="New", + ) + assert res.title == "New Title" + assert res.content == "New" + + +class TestAgentJobOperations: + @pytest_asyncio.fixture(autouse=True) + async def _conv(self, db_session: AsyncSession, test_user): + self.conv = await ops.create_conversation(db_session, test_user.id) + + async def test_create_agent_job(self, db_session: AsyncSession): + job = await ops.create_agent_job( + db_session, self.conv.id, self.conv.user_id, "Process this", + ) + assert job.job_id is not None + assert job.status == "queued" + assert job.message == "Process this" + + async def test_get_agent_job(self, db_session: AsyncSession): + job = await ops.create_agent_job( + db_session, self.conv.id, self.conv.user_id, "Test", + ) + found = await ops.get_agent_job(db_session, job.job_id) + assert found is not None + assert found.status == "queued" + + async def test_get_agent_job_not_found(self, db_session: AsyncSession): + found = await ops.get_agent_job(db_session, "nonexistent") + assert found is None + + async def test_get_active_jobs(self, db_session: AsyncSession): + job1 = await ops.create_agent_job(db_session, self.conv.id, self.conv.user_id, "Job 1") + job2 = await ops.create_agent_job(db_session, self.conv.id, self.conv.user_id, "Job 2") + # Complete one + await ops.update_job_status(db_session, job1.job_id, "completed") + active = await ops.get_active_jobs_for_conversation(db_session, self.conv.id) + assert len(active) == 1 + assert active[0].job_id == job2.job_id + + async def test_update_job_status_to_running(self, db_session: AsyncSession): + job = await ops.create_agent_job(db_session, self.conv.id, self.conv.user_id, "Test") + ok = await ops.update_job_status(db_session, job.job_id, "running") + assert ok is True + found = await ops.get_agent_job(db_session, job.job_id) + assert found.status == "running" + assert found.started_at is not None + + async def test_update_job_status_to_completed(self, db_session: AsyncSession): + job = await ops.create_agent_job(db_session, self.conv.id, self.conv.user_id, "Test") + await ops.update_job_status(db_session, job.job_id, "completed") + found = await ops.get_agent_job(db_session, job.job_id) + assert found.status == "completed" + assert found.completed_at is not None + + async def test_update_job_status_not_found(self, db_session: AsyncSession): + ok = await ops.update_job_status(db_session, "nonexistent", "completed") + assert ok is False + + async def test_update_job_status_with_error(self, db_session: AsyncSession): + job = await ops.create_agent_job(db_session, self.conv.id, self.conv.user_id, "Test") + await ops.update_job_status(db_session, job.job_id, "failed", error="Something broke") + found = await ops.get_agent_job(db_session, job.job_id) + assert found.error == "Something broke" diff --git a/backend/tests/test_dependencies.py b/backend/tests/test_dependencies.py new file mode 100644 index 0000000..c5fc776 --- /dev/null +++ b/backend/tests/test_dependencies.py @@ -0,0 +1,48 @@ +"""Tests for FastAPI dependencies — get_config, get_current_user.""" + +import pytest +from httpx import AsyncClient + +pytestmark = pytest.mark.asyncio + + +class TestGetConfig: + async def test_get_config_returns_agent_config(self): + from openmlr.dependencies import get_config + config = get_config() + assert config is not None + assert hasattr(config, "model_name") + assert hasattr(config, "max_iterations") + + async def test_get_config_is_cached(self): + from openmlr.dependencies import get_config + config1 = get_config() + config2 = get_config() + assert config1 is config2 + + +class TestGetCurrentUser: + async def test_valid_token_returns_user(self, auth_client: AsyncClient): + resp = await auth_client.get("/api/auth/me") + assert resp.status_code == 200 + data = resp.json() + assert data["username"] == "testuser" + + async def test_no_token_returns_401(self, client: AsyncClient): + resp = await client.get("/api/auth/me") + assert resp.status_code == 401 + + async def test_invalid_token_returns_401(self, client: AsyncClient): + client.headers["Authorization"] = "Bearer invalid.token.here" + resp = await client.get("/api/auth/me") + assert resp.status_code == 401 + + +class TestGetDB: + async def test_db_session_yielded(self, client: AsyncClient): + from openmlr.dependencies import get_db + sessions = [] + async for s in get_db(): + sessions.append(s) + break + assert len(sessions) == 1 diff --git a/backend/tests/test_job_manager.py b/backend/tests/test_job_manager.py new file mode 100644 index 0000000..8305cb8 --- /dev/null +++ b/backend/tests/test_job_manager.py @@ -0,0 +1,84 @@ +"""Tests for JobManager — background job creation and status tracking.""" + +import pytest +import pytest_asyncio +from sqlalchemy.ext.asyncio import AsyncSession + +pytestmark = pytest.mark.asyncio + +from openmlr.services.job_manager import JobManager, get_job_manager +from openmlr.db import operations as ops + + +@pytest_asyncio.fixture +async def conversation(db_session: AsyncSession, test_user): + return await ops.create_conversation(db_session, test_user.id, title="Job Test") + + +class TestJobManager: + async def test_get_job_manager_singleton(self): + jm1 = get_job_manager() + jm2 = get_job_manager() + assert jm1 is jm2 + + async def test_create_job_disabled_by_default(self, db_session: AsyncSession, conversation, test_user): + jm = JobManager() + # USE_BACKGROUND_JOBS is controlled by env — test without making assumptions + from openmlr.services.job_manager import USE_BACKGROUND_JOBS + job = await jm.create_job( + db=db_session, + conversation_id=conversation.id, + user_id=test_user.id, + message="Test message", + ) + if not USE_BACKGROUND_JOBS: + assert job is None + else: + assert job is not None + assert job.status == "queued" + + async def test_get_job_status_nonexistent(self, db_session: AsyncSession): + jm = JobManager() + status = await jm.get_job_status(db_session, "nonexistent") + assert status is None + + async def test_get_job_status_from_db(self, db_session: AsyncSession, conversation, test_user): + job = await ops.create_agent_job( + db_session, conversation.id, test_user.id, "Test", + ) + jm = JobManager() + status = await jm.get_job_status(db_session, job.job_id) + assert status is not None + assert status["job_id"] == job.job_id + assert status["status"] == "queued" + assert "created_at" in status + + async def test_get_active_jobs(self, db_session: AsyncSession, conversation, test_user): + job1 = await ops.create_agent_job(db_session, conversation.id, test_user.id, "J1") + job2 = await ops.create_agent_job(db_session, conversation.id, test_user.id, "J2") + await ops.update_job_status(db_session, job2.job_id, "completed") + + jm = JobManager() + active = await jm.get_active_jobs(db_session, conversation.id) + assert len(active) == 1 + assert active[0]["job_id"] == job1.job_id + + async def test_cancel_queued_job(self, db_session: AsyncSession, conversation, test_user): + job = await ops.create_agent_job(db_session, conversation.id, test_user.id, "Test") + jm = JobManager() + cancelled = await jm.cancel_job(db_session, job.job_id) + assert cancelled is True + found = await ops.get_agent_job(db_session, job.job_id) + assert found.status == "cancelled" + + async def test_cancel_nonexistent_job(self, db_session: AsyncSession): + jm = JobManager() + cancelled = await jm.cancel_job(db_session, "nonexistent") + assert cancelled is False + + async def test_cancel_already_running_job(self, db_session: AsyncSession, conversation, test_user): + job = await ops.create_agent_job(db_session, conversation.id, test_user.id, "Test") + await ops.update_job_status(db_session, job.job_id, "running") + jm = JobManager() + cancelled = await jm.cancel_job(db_session, job.job_id) + assert cancelled is False diff --git a/backend/tests/test_llm.py b/backend/tests/test_llm.py new file mode 100644 index 0000000..e636ab1 --- /dev/null +++ b/backend/tests/test_llm.py @@ -0,0 +1,246 @@ +"""Unit tests for LLMProvider — API key resolution, model normalization, tool param conversion, retry logic.""" + +import pytest + +from openmlr.agent.llm import LLMProvider + + +class TestGetApiKey: + @pytest.mark.parametrize("model_name,env_var", [ + ("openai/gpt-4o", "OPENAI_API_KEY"), + ("anthropic/claude-sonnet-4", "ANTHROPIC_API_KEY"), + ("openrouter/openai/gpt-4o", "OPENROUTER_API_KEY"), + ("opencode-go/glm-5.1", "OPENCODE_GO_API_KEY"), + ]) + def test_model_prefix_maps_to_env_var(self, monkeypatch, model_name, env_var): + monkeypatch.setenv(env_var, f"test-key-{env_var}") + key = LLMProvider._get_api_key(model_name) + assert key == f"test-key-{env_var}" + + def test_local_model_uses_local_api_key(self, monkeypatch): + monkeypatch.setenv("LOCAL_API_KEY", "local") + key = LLMProvider._get_api_key("local/default") + assert key == "local" + + def test_ollama_defaults_to_not_needed(self, monkeypatch): + monkeypatch.delenv("LOCAL_API_KEY", raising=False) + monkeypatch.delenv("OPENAI_API_KEY", raising=False) + monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False) + monkeypatch.delenv("OPENROUTER_API_KEY", raising=False) + key = LLMProvider._get_api_key("ollama/llama3.1") + assert key == "not-needed" + + def test_lmstudio_defaults_to_not_needed(self, monkeypatch): + monkeypatch.delenv("LOCAL_API_KEY", raising=False) + monkeypatch.delenv("OPENAI_API_KEY", raising=False) + monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False) + monkeypatch.delenv("OPENROUTER_API_KEY", raising=False) + key = LLMProvider._get_api_key("lmstudio/default") + assert key == "not-needed" + + def test_fallback_to_any_available_key(self, monkeypatch): + monkeypatch.delenv("OPENAI_API_KEY", raising=False) + monkeypatch.setenv("ANTHROPIC_API_KEY", "sk-anthro") + monkeypatch.delenv("OPENROUTER_API_KEY", raising=False) + key = LLMProvider._get_api_key("unknown/model") + assert key == "sk-anthro" + + +class TestNormalizeModel: + @pytest.mark.parametrize("full_name,normalized", [ + ("openai/gpt-4o", "gpt-4o"), + ("anthropic/claude-sonnet-4", "claude-sonnet-4"), + ("openrouter/anthropic/claude-3-sonnet", "anthropic/claude-3-sonnet"), + ("ollama/llama3.1", "llama3.1"), + ("lmstudio/default", "default"), + ("local/custom-model", "custom-model"), + ("opencode-go/glm-5.1", "glm-5.1"), + ]) + def test_normalize_strips_prefix(self, full_name, normalized): + result = LLMProvider._normalize_model(full_name) + assert result == normalized + + def test_no_prefix_passes_through(self): + result = LLMProvider._normalize_model("gpt-4o") + assert result == "gpt-4o" + + +class TestGetBaseUrl: + def test_local_base_url(self): + assert LLMProvider._get_base_url("local/default") == "http://localhost:8000/v1" + + def test_ollama_base_url(self): + assert LLMProvider._get_base_url("ollama/llama3") == "http://localhost:11434/v1" + + def test_lmstudio_base_url(self): + assert LLMProvider._get_base_url("lmstudio/default") == "http://localhost:1234/v1" + + def test_openrouter_base_url(self): + assert LLMProvider._get_base_url("openrouter/gpt-4o") == "https://openrouter.ai/api/v1" + + def test_opencode_go_base_url(self): + assert LLMProvider._get_base_url("opencode-go/glm-5.1") == "https://opencode.ai/zen/go/v1" + + def test_standard_provider_no_base_url(self): + assert LLMProvider._get_base_url("openai/gpt-4o") is None + assert LLMProvider._get_base_url("anthropic/claude-4") is None + + def test_local_custom_base_url(self, monkeypatch): + monkeypatch.setenv("LOCAL_API_BASE", "http://custom:9999/v1") + assert LLMProvider._get_base_url("local/model") == "http://custom:9999/v1" + + def test_ollama_custom_base_url(self, monkeypatch): + monkeypatch.setenv("OLLAMA_API_BASE", "http://remote:11435/v1") + assert LLMProvider._get_base_url("ollama/model") == "http://remote:11435/v1" + + def test_lmstudio_custom_base_url(self, monkeypatch): + monkeypatch.setenv("LMSTUDIO_API_BASE", "http://other:1235/v1") + assert LLMProvider._get_base_url("lmstudio/model") == "http://other:1235/v1" + + +class TestAnthropicFormatDetection: + def test_native_anthropic(self): + assert LLMProvider._is_anthropic_model("anthropic/claude-sonnet-4") is True + + def test_openrouter_is_not_anthropic(self): + assert LLMProvider._is_anthropic_model("openrouter/anthropic/claude-4") is False + + def test_opencode_go_deepseek_is_anthropic_format(self): + assert LLMProvider._is_opencode_go_anthropic_format("opencode-go/deepseek-v4-pro") is True + + def test_opencode_go_deepseek_flash_is_anthropic_format(self): + assert LLMProvider._is_opencode_go_anthropic_format("opencode-go/deepseek-v4-flash") is True + + def test_opencode_go_minimax_is_anthropic_format(self): + assert LLMProvider._is_opencode_go_anthropic_format("opencode-go/minimax-m2.7") is True + + def test_opencode_go_glm_is_not_anthropic_format(self): + assert LLMProvider._is_opencode_go_anthropic_format("opencode-go/glm-5.1") is False + + def test_uses_anthropic_format_native(self): + assert LLMProvider._uses_anthropic_format("anthropic/claude-4") is True + + def test_uses_anthropic_format_opencode_go_deepseek(self): + assert LLMProvider._uses_anthropic_format("opencode-go/deepseek-v4-pro") is True + + def test_uses_anthropic_format_openai(self): + assert LLMProvider._uses_anthropic_format("openai/gpt-4o") is False + + +class TestToolParamConversion: + def test_openai_tool_param_none(self): + assert LLMProvider._openai_tool_param(None) is None + + def test_openai_tool_param_empty(self): + assert LLMProvider._openai_tool_param([]) is None + + def test_openai_tool_param_raw_format(self): + tools = [ + {"name": "search", "description": "Search web", "parameters": {"type": "object", "properties": {}}}, + ] + result = LLMProvider._openai_tool_param(tools) + assert len(result) == 1 + assert result[0]["type"] == "function" + assert result[0]["function"]["name"] == "search" + + def test_openai_tool_param_already_formatted(self): + tools = [ + {"type": "function", "function": {"name": "bash", "description": "Run cmd", "parameters": {}}}, + ] + result = LLMProvider._openai_tool_param(tools) + assert len(result) == 1 + assert result[0]["type"] == "function" + + def test_anthropic_tool_param_none(self): + assert LLMProvider._anthropic_tool_param(None) is None + + def test_anthropic_tool_param_conversion(self): + tools = [ + { + "type": "function", + "function": { + "name": "read", + "description": "Read file", + "parameters": { + "type": "object", + "properties": {"path": {"type": "string"}}, + "required": ["path"], + }, + }, + }, + ] + result = LLMProvider._anthropic_tool_param(tools) + assert len(result) == 1 + assert result[0]["name"] == "read" + assert result[0]["input_schema"]["type"] == "object" + assert "path" in result[0]["input_schema"]["properties"] + assert result[0]["input_schema"]["required"] == ["path"] + + def test_anthropic_tool_param_unwrapped(self): + tools = [ + {"name": "bash", "description": "Run cmd", "parameters": {"type": "object", "properties": {}}}, + ] + result = LLMProvider._anthropic_tool_param(tools) + assert len(result) == 1 + assert result[0]["name"] == "bash" + + +class TestToAnthropicMessages: + def test_separates_system_prompt(self): + messages = [ + {"role": "system", "content": "You are an assistant."}, + {"role": "user", "content": "Hello"}, + ] + system, chat = LLMProvider._to_anthropic_messages(messages) + assert "You are an assistant" in system + assert len(chat) == 1 + assert chat[0]["role"] == "user" + assert chat[0]["content"] == "Hello" + + def test_converts_assistant_with_tool_calls(self): + messages = [ + {"role": "user", "content": "Read file"}, + { + "role": "assistant", + "content": "Let me read that.", + "tool_calls": [ + {"id": "tc1", "name": "read_file", "arguments": {"path": "/tmp/test"}}, + ], + }, + ] + _system, chat = LLMProvider._to_anthropic_messages(messages) + assert len(chat) == 2 + + def test_converts_tool_result(self): + messages = [ + {"role": "tool", "content": "file contents here", "tool_call_id": "tc1"}, + ] + _system, chat = LLMProvider._to_anthropic_messages(messages) + assert len(chat) == 1 + assert chat[0]["role"] == "user" + + +class TestRetryLogic: + def test_is_retryable_rate_limit(self): + assert LLMProvider._is_retryable(Exception("429 Rate limit exceeded")) is True + + def test_is_retryable_timeout(self): + assert LLMProvider._is_retryable(Exception("Connection timeout")) is True + + def test_is_retryable_server_error(self): + assert LLMProvider._is_retryable(Exception("server_error 503")) is True + + def test_is_not_retryable_auth_error(self): + assert LLMProvider._is_retryable(Exception("Invalid API key")) is False + + def test_is_not_retryable_validation_error(self): + assert LLMProvider._is_retryable(Exception("Model not found")) is False + + def test_is_retryable_overloaded(self): + assert LLMProvider._is_retryable(Exception("Server is overloaded")) is True + + def test_is_retryable_capacity(self): + assert LLMProvider._is_retryable(Exception("Exceeded capacity")) is True + + def test_is_retryable_502(self): + assert LLMProvider._is_retryable(Exception("502 Bad Gateway")) is True diff --git a/backend/tests/test_models.py b/backend/tests/test_models.py new file mode 100644 index 0000000..c1031ee --- /dev/null +++ b/backend/tests/test_models.py @@ -0,0 +1,186 @@ +"""Tests for Pydantic API models — validation, defaults, serialization.""" + +import pytest +from pydantic import ValidationError + +from openmlr.models import ( + UserRegister, UserLogin, TokenResponse, UserInfo, + ConversationCreate, ConversationResponse, MessageResponse, ConversationDetail, + MessageSend, ApprovalRequest, + SettingUpdate, ProviderConfig, ModelSwitch, AgentEvent, +) + + +class TestUserRegister: + def test_valid(self): + u = UserRegister(username="testuser", password="testpassword123") + assert u.username == "testuser" + assert u.password == "testpassword123" + assert u.display_name is None + + def test_with_display_name(self): + u = UserRegister(username="tester", password="pass123", display_name="Test User") + assert u.display_name == "Test User" + + def test_username_too_short(self): + with pytest.raises(ValidationError): + UserRegister(username="ab", password="password123") + + def test_username_too_long(self): + with pytest.raises(ValidationError): + UserRegister(username="a" * 51, password="password123") + + def test_password_too_short(self): + with pytest.raises(ValidationError): + UserRegister(username="testuser", password="12345") + + def test_username_exactly_min(self): + u = UserRegister(username="abc", password="123456") + assert u.username == "abc" + + +class TestUserLogin: + def test_valid(self): + u = UserLogin(username="testuser", password="secret") + assert u.username == "testuser" + assert u.password == "secret" + + +class TestTokenResponse: + def test_defaults(self): + t = TokenResponse(access_token="abc123", user={"id": 1, "username": "test"}) + assert t.access_token == "abc123" + assert t.token_type == "bearer" + assert t.user == {"id": 1, "username": "test"} + + +class TestConversationCreate: + def test_defaults(self): + c = ConversationCreate() + assert c.title == "New conversation" + assert c.model is None + assert c.mode == "general" + + def test_custom(self): + c = ConversationCreate(title="Research Q1", model="gpt-4o", mode="research") + assert c.title == "Research Q1" + assert c.model == "gpt-4o" + assert c.mode == "research" + + +class TestConversationResponse: + def test_creation(self): + from datetime import datetime, timezone + now = datetime.now(timezone.utc) + c = ConversationResponse( + id=1, uuid="abc-def", title="Test Conv", model="gpt-4o", + mode="general", user_message_count=5, + created_at=now, updated_at=now, + ) + assert c.id == 1 + assert c.uuid == "abc-def" + assert c.user_message_count == 5 + + +class TestMessageResponse: + def test_creation(self): + from datetime import datetime, timezone + now = datetime.now(timezone.utc) + m = MessageResponse(id=1, role="user", content="Hello", metadata=None, created_at=now) + assert m.id == 1 + assert m.role == "user" + assert m.content == "Hello" + + def test_with_metadata(self): + from datetime import datetime, timezone + now = datetime.now(timezone.utc) + m = MessageResponse(id=2, role="assistant", content="Hi", metadata={"tool": "search"}, created_at=now) + assert m.metadata == {"tool": "search"} + + +class TestConversationDetail: + def test_creation(self): + from datetime import datetime, timezone + now = datetime.now(timezone.utc) + conv = ConversationResponse( + id=1, uuid="x", title="C", model=None, mode="general", + user_message_count=0, created_at=now, updated_at=now, + ) + msgs = [MessageResponse(id=1, role="user", content="Hi", metadata=None, created_at=now)] + cd = ConversationDetail(conversation=conv, messages=msgs) + assert len(cd.messages) == 1 + assert cd.conversation.id == 1 + + +class TestMessageSend: + def test_basic(self): + m = MessageSend(message="Hello world") + assert m.message == "Hello world" + assert m.mode is None + + def test_with_mode(self): + m = MessageSend(message="Research this", mode="research") + assert m.mode == "research" + + +class TestApprovalRequest: + def test_valid(self): + a = ApprovalRequest(approvals={"call_1": True, "call_2": False}) + assert a.approvals == {"call_1": True, "call_2": False} + + def test_empty(self): + a = ApprovalRequest(approvals={}) + assert a.approvals == {} + + +class TestSettingUpdate: + def test_str_value(self): + s = SettingUpdate(value="hello") + assert s.value == "hello" + + def test_int_value(self): + s = SettingUpdate(value=42) + assert s.value == 42 + + def test_bool_value(self): + s = SettingUpdate(value=True) + assert s.value is True + + def test_dict_value(self): + s = SettingUpdate(value={"key": "val"}) + assert s.value == {"key": "val"} + + +class TestProviderConfig: + def test_empty(self): + p = ProviderConfig() + assert p.openai_api_key is None + assert p.anthropic_api_key is None + + def test_with_keys(self): + p = ProviderConfig(openai_api_key="sk-123", brave_api_key="bsk-456") + assert p.openai_api_key == "sk-123" + assert p.brave_api_key == "bsk-456" + + def test_modal_config(self): + p = ProviderConfig(modal_token_id="tid", modal_token_secret="tsec") + assert p.modal_token_id == "tid" + assert p.modal_token_secret == "tsec" + + +class TestModelSwitch: + def test_valid(self): + m = ModelSwitch(model="gpt-4o") + assert m.model == "gpt-4o" + + +class TestAgentEventPydantic: + def test_valid(self): + e = AgentEvent(event_type="status", data={"key": "val"}) + assert e.event_type == "status" + assert e.data == {"key": "val"} + + def test_no_data(self): + e = AgentEvent(event_type="ping") + assert e.event_type == "ping" + assert e.data is None diff --git a/backend/tests/test_models_orm.py b/backend/tests/test_models_orm.py new file mode 100644 index 0000000..bd2e5ce --- /dev/null +++ b/backend/tests/test_models_orm.py @@ -0,0 +1,299 @@ +"""Tests for SQLAlchemy ORM models — creation, relationships, constraints.""" + +import pytest +import pytest_asyncio +from sqlalchemy.ext.asyncio import AsyncSession + +pytestmark = pytest.mark.asyncio +from sqlalchemy import select + +from openmlr.db.models import ( + User, Conversation, Message, ResearchCorpus, + WritingProject, SandboxConfig, ConversationTask, + ConversationResource, AgentJob, UserSetting, +) +from openmlr.auth.security import hash_password + + +class TestUserModel: + async def test_create_user(self, db_session: AsyncSession): + user = User( + username="newuser", + password_hash=hash_password("password123"), + display_name="New User", + is_active=True, + ) + db_session.add(user) + await db_session.commit() + assert user.id is not None + assert user.username == "newuser" + assert user.is_active is True + assert user.created_at is not None + + async def test_user_unique_username(self, db_session: AsyncSession, test_user): + dup = User( + username="testuser", + password_hash=hash_password("another"), + is_active=True, + ) + db_session.add(dup) + with pytest.raises(Exception): + await db_session.commit() + + async def test_user_relationships_exist(self, db_session: AsyncSession, test_user): + # Async SQLAlchemy requires relationship attributes to exist on the class + assert hasattr(User, "conversations") + assert hasattr(User, "settings") + + +class TestConversationModel: + async def test_create_conversation(self, db_session: AsyncSession, test_user): + conv = Conversation( + user_id=test_user.id, + title="Test Conversation", + model="gpt-4o", + mode="research", + ) + db_session.add(conv) + await db_session.commit() + assert conv.id is not None + assert conv.uuid is not None + assert len(conv.uuid) > 0 + assert conv.user_message_count == 0 + assert conv.created_at is not None + assert conv.updated_at is not None + + async def test_conversation_user_relationship(self, db_session: AsyncSession, test_user): + conv = Conversation(user_id=test_user.id, title="Rel Test") + db_session.add(conv) + await db_session.commit() + await db_session.refresh(conv) + assert conv.user_id == test_user.id + + +class TestMessageModel: + @pytest_asyncio.fixture(autouse=True) + async def _conv(self, db_session: AsyncSession, test_user): + conv = Conversation(user_id=test_user.id, title="Msg Model Test") + db_session.add(conv) + await db_session.commit() + self.conv = conv + + async def test_create_message(self, db_session: AsyncSession): + msg = Message( + conversation_id=self.conv.id, + role="user", + content="Hello world", + ) + db_session.add(msg) + await db_session.commit() + assert msg.id is not None + assert msg.role == "user" + assert msg.content == "Hello world" + + async def test_message_with_metadata(self, db_session: AsyncSession): + msg = Message( + conversation_id=self.conv.id, + role="assistant", + content="Done", + meta={"tool": "search", "usage": {"tokens": 100}}, + ) + db_session.add(msg) + await db_session.commit() + assert msg.meta == {"tool": "search", "usage": {"tokens": 100}} + + async def test_message_has_created_at(self, db_session: AsyncSession): + msg = Message(conversation_id=self.conv.id, role="system", content="Init") + db_session.add(msg) + await db_session.commit() + assert msg.created_at is not None + + +class TestConversationTask: + @pytest_asyncio.fixture(autouse=True) + async def _conv(self, db_session: AsyncSession, test_user): + conv = Conversation(user_id=test_user.id, title="Task Model Test") + db_session.add(conv) + await db_session.commit() + self.conv = conv + + async def test_create_task(self, db_session: AsyncSession): + task = ConversationTask( + conversation_id=self.conv.id, + title="Test task", + status="pending", + order_index=0, + ) + db_session.add(task) + await db_session.commit() + assert task.id is not None + assert task.title == "Test task" + assert task.status == "pending" + assert task.order_index == 0 + + async def test_task_with_priority(self, db_session: AsyncSession): + task = ConversationTask( + conversation_id=self.conv.id, + title="Urgent", + status="pending", + priority="high", + order_index=0, + ) + db_session.add(task) + await db_session.commit() + assert task.priority == "high" + + +class TestConversationResource: + @pytest_asyncio.fixture(autouse=True) + async def _conv(self, db_session: AsyncSession, test_user): + conv = Conversation(user_id=test_user.id, title="Resource Model Test") + db_session.add(conv) + await db_session.commit() + self.conv = conv + + async def test_create_resource(self, db_session: AsyncSession): + res = ConversationResource( + conversation_id=self.conv.id, + resource_id="res-001", + title="Paper 1", + type="paper", + url="https://arxiv.org/abs/1234.5678", + ) + db_session.add(res) + await db_session.commit() + assert res.id is not None + assert res.type == "paper" + assert res.url == "https://arxiv.org/abs/1234.5678" + + async def test_resource_with_content(self, db_session: AsyncSession): + res = ConversationResource( + conversation_id=self.conv.id, + resource_id="rpt-001", + title="Report", + type="report", + content="## Findings\nKey results...", + ) + db_session.add(res) + await db_session.commit() + assert res.content is not None + + +class TestAgentJob: + @pytest_asyncio.fixture(autouse=True) + async def _conv(self, db_session: AsyncSession, test_user): + conv = Conversation(user_id=test_user.id, title="Job Model Test") + db_session.add(conv) + await db_session.commit() + self.conv = conv + + async def test_create_job(self, db_session: AsyncSession): + job = AgentJob( + job_id="job-123", + conversation_id=self.conv.id, + user_id=self.conv.user_id, + message="Process this", + status="queued", + ) + db_session.add(job) + await db_session.commit() + assert job.id is not None + assert job.job_id == "job-123" + assert job.status == "queued" + assert job.created_at is not None + + async def test_job_optional_fields(self, db_session: AsyncSession): + job = AgentJob( + job_id="job-456", + conversation_id=self.conv.id, + user_id=self.conv.user_id, + message="Test", + status="queued", + mode="research", + ) + db_session.add(job) + await db_session.commit() + assert job.mode == "research" + + +class TestUserSetting: + async def test_create_setting(self, db_session: AsyncSession, test_user): + setting = UserSetting( + user_id=test_user.id, + category="agent", + key="default_model", + value="gpt-4o", + ) + db_session.add(setting) + await db_session.commit() + assert setting.id is not None + assert setting.category == "agent" + assert setting.key == "default_model" + assert setting.value == "gpt-4o" + + async def test_setting_uniqueness(self, db_session: AsyncSession, test_user): + s1 = UserSetting(user_id=test_user.id, category="agent", key="model", value="gpt-4o") + db_session.add(s1) + await db_session.commit() + + # Verify we can query the setting back + from sqlalchemy import select + result = await db_session.execute( + select(UserSetting).where( + UserSetting.user_id == test_user.id, + UserSetting.category == "agent", + UserSetting.key == "model", + ) + ) + settings = result.scalars().all() + assert len(settings) == 1 + assert settings[0].value == "gpt-4o" + + +class TestResearchCorpus: + async def test_create_corpus(self, db_session: AsyncSession, test_user): + corpus = ResearchCorpus( + user_id=test_user.id, + paper_id="arxiv-123", + title="Test Paper", + abstract="This is a test abstract", + source="arxiv", + tags=["ml", "nlp"], + ) + db_session.add(corpus) + await db_session.commit() + assert corpus.id is not None + assert corpus.paper_id == "arxiv-123" + assert corpus.source == "arxiv" + + +class TestWritingProject: + async def test_create_project(self, db_session: AsyncSession, test_user): + project = WritingProject( + user_id=test_user.id, + title="My Paper", + status="drafting", + sections={"abstract": ""}, + ) + db_session.add(project) + await db_session.commit() + assert project.id is not None + assert project.title == "My Paper" + assert project.status == "drafting" + assert "abstract" in project.sections + + +class TestSandboxConfig: + async def test_create_config(self, db_session: AsyncSession, test_user): + config = SandboxConfig( + user_id=test_user.id, + name="My Modal Sandbox", + type="modal", + config={"gpu": "a100"}, + is_default=True, + ) + db_session.add(config) + await db_session.commit() + assert config.id is not None + assert config.type == "modal" + assert config.is_default is True diff --git a/backend/tests/test_prompts.py b/backend/tests/test_prompts.py new file mode 100644 index 0000000..6ccd6f1 --- /dev/null +++ b/backend/tests/test_prompts.py @@ -0,0 +1,81 @@ +"""Tests for system prompt builder.""" + +import pytest + +from openmlr.agent.prompts import build_system_prompt, COMPACT_PROMPT +from openmlr.agent.types import ToolSpec + + +class TestBuildSystemPrompt: + def test_renders_with_tools(self): + tools = [ + ToolSpec(name="read_file", description="Read a file", parameters={"type": "object"}), + ToolSpec(name="write_file", description="Write a file", parameters={"type": "object"}), + ] + prompt = build_system_prompt(tool_specs=tools, mode="general", username="tester") + assert isinstance(prompt, str) + assert len(prompt) > 0 + assert "read_file" in prompt or "read_file" in prompt + + def test_renders_with_username(self): + tools = [ToolSpec(name="test_tool", description="Test", parameters={"type": "object"})] + prompt = build_system_prompt(tool_specs=tools, username="alice") + assert isinstance(prompt, str) + assert len(prompt) > 0 + + def test_renders_with_sandbox_info(self): + tools = [ToolSpec(name="bash", description="Run commands", parameters={"type": "object"})] + prompt = build_system_prompt(tool_specs=tools, sandbox_info="local") + assert isinstance(prompt, str) + + def test_renders_with_mode_plan(self): + tools = [ToolSpec(name="ask_user", description="Ask questions", parameters={"type": "object"})] + prompt = build_system_prompt(tool_specs=tools, mode="plan") + assert isinstance(prompt, str) + + def test_renders_with_mode_research(self): + tools = [ToolSpec(name="papers", description="Search papers", parameters={"type": "object"})] + prompt = build_system_prompt(tool_specs=tools, mode="research") + assert isinstance(prompt, str) + + def test_renders_with_config(self): + from openmlr.config import AgentConfig + config = AgentConfig(model_name="test/model", max_iterations=10) + tools = [ToolSpec(name="test", description="Test tool", parameters={"type": "object"})] + prompt = build_system_prompt(tool_specs=tools, config=config) + assert isinstance(prompt, str) + + def test_multiple_tools_appear_in_prompt(self): + tools = [ + ToolSpec(name="tool_a", description="First tool", parameters={"type": "object"}), + ToolSpec(name="tool_b", description="Second tool", parameters={"type": "object"}), + ToolSpec(name="tool_c", description="Third tool", parameters={"type": "object"}), + ] + prompt = build_system_prompt(tool_specs=tools) + assert isinstance(prompt, str) + + def test_empty_tools(self): + prompt = build_system_prompt(tool_specs=[]) + assert isinstance(prompt, str) + assert len(prompt) > 0 + + def test_contains_date_and_time(self): + tools = [ToolSpec(name="t", description="d", parameters={"type": "object"})] + prompt = build_system_prompt(tool_specs=tools) + # Contains date in YYYY-MM-DD format + import re + assert re.search(r"\d{4}-\d{2}-\d{2}", prompt) + + def test_contains_sandbox_info_in_prompt(self): + tools = [ToolSpec(name="bash", description="Run", parameters={"type": "object"})] + prompt = build_system_prompt(tool_specs=tools, sandbox_info="SSH remote (4 GPUs)") + assert isinstance(prompt, str) + + +class TestCompactPrompt: + def test_is_string(self): + assert isinstance(COMPACT_PROMPT, str) + assert len(COMPACT_PROMPT) > 0 + + def test_mentions_summary(self): + assert "summary" in COMPACT_PROMPT.lower() diff --git a/backend/tests/test_redis_pubsub.py b/backend/tests/test_redis_pubsub.py new file mode 100644 index 0000000..80fd6d9 --- /dev/null +++ b/backend/tests/test_redis_pubsub.py @@ -0,0 +1,184 @@ +"""Tests for Redis pub/sub — event publishing, answer relay, interrupt signaling.""" + +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + + +@pytest.mark.asyncio +class TestPublishEvent: + async def test_publishes_json_to_redis(self): + from openmlr.agent.types import AgentEvent + from openmlr.services.redis_pubsub import publish_event + + mock_redis = AsyncMock() + with patch("openmlr.services.redis_pubsub.get_redis", return_value=mock_redis): + event = AgentEvent(event_type="status", data={"status": "ready"}) + await publish_event(event) + + mock_redis.publish.assert_called_once() + call_args = mock_redis.publish.call_args[0] + assert call_args[0] == "openmlr:events" + assert "status" in call_args[1] + + async def test_handles_redis_error(self): + from openmlr.agent.types import AgentEvent + from openmlr.services.redis_pubsub import publish_event + + mock_redis = AsyncMock() + mock_redis.publish.side_effect = Exception("Redis down") + with patch("openmlr.services.redis_pubsub.get_redis", return_value=mock_redis): + event = AgentEvent(event_type="status") + await publish_event(event) # should not raise + + +@pytest.mark.asyncio +class TestPublishAnswers: + async def test_sets_answers_key_in_redis(self): + from openmlr.services.redis_pubsub import publish_answers + + mock_redis = AsyncMock() + with patch("openmlr.services.redis_pubsub.get_redis", return_value=mock_redis): + await publish_answers(conversation_id=42, answers={"q1": "Option A"}) + + assert mock_redis.set.called + key = mock_redis.set.call_args[0][0] + assert "openmlr:answers:" in key + assert "42" in key + + async def test_handles_redis_error(self): + from openmlr.services.redis_pubsub import publish_answers + + mock_redis = AsyncMock() + mock_redis.set.side_effect = Exception("Redis down") + with patch("openmlr.services.redis_pubsub.get_redis", return_value=mock_redis): + await publish_answers(conversation_id=1, answers={"q": "a"}) # should not raise + + +@pytest.mark.asyncio +class TestPublishInterrupt: + async def test_sets_interrupt_key(self): + from openmlr.services.redis_pubsub import publish_interrupt + + mock_redis = AsyncMock() + with patch("openmlr.services.redis_pubsub.get_redis", return_value=mock_redis): + await publish_interrupt(conversation_id=99) + + assert mock_redis.set.called + key = mock_redis.set.call_args[0][0] + assert "openmlr:interrupt:" in key + assert "99" in key + + async def test_uses_60s_expiry(self): + from openmlr.services.redis_pubsub import publish_interrupt + + mock_redis = AsyncMock() + with patch("openmlr.services.redis_pubsub.get_redis", return_value=mock_redis): + await publish_interrupt(conversation_id=1) + + call_kwargs = mock_redis.set.call_args[1] + assert call_kwargs.get("ex") == 60 + + +@pytest.mark.asyncio +class TestCheckInterrupt: + async def test_returns_true_when_key_exists(self): + from openmlr.services.redis_pubsub import check_interrupt + + mock_redis = AsyncMock() + mock_redis.exists.return_value = 1 + with patch("openmlr.services.redis_pubsub.get_redis", return_value=mock_redis): + result = await check_interrupt(conversation_id=5) + + assert result is True + + async def test_returns_false_when_not_found(self): + from openmlr.services.redis_pubsub import check_interrupt + + mock_redis = AsyncMock() + mock_redis.exists.return_value = 0 + with patch("openmlr.services.redis_pubsub.get_redis", return_value=mock_redis): + result = await check_interrupt(conversation_id=5) + + assert result is False + + async def test_returns_false_on_redis_error(self): + from openmlr.services.redis_pubsub import check_interrupt + + mock_redis = AsyncMock() + mock_redis.exists.side_effect = Exception("Redis down") + with patch("openmlr.services.redis_pubsub.get_redis", return_value=mock_redis): + result = await check_interrupt(conversation_id=5) + + assert result is False + + +@pytest.mark.asyncio +class TestClearInterrupt: + async def test_deletes_key(self): + from openmlr.services.redis_pubsub import clear_interrupt + + mock_redis = AsyncMock() + with patch("openmlr.services.redis_pubsub.get_redis", return_value=mock_redis): + await clear_interrupt(conversation_id=5) + + assert mock_redis.delete.called + + +@pytest.mark.asyncio +class TestWaitForAnswers: + async def test_returns_answers_when_set(self): + from openmlr.services.redis_pubsub import wait_for_answers + + mock_redis = AsyncMock() + mock_redis.get.return_value = '{"q1": "Option A"}' + with patch("openmlr.services.redis_pubsub.get_redis", return_value=mock_redis): + result = await wait_for_answers(conversation_id=1, timeout=0.5) + + assert result == {"q1": "Option A"} + mock_redis.delete.assert_called_once() + + async def test_returns_none_on_timeout(self): + import time + from openmlr.services.redis_pubsub import wait_for_answers + + mock_redis = AsyncMock() + mock_redis.get.return_value = None # never gets set + with patch("openmlr.services.redis_pubsub.get_redis", return_value=mock_redis): + with patch("openmlr.services.redis_pubsub.asyncio.sleep", return_value=None): + result = await wait_for_answers(conversation_id=1, timeout=0.1) + + assert result is None + + async def test_returns_none_on_redis_error(self): + from openmlr.services.redis_pubsub import wait_for_answers + + mock_redis = AsyncMock() + mock_redis.get.side_effect = Exception("Redis down") + with patch("openmlr.services.redis_pubsub.get_redis", return_value=mock_redis): + result = await wait_for_answers(conversation_id=1, timeout=0.1) + + assert result is None + + +class TestModuleConstants: + def test_channel_name(self): + from openmlr.services.redis_pubsub import CHANNEL_NAME + assert CHANNEL_NAME == "openmlr:events" + + def test_answers_key_prefix(self): + from openmlr.services.redis_pubsub import ANSWERS_KEY_PREFIX + assert ANSWERS_KEY_PREFIX == "openmlr:answers:" + + def test_interrupt_key_prefix(self): + from openmlr.services.redis_pubsub import INTERRUPT_KEY_PREFIX + assert INTERRUPT_KEY_PREFIX == "openmlr:interrupt:" + + def test_redis_url_from_env(self, monkeypatch): + monkeypatch.setenv("REDIS_URL", "redis://custom:6379/1") + from importlib import reload + import openmlr.services.redis_pubsub + reload(openmlr.services.redis_pubsub) + assert openmlr.services.redis_pubsub.REDIS_URL == "redis://custom:6379/1" + # Restore + monkeypatch.delenv("REDIS_URL", raising=False) + reload(openmlr.services.redis_pubsub) diff --git a/backend/tests/test_routes_health.py b/backend/tests/test_routes_health.py new file mode 100644 index 0000000..6af06e6 --- /dev/null +++ b/backend/tests/test_routes_health.py @@ -0,0 +1,42 @@ +"""Tests for the health check endpoints.""" + +import pytest +from httpx import AsyncClient + +pytestmark = pytest.mark.asyncio + + +class TestHealthEndpoints: + async def test_api_health_returns_ok(self, client: AsyncClient): + resp = await client.get("/api/health") + assert resp.status_code == 200 + data = resp.json() + assert data["status"] == "ok" + assert "version" in data + assert "timestamp" in data + assert isinstance(data["version"], str) + + async def test_health_returns_ok(self, client: AsyncClient): + resp = await client.get("/health") + assert resp.status_code == 200 + data = resp.json() + assert data["status"] == "ok" + assert "version" in data + assert "timestamp" in data + + async def test_both_endpoints_return_same_structure(self, client: AsyncClient): + resp1 = await client.get("/api/health") + resp2 = await client.get("/health") + assert resp1.json()["status"] == resp2.json()["status"] + assert resp1.json()["version"] == resp2.json()["version"] + + async def test_health_not_require_auth(self, client: AsyncClient): + resp = await client.get("/api/health") + assert resp.status_code == 200 + + async def test_health_timestamp_is_iso_format(self, client: AsyncClient): + resp = await client.get("/health") + data = resp.json() + import re + iso_pattern = r"\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}" + assert re.search(iso_pattern, data["timestamp"]) diff --git a/backend/tests/test_routes_settings.py b/backend/tests/test_routes_settings.py new file mode 100644 index 0000000..91db400 --- /dev/null +++ b/backend/tests/test_routes_settings.py @@ -0,0 +1,146 @@ +"""Tests for settings API routes.""" + +import os +import pytest +from httpx import AsyncClient + +pytestmark = pytest.mark.asyncio +from sqlalchemy.ext.asyncio import AsyncSession + +from openmlr.db import operations as ops + + +class TestAgentSettings: + async def test_get_all_settings_empty(self, auth_client: AsyncClient): + resp = await auth_client.get("/api/settings") + assert resp.status_code == 200 + data = resp.json() + assert "settings" in data + + async def test_get_all_settings_after_set(self, auth_client: AsyncClient, db_session: AsyncSession, test_user): + await ops.set_user_setting(db_session, test_user.id, "agent", "default_model", "claude") + resp = await auth_client.get("/api/settings") + assert resp.status_code == 200 + data = resp.json() + assert "agent" in data["settings"] + assert data["settings"]["agent"]["default_model"] == "claude" + + async def test_get_settings_category(self, auth_client: AsyncClient, db_session: AsyncSession, test_user): + await ops.set_user_setting(db_session, test_user.id, "agent", "yolo_mode", True) + resp = await auth_client.get("/api/settings/agent") + assert resp.status_code == 200 + data = resp.json() + assert "settings" in data + assert data["settings"]["yolo_mode"] is True + + async def test_update_setting(self, auth_client: AsyncClient, db_session: AsyncSession, test_user): + resp = await auth_client.put( + "/api/settings/agent/test_key", + json={"value": "test_value"}, + ) + assert resp.status_code == 200 + assert resp.json() == {"ok": True} + + async def test_update_setting_missing_value(self, auth_client: AsyncClient): + resp = await auth_client.put( + "/api/settings/agent/test_key", + json={}, + ) + assert resp.status_code == 400 + + async def test_update_setting_with_dict(self, auth_client: AsyncClient): + resp = await auth_client.put( + "/api/settings/mcp/config", + json={"value": {"server": "test", "port": 1234}}, + ) + assert resp.status_code == 200 + + async def test_delete_setting(self, auth_client: AsyncClient, db_session: AsyncSession, test_user): + await ops.set_user_setting(db_session, test_user.id, "agent", "remove_me", "yes") + resp = await auth_client.delete("/api/settings/agent/remove_me") + assert resp.status_code == 200 + assert resp.json() == {"ok": True} + + async def test_update_provider_key_sets_env(self, auth_client: AsyncClient, db_session: AsyncSession, test_user): + resp = await auth_client.put( + "/api/settings/providers/openai_api_key", + json={"value": "sk-test-key"}, + ) + assert resp.status_code == 200 + assert os.environ["OPENAI_API_KEY"] == "sk-test-key" + + async def test_settings_require_auth(self, client: AsyncClient): + resp = await client.get("/api/settings") + assert resp.status_code == 401 + + +class TestProviders: + async def test_list_providers(self, auth_client: AsyncClient): + resp = await auth_client.get("/api/providers") + assert resp.status_code == 200 + data = resp.json() + assert "providers" in data + providers = data["providers"] + assert isinstance(providers, list) + assert len(providers) > 0 + provider_ids = [p["id"] for p in providers] + assert "openai" in provider_ids + assert "anthropic" in provider_ids + assert "openrouter" in provider_ids + + async def test_provider_has_required_fields(self, auth_client: AsyncClient): + resp = await auth_client.get("/api/providers") + data = resp.json() + for p in data["providers"]: + assert "id" in p + assert "name" in p + assert "key_env" in p + assert "configured" in p + + +class TestAppStatus: + async def test_get_status(self, auth_client: AsyncClient): + resp = await auth_client.get("/api/status") + assert resp.status_code == 200 + data = resp.json() + assert "model" in data + assert "research_model" in data + assert "yolo_mode" in data + assert "needs_onboarding" in data + + +class TestModels: + async def test_list_models(self, auth_client: AsyncClient): + resp = await auth_client.get("/api/models") + assert resp.status_code == 200 + data = resp.json() + assert "models" in data + models = data["models"] + assert isinstance(models, list) + assert len(models) > 0 + + async def test_models_have_required_fields(self, auth_client: AsyncClient): + resp = await auth_client.get("/api/models") + models = resp.json()["models"] + for m in models: + assert "id" in m + assert "name" in m + assert "provider" in m + + +class TestConfigEndpoint: + async def test_save_config(self, auth_client: AsyncClient): + resp = await auth_client.post( + "/api/config", + json={"OPENAI_API_KEY": "sk-from-config"}, + ) + assert resp.status_code == 200 + assert resp.json() == {"ok": True} + + async def test_save_config_ignores_non_whitelisted(self, auth_client: AsyncClient): + resp = await auth_client.post( + "/api/config", + json={"RANDOM_KEY": "should-be-ignored"}, + ) + assert resp.status_code == 200 + assert "RANDOM_KEY" not in os.environ diff --git a/backend/tests/test_sandbox_manager.py b/backend/tests/test_sandbox_manager.py new file mode 100644 index 0000000..ca0a751 --- /dev/null +++ b/backend/tests/test_sandbox_manager.py @@ -0,0 +1,54 @@ +"""Tests for SandboxManager — lifecycle management and provider selection.""" + +import pytest +from openmlr.sandbox.manager import SandboxManager + +pytestmark = pytest.mark.asyncio + + +@pytest.fixture +def manager(): + return SandboxManager() + + +class TestSandboxManager: + async def test_initial_state(self, manager): + assert manager.get_active() is None + assert manager.active_type == "none" + + async def test_create_local(self, manager): + sandbox = await manager.create("local") + assert sandbox is not None + assert manager.active_type == "local" + assert manager.get_active() is sandbox + + async def test_create_then_destroy(self, manager): + sandbox = await manager.create("local") + await manager.destroy() + assert manager.get_active() is None + assert manager.active_type == "none" + + async def test_create_replaces_existing(self, manager): + local1 = await manager.create("local") + local2 = await manager.create("local") + assert manager.get_active() is local2 + + async def test_ensure_local_creates_if_none(self, manager): + sandbox = await manager.ensure_local() + assert sandbox is not None + assert manager.active_type == "local" + + async def test_ensure_local_returns_existing(self, manager): + sandbox1 = await manager.create("local") + sandbox2 = await manager.ensure_local() + assert sandbox1 is sandbox2 + + async def test_create_unknown_provider_raises(self, manager): + with pytest.raises(ValueError, match="Unknown sandbox provider"): + await manager.create("invalid_provider") + + async def test_destroy_when_none_different_provider(self, manager): + await manager.create("local") + # Simulate a provider mismatch — destroy should still work + await manager.destroy() + assert manager.active_type == "none" diff --git a/backend/tests/test_sandbox_types.py b/backend/tests/test_sandbox_types.py new file mode 100644 index 0000000..d4cbd5f --- /dev/null +++ b/backend/tests/test_sandbox_types.py @@ -0,0 +1,139 @@ +"""Tests for sandbox interface types and LocalSandbox.""" + +import tempfile +from pathlib import Path +import pytest + +from openmlr.sandbox.interface import EnvironmentInfo, ExecutionResult, SandboxInterface +from openmlr.sandbox.local import LocalSandbox +from openmlr.sandbox.manager import SandboxManager + + +class TestEnvironmentInfo: + def test_defaults(self): + info = EnvironmentInfo() + assert info.os == "unknown" + assert info.python_version == "unknown" + assert info.gpu_available is False + assert info.gpu_info is None + assert info.installed_packages == [] + assert info.available_disk_gb == 0.0 + assert info.available_ram_gb == 0.0 + + def test_custom_values(self): + info = EnvironmentInfo( + os="Linux", + python_version="3.12.0", + gpu_available=True, + gpu_info="NVIDIA A100", + installed_packages=["torch", "numpy"], + available_disk_gb=50.0, + available_ram_gb=32.0, + ) + assert info.os == "Linux" + assert info.gpu_available is True + assert "torch" in info.installed_packages + + +class TestExecutionResult: + def test_defaults(self): + r = ExecutionResult(output="done", success=True) + assert r.output == "done" + assert r.success is True + assert r.exit_code == 0 + assert r.duration_seconds == 0.0 + + def test_failure(self): + r = ExecutionResult(output="error", success=False, exit_code=1, duration_seconds=2.5) + assert r.success is False + assert r.exit_code == 1 + assert r.duration_seconds == 2.5 + + def test_truncation_handled_by_caller(self): + # Truncation is done by the tools, not the dataclass + r = ExecutionResult(output="x" * 100000, success=True) + assert len(r.output) == 100000 + + +class TestSandboxInterface: + def test_is_abstract(self): + with pytest.raises(TypeError): + SandboxInterface() + + +@pytest.mark.asyncio +class TestLocalSandbox: + async def test_create_default(self, monkeypatch, tmp_path): + monkeypatch.chdir(tmp_path) + sb = LocalSandbox() + await sb.create({}) + assert sb.workdir == str(tmp_path) + + async def test_create_with_workdir(self, monkeypatch, tmp_path): + monkeypatch.chdir(tmp_path) + sb = LocalSandbox() + await sb.create({"workdir": str(tmp_path)}) + assert sb.workdir == str(tmp_path) + + async def test_write_and_read_file(self, monkeypatch, tmp_path): + monkeypatch.chdir(tmp_path) + sb = LocalSandbox(str(tmp_path)) + await sb.create({}) + f = tmp_path / "test.txt" + ok = await sb.write_file(str(f), "hello") + assert ok is True + content = await sb.read_file("test.txt") + assert content == "hello" + + async def test_file_exists(self, monkeypatch, tmp_path): + monkeypatch.chdir(tmp_path) + sb = LocalSandbox(str(tmp_path)) + await sb.create({}) + f = tmp_path / "exists.txt" + f.write_text("data") + assert await sb.file_exists("exists.txt") is True + assert await sb.file_exists("nope.txt") is False + + async def test_edit_file(self, monkeypatch, tmp_path): + monkeypatch.chdir(tmp_path) + sb = LocalSandbox(str(tmp_path)) + await sb.create({}) + f = tmp_path / "edit.txt" + f.write_text("old text here") + ok = await sb.edit_file("edit.txt", "old", "new") + assert ok is True + assert f.read_text() == "new text here" + + async def test_edit_nonexistent(self, monkeypatch, tmp_path): + monkeypatch.chdir(tmp_path) + sb = LocalSandbox(str(tmp_path)) + await sb.create({}) + # edit_file tries to read the file first, which raises FileNotFoundError + with pytest.raises(FileNotFoundError): + await sb.edit_file("nope.txt", "a", "b") + + async def test_list_files(self, monkeypatch, tmp_path): + monkeypatch.chdir(tmp_path) + sb = LocalSandbox(str(tmp_path)) + await sb.create({}) + (tmp_path / "a.txt").write_text("a") + (tmp_path / "b.txt").write_text("b") + (tmp_path / "subdir").mkdir() + files = await sb.list_files(".") + assert "a.txt" in files + assert "b.txt" in files + assert "subdir/" in files + + async def test_destroy_is_noop(self, monkeypatch, tmp_path): + monkeypatch.chdir(tmp_path) + sb = LocalSandbox() + await sb.create({}) + await sb.destroy() + + async def test_execute_simple_command(self, monkeypatch, tmp_path): + monkeypatch.chdir(tmp_path) + sb = LocalSandbox(str(tmp_path)) + await sb.create({}) + result = await sb.execute("echo hello", timeout=10) + assert result.success is True + assert "hello" in result.output diff --git a/backend/tests/test_session_manager.py b/backend/tests/test_session_manager.py new file mode 100644 index 0000000..6990d70 --- /dev/null +++ b/backend/tests/test_session_manager.py @@ -0,0 +1,99 @@ +"""Tests for SessionManager — multi-session lifecycle and message queuing.""" + +import pytest +from openmlr.services.session_manager import SessionManager, ActiveSession + +pytestmark = pytest.mark.asyncio +from openmlr.services.event_bus import EventBus +from openmlr.config import AgentConfig + + +@pytest.fixture +def config(): + return AgentConfig( + model_name="test/model", + max_iterations=10, + stream=False, + yolo_mode=False, + ) + + +@pytest.fixture +def event_bus(): + return EventBus() + + +@pytest.fixture +def session_manager(event_bus, config): + return SessionManager(event_bus=event_bus, default_config=config) + + +class TestSessionManager: + async def test_initial_state(self, session_manager): + assert session_manager.current_conversation_id is None + assert session_manager.is_processing is False + assert session_manager.get_current_session() is None + + async def test_get_session_nonexistent(self, session_manager): + assert session_manager.get_session(999) is None + + async def test_get_or_create_session(self, session_manager): + active = await session_manager.get_or_create_session( + conversation_id=1, + uuid="test-uuid-1", + model="test/model", + mode="general", + username="tester", + ) + assert active is not None + assert active.conversation_id == 1 + assert active.uuid == "test-uuid-1" + assert active.session is not None + assert active.tool_router is not None + + async def test_get_or_create_returns_existing(self, session_manager): + s1 = await session_manager.get_or_create_session(1, "u1") + s2 = await session_manager.get_or_create_session(1, "u1") + assert s1 is s2 + + async def test_get_session_after_create(self, session_manager): + await session_manager.get_or_create_session(1, "u1") + s = session_manager.get_session(1) + assert s is not None + assert s.conversation_id == 1 + + async def test_remove_session(self, session_manager): + await session_manager.get_or_create_session(1, "u1") + await session_manager.remove_session(1) + assert session_manager.get_session(1) is None + + async def test_remove_nonexistent_session(self, session_manager): + await session_manager.remove_session(999) + + async def test_multiple_sessions(self, session_manager): + s1 = await session_manager.get_or_create_session(1, "u1") + s2 = await session_manager.get_or_create_session(2, "u2") + assert s1.conversation_id == 1 + assert s2.conversation_id == 2 + assert s1 is not s2 + + async def test_session_loads_existing_messages(self, session_manager): + messages = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + ] + active = await session_manager.get_or_create_session( + 1, "u1", existing_messages=messages, + ) + msgs = active.session.context_manager.get_messages() + assert len(msgs) >= 2 # includes system prompt + existing messages + + async def test_generate_title_no_session(self, session_manager): + title = await session_manager.generate_title(999, []) + assert title is None + + async def test_remove_session_clears_current(self, session_manager): + await session_manager.get_or_create_session(42, "u42") + session_manager.current_conversation_id = 42 + await session_manager.remove_session(42) + assert session_manager.current_conversation_id is None diff --git a/backend/tests/test_tool_registry.py b/backend/tests/test_tool_registry.py new file mode 100644 index 0000000..2f835ef --- /dev/null +++ b/backend/tests/test_tool_registry.py @@ -0,0 +1,233 @@ +"""Tests for ToolRouter — registration, dispatch, mode filtering.""" + +import pytest +from openmlr.tools.registry import ToolRouter, MODE_TOOL_RESTRICTIONS + +pytestmark = pytest.mark.asyncio +from openmlr.agent.types import ToolSpec, ToolCall + + +@pytest.fixture +def router(): + return ToolRouter() + + +@pytest.fixture +def dummy_tool(): + async def handler(arg1: str = "default") -> tuple[str, bool]: + return f"handled: {arg1}", True + + return ToolSpec( + name="dummy", + description="A dummy tool", + parameters={ + "type": "object", + "properties": { + "arg1": {"type": "string", "description": "Test arg"}, + }, + "required": ["arg1"], + }, + handler=handler, + ) + + +@pytest.fixture +def read_tool(): + async def handler(path: str) -> tuple[str, bool]: + return f"read: {path}", True + + return ToolSpec( + name="read_file", + description="Read a file", + parameters={ + "type": "object", + "properties": { + "path": {"type": "string"}, + }, + "required": ["path"], + }, + handler=handler, + ) + + +@pytest.fixture +def bash_tool(): + async def handler(cmd: str) -> tuple[str, bool]: + return f"executed: {cmd}", True + + return ToolSpec( + name="bash", + description="Execute a command", + parameters={ + "type": "object", + "properties": { + "cmd": {"type": "string"}, + }, + "required": ["cmd"], + }, + handler=handler, + ) + + +class TestToolRegistration: + async def test_register_single(self, router, dummy_tool): + router.register(dummy_tool) + assert "dummy" in router.tools + + async def test_register_many(self, router, dummy_tool, read_tool): + router.register_many([dummy_tool, read_tool]) + assert len(router.tools) == 2 + assert "dummy" in router.tools + assert "read_file" in router.tools + + async def test_get_tool(self, router, dummy_tool): + router.register(dummy_tool) + tool = router.get_tool("dummy") + assert tool is not None + assert tool.name == "dummy" + + async def test_get_tool_not_found(self, router): + assert router.get_tool("nonexistent") is None + + async def test_register_blocked_tool(self, router, dummy_tool): + router._blocklist = {"dummy"} + router.register(dummy_tool) + assert "dummy" not in router.tools + + +class TestToolDispatch: + async def test_call_tool_simple(self, router, dummy_tool): + router.register(dummy_tool) + result, success = await router.call_tool("dummy", {"arg1": "hello"}) + assert success is True + assert "handled: hello" in result + + async def test_call_tool_unknown(self, router): + result, success = await router.call_tool("unknown", {}) + assert success is False + assert "Unknown tool" in result + + async def test_call_tool_default_args(self, router, dummy_tool): + router.register(dummy_tool) + result, success = await router.call_tool("dummy", {}) + assert success is True + assert "handled: default" in result + + async def test_call_tool_type_error(self, router): + async def handler(required_arg: str) -> tuple[str, bool]: + return "ok", True + + tool = ToolSpec( + name="strict", description="Needs arg", parameters={"type": "object"}, + handler=handler, + ) + router.register(tool) + result, success = await router.call_tool("strict", {"wrong_param": "val"}) + assert success is False + assert "Tool argument error" in result + + +class TestToolSpecsForLLM: + async def test_get_specs_in_openai_format(self, router, dummy_tool, read_tool): + router.register(dummy_tool) + router.register(read_tool) + specs = router.get_tool_specs_for_llm() + assert len(specs) == 2 + for spec in specs: + assert spec["type"] == "function" + assert "function" in spec + assert "name" in spec["function"] + assert "description" in spec["function"] + assert "parameters" in spec["function"] + + async def test_get_raw_specs(self, router, dummy_tool, read_tool): + router.register(dummy_tool) + router.register(read_tool) + raw = router.get_raw_specs() + assert len(raw) == 2 + assert all(isinstance(s, ToolSpec) for s in raw) + + +class TestModeFiltering: + async def test_default_mode_allows_all(self, router, dummy_tool, bash_tool): + router.register(dummy_tool) + router.register(bash_tool) + allowed, msg = router.is_tool_allowed("bash") + assert allowed is True + assert msg == "" + + async def test_plan_mode_blocks_bash(self, router, bash_tool): + router.register(bash_tool) + router.set_mode("plan") + allowed, msg = router.is_tool_allowed("bash") + assert allowed is False + assert "PLAN mode" in msg + + async def test_plan_mode_allows_ask_user(self, router): + router.set_mode("plan") + allowed, msg = router.is_tool_allowed("ask_user") + assert allowed is True + + async def test_plan_mode_allows_read_file(self, router): + router.set_mode("plan") + allowed, msg = router.is_tool_allowed("read_file") + assert allowed is True + + async def test_execute_mode_blocks_ask_user(self, router): + router.set_mode("execute") + allowed, msg = router.is_tool_allowed("ask_user") + assert allowed is False + assert "EXECUTE mode" in msg + + async def test_execute_mode_allows_bash(self, router, bash_tool): + router.register(bash_tool) + router.set_mode("execute") + allowed, msg = router.is_tool_allowed("bash") + assert allowed is True + + async def test_mode_filtering_in_get_specs(self, router, dummy_tool, bash_tool, read_tool): + router.register(dummy_tool) + router.register(bash_tool) + router.register(read_tool) + router.set_mode("plan") + specs = router.get_tool_specs_for_llm(filter_by_mode=True) + spec_names = [s["function"]["name"] for s in specs] + assert "bash" not in spec_names + + async def test_no_filter_in_get_specs(self, router, dummy_tool, bash_tool): + router.register(dummy_tool) + router.register(bash_tool) + router.set_mode("plan") + specs = router.get_tool_specs_for_llm(filter_by_mode=False) + spec_names = [s["function"]["name"] for s in specs] + assert "bash" in spec_names + assert "dummy" in spec_names + + async def test_call_tool_mode_enforcement(self, router, bash_tool): + router.register(bash_tool) + router.set_mode("plan") + result, success = await router.call_tool("bash", {"cmd": "ls"}, enforce_mode=True) + assert success is False + assert "MODE VIOLATION" in result + + async def test_get_mode(self, router): + assert router.get_mode() == "general" + router.set_mode("research") + assert router.get_mode() == "research" + + +class TestModeRestrictionsConfig: + async def test_plan_has_allowed_list(self): + assert "allowed" in MODE_TOOL_RESTRICTIONS["plan"] + + async def test_execute_has_blocked_list(self): + assert "blocked" in MODE_TOOL_RESTRICTIONS["execute"] + + async def test_plan_allowed_includes_ask_user(self): + assert "ask_user" in MODE_TOOL_RESTRICTIONS["plan"]["allowed"] + + async def test_plan_allowed_includes_read_file(self): + assert "read_file" in MODE_TOOL_RESTRICTIONS["plan"]["allowed"] + + async def test_execute_blocked_includes_ask_user(self): + assert "ask_user" in MODE_TOOL_RESTRICTIONS["execute"]["blocked"] diff --git a/backend/tests/test_tools_github.py b/backend/tests/test_tools_github.py new file mode 100644 index 0000000..3b6a5e7 --- /dev/null +++ b/backend/tests/test_tools_github.py @@ -0,0 +1,55 @@ +"""Tests for GitHub tools — tool specs and helper functions.""" + +import pytest +from openmlr.tools.github import create_github_tools, _headers + +pytestmark = pytest.mark.asyncio + + +class TestCreateGithubTools: + async def test_creates_all_tools(self): + tools = create_github_tools() + names = [t.name for t in tools] + assert "github_read_file" in names + assert "github_list_repos" in names + assert "github_find_examples" in names + assert len(tools) == 3 + + async def test_read_file_required_params(self): + tools = create_github_tools() + read = [t for t in tools if t.name == "github_read_file"][0] + required = read.parameters["required"] + assert "owner" in required + assert "repo" in required + assert "path" in required + + async def test_list_repos_required_params(self): + tools = create_github_tools() + list_tool = [t for t in tools if t.name == "github_list_repos"][0] + assert "owner" in list_tool.parameters["required"] + + async def test_find_examples_required_params(self): + tools = create_github_tools() + find = [t for t in tools if t.name == "github_find_examples"][0] + assert "query" in find.parameters["required"] + + async def test_find_examples_language_filter(self): + tools = create_github_tools() + find = [t for t in tools if t.name == "github_find_examples"][0] + assert "language" in find.parameters["properties"] + + +class TestHeaders: + async def test_headers_accept(self): + h = _headers() + assert h["Accept"] == "application/vnd.github+json" + + async def test_headers_without_token(self, monkeypatch): + monkeypatch.delenv("GITHUB_TOKEN", raising=False) + h = _headers() + assert "Authorization" not in h + + async def test_headers_with_token(self, monkeypatch): + monkeypatch.setenv("GITHUB_TOKEN", "ghp_test123") + h = _headers() + assert h["Authorization"] == "Bearer ghp_test123" diff --git a/backend/tests/test_tools_local.py b/backend/tests/test_tools_local.py new file mode 100644 index 0000000..575b7b6 --- /dev/null +++ b/backend/tests/test_tools_local.py @@ -0,0 +1,219 @@ +"""Tests for local tools — bash, read, write, edit, and path validation.""" + +import os +import tempfile +from pathlib import Path +import pytest + +from openmlr.tools.local import ( + create_local_tools, _validate_path, _handle_read, _handle_write, _handle_edit, + DOCKER_IMAGE, CONTAINER_PREFIX, ALLOW_DIRECT_EXEC, WORKSPACE_ROOT, +) + + +class TestCreateLocalTools: + def test_creates_all_tools(self): + tools = create_local_tools() + names = [t.name for t in tools] + assert "bash" in names + assert "read" in names + assert "write" in names + assert "edit" in names + assert len(tools) == 4 + + def test_bash_tool_spec(self): + tools = create_local_tools() + bash = [t for t in tools if t.name == "bash"][0] + assert "command" in bash.parameters["properties"] + assert "timeout" in bash.parameters["properties"] + assert "workdir" in bash.parameters["properties"] + assert "command" in bash.parameters["required"] + + def test_read_tool_spec(self): + tools = create_local_tools() + read_tool = [t for t in tools if t.name == "read"][0] + assert read_tool.handler is not None + assert "path" in read_tool.parameters["required"] + + def test_write_tool_spec(self): + tools = create_local_tools() + write_tool = [t for t in tools if t.name == "write"][0] + assert "path" in write_tool.parameters["required"] + assert "content" in write_tool.parameters["required"] + + def test_edit_tool_spec(self): + tools = create_local_tools() + edit_tool = [t for t in tools if t.name == "edit"][0] + required = edit_tool.parameters["required"] + assert "path" in required + assert "old_string" in required + assert "new_string" in required + + +class TestValidatePath: + def test_resolves_relative_path(self): + cwd = os.getcwd() + path = Path(".", "test_file.txt") + resolved, error = _validate_path(path) + assert error is None + assert resolved.is_absolute() + + def test_path_within_cwd_is_allowed(self): + path = Path.cwd() / "test" / "file.py" + resolved, error = _validate_path(path) + assert error is None + + def test_blocked_system_path(self, monkeypatch): + monkeypatch.setattr("openmlr.tools.local.WORKSPACE_ROOT", "/home/user/projects") + path = Path("/etc", "passwd") + resolved, error = _validate_path(path) + assert error is not None + assert "outside workspace" in error + + def test_blocked_root_path(self, monkeypatch): + monkeypatch.setattr("openmlr.tools.local.WORKSPACE_ROOT", "/home/user/projects") + path = Path("/root", "secret") + resolved, error = _validate_path(path) + assert error is not None + + def test_blocked_var_path(self, monkeypatch): + monkeypatch.setattr("openmlr.tools.local.WORKSPACE_ROOT", "/home/user/projects") + path = Path("/var", "log") + resolved, error = _validate_path(path) + assert error is not None + + def test_blocked_bin_path(self): + path = Path("/bin", "bash") + resolved, error = _validate_path(path) + assert error is not None + + def test_with_workspace_root_set(self, monkeypatch): + workspace = str(Path.cwd() / "workspace") + monkeypatch.setattr("openmlr.tools.local.WORKSPACE_ROOT", workspace) + within = Path(workspace) / "file.txt" + resolved, error = _validate_path(within) + assert error is None + + +@pytest.mark.asyncio +class TestHandleRead: + async def test_reads_file_with_line_numbers(self, monkeypatch, tmp_path): + monkeypatch.chdir(tmp_path) + f = tmp_path / "test.txt" + f.write_text("line one\nline two\nline three\n") + result, success = await _handle_read("test.txt") + assert success is True + assert "1: line one" in result + assert "2: line two" in result + assert "3: line three" in result + + async def test_read_with_offset_and_limit(self, monkeypatch, tmp_path): + monkeypatch.chdir(tmp_path) + f = tmp_path / "test.txt" + f.write_text("a\nb\nc\nd\ne\n") + result, success = await _handle_read("test.txt", offset=3, limit=2) + assert success is True + assert "3: c" in result + assert "4: d" in result + assert "5: e" not in result + + async def test_read_nonexistent_file(self, monkeypatch, tmp_path): + monkeypatch.chdir(tmp_path) + result, success = await _handle_read("nonexistent_test_file.txt") + assert success is False + assert "File not found" in result + + async def test_read_directory(self, monkeypatch, tmp_path): + monkeypatch.chdir(tmp_path) + d = tmp_path / "testdir" + d.mkdir() + (d / "file1.txt").write_text("x") + (d / "file2.txt").write_text("y") + result, success = await _handle_read("testdir") + assert success is True + assert "file1.txt" in result + assert "file2.txt" in result + + +@pytest.mark.asyncio +class TestHandleWrite: + async def test_writes_file(self, monkeypatch, tmp_path): + monkeypatch.chdir(tmp_path) + f = tmp_path / "output.txt" + result, success = await _handle_write("output.txt", content="hello world") + assert success is True + assert "Wrote" in result + assert f.read_text() == "hello world" + + async def test_creates_parent_directories(self, monkeypatch, tmp_path): + monkeypatch.chdir(tmp_path) + f = tmp_path / "deeply" / "nested" / "dir" / "file.txt" + result, success = await _handle_write("deeply/nested/dir/file.txt", content="deep") + assert success is True + assert f.read_text() == "deep" + + async def test_requires_path(self, monkeypatch, tmp_path): + monkeypatch.chdir(tmp_path) + result, success = await _handle_write(path="", content="test") + assert success is False + + async def test_requires_content(self, monkeypatch, tmp_path): + monkeypatch.chdir(tmp_path) + result, success = await _handle_write(path="test.txt", content="") + assert success is False + + +@pytest.mark.asyncio +class TestHandleEdit: + async def test_replaces_string(self, monkeypatch, tmp_path): + monkeypatch.chdir(tmp_path) + f = tmp_path / "edit_test.txt" + f.write_text("Hello World") + result, success = await _handle_edit("edit_test.txt", "World", "Universe") + assert success is True + assert "Replaced" in result + assert f.read_text() == "Hello Universe" + + async def test_old_string_not_found(self, monkeypatch, tmp_path): + monkeypatch.chdir(tmp_path) + f = tmp_path / "edit_test.txt" + f.write_text("Hello World") + result, success = await _handle_edit("edit_test.txt", "Mars", "Earth") + assert success is False + assert "old_string not found" in result + + async def test_multiple_matches_without_replace_all(self, monkeypatch, tmp_path): + monkeypatch.chdir(tmp_path) + f = tmp_path / "edit_test.txt" + f.write_text("hello hello hello") + result, success = await _handle_edit("edit_test.txt", "hello", "hi") + assert success is False + assert "Found 3 matches" in result + + async def test_replace_all(self, monkeypatch, tmp_path): + monkeypatch.chdir(tmp_path) + f = tmp_path / "edit_test.txt" + f.write_text("hello hello hello") + result, success = await _handle_edit("edit_test.txt", "hello", "hi", replace_all=True) + assert success is True + assert f.read_text() == "hi hi hi" + + async def test_nonexistent_file(self, monkeypatch, tmp_path): + monkeypatch.chdir(tmp_path) + result, success = await _handle_edit("nonexistent_edit_test.txt", "a", "b") + assert success is False + assert "File not found" in result + + +class TestModuleConstants: + def test_docker_image_default(self): + assert DOCKER_IMAGE == "python:3.12-slim" + + def test_container_prefix(self): + assert CONTAINER_PREFIX == "openmlr-sandbox" + + def test_allow_direct_exec_default(self, monkeypatch): + monkeypatch.delenv("OPENMLR_ALLOW_DIRECT_EXEC", raising=False) + import openmlr.tools.local + allow = openmlr.tools.local.ALLOW_DIRECT_EXEC + assert allow is False diff --git a/backend/tests/test_tools_mcp.py b/backend/tests/test_tools_mcp.py new file mode 100644 index 0000000..0cddbee --- /dev/null +++ b/backend/tests/test_tools_mcp.py @@ -0,0 +1,64 @@ +"""Tests for MCP tools — env var substitution and config processing.""" + +import pytest +from openmlr.tools.mcp import substitute_env_vars, process_mcp_config + +pytestmark = pytest.mark.asyncio + + +class TestSubstituteEnvVars: + async def test_replaces_var(self, monkeypatch): + monkeypatch.setenv("TEST_KEY", "test-value") + result = substitute_env_vars("prefix-${TEST_KEY}-suffix") + assert result == "prefix-test-value-suffix" + + async def test_multiple_vars(self, monkeypatch): + monkeypatch.setenv("A", "aaa") + monkeypatch.setenv("B", "bbb") + result = substitute_env_vars("${A}_${B}") + assert result == "aaa_bbb" + + async def test_unknown_var_unchanged(self, monkeypatch): + monkeypatch.delenv("UNKNOWN_VAR", raising=False) + result = substitute_env_vars("${UNKNOWN_VAR}") + assert result == "${UNKNOWN_VAR}" + + async def test_no_vars(self): + result = substitute_env_vars("plain text no vars") + assert result == "plain text no vars" + + async def test_partial_match(self): + result = substitute_env_vars("$TEST no match") + assert result == "$TEST no match" + + +class TestProcessMCPConfig: + async def test_simple_dict(self, monkeypatch): + monkeypatch.setenv("API_KEY", "secret123") + config = {"url": "https://api.example.com?key=${API_KEY}"} + result = process_mcp_config(config) + assert result["url"] == "https://api.example.com?key=secret123" + + async def test_nested_dict(self, monkeypatch): + monkeypatch.setenv("HOST", "localhost") + config = {"server": {"host": "${HOST}", "port": 8080}} + result = process_mcp_config(config) + assert result["server"]["host"] == "localhost" + assert result["server"]["port"] == 8080 + + async def test_list_values(self, monkeypatch): + monkeypatch.setenv("ARG1", "value1") + monkeypatch.setenv("ARG2", "value2") + config = {"args": ["${ARG1}", "${ARG2}", "literal"]} + result = process_mcp_config(config) + assert result["args"] == ["value1", "value2", "literal"] + + async def test_non_string_values(self): + config = {"timeout": 30, "enabled": True, "null_val": None} + result = process_mcp_config(config) + assert result["timeout"] == 30 + assert result["enabled"] is True + assert result["null_val"] is None + + async def test_empty_config(self): + assert process_mcp_config({}) == {} diff --git a/backend/tests/test_tools_papers.py b/backend/tests/test_tools_papers.py new file mode 100644 index 0000000..37b8803 --- /dev/null +++ b/backend/tests/test_tools_papers.py @@ -0,0 +1,98 @@ +"""Tests for papers tool — helper functions and tool spec.""" + +import pytest +from openmlr.tools.papers import ( + create_papers_tool, _to_openalex_id, _extract_arxiv_id, + _extract_arxiv_from_ids, _reconstruct_abstract, _check_budget, + _increment_budget, _get_budget_info, +) + +pytestmark = pytest.mark.asyncio + + +class TestCreatePapersTool: + async def test_creates_tool(self): + tool = create_papers_tool() + assert tool.name == "papers" + assert tool.handler is not None + assert "operation" in tool.parameters["required"] + ops = tool.parameters["properties"]["operation"]["enum"] + assert "search" in ops + assert "trending" in ops + assert "details" in ops + assert "read_paper" in ops + assert "citations" in ops + assert "recommend" in ops + assert "find_code" in ops + assert "find_datasets" in ops + + +class TestExtractArxivId: + async def test_standard_format(self): + assert _extract_arxiv_id("2301.12345") == "2301.12345" + + async def test_with_version(self): + assert _extract_arxiv_id("2301.12345v2") == "2301.12345v2" + + async def test_from_url(self): + assert _extract_arxiv_id("https://arxiv.org/abs/2301.12345") == "2301.12345" + + async def test_from_pdf_url(self): + assert _extract_arxiv_id("https://arxiv.org/pdf/2301.12345.pdf") == "2301.12345" + + async def test_no_arxiv_id(self): + assert _extract_arxiv_id("some text without id") is None + + async def test_from_doi_with_arxiv(self): + assert _extract_arxiv_id("10.48550/arXiv.2301.12345") == "2301.12345" + + +class TestToOpenAlexId: + async def test_openalex_id(self): + assert _to_openalex_id("W123456") == "W123456" + + async def test_openalex_url(self): + assert _to_openalex_id("https://openalex.org/W123456") == "https://openalex.org/W123456" + + async def test_doi(self): + assert _to_openalex_id("10.1234/foo.bar") == "https://doi.org/10.1234/foo.bar" + + async def test_arxiv_id(self): + result = _to_openalex_id("2301.12345") + assert "doi.org" in result + assert "arXiv" in result + + +class TestReconstructAbstract: + async def test_simple(self): + inv = {"hello": [0], "world": [1]} + assert _reconstruct_abstract(inv) == "hello world" + + async def test_empty(self): + assert _reconstruct_abstract({}) is None + assert _reconstruct_abstract(None) is None + + async def test_multiple_positions(self): + inv = {"the": [0, 3], "cat": [1], "sat": [2], "mat": [4]} + result = _reconstruct_abstract(inv) + assert result == "the cat sat the mat" + + +class TestBudgetFunctions: + async def test_check_budget_allows_first_call(self): + ok, msg = _check_budget() + assert ok is True + assert msg == "" + + async def test_get_budget_info(self): + info = _get_budget_info() + assert "used" in info + assert "max" in info + + async def test_increment_and_check(self): + from openmlr.tools.papers import _search_counts + _search_counts.clear() + _increment_budget() + info = _get_budget_info() + assert info["used"] >= 1 + _search_counts.clear() diff --git a/backend/tests/test_tools_research.py b/backend/tests/test_tools_research.py new file mode 100644 index 0000000..3577f82 --- /dev/null +++ b/backend/tests/test_tools_research.py @@ -0,0 +1,62 @@ +"""Tests for research tool — tool specs and the research sub-agent dispatching.""" + +import pytest +from openmlr.tools.research import ( + create_research_tool, _get_research_tool_specs, _execute_research_tool, + RESEARCH_SYSTEM_PROMPT, MAX_RESEARCH_ITERATIONS, +) +from openmlr.agent.types import ToolCall + + +class TestCreateResearchTool: + def test_creates_tool(self): + tool = create_research_tool() + assert tool.name == "research" + assert tool.handler is not None + assert "query" in tool.parameters["required"] + + def test_focus_enum(self): + tool = create_research_tool() + focus = tool.parameters["properties"]["focus"] + assert focus["enum"] == ["papers", "code", "docs", "general"] + + +class TestGetResearchToolSpecs: + def test_returns_formatted_tools(self): + specs = _get_research_tool_specs() + assert len(specs) > 0 + for spec in specs: + assert spec["type"] == "function" + assert "function" in spec + assert "name" in spec["function"] + + def test_includes_search(self): + specs = _get_research_tool_specs() + names = [s["function"]["name"] for s in specs] + assert "web_search" in names + + def test_includes_papers(self): + specs = _get_research_tool_specs() + names = [s["function"]["name"] for s in specs] + assert "papers" in names + + def test_github_read_included(self): + specs = _get_research_tool_specs() + names = [s["function"]["name"] for s in specs] + assert "github_read_file" in names or "github_find_examples" in names + + +class TestExecuteResearchTool: + @pytest.mark.asyncio + async def test_unknown_tool(self): + tc = ToolCall(id="tc1", name="unknown_tool", arguments={}) + result, success = await _execute_research_tool(tc) + assert success is False + assert "not available" in result + + def test_system_prompt_not_empty(self): + assert len(RESEARCH_SYSTEM_PROMPT) > 0 + assert "research sub-agent" in RESEARCH_SYSTEM_PROMPT.lower() + + def test_max_iterations(self): + assert MAX_RESEARCH_ITERATIONS == 60 diff --git a/backend/tests/test_tools_sandbox.py b/backend/tests/test_tools_sandbox.py new file mode 100644 index 0000000..e05a486 --- /dev/null +++ b/backend/tests/test_tools_sandbox.py @@ -0,0 +1,64 @@ +"""Tests for sandbox tools — probe, create, exec, read, write.""" + +import pytest +from openmlr.tools.sandbox_tools import create_sandbox_tools, _handle_probe + +pytestmark = pytest.mark.asyncio +from openmlr.sandbox.manager import SandboxManager + + +class TestCreateSandboxTools: + async def test_creates_all_tools(self): + mgr = SandboxManager() + tools = create_sandbox_tools(mgr) + names = [t.name for t in tools] + assert "sandbox_probe" in names + assert "sandbox_create" in names + assert "sandbox_exec" in names + assert "sandbox_read" in names + assert "sandbox_write" in names + assert len(tools) == 5 + + async def test_sandbox_create_needs_approval(self): + mgr = SandboxManager() + tools = create_sandbox_tools(mgr) + create = [t for t in tools if t.name == "sandbox_create"][0] + assert create.needs_approval is not None + assert create.needs_approval({"provider": "modal"}) is True + + async def test_sandbox_exec_requires_command(self): + mgr = SandboxManager() + tools = create_sandbox_tools(mgr) + exec_tool = [t for t in tools if t.name == "sandbox_exec"][0] + assert "command" in exec_tool.parameters["required"] + + async def test_sandbox_probe_no_params(self): + mgr = SandboxManager() + tools = create_sandbox_tools(mgr) + probe = [t for t in tools if t.name == "sandbox_probe"][0] + assert probe.parameters["properties"] == {} + + async def test_sandbox_create_provider_enum(self): + mgr = SandboxManager() + tools = create_sandbox_tools(mgr) + create = [t for t in tools if t.name == "sandbox_create"][0] + providers = create.parameters["properties"]["provider"]["enum"] + assert "local" in providers + assert "ssh" in providers + assert "modal" in providers + + +class TestHandleProbe: + async def test_probe_without_sandbox(self): + mgr = SandboxManager() + result, success = await _handle_probe(mgr) + assert success is True + assert "Using local environment" in result or "OS:" in result + + async def test_probe_with_local_sandbox(self): + mgr = SandboxManager() + await mgr.create("local") + result, success = await _handle_probe(mgr) + assert success is True + assert "Sandbox Environment" in result + await mgr.destroy() diff --git a/backend/tests/test_tools_search.py b/backend/tests/test_tools_search.py new file mode 100644 index 0000000..36c9f11 --- /dev/null +++ b/backend/tests/test_tools_search.py @@ -0,0 +1,27 @@ +"""Tests for web search tool — tool spec validation.""" + +import pytest +from openmlr.tools.search import create_search_tools + +pytestmark = pytest.mark.asyncio + + +class TestCreateSearchTools: + async def test_creates_tool(self): + tools = create_search_tools() + assert len(tools) == 1 + assert tools[0].name == "web_search" + + async def test_handler_configured(self): + tools = create_search_tools() + assert tools[0].handler is not None + + async def test_query_required(self): + tools = create_search_tools() + assert "query" in tools[0].parameters["required"] + + async def test_count_parameter(self): + tools = create_search_tools() + props = tools[0].parameters["properties"] + assert "count" in props + assert props["count"]["type"] == "integer" diff --git a/backend/tests/test_tools_writing.py b/backend/tests/test_tools_writing.py new file mode 100644 index 0000000..bb72e20 --- /dev/null +++ b/backend/tests/test_tools_writing.py @@ -0,0 +1,216 @@ +"""Tests for writing tool — project management and paper operations.""" + +import pytest +from openmlr.tools.writing import ( + create_writing_tool, _create_project, _set_outline, + _write_section, _get_draft, _get_draft_from_proj, + _list_sections, _add_citation, _refine_section, + _count_sections, +) + +pytestmark = pytest.mark.asyncio + + +class TestCreateWritingTool: + async def test_creates_tool(self): + tool = create_writing_tool() + assert tool.name == "writing" + assert tool.handler is not None + assert "operation" in tool.parameters["required"] + ops = tool.parameters["properties"]["operation"]["enum"] + assert "create_project" in ops + assert "write_section" in ops + assert "get_draft" in ops + + +class TestCreateProject: + async def test_creates_project(self): + from openmlr.tools.writing import _projects + _projects.clear() + result, ok = _create_project(conv_id=1, title="My Paper") + assert ok is True + assert "My Paper" in result + proj = _projects.get(1) + assert proj is not None + assert proj["title"] == "My Paper" + _projects.clear() + + async def test_requires_title(self): + result, ok = _create_project(conv_id=1, title="") + assert ok is False + assert "title" in result.lower() + + +class TestSetOutline: + async def test_no_project(self): + from openmlr.tools.writing import _projects + _projects.clear() + result, ok = _set_outline(conv_id=999, outline=[]) + assert ok is False + assert "No paper project" in result + + async def test_requires_outline(self): + from openmlr.tools.writing import _projects + _projects.clear() + _create_project(conv_id=1, title="Test") + result, ok = _set_outline(conv_id=1, outline=None) + assert ok is False + _projects.clear() + + async def test_sets_outline(self): + from openmlr.tools.writing import _projects + _projects.clear() + _create_project(conv_id=1, title="Test") + outline = [ + {"id": "sec1", "title": "Introduction"}, + {"id": "sec2", "title": "Methods", "subsections": [ + {"id": "sec2.1", "title": "Setup"}, + ]}, + ] + result, ok = _set_outline(conv_id=1, outline=outline) + assert ok is True + assert "Introduction" in result + assert "Methods" in result + assert "Setup" in result + _projects.clear() + + +class TestWriteSection: + async def test_no_project(self): + from openmlr.tools.writing import _projects + _projects.clear() + result, ok = _write_section(conv_id=999, section_id="s1", content="text") + assert ok is False + + async def test_writes_section(self): + from openmlr.tools.writing import _projects + _projects.clear() + _create_project(conv_id=1, title="Test") + _set_outline(conv_id=1, outline=[{"id": "intro", "title": "Introduction"}]) + result, ok = _write_section(conv_id=1, section_id="intro", content="Hello world.") + assert ok is True + assert "intro" in result + assert "auto-saved" in result + _projects.clear() + + +class TestGetDraft: + async def test_no_project(self): + from openmlr.tools.writing import _projects + _projects.clear() + result, ok = _get_draft(conv_id=999) + assert ok is False + + async def test_generates_draft(self): + from openmlr.tools.writing import _projects + _projects.clear() + _create_project(conv_id=1, title="The Paper") + _set_outline(conv_id=1, outline=[{"id": "intro", "title": "Introduction"}]) + _write_section(conv_id=1, section_id="intro", content="This is the intro.") + result, ok = _get_draft(conv_id=1) + assert ok is True + assert "# The Paper" in result + assert "Introduction" in result + assert "This is the intro." in result + _projects.clear() + + +class TestGetDraftFromProj: + async def test_generates_full_draft(self): + proj = { + "title": "ML Research", + "outline": [ + {"id": "abstract", "title": "Abstract"}, + {"id": "method", "title": "Method", "subsections": [ + {"id": "method.experimental", "title": "Experimental Setup"}, + ]}, + ], + "sections": { + "abstract": "This is the abstract.", + "method": "The method section.", + "method.experimental": "Experimental details.", + }, + "bibliography": [ + {"key": "doe2024", "author": "J. Doe", "title": "Great Paper", "year": "2024"}, + ], + } + result, ok = _get_draft_from_proj(proj) + assert ok is True + assert "ML Research" in result + assert "Abstract" in result + assert "Method" in result + assert "Experimental Setup" in result + assert "References" in result + assert "J. Doe" in result + + +class TestListSections: + async def test_no_project(self): + from openmlr.tools.writing import _projects + _projects.clear() + result, ok = _list_sections(conv_id=999) + assert ok is False + assert "No paper project" in result + + async def test_lists_sections_with_status(self): + from openmlr.tools.writing import _projects + _projects.clear() + _create_project(conv_id=1, title="Test") + _set_outline(conv_id=1, outline=[ + {"id": "s1", "title": "Section 1"}, + {"id": "s2", "title": "Section 2"}, + ]) + _write_section(conv_id=1, section_id="s1", content="written") + result, ok = _list_sections(conv_id=1) + assert ok is True + assert "[done]" in result + assert "[pending]" in result + _projects.clear() + + +class TestAddCitation: + async def test_adds_citation(self): + from openmlr.tools.writing import _projects + _projects.clear() + _create_project(conv_id=1, title="Test") + citation = { + "key": "smith2023", + "title": "Important Paper", + "author": "A. Smith", + "year": "2023", + } + result, ok = _add_citation(conv_id=1, citation=citation) + assert ok is True + assert "Added citation" in result + _projects.clear() + + +class TestRefineSection: + async def test_returns_feedback_mode(self): + from openmlr.tools.writing import _projects + _projects.clear() + _create_project(conv_id=1, title="Test") + _set_outline(conv_id=1, outline=[{"id": "s1", "title": "Section"}]) + _write_section(conv_id=1, section_id="s1", content="original content") + result, ok = _refine_section( + conv_id=1, section_id="s1", + content=None, feedback="make it better", + ) + assert ok is True + assert "feedback" in result.lower() + _projects.clear() + + +class TestCountSections: + async def test_counts_with_subsections(self): + outline = [ + {"id": "a", "title": "A"}, + {"id": "b", "title": "B", "subsections": [ + {"id": "b1", "title": "B1"}, + {"id": "b2", "title": "B2"}, + ]}, + ] + assert _count_sections(outline) == 4 + + async def test_empty_outline(self): + assert _count_sections([]) == 0 diff --git a/backend/tests/test_types.py b/backend/tests/test_types.py new file mode 100644 index 0000000..4271696 --- /dev/null +++ b/backend/tests/test_types.py @@ -0,0 +1,172 @@ +"""Tests for agent core types — AgentEvent, OpType, Message, ToolCall, ToolSpec, LLMResult, Submission.""" + +import json +import pytest + +from openmlr.agent.types import ( + AgentEvent, OpType, Message, ToolCall, ToolSpec, LLMResult, Submission, +) + + +class TestToolCall: + def test_creation(self): + tc = ToolCall(id="call_1", name="read_file", arguments={"path": "/tmp/test"}) + assert tc.id == "call_1" + assert tc.name == "read_file" + assert tc.arguments == {"path": "/tmp/test"} + + def test_empty_arguments(self): + tc = ToolCall(id="c2", name="noop", arguments={}) + assert tc.arguments == {} + + +class TestToolSpec: + def test_creation_without_optional(self): + ts = ToolSpec(name="test_tool", description="A test tool", parameters={"type": "object"}) + assert ts.name == "test_tool" + assert ts.description == "A test tool" + assert ts.parameters == {"type": "object"} + assert ts.handler is None + assert ts.needs_approval is None + + def test_creation_with_handler(self): + async def handler(arg1: str, arg2: int = 0) -> tuple[str, bool]: + return f"done: {arg1}", True + + ts = ToolSpec( + name="with_handler", + description="Tool with handler", + parameters={"type": "object", "properties": {"arg1": {"type": "string"}}}, + handler=handler, + ) + assert ts.handler is not None + assert ts.name == "with_handler" + + def test_needs_approval(self): + ts = ToolSpec( + name="dangerous", + description="Needs approval", + parameters={"type": "object"}, + needs_approval=lambda **kwargs: True, + ) + assert ts.needs_approval is not None + + +class TestMessage: + def test_user_message(self): + msg = Message(role="user", content="Hello") + assert msg.role == "user" + assert msg.content == "Hello" + assert msg.tool_calls is None + assert msg.tool_call_id is None + assert msg.name is None + + def test_assistant_with_tool_calls(self): + tc = ToolCall(id="tc1", name="bash", arguments={"cmd": "ls"}) + msg = Message(role="assistant", content="", tool_calls=[tc]) + assert msg.role == "assistant" + assert msg.content == "" + assert len(msg.tool_calls) == 1 + assert msg.tool_calls[0].name == "bash" + + def test_tool_result_message(self): + msg = Message(role="tool", content="output here", tool_call_id="tc1", name="bash") + assert msg.role == "tool" + assert msg.tool_call_id == "tc1" + assert msg.name == "bash" + + def test_system_message(self): + msg = Message(role="system", content="You are an AI assistant.") + assert msg.role == "system" + + +class TestAgentEvent: + def test_creation_without_data(self): + evt = AgentEvent(event_type="status") + assert evt.event_type == "status" + assert evt.data is None + + def test_creation_with_data(self): + evt = AgentEvent(event_type="status", data={"status": "thinking..."}) + assert evt.data == {"status": "thinking..."} + + def test_creation_kwargs(self): + evt = AgentEvent(event_type="text_delta", data={"text": "hello"}) + assert evt.event_type == "text_delta" + assert evt.data["text"] == "hello" + + def test_to_sse_simple(self): + evt = AgentEvent(event_type="ping") + sse = evt.to_sse() + assert sse.startswith("data: ") + assert sse.endswith("\n\n") + parsed = json.loads(sse[6:-2]) + assert parsed["event_type"] == "ping" + assert parsed["data"] is None + + def test_to_sse_with_data(self): + evt = AgentEvent(event_type="status", data={"status": "ready"}) + sse = evt.to_sse() + parsed = json.loads(sse[6:-2]) + assert parsed["event_type"] == "status" + assert parsed["data"] == {"status": "ready"} + + def test_to_sse_complex_data(self): + evt = AgentEvent(event_type="text_delta", data={"text": "hello\nworld", "index": 0}) + sse = evt.to_sse() + parsed = json.loads(sse[6:-2]) + assert parsed["data"]["text"] == "hello\nworld" + assert parsed["data"]["index"] == 0 + + +class TestOpType: + def test_enum_values(self): + assert OpType.USER_INPUT == "user_input" + assert OpType.EXEC_APPROVAL == "exec_approval" + assert OpType.COMPACT == "compact" + assert OpType.UNDO == "undo" + assert OpType.INTERRUPT == "interrupt" + assert OpType.SHUTDOWN == "shutdown" + + def test_string_equality(self): + assert OpType.USER_INPUT == "user_input" + assert OpType.INTERRUPT != "user_input" + + +class TestSubmission: + def test_creation(self): + sub = Submission(op=OpType.USER_INPUT, data="Send this message") + assert sub.op == OpType.USER_INPUT + assert sub.data == "Send this message" + + def test_no_data_default(self): + sub = Submission(op=OpType.SHUTDOWN) + assert sub.op == OpType.SHUTDOWN + assert sub.data is None + + +class TestLLMResult: + def test_basic_result(self): + result = LLMResult(content="Hello, world!", tool_calls=[], finish_reason="stop") + assert result.content == "Hello, world!" + assert result.tool_calls == [] + assert result.finish_reason == "stop" + assert result.usage is None + + def test_with_tool_calls(self): + tc = ToolCall(id="c1", name="search", arguments={"query": "test"}) + result = LLMResult(content="", tool_calls=[tc], finish_reason="tool_calls") + assert len(result.tool_calls) == 1 + assert result.tool_calls[0].name == "search" + assert result.finish_reason == "tool_calls" + + def test_with_usage(self): + usage = {"prompt_tokens": 100, "completion_tokens": 50, "total_tokens": 150} + result = LLMResult(content="ok", tool_calls=[], finish_reason="stop", usage=usage) + assert result.usage == usage + assert result.usage["total_tokens"] == 150 + + def test_finish_reason_length(self): + """finish_reason can be 'length' for truncated responses.""" + result = LLMResult(content="truncated", tool_calls=[], finish_reason="length") + assert result.finish_reason == "length" diff --git a/frontend/src/__tests__/AgentSettings.test.tsx b/frontend/src/__tests__/AgentSettings.test.tsx new file mode 100644 index 0000000..40ac845 --- /dev/null +++ b/frontend/src/__tests__/AgentSettings.test.tsx @@ -0,0 +1,64 @@ +import { describe, it, expect, vi } from 'vitest'; +import { render, screen, waitFor } from '@testing-library/react'; +import { AgentSettings } from '../components/settings/AgentSettings'; +import { api } from '../api'; + +vi.mock('../api', () => ({ + api: { + getSettings: vi.fn(), + updateSetting: vi.fn(), + }, +})); + +describe('AgentSettings', () => { + beforeEach(() => { + vi.clearAllMocks(); + vi.mocked(api.getSettings).mockResolvedValue({ settings: {} }); + }); + + it('renders model hint', async () => { + render(); + await waitFor(() => { + expect(screen.getByText(/Model used for new conversations/)).toBeInTheDocument(); + }); + }); + + it('renders default model input', async () => { + render(); + await waitFor(() => { + expect(screen.getByText('Default Model')).toBeInTheDocument(); + expect(screen.getByPlaceholderText(/anthropic\/claude-sonnet-4/)).toBeInTheDocument(); + }); + }); + + it('renders research model input', async () => { + render(); + await waitFor(() => { + expect(screen.getByText(/Research \/ Title Model/)).toBeInTheDocument(); + }); + }); + + it('renders yolo mode checkbox', async () => { + render(); + await waitFor(() => { + expect(screen.getByText(/YOLO Mode/)).toBeInTheDocument(); + }); + }); + + it('renders save button', async () => { + render(); + await waitFor(() => { + expect(screen.getByText('Save Agent Settings')).toBeInTheDocument(); + }); + }); + + it('loads settings on mount', async () => { + vi.mocked(api.getSettings).mockResolvedValue({ + settings: { agent: { default_model: 'claude-4', yolo_mode: true } }, + }); + render(); + await waitFor(() => { + expect(api.getSettings).toHaveBeenCalled(); + }); + }); +}); diff --git a/frontend/src/__tests__/ApprovalModal.test.tsx b/frontend/src/__tests__/ApprovalModal.test.tsx new file mode 100644 index 0000000..2f2ff69 --- /dev/null +++ b/frontend/src/__tests__/ApprovalModal.test.tsx @@ -0,0 +1,127 @@ +import { describe, it, expect, vi } from 'vitest'; +import { render, screen, fireEvent } from '@testing-library/react'; +import { ApprovalModal } from '../components/ApprovalModal'; +import { api } from '../api'; + +vi.mock('../api', () => ({ + api: { + sendApproval: vi.fn(), + }, +})); + +describe('ApprovalModal', () => { + it('renders header', () => { + vi.mocked(api.sendApproval).mockResolvedValue({}); + const event = { + data: { + tool_calls: [{ id: 'tc1', name: 'bash', arguments: { cmd: 'ls' } }], + }, + }; + render(); + expect(screen.getByText('Approve Tool Calls')).toBeInTheDocument(); + }); + + it('renders tool names', () => { + vi.mocked(api.sendApproval).mockResolvedValue({}); + const event = { + data: { + tool_calls: [ + { id: 'tc1', name: 'bash', arguments: { cmd: 'ls' } }, + { id: 'tc2', name: 'write', arguments: { path: '/tmp/t', content: 'hi' } }, + ], + }, + }; + render(); + expect(screen.getByText('bash')).toBeInTheDocument(); + expect(screen.getByText('write')).toBeInTheDocument(); + }); + + it('renders approve and reject buttons', () => { + vi.mocked(api.sendApproval).mockResolvedValue({}); + const event = { + data: { + tool_calls: [{ id: 'tc1', name: 'test', arguments: {} }], + }, + }; + render(); + expect(screen.getByRole('button', { name: 'Approve' })).toBeInTheDocument(); + expect(screen.getByRole('button', { name: 'Reject' })).toBeInTheDocument(); + }); + + it('displays tool args', () => { + vi.mocked(api.sendApproval).mockResolvedValue({}); + const event = { + data: { + tool_calls: [{ id: 'tc1', name: 'run', arguments: { query: 'test', count: 5 } }], + }, + }; + render(); + expect(screen.getByText(/query/)).toBeInTheDocument(); + }); + + it('handles string arguments', () => { + vi.mocked(api.sendApproval).mockResolvedValue({}); + const event = { + data: { + tool_calls: [{ id: 'tc1', name: 'echo', arguments: 'hello world' }], + }, + }; + render(); + expect(screen.getByText('hello world')).toBeInTheDocument(); + }); + + it('handles tools field as fallback', () => { + vi.mocked(api.sendApproval).mockResolvedValue({}); + const event = { + data: { + tools: [{ id: 'tc1', name: 'fallback_tool', arguments: {} }], + }, + }; + render(); + expect(screen.getByText('fallback_tool')).toBeInTheDocument(); + }); + + it('calls sendApproval with approved', async () => { + const mockSend = vi.mocked(api.sendApproval).mockResolvedValue({}); + const onClose = vi.fn(); + const event = { + data: { + tool_calls: [{ id: 'tc1', name: 'bash', arguments: {} }], + }, + }; + render(); + + fireEvent.click(screen.getByRole('button', { name: 'Approve' })); + expect(mockSend).toHaveBeenCalledWith({ tc1: true }); + }); + + it('calls sendApproval with rejected', () => { + const mockSend = vi.mocked(api.sendApproval).mockResolvedValue({}); + const onClose = vi.fn(); + const event = { + data: { + tool_calls: [{ id: 'tc1', name: 'bash', arguments: {} }], + }, + }; + render(); + + fireEvent.click(screen.getByRole('button', { name: 'Reject' })); + expect(mockSend).toHaveBeenCalledWith({ tc1: false }); + }); + + it('approves multiple tool calls', () => { + const mockSend = vi.mocked(api.sendApproval).mockResolvedValue({}); + const event = { + data: { + tool_calls: [ + { id: 'a', name: 'one', arguments: {} }, + { id: 'b', name: 'two', arguments: {} }, + { id: 'c', name: 'three', arguments: {} }, + ], + }, + }; + render(); + fireEvent.click(screen.getByText('Approve')); + expect(mockSend).toHaveBeenCalledWith({ a: true, b: true, c: true }); + }); +}); diff --git a/frontend/src/__tests__/AuthGuard.test.tsx b/frontend/src/__tests__/AuthGuard.test.tsx new file mode 100644 index 0000000..babc938 --- /dev/null +++ b/frontend/src/__tests__/AuthGuard.test.tsx @@ -0,0 +1,106 @@ +import { describe, it, expect, vi, beforeEach } from 'vitest'; +import { render, screen, waitFor } from '@testing-library/react'; +import { MemoryRouter } from 'react-router-dom'; +import { AuthGuard } from '../components/AuthGuard'; +import { api, setToken } from '../api'; + +vi.mock('../api', () => ({ + api: { + getMe: vi.fn(), + }, + setToken: vi.fn(), +})); + +const getItemMock = vi.fn(); +const setItemMock = vi.fn(); +const removeItemMock = vi.fn(); + +Object.defineProperty(window, 'localStorage', { + value: { + getItem: getItemMock, + setItem: setItemMock, + removeItem: removeItemMock, + }, + writable: true, +}); + +describe('AuthGuard', () => { + beforeEach(() => { + vi.resetAllMocks(); + getItemMock.mockReturnValue(null); + }); + + it('shows loading when checking token', () => { + getItemMock.mockReturnValue('existing-token'); + vi.mocked(api.getMe).mockImplementation(() => new Promise(() => {})); + + render( + + + + ); + + expect(screen.getByText('Loading...')).toBeInTheDocument(); + }); + + it('redirects to /login when no user and no token', async () => { + render( + + + + ); + + await waitFor(() => { + // Navigate should redirect to /login + expect(screen.queryByText('Loading...')).not.toBeInTheDocument(); + }); + }); + + it('authenticates with valid token', async () => { + getItemMock.mockReturnValue('valid-token'); + const user = { id: 1, username: 'test', display_name: 'Test' }; + vi.mocked(api.getMe).mockResolvedValue(user); + const onAuth = vi.fn(); + + render( + + + + ); + + await waitFor(() => { + expect(onAuth).toHaveBeenCalledWith(user); + expect(screen.queryByText('Loading...')).not.toBeInTheDocument(); + }); + }); + + it('clears token on invalid token', async () => { + getItemMock.mockReturnValue('invalid-token'); + vi.mocked(api.getMe).mockRejectedValue(new Error('Invalid')); + const onAuth = vi.fn(); + + render( + + + + ); + + await waitFor(() => { + expect(setToken).toHaveBeenCalledWith(null); + expect(onAuth).not.toHaveBeenCalled(); + }); + }); + + it('skips check when user already passed in', () => { + const user = { id: 1, username: 'test', display_name: 'Test' }; + vi.mocked(api.getMe).mockResolvedValue(user); + + render( + + + + ); + + expect(api.getMe).not.toHaveBeenCalled(); + }); +}); diff --git a/frontend/src/__tests__/ConfirmDialog.test.tsx b/frontend/src/__tests__/ConfirmDialog.test.tsx new file mode 100644 index 0000000..e5fd9df --- /dev/null +++ b/frontend/src/__tests__/ConfirmDialog.test.tsx @@ -0,0 +1,159 @@ +import { describe, it, expect, vi } from 'vitest'; +import { render, screen, fireEvent } from '@testing-library/react'; +import { ConfirmDialog } from '../components/ConfirmDialog'; + +describe('ConfirmDialog', () => { + it('renders title and message', () => { + render( + + ); + expect(screen.getByText('Delete Item')).toBeInTheDocument(); + expect(screen.getByText('Are you sure you want to delete this item?')).toBeInTheDocument(); + }); + + it('renders default button labels', () => { + render( + + ); + expect(screen.getByRole('button', { name: 'Cancel' })).toBeInTheDocument(); + expect(screen.getByRole('button', { name: 'Confirm' })).toBeInTheDocument(); + }); + + it('renders custom button labels', () => { + render( + + ); + expect(screen.getByRole('button', { name: 'Yes, Delete' })).toBeInTheDocument(); + expect(screen.getByRole('button', { name: 'No, Keep' })).toBeInTheDocument(); + }); + + it('calls onConfirm when confirm button clicked', () => { + const onConfirm = vi.fn(); + render( + + ); + fireEvent.click(screen.getByRole('button', { name: 'Confirm' })); + expect(onConfirm).toHaveBeenCalledTimes(1); + }); + + it('calls onCancel when cancel button clicked', () => { + const onCancel = vi.fn(); + render( + + ); + fireEvent.click(screen.getByRole('button', { name: 'Cancel' })); + expect(onCancel).toHaveBeenCalledTimes(1); + }); + + it('calls onCancel when overlay clicked', () => { + const onCancel = vi.fn(); + render( + + ); + const overlay = document.querySelector('.modal-overlay'); + fireEvent.click(overlay!); + expect(onCancel).toHaveBeenCalled(); + }); + + it('calls onCancel on Escape key', () => { + const onCancel = vi.fn(); + render( + + ); + fireEvent.keyDown(window, { key: 'Escape' }); + expect(onCancel).toHaveBeenCalled(); + }); + + it('does not call onCancel for non-Escape keys', () => { + const onCancel = vi.fn(); + render( + + ); + fireEvent.keyDown(window, { key: 'Enter' }); + expect(onCancel).not.toHaveBeenCalled(); + }); + + it('applies danger class to confirm button', () => { + render( + + ); + const confirmBtn = screen.getByRole('button', { name: 'Confirm' }); + expect(confirmBtn.className).toContain('btn-danger'); + }); + + it('uses btn-confirm class for non-danger', () => { + render( + + ); + const confirmBtn = screen.getByRole('button', { name: 'Confirm' }); + expect(confirmBtn.className).toContain('btn-confirm'); + }); + + it('cleans up Escape listener on unmount', () => { + const onCancel = vi.fn(); + const { unmount } = render( + + ); + unmount(); + fireEvent.keyDown(window, { key: 'Escape' }); + expect(onCancel).not.toHaveBeenCalled(); + }); +}); diff --git a/frontend/src/__tests__/LoginPage.test.tsx b/frontend/src/__tests__/LoginPage.test.tsx new file mode 100644 index 0000000..587b99c --- /dev/null +++ b/frontend/src/__tests__/LoginPage.test.tsx @@ -0,0 +1,212 @@ +import { describe, it, expect, vi } from 'vitest'; +import { render, screen, fireEvent, waitFor } from '@testing-library/react'; +import userEvent from '@testing-library/user-event'; +import { MemoryRouter } from 'react-router-dom'; +import { LoginPage } from '../components/LoginPage'; +import { api, setToken } from '../api'; + +vi.mock('../api', () => ({ + api: { + checkSetup: vi.fn(), + register: vi.fn(), + login: vi.fn(), + }, + setToken: vi.fn(), +})); + +const mockNavigate = vi.fn(); +vi.mock('react-router-dom', async (importOriginal) => { + const actual = await importOriginal() as any; + return { + ...actual, + useNavigate: () => mockNavigate, + }; +}); + +function renderLoginPage(onAuth = vi.fn()) { + return render( + + + + ); +} + +describe('LoginPage', () => { + beforeEach(() => { + vi.resetAllMocks(); + vi.mocked(api.checkSetup).mockResolvedValue({ has_users: true }); + }); + + it('renders sign in form', async () => { + renderLoginPage(); + await waitFor(() => { + expect(screen.getByText('OpenMLR')).toBeInTheDocument(); + }); + expect(screen.getByPlaceholderText('Username')).toBeInTheDocument(); + expect(screen.getByPlaceholderText('Password')).toBeInTheDocument(); + }); + + it('shows login and register tabs when users exist', async () => { + renderLoginPage(); + await waitFor(() => { + const tabs = screen.getAllByRole('button'); + const tabTexts = tabs.map(t => t.textContent); + expect(tabTexts).toContain('Sign In'); + expect(tabTexts).toContain('Register'); + }); + }); + + it('switches to register mode', async () => { + renderLoginPage(); + await waitFor(() => { + expect(screen.getByText('Register')).toBeInTheDocument(); + }); + fireEvent.click(screen.getByText('Register')); + const createBtn = screen.getByText('Create Account'); + expect(createBtn).toBeInTheDocument(); + expect(screen.getByPlaceholderText('Display name (optional)')).toBeInTheDocument(); + }); + + it('shows first-user notice when no users exist', async () => { + vi.mocked(api.checkSetup).mockResolvedValue({ has_users: false }); + renderLoginPage(); + await waitFor(() => { + expect(screen.getByText(/Create your account/)).toBeInTheDocument(); + }); + }); + + it('handles login submission', async () => { + const onAuth = vi.fn(); + vi.mocked(api.login).mockResolvedValue({ + access_token: 'test-token', + user: { id: 1, username: 'testuser', display_name: 'Test User' }, + }); + + renderLoginPage(onAuth); + + await waitFor(() => { + expect(screen.getByPlaceholderText('Username')).toBeInTheDocument(); + }); + + const user = userEvent.setup(); + await user.type(screen.getByPlaceholderText('Username'), 'testuser'); + await user.type(screen.getByPlaceholderText('Password'), 'password123'); + + const submitButton = screen.getByText('Sign In', { selector: '.login-submit' }); + await user.click(submitButton); + + await waitFor(() => { + expect(api.login).toHaveBeenCalledWith('testuser', 'password123'); + expect(setToken).toHaveBeenCalledWith('test-token'); + expect(onAuth).toHaveBeenCalledWith({ id: 1, username: 'testuser', display_name: 'Test User' }); + expect(mockNavigate).toHaveBeenCalledWith('/', { replace: true }); + }); + }); + + it('handles register submission', async () => { + const onAuth = vi.fn(); + vi.mocked(api.register).mockResolvedValue({ + access_token: 'reg-token', + user: { id: 2, username: 'newuser', display_name: 'New' }, + }); + + renderLoginPage(onAuth); + + await waitFor(() => { + fireEvent.click(screen.getByText('Register')); + }); + + const user = userEvent.setup(); + await user.type(screen.getByPlaceholderText('Username'), 'newuser'); + await user.type(screen.getByPlaceholderText('Display name (optional)'), 'New'); + await user.type(screen.getByPlaceholderText('Password'), 'password123'); + await user.click(screen.getByText('Create Account', { selector: '.login-submit' })); + + await waitFor(() => { + expect(api.register).toHaveBeenCalledWith('newuser', 'password123', 'New'); + expect(setToken).toHaveBeenCalledWith('reg-token'); + }); + }); + + it('displays error on auth failure', async () => { + vi.mocked(api.login).mockRejectedValue(new Error('Invalid credentials')); + + renderLoginPage(); + + await waitFor(() => { + expect(screen.getByPlaceholderText('Username')).toBeInTheDocument(); + }); + + const user = userEvent.setup(); + await user.type(screen.getByPlaceholderText('Username'), 'wrong'); + await user.type(screen.getByPlaceholderText('Password'), 'wrong'); + await user.click(screen.getByText('Sign In', { selector: '.login-submit' })); + + await waitFor(() => { + expect(screen.getByText('Invalid credentials')).toBeInTheDocument(); + }); + }); + + it('displays generic error for unknown failures', async () => { + vi.mocked(api.login).mockRejectedValue(new Error()); + + renderLoginPage(); + + await waitFor(() => { + expect(screen.getByPlaceholderText('Username')).toBeInTheDocument(); + }); + + const user = userEvent.setup(); + await user.type(screen.getByPlaceholderText('Username'), 'test'); + await user.type(screen.getByPlaceholderText('Password'), 'pass'); + await user.click(screen.getByText('Sign In', { selector: '.login-submit' })); + + await waitFor(() => { + expect(screen.getByText('Authentication failed')).toBeInTheDocument(); + }); + }); + + it('clears error when switching modes', async () => { + vi.mocked(api.register).mockRejectedValue(new Error('Username taken')); + + renderLoginPage(); + + await waitFor(() => { + fireEvent.click(screen.getByText('Register')); + }); + + // Fill required fields and submit register form + const user = userEvent.setup(); + await user.type(screen.getByPlaceholderText('Username'), 'newuser'); + await user.type(screen.getByPlaceholderText('Password'), 'password'); + await user.click(screen.getByText('Create Account', { selector: '.login-submit' })); + + await waitFor(() => { + expect(screen.getByText('Username taken')).toBeInTheDocument(); + }); + + fireEvent.click(screen.getByRole('button', { name: 'Sign In' })); + await waitFor(() => { + expect(screen.queryByText('Username taken')).not.toBeInTheDocument(); + }); + }); + + it('disables submit button while loading', async () => { + vi.mocked(api.login).mockImplementation(() => new Promise(() => {})); + + renderLoginPage(); + + await waitFor(() => { + expect(screen.getByPlaceholderText('Username')).toBeInTheDocument(); + }); + + const user = userEvent.setup(); + await user.type(screen.getByPlaceholderText('Username'), 'test'); + await user.type(screen.getByPlaceholderText('Password'), 'pass'); + await user.click(screen.getByText('Sign In', { selector: '.login-submit' })); + + await waitFor(() => { + expect(screen.getByText('Please wait...')).toBeInTheDocument(); + }); + }); +}); diff --git a/frontend/src/__tests__/ModelModal.test.tsx b/frontend/src/__tests__/ModelModal.test.tsx new file mode 100644 index 0000000..983f5b0 --- /dev/null +++ b/frontend/src/__tests__/ModelModal.test.tsx @@ -0,0 +1,176 @@ +import { describe, it, expect, vi, beforeEach } from 'vitest'; +import { render, screen, fireEvent, waitFor } from '@testing-library/react'; +import { ModelModal } from '../components/ModelModal'; +import { api } from '../api'; + +vi.mock('../api', () => ({ + api: { + getProviders: vi.fn(), + getModels: vi.fn(), + saveConfig: vi.fn(), + setModel: vi.fn(), + }, +})); + +const defaultProviders = [ + { id: 'openai', name: 'OpenAI', key_env: 'OPENAI_API_KEY', configured: true }, + { id: 'anthropic', name: 'Anthropic', key_env: 'ANTHROPIC_API_KEY', configured: false }, +]; + +const defaultModels = [ + { id: 'openai/gpt-4o', name: 'GPT-4o', provider: 'openai' }, + { id: 'openai/gpt-4o-mini', name: 'GPT-4o Mini', provider: 'openai' }, + { id: 'anthropic/claude-4', name: 'Claude 4', provider: 'anthropic' }, +]; + +describe('ModelModal', () => { + beforeEach(() => { + vi.resetAllMocks(); + vi.mocked(api.getProviders).mockResolvedValue({ providers: defaultProviders }); + vi.mocked(api.getModels).mockResolvedValue({ models: defaultModels }); + }); + + it('renders current model button', () => { + render(); + expect(screen.getByText('openai/gpt-4o')).toBeInTheDocument(); + }); + + it('opens modal on button click', () => { + render(); + fireEvent.click(screen.getByText('openai/gpt-4o')); + expect(screen.getByText('Models')).toBeInTheDocument(); + expect(screen.getByText('Providers')).toBeInTheDocument(); + }); + + it('shows model list when opened', async () => { + render(); + fireEvent.click(screen.getByText('openai/gpt-4o')); + + await waitFor(() => { + expect(screen.getByText('GPT-4o')).toBeInTheDocument(); + expect(screen.getByText('GPT-4o Mini')).toBeInTheDocument(); + expect(screen.getByText('Claude 4')).toBeInTheDocument(); + }); + }); + + it('highlights current model', async () => { + render(); + fireEvent.click(screen.getByText('openai/gpt-4o-mini')); + + await waitFor(() => { + const options = document.querySelectorAll('.model-picker-option'); + const mini = Array.from(options).find(o => o.textContent?.includes('GPT-4o Mini')); + expect(mini?.classList.contains('active')).toBe(true); + }); + }); + + it('switches to providers tab', async () => { + render(); + fireEvent.click(screen.getByText('openai/gpt-4o')); + fireEvent.click(screen.getByText('Providers')); + + await waitFor(() => { + expect(screen.getByText('Configured')).toBeInTheDocument(); + expect(screen.getByText('API key missing')).toBeInTheDocument(); + }); + }); + + it('filters models by provider', async () => { + render(); + fireEvent.click(screen.getByText('openai/gpt-4o')); + + await waitFor(() => { + expect(screen.getByText('Claude 4')).toBeInTheDocument(); + }); + + const select = document.querySelector('select')!; + fireEvent.change(select, { target: { value: 'openai' } }); + + await waitFor(() => { + expect(screen.getByText('GPT-4o')).toBeInTheDocument(); + expect(screen.queryByText('Claude 4')).not.toBeInTheDocument(); + }); + }); + + it('selects model and calls onModelChange', async () => { + vi.mocked(api.setModel).mockResolvedValue({ ok: true }); + const onChange = vi.fn(); + + render(); + fireEvent.click(screen.getByText('openai/gpt-4o')); + + await waitFor(() => { + expect(screen.getByText('Claude 4')).toBeInTheDocument(); + }); + + fireEvent.click(screen.getByText('Claude 4')); + + await waitFor(() => { + expect(api.setModel).toHaveBeenCalledWith('anthropic/claude-4'); + expect(onChange).toHaveBeenCalledWith('anthropic/claude-4'); + }); + }); + + it('closed modal on close button', async () => { + vi.mocked(api.setModel).mockResolvedValue({ ok: true }); + + render(); + fireEvent.click(screen.getByText('openai/gpt-4o')); + + await waitFor(() => { + expect(screen.getByText('GPT-4o')).toBeInTheDocument(); + }); + + fireEvent.click(screen.getByText('Close')); + + await waitFor(() => { + expect(screen.queryByText('GPT-4o Mini')).not.toBeInTheDocument(); + }); + }); + + it('shows custom model input', async () => { + render(); + fireEvent.click(screen.getByText('openai/gpt-4o')); + + await waitFor(() => { + expect(screen.getByPlaceholderText('Custom model ID')).toBeInTheDocument(); + }); + }); + + it('uses custom model on Enter', async () => { + vi.mocked(api.setModel).mockResolvedValue({ ok: true }); + const onModelChange = vi.fn(); + + render(); + fireEvent.click(screen.getByText('openai/gpt-4o')); + + await waitFor(() => { + expect(screen.getByPlaceholderText('Custom model ID')).toBeInTheDocument(); + }); + + const input = screen.getByPlaceholderText('Custom model ID'); + fireEvent.change(input, { target: { value: 'my-custom-model' } }); + fireEvent.keyDown(input, { key: 'Enter' }); + + await waitFor(() => { + expect(api.setModel).toHaveBeenCalledWith('my-custom-model'); + expect(onModelChange).toHaveBeenCalledWith('my-custom-model'); + }); + }); + + it('closes modal when overlay clicked', async () => { + render(); + fireEvent.click(screen.getByText('openai/gpt-4o')); + + await waitFor(() => { + expect(screen.getByText('GPT-4o')).toBeInTheDocument(); + }); + + const overlay = document.querySelector('.modal-overlay'); + fireEvent.click(overlay!); + + await waitFor(() => { + expect(screen.queryByText('GPT-4o Mini')).not.toBeInTheDocument(); + }); + }); +}); diff --git a/frontend/src/__tests__/OnboardingModal.test.tsx b/frontend/src/__tests__/OnboardingModal.test.tsx new file mode 100644 index 0000000..8ac92bc --- /dev/null +++ b/frontend/src/__tests__/OnboardingModal.test.tsx @@ -0,0 +1,185 @@ +import { describe, it, expect, vi, beforeEach } from 'vitest'; +import { render, screen, fireEvent, waitFor } from '@testing-library/react'; +import { OnboardingModal } from '../components/OnboardingModal'; +import { api } from '../api'; + +vi.mock('../api', () => ({ + api: { + getProviders: vi.fn(), + getModels: vi.fn(), + saveConfig: vi.fn(), + setModel: vi.fn(), + }, +})); + +const defaultProviders = [ + { id: 'openai', name: 'OpenAI', key_env: 'OPENAI_API_KEY', configured: false }, + { id: 'anthropic', name: 'Anthropic', key_env: 'ANTHROPIC_API_KEY', configured: false }, + { id: 'openrouter', name: 'OpenRouter', key_env: 'OPENROUTER_API_KEY', configured: false }, + { id: 'opencode-go', name: 'OpenCode Go', key_env: 'OPENCODE_GO_API_KEY', configured: false }, + { id: 'ollama', name: 'Ollama', key_env: 'OLLAMA_API_BASE', configured: false }, + { id: 'brave', name: 'Brave Search', key_env: 'BRAVE_API_KEY', configured: false }, +]; + +const defaultModels = [ + { id: 'openai/gpt-4o', name: 'GPT-4o', provider: 'openai' }, + { id: 'anthropic/claude-4', name: 'Claude 4', provider: 'anthropic' }, +]; + +describe('OnboardingModal', () => { + beforeEach(() => { + vi.resetAllMocks(); + vi.mocked(api.getProviders).mockResolvedValue({ providers: defaultProviders }); + vi.mocked(api.getModels).mockResolvedValue({ models: defaultModels }); + }); + + it('renders welcome heading', async () => { + render(); + await waitFor(() => { + expect(screen.getByText('Welcome to OpenMLR')).toBeInTheDocument(); + }); + }); + + it('shows loading state initially', () => { + vi.mocked(api.getProviders).mockImplementation(() => new Promise(() => {})); + render(); + expect(screen.getByText('Loading...')).toBeInTheDocument(); + }); + + it('shows providers step by default', async () => { + render(); + await waitFor(() => { + expect(screen.getByText(/Configure at least one LLM provider/)).toBeInTheDocument(); + }); + }); + + it('renders LLM providers', async () => { + render(); + await waitFor(() => { + expect(screen.getByText('OpenAI')).toBeInTheDocument(); + expect(screen.getByText('Anthropic')).toBeInTheDocument(); + }); + }); + + it('shows key inputs for unconfigured providers', async () => { + render(); + await waitFor(() => { + expect(screen.getByPlaceholderText(/OPENAI_API_KEY/)).toBeInTheDocument(); + }); + }); + + it('shows "configured" for configured providers', async () => { + vi.mocked(api.getProviders).mockResolvedValue({ + providers: [ + { id: 'openai', name: 'OpenAI', key_env: 'OPENAI_API_KEY', configured: true }, + { id: 'anthropic', name: 'Anthropic', key_env: 'ANTHROPIC_API_KEY', configured: false }, + { id: 'openrouter', name: 'OpenRouter', key_env: 'OPENROUTER_API_KEY', configured: false }, + { id: 'opencode-go', name: 'OpenCode Go', key_env: 'OPENCODE_GO_API_KEY', configured: false }, + { id: 'ollama', name: 'Ollama', key_env: 'OLLAMA_API_BASE', configured: false }, + ], + }); + render(); + // When at least one provider is configured, step auto-skips to model selection + await waitFor(() => { + expect(screen.getByText(/Pick a model/)).toBeInTheDocument(); + }); + }); + + it('skips to model selection when providers are configured', async () => { + vi.mocked(api.getProviders).mockResolvedValue({ + providers: [ + { id: 'openai', name: 'OpenAI', key_env: 'OPENAI_API_KEY', configured: true }, + ], + }); + render(); + await waitFor(() => { + expect(screen.getByText(/Pick a model/)).toBeInTheDocument(); + }); + }); + + it('disables save button when no keys entered', async () => { + render(); + await waitFor(() => { + const btn = screen.getByText('Save & Continue'); + expect(btn).toBeDisabled(); + }); + }); + + it('saves keys and moves to model step', async () => { + vi.mocked(api.saveConfig).mockResolvedValue({ ok: true }); + vi.mocked(api.getProviders) + .mockResolvedValueOnce({ providers: defaultProviders }) + .mockResolvedValueOnce({ + providers: [{ ...defaultProviders[0], configured: true }, ...defaultProviders.slice(1)], + }); + + render(); + await waitFor(() => { + expect(screen.getByPlaceholderText(/OPENAI_API_KEY/)).toBeInTheDocument(); + }); + + // Type a key + const input = screen.getByPlaceholderText(/OPENAI_API_KEY/); + fireEvent.change(input, { target: { value: 'sk-test-key' } }); + + fireEvent.click(screen.getByText('Save & Continue')); + + await waitFor(() => { + expect(api.saveConfig).toHaveBeenCalled(); + expect(screen.getByText(/Pick a model/)).toBeInTheDocument(); + }); + }); + + it('shows models in model step', async () => { + vi.mocked(api.getProviders).mockResolvedValue({ + providers: [{ ...defaultProviders[0], configured: true }], + }); + render(); + await waitFor(() => { + expect(screen.getByText('GPT-4o')).toBeInTheDocument(); + expect(screen.getByText('Claude 4')).toBeInTheDocument(); + }); + }); + + it('filters models by provider', async () => { + vi.mocked(api.getProviders).mockResolvedValue({ + providers: [ + { id: 'openai', name: 'OpenAI', key_env: 'OPENAI_API_KEY', configured: true }, + { id: 'anthropic', name: 'Anthropic', key_env: 'ANTHROPIC_API_KEY', configured: true }, + ], + }); + render(); + await waitFor(() => { + expect(screen.getByText(/Pick a model/)).toBeInTheDocument(); + }); + + // Select OpenAI from provider filter + const select = document.querySelector('select')!; + fireEvent.change(select, { target: { value: 'openai' } }); + + await waitFor(() => { + expect(screen.getByText('GPT-4o')).toBeInTheDocument(); + expect(screen.queryByText('Claude 4')).not.toBeInTheDocument(); + }); + }); + + it('selects model and completes', async () => { + const onComplete = vi.fn(); + vi.mocked(api.setModel).mockResolvedValue({ ok: true }); + vi.mocked(api.getProviders).mockResolvedValue({ + providers: [{ ...defaultProviders[0], configured: true }], + }); + + render(); + await waitFor(() => { + expect(screen.getByText('GPT-4o')).toBeInTheDocument(); + }); + + fireEvent.click(screen.getByText('GPT-4o')); + + await waitFor(() => { + expect(api.setModel).toHaveBeenCalledWith('openai/gpt-4o'); + expect(onComplete).toHaveBeenCalledWith('openai/gpt-4o'); + }); + }); +}); diff --git a/frontend/src/__tests__/ProvidersSettings.test.tsx b/frontend/src/__tests__/ProvidersSettings.test.tsx new file mode 100644 index 0000000..6d3b98a --- /dev/null +++ b/frontend/src/__tests__/ProvidersSettings.test.tsx @@ -0,0 +1,86 @@ +import { describe, it, expect, vi } from 'vitest'; +import { render, screen, fireEvent, waitFor } from '@testing-library/react'; +import { ProvidersSettings } from '../components/settings/ProvidersSettings'; +import { api } from '../api'; + +vi.mock('../api', () => ({ + api: { + getProviders: vi.fn(), + updateSetting: vi.fn(), + }, +})); + +describe('ProvidersSettings', () => { + const mockProviders = [ + { id: 'openai', name: 'OpenAI', key_env: 'OPENAI_API_KEY', configured: true }, + { id: 'anthropic', name: 'Anthropic', key_env: 'ANTHROPIC_API_KEY', configured: false }, + ]; + + beforeEach(() => { + vi.clearAllMocks(); + vi.mocked(api.getProviders).mockResolvedValue({ providers: mockProviders }); + }); + + it('renders hint', async () => { + render(); + await waitFor(() => { + expect(screen.getByText(/API keys are stored/)).toBeInTheDocument(); + }); + }); + + it('renders provider names', async () => { + render(); + await waitFor(() => { + expect(screen.getByText('OpenAI')).toBeInTheDocument(); + expect(screen.getByText('Anthropic')).toBeInTheDocument(); + }); + }); + + it('shows configured/not set status', async () => { + render(); + await waitFor(() => { + expect(screen.getByText('Configured')).toBeInTheDocument(); + expect(screen.getByText('Not set')).toBeInTheDocument(); + }); + }); + + it('renders key input placeholders', async () => { + render(); + await waitFor(() => { + expect(screen.getByPlaceholderText('OPENAI_API_KEY')).toBeInTheDocument(); + expect(screen.getByPlaceholderText('ANTHROPIC_API_KEY')).toBeInTheDocument(); + }); + }); + + it('renders save button', async () => { + render(); + await waitFor(() => { + expect(screen.getByText('Save Keys')).toBeInTheDocument(); + }); + }); + + it('disables save when no keys entered', async () => { + render(); + await waitFor(() => { + const btn = screen.getByText('Save Keys'); + expect(btn).toBeDisabled(); + }); + }); + + it('enables save when key entered', async () => { + render(); + await waitFor(() => { + const input = screen.getByPlaceholderText('ANTHROPIC_API_KEY'); + fireEvent.change(input, { target: { value: 'sk-ant-test' } }); + }); + const btn = screen.getByText('Save Keys'); + expect(btn).not.toBeDisabled(); + }); + + it('loads providers on mount', async () => { + render(); + await waitFor(() => { + expect(api.getProviders).toHaveBeenCalled(); + }); + }); +}); diff --git a/frontend/src/__tests__/QuestionDrawer.test.tsx b/frontend/src/__tests__/QuestionDrawer.test.tsx new file mode 100644 index 0000000..c61e459 --- /dev/null +++ b/frontend/src/__tests__/QuestionDrawer.test.tsx @@ -0,0 +1,136 @@ +import { describe, it, expect, vi } from 'vitest'; +import { render, screen, fireEvent, waitFor } from '@testing-library/react'; +import { QuestionDrawer } from '../components/QuestionDrawer'; +import { api } from '../api'; +import type { QuestionsPayload } from '../types'; + +vi.mock('../api', () => ({ + api: { + submitAnswers: vi.fn(), + }, +})); + +const defaultPayload: QuestionsPayload = { + questions: [ + { + id: 'q1', + question: 'What is your preference?', + options: [ + { label: 'Option A', description: 'First option' }, + { label: 'Option B' }, + { label: 'Option C', description: 'Third option' }, + ], + }, + { + id: 'q2', + question: 'Choose a framework.', + options: [ + { label: 'PyTorch' }, + { label: 'JAX', description: 'Functional approach' }, + ], + }, + ], + context: 'Project setup questions', +}; + +describe('QuestionDrawer', () => { + it('renders context title', () => { + vi.mocked(api.submitAnswers).mockResolvedValue({}); + render( + + ); + expect(screen.getByText('Project setup questions')).toBeInTheDocument(); + }); + + it('renders question tabs', () => { + vi.mocked(api.submitAnswers).mockResolvedValue({}); + render( + + ); + expect(screen.getByText('1')).toBeInTheDocument(); + expect(screen.getByText('2')).toBeInTheDocument(); + }); + + it('renders current question text', () => { + vi.mocked(api.submitAnswers).mockResolvedValue({}); + render( + + ); + expect(screen.getByText('What is your preference?')).toBeInTheDocument(); + }); + + it('renders options with descriptions', () => { + vi.mocked(api.submitAnswers).mockResolvedValue({}); + render( + + ); + expect(screen.getByText('Option A')).toBeInTheDocument(); + expect(screen.getByText('First option')).toBeInTheDocument(); + expect(screen.getByText('Option C')).toBeInTheDocument(); + expect(screen.getByText('Third option')).toBeInTheDocument(); + }); + + it('shows text input by default', () => { + vi.mocked(api.submitAnswers).mockResolvedValue({}); + render( + + ); + expect(screen.getByPlaceholderText('Or type your own answer...')).toBeInTheDocument(); + }); + + it('navigates to next question', () => { + vi.mocked(api.submitAnswers).mockResolvedValue({}); + render( + + ); + fireEvent.click(screen.getByText('Next')); + expect(screen.getByText('Choose a framework.')).toBeInTheDocument(); + }); + + it('shows progress', () => { + vi.mocked(api.submitAnswers).mockResolvedValue({}); + render( + + ); + expect(screen.getByText('0 / 2')).toBeInTheDocument(); + }); + + it('selects option and advances', async () => { + vi.mocked(api.submitAnswers).mockResolvedValue({}); + render( + + ); + fireEvent.click(screen.getByText('Option A')); + await waitFor(() => { + expect(screen.getByText('Choose a framework.')).toBeInTheDocument(); + }); + }); + + it('calls submit with answers', async () => { + const mockSubmit = vi.mocked(api.submitAnswers).mockResolvedValue({}); + const onDone = vi.fn(); + render( + + ); + // Answer first question + fireEvent.click(screen.getByText('Option A')); + await waitFor(() => { + expect(screen.getByText('Choose a framework.')).toBeInTheDocument(); + }); + // Answer second question + fireEvent.click(screen.getByText('PyTorch')); + + fireEvent.click(screen.getByText('Submit')); + expect(mockSubmit).toHaveBeenCalledWith({ q1: 'Option A', q2: 'PyTorch' }); + }); + + it('calls onClose when X clicked', () => { + vi.mocked(api.submitAnswers).mockResolvedValue({}); + const onClose = vi.fn(); + render( + + ); + fireEvent.click(screen.getByText('×')); + expect(onClose).toHaveBeenCalled(); + }); +}); diff --git a/frontend/src/__tests__/ReportDrawer.test.tsx b/frontend/src/__tests__/ReportDrawer.test.tsx new file mode 100644 index 0000000..c09bcff --- /dev/null +++ b/frontend/src/__tests__/ReportDrawer.test.tsx @@ -0,0 +1,82 @@ +import { describe, it, expect, vi } from 'vitest'; +import { render, screen } from '@testing-library/react'; +import ReactMarkdown from 'react-markdown'; +import { ReportDrawer } from '../components/ReportDrawer'; + +vi.mock('../api', () => ({ + api: { + getReport: vi.fn(), + }, +})); + +describe('ReportDrawer', () => { + it('renders title and close button', () => { + render( + + ); + expect(screen.getByText('My Test Report')).toBeInTheDocument(); + expect(screen.getByText('×')).toBeInTheDocument(); + }); + + it('renders cached content without loading', () => { + render( + + ); + expect(screen.queryByText('Loading...')).not.toBeInTheDocument(); + }); + + it('shows loading state without cached content', async () => { + const { api } = await import('../api'); + vi.mocked(api.getReport).mockImplementation(() => new Promise(() => {})); // never resolves + render( + + ); + expect(screen.getByText('Loading...')).toBeInTheDocument(); + }); + + it('calls onClose when close button clicked', () => { + const onClose = vi.fn(); + render( + + ); + screen.getByText('×').click(); + expect(onClose).toHaveBeenCalled(); + }); + + it('calls onClose when overlay clicked', () => { + const onClose = vi.fn(); + render( + + ); + const overlay = document.querySelector('.report-overlay'); + if (overlay) { + (overlay as HTMLElement).click(); + expect(onClose).toHaveBeenCalled(); + } + }); +}); diff --git a/frontend/src/__tests__/RightPanel.test.tsx b/frontend/src/__tests__/RightPanel.test.tsx new file mode 100644 index 0000000..4dc575d --- /dev/null +++ b/frontend/src/__tests__/RightPanel.test.tsx @@ -0,0 +1,185 @@ +import { describe, it, expect, vi } from 'vitest'; +import { render, screen } from '@testing-library/react'; +import { RightPanel } from '../components/RightPanel'; +import type { PlanTask, Resource, ContextUsage, SearchBudget } from '../types'; + +vi.mock('../api', () => ({ + api: { + getReport: vi.fn(), + }, +})); + +describe('RightPanel', () => { + const mockTasks: PlanTask[] = [ + { title: 'Read papers', status: 'completed' }, + { title: 'Implement model', status: 'in_progress' }, + { title: 'Write report', status: 'pending' }, + ]; + + const mockResources: Resource[] = [ + { title: 'Plan', type: 'plan', id: 'plan-1' }, + { title: 'ArXiv Paper', type: 'paper', id: 'paper-1' }, + { title: 'Dataset X', type: 'dataset', url: 'https://example.com' }, + ]; + + const mockContext: ContextUsage = { used: 50000, max: 200000, ratio: 0.25 }; + const mockSearchBudget: SearchBudget = { used: 5, max: 25 }; + + it('renders toggle button when not visible', () => { + render( + + ); + expect(screen.getByTitle('Tasks & resources')).toBeInTheDocument(); + }); + + it('renders tasks when visible', () => { + render( + + ); + expect(screen.getByText('Read papers')).toBeInTheDocument(); + expect(screen.getByText('Implement model')).toBeInTheDocument(); + expect(screen.getByText('Write report')).toBeInTheDocument(); + }); + + it('shows task count', () => { + render( + + ); + expect(screen.getByText('Tasks (1/3)')).toBeInTheDocument(); + }); + + it('shows "No tasks yet" when empty', () => { + render( + + ); + expect(screen.getByText('No tasks yet')).toBeInTheDocument(); + expect(screen.getByText('No resources yet')).toBeInTheDocument(); + }); + + it('renders context gauge', () => { + render( + + ); + expect(screen.getByText(/50k/)).toBeInTheDocument(); + expect(screen.getByText(/200k/)).toBeInTheDocument(); + }); + + it('renders search budget gauge', () => { + render( + + ); + expect(screen.getByText(/Searches:/)).toBeInTheDocument(); + }); + + it('renders resources', () => { + render( + + ); + expect(screen.getByText('Dataset X')).toBeInTheDocument(); + }); + + it('renders paper resource with export buttons', () => { + render( + + ); + expect(screen.getByText('.md')).toBeInTheDocument(); + expect(screen.getByText('.tex')).toBeInTheDocument(); + }); + + it('hides toggle badge when no tasks and visible', () => { + render( + + ); + expect(screen.queryByTitle('Tasks & resources')).not.toBeInTheDocument(); + }); + + it('shows toggle badge with task count when collapsed with tasks', () => { + render( + + ); + const badge = document.querySelector('.toggle-badge'); + expect(badge?.textContent).toBe('3'); + }); +}); diff --git a/frontend/src/__tests__/SandboxSettings.test.tsx b/frontend/src/__tests__/SandboxSettings.test.tsx new file mode 100644 index 0000000..3916af2 --- /dev/null +++ b/frontend/src/__tests__/SandboxSettings.test.tsx @@ -0,0 +1,60 @@ +import { describe, it, expect, vi } from 'vitest'; +import { render, screen, fireEvent, waitFor } from '@testing-library/react'; +import { SandboxSettings } from '../components/settings/SandboxSettings'; +import { api } from '../api'; + +vi.mock('../api', () => ({ + api: { + getSettings: vi.fn(), + updateSetting: vi.fn(), + }, +})); + +describe('SandboxSettings', () => { + beforeEach(() => { + vi.clearAllMocks(); + vi.mocked(api.getSettings).mockResolvedValue({ settings: {} }); + }); + + it('renders heading and hint', async () => { + render(); + await waitFor(() => { + expect(screen.getByText(/Execution environment/)).toBeInTheDocument(); + }); + }); + + it('renders default sandbox select', async () => { + render(); + await waitFor(() => { + expect(screen.getByText('Default Sandbox')).toBeInTheDocument(); + expect(screen.getByText('Local')).toBeInTheDocument(); + expect(screen.getByText('SSH Remote')).toBeInTheDocument(); + expect(screen.getByText('Modal Cloud')).toBeInTheDocument(); + }); + }); + + it('renders modal token fields', async () => { + render(); + await waitFor(() => { + expect(screen.getByPlaceholderText('MODAL_TOKEN_ID')).toBeInTheDocument(); + expect(screen.getByPlaceholderText('MODAL_TOKEN_SECRET')).toBeInTheDocument(); + }); + }); + + it('renders save button', async () => { + render(); + await waitFor(() => { + expect(screen.getByText('Save Sandbox Settings')).toBeInTheDocument(); + }); + }); + + it('loads settings on mount', async () => { + vi.mocked(api.getSettings).mockResolvedValue({ + settings: { sandbox: { default_sandbox: 'ssh' } }, + }); + render(); + await waitFor(() => { + expect(api.getSettings).toHaveBeenCalled(); + }); + }); +}); diff --git a/frontend/src/__tests__/SettingsPage.test.tsx b/frontend/src/__tests__/SettingsPage.test.tsx new file mode 100644 index 0000000..e736773 --- /dev/null +++ b/frontend/src/__tests__/SettingsPage.test.tsx @@ -0,0 +1,44 @@ +import { describe, it, expect, vi, beforeEach } from 'vitest'; +import { render, screen, fireEvent, waitFor } from '@testing-library/react'; +import { MemoryRouter } from 'react-router-dom'; +import { SettingsPage } from '../components/SettingsPage'; + +describe('SettingsPage', () => { + function renderSettings(path = '/settings/agent') { + return render( + + + + ); + } + + it('renders back link', () => { + renderSettings(); + expect(screen.getByText(/Back to chat/)).toBeInTheDocument(); + }); + + it('renders Settings title', () => { + renderSettings(); + expect(screen.getByText('Settings')).toBeInTheDocument(); + }); + + it('renders all nav links', () => { + renderSettings(); + expect(screen.getByText('Providers')).toBeInTheDocument(); + expect(screen.getByText('Agent')).toBeInTheDocument(); + expect(screen.getByText('Sandbox')).toBeInTheDocument(); + expect(screen.getByText('Writing')).toBeInTheDocument(); + }); + + it('highlights active nav link', () => { + renderSettings('/settings/agent'); + const agentLink = screen.getByText('Agent'); + expect(agentLink.className).toContain('active'); + }); + + it('does not highlight inactive links', () => { + renderSettings('/settings/agent'); + const providers = screen.getByText('Providers'); + expect(providers.className).not.toContain('active'); + }); +}); diff --git a/frontend/src/__tests__/SettingsPanel.test.tsx b/frontend/src/__tests__/SettingsPanel.test.tsx new file mode 100644 index 0000000..633e7a6 --- /dev/null +++ b/frontend/src/__tests__/SettingsPanel.test.tsx @@ -0,0 +1,104 @@ +import { describe, it, expect, vi } from 'vitest'; +import { render, screen, fireEvent, waitFor } from '@testing-library/react'; +import { SettingsPanel } from '../components/SettingsPanel'; +import { api } from '../api'; + +vi.mock('../api', () => ({ + api: { + getProviders: vi.fn(), + getSettings: vi.fn(), + updateSetting: vi.fn(), + }, +})); + +describe('SettingsPanel', () => { + beforeEach(() => { + vi.resetAllMocks(); + vi.mocked(api.getProviders).mockResolvedValue({ + providers: [ + { id: 'openai', name: 'OpenAI', key_env: 'OPENAI_API_KEY', configured: true }, + { id: 'anthropic', name: 'Anthropic', key_env: 'ANTHROPIC_API_KEY', configured: false }, + ], + }); + vi.mocked(api.getSettings).mockResolvedValue({ settings: {} }); + }); + + it('renders settings heading', async () => { + render(); + await waitFor(() => { + expect(screen.getByText('Settings')).toBeInTheDocument(); + }); + }); + + it('renders tab buttons', async () => { + render(); + await waitFor(() => { + expect(screen.getByText('Providers')).toBeInTheDocument(); + expect(screen.getByText('Agent')).toBeInTheDocument(); + expect(screen.getByText('Sandbox')).toBeInTheDocument(); + expect(screen.getByText('Writing')).toBeInTheDocument(); + }); + }); + + it('shows providers by default', async () => { + render(); + await waitFor(() => { + expect(screen.getByText('OpenAI')).toBeInTheDocument(); + expect(screen.getByText('Anthropic')).toBeInTheDocument(); + }); + }); + + it('shows configured status', async () => { + render(); + await waitFor(() => { + expect(screen.getByText('Configured')).toBeInTheDocument(); + expect(screen.getByText('Not set')).toBeInTheDocument(); + }); + }); + + it('switches to agent tab', async () => { + render(); + await waitFor(() => { + fireEvent.click(screen.getByText('Agent')); + }); + expect(screen.getByText('Default Model')).toBeInTheDocument(); + expect(screen.getByText(/YOLO Mode/)).toBeInTheDocument(); + }); + + it('switches to sandbox tab', async () => { + render(); + await waitFor(() => { + fireEvent.click(screen.getByText('Sandbox')); + }); + expect(screen.getByText('Default Sandbox')).toBeInTheDocument(); + }); + + it('switches to writing tab', async () => { + render(); + await waitFor(() => { + fireEvent.click(screen.getByText('Writing')); + }); + expect(screen.getByText('Citation Style')).toBeInTheDocument(); + expect(screen.getByText('Export Format')).toBeInTheDocument(); + }); + + it('calls onClose when close clicked', async () => { + const onClose = vi.fn(); + render(); + await waitFor(() => { + const closeBtn = screen.getByText('×'); + fireEvent.click(closeBtn); + }); + expect(onClose).toHaveBeenCalled(); + }); + + it('calls onClose when overlay clicked', async () => { + const onClose = vi.fn(); + render(); + await waitFor(() => { + const overlay = document.querySelector('.modal-overlay'); + fireEvent.click(overlay!); + }); + expect(onClose).toHaveBeenCalled(); + }); +}); diff --git a/frontend/src/__tests__/Sidebar.test.tsx b/frontend/src/__tests__/Sidebar.test.tsx new file mode 100644 index 0000000..1553492 --- /dev/null +++ b/frontend/src/__tests__/Sidebar.test.tsx @@ -0,0 +1,237 @@ +import { describe, it, expect, vi } from 'vitest'; +import { render, screen, fireEvent } from '@testing-library/react'; +import { MemoryRouter } from 'react-router-dom'; +import { Sidebar } from '../components/Sidebar'; +import type { Conversation, User } from '../types'; + +vi.mock('../api', () => ({ + setToken: vi.fn(), +})); + +const mockNavigate = vi.fn(); +vi.mock('react-router-dom', async (importOriginal) => { + const actual = await importOriginal() as any; + return { + ...actual, + useNavigate: () => mockNavigate, + }; +}); + +const mockUser: User = { id: 1, username: 'tester', display_name: 'Test User' }; + +const mockConversations: Conversation[] = [ + { + id: 1, + uuid: 'conv-1', + title: 'First conversation', + model: 'gpt-4o', + mode: 'general', + user_message_count: 3, + created_at: new Date().toISOString(), + updated_at: new Date().toISOString(), + }, + { + id: 2, + uuid: 'conv-2', + title: 'Research project', + model: 'claude-4', + mode: 'research', + user_message_count: 10, + created_at: new Date(Date.now() - 86400000 * 3).toISOString(), + updated_at: new Date(Date.now() - 86400000 * 3).toISOString(), + }, +]; + +describe('Sidebar', () => { + it('renders new chat button', () => { + render( + + + + ); + expect(screen.getByText('+ New Chat')).toBeInTheDocument(); + }); + + it('renders conversation titles', () => { + render( + + + + ); + expect(screen.getByText('First conversation')).toBeInTheDocument(); + expect(screen.getByText('Research project')).toBeInTheDocument(); + }); + + it('highlights current conversation', () => { + render( + + + + ); + const items = document.querySelectorAll('.conversation-item'); + expect(items[0].classList.contains('active')).toBe(true); + }); + + it('calls onSwitch when conversation clicked', () => { + const onSwitch = vi.fn(); + render( + + + + ); + fireEvent.click(screen.getByText('First conversation')); + expect(onSwitch).toHaveBeenCalledWith('conv-1'); + }); + + it('shows empty state when no conversations', () => { + render( + + + + ); + expect(screen.getByText('No conversations yet')).toBeInTheDocument(); + }); + + it('shows user display name', () => { + render( + + + + ); + expect(screen.getByText('Test User')).toBeInTheDocument(); + }); + + it('renders action buttons', () => { + render( + + + + ); + expect(screen.getByText(/Undo/)).toBeInTheDocument(); + expect(screen.getByText(/Compact/)).toBeInTheDocument(); + }); + + it('calls onAction for undo', () => { + const onAction = vi.fn(); + render( + + + + ); + fireEvent.click(screen.getByText(/Undo/)); + expect(onAction).toHaveBeenCalledWith('undo'); + }); + + it('calls onAction for compact', () => { + const onAction = vi.fn(); + render( + + + + ); + fireEvent.click(screen.getByText(/Compact/)); + expect(onAction).toHaveBeenCalledWith('compact'); + }); + + it('filters conversations by search', () => { + render( + + + + ); + const searchInput = screen.getByPlaceholderText('Search...'); + fireEvent.change(searchInput, { target: { value: 'Research' } }); + expect(screen.getByText('Research project')).toBeInTheDocument(); + expect(screen.queryByText('First conversation')).not.toBeInTheDocument(); + }); +}); diff --git a/frontend/src/__tests__/WritingSettings.test.tsx b/frontend/src/__tests__/WritingSettings.test.tsx new file mode 100644 index 0000000..c83a2c6 --- /dev/null +++ b/frontend/src/__tests__/WritingSettings.test.tsx @@ -0,0 +1,60 @@ +import { describe, it, expect, vi } from 'vitest'; +import { render, screen, waitFor } from '@testing-library/react'; +import { WritingSettings } from '../components/settings/WritingSettings'; +import { api } from '../api'; + +vi.mock('../api', () => ({ + api: { + getSettings: vi.fn(), + updateSetting: vi.fn(), + }, +})); + +describe('WritingSettings', () => { + beforeEach(() => { + vi.clearAllMocks(); + vi.mocked(api.getSettings).mockResolvedValue({ settings: {} }); + }); + + it('renders hint', async () => { + render(); + await waitFor(() => { + expect(screen.getByText('Paper writing preferences.')).toBeInTheDocument(); + }); + }); + + it('renders citation style select', async () => { + render(); + await waitFor(() => { + expect(screen.getByText('Citation Style')).toBeInTheDocument(); + expect(screen.getByText('APA')).toBeInTheDocument(); + expect(screen.getByText('IEEE')).toBeInTheDocument(); + }); + }); + + it('renders export format select', async () => { + render(); + await waitFor(() => { + expect(screen.getByText('Export Format')).toBeInTheDocument(); + expect(screen.getByText('Markdown')).toBeInTheDocument(); + expect(screen.getByText('LaTeX')).toBeInTheDocument(); + }); + }); + + it('renders save button', async () => { + render(); + await waitFor(() => { + expect(screen.getByText('Save Writing Settings')).toBeInTheDocument(); + }); + }); + + it('loads settings on mount', async () => { + vi.mocked(api.getSettings).mockResolvedValue({ + settings: { writing: { citation_style: 'ieee' } }, + }); + render(); + await waitFor(() => { + expect(api.getSettings).toHaveBeenCalled(); + }); + }); +}); diff --git a/frontend/src/__tests__/useJobStatus.test.ts b/frontend/src/__tests__/useJobStatus.test.ts new file mode 100644 index 0000000..c648928 --- /dev/null +++ b/frontend/src/__tests__/useJobStatus.test.ts @@ -0,0 +1,213 @@ +import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'; +import { renderHook, waitFor, act } from '@testing-library/react'; +import { useJobStatus } from '../hooks/useJobStatus'; +import { api } from '../api'; + +vi.mock('../api', () => ({ + api: { + getConversationJobs: vi.fn(), + }, +})); + +describe('useJobStatus', () => { + beforeEach(() => { + vi.resetAllMocks(); + vi.useFakeTimers({ shouldAdvanceTime: true }); + }); + + afterEach(() => { + vi.useRealTimers(); + }); + + async function flushPromises() { + await act(async () => { + await Promise.resolve(); + }); + } + + it('returns empty jobs when disabled', async () => { + const { result } = renderHook(() => + useJobStatus({ conversationUuid: 'test-uuid', enabled: false }) + ); + expect(result.current.activeJobs).toEqual([]); + expect(result.current.isProcessing).toBe(false); + }); + + it('returns empty jobs when conversationUuid is null', async () => { + const { result } = renderHook(() => + useJobStatus({ conversationUuid: null }) + ); + expect(result.current.activeJobs).toEqual([]); + }); + + it('fetches jobs on mount', async () => { + const mockJobs = [ + { job_id: 'job-1', status: 'queued', created_at: '2024-01-01T00:00:00Z' }, + { job_id: 'job-2', status: 'running', created_at: '2024-01-01T00:01:00Z' }, + ]; + vi.mocked(api.getConversationJobs).mockResolvedValue({ jobs: mockJobs }); + + renderHook(() => + useJobStatus({ conversationUuid: 'test-uuid', pollInterval: 10000 }) + ); + + await flushPromises(); + + expect(api.getConversationJobs).toHaveBeenCalledTimes(1); + }); + + it('detects processing state for queued jobs', async () => { + const mockJobs = [ + { job_id: 'job-1', status: 'queued', created_at: '2024-01-01T00:00:00Z' }, + ]; + vi.mocked(api.getConversationJobs).mockResolvedValue({ jobs: mockJobs }); + + const { result } = renderHook(() => + useJobStatus({ conversationUuid: 'test-uuid', pollInterval: 10000 }) + ); + + await flushPromises(); + + expect(result.current.isProcessing).toBe(true); + }); + + it('detects processing state for running jobs', async () => { + const mockJobs = [ + { job_id: 'job-1', status: 'running', created_at: '2024-01-01T00:00:00Z' }, + ]; + vi.mocked(api.getConversationJobs).mockResolvedValue({ jobs: mockJobs }); + + const { result } = renderHook(() => + useJobStatus({ conversationUuid: 'test-uuid', pollInterval: 10000 }) + ); + + await flushPromises(); + + expect(result.current.isProcessing).toBe(true); + }); + + it('not processing when all jobs completed', async () => { + const mockJobs = [ + { job_id: 'job-1', status: 'completed', created_at: '2024-01-01T00:00:00Z' }, + ]; + vi.mocked(api.getConversationJobs).mockResolvedValue({ jobs: mockJobs }); + + const { result } = renderHook(() => + useJobStatus({ conversationUuid: 'test-uuid', pollInterval: 10000 }) + ); + + await flushPromises(); + + expect(result.current.isProcessing).toBe(false); + }); + + it('handles empty jobs array', async () => { + vi.mocked(api.getConversationJobs).mockResolvedValue({ jobs: [] }); + + const { result } = renderHook(() => + useJobStatus({ conversationUuid: 'test-uuid', pollInterval: 10000 }) + ); + + await flushPromises(); + + expect(result.current.activeJobs).toEqual([]); + }); + + it('handles missing jobs key', async () => { + vi.mocked(api.getConversationJobs).mockResolvedValue({}); + + const { result } = renderHook(() => + useJobStatus({ conversationUuid: 'test-uuid', pollInterval: 10000 }) + ); + + await flushPromises(); + + expect(result.current.activeJobs).toEqual([]); + }); + + it('calls onJobComplete when jobs finish', async () => { + const onComplete = vi.fn(); + // First call: processing, Second call: completed + vi.mocked(api.getConversationJobs) + .mockResolvedValueOnce({ jobs: [{ job_id: 'j1', status: 'running' }] }) + .mockResolvedValueOnce({ jobs: [{ job_id: 'j1', status: 'completed' }] }); + + const { result } = renderHook(() => + useJobStatus({ conversationUuid: 'test-uuid', pollInterval: 5000, onJobComplete: onComplete }) + ); + + await flushPromises(); + expect(result.current.isProcessing).toBe(true); + + // Advance time to trigger next poll + await act(async () => { + vi.advanceTimersByTime(5000); + }); + await flushPromises(); + + expect(onComplete).toHaveBeenCalledWith('test-uuid'); + }); + + it('tracks last job id', async () => { + vi.mocked(api.getConversationJobs).mockResolvedValue({ + jobs: [{ job_id: 'job-last', status: 'running' }], + }); + + const { result } = renderHook(() => + useJobStatus({ conversationUuid: 'test-uuid', pollInterval: 10000 }) + ); + + await flushPromises(); + + expect(result.current.lastJobId).toBe('job-last'); + }); + + it('handles fetch errors gracefully', async () => { + vi.mocked(api.getConversationJobs).mockRejectedValue(new Error('Network error')); + const consoleDebug = vi.spyOn(console, 'debug').mockImplementation(() => {}); + + renderHook(() => + useJobStatus({ conversationUuid: 'test-uuid', pollInterval: 10000 }) + ); + + await flushPromises(); + + expect(consoleDebug).toHaveBeenCalled(); + consoleDebug.mockRestore(); + }); + + it('polls at configured interval', async () => { + vi.mocked(api.getConversationJobs).mockResolvedValue({ jobs: [] }); + + renderHook(() => + useJobStatus({ conversationUuid: 'test-uuid', pollInterval: 5000 }) + ); + + await flushPromises(); + expect(api.getConversationJobs).toHaveBeenCalledTimes(1); + + await act(async () => { + vi.advanceTimersByTime(5000); + }); + await flushPromises(); + + expect(api.getConversationJobs).toHaveBeenCalledTimes(2); + }); + + it('refresh triggers immediate fetch', async () => { + vi.mocked(api.getConversationJobs).mockResolvedValue({ jobs: [] }); + + const { result } = renderHook(() => + useJobStatus({ conversationUuid: 'test-uuid', pollInterval: 10000 }) + ); + + await flushPromises(); + expect(api.getConversationJobs).toHaveBeenCalledTimes(1); + + await act(async () => { + await result.current.refresh(); + }); + + expect(api.getConversationJobs).toHaveBeenCalledTimes(2); + }); +}); diff --git a/frontend/src/__tests__/useSSE.test.ts b/frontend/src/__tests__/useSSE.test.ts new file mode 100644 index 0000000..c6621dd --- /dev/null +++ b/frontend/src/__tests__/useSSE.test.ts @@ -0,0 +1,134 @@ +import { describe, it, expect, vi } from 'vitest'; +import { renderHook, act } from '@testing-library/react'; +import { useSSE } from '../hooks/useSSE'; +import type { AgentEvent } from '../types'; + +const instances: any[] = []; + +vi.stubGlobal('EventSource', vi.fn(function (this: any, url: string) { + this.url = url; + this.readyState = 0; + this.onopen = null; + this.onmessage = null; + this.onerror = null; + this.close = function () { + this.readyState = 2; + }; + instances.push(this); + return this; +})); + +(globalThis as any).EventSource.CONNECTING = 0; +(globalThis as any).EventSource.OPEN = 1; +(globalThis as any).EventSource.CLOSED = 2; + +function newest() { + return instances[instances.length - 1]; +} + +describe('useSSE', () => { + beforeEach(() => { + instances.length = 0; + }); + + it('creates EventSource when enabled is true', () => { + const onEvent = vi.fn(); + renderHook(() => useSSE(onEvent, true)); + + expect((globalThis as any).EventSource).toHaveBeenCalled(); + expect(instances.length).toBe(1); + expect(instances[0].url).toContain('/api/events'); + }); + + it('does not create EventSource when enabled is false', () => { + const onEvent = vi.fn(); + renderHook(() => useSSE(onEvent, false)); + + expect(instances.length).toBe(0); + }); + + it('appends token as query parameter', () => { + const onEvent = vi.fn(); + renderHook(() => useSSE(onEvent, true, 'test-token')); + + expect(instances.length).toBe(1); + expect(instances[0].url).toContain('token=test-token'); + }); + + it('sets connected to true on open', () => { + const onEvent = vi.fn(); + const { result } = renderHook(() => useSSE(onEvent, true)); + + act(() => { + newest().onopen(new Event('open')); + }); + + expect(result.current.connected).toBe(true); + }); + + it('calls onEvent when message received', () => { + const onEvent = vi.fn(); + renderHook(() => useSSE(onEvent, true)); + + const event: AgentEvent = { event_type: 'text_delta', data: { text: 'hello' } }; + act(() => { + newest().onmessage(new MessageEvent('message', { data: JSON.stringify(event) })); + }); + + expect(onEvent).toHaveBeenCalledWith(event); + }); + + it('does not call onEvent for messages without data', () => { + const onEvent = vi.fn(); + renderHook(() => useSSE(onEvent, true)); + + act(() => { + newest().onmessage(new MessageEvent('message', { data: '' })); + }); + + expect(onEvent).not.toHaveBeenCalled(); + }); + + it('handles malformed JSON gracefully', () => { + const onEvent = vi.fn(); + renderHook(() => useSSE(onEvent, true)); + + act(() => { + newest().onmessage(new MessageEvent('message', { data: 'not-json' })); + }); + + expect(onEvent).not.toHaveBeenCalled(); + }); + + it('sets connected to false on error and closes connection', () => { + vi.useFakeTimers(); + const onEvent = vi.fn(); + const { result } = renderHook(() => useSSE(onEvent, true)); + + act(() => { + newest().onopen(new Event('open')); + }); + expect(result.current.connected).toBe(true); + + act(() => { + newest().onerror(new Event('error')); + }); + expect(result.current.connected).toBe(false); + + vi.useRealTimers(); + }); + + it('calls onReconnect when reconnecting after disconnect', () => { + const onEvent = vi.fn(); + const onReconnect = vi.fn(); + renderHook(() => useSSE(onEvent, true, null, onReconnect)); + + act(() => { + newest().onopen(new Event('open')); + }); + act(() => { + newest().onopen(new Event('open')); + }); + expect(onReconnect).toHaveBeenCalled(); + }); +}); diff --git a/frontend/vite.config.ts b/frontend/vite.config.ts index 9077474..4b0c245 100644 --- a/frontend/vite.config.ts +++ b/frontend/vite.config.ts @@ -18,5 +18,10 @@ export default defineConfig({ globals: true, setupFiles: ['./src/__tests__/setup.ts'], include: ['src/**/*.test.{ts,tsx}'], + coverage: { + provider: 'v8', + include: ['src/**/*.{ts,tsx}'], + exclude: ['src/__tests__/**', 'src/types.ts', 'src/main.tsx'], + }, }, }) diff --git a/qodana.yaml b/qodana.yaml index 843fce2..9a4826a 100644 --- a/qodana.yaml +++ b/qodana.yaml @@ -3,8 +3,8 @@ #################################################################################################################### version: "1.0" -linter: jetbrains/qodana-python-community:2026.1 +linter: jetbrains/qodana-python-community:2025.3 profile: name: qodana.recommended include: - - name: CheckDependencyLicenses \ No newline at end of file + - name: CheckDependencyLicenses diff --git a/site/docs/setup.md b/site/docs/setup.md index 8797c31..878a7c9 100644 --- a/site/docs/setup.md +++ b/site/docs/setup.md @@ -149,8 +149,8 @@ Run `make help` for the full list: | `make db-upgrade` | Run migrations | | **Testing** | | | `make test` | Run all tests (backend + frontend + docs build) | -| `make test-backend` | Backend tests only (149 tests) | -| `make test-frontend` | Frontend tests only (29 tests) | +| `make test-backend` | Backend tests only (591 tests) | +| `make test-frontend` | Frontend tests only (182 tests) | | `make test-docs` | Docs build check | | **Other** | | | `make check` | Type-check backend + frontend |