From 976e0b7b1b90dd7758e76c6d7dcd222af36c088c Mon Sep 17 00:00:00 2001 From: xprilion Date: Tue, 28 Apr 2026 16:45:29 +0530 Subject: [PATCH 1/3] Fix HF tool calls and stricter todo list --- backend/configs/prompts/system_prompt.yaml | 34 +++++++--- backend/openmlr/agent/session.py | 4 ++ backend/openmlr/tools/http_utils.py | 2 +- backend/openmlr/tools/plan.py | 60 ++++++++++++----- backend/openmlr/tools/registry.py | 76 +++++++++++++++++++++- 5 files changed, 149 insertions(+), 27 deletions(-) diff --git a/backend/configs/prompts/system_prompt.yaml b/backend/configs/prompts/system_prompt.yaml index 7647b59..edf8334 100644 --- a/backend/configs/prompts/system_prompt.yaml +++ b/backend/configs/prompts/system_prompt.yaml @@ -45,28 +45,42 @@ prompt: | Calls to ask_user will be **rejected by the system**. **Rules**: - 1. Follow the task plan — check it with `plan_tool get` if unsure - 2. Work through tasks one at a time, marking them in_progress then completed - 3. When completing a task, provide a summary and next_hints - 4. Add all papers, code, datasets as resources via plan_tool add_resource - 5. Do NOT ask the user questions — if something is ambiguous, make a reasonable + 1. ALWAYS start by calling `plan_tool get` to load the current task plan + 2. You MUST mark a task as `in_progress` before doing any work on it — + work tools (bash, write, edit, writing, research, sandbox, etc.) are + **blocked by the system** unless a task is currently in_progress + 3. Work through ONE task at a time — complete the current task before + starting the next one + 4. You MUST mark each task as `completed` with a `summary` and `next_hints` + when finished — the system auto-generates a completion report. The task + status will NOT be updated without a summary. + 5. Do NOT skip task completion — the system will block you from starting + the next task if the current one has no completion report + 6. Add all papers, code, datasets as resources via plan_tool add_resource + 7. Do NOT ask the user questions — if something is ambiguous, make a reasonable decision and document it in your completion report - 6. Keep pushing through the task list until done or interrupted - 7. Generate completion reports for each task + 8. Keep pushing through the task list until done or interrupted {% else %} ## CURRENT MODE: PLAN (default) Plan only — ask questions, gather context, create plan. No execution. {% endif %} - # Task Management + # Task Management (MANDATORY) - Always create a plan using `plan_tool` before starting work (in Plan mode) + - In Execute mode, ALWAYS call `plan_tool get` first to load the task plan + - Mark a task as `in_progress` BEFORE doing any work — work tools are blocked + by the system without an active in_progress task - When completing a task, call `plan_tool` update with status="completed", - include a `summary` of what was accomplished and `next_hints` + include a `summary` of what was accomplished and `next_hints` — the system + will reject the completion without a summary - This auto-generates a completion report stored as a resource - - ONE task in_progress at a time — complete current before starting next + - ONE task in_progress at a time — you cannot start a new task until the + current one is completed with a report or cancelled - Use `plan_tool add_resource` to track every paper, code repo, or doc + - The workflow is enforced by the system: get plan → mark in_progress → + do work → mark completed with summary → repeat for next task # Academic Research Tools diff --git a/backend/openmlr/agent/session.py b/backend/openmlr/agent/session.py index 8051f29..deddf24 100644 --- a/backend/openmlr/agent/session.py +++ b/backend/openmlr/agent/session.py @@ -37,6 +37,10 @@ class Session: # Sandbox reference sandbox: Any | None = None + # Plan task state (cached for tool enforcement — updated by plan_tool) + plan_tasks: list[dict] | None = None # None = not loaded yet + _plan_loaded: bool = False # True once we've checked DB or plan_tool ran + # Turn counter (for title generation etc.) turn_count: int = 0 diff --git a/backend/openmlr/tools/http_utils.py b/backend/openmlr/tools/http_utils.py index 137ca69..3c0a597 100644 --- a/backend/openmlr/tools/http_utils.py +++ b/backend/openmlr/tools/http_utils.py @@ -76,7 +76,7 @@ async def fetch_with_retry( for attempt in range(max_retries + 1): try: - async with httpx.AsyncClient() as client: + async with httpx.AsyncClient(follow_redirects=True) as client: response = await client.request( method, url, diff --git a/backend/openmlr/tools/plan.py b/backend/openmlr/tools/plan.py index f0d796e..ccb916f 100644 --- a/backend/openmlr/tools/plan.py +++ b/backend/openmlr/tools/plan.py @@ -124,6 +124,7 @@ async def _handle_plan( {"title": t.get("title", ""), "status": t.get("status", "pending")} for t in tasks ] await ops.upsert_conversation_tasks(db, conv_id, task_list) + _sync_session_plan(session, task_list) await _emit_plan(session, conv_id, db) # Auto-save plan as PLAN.md resource (pinned) @@ -144,6 +145,7 @@ async def _handle_plan( ] task_list.append({"title": title, "status": "pending"}) await ops.upsert_conversation_tasks(db, conv_id, task_list) + _sync_session_plan(session, task_list) await _emit_plan(session, conv_id, db) # Update PLAN.md @@ -164,7 +166,22 @@ async def _handle_plan( task = existing[task_index] old_status = task.status - # ENFORCEMENT: When starting a new task (in_progress), check if previous in_progress task has a report + # ── VALIDATION (all checks BEFORE any state changes) ── + + # ENFORCEMENT: Completing a task requires a summary + generates a report + if status == "completed" and old_status != "completed": + if not summary: + return ( + f"COMPLETION REQUIRES SUMMARY: To mark task {task_index} as completed, " + f"you must provide a 'summary' of what was accomplished.\n\n" + f"Example:\n" + f"plan_tool(operation='update', task_index={task_index}, status='completed', " + f"summary='Found 5 relevant papers on X technique...', " + f"next_hints='Review paper Y for implementation details')" + ), False + + # ENFORCEMENT: When starting a new task (in_progress), the previous + # in_progress task must be completed/cancelled with a report first. if status == "in_progress" and old_status != "in_progress": in_progress_tasks = [i for i, t in enumerate(existing) if t.status == "in_progress"] if in_progress_tasks: @@ -177,34 +194,31 @@ async def _handle_plan( ) if not has_report: return ( - f"⚠️ WORKFLOW VIOLATION: Cannot start task {task_index} while task {prev_idx} " - f"('{prev_task.title}') is still in progress.\n\n" + f"WORKFLOW VIOLATION: Cannot start task {task_index} while task {prev_idx} " + f"('{prev_task.title}') is still in progress without a completion report.\n\n" f"You must either:\n" f"1. Complete task {prev_idx} first with status='completed', summary, and next_hints\n" f"2. Cancel task {prev_idx} if it's no longer needed\n\n" f"This ensures a completion report is generated before moving on." ), False - # Update status + # ── STATE UPDATE (validation passed — persist changes) ── + task_list = [ {"title": t.title, "status": t.status, "priority": t.priority} for t in existing ] task_list[task_index]["status"] = status await ops.upsert_conversation_tasks(db, conv_id, task_list) + _sync_session_plan(session, task_list) await _emit_plan(session, conv_id, db) - # ENFORCEMENT: Completing a task requires a summary - if status == "completed" and old_status != "completed": - if not summary: - return ( - f"⚠️ COMPLETION REQUIRES SUMMARY: To mark task {task_index} as completed, " - f"you must provide a 'summary' of what was accomplished.\n\n" - f"Example:\n" - f"plan_tool(operation='update', task_index={task_index}, status='completed', " - f"summary='Found 5 relevant papers on X technique...', " - f"next_hints='Review paper Y for implementation details')" - ), False + # Update PLAN.md to reflect new status + plan_md = _generate_plan_md(task_list) + await ops.upsert_plan_resource(db, conv_id, plan_md) + + # ── POST-UPDATE: Generate completion report if task was completed ── + if status == "completed" and old_status != "completed": report = _generate_completion_report(task.title, summary, next_hints) report_id = f"report-{task_index}-{len(existing)}" @@ -224,9 +238,18 @@ async def _handle_plan( result += f"\nHints for next tasks: {next_hints}" return result, True + await _emit_resources(session, conv_id, db) return await _format_plan(db, conv_id), True elif operation == "get": + # Sync session plan state on read (lazy load for enforcement) + existing = await ops.get_conversation_tasks(db, conv_id) + if existing: + task_list = [ + {"title": t.title, "status": t.status, "priority": t.priority} for t in existing + ] + _sync_session_plan(session, task_list) + result = await _format_plan(db, conv_id) # Include any next_hints from recent reports for context resources = await ops.get_conversation_resources(db, conv_id) @@ -267,6 +290,13 @@ async def _handle_plan( return f"Unknown operation: {operation}", False +def _sync_session_plan(session, task_list: list[dict]) -> None: + """Update the session's cached plan state for tool enforcement.""" + if session: + session.plan_tasks = task_list + session._plan_loaded = True + + async def get_report_content(report_id: str) -> str | None: """Retrieve a stored report by ID. Used by the API.""" session_factory = _get_session_factory() diff --git a/backend/openmlr/tools/registry.py b/backend/openmlr/tools/registry.py index 168142e..eb46b84 100644 --- a/backend/openmlr/tools/registry.py +++ b/backend/openmlr/tools/registry.py @@ -1,9 +1,12 @@ """ToolRouter — registers, dispatches, and manages all agent tools.""" import inspect +import logging from ..agent.types import ToolSpec +logger = logging.getLogger("openmlr.tools.registry") + # Define which tools are allowed in each mode # Tools not listed are allowed in all modes MODE_TOOL_RESTRICTIONS = { @@ -118,6 +121,70 @@ def is_tool_allowed(self, name: str) -> tuple[bool, str]: error_msg = restrictions.get("blocked_message", "Tool '{tool}' not allowed in this mode.") return False, error_msg.format(tool=name, mode=self._current_mode) + async def _check_task_enforcement(self, tool_name: str, session) -> str | None: + """Check if a work tool can run — requires an in_progress task when a plan exists. + + Returns an error message string if blocked, or None if allowed. + """ + # plan_tool is always allowed (it's how you update task status) + plan_allowed = MODE_TOOL_RESTRICTIONS.get("plan", {}).get("allowed", set()) + if tool_name in plan_allowed: + return None + + # Lazy-load plan state from DB if not yet loaded this session + if not getattr(session, "_plan_loaded", False): + try: + plan_tasks = await self._load_plan_from_db(session) + session.plan_tasks = plan_tasks + session._plan_loaded = True + except Exception as e: + logger.warning(f"Failed to load plan state for enforcement: {e}") + # Don't block on DB errors — allow the tool call + return None + + plan_tasks = getattr(session, "plan_tasks", None) + if not plan_tasks: + # No plan exists — no enforcement needed + return None + + # Check if any task is in_progress + in_progress = any(t.get("status") == "in_progress" for t in plan_tasks) + if in_progress: + return None + + # Check if all tasks are completed/cancelled (work is done) + all_done = all(t.get("status") in ("completed", "cancelled") for t in plan_tasks) + if all_done: + return None + + return ( + f"TASK ENFORCEMENT: Cannot use '{tool_name}' without an active task.\n\n" + f"A task plan exists but no task is marked as in_progress. " + f"You must mark a task as in_progress before doing any work.\n\n" + f"Steps:\n" + f"1. Call plan_tool(operation='get') to review the current plan\n" + f"2. Call plan_tool(operation='update', task_index=N, status='in_progress') " + f"to start working on a task\n" + f"3. Then you can use work tools like '{tool_name}'\n\n" + f"After finishing work, mark the task completed with a summary before starting the next one." + ) + + async def _load_plan_from_db(self, session) -> list[dict] | None: + """Load plan tasks from DB for a session. Used for lazy enforcement init.""" + conv_id = getattr(session, "conversation_id", None) + if not conv_id: + return None + + from ..db import operations as ops + from .plan import _get_session_factory + + session_factory = _get_session_factory() + async with session_factory() as db: + tasks = await ops.get_conversation_tasks(db, conv_id) + if not tasks: + return None + return [{"title": t.title, "status": t.status, "priority": t.priority} for t in tasks] + def get_tool(self, name: str) -> ToolSpec | None: """Look up a tool by name.""" return self.tools.get(name) @@ -167,12 +234,19 @@ async def call_tool( allowed, error_msg = self.is_tool_allowed(name) if not allowed: warning = ( - f"⚠️ MODE VIOLATION: {error_msg}\n\n" + f"MODE VIOLATION: {error_msg}\n\n" f"Current mode: {self._current_mode.upper()}\n" f"To use this tool, ask the user to switch modes using ask_user with suggest_mode parameter." ) return warning, False + # ENFORCEMENT: In execute mode, work tools require an in_progress task. + # "Work tools" = anything NOT in the plan-mode allowed set (read-only tools). + if enforce_mode and session and self._current_mode == "execute": + violation = await self._check_task_enforcement(name, session) + if violation: + return violation, False + tool = self.tools.get(name) if not tool: return f"Unknown tool: {name}", False From a75c59af3e8707467f57d2df82bbf55b0df685c6 Mon Sep 17 00:00:00 2001 From: xprilion Date: Wed, 29 Apr 2026 10:03:44 +0530 Subject: [PATCH 2/3] improve layout --- backend/configs/prompts/system_prompt.yaml | 193 +++++--- backend/openmlr/agent/loop.py | 15 +- backend/openmlr/agent/session.py | 3 + backend/openmlr/db/operations.py | 37 +- backend/openmlr/models.py | 1 + backend/openmlr/routes/agent.py | 60 ++- backend/openmlr/routes/projects.py | 18 +- backend/openmlr/routes/terminal.py | 3 +- backend/openmlr/services/redis_pubsub.py | 37 ++ backend/openmlr/services/session_manager.py | 40 +- backend/openmlr/tasks/agent_tasks.py | 20 + backend/openmlr/tools/local.py | 146 ++++-- backend/openmlr/tools/plan.py | 171 ++++++- backend/openmlr/tools/registry.py | 57 ++- backend/openmlr/tools/research.py | 63 ++- backend/openmlr/tools/writing.py | 107 ++++- backend/openmlr/workspace/persistence.py | 12 + backend/pyproject.toml | 3 + backend/tests/conftest.py | 30 +- backend/tests/test_agent_loop.py | 103 +++++ backend/tests/test_conversations.py | 17 +- backend/tests/test_db_operations.py | 78 +++- backend/tests/test_models.py | 9 + backend/tests/test_redis_pubsub.py | 60 +++ backend/tests/test_tool_registry.py | 113 ++++- backend/tests/test_tools_local.py | 60 +++ backend/tests/test_tools_writing.py | 138 ++++++ backend/tests/test_workspace.py | 22 + frontend/src/App.tsx | 134 +++++- frontend/src/__tests__/InputArea.test.tsx | 22 +- .../src/__tests__/OnboardingModal.test.tsx | 8 +- .../src/__tests__/ProjectSelector.test.tsx | 231 ++++++++++ frontend/src/__tests__/RightPanel.test.tsx | 76 ++- .../src/__tests__/SandboxSettings.test.tsx | 60 --- frontend/src/__tests__/SettingsPanel.test.tsx | 109 ----- frontend/src/__tests__/Sidebar.test.tsx | 183 +++----- .../src/__tests__/TodoReviewDrawer.test.tsx | 197 ++++++++ frontend/src/api.ts | 6 +- frontend/src/components/CollapsiblePanel.tsx | 39 ++ frontend/src/components/ComputeSelector.tsx | 15 +- frontend/src/components/FileTree.tsx | 42 +- frontend/src/components/InputArea.tsx | 60 +-- frontend/src/components/OnboardingModal.tsx | 96 +++- frontend/src/components/ProjectSelector.tsx | 94 ++++ frontend/src/components/RightPanel.tsx | 433 ++++++------------ frontend/src/components/SettingsPanel.tsx | 369 --------------- frontend/src/components/Sidebar.tsx | 130 +++--- frontend/src/components/Terminal.tsx | 22 +- frontend/src/components/TodoReviewDrawer.tsx | 250 ++++++++++ .../components/settings/SandboxSettings.tsx | 106 ----- frontend/src/types.ts | 8 + 51 files changed, 2853 insertions(+), 1453 deletions(-) create mode 100644 frontend/src/__tests__/ProjectSelector.test.tsx delete mode 100644 frontend/src/__tests__/SandboxSettings.test.tsx delete mode 100644 frontend/src/__tests__/SettingsPanel.test.tsx create mode 100644 frontend/src/__tests__/TodoReviewDrawer.test.tsx create mode 100644 frontend/src/components/CollapsiblePanel.tsx create mode 100644 frontend/src/components/ProjectSelector.tsx delete mode 100644 frontend/src/components/SettingsPanel.tsx create mode 100644 frontend/src/components/TodoReviewDrawer.tsx delete mode 100644 frontend/src/components/settings/SandboxSettings.tsx diff --git a/backend/configs/prompts/system_prompt.yaml b/backend/configs/prompts/system_prompt.yaml index edf8334..99011f6 100644 --- a/backend/configs/prompts/system_prompt.yaml +++ b/backend/configs/prompts/system_prompt.yaml @@ -1,8 +1,8 @@ title: OpenMLR System Prompt -version: 6 +version: 7 prompt: | - You are OpenMLR, an ML research intern. You help users plan, research, + You are OpenMLR, an ML research assistant. You help users plan, research, write, and execute ML work end-to-end. # Mode System @@ -13,7 +13,7 @@ prompt: | ## CURRENT MODE: PLAN You are in **Plan mode**. Your job is to understand the task, ask clarifying - questions, gather context, and produce a comprehensive plan. + questions, and produce a comprehensive plan. **Available tools**: ask_user, plan_tool, read (local files), web_search, papers, github_read_file, github_find_examples, github_search_repos, @@ -25,15 +25,20 @@ prompt: | sandbox/code execution, compute_select, compute_sync_up, compute_sync_down. Calls to unavailable tools will be **rejected by the system**. - **Rules**: + **Plan Mode Rules**: 1. Ask clarifying questions using `ask_user` before making assumptions - 2. Search the web, papers, GitHub, and Hugging Face to gather context - 3. Read local project files with `read` to understand the codebase - 4. Create a structured plan using `plan_tool` with clear, actionable tasks - 5. The plan is auto-saved as PLAN.md in resources — the user can see it - 6. Do NOT execute any work — plan only - 7. Do NOT write content, run code, or make changes - 8. Be thorough in your plan — it will be the blueprint for Execute mode + 2. You may use search/papers for **quick feasibility checks** only + (e.g., confirming a method exists, checking if a dataset is available) + 3. Do NOT conduct comprehensive research in Plan mode — add research + tasks to the plan for Execute mode instead + 4. Read local project files with `read` to understand the codebase + 5. Create a structured plan using `plan_tool` with clear, actionable tasks + 6. The plan is auto-saved as PLAN.md in the workspace — visible in Files tab + 7. Do NOT execute any work — plan only + 8. Do NOT write content, run code, or make changes + 9. Research tasks (literature review, deep paper analysis, code search) + belong as Execute mode tasks — NOT done during planning + 10. Be thorough in your plan — it will be the blueprint for Execute mode {% elif mode == "execute" %} ## CURRENT MODE: EXECUTE @@ -44,7 +49,7 @@ prompt: | **Available tools**: ALL tools EXCEPT ask_user. Calls to ask_user will be **rejected by the system**. - **Rules**: + **Execute Mode Rules**: 1. ALWAYS start by calling `plan_tool get` to load the current task plan 2. You MUST mark a task as `in_progress` before doing any work on it — work tools (bash, write, edit, writing, research, sandbox, etc.) are @@ -52,13 +57,14 @@ prompt: | 3. Work through ONE task at a time — complete the current task before starting the next one 4. You MUST mark each task as `completed` with a `summary` and `next_hints` - when finished — the system auto-generates a completion report. The task - status will NOT be updated without a summary. + when finished — the system auto-generates a completion report saved to + the workspace as a file. The task status will NOT be updated without + a summary. 5. Do NOT skip task completion — the system will block you from starting the next task if the current one has no completion report - 6. Add all papers, code, datasets as resources via plan_tool add_resource - 7. Do NOT ask the user questions — if something is ambiguous, make a reasonable - decision and document it in your completion report + 6. Track all papers, code, datasets via `plan_tool add_resource` + 7. Do NOT ask the user questions — if something is ambiguous, make a + reasonable decision and document it in your completion report 8. Keep pushing through the task list until done or interrupted {% else %} @@ -66,6 +72,47 @@ prompt: | Plan only — ask questions, gather context, create plan. No execution. {% endif %} + # Tool Selection Guide + + Use this decision tree to pick the right tool: + + ## Finding papers and methods + - Quick check (1-3 queries): use `papers` or `web_search` directly + - Deep literature review (10+ papers): use `research` sub-agent (Execute only) + - Track every paper found: `plan_tool add_resource` with type=paper + + ## Writing an academic paper + - ALWAYS use the `writing` tool: create_project -> set_outline -> write_section + - NEVER use the `write` file tool for papers + - Paper auto-saves to `papers/` in the workspace (visible in Files tab) + - Write ALL sections — do NOT leave `[Not yet written]` placeholders + + ## Writing code, configs, scripts + - Use `write` for new files + - Use `edit` for modifying existing files (find-and-replace) + - Use `bash` for running commands, installing packages, executing scripts + - All file operations target the project workspace automatically + - Written files appear in the Files tab immediately + + ## Tracking research knowledge + - `workspace knowledge_add` for entities (papers, methods, datasets) + - `workspace knowledge_relate` for relationships between entities + - `workspace note` for research summaries and important findings + - These persist across conversations in the same project + + ## Running code and experiments + - `bash` executes in Docker isolation (8GB RAM, read-only root) + - Default timeout: 120s, max: 3600s + - Working directory is the project workspace + - Install dependencies first: `bash(command='pip install ...')` + - Always check environment before running: `bash(command='python --version')` + + ## Deep research + - Use the `research` sub-agent for comprehensive investigations + - It has independent context and uses: web_search, papers, github tools, hf tools + - Max 60 iterations, separate token budget + - Save findings using `workspace knowledge_add` and `plan_tool add_resource` + # Task Management (MANDATORY) - Always create a plan using `plan_tool` before starting work (in Plan mode) @@ -75,21 +122,43 @@ prompt: | - When completing a task, call `plan_tool` update with status="completed", include a `summary` of what was accomplished and `next_hints` — the system will reject the completion without a summary - - This auto-generates a completion report stored as a resource + - This auto-generates a completion report saved to `.project-meta/reports/` - ONE task in_progress at a time — you cannot start a new task until the current one is completed with a report or cancelled - - Use `plan_tool add_resource` to track every paper, code repo, or doc - - The workflow is enforced by the system: get plan → mark in_progress → - do work → mark completed with summary → repeat for next task + - The workflow is enforced by the system: get plan -> mark in_progress -> + do work -> mark completed with summary -> repeat for next task + - In Execute mode, creating/adding tasks requires user approval + + # Workspace Structure + + The project workspace is your persistent working directory. All files you + create appear in the Files tab for the user. Use the correct directory: + + ``` + code/ — source code, scripts, notebooks + data/ — datasets, preprocessed data + models/ — trained models, checkpoints + outputs/ — results, figures, tables + papers/ — paper drafts (auto-saved by writing tool) + research/ — research notes and searches + notes/ — markdown notes (saved by workspace note tool) + searches/ — saved search results + citations/ — citation files + logs/ — tool failures, compute probes, experiments + .project-meta/ — auto-managed by the system + plans/ — PLAN.md (auto-updated by plan_tool) + reports/ — completion reports (auto-generated) + knowledge.json — knowledge graph + ``` # Academic Research Tools For literature review and paper research, use the `papers` tool: - **search**: Search papers via OpenAlex (with Semantic Scholar fallback) - **arxiv_search**: Search arXiv directly (best for ML/CS/Physics preprints) - - **semantic_search**: Search directly via Semantic Scholar (good for recent papers) + - **semantic_search**: Search directly via Semantic Scholar - **trending**: Find highly-cited recent papers in a topic - - **details**: Get full metadata for a paper (by DOI, arXiv ID, or OpenAlex ID) + - **details**: Get full metadata for a paper (by DOI, arXiv ID, etc.) - **read_paper**: Read arXiv paper sections (provides table of contents first) - **citations**: Get references and citing papers - **recommend**: Find related papers @@ -97,41 +166,47 @@ prompt: | - **find_code**: Find code implementations via Papers With Code - **find_datasets**: Find related datasets via Papers With Code - **Search source recommendations**: - - Use `arxiv_search` for latest ML/AI preprints and CS research - - Use `search` (OpenAlex) for broad academic coverage across all fields - - Use `semantic_search` when you need abstracts or Semantic Scholar IDs + **Search source tips**: + - `arxiv_search` for latest ML/AI preprints and CS research + - `search` (OpenAlex) for broad academic coverage across all fields + - `semantic_search` when you need abstracts or Semantic Scholar IDs - For finding code implementations on GitHub: + For finding code on GitHub: - **github_search_repos**: Search repos by keywords, topics, or paper titles - - **github_find_examples**: Search code for specific patterns or implementations - - **github_get_readme**: Read a repo's README to understand what it does + - **github_find_examples**: Search code for specific patterns + - **github_get_readme**: Read a repo's README - **github_read_file**: Read specific files from repositories - **Research workflow**: - 1. Use `papers search` or `semantic_search` to find relevant papers - 2. Use `papers details` to get abstracts and metadata - 3. Use `papers read_paper` to read full arXiv papers section by section - 4. Use `github_search_repos` to find implementations - 5. Track all findings with `plan_tool add_resource` - # Paper Writing When writing a paper, use the `writing` tool exclusively: 1. `create_project` with a title 2. `set_outline` with section structure - 3. `write_section` for each section — the paper is AUTO-SAVED after each write + 3. `write_section` for each section — auto-saved after each write 4. `add_citation` for references 5. `get_draft` to review the full paper + 6. `list_sections` to check progress - **CRITICAL**: Do NOT use the `write` file tool to save papers. The `writing` tool - auto-saves to the database and the user can preview/export from the Paper tab. - Do NOT call `export` — the user handles export from the UI. + **CRITICAL**: Do NOT use the `write` file tool to save papers. The `writing` + tool auto-saves to the database and writes to `papers/` in the workspace. + + **CRITICAL**: Do NOT leave sections marked "[Not yet written]" or with + placeholder content. Write substantive content for EVERY section. Use + `get_draft` to verify all sections are filled before finishing. # Code Execution - Code runs inside a Docker container (/workspace) when needed. - Before running code: check the environment, install dependencies. + Code runs inside a Docker container when running locally: + - 8GB memory limit, 256 process limit + - Read-only root filesystem (only /tmp is writable) + - Bridge network (no host access) + - Working directory: the project workspace + + Before running code: + 1. Check the environment: `bash(command='python --version && pip list')` + 2. Install dependencies: `bash(command='pip install torch transformers')` + 3. Run your script: `bash(command='python train.py --epochs 10')` + Never modify the user's host environment directly. # Context Management @@ -148,10 +223,9 @@ prompt: | # MCP Tools - The user may have configured MCP (Model Context Protocol) servers that provide - additional tools. These tools appear in the available tools list below and can - be used like any other tool. MCP tools can provide capabilities like database - access, file system operations, external APIs, and more. + The user may have configured MCP (Model Context Protocol) servers that + provide additional tools. These appear in the available tools list and + can be used like any other tool. ## Available Tools {% for spec in tool_specs %} @@ -172,26 +246,21 @@ prompt: | If a project workspace is active, use it to persist knowledge across conversations: - - Use `workspace status` at the start of a conversation to understand what has - been done before (papers found, notes written, experiments run, known failures) - - Use `workspace knowledge_add` to record important entities (papers, methods, - datasets, findings) in the knowledge graph - - Use `workspace knowledge_relate` to link entities (e.g., paper --proposes--> method) - - Use `workspace note` to save research summaries and important findings + - Use `workspace status` at the start to check what has been done before + - Use `workspace knowledge_add` to record important entities + - Use `workspace knowledge_relate` to link entities + - Use `workspace note` to save research summaries - Use `workspace knowledge_summary` to review accumulated knowledge - - Use `workspace recent_failures` to check for known tool/API issues before retrying + - Use `workspace recent_failures` to check for known tool/API issues - The workspace persists independently of compute resources. Files in the workspace - (code/, data/, papers/, research/, outputs/) survive compute changes and new conversations. + The workspace persists independently of compute resources. Files survive + compute changes and new conversations. - **Important**: The workspace is the source of truth for the project. Always check it - before doing redundant work. Save important findings so future conversations can build on them. + **Important**: The workspace is the source of truth for the project. Always + check it before doing redundant work. # Compute Planning - When starting tasks that require significant computation (training models, processing large datasets, etc.): + When starting tasks that require significant computation: 1. Use `compute_plan` to verify the active node meets requirements 2. If not, use `compute_select` to switch to a suitable node 3. Always `sandbox_probe` before executing code on a node for the first time - - Note: The compute resource is separate from the project workspace. Switching compute - does not affect your workspace files. The workspace is always available locally. diff --git a/backend/openmlr/agent/loop.py b/backend/openmlr/agent/loop.py index 133f631..2ae3850 100644 --- a/backend/openmlr/agent/loop.py +++ b/backend/openmlr/agent/loop.py @@ -60,7 +60,9 @@ async def _run_agent(session: Session, tool_router, user_message: str, mode: str # Inject per-message mode hint (short reinforcement of system prompt rules) mode_hint = f"[Mode: {effective_mode.upper()}] " + ( - "Plan only — ask questions, gather context, create plan. No execution." + "Plan only — ask questions, create plan. " + "Use search/papers only for quick feasibility checks. " + "Do NOT do comprehensive research here — add research as Execute mode tasks." if effective_mode == "plan" else "Execute the plan — do the work, no questions. All tools except ask_user." ) @@ -167,6 +169,17 @@ async def _run_agent(session: Session, tool_router, user_message: str, mode: str ) ) + # Persist assistant text that accompanies tool calls (otherwise + # it only lives in the in-memory ContextManager and is lost on + # page refresh). + if result.content: + await session.emit( + AgentEvent( + event_type="assistant_message", + data={"content": result.content}, + ) + ) + # Check for approval-required tools needs_approval = [] auto_approve = [] diff --git a/backend/openmlr/agent/session.py b/backend/openmlr/agent/session.py index deddf24..efb8e34 100644 --- a/backend/openmlr/agent/session.py +++ b/backend/openmlr/agent/session.py @@ -34,6 +34,9 @@ class Session: # Question/answer flow (ask_user tool) pending_answers: Any | None = None + # TODO approval flow (plan_tool in execute mode) + pending_todo_approval: Any | None = None + # Sandbox reference sandbox: Any | None = None diff --git a/backend/openmlr/db/operations.py b/backend/openmlr/db/operations.py index c5dbaa9..8818a57 100644 --- a/backend/openmlr/db/operations.py +++ b/backend/openmlr/db/operations.py @@ -172,14 +172,35 @@ async def create_conversation( async def get_conversations(db: AsyncSession, user_id: int) -> list[Conversation]: + """Return all conversations for a user that belong to a project.""" result = await db.execute( select(Conversation) - .where(Conversation.user_id == user_id) + .where(Conversation.user_id == user_id, Conversation.project_id.isnot(None)) .order_by(Conversation.updated_at.desc()) ) return list(result.scalars().all()) +async def delete_orphan_conversations(db: AsyncSession, user_id: int) -> int: + """Delete conversations with no project (project_id IS NULL). + + Returns the count of deleted conversations. Messages, tasks, resources, + and jobs cascade-delete via the FK constraints. + """ + result = await db.execute( + select(Conversation).where( + Conversation.user_id == user_id, Conversation.project_id.is_(None) + ) + ) + orphans = list(result.scalars().all()) + count = len(orphans) + for conv in orphans: + await db.delete(conv) + if count > 0: + await db.commit() + return count + + async def get_conversation_by_id(db: AsyncSession, conv_id: int) -> Conversation | None: result = await db.execute(select(Conversation).where(Conversation.id == conv_id)) return result.scalar_one_or_none() @@ -394,6 +415,20 @@ async def update_task_status( return False +# ---- Workspace Path Helpers ---- + + +async def get_project_workspace_for_conversation(db: AsyncSession, conv_id: int) -> str | None: + """Resolve the project workspace path for a conversation (conv -> project -> workspace_path).""" + conv = await get_conversation_by_id(db, conv_id) + if not conv or not conv.project_id: + return None + project = await get_project_by_id(db, conv.project_id) + if not project: + return None + return project.workspace_path + + # ---- Conversation Resources ---- diff --git a/backend/openmlr/models.py b/backend/openmlr/models.py index 5e339a0..7a2f9d2 100644 --- a/backend/openmlr/models.py +++ b/backend/openmlr/models.py @@ -40,6 +40,7 @@ class ConversationCreate(BaseModel): title: str | None = "New conversation" model: str | None = None mode: str | None = "general" # "research", "writing", "coding", "general" + project_uuid: str | None = None # required — conversations must belong to a project class ConversationResponse(BaseModel): diff --git a/backend/openmlr/routes/agent.py b/backend/openmlr/routes/agent.py index cd4c68c..bafd95b 100644 --- a/backend/openmlr/routes/agent.py +++ b/backend/openmlr/routes/agent.py @@ -102,6 +102,9 @@ async def list_conversations( user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): + # Clean up orphan conversations (no project) on every list call. + # This handles the migration from pre-project-required era. + await ops.delete_orphan_conversations(db, user.id) convs = await ops.get_conversations(db, user.id) return {"conversations": [_conv_dict(c) for c in convs]} @@ -113,12 +116,21 @@ async def create_conversation( user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): + # Conversations must belong to a project + project_id = None + if body.project_uuid: + project = await ops.get_project_by_uuid(db, body.project_uuid, user.id) + if not project: + raise HTTPException(status_code=404, detail="Project not found") + project_id = project.id + conv = await ops.create_conversation( db, user.id, title=body.title, model=body.model, mode=body.mode, + project_id=project_id, ) _sm(request).current_conversation_id = conv.id return {"conversation": _conv_dict(conv)} @@ -365,12 +377,12 @@ async def send_message( user_default_model = user_agent_settings.get("default_model") if not sm.current_conversation_id: - # Create conversation with user's default model - conv = await ops.create_conversation(db, user.id, model=user_default_model) - sm.current_conversation_id = conv.id - else: - conv = await ops.get_conversation_by_id(db, sm.current_conversation_id) + raise HTTPException( + status_code=400, + detail="No active conversation. Create a conversation first.", + ) + conv = await ops.get_conversation_by_id(db, sm.current_conversation_id) if not conv: raise HTTPException(status_code=400, detail="No active conversation") @@ -592,6 +604,42 @@ async def submit_approval( return {"ok": True} +@router.post("/todo-approval") +async def submit_todo_approval( + request: Request, + user: User = Depends(get_current_user), +): + """Submit approval/rejection for proposed TODO list changes.""" + body = await request.json() + approved = body.get("approved", False) + tasks = body.get("tasks") # optional modified task list + + result = {"approved": approved, "tasks": tasks} + + # Try in-process session first (inline mode) + active = _sm(request).get_current_session() + if ( + active + and hasattr(active.session, "pending_todo_approval") + and active.session.pending_todo_approval + ): + if not active.session.pending_todo_approval.done(): + active.session.pending_todo_approval.set_result(result) + return {"ok": True} + + # Publish to Redis for background job workers + try: + from ..services.redis_pubsub import publish_todo_approval + + sm = _sm(request) + if sm.current_conversation_id: + await publish_todo_approval(sm.current_conversation_id, result) + except Exception as e: + logger.warning(f"Failed to relay todo approval via Redis: {e}") + + return {"ok": True} + + @router.post("/undo") async def undo(request: Request, user: User = Depends(get_current_user)): active = _sm(request).get_current_session() @@ -698,7 +746,7 @@ async def _persist(event: AgentEvent): }, ) except Exception: - pass + logger.exception("Failed to persist message to DB (event_type=%s)", event.event_type) active.session.on_event(_persist) diff --git a/backend/openmlr/routes/projects.py b/backend/openmlr/routes/projects.py index c3857e0..5cbc25f 100644 --- a/backend/openmlr/routes/projects.py +++ b/backend/openmlr/routes/projects.py @@ -58,7 +58,12 @@ def _get_workspaces_root() -> Path: async def get_or_create_default_project(db, user_id: int): - """Get (or create) the user's default project. Every user has exactly one.""" + """Get (or create) the user's default project. + + DEPRECATED: This exists only as a fallback for the terminal route. + New code should not call this — all conversations must belong to + a user-created project. + """ existing = await ops.get_project_by_slug(db, user_id, DEFAULT_PROJECT_SLUG) if existing: return existing @@ -71,7 +76,7 @@ async def get_or_create_default_project(db, user_id: int): user_id, DEFAULT_PROJECT_NAME, DEFAULT_PROJECT_SLUG, - description="Default workspace for all conversations", + description="Default workspace (legacy fallback)", workspace_path=workspace_path, settings={"is_default": True}, ) @@ -176,16 +181,15 @@ async def list_projects( user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): - """List all projects for the current user. Ensures default project exists.""" - # Ensure default project exists - await get_or_create_default_project(db, user.id) - + """List all projects for the current user. Excludes the legacy default project.""" projects = await ops.get_user_projects(db, user.id, include_archived=include_archived) result = [] for p in projects: + # Skip the legacy default project — it shouldn't appear in the UI + if p.settings and p.settings.get("is_default"): + continue convs = await ops.get_project_conversations(db, p.id) d = _project_dict(p, conv_count=len(convs)) - d["is_default"] = bool(p.settings and p.settings.get("is_default")) result.append(d) return {"projects": result} diff --git a/backend/openmlr/routes/terminal.py b/backend/openmlr/routes/terminal.py index cba8ba4..6f1c0a0 100644 --- a/backend/openmlr/routes/terminal.py +++ b/backend/openmlr/routes/terminal.py @@ -156,7 +156,8 @@ async def terminal_websocket( if project_uuid: project = await ops.get_project_by_uuid(db, project_uuid, user.id) else: - # Use default project + # Legacy fallback — terminal should always receive a project UUID + # from the frontend. This path will be removed in a future version. from .projects import get_or_create_default_project project = await get_or_create_default_project(db, user.id) diff --git a/backend/openmlr/services/redis_pubsub.py b/backend/openmlr/services/redis_pubsub.py index 6883114..d5d8fe3 100644 --- a/backend/openmlr/services/redis_pubsub.py +++ b/backend/openmlr/services/redis_pubsub.py @@ -213,3 +213,40 @@ async def wait_for_answers(conversation_id: int, timeout: float = 300) -> dict | except Exception as e: logger.warning(f"Failed to wait for answers from Redis: {e}") return None + + +# ── TODO approval relay for background jobs ────────────── + +TODO_APPROVAL_KEY_PREFIX = "openmlr:todo_approval:" + + +async def publish_todo_approval(conversation_id: int, result: dict) -> None: + """Publish TODO approval/rejection to Redis for the background worker.""" + try: + client = await get_redis() + key = f"{TODO_APPROVAL_KEY_PREFIX}{conversation_id}" + await client.set(key, json.dumps(result), ex=600) + await client.publish(f"{TODO_APPROVAL_KEY_PREFIX}notify", str(conversation_id)) + except Exception as e: + logger.warning(f"Failed to publish todo approval to Redis: {e}") + + +async def wait_for_todo_approval(conversation_id: int, timeout: float = 300) -> dict | None: + """Wait for TODO approval from Redis. Used by background worker's plan_tool.""" + try: + client = await get_redis() + key = f"{TODO_APPROVAL_KEY_PREFIX}{conversation_id}" + + elapsed = 0.0 + while elapsed < timeout: + data = await client.get(key) + if data: + await client.delete(key) + return json.loads(data) + await asyncio.sleep(1.0) + elapsed += 1.0 + + return None + except Exception as e: + logger.warning(f"Failed to wait for todo approval from Redis: {e}") + return None diff --git a/backend/openmlr/services/session_manager.py b/backend/openmlr/services/session_manager.py index 5a48cb3..cb73799 100644 --- a/backend/openmlr/services/session_manager.py +++ b/backend/openmlr/services/session_manager.py @@ -137,6 +137,31 @@ async def get_or_create_session( conversation_uuid=uuid, ) + # Resolve the project workspace path for workspace tools and local tools. + # If the conversation belongs to a project with a workspace_path, bind it. + project_workspace_path: str | None = None + if user_id and db: + try: + conv = await ops.get_conversation_by_id(db, conversation_id) + if conv and conv.project_id: + project = await ops.get_project_by_id(db, conv.project_id) + if project and project.workspace_path: + project_workspace_path = project.workspace_path + except Exception as e: + log.warning(f"Session {conversation_id}: failed to resolve project workspace - {e}") + + # Activate workspace tools (knowledge graph, notes, persistence) + # and local tools (read/write/edit/bash target the project directory) + if project_workspace_path: + from ..tools.local import set_project_workspace + from ..tools.workspace_tools import set_workspace_context + + set_workspace_context(project_workspace_path) + set_project_workspace(project_workspace_path) + log.info( + f"Session {conversation_id}: workspace context set to {project_workspace_path}" + ) + # If a compute node is configured, activate it if effective_node: try: @@ -224,8 +249,12 @@ async def get_or_create_session( compute_env=compute_env, ) - # Wire event broadcasting + # Wire event broadcasting — inject conversation_uuid so the frontend + # can filter events per conversation (mirrors the Celery path). async def _broadcast(event: AgentEvent): + if event.data is None: + event.data = {} + event.data["conversation_uuid"] = uuid await self.event_bus.broadcast(event) session.on_event(_broadcast) @@ -258,6 +287,15 @@ async def remove_session(self, conversation_id: int) -> None: active.session.pending_answers.cancel() except Exception: pass + if ( + hasattr(active.session, "pending_todo_approval") + and active.session.pending_todo_approval + ): + try: + if not active.session.pending_todo_approval.done(): + active.session.pending_todo_approval.cancel() + except Exception: + pass try: await active.sandbox_manager.destroy() except Exception: diff --git a/backend/openmlr/tasks/agent_tasks.py b/backend/openmlr/tasks/agent_tasks.py index 3e9dd2b..15145b9 100644 --- a/backend/openmlr/tasks/agent_tasks.py +++ b/backend/openmlr/tasks/agent_tasks.py @@ -113,6 +113,26 @@ async def _async_process_message( sandbox_manager = SandboxManager() tool_router = create_tool_router(sandbox_manager) + # Resolve project workspace for workspace tools and local tools + async with worker_session() as db: + try: + conv = await ops.get_conversation_by_id(db, conversation_id) + if conv and conv.project_id: + from ..db.operations import get_project_by_id + + project = await get_project_by_id(db, conv.project_id) + if project and project.workspace_path: + from ..tools.local import set_project_workspace + from ..tools.workspace_tools import set_workspace_context + + set_workspace_context(project.workspace_path) + set_project_workspace(project.workspace_path) + logger.info( + f"Worker job {job_id}: workspace context set to {project.workspace_path}" + ) + except Exception as e: + logger.warning(f"Worker job {job_id}: failed to resolve project workspace - {e}") + # Build and set system prompt session.context_manager.system_prompt = build_system_prompt( tool_specs=tool_router.get_raw_specs(), diff --git a/backend/openmlr/tools/local.py b/backend/openmlr/tools/local.py index fa0a6af..4904ed0 100644 --- a/backend/openmlr/tools/local.py +++ b/backend/openmlr/tools/local.py @@ -10,6 +10,7 @@ import asyncio import logging import os +from contextvars import ContextVar from pathlib import Path from ..agent.types import ToolSpec @@ -26,6 +27,30 @@ # Files outside this directory cannot be read/written/edited WORKSPACE_ROOT = os.environ.get("OPENMLR_WORKSPACE_ROOT", "") +# Per-async-context project workspace override. When a project is active its +# workspace path is injected here so that read/write/edit/bash tools +# automatically target the project directory (files then appear in the +# frontend FileTree). +_project_workspace_var: ContextVar[str | None] = ContextVar("project_workspace", default=None) + + +def set_project_workspace(path: str | None) -> None: + """Set the active project workspace for the current async context.""" + _project_workspace_var.set(path) + + +def _get_effective_root() -> Path: + """Return the effective workspace root for file operations. + + Priority: project workspace > WORKSPACE_ROOT env var > cwd. + """ + project_ws = _project_workspace_var.get(None) + if project_ws: + return Path(project_ws).resolve() + if WORKSPACE_ROOT: + return Path(WORKSPACE_ROOT).resolve() + return Path.cwd().resolve() + def _running_in_container() -> bool: """Detect if we're running inside a Docker container. @@ -57,41 +82,44 @@ def _validate_path(path: Path) -> tuple[Path, str | None]: """Validate path is within allowed workspace. Returns (resolved_path, error_or_none).""" try: resolved = path.resolve() - - # If WORKSPACE_ROOT is set, enforce it - if WORKSPACE_ROOT: - workspace = Path(WORKSPACE_ROOT).resolve() - try: - resolved.relative_to(workspace) - except ValueError: - return resolved, f"Path {resolved} is outside workspace {workspace}" - else: - # Default: allow paths under current working directory - cwd = Path.cwd().resolve() - try: - resolved.relative_to(cwd) - except ValueError: - # Also allow paths that are explicitly absolute and exist (for reading configs etc) - # But block obvious dangerous paths - dangerous_prefixes = [ - "/etc", - "/root", - "/var", - "/usr", - "/bin", - "/sbin", - "/boot", - "/sys", - "/proc", - ] - for prefix in dangerous_prefixes: - if str(resolved).startswith(prefix): - return ( - resolved, - f"Access denied: {resolved} is in a protected system directory", - ) - - return resolved, None + effective_root = _get_effective_root() + + try: + resolved.relative_to(effective_root) + return resolved, None + except ValueError: + pass + + # Also allow CWD even when project workspace is active (for read-only + # access to project configuration files etc.) + cwd = Path.cwd().resolve() + try: + resolved.relative_to(cwd) + return resolved, None + except ValueError: + pass + + # Block obvious dangerous paths + dangerous_prefixes = [ + "/etc", + "/root", + "/var", + "/usr", + "/bin", + "/sbin", + "/boot", + "/sys", + "/proc", + ] + for prefix in dangerous_prefixes: + if str(resolved).startswith(prefix): + return ( + resolved, + f"Access denied: {resolved} is in a protected system directory", + ) + + # If none of the above matched, reject unless under effective root + return resolved, f"Path {resolved} is outside workspace {effective_root}" except Exception as e: return path, f"Path validation error: {e}" @@ -101,9 +129,17 @@ def create_local_tools() -> list[ToolSpec]: ToolSpec( name="bash", description=( - "Execute a shell command. When running locally, uses Docker for isolation. " - "When running in a containerized deployment, commands run directly in the isolated environment. " - "Use for running scripts, installing packages, training models, etc." + "Execute a shell command in the project workspace.\n\n" + "Commands run in Docker isolation with: 8GB memory limit, 256 process limit, " + "read-only root filesystem (/tmp is writable), bridge network.\n\n" + "Use for: running scripts, installing packages (pip/conda), training models, " + "data processing, system commands.\n\n" + "The working directory is the project workspace. Files created here appear " + "in the Files tab.\n\n" + "Common patterns:\n" + "- Install deps: bash(command='pip install torch transformers')\n" + "- Run script: bash(command='python train.py --epochs 10')\n" + "- Check env: bash(command='python --version && pip list')" ), parameters={ "type": "object", @@ -124,7 +160,14 @@ def create_local_tools() -> list[ToolSpec]: ), ToolSpec( name="read", - description="Read a file from the local filesystem with line numbers.", + description=( + "Read a file from the project workspace with line numbers.\n\n" + "Returns up to 2000 lines starting from the given offset. " + "Use to inspect code, data files, logs, or configuration.\n" + "Relative paths resolve from the project workspace root.\n\n" + "For large files use offset/limit:\n" + "- read(path='train.py', offset=100, limit=50) reads lines 100-149" + ), parameters={ "type": "object", "properties": { @@ -141,7 +184,14 @@ def create_local_tools() -> list[ToolSpec]: ), ToolSpec( name="write", - description="Write content to a file. Creates parent directories if needed.", + description=( + "Write content to a file in the project workspace. Creates parent " + "directories automatically.\n\n" + "Use for: source code, configuration files, scripts, data files.\n" + "Do NOT use for academic papers — use the 'writing' tool instead.\n\n" + "Relative paths resolve from the project workspace root. " + "Written files appear in the Files tab immediately." + ), parameters={ "type": "object", "properties": { @@ -154,7 +204,13 @@ def create_local_tools() -> list[ToolSpec]: ), ToolSpec( name="edit", - description="Edit a file by replacing a specific string with another.", + description=( + "Edit an existing file by replacing a specific string with another.\n\n" + "Provide the exact string to find (old_string) and its replacement " + "(new_string). Use replace_all=true to replace all occurrences.\n\n" + "If old_string matches multiple times and replace_all is false, the " + "edit fails — provide more surrounding context to make it unique." + ), parameters={ "type": "object", "properties": { @@ -195,7 +251,7 @@ async def _handle_bash( command: str, timeout: int = 120, workdir: str = None, **kwargs ) -> tuple[str, bool]: timeout = min(int(timeout), 3600) - cwd = workdir or os.getcwd() + cwd = workdir or str(_get_effective_root()) # If we're already running inside a container, execute directly # The container itself provides isolation, so no need for Docker-in-Docker @@ -321,7 +377,7 @@ async def _handle_read(path: str, offset: int = 1, limit: int = 2000, **kwargs) try: target = Path(path).expanduser() if not target.is_absolute(): - target = Path.cwd() / target + target = _get_effective_root() / target # Security: Validate path is within allowed workspace target, error = _validate_path(target) @@ -369,7 +425,7 @@ async def _handle_write(path: str = "", content: str = "", **kwargs) -> tuple[st try: target = Path(path).expanduser() if not target.is_absolute(): - target = Path.cwd() / target + target = _get_effective_root() / target # Security: Validate path is within allowed workspace target, error = _validate_path(target) @@ -389,7 +445,7 @@ async def _handle_edit( try: target = Path(path).expanduser() if not target.is_absolute(): - target = Path.cwd() / target + target = _get_effective_root() / target # Security: Validate path is within allowed workspace target, error = _validate_path(target) diff --git a/backend/openmlr/tools/plan.py b/backend/openmlr/tools/plan.py index ccb916f..fda8961 100644 --- a/backend/openmlr/tools/plan.py +++ b/backend/openmlr/tools/plan.py @@ -1,8 +1,11 @@ """Plan tool — task tracking with completion reports and plan change approval. Tasks and resources are persisted to the database per conversation. +In Execute mode, structural plan changes (create, add) require user approval +via a dedicated TODO review UI. """ +import asyncio import logging from datetime import UTC, datetime @@ -30,16 +33,19 @@ def create_plan_tool() -> ToolSpec: return ToolSpec( name="plan_tool", description=( - "Manage the task plan. Emits live updates to the UI panel.\n" + "Manage the task plan. Updates are shown live in the Tasks panel and " + "PLAN.md is auto-saved to the workspace (.project-meta/plans/PLAN.md).\n\n" "Operations:\n" - " 'create' — create a new task list (propose to user for approval)\n" - " 'update' — change task status (when marking completed, provide a summary)\n" - " 'get' — show current plan\n" - " 'add' — add a task (propose to user for approval)\n" - " 'add_resource' — track a paper/URL/code/report the agent has read\n" - "When marking a task completed, include a 'summary' with key findings " - "and a 'next_hints' with recommendations for upcoming tasks. " - "The tool auto-generates a completion report stored as a resource." + "- 'create': Create a new task list. In Execute mode, requires user approval.\n" + "- 'update': Change task status. When completing: provide summary + next_hints.\n" + " The system auto-generates a report saved to .project-meta/reports/.\n" + "- 'get': Show the current plan with status and hints from the latest report.\n" + "- 'add': Add a single task. In Execute mode, requires user approval.\n" + "- 'add_resource': Track a paper/URL/code/report the agent has found.\n\n" + "Enforcement:\n" + "- Completing a task without a summary is rejected by the system.\n" + "- Starting a new task while another is in_progress is blocked.\n" + "- Work tools are blocked unless a task is marked in_progress." ), parameters={ "type": "object", @@ -114,6 +120,10 @@ async def _handle_plan( logger.warning("No conversation_id in session, plan tool cannot persist") return "Error: No active conversation.", False + # Detect current mode from session for approval gating + current_mode = getattr(session, "current_mode", "plan") if session else "plan" + needs_todo_approval = current_mode == "execute" and operation in ("create", "add") + session_factory = _get_session_factory() async with session_factory() as db: if operation == "create": @@ -123,6 +133,16 @@ async def _handle_plan( task_list = [ {"title": t.get("title", ""), "status": t.get("status", "pending")} for t in tasks ] + + # In Execute mode, request user approval before applying plan changes + if needs_todo_approval and session: + approved_tasks = await _request_todo_approval( + session, conv_id, db, "create", proposed_tasks=task_list + ) + if approved_tasks is None: + return "User rejected the proposed plan.", False + task_list = approved_tasks + await ops.upsert_conversation_tasks(db, conv_id, task_list) _sync_session_plan(session, task_list) await _emit_plan(session, conv_id, db) @@ -132,6 +152,10 @@ async def _handle_plan( await ops.upsert_plan_resource(db, conv_id, plan_md) await _emit_resources(session, conv_id, db) + # Write PLAN.md to workspace filesystem + await _write_to_workspace(conv_id, "PLAN.md", plan_md, ".project-meta/plans") + await _emit_files_changed(session, ".project-meta/plans") + return await _format_plan(db, conv_id), True elif operation == "add": @@ -143,15 +167,36 @@ async def _handle_plan( task_list = [ {"title": t.title, "status": t.status, "priority": t.priority} for t in existing ] - task_list.append({"title": title, "status": "pending"}) - await ops.upsert_conversation_tasks(db, conv_id, task_list) - _sync_session_plan(session, task_list) + new_task = {"title": title, "status": "pending"} + proposed_list = task_list + [new_task] + + # In Execute mode, request user approval before adding + if needs_todo_approval and session: + approved_tasks = await _request_todo_approval( + session, + conv_id, + db, + "add", + proposed_tasks=proposed_list, + current_tasks=task_list, + ) + if approved_tasks is None: + return "User rejected the proposed task addition.", False + proposed_list = approved_tasks + + await ops.upsert_conversation_tasks(db, conv_id, proposed_list) + _sync_session_plan(session, proposed_list) await _emit_plan(session, conv_id, db) # Update PLAN.md - await ops.upsert_plan_resource(db, conv_id, _generate_plan_md(task_list)) + plan_md = _generate_plan_md(proposed_list) + await ops.upsert_plan_resource(db, conv_id, plan_md) await _emit_resources(session, conv_id, db) + # Write PLAN.md to workspace filesystem + await _write_to_workspace(conv_id, "PLAN.md", plan_md, ".project-meta/plans") + await _emit_files_changed(session, ".project-meta/plans") + return await _format_plan(db, conv_id), True elif operation == "update": @@ -216,6 +261,9 @@ async def _handle_plan( plan_md = _generate_plan_md(task_list) await ops.upsert_plan_resource(db, conv_id, plan_md) + # Write PLAN.md to workspace filesystem + await _write_to_workspace(conv_id, "PLAN.md", plan_md, ".project-meta/plans") + # ── POST-UPDATE: Generate completion report if task was completed ── if status == "completed" and old_status != "completed": @@ -232,6 +280,15 @@ async def _handle_plan( ) await _emit_resources(session, conv_id, db) + # Write report to workspace filesystem + from ..workspace.persistence import WorkspacePersistence + + safe_title = WorkspacePersistence._sanitize_filename(task.title) + await _write_to_workspace( + conv_id, f"{safe_title}.md", report, ".project-meta/reports" + ) + await _emit_files_changed(session, ".project-meta/reports") + result = await _format_plan(db, conv_id) result += f"\n\nCompletion report generated for: {task.title}" if next_hints: @@ -239,6 +296,7 @@ async def _handle_plan( return result, True await _emit_resources(session, conv_id, db) + await _emit_files_changed(session, ".project-meta/plans") return await _format_plan(db, conv_id), True elif operation == "get": @@ -290,6 +348,34 @@ async def _handle_plan( return f"Unknown operation: {operation}", False +async def _write_to_workspace(conv_id: int, filename: str, content: str, subdir: str = "") -> None: + """Write a resource file to the project workspace so it appears in the FileTree. + + Silently skips if there is no project workspace (e.g., no active project). + """ + try: + session_factory = _get_session_factory() + async with session_factory() as db: + ws_path = await ops.get_project_workspace_for_conversation(db, conv_id) + if not ws_path: + return + from pathlib import Path + + target_dir = Path(ws_path) + if subdir: + target_dir = target_dir / subdir + target_dir.mkdir(parents=True, exist_ok=True) + (target_dir / filename).write_text(content, encoding="utf-8") + except Exception as e: + logger.warning(f"Failed to write {filename} to workspace: {e}") + + +async def _emit_files_changed(session, path: str = "") -> None: + """Notify the frontend that workspace files changed so FileTree refreshes.""" + if session: + await session.emit(AgentEvent(event_type="workspace_files_changed", data={"path": path})) + + def _sync_session_plan(session, task_list: list[dict]) -> None: """Update the session's cached plan state for tool enforcement.""" if session: @@ -384,6 +470,65 @@ async def _emit_resources(session, conv_id: int, db): ) +async def _request_todo_approval( + session, + conv_id: int, + db, + change_type: str, + proposed_tasks: list[dict], + current_tasks: list[dict] | None = None, +) -> list[dict] | None: + """Emit a todo_approval_required event and wait for the user's response. + + Returns the (possibly modified) task list if approved, or None if rejected. + Uses the same Future-based pattern as ask_user. + """ + import os + + # Build the payload for the frontend + payload = { + "change_type": change_type, # "create" or "add" + "proposed_tasks": proposed_tasks, + "current_tasks": current_tasks or [], + } + + await session.emit(AgentEvent(event_type="todo_approval_required", data=payload)) + + result = None + + # Try Redis-based relay first (background jobs) + try: + from ..services.redis_pubsub import wait_for_todo_approval + + if os.environ.get("USE_BACKGROUND_JOBS", "").lower() in ("true", "1", "yes"): + result = await wait_for_todo_approval(session.conversation_id, timeout=300) + except Exception: + pass + + # Fallback: in-process Future (inline mode) + if result is None: + future = asyncio.get_event_loop().create_future() + session.pending_todo_approval = future + + try: + result = await asyncio.wait_for(future, timeout=300) + except TimeoutError: + session.pending_todo_approval = None + return None + + session.pending_todo_approval = None + + if not result: + return None + + # result: {"approved": bool, "tasks": [...] | None} + if not result.get("approved"): + return None + + # If user modified the tasks, use their version + return result.get("tasks") or proposed_tasks + + async def _format_plan(db, conv_id: int) -> str: """Format the plan as a string for LLM context.""" tasks = await ops.get_conversation_tasks(db, conv_id) diff --git a/backend/openmlr/tools/registry.py b/backend/openmlr/tools/registry.py index eb46b84..aa84f52 100644 --- a/backend/openmlr/tools/registry.py +++ b/backend/openmlr/tools/registry.py @@ -57,6 +57,24 @@ } +# Tools that count as "research" for plan-mode budget tracking +_RESEARCH_TOOLS = { + "web_search", + "papers", + "github_search_repos", + "github_find_examples", + "github_get_readme", + "github_read_file", + "github_list_repos", + "hf_search_models", + "hf_model_info", + "hf_search_datasets", + "hf_dataset_info", + "hf_read_file", +} +_PLAN_RESEARCH_LIMIT = 5 # warn after this many research calls in plan mode + + class ToolRouter: """Central tool registry and dispatcher.""" @@ -67,6 +85,7 @@ def __init__(self): self._current_mode: str = "general" self._user_id: int | None = None self._db = None + self._plan_research_calls: int = 0 def set_context(self, user_id: int | None = None, db=None) -> None: """Set per-request context (user_id, db) for tools that need them.""" @@ -86,6 +105,8 @@ def register_many(self, specs: list[ToolSpec]) -> None: def set_mode(self, mode: str) -> None: """Set the current operating mode for tool filtering.""" + if mode != self._current_mode: + self._plan_research_calls = 0 self._current_mode = mode def get_mode(self) -> str: @@ -240,6 +261,14 @@ async def call_tool( ) return warning, False + # ENFORCEMENT: In plan mode, track research tool usage and warn if excessive. + if enforce_mode and self._current_mode == "plan" and name in _RESEARCH_TOOLS: + self._plan_research_calls += 1 + if self._plan_research_calls > _PLAN_RESEARCH_LIMIT: + logger.info( + f"Plan-mode research budget exceeded ({self._plan_research_calls} calls)" + ) + # ENFORCEMENT: In execute mode, work tools require an in_progress task. # "Work tools" = anything NOT in the plan-mode allowed set (read-only tools). if enforce_mode and session and self._current_mode == "execute": @@ -247,6 +276,22 @@ async def call_tool( if violation: return violation, False + # Prepare the plan-mode research budget warning (appended to output below) + _research_warning = "" + if ( + enforce_mode + and self._current_mode == "plan" + and name in _RESEARCH_TOOLS + and self._plan_research_calls > _PLAN_RESEARCH_LIMIT + ): + _research_warning = ( + f"\n\n[PLAN MODE RESEARCH BUDGET: You have made " + f"{self._plan_research_calls} research tool calls in Plan mode. " + f"This is getting excessive. Plan mode is for quick feasibility checks, " + f"not comprehensive research. Please add remaining research as tasks in " + f"the plan for Execute mode and finalize the plan now.]" + ) + tool = self.tools.get(name) if not tool: return f"Unknown tool: {name}", False @@ -267,7 +312,12 @@ async def call_tool( if "db" in sig.parameters and "db" not in kwargs: kwargs["db"] = self._db try: - return await tool.handler(**kwargs) if kwargs else await tool.handler(**arguments) + output, success = ( + await tool.handler(**kwargs) if kwargs else await tool.handler(**arguments) + ) + if _research_warning: + output += _research_warning + return output, success except TypeError as e: # Handle argument mismatches (model sending wrong param names) return ( @@ -279,7 +329,10 @@ async def call_tool( if self._mcp_client: try: result = await self._mcp_client.call_tool(name, arguments) - return _convert_mcp_content(result), True + output = _convert_mcp_content(result) + if _research_warning: + output += _research_warning + return output, True except Exception as e: return f"MCP tool error: {str(e)}", False diff --git a/backend/openmlr/tools/research.py b/backend/openmlr/tools/research.py index ac6603d..49ecda0 100644 --- a/backend/openmlr/tools/research.py +++ b/backend/openmlr/tools/research.py @@ -16,21 +16,50 @@ research a topic using the tools available to you. You have independent context that won't affect the main conversation. +## Available Tools + +You can use ONLY these read-only tools: +- **web_search**: General web search for docs, blog posts, tutorials +- **papers**: Academic paper search (OpenAlex, arXiv, Semantic Scholar, Citations, etc.) +- **github_read_file**: Read specific files from GitHub repos +- **github_find_examples**: Search code patterns across GitHub +- **hf_search_models**: Search Hugging Face models +- **hf_read_file**: Read files from Hugging Face repos + +You CANNOT write files, run code, modify the knowledge graph, or save notes. +The parent agent will save your findings after you return them. + +## Constraints + +- Maximum 60 tool calls (iterations) +- Token budget: ~190k tokens — stop and synthesize before hitting the limit +- If an API call fails, try an alternative source (e.g., arXiv instead of OpenAlex) +- Do NOT keep retrying the same failed call + ## Research Protocol -1. Start broad: search for papers, docs, and code examples -2. Go deep: read key papers section by section, crawl citation graphs -3. Cross-reference: compare methodologies across papers -4. Synthesize: create structured summaries with recipe tables +1. **Start broad**: search for papers, docs, and code examples (3-5 searches) +2. **Go deep**: read key papers section by section, crawl citation graphs +3. **Cross-reference**: compare methodologies across papers +4. **Synthesize**: create structured summaries with recipe tables + +## When to Stop + +Stop researching and synthesize when ANY of these are true: +- You have found 8+ relevant papers with clear methodology details +- You have used 40+ tool calls +- You are getting diminishing returns (same papers appearing repeatedly) +- You have enough information to answer the original question ## Output Format -When done, provide a structured summary: -- Key findings (bulleted list) -- Recipe table if applicable: +Provide a structured summary with: +- **Key findings** (bulleted list, most important first) +- **Recipe table** if applicable: | Paper | Result | Dataset | Method | Key Insight | -- Recommended approach with citations -- Links to relevant code/repos +- **Recommended approach** with citations +- **Links** to relevant code repos and implementations +- **Open questions** that could not be resolved Be thorough but concise. Focus on actionable information.""" @@ -39,10 +68,18 @@ def create_research_tool() -> ToolSpec: return ToolSpec( name="research", description=( - "Spawn an independent research sub-agent that searches docs, papers, " - "and code without affecting the main conversation context. " - "Use for deep dives into topics, literature surveys, " - "finding implementations, etc. Returns structured findings." + "Spawn an independent research sub-agent for deep investigation.\n\n" + "The sub-agent has its own context window and can make up to 60 tool " + "calls using: web_search, papers, github_read_file, github_find_examples, " + "hf_search_models, hf_read_file.\n\n" + "Use when you need:\n" + "- Comprehensive literature review (10+ papers)\n" + "- Deep analysis of a specific methodology\n" + "- Cross-referencing multiple papers' approaches\n" + "- Finding and comparing code implementations\n\n" + "Do NOT use for quick lookups — use papers/web_search directly instead.\n" + "The sub-agent returns structured findings; save important ones via " + "workspace knowledge_add and plan_tool add_resource." ), parameters={ "type": "object", diff --git a/backend/openmlr/tools/writing.py b/backend/openmlr/tools/writing.py index 8b3823a..2f0b560 100644 --- a/backend/openmlr/tools/writing.py +++ b/backend/openmlr/tools/writing.py @@ -70,7 +70,7 @@ async def _get_author_info(db, conv_id: int) -> dict | None: async def _save_project(conv_id: int, proj: dict) -> None: - """Save project metadata and draft to DB.""" + """Save project metadata and draft to DB and workspace filesystem.""" _projects[conv_id] = proj session_factory = _get_session_factory() @@ -89,16 +89,52 @@ async def _save_project(conv_id: int, proj: dict) -> None: draft, _ = _get_draft_from_proj(proj, author_info) await ops.upsert_paper_resource(db, conv_id, proj.get("title", "Paper"), draft) + # Also write the paper draft and metadata to the project workspace + # so they appear in the FileTree. + try: + ws_path = await ops.get_project_workspace_for_conversation(db, conv_id) + if ws_path: + from pathlib import Path + + papers_dir = Path(ws_path) / "papers" + papers_dir.mkdir(parents=True, exist_ok=True) + + # Sanitize title for filename + safe_title = ( + "".join( + c if c.isalnum() or c in "-_ " else "_" for c in proj.get("title", "paper") + )[:80] + .strip() + .replace(" ", "_") + or "paper" + ) + + (papers_dir / f"{safe_title}.md").write_text(draft, encoding="utf-8") + (papers_dir / f".{safe_title}.meta.json").write_text( + json.dumps(proj, indent=2, default=str), encoding="utf-8" + ) + except Exception as e: + logger.warning(f"Failed to write paper to workspace: {e}") + def create_writing_tool() -> ToolSpec: return ToolSpec( name="writing", description=( - "Manage academic paper writing. Supports section-by-section writing.\n" - "Operations: create_project, set_outline, write_section, refine_section, " - "add_citation, get_draft, list_sections.\n\n" - "The paper is auto-saved after each write. Users can preview and export " - "from the Paper tab in the UI — do NOT use the 'write' file tool for papers." + "Manage academic paper writing with section-by-section authoring.\n\n" + "Workflow: create_project -> set_outline -> write_section (for each) -> " + "add_citation -> get_draft to review.\n\n" + "Operations:\n" + "- create_project: Start a new paper with a title\n" + "- set_outline: Define section structure [{id, title, subsections}]\n" + "- write_section: Write content for a section by ID\n" + "- refine_section: Revise an existing section\n" + "- add_citation: Add a reference to the bibliography\n" + "- get_draft: Review the full rendered paper\n" + "- list_sections: Check which sections are done/pending\n\n" + "The paper auto-saves to the database AND to papers/ in the workspace " + "(visible in Files tab). Do NOT use the 'write' tool for papers.\n" + "You MUST write ALL sections — do NOT leave '[Not yet written]' placeholders." ), parameters={ "type": "object", @@ -198,6 +234,7 @@ async def _handle_writing( if ok and conv_id: await _save_project(conv_id, _projects[conv_id]) await _emit_resources(session, conv_id) + await _emit_files_changed(session, "papers") return result, ok # For all other operations, try to load existing project @@ -209,24 +246,28 @@ async def _handle_writing( if ok and conv_id: await _save_project(conv_id, _projects[conv_id]) await _emit_resources(session, conv_id) + await _emit_files_changed(session, "papers") return result, ok elif operation == "write_section": result, ok = _write_section(conv_id, section_id, content) if ok and conv_id: await _save_project(conv_id, _projects[conv_id]) await _emit_resources(session, conv_id) + await _emit_files_changed(session, "papers") return result, ok elif operation == "refine_section": result, ok = _refine_section(conv_id, section_id, content, feedback) if ok and content and conv_id: await _save_project(conv_id, _projects[conv_id]) await _emit_resources(session, conv_id) + await _emit_files_changed(session, "papers") return result, ok elif operation == "add_citation": result, ok = _add_citation(conv_id, citation) if ok and conv_id: await _save_project(conv_id, _projects[conv_id]) await _emit_resources(session, conv_id) + await _emit_files_changed(session, "papers") return result, ok elif operation == "get_draft": return await _get_draft(conv_id) @@ -285,10 +326,20 @@ def _write_section(conv_id: int, section_id: str, content: str) -> tuple[str, bo proj["sections"][section_id] = content written = len(proj["sections"]) total = _count_sections(proj["outline"]) - return ( + incomplete = _get_incomplete_sections(proj) + msg = ( f"Section '{section_id}' written ({len(content)} chars). " f"Progress: {written}/{total} sections. Paper auto-saved." - ), True + ) + if incomplete: + msg += ( + f"\n\nRemaining incomplete sections ({len(incomplete)}): " + + ", ".join(incomplete) + + "\nYou MUST write all remaining sections — do NOT leave placeholders." + ) + else: + msg += "\n\nAll sections are now written." + return msg, True def _refine_section(conv_id: int, section_id: str, content: str, feedback: str) -> tuple[str, bool]: @@ -335,7 +386,39 @@ async def _get_draft(conv_id: int) -> tuple[str, bool]: async with session_factory() as db: author_info = await _get_author_info(db, conv_id) - return _get_draft_from_proj(proj, author_info) + draft, ok = _get_draft_from_proj(proj, author_info) + + # Append warning about incomplete sections so the agent cannot + # consider the paper finished while placeholders remain. + incomplete = _get_incomplete_sections(proj) + if incomplete: + draft += ( + "\n\n---\n" + f"**WARNING — {len(incomplete)} section(s) still incomplete " + "(marked '[Not yet written]'):**\n" + ) + for sec in incomplete: + draft += f" - {sec}\n" + draft += ( + "\nYou MUST write content for every section before the paper " + "can be considered complete. Do NOT leave placeholder sections." + ) + + return draft, ok + + +def _get_incomplete_sections(proj: dict) -> list[str]: + """Return a list of section IDs that still have placeholder content.""" + incomplete = [] + for sec in proj.get("outline", []): + sid = sec.get("id", "") + if sid and sid not in proj.get("sections", {}): + incomplete.append(f"{sid} ({sec.get('title', '')})") + for sub in sec.get("subsections", []): + sub_id = sub.get("id", "") + if sub_id and sub_id not in proj.get("sections", {}): + incomplete.append(f"{sub_id} ({sub.get('title', '')})") + return incomplete def _get_draft_from_proj(proj: dict, author_info: dict | None = None) -> tuple[str, bool]: @@ -405,6 +488,12 @@ def _list_sections(conv_id: int) -> tuple[str, bool]: return "\n".join(lines), True +async def _emit_files_changed(session, path: str = "") -> None: + """Notify the frontend that workspace files changed so FileTree refreshes.""" + if session: + await session.emit(AgentEvent(event_type="workspace_files_changed", data={"path": path})) + + async def _emit_resources(session, conv_id: int) -> None: """Emit resources update event to frontend.""" if not session: diff --git a/backend/openmlr/workspace/persistence.py b/backend/openmlr/workspace/persistence.py index 1ca330d..241694a 100644 --- a/backend/openmlr/workspace/persistence.py +++ b/backend/openmlr/workspace/persistence.py @@ -298,6 +298,18 @@ def update_state(self, **kwargs) -> dict: self.save_state(state) return state + # ── Report Storage ───────────────────────────────────── + + def save_report(self, title: str, content: str) -> Path: + """Save a task completion report to the workspace.""" + dir_path = self._ensure_dir(".project-meta", "reports") + safe_title = self._sanitize_filename(title) + filename = f"{safe_title}.md" + filepath = dir_path / filename + filepath.write_text(content, encoding="utf-8") + log.debug(f"Saved report to {filepath}") + return filepath + # ── Plan Storage ───────────────────────────────────── def save_plan( diff --git a/backend/pyproject.toml b/backend/pyproject.toml index 63816c9..aff4a1a 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -111,6 +111,9 @@ ignore = [ [tool.ruff.lint.isort] known-first-party = ["openmlr"] +[tool.pytest.ini_options] +asyncio_mode = "auto" + [tool.coverage.run] source = ["openmlr"] omit = ["openmlr/db/migrations/*"] diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py index 191665f..8a5dba8 100644 --- a/backend/tests/conftest.py +++ b/backend/tests/conftest.py @@ -10,7 +10,6 @@ pytestmark = pytest.mark.asyncio -import asyncio from collections.abc import AsyncGenerator import httpx @@ -55,14 +54,6 @@ # --------------------------------------------------------------------------- -@pytest_asyncio.fixture(scope="session") -def event_loop(): - """Create a single event loop for the entire test session.""" - loop = asyncio.new_event_loop() - yield loop - loop.close() - - @pytest_asyncio.fixture(autouse=True) async def _setup_db(): """Create all tables before each test and drop them after. @@ -76,6 +67,27 @@ async def _setup_db(): await conn.run_sync(Base.metadata.drop_all) +@pytest.fixture(autouse=True, scope="session") +def _dispose_engine_at_exit(): + """Dispose the async engine after the test session completes. + + Without this, the aiosqlite connection pool keeps background tasks + alive and prevents the process from exiting after tests finish. + """ + yield + import asyncio + + try: + loop = asyncio.get_event_loop() + if loop.is_running(): + loop.create_task(_test_engine.dispose()) + else: + loop.run_until_complete(_test_engine.dispose()) + except RuntimeError: + # No event loop available — create a temporary one for cleanup + asyncio.run(_test_engine.dispose()) + + async def _override_get_db() -> AsyncGenerator[AsyncSession, None]: """Dependency override that yields a test SQLite session.""" async with _TestSessionLocal() as session: diff --git a/backend/tests/test_agent_loop.py b/backend/tests/test_agent_loop.py index d093103..860ea94 100644 --- a/backend/tests/test_agent_loop.py +++ b/backend/tests/test_agent_loop.py @@ -217,6 +217,109 @@ async def test_cancelled_stops_early(self, mock_session, mock_router): mock_session.emit.assert_any_call(AgentEvent(event_type="interrupted")) +class TestAssistantMessagePersistence: + """Tests for Fix 1: assistant_message emitted for text + tool_calls. + + These tests use a real Session object (not mocked) so the full + agent loop runs without MagicMock interaction issues. + """ + + async def test_emits_assistant_message_with_tool_calls(self, config): + """When LLM returns text AND tool calls, assistant_message should be emitted + so the text is persisted to the DB (not just held in memory).""" + session = Session(config=config) + session.config.stream = False + session.config.yolo_mode = True + + router = MagicMock(spec=ToolRouter) + router.set_mode = MagicMock() + router.get_tool_specs_for_llm = MagicMock(return_value=[]) + router.get_tool = MagicMock(return_value=None) + router.call_tool = AsyncMock(return_value=("search results", True)) + + tool_call = ToolCall(id="tc1", name="web_search", arguments={"query": "test"}) + + emitted = [] + original_emit = session.emit + + async def capture_emit(event): + emitted.append(event) + + session.emit = capture_emit + + with ( + patch("openmlr.agent.loop.LLMProvider.generate") as mock_gen, + patch("openmlr.agent.loop.detect_doom_loop", return_value=None), + ): + mock_gen.side_effect = [ + LLMResult( + content="Let me search for that.", + tool_calls=[tool_call], + finish_reason="tool_calls", + usage={"total_tokens": 50}, + ), + LLMResult( + content="Here are the results.", + tool_calls=[], + finish_reason="stop", + usage={"total_tokens": 80}, + ), + ] + await _run_agent(session, router, "find papers") + + assistant_msgs = [e for e in emitted if e.event_type == "assistant_message"] + assert len(assistant_msgs) >= 2 + contents = [e.data["content"] for e in assistant_msgs] + assert "Let me search for that." in contents + assert "Here are the results." in contents + + async def test_no_assistant_message_for_empty_text_with_tool_calls(self, config): + """When LLM returns empty text + tool calls, no assistant_message should be emitted.""" + session = Session(config=config) + session.config.stream = False + session.config.yolo_mode = True + + router = MagicMock(spec=ToolRouter) + router.set_mode = MagicMock() + router.get_tool_specs_for_llm = MagicMock(return_value=[]) + router.get_tool = MagicMock(return_value=None) + router.call_tool = AsyncMock(return_value=("file contents", True)) + + tool_call = ToolCall(id="tc1", name="read", arguments={"path": "f.txt"}) + + emitted = [] + + async def capture_emit(event): + emitted.append(event) + + session.emit = capture_emit + + with ( + patch("openmlr.agent.loop.LLMProvider.generate") as mock_gen, + patch("openmlr.agent.loop.detect_doom_loop", return_value=None), + ): + mock_gen.side_effect = [ + LLMResult( + content="", + tool_calls=[tool_call], + finish_reason="tool_calls", + usage={"total_tokens": 30}, + ), + LLMResult( + content="Done.", + tool_calls=[], + finish_reason="stop", + usage={"total_tokens": 50}, + ), + ] + await _run_agent(session, router, "read the file") + + assistant_msgs = [e for e in emitted if e.event_type == "assistant_message"] + contents = [e.data["content"] for e in assistant_msgs] + assert "" not in contents + assert "Done." in contents + + class TestRunAgentTurn: async def test_delegates_to_run_agent(self, mock_session, mock_router): mock_session.context_manager.get_messages.return_value = [] diff --git a/backend/tests/test_conversations.py b/backend/tests/test_conversations.py index 744fb83..368ea5e 100644 --- a/backend/tests/test_conversations.py +++ b/backend/tests/test_conversations.py @@ -56,8 +56,18 @@ async def _patch_app_state(): # --------------------------------------------------------------------------- +async def _ensure_project(auth_client) -> str: + """Create a test project and return its UUID.""" + resp = await auth_client.post("/api/projects", json={"name": "Test Project"}) + assert resp.status_code == 200 + return resp.json()["project"]["uuid"] + + async def _create_conversation(auth_client, **overrides): """Shortcut: create a conversation and return the response body.""" + # Ensure a project exists so conversations aren't orphaned + if "project_uuid" not in overrides: + overrides["project_uuid"] = await _ensure_project(auth_client) payload = {"title": "Test Conv", "model": "test-model", "mode": "general"} payload.update(overrides) resp = await auth_client.post("/api/conversations", json=payload) @@ -117,9 +127,10 @@ async def test_list_conversations_empty(auth_client): async def test_list_conversations(auth_client): - """GET /api/conversations returns all conversations for the user.""" - await _create_conversation(auth_client, title="First") - await _create_conversation(auth_client, title="Second") + """GET /api/conversations returns all project-scoped conversations for the user.""" + project_uuid = await _ensure_project(auth_client) + await _create_conversation(auth_client, title="First", project_uuid=project_uuid) + await _create_conversation(auth_client, title="Second", project_uuid=project_uuid) resp = await auth_client.get("/api/conversations") assert resp.status_code == 200 diff --git a/backend/tests/test_db_operations.py b/backend/tests/test_db_operations.py index 1c595f9..f4cc6a1 100644 --- a/backend/tests/test_db_operations.py +++ b/backend/tests/test_db_operations.py @@ -28,9 +28,22 @@ async def test_get_conversations_empty(self, db_session: AsyncSession, test_user convs = await ops.get_conversations(db_session, test_user.id) assert convs == [] - async def test_get_conversations(self, db_session: AsyncSession, test_user): - await ops.create_conversation(db_session, test_user.id, title="Conv 1") - await ops.create_conversation(db_session, test_user.id, title="Conv 2") + async def test_get_conversations_excludes_orphans(self, db_session: AsyncSession, test_user): + """Conversations without a project_id should not be returned.""" + # Create orphan conversation (no project) + await ops.create_conversation(db_session, test_user.id, title="Orphan") + convs = await ops.get_conversations(db_session, test_user.id) + assert len(convs) == 0 # Orphans are filtered out + + async def test_get_conversations_with_project(self, db_session: AsyncSession, test_user): + """Conversations with a project_id should be returned.""" + project = await ops.create_project(db_session, test_user.id, "Test Project", "test-project") + await ops.create_conversation( + db_session, test_user.id, title="Conv 1", project_id=project.id + ) + await ops.create_conversation( + db_session, test_user.id, title="Conv 2", project_id=project.id + ) convs = await ops.get_conversations(db_session, test_user.id) assert len(convs) == 2 assert convs[0].title == "Conv 2" # most recent first @@ -94,8 +107,16 @@ async def test_conversations_isolated_by_user(self, db_session: AsyncSession, te db_session.add(user2) await db_session.flush() - await ops.create_conversation(db_session, test_user.id, title="User 1 Conv") - await ops.create_conversation(db_session, user2.id, title="User 2 Conv") + # Both users need projects for conversations to be visible + proj1 = await ops.create_project(db_session, test_user.id, "P1", "p1") + proj2 = await ops.create_project(db_session, user2.id, "P2", "p2") + + await ops.create_conversation( + db_session, test_user.id, title="User 1 Conv", project_id=proj1.id + ) + await ops.create_conversation( + db_session, user2.id, title="User 2 Conv", project_id=proj2.id + ) convs_u1 = await ops.get_conversations(db_session, test_user.id) convs_u2 = await ops.get_conversations(db_session, user2.id) @@ -105,6 +126,53 @@ async def test_conversations_isolated_by_user(self, db_session: AsyncSession, te assert convs_u2[0].title == "User 2 Conv" +class TestDeleteOrphanConversations: + async def test_deletes_orphan_conversations(self, db_session: AsyncSession, test_user): + """Conversations with project_id=NULL should be deleted.""" + await ops.create_conversation(db_session, test_user.id, title="Orphan 1") + await ops.create_conversation(db_session, test_user.id, title="Orphan 2") + count = await ops.delete_orphan_conversations(db_session, test_user.id) + assert count == 2 + + async def test_preserves_project_conversations(self, db_session: AsyncSession, test_user): + """Conversations with a project_id should NOT be deleted.""" + proj = await ops.create_project(db_session, test_user.id, "P", "p") + await ops.create_conversation( + db_session, test_user.id, title="Has Project", project_id=proj.id + ) + await ops.create_conversation(db_session, test_user.id, title="No Project") + count = await ops.delete_orphan_conversations(db_session, test_user.id) + assert count == 1 + convs = await ops.get_conversations(db_session, test_user.id) + assert len(convs) == 1 + assert convs[0].title == "Has Project" + + async def test_no_orphans_returns_zero(self, db_session: AsyncSession, test_user): + count = await ops.delete_orphan_conversations(db_session, test_user.id) + assert count == 0 + + +class TestGetProjectWorkspaceForConversation: + async def test_returns_workspace_path(self, db_session: AsyncSession, test_user): + proj = await ops.create_project( + db_session, test_user.id, "P", "p", workspace_path="/tmp/test-ws" + ) + conv = await ops.create_conversation( + db_session, test_user.id, title="C", project_id=proj.id + ) + path = await ops.get_project_workspace_for_conversation(db_session, conv.id) + assert path == "/tmp/test-ws" + + async def test_returns_none_for_orphan(self, db_session: AsyncSession, test_user): + conv = await ops.create_conversation(db_session, test_user.id, title="Orphan") + path = await ops.get_project_workspace_for_conversation(db_session, conv.id) + assert path is None + + async def test_returns_none_for_nonexistent_conv(self, db_session: AsyncSession): + path = await ops.get_project_workspace_for_conversation(db_session, 9999) + assert path is None + + class TestMessageOperations: @pytest_asyncio.fixture(autouse=True) async def _conv(self, db_session: AsyncSession, test_user): diff --git a/backend/tests/test_models.py b/backend/tests/test_models.py index 3efc436..ef9cf1b 100644 --- a/backend/tests/test_models.py +++ b/backend/tests/test_models.py @@ -71,6 +71,7 @@ def test_defaults(self): assert c.title == "New conversation" assert c.model is None assert c.mode == "general" + assert c.project_uuid is None def test_custom(self): c = ConversationCreate(title="Research Q1", model="gpt-4o", mode="research") @@ -78,6 +79,14 @@ def test_custom(self): assert c.model == "gpt-4o" assert c.mode == "research" + def test_with_project_uuid(self): + c = ConversationCreate(title="Test", project_uuid="abc-123-def") + assert c.project_uuid == "abc-123-def" + + def test_project_uuid_defaults_to_none(self): + c = ConversationCreate() + assert c.project_uuid is None + class TestConversationResponse: def test_creation(self): diff --git a/backend/tests/test_redis_pubsub.py b/backend/tests/test_redis_pubsub.py index 1b98681..aa2f20c 100644 --- a/backend/tests/test_redis_pubsub.py +++ b/backend/tests/test_redis_pubsub.py @@ -160,6 +160,61 @@ async def test_returns_none_on_redis_error(self): assert result is None +@pytest.mark.asyncio +class TestPublishTodoApproval: + async def test_sets_todo_approval_key_in_redis(self): + import json + + from openmlr.services.redis_pubsub import publish_todo_approval + + mock_redis = AsyncMock() + with patch("openmlr.services.redis_pubsub.get_redis", return_value=mock_redis): + await publish_todo_approval(42, {"approved": True, "tasks": [{"title": "t1"}]}) + + mock_redis.set.assert_called_once() + call_args = mock_redis.set.call_args[0] + assert "todo_approval" in call_args[0] + assert "42" in call_args[0] + data = json.loads(call_args[1]) + assert data["approved"] is True + + async def test_handles_redis_error(self): + from openmlr.services.redis_pubsub import publish_todo_approval + + mock_redis = AsyncMock() + mock_redis.set.side_effect = Exception("Redis down") + with patch("openmlr.services.redis_pubsub.get_redis", return_value=mock_redis): + await publish_todo_approval(1, {"approved": False}) # should not raise + + +@pytest.mark.asyncio +class TestWaitForTodoApproval: + async def test_returns_approval_when_set(self): + import json + + from openmlr.services.redis_pubsub import wait_for_todo_approval + + result_data = {"approved": True, "tasks": [{"title": "task1"}]} + mock_redis = AsyncMock() + mock_redis.get.return_value = json.dumps(result_data) + with patch("openmlr.services.redis_pubsub.get_redis", return_value=mock_redis): + result = await wait_for_todo_approval(conversation_id=42, timeout=0.1) + + assert result is not None + assert result["approved"] is True + assert len(result["tasks"]) == 1 + + async def test_returns_none_on_timeout(self): + from openmlr.services.redis_pubsub import wait_for_todo_approval + + mock_redis = AsyncMock() + mock_redis.get.return_value = None + with patch("openmlr.services.redis_pubsub.get_redis", return_value=mock_redis): + result = await wait_for_todo_approval(conversation_id=42, timeout=0.1) + + assert result is None + + class TestModuleConstants: def test_channel_name(self): from openmlr.services.redis_pubsub import CHANNEL_NAME @@ -176,6 +231,11 @@ def test_interrupt_key_prefix(self): assert INTERRUPT_KEY_PREFIX == "openmlr:interrupt:" + def test_todo_approval_key_prefix(self): + from openmlr.services.redis_pubsub import TODO_APPROVAL_KEY_PREFIX + + assert TODO_APPROVAL_KEY_PREFIX == "openmlr:todo_approval:" + def test_redis_url_from_env(self, monkeypatch): monkeypatch.setenv("REDIS_URL", "redis://custom:6379/1") from importlib import reload diff --git a/backend/tests/test_tool_registry.py b/backend/tests/test_tool_registry.py index 13b9162..36ffbfa 100644 --- a/backend/tests/test_tool_registry.py +++ b/backend/tests/test_tool_registry.py @@ -1,8 +1,13 @@ -"""Tests for ToolRouter — registration, dispatch, mode filtering.""" +"""Tests for ToolRouter — registration, dispatch, mode filtering, research budget.""" import pytest -from openmlr.tools.registry import MODE_TOOL_RESTRICTIONS, ToolRouter +from openmlr.tools.registry import ( + _PLAN_RESEARCH_LIMIT, + _RESEARCH_TOOLS, + MODE_TOOL_RESTRICTIONS, + ToolRouter, +) pytestmark = pytest.mark.asyncio from openmlr.agent.types import ToolSpec @@ -273,3 +278,107 @@ async def test_plan_allowlist_has_no_phantom_entries(self): f"Plan allowlist contains phantom tool '{tool_name}' " f"that is not registered in the ToolRouter" ) + + +class TestPlanModeResearchBudget: + """Tests for the plan-mode research call budget and warning system.""" + + async def test_research_tools_constant_not_empty(self): + assert len(_RESEARCH_TOOLS) > 0 + assert "web_search" in _RESEARCH_TOOLS + assert "papers" in _RESEARCH_TOOLS + + async def test_budget_limit_positive(self): + assert _PLAN_RESEARCH_LIMIT > 0 + + async def test_counter_resets_on_mode_switch(self, router): + router.set_mode("plan") + router._plan_research_calls = 10 + router.set_mode("execute") + assert router._plan_research_calls == 0 + + async def test_counter_not_reset_on_same_mode(self, router): + router.set_mode("plan") + router._plan_research_calls = 5 + router.set_mode("plan") + assert router._plan_research_calls == 5 + + async def test_research_call_increments_counter(self, router): + """Calling a research tool in plan mode increments the counter.""" + + async def research_handler(**kwargs): + return "results", True + + research_tool = ToolSpec( + name="web_search", + description="Search", + parameters={"type": "object", "properties": {}}, + handler=research_handler, + ) + router.register(research_tool) + router.set_mode("plan") + assert router._plan_research_calls == 0 + + await router.call_tool("web_search", {}) + assert router._plan_research_calls == 1 + + async def test_budget_warning_appended_after_limit(self, router): + """After exceeding the limit, a warning should be appended to tool output.""" + + async def research_handler(**kwargs): + return "search results here", True + + research_tool = ToolSpec( + name="papers", + description="Papers", + parameters={"type": "object", "properties": {}}, + handler=research_handler, + ) + router.register(research_tool) + router.set_mode("plan") + router._plan_research_calls = _PLAN_RESEARCH_LIMIT # at limit + + output, success = await router.call_tool("papers", {}) + assert success is True + assert "PLAN MODE RESEARCH BUDGET" in output + assert "search results here" in output + + async def test_no_warning_under_limit(self, router): + """Under the limit, no warning should be appended.""" + + async def research_handler(**kwargs): + return "search results here", True + + research_tool = ToolSpec( + name="papers", + description="Papers", + parameters={"type": "object", "properties": {}}, + handler=research_handler, + ) + router.register(research_tool) + router.set_mode("plan") + router._plan_research_calls = 0 + + output, success = await router.call_tool("papers", {}) + assert success is True + assert "PLAN MODE RESEARCH BUDGET" not in output + + async def test_no_budget_tracking_in_execute_mode(self, router): + """Research budget should not apply in execute mode.""" + + async def research_handler(**kwargs): + return "results", True + + research_tool = ToolSpec( + name="web_search", + description="Search", + parameters={"type": "object", "properties": {}}, + handler=research_handler, + ) + router.register(research_tool) + router.set_mode("execute") + router._plan_research_calls = 100 # way over limit + + output, success = await router.call_tool("web_search", {}) + assert success is True + assert "PLAN MODE RESEARCH BUDGET" not in output diff --git a/backend/tests/test_tools_local.py b/backend/tests/test_tools_local.py index c598aa3..0e4295a 100644 --- a/backend/tests/test_tools_local.py +++ b/backend/tests/test_tools_local.py @@ -8,12 +8,14 @@ from openmlr.tools.local import ( CONTAINER_PREFIX, DOCKER_IMAGE, + _get_effective_root, _handle_edit, _handle_read, _handle_write, _running_in_container, _validate_path, create_local_tools, + set_project_workspace, ) @@ -101,6 +103,64 @@ def test_with_workspace_root_set(self, monkeypatch): assert error is None +class TestProjectWorkspace: + """Tests for project workspace targeting (set_project_workspace, _get_effective_root).""" + + def test_get_effective_root_defaults_to_cwd(self): + set_project_workspace(None) + root = _get_effective_root() + assert root == Path.cwd().resolve() + + def test_get_effective_root_uses_project_workspace(self, tmp_path): + set_project_workspace(str(tmp_path)) + try: + root = _get_effective_root() + assert root == tmp_path.resolve() + finally: + set_project_workspace(None) + + def test_get_effective_root_prefers_project_over_env(self, tmp_path, monkeypatch): + monkeypatch.setattr("openmlr.tools.local.WORKSPACE_ROOT", "/some/other/path") + set_project_workspace(str(tmp_path)) + try: + root = _get_effective_root() + assert root == tmp_path.resolve() + finally: + set_project_workspace(None) + + def test_validate_path_allows_project_workspace(self, tmp_path): + set_project_workspace(str(tmp_path)) + try: + path = tmp_path / "code" / "train.py" + resolved, error = _validate_path(path) + assert error is None + finally: + set_project_workspace(None) + + def test_validate_path_blocks_outside_project_workspace(self, tmp_path, monkeypatch): + other_dir = tmp_path / "other" + other_dir.mkdir() + project_dir = tmp_path / "project" + project_dir.mkdir() + set_project_workspace(str(project_dir)) + # Also change cwd so the "also allow CWD" fallback doesn't save it + monkeypatch.chdir(project_dir) + try: + path = other_dir / "secret.txt" + resolved, error = _validate_path(path) + assert error is not None + assert "outside workspace" in error + finally: + set_project_workspace(None) + + def test_set_project_workspace_clears(self): + set_project_workspace("/tmp/test-project") + set_project_workspace(None) + root = _get_effective_root() + # Should fall back to CWD or WORKSPACE_ROOT, not /tmp/test-project + assert str(root) != "/tmp/test-project" + + @pytest.mark.asyncio class TestHandleRead: async def test_reads_file_with_line_numbers(self, monkeypatch, tmp_path): diff --git a/backend/tests/test_tools_writing.py b/backend/tests/test_tools_writing.py index efd525d..0610b47 100644 --- a/backend/tests/test_tools_writing.py +++ b/backend/tests/test_tools_writing.py @@ -8,6 +8,7 @@ _create_project, _get_draft, _get_draft_from_proj, + _get_incomplete_sections, _list_sections, _refine_section, _set_outline, @@ -262,3 +263,140 @@ async def test_counts_with_subsections(self): async def test_empty_outline(self): assert _count_sections([]) == 0 + + +class TestGetIncompleteSections: + async def test_all_incomplete(self): + proj = { + "outline": [ + {"id": "s1", "title": "Intro"}, + {"id": "s2", "title": "Methods"}, + ], + "sections": {}, + } + incomplete = _get_incomplete_sections(proj) + assert len(incomplete) == 2 + assert "s1 (Intro)" in incomplete + assert "s2 (Methods)" in incomplete + + async def test_all_complete(self): + proj = { + "outline": [ + {"id": "s1", "title": "Intro"}, + {"id": "s2", "title": "Methods"}, + ], + "sections": {"s1": "content", "s2": "content"}, + } + incomplete = _get_incomplete_sections(proj) + assert len(incomplete) == 0 + + async def test_partial_complete(self): + proj = { + "outline": [ + {"id": "s1", "title": "Intro"}, + {"id": "s2", "title": "Methods"}, + ], + "sections": {"s1": "content"}, + } + incomplete = _get_incomplete_sections(proj) + assert len(incomplete) == 1 + assert "s2 (Methods)" in incomplete + + async def test_includes_subsections(self): + proj = { + "outline": [ + { + "id": "s1", + "title": "Methods", + "subsections": [ + {"id": "s1.1", "title": "Setup"}, + {"id": "s1.2", "title": "Training"}, + ], + }, + ], + "sections": {"s1": "content", "s1.1": "content"}, + } + incomplete = _get_incomplete_sections(proj) + assert len(incomplete) == 1 + assert "s1.2 (Training)" in incomplete + + async def test_empty_outline(self): + proj = {"outline": [], "sections": {}} + incomplete = _get_incomplete_sections(proj) + assert len(incomplete) == 0 + + +class TestWriteSectionWarnings: + async def test_shows_remaining_incomplete_sections(self): + from openmlr.tools.writing import _projects + + _projects.clear() + _create_project(conv_id=1, title="Test") + _set_outline( + conv_id=1, + outline=[ + {"id": "s1", "title": "Intro"}, + {"id": "s2", "title": "Methods"}, + {"id": "s3", "title": "Results"}, + ], + ) + result, ok = _write_section(conv_id=1, section_id="s1", content="Introduction text.") + assert ok is True + assert "Remaining incomplete sections (2)" in result + assert "s2" in result + assert "s3" in result + assert "MUST write all remaining" in result + _projects.clear() + + async def test_shows_all_complete_when_done(self): + from openmlr.tools.writing import _projects + + _projects.clear() + _create_project(conv_id=1, title="Test") + _set_outline(conv_id=1, outline=[{"id": "s1", "title": "Intro"}]) + result, ok = _write_section(conv_id=1, section_id="s1", content="Done.") + assert ok is True + assert "All sections are now written" in result + _projects.clear() + + +class TestGetDraftIncompleteWarning: + async def test_draft_includes_incomplete_warning(self): + from unittest.mock import AsyncMock, patch + + from openmlr.tools.writing import _projects + + _projects.clear() + _create_project(conv_id=1, title="Test") + _set_outline( + conv_id=1, outline=[{"id": "s1", "title": "Intro"}, {"id": "s2", "title": "Methods"}] + ) + _write_section(conv_id=1, section_id="s1", content="Written.") + + with patch( + "openmlr.tools.writing._get_author_info", new_callable=AsyncMock, return_value=None + ): + draft, ok = await _get_draft(conv_id=1) + assert ok is True + assert "WARNING" in draft + assert "1 section(s) still incomplete" in draft + assert "s2 (Methods)" in draft + _projects.clear() + + async def test_draft_no_warning_when_complete(self): + from unittest.mock import AsyncMock, patch + + from openmlr.tools.writing import _projects + + _projects.clear() + _create_project(conv_id=1, title="Test") + _set_outline(conv_id=1, outline=[{"id": "s1", "title": "Intro"}]) + _write_section(conv_id=1, section_id="s1", content="Complete content.") + + with patch( + "openmlr.tools.writing._get_author_info", new_callable=AsyncMock, return_value=None + ): + draft, ok = await _get_draft(conv_id=1) + assert ok is True + assert "WARNING" not in draft + _projects.clear() diff --git a/backend/tests/test_workspace.py b/backend/tests/test_workspace.py index d3f59c2..c436131 100644 --- a/backend/tests/test_workspace.py +++ b/backend/tests/test_workspace.py @@ -310,6 +310,28 @@ def test_state_persistence(self, workspace_dir): assert "Attention is effective for NLP" in state2["key_findings"] assert "Does it scale?" in state2["open_questions"] + def test_save_report(self, workspace_dir): + wp = WorkspacePersistence(workspace_dir) + filepath = wp.save_report( + title="Literature Review", + content="# Report\n\nFound 5 relevant papers.", + ) + assert filepath.exists() + content = filepath.read_text() + assert "Found 5 relevant papers" in content + assert filepath.parent.name == "reports" + + def test_save_report_sanitizes_filename(self, workspace_dir): + wp = WorkspacePersistence(workspace_dir) + filepath = wp.save_report( + title="Task: Find papers/methods!", + content="Report content.", + ) + assert filepath.exists() + # Filename should not contain special chars + assert "/" not in filepath.name + assert "!" not in filepath.name + def test_save_plan(self, workspace_dir): wp = WorkspacePersistence(workspace_dir) filepath = wp.save_plan( diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx index 334579b..46c812a 100644 --- a/frontend/src/App.tsx +++ b/frontend/src/App.tsx @@ -2,10 +2,11 @@ import { useState, useCallback, useEffect, useRef } from 'react'; import { Routes, Route, Navigate, useNavigate, useParams } from 'react-router-dom'; import { Copy, Check } from 'lucide-react'; import { ComputeSelector } from './components/ComputeSelector'; +import { ProjectSelector } from './components/ProjectSelector'; import { useSSE } from './hooks/useSSE'; import { useJobStatus } from './hooks/useJobStatus'; import { api } from './api'; -import type { AgentEvent, Message, Conversation, User, QuestionsPayload, PlanTask, Resource, ContextUsage, SearchBudget, Project } from './types'; +import type { AgentEvent, Message, Conversation, User, QuestionsPayload, PlanTask, Resource, ContextUsage, SearchBudget, Project, TodoApprovalPayload } from './types'; import { MessageList } from './components/MessageList'; import { InputArea, type Mode } from './components/InputArea'; import { Sidebar } from './components/Sidebar'; @@ -15,6 +16,7 @@ import { LoginPage } from './components/LoginPage'; import { QuestionDrawer } from './components/QuestionDrawer'; import { RightPanel } from './components/RightPanel'; import { ReportDrawer } from './components/ReportDrawer'; +import { TodoReviewDrawer } from './components/TodoReviewDrawer'; import { AuthGuard } from './components/AuthGuard'; import { OnboardingModal } from './components/OnboardingModal'; import { Terminal } from './components/Terminal'; @@ -111,10 +113,15 @@ function ChatUI({ const [showProjectModal, setShowProjectModal] = useState(false); const [showManageProjects, setShowManageProjects] = useState(false); const [terminalOpen, setTerminalOpen] = useState(false); + const [terminalConnected, setTerminalConnected] = useState(false); + const [todoApprovalPayload, setTodoApprovalPayload] = useState(null); + const [fileTreeRefreshKey, setFileTreeRefreshKey] = useState(0); - // Ref to always have current conv UUID in SSE callback (avoids stale closure) + // Refs to always have current values in SSE callback (avoids stale closure) const currentConvUuidRef = useRef(currentConvUuid); currentConvUuidRef.current = currentConvUuid; + const activeProjectRef = useRef(activeProject); + activeProjectRef.current = activeProject; // Sequence counter to discard stale switchConv responses after rapid switching const switchSeqRef = useRef(0); @@ -126,9 +133,17 @@ function ChatUI({ const isProcessing = currentStatus === 'processing'; const agentTurnActive = currentStatus !== 'idle'; - const loadConversations = useCallback(async () => { - try { - const data = await api.listConversations(); + const loadConversations = useCallback(async (project?: Project | null) => { + // Use the ref value when called without args (from SSE callbacks etc.) + const proj = project !== undefined ? project : activeProjectRef.current; + try { + let data; + if (proj?.uuid) { + data = await api.listProjectConversations(proj.uuid); + } else { + // No project selected — return empty. All conversations must belong to a project. + return []; + } setConversations(data.conversations || []); return data.conversations || []; } catch { @@ -163,11 +178,29 @@ function ChatUI({ } }, []); - // Initial load - load conversations and activate the correct one + // Initial load - load projects first, then conversations for the active project useEffect(() => { const init = async () => { - await Promise.all([loadComputeNodes(), loadProjects()]); - const convs = await loadConversations(); + const [, projData] = await Promise.all([ + loadComputeNodes(), + api.listProjects().catch(() => ({ projects: [] })), + ]); + const allProjects: Project[] = projData.projects || []; + setProjects(allProjects); + + // If no projects exist, the user needs to create one. + // The OnboardingModal or ProjectModal will handle this — don't proceed. + if (allProjects.length === 0) { + setShowProjectModal(true); + return; + } + + // Auto-select the first project if none is active + const proj = allProjects[0]; + setActiveProject(proj); + + // Load conversations for this project + const convs = await loadConversations(proj); // If URL has a conversation UUID, load it directly if (routeUuid) { @@ -175,10 +208,10 @@ function ChatUI({ return; } - // If no conversations exist, create one automatically + // If no conversations exist for this project, create one if (convs.length === 0) { try { - const data = await api.createConversation(); + const data = await api.createConversation(undefined, undefined, undefined, proj.uuid); const conv = data.conversation; setConversations([conv]); setCurrentConvUuid(conv.uuid); @@ -196,6 +229,25 @@ function ChatUI({ init(); }, []); + // Reload conversations when activeProject changes and auto-select the first + useEffect(() => { + if (!activeProject) return; + loadConversations(activeProject).then((convs) => { + if (convs.length > 0) { + const first = convs[0]; + setCurrentConvUuid(first.uuid); + navigate(`/${first.uuid}`, { replace: true }); + switchConv(first.uuid); + } else { + // No conversations in this project — clear state + setCurrentConvUuid(null); + setMessages([]); + setTasks([]); + setResources([]); + } + }); + }, [activeProject, loadConversations]); + // Handle navigation to a different conversation via URL change useEffect(() => { if (routeUuid && routeUuid !== currentConvUuid) { @@ -224,7 +276,7 @@ function ChatUI({ // Only update model if conversation has one explicitly set; don't overwrite the user's sticky model if (data.conversation?.model) setModel(data.conversation.model); setContextUsage(null); setSearchBudget(null); - setApprovalEvent(null); setQuestionsPayload(null); + setApprovalEvent(null); setQuestionsPayload(null); setTodoApprovalPayload(null); // Load persisted tasks and resources from database setTasks(data.tasks?.map((t: any) => ({ title: t.title, status: t.status })) || []); @@ -260,13 +312,17 @@ function ChatUI({ }; const handleNewConversation = async () => { + if (!activeProject) { + setShowProjectModal(true); + return; + } try { - const data = await api.createConversation(); + const data = await api.createConversation(undefined, undefined, undefined, activeProject.uuid); const conv = data.conversation; setConversations((prev) => [conv, ...prev]); setCurrentConvUuid(conv.uuid); setMessages([]); setTasks([]); setResources([]); setContextUsage(null); setSearchBudget(null); - setApprovalEvent(null); setQuestionsPayload(null); + setApprovalEvent(null); setQuestionsPayload(null); setTodoApprovalPayload(null); if (conv.model) setModel(conv.model); // Load default compute for new conversation await loadActiveCompute(conv.uuid); @@ -281,7 +337,7 @@ function ChatUI({ setConvStatuses((prev) => { const n = { ...prev }; delete n[uuid]; return n; }); if (currentConvUuid === uuid) { setCurrentConvUuid(null); setMessages([]); setTasks([]); setResources([]); - setApprovalEvent(null); setQuestionsPayload(null); + setApprovalEvent(null); setQuestionsPayload(null); setTodoApprovalPayload(null); setActiveCompute(null); navigate('/', { replace: true }); } @@ -483,11 +539,13 @@ function ChatUI({ break; } case 'resources_update': setResources(data?.resources || []); setRightPanelOpen(true); break; + case 'workspace_files_changed': setFileTreeRefreshKey((k) => k + 1); break; case 'context_usage': if (data) setContextUsage(data as ContextUsage); break; case 'search_budget': if (data) setSearchBudget(data as SearchBudget); break; case 'approval_required': setApprovalEvent(event); setCurrentConvStatus('waiting_approval'); break; + case 'todo_approval_required': setTodoApprovalPayload(data as TodoApprovalPayload); setCurrentConvStatus('waiting_approval'); break; case 'turn_complete': - setApprovalEvent(null); + setApprovalEvent(null); setTodoApprovalPayload(null); setMessages((prev) => { const c = prev.filter((m) => !(m.role === 'system' && m.content === '::thinking::')); const last = c[c.length - 1]; @@ -594,6 +652,13 @@ function ChatUI({ />
+ setShowProjectModal(true)} + onManageProjects={() => setShowManageProjects(true)} + /> setShowProjectModal(true)} - onManageProjects={() => setShowManageProjects(true)} + onTerminalToggle={() => setTerminalOpen((v) => !v)} />
{/* Empty state */} {messages.length === 0 && !effectiveProcessing && ( @@ -649,6 +713,10 @@ function ChatUI({ {approvalEvent && setApprovalEvent(null)} />} + {todoApprovalPayload && { + setTodoApprovalPayload(null); + setCurrentConvStatus('processing'); + }} onClose={() => { setTodoApprovalPayload(null); setCurrentConvStatus('idle'); }} />} {questionsPayload && { setQuestionsPayload(null); setCurrentConvStatus('processing'); @@ -668,7 +736,7 @@ function ChatUI({
{/* RightPanel is fixed position, doesn't affect flex layout */} - setRightPanelOpen((v) => !v)} onViewReport={(r) => setViewingReport(r)} /> + setRightPanelOpen((v) => !v)} onViewReport={(r) => setViewingReport(r)} onSearchBudgetChange={(newMax) => setSearchBudget((prev) => prev ? { ...prev, max: newMax } : prev)} />
{/* Terminal panel */} @@ -676,10 +744,23 @@ function ChatUI({ projectUuid={activeProject?.uuid || null} visible={terminalOpen} onToggle={() => setTerminalOpen((v) => !v)} + onConnectionChange={setTerminalConnected} + rightOffset={rightPanelOpen ? 288 : 48} /> {viewingReport && setViewingReport(null)} />} - {showProjectModal && setShowProjectModal(false)} onCreate={(p) => { setProjects((prev) => [p, ...prev]); setActiveProject(p); }} />} + {showProjectModal && setShowProjectModal(false)} onCreate={async (p) => { + setProjects((prev) => [p, ...prev]); + setActiveProject(p); + // Auto-create a first conversation in the new project + try { + const data = await api.createConversation(undefined, undefined, undefined, p.uuid); + const conv = data.conversation; + setConversations([conv]); + setCurrentConvUuid(conv.uuid); + navigate(`/${conv.uuid}`, { replace: true }); + } catch { /* */ } + }} />} {showManageProjects && setShowManageProjects(false)} onChanged={() => { loadProjects(); }} />} ); @@ -703,9 +784,14 @@ export default function App() { }).catch(() => {}); }, []); - const handleOnboardingComplete = useCallback((selectedModel: string) => { + const handleOnboardingComplete = useCallback((selectedModel: string, project?: Project) => { setModel(selectedModel); setNeedsOnboarding(false); + // If onboarding created a project, reload to pick it up + if (project) { + // Force a full page reload to reinitialize with the new project + window.location.reload(); + } }, []); return ( diff --git a/frontend/src/__tests__/InputArea.test.tsx b/frontend/src/__tests__/InputArea.test.tsx index eb098b2..93d479a 100644 --- a/frontend/src/__tests__/InputArea.test.tsx +++ b/frontend/src/__tests__/InputArea.test.tsx @@ -92,32 +92,12 @@ describe('InputArea', () => { expect(sendBtn).toBeDisabled(); }); - it('shows Send & Execute button in plan mode', () => { + it('does not render a separate Execute button', () => { render(); - const execBtn = screen.getByTitle('Send & switch to Execute mode (Cmd+Enter)'); - expect(execBtn).toBeInTheDocument(); - }); - - it('Send & Execute button not shown in execute mode', () => { - render(); const execBtn = screen.queryByTitle('Send & switch to Execute mode (Cmd+Enter)'); expect(execBtn).toBeNull(); }); - it('Send & Execute calls onSend with execute mode and switches', () => { - const onSend = vi.fn(); - const onModeChange = vi.fn(); - const onTextChange = vi.fn(); - render( - - ); - const execBtn = screen.getByTitle('Send & switch to Execute mode (Cmd+Enter)'); - fireEvent.click(execBtn); - expect(onModeChange).toHaveBeenCalledWith('execute'); - expect(onSend).toHaveBeenCalledWith('do it', 'execute'); - expect(onTextChange).toHaveBeenCalledWith(''); - }); - it('keyboard shortcut Cmd+M toggles from execute to plan', () => { const onModeChange = vi.fn(); render(); diff --git a/frontend/src/__tests__/OnboardingModal.test.tsx b/frontend/src/__tests__/OnboardingModal.test.tsx index 8ac92bc..0fce820 100644 --- a/frontend/src/__tests__/OnboardingModal.test.tsx +++ b/frontend/src/__tests__/OnboardingModal.test.tsx @@ -163,7 +163,7 @@ describe('OnboardingModal', () => { }); }); - it('selects model and completes', async () => { + it('selects model and advances to project step', async () => { const onComplete = vi.fn(); vi.mocked(api.setModel).mockResolvedValue({ ok: true }); vi.mocked(api.getProviders).mockResolvedValue({ @@ -177,9 +177,13 @@ describe('OnboardingModal', () => { fireEvent.click(screen.getByText('GPT-4o')); + // After selecting a model, should advance to project creation step await waitFor(() => { expect(api.setModel).toHaveBeenCalledWith('openai/gpt-4o'); - expect(onComplete).toHaveBeenCalledWith('openai/gpt-4o'); + expect(screen.getByText('Create Project & Start')).toBeInTheDocument(); }); + + // onComplete should NOT be called yet — user needs to create a project + expect(onComplete).not.toHaveBeenCalled(); }); }); diff --git a/frontend/src/__tests__/ProjectSelector.test.tsx b/frontend/src/__tests__/ProjectSelector.test.tsx new file mode 100644 index 0000000..ea1fc54 --- /dev/null +++ b/frontend/src/__tests__/ProjectSelector.test.tsx @@ -0,0 +1,231 @@ +import { describe, it, expect, vi } from 'vitest'; +import { render, screen, fireEvent } from '@testing-library/react'; +import { ProjectSelector } from '../components/ProjectSelector'; +import type { Project } from '../types'; + +const mockProjects: Project[] = [ + { + id: 1, + uuid: 'proj-1', + name: 'ML Research', + slug: 'ml-research', + description: 'Machine learning research', + workspace_path: '/tmp/ml', + status: 'active', + settings: {}, + created_at: '2024-01-01', + updated_at: '2024-01-01', + conversation_count: 3, + }, + { + id: 2, + uuid: 'proj-2', + name: 'NLP Project', + slug: 'nlp-project', + description: null, + workspace_path: '/tmp/nlp', + status: 'active', + settings: {}, + created_at: '2024-01-02', + updated_at: '2024-01-02', + conversation_count: 1, + }, +]; + +describe('ProjectSelector', () => { + it('renders active project name', () => { + render( + + ); + expect(screen.getByText('ML Research')).toBeInTheDocument(); + }); + + it('renders "Project" when no active project', () => { + render( + + ); + expect(screen.getByText('Project')).toBeInTheDocument(); + }); + + it('does NOT show "All Conversations" option', () => { + render( + + ); + fireEvent.click(screen.getByText('ML Research')); + expect(screen.queryByText('All Conversations')).not.toBeInTheDocument(); + }); + + it('shows all projects in dropdown', () => { + render( + + ); + fireEvent.click(screen.getByText('ML Research')); + expect(screen.getByText('NLP Project')).toBeInTheDocument(); + }); + + it('highlights active project', () => { + render( + + ); + fireEvent.click(screen.getByText('ML Research')); + // The active project button in dropdown should have primary styling + const buttons = screen.getAllByRole('button'); + const activeButton = buttons.find((b) => b.textContent?.includes('ML Research') && b.className.includes('bg-primary')); + expect(activeButton).toBeTruthy(); + }); + + it('calls onSelectProject when project clicked', () => { + const onSelect = vi.fn(); + render( + + ); + fireEvent.click(screen.getByText('ML Research')); + fireEvent.click(screen.getByText('NLP Project')); + expect(onSelect).toHaveBeenCalledWith(mockProjects[1]); + }); + + it('calls onSelectProject with Project, never null', () => { + const onSelect = vi.fn(); + render( + + ); + fireEvent.click(screen.getByText('ML Research')); + fireEvent.click(screen.getByText('NLP Project')); + // Verify the argument is a Project object, not null + expect(onSelect).toHaveBeenCalledTimes(1); + expect(onSelect.mock.calls[0][0]).not.toBeNull(); + expect(onSelect.mock.calls[0][0].uuid).toBe('proj-2'); + }); + + it('shows "New Project" button in dropdown', () => { + render( + + ); + fireEvent.click(screen.getByText('ML Research')); + expect(screen.getByText('New Project')).toBeInTheDocument(); + }); + + it('calls onNewProject when "New Project" clicked', () => { + const onNew = vi.fn(); + render( + + ); + fireEvent.click(screen.getByText('ML Research')); + fireEvent.click(screen.getByText('New Project')); + expect(onNew).toHaveBeenCalled(); + }); + + it('shows conversation count per project', () => { + render( + + ); + fireEvent.click(screen.getByText('ML Research')); + expect(screen.getByText('3')).toBeInTheDocument(); + expect(screen.getByText('1')).toBeInTheDocument(); + }); + + it('shows "No projects yet" when empty', () => { + render( + + ); + fireEvent.click(screen.getByText('Project')); + expect(screen.getByText('No projects yet')).toBeInTheDocument(); + }); + + it('shows Manage Projects when projects exist', () => { + render( + + ); + fireEvent.click(screen.getByText('ML Research')); + expect(screen.getByText('Manage Projects')).toBeInTheDocument(); + }); + + it('closes dropdown on project select', () => { + render( + + ); + fireEvent.click(screen.getByText('ML Research')); + expect(screen.getByText('NLP Project')).toBeInTheDocument(); + fireEvent.click(screen.getByText('NLP Project')); + // Dropdown should be closed + expect(screen.queryByText('New Project')).not.toBeInTheDocument(); + }); +}); diff --git a/frontend/src/__tests__/RightPanel.test.tsx b/frontend/src/__tests__/RightPanel.test.tsx index eb5ae61..14b3803 100644 --- a/frontend/src/__tests__/RightPanel.test.tsx +++ b/frontend/src/__tests__/RightPanel.test.tsx @@ -25,7 +25,7 @@ describe('RightPanel', () => { const mockContext: ContextUsage = { used: 50000, max: 200000, ratio: 0.25 }; const mockSearchBudget: SearchBudget = { used: 5, max: 25 }; - it('renders toggle button when not visible', () => { + it('renders collapsed rail with expand button when not visible', () => { render( { onViewReport={vi.fn()} /> ); - expect(screen.getByTitle('Tasks & resources')).toBeInTheDocument(); + expect(screen.getByTitle('Expand panel')).toBeInTheDocument(); + expect(screen.getByTitle('Todos')).toBeInTheDocument(); }); it('renders tasks when visible', () => { @@ -59,7 +60,7 @@ describe('RightPanel', () => { expect(screen.getByText('Write report')).toBeInTheDocument(); }); - it('shows task count', () => { + it('shows task completion count badge', () => { render( { onViewReport={vi.fn()} /> ); - expect(screen.getByText('Tasks (1/3)')).toBeInTheDocument(); + // CollapsiblePanel renders badge with "done/total" + expect(screen.getByText('1/3')).toBeInTheDocument(); }); it('shows "No tasks yet" when empty', () => { @@ -89,10 +91,9 @@ describe('RightPanel', () => { /> ); expect(screen.getByText('No tasks yet')).toBeInTheDocument(); - expect(screen.getByText('No resources yet')).toBeInTheDocument(); }); - it('renders context gauge', () => { + it('renders context gauge with data', () => { render( { expect(screen.getByText(/200k/)).toBeInTheDocument(); }); + it('renders context gauge placeholder when null', () => { + render( + + ); + expect(screen.getByText('Context: --')).toBeInTheDocument(); + }); + it('renders search budget gauge', () => { render( { expect(screen.getByText(/Searches:/)).toBeInTheDocument(); }); - it('renders resources', () => { + it('renders default search budget when null', () => { + render( + + ); + expect(screen.getByText('Searches: 0 / 25')).toBeInTheDocument(); + }); + + it('does not render resources section (resources are now in FileTree)', () => { render( { onViewReport={vi.fn()} /> ); - expect(screen.getByText('Dataset X')).toBeInTheDocument(); + // Resources section was removed — resources now appear as files in the workspace + expect(screen.queryByText('No resources yet')).not.toBeInTheDocument(); }); - it('renders paper resource with export buttons', () => { + it('does not render paper export buttons (papers are now in FileTree)', () => { render( { onViewReport={vi.fn()} /> ); - expect(screen.getByText('.md')).toBeInTheDocument(); - expect(screen.getByText('.tex')).toBeInTheDocument(); + // Paper export buttons were removed — papers are now files in workspace + expect(screen.queryByText('.md')).not.toBeInTheDocument(); + expect(screen.queryByText('.tex')).not.toBeInTheDocument(); }); - it('hides toggle badge when no tasks and visible', () => { + it('shows task count badge on collapsed rail', () => { render( { onViewReport={vi.fn()} /> ); - expect(screen.queryByTitle('Tasks & resources')).not.toBeInTheDocument(); + const todosButton = screen.getByTitle('Todos'); + const badge = todosButton.querySelector('span'); + expect(badge?.textContent).toBe('3'); }); - it('shows toggle badge with task count when collapsed with tasks', () => { + it('has search budget settings button', () => { render( ); - // The badge is shown inside the toggle button with task count - const toggleButton = screen.getByTitle('Tasks & resources'); - const badge = toggleButton.querySelector('span'); - expect(badge?.textContent).toBe('3'); + expect(screen.getByTitle('Change search budget')).toBeInTheDocument(); }); }); diff --git a/frontend/src/__tests__/SandboxSettings.test.tsx b/frontend/src/__tests__/SandboxSettings.test.tsx deleted file mode 100644 index 088b34c..0000000 --- a/frontend/src/__tests__/SandboxSettings.test.tsx +++ /dev/null @@ -1,60 +0,0 @@ -import { describe, it, expect, vi, beforeEach } from 'vitest'; -import { render, screen, waitFor } from '@testing-library/react'; -import { SandboxSettings } from '../components/settings/SandboxSettings'; -import { api } from '../api'; - -vi.mock('../api', () => ({ - api: { - getSettings: vi.fn(), - updateSetting: vi.fn(), - }, -})); - -describe('SandboxSettings', () => { - beforeEach(() => { - vi.clearAllMocks(); - vi.mocked(api.getSettings).mockResolvedValue({ settings: {} }); - }); - - it('renders heading and hint', async () => { - render(); - await waitFor(() => { - expect(screen.getByText(/Execution environment/)).toBeInTheDocument(); - }); - }); - - it('renders default sandbox select', async () => { - render(); - await waitFor(() => { - expect(screen.getByText('Default Sandbox')).toBeInTheDocument(); - expect(screen.getByText('Local')).toBeInTheDocument(); - expect(screen.getByText('SSH Remote')).toBeInTheDocument(); - expect(screen.getByText('Modal Cloud')).toBeInTheDocument(); - }); - }); - - it('renders modal token fields', async () => { - render(); - await waitFor(() => { - expect(screen.getByPlaceholderText('MODAL_TOKEN_ID')).toBeInTheDocument(); - expect(screen.getByPlaceholderText('MODAL_TOKEN_SECRET')).toBeInTheDocument(); - }); - }); - - it('renders save button', async () => { - render(); - await waitFor(() => { - expect(screen.getByText('Save Sandbox Settings')).toBeInTheDocument(); - }); - }); - - it('loads settings on mount', async () => { - vi.mocked(api.getSettings).mockResolvedValue({ - settings: { sandbox: { default_sandbox: 'ssh' } }, - }); - render(); - await waitFor(() => { - expect(api.getSettings).toHaveBeenCalled(); - }); - }); -}); diff --git a/frontend/src/__tests__/SettingsPanel.test.tsx b/frontend/src/__tests__/SettingsPanel.test.tsx deleted file mode 100644 index 055cdfa..0000000 --- a/frontend/src/__tests__/SettingsPanel.test.tsx +++ /dev/null @@ -1,109 +0,0 @@ -import { describe, it, expect, vi, beforeEach } from 'vitest'; -import { render, screen, fireEvent, waitFor } from '@testing-library/react'; -import { SettingsPanel } from '../components/SettingsPanel'; -import { api } from '../api'; - -vi.mock('../api', () => ({ - api: { - getProviders: vi.fn(), - getSettings: vi.fn(), - updateSetting: vi.fn(), - }, -})); - -describe('SettingsPanel', () => { - beforeEach(() => { - vi.resetAllMocks(); - vi.mocked(api.getProviders).mockResolvedValue({ - providers: [ - { id: 'openai', name: 'OpenAI', key_env: 'OPENAI_API_KEY', configured: true }, - { id: 'anthropic', name: 'Anthropic', key_env: 'ANTHROPIC_API_KEY', configured: false }, - ], - }); - vi.mocked(api.getSettings).mockResolvedValue({ settings: {} }); - }); - - it('renders settings heading', async () => { - render(); - await waitFor(() => { - expect(screen.getByText('Settings')).toBeInTheDocument(); - }); - }); - - it('renders tab buttons', async () => { - render(); - await waitFor(() => { - expect(screen.getByText('Providers')).toBeInTheDocument(); - expect(screen.getByText('Agent')).toBeInTheDocument(); - expect(screen.getByText('Sandbox')).toBeInTheDocument(); - expect(screen.getByText('Writing')).toBeInTheDocument(); - }); - }); - - it('shows providers by default', async () => { - render(); - await waitFor(() => { - expect(screen.getByText('OpenAI')).toBeInTheDocument(); - expect(screen.getByText('Anthropic')).toBeInTheDocument(); - }); - }); - - it('shows configured status', async () => { - render(); - await waitFor(() => { - expect(screen.getByText('Configured')).toBeInTheDocument(); - expect(screen.getByText('Not set')).toBeInTheDocument(); - }); - }); - - it('switches to agent tab', async () => { - render(); - await waitFor(() => { - fireEvent.click(screen.getByText('Agent')); - }); - expect(screen.getByText('Default Model')).toBeInTheDocument(); - expect(screen.getByText(/YOLO Mode/)).toBeInTheDocument(); - }); - - it('switches to sandbox tab', async () => { - render(); - await waitFor(() => { - fireEvent.click(screen.getByText('Sandbox')); - }); - expect(screen.getByText('Default Sandbox')).toBeInTheDocument(); - }); - - it('switches to writing tab', async () => { - render(); - await waitFor(() => { - fireEvent.click(screen.getByText('Writing')); - }); - expect(screen.getByText('Citation Style')).toBeInTheDocument(); - expect(screen.getByText('Export Format')).toBeInTheDocument(); - }); - - it('calls onClose when close button clicked', async () => { - const onClose = vi.fn(); - const { container } = render(); - await waitFor(() => { - expect(screen.getByText('Settings')).toBeInTheDocument(); - }); - // Close button now uses lucide X icon, find it by the SVG class - const closeBtn = container.querySelector('.lucide-x')?.closest('button'); - expect(closeBtn).toBeTruthy(); - fireEvent.click(closeBtn!); - expect(onClose).toHaveBeenCalled(); - }); - - it('calls onClose when overlay clicked', async () => { - const onClose = vi.fn(); - render(); - await waitFor(() => { - expect(screen.getByText('Settings')).toBeInTheDocument(); - }); - // The overlay is the fixed div with bg-black/60 - const overlay = document.querySelector('.fixed.inset-0'); - fireEvent.click(overlay!); - expect(onClose).toHaveBeenCalled(); - }); -}); diff --git a/frontend/src/__tests__/Sidebar.test.tsx b/frontend/src/__tests__/Sidebar.test.tsx index f31447f..f38b962 100644 --- a/frontend/src/__tests__/Sidebar.test.tsx +++ b/frontend/src/__tests__/Sidebar.test.tsx @@ -42,23 +42,25 @@ const mockConversations: Conversation[] = [ }, ]; +const defaultProps = { + conversations: mockConversations, + currentUuid: null as string | null, + user: mockUser, + convStatuses: {} as Record, + terminalOpen: false, + terminalConnected: false, + terminalSessionCount: 0, + onSwitch: vi.fn(), + onNew: vi.fn(), + onDelete: vi.fn(), + onTerminalToggle: vi.fn(), +}; + describe('Sidebar', () => { it('renders new chat button', () => { render( - + ); expect(screen.getByText('New Chat')).toBeInTheDocument(); @@ -67,19 +69,7 @@ describe('Sidebar', () => { it('renders conversation titles', () => { render( - + ); expect(screen.getByText('First conversation')).toBeInTheDocument(); @@ -89,24 +79,10 @@ describe('Sidebar', () => { it('highlights current conversation with bg-primary/10 class', () => { render( - + ); - // Find the conversation item containing "First conversation" const convItem = screen.getByText('First conversation').closest('div'); - // Check that the parent container has the active styling (bg-primary/10) expect(convItem?.className).toContain('bg-primary/10'); }); @@ -114,19 +90,7 @@ describe('Sidebar', () => { const onSwitch = vi.fn(); render( - + ); fireEvent.click(screen.getByText('First conversation')); @@ -136,19 +100,7 @@ describe('Sidebar', () => { it('shows empty state when no conversations', () => { render( - + ); expect(screen.getByText('No conversations yet')).toBeInTheDocument(); @@ -157,19 +109,7 @@ describe('Sidebar', () => { it('shows user display name', () => { render( - + ); expect(screen.getByText('Test User')).toBeInTheDocument(); @@ -178,19 +118,7 @@ describe('Sidebar', () => { it('renders settings and sign out buttons', () => { render( - + ); expect(screen.getByTitle('Settings')).toBeInTheDocument(); @@ -200,19 +128,7 @@ describe('Sidebar', () => { it('filters conversations by search', () => { render( - + ); const searchInput = screen.getByPlaceholderText('Search...'); @@ -225,22 +141,51 @@ describe('Sidebar', () => { const onNew = vi.fn(); render( - + ); fireEvent.click(screen.getByText('New Chat')); expect(onNew).toHaveBeenCalled(); }); + + it('renders terminal button with Closed status when terminal is not open', () => { + render( + + + + ); + expect(screen.getByText('Terminal')).toBeInTheDocument(); + expect(screen.getByText('Closed')).toBeInTheDocument(); + }); + + it('renders terminal button with Connected status when terminal is open and connected', () => { + render( + + + + ); + expect(screen.getByText('Terminal')).toBeInTheDocument(); + expect(screen.getByText('Connected')).toBeInTheDocument(); + expect(screen.getByText('1')).toBeInTheDocument(); + }); + + it('renders terminal button with Disconnected when terminal is open but not connected', () => { + render( + + + + ); + expect(screen.getByText('Disconnected')).toBeInTheDocument(); + }); + + it('calls onTerminalToggle when terminal button clicked', () => { + const onTerminalToggle = vi.fn(); + render( + + + + ); + fireEvent.click(screen.getByText('Terminal')); + expect(onTerminalToggle).toHaveBeenCalled(); + }); }); diff --git a/frontend/src/__tests__/TodoReviewDrawer.test.tsx b/frontend/src/__tests__/TodoReviewDrawer.test.tsx new file mode 100644 index 0000000..64c9ebc --- /dev/null +++ b/frontend/src/__tests__/TodoReviewDrawer.test.tsx @@ -0,0 +1,197 @@ +import { describe, it, expect, vi } from 'vitest'; +import { render, screen, fireEvent } from '@testing-library/react'; +import { TodoReviewDrawer } from '../components/TodoReviewDrawer'; + +vi.mock('../api', () => ({ + api: { + submitTodoApproval: vi.fn().mockResolvedValue({ ok: true }), + }, +})); + +describe('TodoReviewDrawer', () => { + const basePayload = { + change_type: 'create' as const, + proposed_tasks: [ + { title: 'Read papers', status: 'pending' }, + { title: 'Train model', status: 'pending' }, + { title: 'Write report', status: 'pending' }, + ], + current_tasks: [] as Array<{ title: string; status: string }>, + }; + + it('renders header with correct title for create', () => { + render( + + ); + expect(screen.getByText('Review Proposed Plan')).toBeInTheDocument(); + }); + + it('renders header with correct title for add', () => { + render( + + ); + expect(screen.getByText('Review Task Addition')).toBeInTheDocument(); + }); + + it('renders proposed tasks', () => { + render( + + ); + expect(screen.getByText('Read papers')).toBeInTheDocument(); + expect(screen.getByText('Train model')).toBeInTheDocument(); + expect(screen.getByText('Write report')).toBeInTheDocument(); + }); + + it('shows task count', () => { + render( + + ); + expect(screen.getByText('3 tasks')).toBeInTheDocument(); + }); + + it('renders approve and reject buttons', () => { + render( + + ); + expect(screen.getByText('Approve')).toBeInTheDocument(); + expect(screen.getByText('Reject')).toBeInTheDocument(); + }); + + it('marks new tasks with "new" badge', () => { + const payload = { + change_type: 'add' as const, + proposed_tasks: [ + { title: 'Existing task', status: 'completed' }, + { title: 'New task', status: 'pending' }, + ], + current_tasks: [{ title: 'Existing task', status: 'completed' }], + }; + render( + + ); + const badges = screen.getAllByText('new'); + expect(badges.length).toBe(1); // only the new task has a badge + }); + + it('renders current tasks column when they exist', () => { + const payload = { + ...basePayload, + change_type: 'add' as const, + current_tasks: [{ title: 'Old task', status: 'in_progress' }], + }; + render( + + ); + expect(screen.getByText('Current Plan')).toBeInTheDocument(); + expect(screen.getByText('Old task')).toBeInTheDocument(); + }); + + it('calls onClose when X button clicked', () => { + const onClose = vi.fn(); + render( + + ); + // Find the X button by looking for all buttons and finding one with the X icon + const allButtons = screen.getAllByRole('button'); + // The X close button is in the header area and contains an SVG + const closeButton = allButtons.find((btn) => { + const svg = btn.querySelector('svg'); + return svg && btn.closest('.border-b') !== null; + }); + if (closeButton) { + fireEvent.click(closeButton); + } + expect(onClose).toHaveBeenCalled(); + }); + + it('calls submitTodoApproval with approved on approve click', async () => { + const { api } = await import('../api'); + const onDone = vi.fn(); + render( + + ); + + fireEvent.click(screen.getByText('Approve')); + + // Wait for async handler + await vi.waitFor(() => { + expect(api.submitTodoApproval).toHaveBeenCalledWith(true, expect.any(Array)); + }); + }); + + it('calls submitTodoApproval with rejected on reject click', async () => { + const { api } = await import('../api'); + const onDone = vi.fn(); + render( + + ); + + fireEvent.click(screen.getByText('Reject')); + + await vi.waitFor(() => { + expect(api.submitTodoApproval).toHaveBeenCalledWith(false); + }); + }); + + it('has an add task input', () => { + render( + + ); + expect(screen.getByPlaceholderText('Add a task...')).toBeInTheDocument(); + }); + + it('shows edit hint text', () => { + render( + + ); + expect(screen.getByText(/You can edit tasks before approving/)).toBeInTheDocument(); + }); +}); diff --git a/frontend/src/api.ts b/frontend/src/api.ts index 3f5ce35..7fd7afb 100644 --- a/frontend/src/api.ts +++ b/frontend/src/api.ts @@ -74,14 +74,16 @@ export const api = { submitAnswers: (answers: Record) => post('/api/answers', { answers }), interrupt: () => post('/api/interrupt', {}), sendApproval: (approvals: Record) => post('/api/approval', { approvals }), + submitTodoApproval: (approved: boolean, tasks?: any[]) => + post('/api/todo-approval', { approved, tasks }), undo: () => post('/api/undo', {}), compact: () => post('/api/compact', {}), setModel: (model: string) => post('/api/model', { model }), // Conversations listConversations: () => get('/api/conversations'), - createConversation: (title?: string, model?: string, mode?: string) => - post('/api/conversations', { title, model, mode }), + createConversation: (title?: string, model?: string, mode?: string, projectUuid?: string) => + post('/api/conversations', { title, model, mode, project_uuid: projectUuid }), getConversation: (uuid: string) => get(`/api/conversations/${uuid}`), deleteConversation: (uuid: string) => del(`/api/conversations/${uuid}`), switchConversation: (uuid: string) => post(`/api/conversations/${uuid}/switch`, {}), diff --git a/frontend/src/components/CollapsiblePanel.tsx b/frontend/src/components/CollapsiblePanel.tsx new file mode 100644 index 0000000..3bbd3d5 --- /dev/null +++ b/frontend/src/components/CollapsiblePanel.tsx @@ -0,0 +1,39 @@ +import { useState, type ReactNode } from 'react'; +import { ChevronDown, ChevronRight } from 'lucide-react'; + +interface CollapsiblePanelProps { + title: string; + icon?: ReactNode; + badge?: string | number; + defaultExpanded?: boolean; + children: ReactNode; +} + +export function CollapsiblePanel({ title, icon, badge, defaultExpanded = true, children }: CollapsiblePanelProps) { + const [expanded, setExpanded] = useState(defaultExpanded); + + return ( +
+ + {expanded && ( +
+ {children} +
+ )} +
+ ); +} diff --git a/frontend/src/components/ComputeSelector.tsx b/frontend/src/components/ComputeSelector.tsx index c7b836f..8dc1638 100644 --- a/frontend/src/components/ComputeSelector.tsx +++ b/frontend/src/components/ComputeSelector.tsx @@ -1,5 +1,6 @@ import { useState, useEffect, useRef } from 'react'; -import { Cpu, ChevronDown, Monitor } from 'lucide-react'; +import { useNavigate } from 'react-router-dom'; +import { Cpu, ChevronDown, Monitor, Settings } from 'lucide-react'; interface ComputeNode { id: number; @@ -17,6 +18,7 @@ interface ComputeSelectorProps { export function ComputeSelector({ currentNode, nodes, onChange }: ComputeSelectorProps) { const [open, setOpen] = useState(false); const ref = useRef(null); + const navigate = useNavigate(); useEffect(() => { function handleClickOutside(event: MouseEvent) { @@ -98,6 +100,17 @@ export function ComputeSelector({ currentNode, nodes, onChange }: ComputeSelecto No compute nodes configured )} + + {/* Manage Compute link */} +
+ +
)} diff --git a/frontend/src/components/FileTree.tsx b/frontend/src/components/FileTree.tsx index a892cc6..9f4a67a 100644 --- a/frontend/src/components/FileTree.tsx +++ b/frontend/src/components/FileTree.tsx @@ -10,12 +10,16 @@ import { ChevronDown, RefreshCw, AlertCircle, + Pin, + ClipboardList, + BookOpen, } from 'lucide-react'; import { api } from '../api'; import type { FileNode } from '../types'; interface Props { projectUuid: string; + refreshKey?: number; onFileSelect?: (path: string, content: string) => void; } @@ -40,6 +44,16 @@ const FILE_ICONS: Record = { '.svg': , }; +/** Get a special badge/icon for generated resource files. */ +function getSpecialBadge(path: string, name: string): React.ReactNode | null { + if (name === 'PLAN.md') return ; + if (path.includes('.project-meta/reports/') && name.endsWith('.md')) + return ; + if (path.startsWith('papers/') && name.endsWith('.md') && !name.startsWith('.')) + return ; + return null; +} + function getFileIcon(name: string, isDir: boolean): React.ReactNode { if (isDir) return null; // handled by folder icons const ext = name.includes('.') ? '.' + name.split('.').pop() : ''; @@ -64,6 +78,8 @@ function TreeItem({ onToggle: (path: string) => void; onSelect: (path: string) => void; }) { + const badge = getSpecialBadge(node.path, node.name); + return (
{fileLoading ? ( diff --git a/frontend/src/components/InputArea.tsx b/frontend/src/components/InputArea.tsx index d3f693d..12a25e7 100644 --- a/frontend/src/components/InputArea.tsx +++ b/frontend/src/components/InputArea.tsx @@ -1,5 +1,5 @@ import { useRef, useEffect, useCallback, useState } from 'react'; -import { ArrowUp, Square, Play } from 'lucide-react'; +import { ArrowUp, Square } from 'lucide-react'; export type Mode = 'plan' | 'execute'; @@ -34,16 +34,6 @@ export function InputArea({ disabled, showStop, mode, onModeChange, onSend, onSt if (textareaRef.current) textareaRef.current.style.height = 'auto'; }, [text, disabled, onSend, mode, onTextChange]); - /** Send message AND switch to execute mode in one action. */ - const submitAndExecute = useCallback(() => { - const trimmed = text.trim(); - if (!trimmed || disabled) return; - onTextChange(''); - onModeChange('execute'); - onSend(trimmed, 'execute'); - if (textareaRef.current) textareaRef.current.style.height = 'auto'; - }, [text, disabled, onSend, onModeChange, onTextChange]); - const toggleMode = useCallback(() => { onModeChange(mode === 'plan' ? 'execute' : 'plan'); }, [mode, onModeChange]); @@ -56,15 +46,11 @@ export function InputArea({ disabled, showStop, mode, onModeChange, onSend, onSt // Cmd+M: toggle mode e.preventDefault(); onModeChange(mode === 'plan' ? 'execute' : 'plan'); - } else if (e.key === 'Enter' && mode === 'plan') { - // Cmd+Enter in plan mode: send & execute - e.preventDefault(); - submitAndExecute(); } }; window.addEventListener('keydown', handler); return () => window.removeEventListener('keydown', handler); - }, [mode, onModeChange, submitAndExecute]); + }, [mode, onModeChange]); const isPlan = mode === 'plan'; @@ -132,36 +118,20 @@ export function InputArea({ disabled, showStop, mode, onModeChange, onSend, onSt )} - {/* Send buttons */} + {/* Send button */} {!disabled && ( -
- {/* Primary send: uses current mode */} - - - {/* Send & Execute: visible only in Plan mode */} - {isPlan && ( - - )} -
+ )} diff --git a/frontend/src/components/OnboardingModal.tsx b/frontend/src/components/OnboardingModal.tsx index b4f2f4a..9616b74 100644 --- a/frontend/src/components/OnboardingModal.tsx +++ b/frontend/src/components/OnboardingModal.tsx @@ -1,5 +1,6 @@ import { useState, useEffect, useMemo } from 'react'; import { api } from '../api'; +import type { Project } from '../types'; interface Provider { id: string; @@ -16,11 +17,11 @@ interface ModelInfo { } interface Props { - onComplete: (model: string) => void; + onComplete: (model: string, project?: Project) => void; } export function OnboardingModal({ onComplete }: Props) { - const [step, setStep] = useState<'providers' | 'model'>('providers'); + const [step, setStep] = useState<'providers' | 'model' | 'project'>('providers'); const [providers, setProviders] = useState([]); const [models, setModels] = useState([]); const [loading, setLoading] = useState(true); @@ -28,6 +29,12 @@ export function OnboardingModal({ onComplete }: Props) { const [saving, setSaving] = useState(false); const [search, setSearch] = useState(''); const [selectedProvider, setSelectedProvider] = useState('all'); + const [selectedModel, setSelectedModel] = useState(''); + + // Project creation state + const [projectName, setProjectName] = useState(''); + const [projectDesc, setProjectDesc] = useState(''); + const [creatingProject, setCreatingProject] = useState(false); const loadData = async () => { setLoading(true); @@ -77,12 +84,10 @@ export function OnboardingModal({ onComplete }: Props) { setSaving(true); try { await api.saveConfig(toSave); - // Refresh providers and models after saving keys const [pData, mData] = await Promise.all([api.getProviders(), api.getModels()]); setProviders(pData.providers || []); setModels(mData.models || []); setKeyInputs({}); - // Only go to model step if we now have models if ((mData.models || []).length > 0) { setStep('model'); } @@ -93,7 +98,24 @@ export function OnboardingModal({ onComplete }: Props) { const selectModel = async (modelId: string) => { await api.setModel(modelId); - onComplete(modelId); + setSelectedModel(modelId); + // Move to project creation step + setStep('project'); + }; + + const createProjectAndFinish = async () => { + if (!projectName.trim()) return; + setCreatingProject(true); + try { + const data = await api.createProject(projectName.trim(), projectDesc.trim() || undefined); + const project = data.project as Project; + onComplete(selectedModel, project); + } catch { + // If project creation fails, still complete onboarding with the model + onComplete(selectedModel); + } finally { + setCreatingProject(false); + } }; if (loading) { @@ -115,8 +137,21 @@ export function OnboardingModal({ onComplete }: Props) {

{step === 'providers' ? 'Configure at least one LLM provider to get started.' - : 'Pick a model to use for conversations.'} + : step === 'model' + ? 'Pick a model to use for conversations.' + : 'Create your first research project.'}

+ {/* Step indicator */} +
+ {['providers', 'model', 'project'].map((s, i) => ( +
+ ))} +
{/* Providers step */} @@ -173,7 +208,6 @@ export function OnboardingModal({ onComplete }: Props) { {step === 'model' && (
{models.length === 0 ? ( - /* No models available — send user back to configure a provider */

No models available

@@ -190,7 +224,6 @@ export function OnboardingModal({ onComplete }: Props) {
) : ( <> - {/* Filters */}
- {/* Model list */}
{filteredModels.map((m) => (