diff --git a/backend/configs/prompts/system_prompt.yaml b/backend/configs/prompts/system_prompt.yaml index dd0421f..611e1d8 100644 --- a/backend/configs/prompts/system_prompt.yaml +++ b/backend/configs/prompts/system_prompt.yaml @@ -1,97 +1,84 @@ title: OpenMLR System Prompt -version: 4 +version: 5 prompt: | - You are OpenMLR, an ML research intern. You plan, research, write, and - execute ML work end-to-end. + You are OpenMLR, an ML research intern. You help users plan, research, + write, and execute ML work end-to-end. - # Clarification First + # Mode System - **CRITICAL**: Before taking any significant action, ask clarifying questions - using the `ask_user` tool with `allow_text: true` on every question. - Do NOT assume which model, dataset, approach, hardware, or scope to use. - Ask at least one clarifying question for non-trivial tasks. - - # Task Management - - - Always create a plan using `plan_tool` before starting work. - - When proposing a new plan or adding tasks, explain what you're planning - and why. Use `ask_user` to get approval for significant plan changes. - - When completing a task, call `plan_tool` update with status="completed", - include a `summary` of what was accomplished and `next_hints` for upcoming tasks. - This auto-generates a completion report stored as a resource. - - After each task completion, evaluate: does the remaining plan still make sense? - If not, propose changes to the user before continuing. - - Keep pushing forward through the task list. Don't stop after one task unless - waiting for user input. - - Use `plan_tool add_resource` to track every paper, code repo, dataset, or - doc you reference. This builds a knowledge base for the session. - - # Per-Message Modes — STRICT ENFORCEMENT + The user controls which mode you operate in. There are two modes: {% if mode == "plan" %} - ## Plan Mode — RESTRICTED - **ONLY** these tools are available: ask_user, plan_tool, read_file, list_dir + ## CURRENT MODE: PLAN - DO NOT call: web_search, papers, research, writing, or any execution tools. - These will fail with a mode violation error. + You are in **Plan mode**. Your job is to understand the task, ask clarifying + questions, gather context, and produce a comprehensive plan. - Your job in plan mode: - 1. Ask clarifying questions using ask_user - 2. Create a plan using plan_tool - 3. When ready, use ask_user with suggest_mode='research' to propose switching + **Available tools**: ask_user, plan_tool, read_file, list_dir, glob_files, + grep_search, web_search, papers, github_search, github_read_file, github_read_repo - WAIT for the user to approve the mode switch. - {% elif mode == "research" %} - ## Research Mode — RESTRICTED - **ONLY** these tools are available: ask_user, plan_tool, web_search, papers, - research, read_file, github_search, github_read_file + **NOT available**: writing, research sub-agent, sandbox/code execution tools. + Calls to unavailable tools will be rejected. - DO NOT call: writing tool or execution tools. + **Rules**: + 1. Ask clarifying questions using `ask_user` before making assumptions + 2. Search the web, papers, and code repos to gather context + 3. Create a structured plan using `plan_tool` with clear, actionable tasks + 4. The plan is auto-saved as PLAN.md in resources — the user can see it + 5. Do NOT execute any work — plan only + 6. Do NOT write content, run code, or make changes + 7. Be thorough in your plan — it will be the blueprint for Execute mode - Your job in research mode: - 1. Search papers and web for information - 2. Add all sources as resources via plan_tool add_resource - 3. Complete tasks with summaries and reports - 4. When research is complete, use ask_user with suggest_mode='write' + {% elif mode == "execute" %} + ## CURRENT MODE: EXECUTE - WAIT for the user to approve the mode switch. - {% elif mode == "write" %} - ## Write Mode — RESTRICTED - **ONLY** these tools are available: ask_user, plan_tool, writing, - read_file, web_search, papers (for citations) + You are in **Execute mode**. Your job is to follow the plan and do the work. + Do NOT ask questions — just execute. - Your job in write mode: - 1. Write content using the writing tool - 2. Reference resources from research phase - 3. Generate completion reports for each section - {% else %} - ## General Mode - All tools available. Ask clarifying questions first. - {% endif %} - - # CRITICAL WORKFLOW RULES - - 1. **ONE TASK AT A TIME**: You can only have one task in_progress at a time. - Complete the current task with a summary and report before starting the next. + **Available tools**: ALL tools EXCEPT ask_user. + Calls to ask_user will be rejected. - 2. **COMPLETION REPORTS REQUIRED**: When marking a task completed, you MUST provide: - - summary: what was accomplished - - next_hints: recommendations for upcoming tasks + **Rules**: + 1. Follow the task plan — check it with `plan_tool get` if unsure + 2. Work through tasks one at a time, marking them in_progress then completed + 3. When completing a task, provide a summary and next_hints + 4. Add all papers, code, datasets as resources via plan_tool add_resource + 5. Do NOT ask the user questions — if something is ambiguous, make a reasonable + decision and document it in your completion report + 6. Keep pushing through the task list until done or interrupted + 7. Generate completion reports for each task - 3. **MODE SWITCHING REQUIRES USER APPROVAL**: - - Do NOT switch modes automatically - - Use ask_user with suggest_mode to propose a switch - - WAIT for the user's response - - 4. **STAY IN YOUR LANE**: - - In plan mode: only ask and plan - - In research mode: only search and read - - In write mode: only write and cite + {% else %} + ## CURRENT MODE: EXECUTE (default) + All tools available except ask_user. Execute the work. + {% endif %} + + # Task Management + + - Always create a plan using `plan_tool` before starting work (in Plan mode) + - When completing a task, call `plan_tool` update with status="completed", + include a `summary` of what was accomplished and `next_hints` + - This auto-generates a completion report stored as a resource + - ONE task in_progress at a time — complete current before starting next + - Use `plan_tool add_resource` to track every paper, code repo, or doc + + # Paper Writing + + When writing a paper, use the `writing` tool exclusively: + 1. `create_project` with a title + 2. `set_outline` with section structure + 3. `write_section` for each section — the paper is AUTO-SAVED after each write + 4. `add_citation` for references + 5. `get_draft` to review the full paper + + **CRITICAL**: Do NOT use the `write` file tool to save papers. The `writing` tool + auto-saves to the database and the user can preview/export from the Paper tab. + Do NOT call `export` — the user handles export from the UI. # Code Execution - Code runs inside a Docker container (/workspace) in ALL modes when needed. + Code runs inside a Docker container (/workspace) when needed. Before running code: check the environment, install dependencies. Never modify the user's host environment directly. @@ -101,10 +88,6 @@ prompt: | information that was summarized. Prefer short tool outputs. Re-read completion reports (via `plan_tool get`) for context on past work. - # Knowledge is Outdated - - Your ML library knowledge is outdated. Always verify against documentation. - # Communication - Concise and direct. No flattery, no emojis. diff --git a/backend/openmlr/agent/loop.py b/backend/openmlr/agent/loop.py index 6c6194c..deaf92e 100644 --- a/backend/openmlr/agent/loop.py +++ b/backend/openmlr/agent/loop.py @@ -48,35 +48,17 @@ async def _run_agent(session: Session, tool_router, user_message: str, mode: str session.pending_approval = None # Set the mode on the tool router for strict enforcement - effective_mode = mode if mode in ("plan", "research", "write") else "general" + effective_mode = mode if mode in ("plan", "execute") else "execute" tool_router.set_mode(effective_mode) - # Inject per-message mode context if provided - if mode and mode in ("plan", "research", "write"): - mode_hints = { - "plan": ( - "[Mode: PLAN — STRICT ENFORCEMENT]\n" - "- Only ask_user and plan_tool are available\n" - "- Do NOT execute any research, writing, or code tools\n" - "- Ask clarifying questions and create a plan\n" - "- When ready, use ask_user with suggest_mode='research' or 'write' to propose switching" - ), - "research": ( - "[Mode: RESEARCH — STRICT ENFORCEMENT]\n" - "- Search papers, web, and gather information only\n" - "- Do NOT write content or execute code\n" - "- Add all sources as resources via plan_tool\n" - "- When research is complete, use ask_user with suggest_mode='write' to propose switching" - ), - "write": ( - "[Mode: WRITE — STRICT ENFORCEMENT]\n" - "- Write and edit content only\n" - "- Use writing tool for paper sections\n" - "- Reference resources gathered in research phase\n" - "- When writing is complete, generate a report via plan_tool" - ), - } - session.context_manager.add_message(Message(role="system", content=mode_hints[mode])) + # Inject per-message mode hint (short reinforcement of system prompt rules) + mode_hint = ( + f"[Mode: {effective_mode.upper()}] " + + ("Plan only — ask questions, gather context, create plan. No execution." + if effective_mode == "plan" else + "Execute the plan — do the work, no questions. All tools except ask_user.") + ) + session.context_manager.add_message(Message(role="system", content=mode_hint)) session.context_manager.add_message(Message(role="user", content=user_message)) diff --git a/backend/openmlr/db/operations.py b/backend/openmlr/db/operations.py index 9f9e1cf..37be93d 100644 --- a/backend/openmlr/db/operations.py +++ b/backend/openmlr/db/operations.py @@ -324,6 +324,75 @@ async def upsert_conversation_resources( return new_resources +PLAN_RESOURCE_ID = "plan-md" + + +async def upsert_plan_resource(db: AsyncSession, conv_id: int, content: str) -> ConversationResource: + """Create or update the pinned PLAN.md resource for a conversation.""" + existing = await get_resource_by_id(db, f"{PLAN_RESOURCE_ID}-{conv_id}") + if existing: + existing.content = content + await db.commit() + await db.refresh(existing) + return existing + return await add_conversation_resource( + db, conv_id, + title="PLAN.md", + resource_type="plan", + content=content, + resource_id=f"{PLAN_RESOURCE_ID}-{conv_id}", + ) + + +PAPER_RESOURCE_ID = "paper" + + +async def upsert_paper_resource( + db: AsyncSession, conv_id: int, title: str, content: str, +) -> ConversationResource: + """Create or update the paper draft resource for a conversation.""" + rid = f"{PAPER_RESOURCE_ID}-{conv_id}" + existing = await get_resource_by_id(db, rid) + if existing: + existing.title = title + existing.content = content + await db.commit() + await db.refresh(existing) + return existing + return await add_conversation_resource( + db, conv_id, + title=title, + resource_type="paper", + content=content, + resource_id=rid, + ) + + +async def upsert_resource( + db: AsyncSession, conv_id: int, + resource_id: str, title: str, resource_type: str, + content: str = None, url: str = None, +) -> ConversationResource: + """Create or update a resource by resource_id.""" + existing = await get_resource_by_id(db, resource_id) + if existing: + existing.title = title + existing.content = content + if url: + existing.url = url + await db.commit() + await db.refresh(existing) + return existing + return await add_conversation_resource( + db, conv_id, + title=title, + resource_type=resource_type, + content=content, + url=url, + resource_id=resource_id, + ) + + # ---- Agent Jobs ---- async def create_agent_job( diff --git a/backend/openmlr/routes/agent.py b/backend/openmlr/routes/agent.py index d580245..f395d4b 100644 --- a/backend/openmlr/routes/agent.py +++ b/backend/openmlr/routes/agent.py @@ -387,12 +387,43 @@ async def submit_answers( @router.post("/interrupt") -async def interrupt(request: Request, user: User = Depends(get_current_user)): - """Cancel the current agent turn.""" - active = _sm(request).get_current_session() +async def interrupt( + request: Request, + user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + """Cancel the current agent turn (in-process and background workers).""" + sm = _sm(request) + + # 1. Cancel the in-process session (works for inline / non-Celery mode) + active = sm.get_current_session() if active: active.session.cancel() - await _bus(request).broadcast(AgentEvent(event_type="interrupted")) + + # 2. For background Celery workers: relay interrupt via Redis + revoke task + conv_id = sm.current_conversation_id + if conv_id: + from ..services.redis_pubsub import publish_interrupt + await publish_interrupt(conv_id) + + # Also try to revoke active Celery tasks for this conversation + try: + from ..services.job_manager import get_job_manager, USE_BACKGROUND_JOBS + if USE_BACKGROUND_JOBS: + job_manager = get_job_manager() + active_jobs = await job_manager.get_active_jobs(db, conv_id) + for job_info in active_jobs: + jid = job_info["job_id"] + # Revoke with SIGTERM so the worker process is interrupted + if job_manager.celery_app: + job_manager.celery_app.control.revoke(jid, terminate=True, signal="SIGTERM") + logger.info(f"Revoked Celery task {jid} for conversation {conv_id}") + # Mark the job as cancelled in DB + await ops.update_job_status(db, jid, "cancelled") + except Exception as e: + logger.warning(f"Failed to revoke background jobs: {e}") + + await _bus(request).broadcast(AgentEvent(event_type="interrupted")) return {"ok": True} diff --git a/backend/openmlr/services/redis_pubsub.py b/backend/openmlr/services/redis_pubsub.py index b25ec1d..452830d 100644 --- a/backend/openmlr/services/redis_pubsub.py +++ b/backend/openmlr/services/redis_pubsub.py @@ -156,6 +156,41 @@ async def publish_answers(conversation_id: int, answers: dict) -> None: logger.warning(f"Failed to publish answers to Redis: {e}") +INTERRUPT_KEY_PREFIX = "openmlr:interrupt:" + + +async def publish_interrupt(conversation_id: int) -> None: + """Set a Redis key to signal interruption to a background worker.""" + try: + client = await get_redis() + key = f"{INTERRUPT_KEY_PREFIX}{conversation_id}" + await client.set(key, "1", ex=60) # TTL 60 seconds + logger.info(f"Published interrupt for conversation {conversation_id}") + except Exception as e: + logger.warning(f"Failed to publish interrupt to Redis: {e}") + + +async def check_interrupt(conversation_id: int) -> bool: + """Check whether an interrupt signal exists for the given conversation.""" + try: + client = await get_redis() + key = f"{INTERRUPT_KEY_PREFIX}{conversation_id}" + return await client.exists(key) > 0 + except Exception as e: + logger.warning(f"Failed to check interrupt in Redis: {e}") + return False + + +async def clear_interrupt(conversation_id: int) -> None: + """Remove the interrupt key after it has been consumed.""" + try: + client = await get_redis() + key = f"{INTERRUPT_KEY_PREFIX}{conversation_id}" + await client.delete(key) + except Exception as e: + logger.warning(f"Failed to clear interrupt in Redis: {e}") + + async def wait_for_answers(conversation_id: int, timeout: float = 300) -> dict | None: """Wait for user answers from Redis. Used by background worker's ask_user handler.""" try: diff --git a/backend/openmlr/tasks/agent_tasks.py b/backend/openmlr/tasks/agent_tasks.py index 4246ba7..2885c08 100644 --- a/backend/openmlr/tasks/agent_tasks.py +++ b/backend/openmlr/tasks/agent_tasks.py @@ -149,6 +149,25 @@ async def _broadcast(event: AgentEvent): session.on_event(_broadcast) + # Start a background task that polls Redis for an interrupt signal + # and cancels the session when found. + async def _poll_interrupt(): + from ..services.redis_pubsub import check_interrupt, clear_interrupt + try: + while True: + await asyncio.sleep(2) + if await check_interrupt(conversation_id): + logger.info(f"Interrupt detected via Redis for conversation {conversation_id}, cancelling session") + session.cancel() + await clear_interrupt(conversation_id) + break + except asyncio.CancelledError: + pass + except Exception as e: + logger.warning(f"Interrupt poll error: {e}") + + interrupt_task = asyncio.create_task(_poll_interrupt()) + try: # Run the agent turn await run_agent_turn(session, tool_router, message, mode=mode) @@ -175,12 +194,26 @@ async def _broadcast(event: AgentEvent): raise finally: + # Stop the interrupt polling task + interrupt_task.cancel() + try: + await interrupt_task + except asyncio.CancelledError: + pass + # Cleanup try: await sandbox_manager.destroy() except Exception: pass + # Clear any lingering interrupt key + try: + from ..services.redis_pubsub import clear_interrupt + await clear_interrupt(conversation_id) + except Exception: + pass + # Broadcast ready status await publish_event(AgentEvent( event_type="status", diff --git a/backend/openmlr/tools/local.py b/backend/openmlr/tools/local.py index 6815373..847e671 100644 --- a/backend/openmlr/tools/local.py +++ b/backend/openmlr/tools/local.py @@ -270,7 +270,18 @@ async def _handle_read(path: str, offset: int = 1, limit: int = 2000, **kwargs) return f"Error reading: {str(e)}", False -async def _handle_write(path: str, content: str, **kwargs) -> tuple[str, bool]: +async def _handle_write(path: str = "", content: str = "", **kwargs) -> tuple[str, bool]: + # Handle models that abbreviate argument names (e.g., 'p' for 'path', 'c' for 'content') + if not path: + path = kwargs.get("p", kwargs.get("file", kwargs.get("filepath", ""))) + if not content: + content = kwargs.get("c", kwargs.get("text", kwargs.get("data", ""))) + + if not path: + return "Error: 'path' argument is required.", False + if not content: + return "Error: 'content' argument is required.", False + try: target = Path(path).expanduser() if not target.is_absolute(): diff --git a/backend/openmlr/tools/plan.py b/backend/openmlr/tools/plan.py index 9d6e3a8..4323a5b 100644 --- a/backend/openmlr/tools/plan.py +++ b/backend/openmlr/tools/plan.py @@ -6,12 +6,23 @@ import logging from datetime import datetime, timezone from ..agent.types import ToolSpec, AgentEvent -from ..db.engine import async_session from ..db import operations as ops logger = logging.getLogger("openmlr.tools.plan") +def _get_session_factory(): + """Get the correct async session factory for the current context (web or worker).""" + from ..db.engine import _worker_engine, async_session + # If we're in a Celery worker context, use the worker engine + eng = _worker_engine.get(None) + if eng is not None: + from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker + return async_sessionmaker(eng, class_=AsyncSession, expire_on_commit=False) + # Otherwise use the main web engine + return async_session + + def create_plan_tool() -> ToolSpec: return ToolSpec( name="plan_tool", @@ -82,7 +93,8 @@ async def _handle_plan( logger.warning("No conversation_id in session, plan tool cannot persist") return "Error: No active conversation.", False - async with async_session() as db: + session_factory = _get_session_factory() + async with session_factory() as db: if operation == "create": if not tasks: return "Provide 'tasks' array.", False @@ -90,6 +102,12 @@ async def _handle_plan( task_list = [{"title": t.get("title", ""), "status": t.get("status", "pending")} for t in tasks] await ops.upsert_conversation_tasks(db, conv_id, task_list) await _emit_plan(session, conv_id, db) + + # Auto-save plan as PLAN.md resource (pinned) + plan_md = _generate_plan_md(task_list) + await ops.upsert_plan_resource(db, conv_id, plan_md) + await _emit_resources(session, conv_id, db) + return await _format_plan(db, conv_id), True elif operation == "add": @@ -102,6 +120,11 @@ async def _handle_plan( task_list.append({"title": title, "status": "pending"}) await ops.upsert_conversation_tasks(db, conv_id, task_list) await _emit_plan(session, conv_id, db) + + # Update PLAN.md + await ops.upsert_plan_resource(db, conv_id, _generate_plan_md(task_list)) + await _emit_resources(session, conv_id, db) + return await _format_plan(db, conv_id), True elif operation == "update": @@ -217,11 +240,30 @@ async def _handle_plan( async def get_report_content(report_id: str) -> str | None: """Retrieve a stored report by ID. Used by the API.""" - async with async_session() as db: + session_factory = _get_session_factory() + async with session_factory() as db: resource = await ops.get_resource_by_id(db, report_id) return resource.content if resource else None +def _generate_plan_md(tasks: list[dict]) -> str: + """Generate a PLAN.md markdown document from the task list.""" + now = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M UTC") + icons = {"pending": "- [ ]", "in_progress": "- [~]", "completed": "- [x]", "cancelled": "- [-]"} + lines = [ + "# Plan", + "", + f"*Last updated: {now}*", + "", + ] + for t in tasks: + status = t.get("status", "pending") + lines.append(f"{icons.get(status, '- [ ]')} {t.get('title', '')}") + done = sum(1 for t in tasks if t.get("status") == "completed") + lines.extend(["", f"**Progress: {done}/{len(tasks)}**"]) + return "\n".join(lines) + + def _generate_completion_report(task_title: str, summary: str = None, next_hints: str = None) -> str: """Generate a structured markdown completion report.""" now = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M UTC") diff --git a/backend/openmlr/tools/registry.py b/backend/openmlr/tools/registry.py index f0895ca..81a90ce 100644 --- a/backend/openmlr/tools/registry.py +++ b/backend/openmlr/tools/registry.py @@ -9,47 +9,28 @@ # Tools not listed are allowed in all modes MODE_TOOL_RESTRICTIONS = { "plan": { - # In plan mode: only planning, asking, and reading tools + # Plan mode: ask questions, create plans, read context — NO execution tools "allowed": { "ask_user", "plan_tool", - # Read-only tools - "read_file", "list_dir", "glob_files", "grep_search", - }, - "blocked_message": ( - "Tool '{tool}' is not available in PLAN mode. " - "Plan mode is for planning and asking questions only. " - "Suggest switching to research or write mode using ask_user with suggest_mode." - ), - }, - "research": { - # In research mode: search, papers, reading, planning — NO ask_user - "allowed": { - "plan_tool", - "web_search", "papers", "research", + # Read-only tools for gathering context "read_file", "list_dir", "glob_files", "grep_search", + "web_search", "papers", "github_search", "github_read_file", "github_read_repo", }, "blocked_message": ( - "Tool '{tool}' is not available in RESEARCH mode. " - "Research mode is for searching, reading papers, and gathering information. " - "Do NOT ask the user questions in this mode — just do the research. " - "If you need clarification, present your findings first." + "Tool '{tool}' is not available in PLAN mode. " + "Plan mode is for asking questions, planning tasks, and gathering context. " + "Switch to Execute mode to run tools, write content, or execute code." ), }, - "write": { - # In write mode: writing, planning, reading, limited search — NO ask_user - "allowed": { - "plan_tool", "writing", - "read_file", "list_dir", "glob_files", "grep_search", - "web_search", "papers", # For citations - }, + "execute": { + # Execute mode: all tools EXCEPT ask_user — just do the work + "blocked": {"ask_user"}, "blocked_message": ( - "Tool '{tool}' is not available in WRITE mode. " - "Write mode is for drafting and editing content. " - "Do NOT ask the user questions in this mode — just write." + "Tool '{tool}' is not available in EXECUTE mode. " + "Execute mode is for doing the work — do not ask questions, just execute the plan." ), }, - # "general" mode has no restrictions } @@ -85,13 +66,23 @@ def is_tool_allowed(self, name: str) -> tuple[bool, str]: """Check if a tool is allowed in the current mode. Returns (allowed, error_message). + Supports both 'allowed' (whitelist) and 'blocked' (blacklist) sets. """ if self._current_mode not in MODE_TOOL_RESTRICTIONS: return True, "" restrictions = MODE_TOOL_RESTRICTIONS[self._current_mode] - allowed_tools = restrictions.get("allowed", set()) + # Blacklist mode: specific tools are blocked + blocked_tools = restrictions.get("blocked", set()) + if blocked_tools: + if name in blocked_tools: + error_msg = restrictions.get("blocked_message", "Tool '{tool}' not allowed in this mode.") + return False, error_msg.format(tool=name, mode=self._current_mode) + return True, "" + + # Whitelist mode: only specific tools allowed + allowed_tools = restrictions.get("allowed", set()) if name in allowed_tools: return True, "" @@ -162,7 +153,14 @@ async def call_tool( kwargs = dict(arguments) if "session" in sig.parameters: kwargs["session"] = session - return await tool.handler(**kwargs) if kwargs else await tool.handler(arguments) + # Also pass tool_call_id if the handler accepts it + if "tool_call_id" in sig.parameters and "tool_call_id" not in kwargs: + kwargs["tool_call_id"] = kwargs.pop("id", "") + try: + return await tool.handler(**kwargs) if kwargs else await tool.handler(**arguments) + except TypeError as e: + # Handle argument mismatches (model sending wrong param names) + return f"Tool argument error: {e}. Expected parameters: {list(sig.parameters.keys())}", False # MCP tool (no handler — dispatch to MCP client) if self._mcp_client: diff --git a/backend/openmlr/tools/writing.py b/backend/openmlr/tools/writing.py index 4bc992b..78f6d72 100644 --- a/backend/openmlr/tools/writing.py +++ b/backend/openmlr/tools/writing.py @@ -1,20 +1,81 @@ -"""Paper writing tool — section-by-section academic paper authoring.""" +"""Paper writing tool — section-by-section academic paper authoring. -import json -from ..agent.types import ToolSpec +Projects are persisted to the database as resources so they survive +across Celery workers and server restarts. +""" -# In-memory writing projects (will be backed by DB later) -_projects: dict[str, dict] = {} +import json +import logging +from datetime import datetime, timezone +from ..agent.types import ToolSpec, AgentEvent +from ..db import operations as ops + +logger = logging.getLogger("openmlr.tools.writing") + + +def _get_session_factory(): + """Get the correct async session factory for the current context.""" + from ..db.engine import _worker_engine, async_session + eng = _worker_engine.get(None) + if eng is not None: + from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker + return async_sessionmaker(eng, class_=AsyncSession, expire_on_commit=False) + return async_session + + +# In-memory cache (hydrated from DB on first access per conversation) +_projects: dict[int, dict] = {} # keyed by conversation_id + + +async def _load_project(conv_id: int) -> dict | None: + """Load project from DB if not already cached.""" + if conv_id in _projects: + return _projects[conv_id] + + session_factory = _get_session_factory() + async with session_factory() as db: + resource = await ops.get_resource_by_id(db, f"paper-{conv_id}") + if resource and resource.content: + # Try to load the project JSON from a special metadata resource + meta_resource = await ops.get_resource_by_id(db, f"paper-meta-{conv_id}") + if meta_resource and meta_resource.content: + try: + proj = json.loads(meta_resource.content) + _projects[conv_id] = proj + return proj + except json.JSONDecodeError: + pass + return None + + +async def _save_project(conv_id: int, proj: dict) -> None: + """Save project metadata and draft to DB.""" + _projects[conv_id] = proj + + session_factory = _get_session_factory() + async with session_factory() as db: + # Save project metadata (structure, bibliography, etc.) + await ops.upsert_resource( + db, conv_id, + resource_id=f"paper-meta-{conv_id}", + title=f"Paper Metadata: {proj.get('title', 'Untitled')}", + resource_type="doc", + content=json.dumps(proj, default=str), + ) + # Save the rendered draft as the paper resource + draft, _ = _get_draft_from_proj(proj) + await ops.upsert_paper_resource(db, conv_id, proj.get("title", "Paper"), draft) def create_writing_tool() -> ToolSpec: return ToolSpec( name="writing", description=( - "Manage academic paper writing projects. Supports section-by-section " - "writing with research corpus integration. Operations: create_project, " - "set_outline, write_section, refine_section, add_citation, " - "get_draft, list_sections, export." + "Manage academic paper writing. Supports section-by-section writing.\n" + "Operations: create_project, set_outline, write_section, refine_section, " + "add_citation, get_draft, list_sections.\n\n" + "The paper is auto-saved after each write. Users can preview and export " + "from the Paper tab in the UI — do NOT use the 'write' file tool for papers." ), parameters={ "type": "object", @@ -24,7 +85,7 @@ def create_writing_tool() -> ToolSpec: "enum": [ "create_project", "set_outline", "write_section", "refine_section", "add_citation", "get_draft", - "list_sections", "export", + "list_sections", ], "description": "Which writing operation to perform", }, @@ -73,21 +134,16 @@ def create_writing_tool() -> ToolSpec: "type": "object", "description": "BibTeX-style citation object", "properties": { - "key": {"type": "string", "description": "Citation key (e.g. smith2024)"}, - "type": {"type": "string", "description": "Entry type (article, inproceedings, etc.)"}, - "title": {"type": "string", "description": "Paper title"}, - "author": {"type": "string", "description": "Author names"}, - "year": {"type": "string", "description": "Publication year"}, - "venue": {"type": "string", "description": "Journal or conference"}, - "url": {"type": "string", "description": "URL or DOI"}, + "key": {"type": "string"}, + "type": {"type": "string"}, + "title": {"type": "string"}, + "author": {"type": "string"}, + "year": {"type": "string"}, + "venue": {"type": "string"}, + "url": {"type": "string"}, }, "required": ["key", "title", "author", "year"], }, - "format": { - "type": "string", - "enum": ["markdown", "latex"], - "description": "Export format (default: markdown)", - }, }, "required": ["operation"], }, @@ -104,100 +160,120 @@ async def _handle_writing( content: str = None, feedback: str = None, citation: dict = None, - format: str = "markdown", + session=None, **kwargs, ) -> tuple[str, bool]: """Route writing operations.""" + conv_id = session.conversation_id if session else None if operation == "create_project": - return _create_project(title) - elif operation == "set_outline": - return _set_outline(project_id, outline) + result, ok = _create_project(conv_id, title) + if ok and conv_id: + await _save_project(conv_id, _projects[conv_id]) + await _emit_resources(session, conv_id) + return result, ok + + # For all other operations, try to load existing project + if conv_id: + await _load_project(conv_id) + + if operation == "set_outline": + result, ok = _set_outline(conv_id, outline) + if ok and conv_id: + await _save_project(conv_id, _projects[conv_id]) + await _emit_resources(session, conv_id) + return result, ok elif operation == "write_section": - return _write_section(project_id, section_id, content) + result, ok = _write_section(conv_id, section_id, content) + if ok and conv_id: + await _save_project(conv_id, _projects[conv_id]) + await _emit_resources(session, conv_id) + return result, ok elif operation == "refine_section": - return _refine_section(project_id, section_id, content, feedback) + result, ok = _refine_section(conv_id, section_id, content, feedback) + if ok and content and conv_id: + await _save_project(conv_id, _projects[conv_id]) + await _emit_resources(session, conv_id) + return result, ok elif operation == "add_citation": - return _add_citation(project_id, citation) + result, ok = _add_citation(conv_id, citation) + if ok and conv_id: + await _save_project(conv_id, _projects[conv_id]) + await _emit_resources(session, conv_id) + return result, ok elif operation == "get_draft": - return _get_draft(project_id) + return _get_draft(conv_id) elif operation == "list_sections": - return _list_sections(project_id) - elif operation == "export": - return _export(project_id, format) + return _list_sections(conv_id) else: return f"Unknown operation: {operation}", False -def _get_project(project_id: str) -> dict | None: - if not project_id: - # Return most recent project - if _projects: - return list(_projects.values())[-1] +def _get_project(conv_id: int) -> dict | None: + """Get project from in-memory cache.""" + if not conv_id: return None - return _projects.get(project_id) + return _projects.get(conv_id) -def _create_project(title: str) -> tuple[str, bool]: +def _create_project(conv_id: int, title: str) -> tuple[str, bool]: if not title: return "Provide a 'title' for the project.", False - import uuid - pid = str(uuid.uuid4())[:8] - _projects[pid] = { - "id": pid, + proj = { "title": title, "outline": [], "sections": {}, "bibliography": [], - "status": "draft", + "created_at": datetime.now(timezone.utc).isoformat(), } - return f"Created project '{title}' (id: {pid}). Use set_outline to define sections.", True + if conv_id: + _projects[conv_id] = proj + return f"Created paper project: '{title}'. Use set_outline to define sections.", True -def _set_outline(project_id: str, outline: list) -> tuple[str, bool]: - proj = _get_project(project_id) +def _set_outline(conv_id: int, outline: list) -> tuple[str, bool]: + proj = _get_project(conv_id) if not proj: - return "Project not found. Create one first.", False + return "No paper project exists. Call create_project first.", False if not outline: return "Provide an 'outline' array.", False proj["outline"] = outline lines = [f"Outline set for '{proj['title']}':\n"] - for i, sec in enumerate(outline): - lines.append(f" {sec.get('id', i)}. {sec.get('title', 'Untitled')}") + for sec in outline: + lines.append(f" {sec.get('id', '?')}. {sec.get('title', 'Untitled')}") for sub in sec.get("subsections", []): lines.append(f" {sub.get('id', '')}. {sub.get('title', '')}") return "\n".join(lines), True -def _write_section(project_id: str, section_id: str, content: str) -> tuple[str, bool]: - proj = _get_project(project_id) +def _write_section(conv_id: int, section_id: str, content: str) -> tuple[str, bool]: + proj = _get_project(conv_id) if not proj: - return "Project not found.", False + return "No paper project exists. Call create_project first.", False if not section_id or not content: return "Provide both 'section_id' and 'content'.", False proj["sections"][section_id] = content - written = len(proj["sections"]) total = _count_sections(proj["outline"]) return ( f"Section '{section_id}' written ({len(content)} chars). " - f"Progress: {written}/{total} sections complete." + f"Progress: {written}/{total} sections. Paper auto-saved." ), True -def _refine_section(project_id: str, section_id: str, content: str, feedback: str) -> tuple[str, bool]: - proj = _get_project(project_id) +def _refine_section(conv_id: int, section_id: str, content: str, feedback: str) -> tuple[str, bool]: + proj = _get_project(conv_id) if not proj: - return "Project not found.", False + return "No paper project exists.", False if not section_id: return "Provide 'section_id' to refine.", False if content: proj["sections"][section_id] = content - return f"Section '{section_id}' refined ({len(content)} chars).", True + return f"Section '{section_id}' refined ({len(content)} chars). Paper auto-saved.", True else: existing = proj["sections"].get(section_id, "") return ( @@ -208,26 +284,30 @@ def _refine_section(project_id: str, section_id: str, content: str, feedback: st ), True -def _add_citation(project_id: str, citation: dict) -> tuple[str, bool]: - proj = _get_project(project_id) +def _add_citation(conv_id: int, citation: dict) -> tuple[str, bool]: + proj = _get_project(conv_id) if not proj: - return "Project not found.", False + return "No paper project exists.", False if not citation: return "Provide a 'citation' object.", False proj["bibliography"].append(citation) key = citation.get("key", f"ref{len(proj['bibliography'])}") - return f"Added citation [@{key}]. Bibliography now has {len(proj['bibliography'])} entries.", True + return f"Added citation [@{key}]. Bibliography: {len(proj['bibliography'])} entries.", True -def _get_draft(project_id: str) -> tuple[str, bool]: - proj = _get_project(project_id) +def _get_draft(conv_id: int) -> tuple[str, bool]: + proj = _get_project(conv_id) if not proj: - return "Project not found.", False + return "No paper project exists.", False + return _get_draft_from_proj(proj) + +def _get_draft_from_proj(proj: dict) -> tuple[str, bool]: + """Generate the full markdown draft from a project dict.""" lines = [f"# {proj['title']}\n"] - if proj["outline"]: + if proj.get("outline"): for sec in proj["outline"]: sid = sec.get("id", "") title = sec.get("title", "") @@ -240,15 +320,13 @@ def _get_draft(project_id: str) -> tuple[str, bool]: sub_content = proj["sections"].get(sub_id, "[Not yet written]") lines.append(f"\n### {sub_title}\n\n{sub_content}") else: - # No outline — just dump sections - for sid, content in proj["sections"].items(): + for sid, content in proj.get("sections", {}).items(): lines.append(f"\n## {sid}\n\n{content}") - # Bibliography - if proj["bibliography"]: + if proj.get("bibliography"): lines.append("\n## References\n") - for i, c in enumerate(proj["bibliography"], 1): - key = c.get("key", f"ref{i}") + for c in proj["bibliography"]: + key = c.get("key", "?") author = c.get("author", "Unknown") title = c.get("title", "Untitled") year = c.get("year", "?") @@ -257,13 +335,13 @@ def _get_draft(project_id: str) -> tuple[str, bool]: return "\n".join(lines), True -def _list_sections(project_id: str) -> tuple[str, bool]: - proj = _get_project(project_id) +def _list_sections(conv_id: int) -> tuple[str, bool]: + proj = _get_project(conv_id) if not proj: - return "Project not found.", False + return "No paper project exists.", False lines = [f"## Sections for '{proj['title']}'\n"] - if proj["outline"]: + if proj.get("outline"): for sec in proj["outline"]: sid = sec.get("id", "") written = "done" if sid in proj["sections"] else "pending" @@ -271,65 +349,24 @@ def _list_sections(project_id: str) -> tuple[str, bool]: lines.append(f" [{written}] {sid}: {sec.get('title', '')} ({char_count} chars)") else: lines.append("No outline defined. Use set_outline first.") - return "\n".join(lines), True -def _export(project_id: str, fmt: str = "markdown") -> tuple[str, bool]: - proj = _get_project(project_id) - if not proj: - return "Project not found.", False - - draft, _ = _get_draft(project_id) - - if fmt == "latex": - return _convert_to_latex(proj, draft), True - else: - return f"Markdown draft:\n\n{draft}", True - - -def _convert_to_latex(proj: dict, markdown: str) -> str: - """Basic Markdown to LaTeX conversion.""" - lines = [ - "\\documentclass{article}", - "\\usepackage[utf8]{inputenc}", - "\\usepackage{amsmath,amssymb}", - "\\usepackage{hyperref}", - "", - f"\\title{{{proj['title']}}}", - "\\author{}", - "\\date{\\today}", - "", - "\\begin{document}", - "\\maketitle", - "", - ] - - # Simple conversion - for line in markdown.split("\n"): - if line.startswith("### "): - lines.append(f"\\subsubsection{{{line[4:]}}}") - elif line.startswith("## "): - lines.append(f"\\subsection{{{line[3:]}}}") - elif line.startswith("# "): - lines.append(f"\\section{{{line[2:]}}}") - else: - lines.append(line) - - # Bibliography - if proj["bibliography"]: - lines.append("") - lines.append("\\begin{thebibliography}{99}") - for c in proj["bibliography"]: - key = c.get("key", "") - author = c.get("author", "") - title = c.get("title", "") - year = c.get("year", "") - lines.append(f"\\bibitem{{{key}}} {author}. \\textit{{{title}}}. {year}.") - lines.append("\\end{thebibliography}") - - lines.append("\\end{document}") - return "\n".join(lines) +async def _emit_resources(session, conv_id: int) -> None: + """Emit resources update event to frontend.""" + if not session: + return + session_factory = _get_session_factory() + async with session_factory() as db: + resources = await ops.get_conversation_resources(db, conv_id) + res_list = [ + {"title": r.title, "url": r.url or "", "type": r.type, "id": r.resource_id} + for r in resources + ] + await session.emit(AgentEvent( + event_type="resources_update", + data={"resources": res_list}, + )) def _count_sections(outline: list) -> int: diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx index 334d8ee..f09599c 100644 --- a/frontend/src/App.tsx +++ b/frontend/src/App.tsx @@ -63,7 +63,6 @@ function ChatUI({ const [viewingReport, setViewingReport] = useState(null); const [inputMode, setInputMode] = useState('plan'); const [inputText, setInputText] = useState(''); - const [pendingModeSwitch, setPendingModeSwitch] = useState(null); // Ref to always have current conv UUID in SSE callback (avoids stale closure) const currentConvUuidRef = useRef(currentConvUuid); @@ -491,21 +490,6 @@ function ChatUI({ setCurrentConvStatus('processing'); setMessages((prev) => [...prev, { id: nextId(), role: 'user', content: `Answered:\n${summary}` }]); }} onClose={() => setQuestionsPayload(null)} />} - {pendingModeSwitch && ( -
-
- Agent suggests switching to {pendingModeSwitch} mode -
- - -
-
-
- )} > = {}) { return { disabled: false, - mode: 'general' as Mode, + mode: 'plan' as Mode, onModeChange: vi.fn(), onSend: vi.fn(), onStop: vi.fn(), @@ -17,30 +17,25 @@ function defaultProps(overrides: Partial> } describe('InputArea', () => { - it('renders mode buttons (Plan, Research, Write, General)', () => { + it('renders mode toggle button showing P in plan mode', () => { render(); - expect(screen.getByText('Plan')).toBeInTheDocument(); - expect(screen.getByText('Research')).toBeInTheDocument(); - expect(screen.getByText('Write')).toBeInTheDocument(); - expect(screen.getByText('General')).toBeInTheDocument(); + const toggle = screen.getByText('P'); + expect(toggle).toBeInTheDocument(); + expect(toggle).toHaveClass('mode-plan'); }); - it('active mode button has active class', () => { - render(); - // The button with title="Research" should have the active class - const researchBtn = screen.getByTitle('Research'); - expect(researchBtn).toHaveClass('active'); - - // Other buttons should not have active class - const planBtn = screen.getByTitle('Plan'); - expect(planBtn).not.toHaveClass('active'); + it('renders mode toggle button showing E in execute mode', () => { + render(); + const toggle = screen.getByText('E'); + expect(toggle).toBeInTheDocument(); + expect(toggle).toHaveClass('mode-execute'); }); - it('clicking mode button calls onModeChange', () => { + it('clicking toggle switches mode', () => { const onModeChange = vi.fn(); - render(); - fireEvent.click(screen.getByTitle('Write')); - expect(onModeChange).toHaveBeenCalledWith('write'); + render(); + fireEvent.click(screen.getByText('P')); + expect(onModeChange).toHaveBeenCalledWith('execute'); }); it('send button click calls onSend with text and mode', () => { @@ -48,47 +43,26 @@ describe('InputArea', () => { const onTextChange = vi.fn(); render( , ); const sendBtn = screen.getByRole('button', { name: '↑' }); fireEvent.click(sendBtn); - expect(onSend).toHaveBeenCalledWith('Hello there', 'plan'); + expect(onSend).toHaveBeenCalledWith('hello', 'plan'); expect(onTextChange).toHaveBeenCalledWith(''); }); - it('Enter key submits (without shift)', () => { + it('enter key submits', () => { const onSend = vi.fn(); - const onTextChange = vi.fn(); - render( - , - ); + render(); const textarea = screen.getByRole('textbox'); - fireEvent.keyDown(textarea, { key: 'Enter', shiftKey: false }); - expect(onSend).toHaveBeenCalledWith('Enter submit test', 'general'); + fireEvent.keyDown(textarea, { key: 'Enter' }); + expect(onSend).toHaveBeenCalledOnce(); }); - it('Shift+Enter does not submit', () => { + it('shift+enter does not submit', () => { const onSend = vi.fn(); - render( - , - ); + render(); const textarea = screen.getByRole('textbox'); fireEvent.keyDown(textarea, { key: 'Enter', shiftKey: true }); expect(onSend).not.toHaveBeenCalled(); @@ -114,4 +88,18 @@ describe('InputArea', () => { const sendBtn = screen.getByRole('button', { name: '↑' }); expect(sendBtn).toBeDisabled(); }); + + it('keyboard shortcut Cmd+B switches to plan', () => { + const onModeChange = vi.fn(); + render(); + fireEvent.keyDown(window, { key: 'b', metaKey: true }); + expect(onModeChange).toHaveBeenCalledWith('plan'); + }); + + it('keyboard shortcut Cmd+E switches to execute', () => { + const onModeChange = vi.fn(); + render(); + fireEvent.keyDown(window, { key: 'e', metaKey: true }); + expect(onModeChange).toHaveBeenCalledWith('execute'); + }); }); diff --git a/frontend/src/components/InputArea.tsx b/frontend/src/components/InputArea.tsx index 60f126b..579d008 100644 --- a/frontend/src/components/InputArea.tsx +++ b/frontend/src/components/InputArea.tsx @@ -1,10 +1,10 @@ import { useRef, useEffect, useCallback } from 'react'; -export type Mode = 'plan' | 'research' | 'write' | 'general'; +export type Mode = 'plan' | 'execute'; interface Props { disabled: boolean; - showStop?: boolean; // Show the stop button even when input is enabled (e.g., during questions/approval) + showStop?: boolean; mode: Mode; onModeChange: (mode: Mode) => void; onSend: (text: string, mode: Mode) => void; @@ -13,13 +13,6 @@ interface Props { onTextChange: (text: string) => void; } -const MODE_INFO: Record = { - plan: { label: 'Plan', icon: 'P', placeholder: 'Ask questions, plan tasks, clarify scope...' }, - research: { label: 'Research', icon: 'R', placeholder: 'Search papers, find code, explore literature...' }, - write: { label: 'Write', icon: 'W', placeholder: 'Write sections, manage citations, draft content...' }, - general: { label: 'General', icon: 'G', placeholder: 'General conversation, any task...' }, -}; - export function InputArea({ disabled, showStop, mode, onModeChange, onSend, onStop, text, onTextChange }: Props) { const textareaRef = useRef(null); @@ -40,24 +33,38 @@ export function InputArea({ disabled, showStop, mode, onModeChange, onSend, onSt if (textareaRef.current) textareaRef.current.style.height = 'auto'; }, [text, disabled, onSend, mode, onTextChange]); - const modes: Mode[] = ['plan', 'research', 'write', 'general']; + const toggleMode = useCallback(() => { + onModeChange(mode === 'plan' ? 'execute' : 'plan'); + }, [mode, onModeChange]); + + // Keyboard shortcuts: Cmd+B = Plan, Cmd+E = Execute + useEffect(() => { + const handler = (e: KeyboardEvent) => { + if (!(e.metaKey || e.ctrlKey)) return; + if (e.key === 'b' || e.key === 'B') { + e.preventDefault(); + onModeChange('plan'); + } else if (e.key === 'e' || e.key === 'E') { + e.preventDefault(); + onModeChange('execute'); + } + }; + window.addEventListener('keydown', handler); + return () => window.removeEventListener('keydown', handler); + }, [onModeChange]); + + const isPlan = mode === 'plan'; return (
-
- {modes.map((m) => ( - - ))} -
+