From 68c59adf6e0598cdb8a99df0c6cc88066e9219c3 Mon Sep 17 00:00:00 2001 From: xprilion Date: Mon, 27 Apr 2026 18:38:47 +0530 Subject: [PATCH 1/7] Add workspaces and projects - Project entity with persistent workspace directories (Docker volume) - Knowledge graph (networkx) for cross-conversation memory - File tree panel, interactive terminal, project selector UI - Workspace tool for agent: status, search, notes, knowledge CRUD - ArXiv API fixed to use HTTPS - Security hardened: path traversal prevention, env scrubbing, size limits, sandbox containment, filename sanitization, DB operation allowlists, contextvars for session isolation - 65 new tests (projects, workspace, knowledge graph, tools) - Docs: projects.md, updated tools.md, README, VitePress nav --- .env.example | 8 + README.md | 6 +- backend/configs/prompts/system_prompt.yaml | 22 + backend/openmlr/app.py | 10 +- backend/openmlr/compute/workspace.py | 177 ++++-- backend/openmlr/db/engine.py | 5 + .../migrations/versions/004_add_projects.py | 48 ++ backend/openmlr/db/models.py | 94 ++- backend/openmlr/db/operations.py | 256 ++++++-- backend/openmlr/routes/projects.py | 545 ++++++++++++++++++ backend/openmlr/routes/terminal.py | 275 +++++++++ backend/openmlr/sandbox/local.py | 55 +- backend/openmlr/sandbox/manager.py | 25 +- backend/openmlr/tools/papers.py | 181 ++++-- backend/openmlr/tools/registry.py | 59 +- backend/openmlr/tools/workspace_tools.py | 383 ++++++++++++ backend/openmlr/workspace/__init__.py | 6 + backend/openmlr/workspace/knowledge.py | 375 ++++++++++++ backend/openmlr/workspace/persistence.py | 353 ++++++++++++ backend/pyproject.toml | 3 + backend/tests/test_projects.py | 196 +++++++ backend/tests/test_tools_workspace.py | 208 +++++++ backend/tests/test_workspace.py | 333 +++++++++++ docker-compose.prod.yml | 4 + docker-compose.yml | 3 + frontend/src/App.tsx | 33 +- frontend/src/__tests__/RightPanel.test.tsx | 10 + frontend/src/__tests__/Sidebar.test.tsx | 36 ++ frontend/src/api.ts | 22 + frontend/src/components/FileTree.tsx | 303 ++++++++++ frontend/src/components/ProjectModal.tsx | 104 ++++ frontend/src/components/RightPanel.tsx | 52 +- frontend/src/components/Sidebar.tsx | 59 +- frontend/src/components/Terminal.tsx | 208 +++++++ frontend/src/types.ts | 24 + site/docs/.vitepress/config.ts | 2 + site/docs/projects.md | 173 ++++++ site/docs/tools.md | 23 +- 38 files changed, 4482 insertions(+), 197 deletions(-) create mode 100644 backend/openmlr/db/migrations/versions/004_add_projects.py create mode 100644 backend/openmlr/routes/projects.py create mode 100644 backend/openmlr/routes/terminal.py create mode 100644 backend/openmlr/tools/workspace_tools.py create mode 100644 backend/openmlr/workspace/__init__.py create mode 100644 backend/openmlr/workspace/knowledge.py create mode 100644 backend/openmlr/workspace/persistence.py create mode 100644 backend/tests/test_projects.py create mode 100644 backend/tests/test_tools_workspace.py create mode 100644 backend/tests/test_workspace.py create mode 100644 frontend/src/components/FileTree.tsx create mode 100644 frontend/src/components/ProjectModal.tsx create mode 100644 frontend/src/components/Terminal.tsx create mode 100644 site/docs/projects.md diff --git a/.env.example b/.env.example index 30c6638..cd0d11c 100644 --- a/.env.example +++ b/.env.example @@ -78,6 +78,14 @@ GITHUB_TOKEN= # OpenAlex email for faster rate limits (paper search works without it) # OPENALEX_EMAIL=you@example.com +# ═══════════════════════════════════════════════════════════ +# PROJECT WORKSPACES +# ═══════════════════════════════════════════════════════════ + +# Where project workspaces are stored (bind mount in production) +# Docker Compose dev uses a named volume; production uses this path. +# OPENMLR_WORKSPACES_PATH=./.workspaces + # ═══════════════════════════════════════════════════════════ # SANDBOX / CODE EXECUTION # ═══════════════════════════════════════════════════════════ diff --git a/README.md b/README.md index bc781b3..d1cbc2d 100644 --- a/README.md +++ b/README.md @@ -19,10 +19,12 @@ ## Features +- **Projects & Workspaces** — Persistent workspaces with knowledge graphs, file trees, and cross-conversation memory. Research accumulates across chats. +- **Interactive terminal** — Built-in terminal connected to the project workspace. Run commands directly alongside AI-driven research. - **Plan + Execute modes** — Plan mode gathers context; Execute mode does the work. Toggle with `Cmd+M`. - **Paper research** — OpenAlex, Semantic Scholar, arXiv, CrossRef, Papers With Code. Reads full papers, crawls citation graphs. - **Paper writing** — Section-by-section drafting with auto-save. Export to Markdown/LaTeX. -- **Compute environments** — Execute code on local Docker, SSH remotes, or Modal cloud. Probe GPU/CPU capabilities. +- **Compute environments** — Execute code on local Docker, SSH remotes, or Modal cloud. Workspace persists independently of compute. - **Background jobs** — Celery + Redis. Close the browser, come back later. - **Multi-provider LLMs** — OpenAI, Anthropic, OpenRouter, plus local models (Ollama, LM Studio). - **MCP servers** — Connect external tools via the Model Context Protocol. @@ -40,6 +42,8 @@ make up Open `http://localhost:3000`. Create an account. Add your API keys in **Settings > Providers**. +Project workspaces are stored in a persistent Docker volume (`.workspaces/`), so your research data survives container rebuilds. + > No API keys needed to start — the app guides you through configuration after login. ## Development diff --git a/backend/configs/prompts/system_prompt.yaml b/backend/configs/prompts/system_prompt.yaml index bd340c8..ae4c1d2 100644 --- a/backend/configs/prompts/system_prompt.yaml +++ b/backend/configs/prompts/system_prompt.yaml @@ -150,8 +150,30 @@ prompt: | {{ compute_env }} {% endif %} + # Project Workspace + + If a project workspace is active, use it to persist knowledge across conversations: + + - Use `workspace status` at the start of a conversation to understand what has + been done before (papers found, notes written, experiments run, known failures) + - Use `workspace knowledge_add` to record important entities (papers, methods, + datasets, findings) in the knowledge graph + - Use `workspace knowledge_relate` to link entities (e.g., paper --proposes--> method) + - Use `workspace note` to save research summaries and important findings + - Use `workspace knowledge_summary` to review accumulated knowledge + - Use `workspace recent_failures` to check for known tool/API issues before retrying + + The workspace persists independently of compute resources. Files in the workspace + (code/, data/, papers/, research/, outputs/) survive compute changes and new conversations. + + **Important**: The workspace is the source of truth for the project. Always check it + before doing redundant work. Save important findings so future conversations can build on them. + # Compute Planning When starting tasks that require significant computation (training models, processing large datasets, etc.): 1. Use `compute_plan` to verify the active node meets requirements 2. If not, use `compute_select` to switch to a suitable node 3. Always `sandbox_probe` before executing code on a node for the first time + + Note: The compute resource is separate from the project workspace. Switching compute + does not affect your workspace files. The workspace is always available locally. diff --git a/backend/openmlr/app.py b/backend/openmlr/app.py index 66f1d03..dead080 100644 --- a/backend/openmlr/app.py +++ b/backend/openmlr/app.py @@ -22,6 +22,7 @@ async def lifespan(app: FastAPI): """Startup: create tables & shared state. Shutdown: teardown sessions.""" import logging + logger = logging.getLogger("openmlr.app") async with engine.begin() as conn: @@ -57,7 +58,9 @@ async def lifespan(app: FastAPI): ) # CORS configuration - restrict in production -_cors_origins = os.environ.get("CORS_ORIGINS", "http://localhost:3000,http://localhost:5173").split(",") +_cors_origins = os.environ.get("CORS_ORIGINS", "http://localhost:3000,http://localhost:5173").split( + "," +) _cors_origins = [origin.strip() for origin in _cors_origins if origin.strip()] app.add_middleware( @@ -74,7 +77,9 @@ async def lifespan(app: FastAPI): from .routes.compute import router as compute_router from .routes.health import router as health_router from .routes.keys import router as keys_router +from .routes.projects import router as projects_router from .routes.settings import router as settings_router +from .routes.terminal import router as terminal_router app.include_router(auth_router) app.include_router(agent_router) @@ -82,12 +87,15 @@ async def lifespan(app: FastAPI): app.include_router(health_router) app.include_router(keys_router) app.include_router(compute_router) +app.include_router(projects_router) +app.include_router(terminal_router) # ── Global error handler ──────────────────────────────── @app.exception_handler(Exception) async def global_exception_handler(request: Request, exc: Exception): import logging + logger = logging.getLogger(__name__) logger.exception(f"Unhandled exception: {exc}") # Don't leak internal details to client diff --git a/backend/openmlr/compute/workspace.py b/backend/openmlr/compute/workspace.py index 9517bf3..f342793 100644 --- a/backend/openmlr/compute/workspace.py +++ b/backend/openmlr/compute/workspace.py @@ -1,14 +1,54 @@ -"""Workspace Manager — per-conversation filesystem isolation.""" +"""Workspace Manager — project-scoped filesystem with backward-compatible conversation support. +The workspace is the persistent home for all project artifacts: +code, data, models, outputs, papers, research notes, logs, and knowledge graph. +It persists across conversations and compute resource changes. +""" + +import json +import logging import os import shutil import tarfile from datetime import UTC, datetime from pathlib import Path +log = logging.getLogger(__name__) + +# Default workspace root — overridden by OPENMLR_WORKSPACES_PATH in Docker +WORKSPACES_ROOT = Path(os.environ.get("OPENMLR_WORKSPACES_PATH", "/app/.workspaces")) + +# Standard project workspace subdirectories +PROJECT_SUBDIRS = [ + "code", + "data", + "models", + "outputs", + "papers", + "research", + "research/searches", + "research/notes", + "research/citations", + "logs", + "logs/tool_failures", + "logs/compute", + "logs/experiments", + "venvs", + ".project-meta", + ".project-meta/plans", +] + +# Legacy conversation-only subdirectories (backward compat) +LEGACY_SUBDIRS = ["data", "models", "code", "outputs", ".openmlr-meta"] + class WorkspaceManager: - """Manages isolated workspace directories for each conversation.""" + """Manages isolated workspace directories for projects and conversations. + + Supports two modes: + - Project mode: workspace at WORKSPACES_ROOT/{project_slug}/ + - Legacy mode: workspace at ~/.openmlr/workspaces/workspace-{uuid}/ + """ def __init__(self, base_dir: str | Path = None): self.base_dir = Path(base_dir) if base_dir else Path.home() / ".openmlr" @@ -21,23 +61,103 @@ def _ensure_dirs(self) -> None: self.workspace_dir.mkdir(parents=True, exist_ok=True) self.archive_dir.mkdir(parents=True, exist_ok=True) + # ── Project-scoped workspaces ──────────────────────── + + @staticmethod + def get_project_workspace_path(project_slug: str) -> Path: + """Get the workspace directory for a project.""" + return WORKSPACES_ROOT / project_slug + + @staticmethod + def create_project_workspace(project_slug: str, name: str = "", description: str = "") -> Path: + """Create a new project workspace with all standard subdirectories.""" + path = WORKSPACES_ROOT / project_slug + path.mkdir(parents=True, exist_ok=True) + + for subdir in PROJECT_SUBDIRS: + (path / subdir).mkdir(parents=True, exist_ok=True) + + # Write initial project metadata if it doesn't exist + meta_path = path / ".project-meta" / "project.json" + if not meta_path.exists(): + meta_path.write_text( + json.dumps( + { + "name": name or project_slug, + "slug": project_slug, + "description": description, + "created_at": datetime.now(UTC).isoformat(), + }, + indent=2, + ) + ) + + # Initialize empty knowledge graph if it doesn't exist + kg_path = path / ".project-meta" / "knowledge.json" + if not kg_path.exists(): + kg_path.write_text( + json.dumps( + { + "nodes": [], + "edges": [], + "version": 1, + }, + indent=2, + ) + ) + + # Initialize empty state file for cross-conversation persistence + state_path = path / ".project-meta" / "state.json" + if not state_path.exists(): + state_path.write_text( + json.dumps( + { + "last_conversation_uuid": None, + "open_questions": [], + "key_findings": [], + "active_experiments": [], + }, + indent=2, + ) + ) + + return path + + @staticmethod + def project_workspace_exists(project_slug: str) -> bool: + """Check if a project workspace exists.""" + return (WORKSPACES_ROOT / project_slug).exists() + + @staticmethod + def get_project_workspace_size(project_slug: str) -> int: + """Get total size of a project workspace in bytes.""" + path = WORKSPACES_ROOT / project_slug + if not path.exists(): + return 0 + total = 0 + for dirpath, _, filenames in os.walk(path): + for f in filenames: + fp = Path(dirpath) / f + if fp.exists(): + total += fp.stat().st_size + return total + + # ── Legacy conversation-scoped workspaces ──────────── + def get_workspace_path(self, conversation_uuid: str) -> Path: - """Get the workspace directory for a conversation.""" + """Get the workspace directory for a conversation (legacy mode).""" return self.workspace_dir / f"workspace-{conversation_uuid}" def create_workspace(self, conversation_uuid: str) -> Path: - """Create a new workspace directory for a conversation.""" + """Create a new workspace directory for a conversation (legacy mode).""" path = self.get_workspace_path(conversation_uuid) path.mkdir(parents=True, exist_ok=True) - # Create standard subdirectories - for subdir in ["data", "models", "code", "outputs"]: + for subdir in LEGACY_SUBDIRS: (path / subdir).mkdir(exist_ok=True) - # Create meta directory (hidden from agent) - (path / ".openmlr-meta").mkdir(exist_ok=True) return path def workspace_exists(self, conversation_uuid: str) -> bool: - """Check if a workspace exists.""" + """Check if a conversation workspace exists.""" return self.get_workspace_path(conversation_uuid).exists() def archive_workspace(self, conversation_uuid: str) -> Path | None: @@ -82,34 +202,27 @@ def get_workspace_size(self, conversation_uuid: str) -> int: return total def list_workspaces(self) -> list[dict]: - """List all workspaces with metadata.""" + """List all conversation workspaces with metadata.""" workspaces = [] for path in self.workspace_dir.glob("workspace-*"): if path.is_dir(): uuid = path.name.replace("workspace-", "") size = self.get_workspace_size(uuid) - workspaces.append({ - "uuid": uuid, - "path": str(path), - "size_bytes": size, - "created": datetime.fromtimestamp(path.stat().st_ctime, UTC).isoformat(), - }) + workspaces.append( + { + "uuid": uuid, + "path": str(path), + "size_bytes": size, + "created": datetime.fromtimestamp(path.stat().st_ctime, UTC).isoformat(), + } + ) return sorted(workspaces, key=lambda x: x["created"], reverse=True) def cleanup_archives(self, max_age_days: int = 30, max_count: int = 100) -> dict: - """Clean up old workspace archives. - - Args: - max_age_days: Delete archives older than this many days - max_count: Keep at most this many archives, delete oldest first - - Returns: - Dict with deleted count and freed bytes - """ + """Clean up old workspace archives.""" deleted = 0 freed_bytes = 0 - # Get all archives sorted by modification time (oldest first) archives = [] for path in self.archive_dir.glob("workspace-*.tar.gz"): if path.is_file(): @@ -118,7 +231,6 @@ def cleanup_archives(self, max_age_days: int = 30, max_count: int = 100) -> dict archives.sort(key=lambda x: x["mtime"]) - # Delete old archives now = datetime.now(UTC) for archive in archives: age_days = (now - archive["mtime"]).days @@ -127,7 +239,6 @@ def cleanup_archives(self, max_age_days: int = 30, max_count: int = 100) -> dict archive["path"].unlink() deleted += 1 - # Delete excess archives (oldest first) remaining = [a for a in archives if a["path"].exists()] while len(remaining) > max_count: oldest = remaining.pop(0) @@ -138,15 +249,7 @@ def cleanup_archives(self, max_age_days: int = 30, max_count: int = 100) -> dict return {"deleted": deleted, "freed_bytes": freed_bytes} def cleanup_workspaces(self, conversation_uuids: list[str], archive: bool = True) -> dict: - """Clean up workspaces for deleted conversations. - - Args: - conversation_uuids: List of conversation UUIDs to keep - archive: Whether to archive before deleting - - Returns: - Dict with deleted count and freed bytes - """ + """Clean up workspaces for deleted conversations.""" deleted = 0 freed_bytes = 0 keep_set = set(conversation_uuids) diff --git a/backend/openmlr/db/engine.py b/backend/openmlr/db/engine.py index 8c3e995..0afc5fd 100644 --- a/backend/openmlr/db/engine.py +++ b/backend/openmlr/db/engine.py @@ -55,3 +55,8 @@ async def get_db() -> AsyncSession: yield session finally: await session.close() + + +def get_async_session(): + """Get an async session as a context manager (for non-dependency use like WebSockets).""" + return async_session() diff --git a/backend/openmlr/db/migrations/versions/004_add_projects.py b/backend/openmlr/db/migrations/versions/004_add_projects.py new file mode 100644 index 0000000..95f69a8 --- /dev/null +++ b/backend/openmlr/db/migrations/versions/004_add_projects.py @@ -0,0 +1,48 @@ +"""Add projects table and project_id to conversations + +Revision ID: 004_add_projects +Revises: 003_migrate_sandbox_to_compute +Create Date: 2026-04-27 +""" +from typing import Sequence, Union +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision: str = '004_add_projects' +down_revision: Union[str, None] = '003_migrate_sandbox_to_compute' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # Create projects table + op.create_table( + 'projects', + sa.Column('id', sa.Integer(), primary_key=True), + sa.Column('uuid', sa.String(36), unique=True, nullable=False), + sa.Column('user_id', sa.Integer(), sa.ForeignKey('users.id', ondelete='CASCADE'), nullable=False), + sa.Column('name', sa.String(255), nullable=False), + sa.Column('slug', sa.String(255), nullable=False), + sa.Column('description', sa.Text(), nullable=True), + sa.Column('workspace_path', sa.String(1000), nullable=True), + sa.Column('status', sa.String(20), server_default='active', nullable=False), + sa.Column('settings', sa.JSON(), nullable=True), + sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False), + sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False), + ) + op.create_index('ix_projects_user_id', 'projects', ['user_id']) + op.create_unique_constraint('uq_projects_user_slug', 'projects', ['user_id', 'slug']) + + # Add project_id column to conversations + op.add_column( + 'conversations', + sa.Column('project_id', sa.Integer(), sa.ForeignKey('projects.id', ondelete='SET NULL'), nullable=True), + ) + op.create_index('ix_conversations_project_id', 'conversations', ['project_id']) + + +def downgrade() -> None: + op.drop_index('ix_conversations_project_id', table_name='conversations') + op.drop_column('conversations', 'project_id') + op.drop_table('projects') diff --git a/backend/openmlr/db/models.py b/backend/openmlr/db/models.py index ef40f74..59900e7 100644 --- a/backend/openmlr/db/models.py +++ b/backend/openmlr/db/models.py @@ -12,6 +12,7 @@ Integer, String, Text, + UniqueConstraint, ) from sqlalchemy.dialects.postgresql import ARRAY from sqlalchemy.orm import DeclarativeBase, relationship @@ -37,12 +38,21 @@ class User(Base): updated_at = Column(DateTime(timezone=True), default=_utcnow, onupdate=_utcnow, nullable=False) settings = relationship("UserSetting", back_populates="user", cascade="all, delete-orphan") - conversations = relationship("Conversation", back_populates="user", cascade="all, delete-orphan") - sandbox_configs = relationship("SandboxConfig", back_populates="user", cascade="all, delete-orphan") + conversations = relationship( + "Conversation", back_populates="user", cascade="all, delete-orphan" + ) + projects = relationship("Project", back_populates="user", cascade="all, delete-orphan") + sandbox_configs = relationship( + "SandboxConfig", back_populates="user", cascade="all, delete-orphan" + ) ssh_keys = relationship("SSHKey", back_populates="user", cascade="all, delete-orphan") compute_nodes = relationship("ComputeNode", back_populates="user", cascade="all, delete-orphan") - research_corpus = relationship("ResearchCorpus", back_populates="user", cascade="all, delete-orphan") - writing_projects = relationship("WritingProject", back_populates="user", cascade="all, delete-orphan") + research_corpus = relationship( + "ResearchCorpus", back_populates="user", cascade="all, delete-orphan" + ) + writing_projects = relationship( + "WritingProject", back_populates="user", cascade="all, delete-orphan" + ) class UserSetting(Base): @@ -63,26 +73,57 @@ class UserSetting(Base): ) +class Project(Base): + """A project groups multiple conversations around a persistent workspace.""" + + __tablename__ = "projects" + + id = Column(Integer, primary_key=True) + uuid = Column(String(36), unique=True, nullable=False, default=lambda: str(uuid.uuid4())) + user_id = Column(Integer, ForeignKey("users.id", ondelete="CASCADE"), nullable=False) + name = Column(String(255), nullable=False) + slug = Column(String(255), nullable=False) + description = Column(Text, nullable=True) + workspace_path = Column(String(1000), nullable=True) # absolute path to workspace dir + status = Column(String(20), default="active", nullable=False) # active, archived + settings = Column("settings", JSON, nullable=True) # project-level overrides + created_at = Column(DateTime(timezone=True), default=_utcnow, nullable=False) + updated_at = Column(DateTime(timezone=True), default=_utcnow, onupdate=_utcnow, nullable=False) + + user = relationship("User", back_populates="projects") + conversations = relationship("Conversation", back_populates="project") + + __table_args__ = (UniqueConstraint("user_id", "slug", name="uq_projects_user_slug"),) + + class Conversation(Base): __tablename__ = "conversations" id = Column(Integer, primary_key=True) uuid = Column(String(36), unique=True, nullable=False, default=lambda: str(uuid.uuid4())) user_id = Column(Integer, ForeignKey("users.id", ondelete="CASCADE"), nullable=False) + project_id = Column(Integer, ForeignKey("projects.id", ondelete="SET NULL"), nullable=True) title = Column(String(255), default="New conversation", nullable=False) model = Column(String(100), nullable=True) - mode = Column(String(20), default="general", nullable=False) # research, writing, coding, general + mode = Column( + String(20), default="general", nullable=False + ) # research, writing, coding, general user_message_count = Column(Integer, default=0, nullable=False) extra = Column("extra", JSON, nullable=True) created_at = Column(DateTime(timezone=True), default=_utcnow, nullable=False) updated_at = Column(DateTime(timezone=True), default=_utcnow, onupdate=_utcnow, nullable=False) user = relationship("User", back_populates="conversations") + project = relationship("Project", back_populates="conversations") messages = relationship("Message", back_populates="conversation", cascade="all, delete-orphan") corpus = relationship("ResearchCorpus", back_populates="conversation") writing_project = relationship("WritingProject", back_populates="conversation") - tasks = relationship("ConversationTask", back_populates="conversation", cascade="all, delete-orphan") - resources = relationship("ConversationResource", back_populates="conversation", cascade="all, delete-orphan") + tasks = relationship( + "ConversationTask", back_populates="conversation", cascade="all, delete-orphan" + ) + resources = relationship( + "ConversationResource", back_populates="conversation", cascade="all, delete-orphan" + ) jobs = relationship("AgentJob", back_populates="conversation", cascade="all, delete-orphan") @@ -90,7 +131,9 @@ class Message(Base): __tablename__ = "messages" id = Column(Integer, primary_key=True) - conversation_id = Column(Integer, ForeignKey("conversations.id", ondelete="CASCADE"), nullable=False) + conversation_id = Column( + Integer, ForeignKey("conversations.id", ondelete="CASCADE"), nullable=False + ) role = Column(String(20), nullable=False) # system, user, assistant, tool content = Column(Text, nullable=False) meta = Column("meta", JSON, nullable=True) @@ -160,7 +203,9 @@ class ResearchCorpus(Base): __tablename__ = "research_corpus" id = Column(Integer, primary_key=True) - conversation_id = Column(Integer, ForeignKey("conversations.id", ondelete="SET NULL"), nullable=True) + conversation_id = Column( + Integer, ForeignKey("conversations.id", ondelete="SET NULL"), nullable=True + ) user_id = Column(Integer, ForeignKey("users.id", ondelete="CASCADE"), nullable=False) paper_id = Column(String(100), nullable=True) # arxiv ID or DOI title = Column(String(500), nullable=False) @@ -181,7 +226,9 @@ class WritingProject(Base): id = Column(Integer, primary_key=True) user_id = Column(Integer, ForeignKey("users.id", ondelete="CASCADE"), nullable=False) - conversation_id = Column(Integer, ForeignKey("conversations.id", ondelete="SET NULL"), nullable=True) + conversation_id = Column( + Integer, ForeignKey("conversations.id", ondelete="SET NULL"), nullable=True + ) title = Column(String(500), nullable=False) outline = Column(JSON, nullable=True) # section structure sections = Column(JSON, default=dict, nullable=False) # section_id -> markdown content @@ -197,12 +244,17 @@ class WritingProject(Base): class ConversationTask(Base): """Persisted tasks (todo items) for a conversation.""" + __tablename__ = "conversation_tasks" id = Column(Integer, primary_key=True) - conversation_id = Column(Integer, ForeignKey("conversations.id", ondelete="CASCADE"), nullable=False) + conversation_id = Column( + Integer, ForeignKey("conversations.id", ondelete="CASCADE"), nullable=False + ) title = Column(String(500), nullable=False) - status = Column(String(20), default="pending", nullable=False) # pending, in_progress, completed, cancelled + status = Column( + String(20), default="pending", nullable=False + ) # pending, in_progress, completed, cancelled priority = Column(String(20), default="medium", nullable=True) # high, medium, low order_index = Column(Integer, default=0, nullable=False) created_at = Column(DateTime(timezone=True), default=_utcnow, nullable=False) @@ -213,11 +265,16 @@ class ConversationTask(Base): class ConversationResource(Base): """Persisted resources (papers, code, datasets, reports) for a conversation.""" + __tablename__ = "conversation_resources" id = Column(Integer, primary_key=True) - conversation_id = Column(Integer, ForeignKey("conversations.id", ondelete="CASCADE"), nullable=False) - resource_id = Column(String(100), unique=True, nullable=False, default=lambda: str(uuid.uuid4())[:8]) + conversation_id = Column( + Integer, ForeignKey("conversations.id", ondelete="CASCADE"), nullable=False + ) + resource_id = Column( + String(100), unique=True, nullable=False, default=lambda: str(uuid.uuid4())[:8] + ) title = Column(String(500), nullable=False) url = Column(String(2000), nullable=True) type = Column(String(20), default="doc", nullable=False) # paper, code, dataset, doc, report @@ -229,13 +286,18 @@ class ConversationResource(Base): class AgentJob(Base): """Background job tracking for agent execution.""" + __tablename__ = "agent_jobs" id = Column(Integer, primary_key=True) job_id = Column(String(100), unique=True, nullable=False, default=lambda: str(uuid.uuid4())) - conversation_id = Column(Integer, ForeignKey("conversations.id", ondelete="CASCADE"), nullable=False) + conversation_id = Column( + Integer, ForeignKey("conversations.id", ondelete="CASCADE"), nullable=False + ) user_id = Column(Integer, ForeignKey("users.id", ondelete="CASCADE"), nullable=False) - status = Column(String(20), default="queued", nullable=False) # queued, running, completed, failed, cancelled + status = Column( + String(20), default="queued", nullable=False + ) # queued, running, completed, failed, cancelled message = Column(Text, nullable=True) # The user message that triggered this job mode = Column(String(20), nullable=True) # research, writing, coding, general error = Column(Text, nullable=True) # Error message if failed diff --git a/backend/openmlr/db/operations.py b/backend/openmlr/db/operations.py index 7bafae0..c5dbaa9 100644 --- a/backend/openmlr/db/operations.py +++ b/backend/openmlr/db/operations.py @@ -12,24 +12,158 @@ ConversationResource, ConversationTask, Message, + Project, SSHKey, UserSetting, ) +# ---- Projects ---- + + +async def create_project( + db: AsyncSession, + user_id: int, + name: str, + slug: str, + description: str | None = None, + workspace_path: str | None = None, + settings: dict | None = None, +) -> Project: + project = Project( + user_id=user_id, + name=name, + slug=slug, + description=description, + workspace_path=workspace_path, + settings=settings, + ) + db.add(project) + await db.commit() + await db.refresh(project) + return project + + +async def get_user_projects( + db: AsyncSession, + user_id: int, + include_archived: bool = False, +) -> list[Project]: + query = select(Project).where(Project.user_id == user_id) + if not include_archived: + query = query.where(Project.status == "active") + query = query.order_by(Project.updated_at.desc()) + result = await db.execute(query) + return list(result.scalars().all()) + + +async def get_project_by_id( + db: AsyncSession, project_id: int, user_id: int | None = None +) -> Project | None: + query = select(Project).where(Project.id == project_id) + if user_id is not None: + query = query.where(Project.user_id == user_id) + result = await db.execute(query) + return result.scalar_one_or_none() + + +async def get_project_by_uuid( + db: AsyncSession, uuid: str, user_id: int | None = None +) -> Project | None: + query = select(Project).where(Project.uuid == uuid) + if user_id is not None: + query = query.where(Project.user_id == user_id) + result = await db.execute(query) + return result.scalar_one_or_none() + + +async def get_project_by_slug(db: AsyncSession, user_id: int, slug: str) -> Project | None: + result = await db.execute( + select(Project).where(Project.user_id == user_id, Project.slug == slug) + ) + return result.scalar_one_or_none() + + +# Explicit allowlist of fields that can be updated via update_project. +# Prevents injection of workspace_path, user_id, id, uuid, etc. +_PROJECT_UPDATABLE_FIELDS = {"name", "slug", "description", "settings", "status"} + + +async def update_project( + db: AsyncSession, + project_id: int, + user_id: int, + **kwargs, +) -> Project | None: + result = await db.execute( + select(Project).where(Project.id == project_id, Project.user_id == user_id) + ) + project = result.scalar_one_or_none() + if not project: + return None + for key, value in kwargs.items(): + if key in _PROJECT_UPDATABLE_FIELDS: + setattr(project, key, value) + await db.commit() + await db.refresh(project) + return project + + +async def archive_project(db: AsyncSession, project_id: int, user_id: int) -> Project | None: + return await update_project(db, project_id, user_id, status="archived") + + +async def get_project_conversations(db: AsyncSession, project_id: int) -> list[Conversation]: + result = await db.execute( + select(Conversation) + .where(Conversation.project_id == project_id) + .order_by(Conversation.updated_at.desc()) + ) + return list(result.scalars().all()) + + +async def attach_conversation_to_project( + db: AsyncSession, + conversation_id: int, + project_id: int | None, + user_id: int | None = None, +) -> bool: + """Attach or detach a conversation from a project. + + When user_id is provided, verifies ownership of both conversation and project. + """ + conv = await get_conversation_by_id(db, conversation_id) + if not conv: + return False + # Verify conversation ownership when user_id is provided + if user_id is not None and conv.user_id != user_id: + return False + # Verify project ownership when attaching (not detaching) + if project_id is not None and user_id is not None: + project = await get_project_by_id(db, project_id, user_id) + if not project: + return False + conv.project_id = project_id + await db.commit() + return True + + # ---- Conversations ---- + async def create_conversation( db: AsyncSession, user_id: int, title: str = "New conversation", model: str | None = None, mode: str = "general", + project_id: int | None = None, ) -> Conversation: conv = Conversation( user_id=user_id, title=title, model=model, mode=mode, + project_id=project_id, ) db.add(conv) await db.commit() @@ -66,23 +200,17 @@ async def delete_conversation(db: AsyncSession, conv_id: int) -> bool: async def update_conversation_title(db: AsyncSession, conv_id: int, title: str): - await db.execute( - update(Conversation).where(Conversation.id == conv_id).values(title=title) - ) + await db.execute(update(Conversation).where(Conversation.id == conv_id).values(title=title)) await db.commit() async def update_conversation_model(db: AsyncSession, conv_id: int, model: str): - await db.execute( - update(Conversation).where(Conversation.id == conv_id).values(model=model) - ) + await db.execute(update(Conversation).where(Conversation.id == conv_id).values(model=model)) await db.commit() async def update_conversation_extra(db: AsyncSession, conv_id: int, extra: dict): - await db.execute( - update(Conversation).where(Conversation.id == conv_id).values(extra=extra) - ) + await db.execute(update(Conversation).where(Conversation.id == conv_id).values(extra=extra)) await db.commit() @@ -97,11 +225,10 @@ async def increment_user_message_count(db: AsyncSession, conv_id: int): # ---- Messages ---- + async def get_messages(db: AsyncSession, conv_id: int) -> list[Message]: result = await db.execute( - select(Message) - .where(Message.conversation_id == conv_id) - .order_by(Message.created_at.asc()) + select(Message).where(Message.conversation_id == conv_id).order_by(Message.created_at.asc()) ) return list(result.scalars().all()) @@ -126,21 +253,18 @@ async def add_message( async def clear_messages(db: AsyncSession, conv_id: int): - await db.execute( - delete(Message).where(Message.conversation_id == conv_id) - ) + await db.execute(delete(Message).where(Message.conversation_id == conv_id)) await db.commit() # ---- Settings ---- -async def get_user_setting( - db: AsyncSession, user_id: int, category: str, key: str -) -> dict | None: + +async def get_user_setting(db: AsyncSession, user_id: int, category: str, key: str) -> dict | None: from .models import UserSetting + result = await db.execute( - select(UserSetting) - .where( + select(UserSetting).where( UserSetting.user_id == user_id, UserSetting.category == category, UserSetting.key == key, @@ -151,12 +275,16 @@ async def get_user_setting( async def set_user_setting( - db: AsyncSession, user_id: int, category: str, key: str, value: dict | list | str | int | float | bool + db: AsyncSession, + user_id: int, + category: str, + key: str, + value: dict | list | str | int | float | bool, ): from .models import UserSetting + result = await db.execute( - select(UserSetting) - .where( + select(UserSetting).where( UserSetting.user_id == user_id, UserSetting.category == category, UserSetting.key == key, @@ -185,6 +313,7 @@ def _clean_json_value(val: object) -> object: async def get_all_settings(db: AsyncSession, user_id: int, category: str | None = None) -> dict: from .models import UserSetting + query = select(UserSetting).where(UserSetting.user_id == user_id) if category: query = query.where(UserSetting.category == category) @@ -201,6 +330,7 @@ async def get_all_settings(db: AsyncSession, user_id: int, category: str | None async def delete_user_setting(db: AsyncSession, user_id: int, category: str, key: str): from .models import UserSetting + await db.execute( delete(UserSetting).where( UserSetting.user_id == user_id, @@ -213,6 +343,7 @@ async def delete_user_setting(db: AsyncSession, user_id: int, category: str, key # ---- Conversation Tasks ---- + async def get_conversation_tasks(db: AsyncSession, conv_id: int) -> list[ConversationTask]: result = await db.execute( select(ConversationTask) @@ -229,9 +360,7 @@ async def upsert_conversation_tasks( ) -> list[ConversationTask]: """Replace all tasks for a conversation with the new list.""" # Delete existing tasks - await db.execute( - delete(ConversationTask).where(ConversationTask.conversation_id == conv_id) - ) + await db.execute(delete(ConversationTask).where(ConversationTask.conversation_id == conv_id)) # Insert new tasks new_tasks = [] @@ -267,6 +396,7 @@ async def update_task_status( # ---- Conversation Resources ---- + async def get_conversation_resources(db: AsyncSession, conv_id: int) -> list[ConversationResource]: result = await db.execute( select(ConversationResource) @@ -286,6 +416,7 @@ async def add_conversation_resource( resource_id: str | None = None, ) -> ConversationResource: import uuid as uuid_mod + resource = ConversationResource( conversation_id=conv_id, resource_id=resource_id or str(uuid_mod.uuid4())[:8], @@ -341,7 +472,9 @@ async def upsert_conversation_resources( PLAN_RESOURCE_ID = "plan-md" -async def upsert_plan_resource(db: AsyncSession, conv_id: int, content: str) -> ConversationResource: +async def upsert_plan_resource( + db: AsyncSession, conv_id: int, content: str +) -> ConversationResource: """Create or update the pinned PLAN.md resource for a conversation.""" existing = await get_resource_by_id(db, f"{PLAN_RESOURCE_ID}-{conv_id}") if existing: @@ -350,7 +483,8 @@ async def upsert_plan_resource(db: AsyncSession, conv_id: int, content: str) -> await db.refresh(existing) return existing return await add_conversation_resource( - db, conv_id, + db, + conv_id, title="PLAN.md", resource_type="plan", content=content, @@ -362,7 +496,10 @@ async def upsert_plan_resource(db: AsyncSession, conv_id: int, content: str) -> async def upsert_paper_resource( - db: AsyncSession, conv_id: int, title: str, content: str, + db: AsyncSession, + conv_id: int, + title: str, + content: str, ) -> ConversationResource: """Create or update the paper draft resource for a conversation.""" rid = f"{PAPER_RESOURCE_ID}-{conv_id}" @@ -374,7 +511,8 @@ async def upsert_paper_resource( await db.refresh(existing) return existing return await add_conversation_resource( - db, conv_id, + db, + conv_id, title=title, resource_type="paper", content=content, @@ -383,9 +521,13 @@ async def upsert_paper_resource( async def upsert_resource( - db: AsyncSession, conv_id: int, - resource_id: str, title: str, resource_type: str, - content: str | None = None, url: str | None = None, + db: AsyncSession, + conv_id: int, + resource_id: str, + title: str, + resource_type: str, + content: str | None = None, + url: str | None = None, ) -> ConversationResource: """Create or update a resource by resource_id.""" existing = await get_resource_by_id(db, resource_id) @@ -398,7 +540,8 @@ async def upsert_resource( await db.refresh(existing) return existing return await add_conversation_resource( - db, conv_id, + db, + conv_id, title=title, resource_type=resource_type, content=content, @@ -409,6 +552,7 @@ async def upsert_resource( # ---- Agent Jobs ---- + async def create_agent_job( db: AsyncSession, conv_id: int, @@ -417,6 +561,7 @@ async def create_agent_job( mode: str | None = None, ) -> AgentJob: import uuid as uuid_mod + job = AgentJob( job_id=str(uuid_mod.uuid4()), conversation_id=conv_id, @@ -432,9 +577,7 @@ async def create_agent_job( async def get_agent_job(db: AsyncSession, job_id: str) -> AgentJob | None: - result = await db.execute( - select(AgentJob).where(AgentJob.job_id == job_id) - ) + result = await db.execute(select(AgentJob).where(AgentJob.job_id == job_id)) return result.scalar_one_or_none() @@ -478,6 +621,7 @@ async def update_job_status( # ---- User Settings ---- + async def get_user_settings(db: AsyncSession, user_id: int, category: str | None = None) -> dict: """Get user settings as a dict. Optionally filter by category.""" query = select(UserSetting).where(UserSetting.user_id == user_id) @@ -495,10 +639,7 @@ async def get_user_settings(db: AsyncSession, user_id: int, category: str | None async def get_user_agent_settings(db: AsyncSession, user_id: int) -> dict: """Get user's agent settings (default_model, research_model, yolo_mode).""" result = await db.execute( - select(UserSetting).where( - UserSetting.user_id == user_id, - UserSetting.category == "agent" - ) + select(UserSetting).where(UserSetting.user_id == user_id, UserSetting.category == "agent") ) settings = {} for s in result.scalars().all(): @@ -508,9 +649,15 @@ async def get_user_agent_settings(db: AsyncSession, user_id: int) -> dict: # ---- SSH Keys ---- + async def create_ssh_key( - db: AsyncSession, user_id: int, filename: str, fingerprint: str, - algorithm: str, public_key: str, comment: str | None = None, + db: AsyncSession, + user_id: int, + filename: str, + fingerprint: str, + algorithm: str, + public_key: str, + comment: str | None = None, ) -> SSHKey: key = SSHKey( user_id=user_id, @@ -554,9 +701,15 @@ async def delete_ssh_key(db: AsyncSession, user_id: int, filename: str) -> bool: # ---- Compute Nodes ---- + async def create_compute_node( - db: AsyncSession, user_id: int, name: str, node_type: str, config: dict, - is_default: bool = False, priority: int = 0, + db: AsyncSession, + user_id: int, + name: str, + node_type: str, + config: dict, + is_default: bool = False, + priority: int = 0, ) -> ComputeNode: node = ComputeNode( user_id=user_id, @@ -574,12 +727,16 @@ async def create_compute_node( async def get_compute_nodes(db: AsyncSession, user_id: int) -> list[ComputeNode]: result = await db.execute( - select(ComputeNode).where(ComputeNode.user_id == user_id).order_by(ComputeNode.priority.desc(), ComputeNode.created_at.desc()) + select(ComputeNode) + .where(ComputeNode.user_id == user_id) + .order_by(ComputeNode.priority.desc(), ComputeNode.created_at.desc()) ) return list(result.scalars().all()) -async def get_compute_node_by_id(db: AsyncSession, node_id: int, user_id: int | None = None) -> ComputeNode | None: +async def get_compute_node_by_id( + db: AsyncSession, node_id: int, user_id: int | None = None +) -> ComputeNode | None: query = select(ComputeNode).where(ComputeNode.id == node_id) if user_id is not None: query = query.where(ComputeNode.user_id == user_id) @@ -595,7 +752,10 @@ async def get_compute_node_by_name(db: AsyncSession, user_id: int, name: str) -> async def update_compute_node( - db: AsyncSession, node_id: int, user_id: int, **kwargs, + db: AsyncSession, + node_id: int, + user_id: int, + **kwargs, ) -> ComputeNode | None: result = await db.execute( select(ComputeNode).where(ComputeNode.id == node_id, ComputeNode.user_id == user_id) diff --git a/backend/openmlr/routes/projects.py b/backend/openmlr/routes/projects.py new file mode 100644 index 0000000..ae37f64 --- /dev/null +++ b/backend/openmlr/routes/projects.py @@ -0,0 +1,545 @@ +"""Project routes — CRUD, file tree, file operations. + +Security: +- Path traversal prevention via relative_to() (not str.startswith) +- Upload/write size limits +- Symlink-aware rmtree +- No server-side paths leaked to API responses +- Workspace root fallback uses restrictive permissions +""" + +import json +import logging +import mimetypes +import os +import re +import shutil +import uuid as uuid_mod +from pathlib import Path + +from fastapi import APIRouter, Depends, HTTPException, Request, UploadFile +from fastapi.responses import FileResponse +from sqlalchemy.ext.asyncio import AsyncSession + +from ..db import operations as ops +from ..db.engine import get_db +from ..db.models import User +from ..dependencies import get_current_user + +router = APIRouter(prefix="/api/projects", tags=["projects"]) + +log = logging.getLogger(__name__) + +# Size limits +MAX_UPLOAD_BYTES = 100 * 1024 * 1024 # 100 MB +MAX_WRITE_BYTES = 10 * 1024 * 1024 # 10 MB + + +def _get_workspaces_root() -> Path: + """Get the workspace root directory, falling back to a temp dir if needed.""" + configured = os.environ.get("OPENMLR_WORKSPACES_PATH") + if configured: + return Path(configured) + default = Path("/app/.workspaces") + if default.parent.exists(): + return default + # Fallback for non-Docker environments (tests, native dev) + import tempfile + + fallback = Path(tempfile.gettempdir()) / "openmlr-workspaces" + fallback.mkdir(parents=True, exist_ok=True, mode=0o700) + return fallback + + +WORKSPACES_ROOT = _get_workspaces_root() + + +def _slugify(name: str) -> str: + """Generate a filesystem-safe slug from a project name.""" + slug = name.lower().strip() + slug = re.sub(r"[^\w\s-]", "", slug) + slug = re.sub(r"[\s_]+", "-", slug) + slug = re.sub(r"-+", "-", slug).strip("-") + return slug[:60] or "project" + + +def _project_dict(project, conv_count: int | None = None) -> dict: + """Serialize a project for the API. No server-side paths exposed.""" + d = { + "id": project.id, + "uuid": project.uuid, + "name": project.name, + "slug": project.slug, + "description": project.description, + "status": project.status, + "settings": project.settings or {}, + "created_at": project.created_at.isoformat() if project.created_at else None, + "updated_at": project.updated_at.isoformat() if project.updated_at else None, + } + if conv_count is not None: + d["conversation_count"] = conv_count + return d + + +def _ensure_workspace(workspace_path: str) -> Path: + """Ensure workspace directory and standard subdirs exist.""" + ws = Path(workspace_path) + ws.mkdir(parents=True, exist_ok=True) + for subdir in [ + "code", + "data", + "models", + "outputs", + "papers", + "research", + "research/searches", + "research/notes", + "research/citations", + "logs", + "logs/tool_failures", + "logs/compute", + "logs/experiments", + "venvs", + ".project-meta", + ".project-meta/plans", + ]: + (ws / subdir).mkdir(parents=True, exist_ok=True) + return ws + + +def _safe_resolve(workspace_path: str, relative_path: str) -> Path: + """Resolve a relative path within the workspace, preventing traversal attacks. + + Uses Path.relative_to() for correct containment checking (not str.startswith). + """ + ws = Path(workspace_path).resolve() + target = (ws / relative_path).resolve() + try: + target.relative_to(ws) + except ValueError: + raise HTTPException(status_code=400, detail="Path traversal not allowed") + return target + + +def _safe_rmtree(target: Path, workspace_path: str) -> None: + """Remove a directory tree, refusing to follow symlinks that escape the workspace.""" + ws = Path(workspace_path).resolve() + + # Check for symlinks that point outside workspace before deleting + for root, dirs, files in os.walk(str(target)): + root_path = Path(root) + for name in dirs + files: + item = root_path / name + if item.is_symlink(): + link_target = item.resolve() + try: + link_target.relative_to(ws) + except ValueError: + raise HTTPException( + status_code=400, + detail="Cannot delete: contains symlink to outside workspace", + ) + + shutil.rmtree(target) + + +# ── Project CRUD ───────────────────────────────────────── + + +@router.get("") +async def list_projects( + include_archived: bool = False, + user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + """List all projects for the current user.""" + projects = await ops.get_user_projects(db, user.id, include_archived=include_archived) + result = [] + for p in projects: + convs = await ops.get_project_conversations(db, p.id) + result.append(_project_dict(p, conv_count=len(convs))) + return {"projects": result} + + +@router.post("") +async def create_project( + request: Request, + user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + """Create a new project with a workspace directory.""" + body = await request.json() + name = body.get("name", "").strip() + description = body.get("description", "").strip() or None + + if not name: + raise HTTPException(status_code=400, detail="Missing 'name'") + + slug = _slugify(name) + + # Check for duplicate slug + existing = await ops.get_project_by_slug(db, user.id, slug) + if existing: + slug = f"{slug}-{str(uuid_mod.uuid4())[:6]}" + + # Create workspace directory + workspace_path = str(WORKSPACES_ROOT / slug) + _ensure_workspace(workspace_path) + + # Write initial project metadata + meta_path = Path(workspace_path) / ".project-meta" / "project.json" + meta_path.write_text( + json.dumps( + { + "name": name, + "slug": slug, + "description": description, + "created_by": user.username, + }, + indent=2, + ) + ) + + # Initialize empty knowledge graph + kg_path = Path(workspace_path) / ".project-meta" / "knowledge.json" + kg_path.write_text( + json.dumps( + { + "nodes": [], + "edges": [], + "version": 1, + }, + indent=2, + ) + ) + + project = await ops.create_project( + db, + user.id, + name, + slug, + description=description, + workspace_path=workspace_path, + settings=body.get("settings"), + ) + + return {"project": _project_dict(project, conv_count=0)} + + +@router.get("/{project_uuid}") +async def get_project( + project_uuid: str, + user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + """Get project details including conversation count.""" + project = await ops.get_project_by_uuid(db, project_uuid, user.id) + if not project: + raise HTTPException(status_code=404, detail="Project not found") + convs = await ops.get_project_conversations(db, project.id) + return {"project": _project_dict(project, conv_count=len(convs))} + + +@router.put("/{project_uuid}") +async def update_project( + project_uuid: str, + request: Request, + user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + """Update project name, description, or settings.""" + project = await ops.get_project_by_uuid(db, project_uuid, user.id) + if not project: + raise HTTPException(status_code=404, detail="Project not found") + + body = await request.json() + updates = {} + if "name" in body: + updates["name"] = body["name"].strip() + if "description" in body: + updates["description"] = body["description"].strip() or None + if "settings" in body: + updates["settings"] = body["settings"] + + updated = await ops.update_project(db, project.id, user.id, **updates) + return {"project": _project_dict(updated)} + + +@router.delete("/{project_uuid}") +async def delete_project( + project_uuid: str, + user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + """Archive a project (soft delete). Workspace files are preserved.""" + project = await ops.get_project_by_uuid(db, project_uuid, user.id) + if not project: + raise HTTPException(status_code=404, detail="Project not found") + await ops.archive_project(db, project.id, user.id) + return {"ok": True} + + +@router.get("/{project_uuid}/conversations") +async def list_project_conversations( + project_uuid: str, + user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + """List all conversations within a project.""" + project = await ops.get_project_by_uuid(db, project_uuid, user.id) + if not project: + raise HTTPException(status_code=404, detail="Project not found") + convs = await ops.get_project_conversations(db, project.id) + return { + "conversations": [ + { + "id": c.id, + "uuid": c.uuid, + "title": c.title, + "model": c.model, + "mode": c.mode, + "user_message_count": c.user_message_count, + "created_at": c.created_at.isoformat() if c.created_at else None, + "updated_at": c.updated_at.isoformat() if c.updated_at else None, + } + for c in convs + ] + } + + +@router.post("/{project_uuid}/attach/{conversation_uuid}") +async def attach_conversation( + project_uuid: str, + conversation_uuid: str, + user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + """Attach an existing conversation to a project.""" + project = await ops.get_project_by_uuid(db, project_uuid, user.id) + if not project: + raise HTTPException(status_code=404, detail="Project not found") + conv = await ops.get_conversation_by_uuid(db, conversation_uuid) + if not conv or conv.user_id != user.id: + raise HTTPException(status_code=404, detail="Conversation not found") + await ops.attach_conversation_to_project(db, conv.id, project.id, user.id) + return {"ok": True} + + +@router.post("/{project_uuid}/detach/{conversation_uuid}") +async def detach_conversation( + project_uuid: str, + conversation_uuid: str, + user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + """Detach a conversation from a project.""" + # Verify both project and conversation ownership + project = await ops.get_project_by_uuid(db, project_uuid, user.id) + if not project: + raise HTTPException(status_code=404, detail="Project not found") + conv = await ops.get_conversation_by_uuid(db, conversation_uuid) + if not conv or conv.user_id != user.id: + raise HTTPException(status_code=404, detail="Conversation not found") + await ops.attach_conversation_to_project(db, conv.id, None, user.id) + return {"ok": True} + + +# ── File Tree & File Operations ────────────────────────── + + +@router.get("/{project_uuid}/files") +async def list_files( + project_uuid: str, + path: str = "", + user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + """List files and directories in the project workspace.""" + project = await ops.get_project_by_uuid(db, project_uuid, user.id) + if not project or not project.workspace_path: + raise HTTPException(status_code=404, detail="Project not found") + + target = _safe_resolve(project.workspace_path, path) + if not target.exists(): + raise HTTPException(status_code=404, detail="Path not found") + if not target.is_dir(): + raise HTTPException(status_code=400, detail="Not a directory") + + entries = [] + try: + for item in sorted( + target.iterdir(), + key=lambda p: (not p.is_dir(), p.name.lower()), + ): + # Skip hidden files except .project-meta + if item.name.startswith(".") and item.name != ".project-meta": + continue + try: + stat = item.stat(follow_symlinks=False) + except OSError: + continue + entries.append( + { + "name": item.name, + "path": str(item.relative_to(Path(project.workspace_path))), + "is_dir": item.is_dir(), + "size": stat.st_size if item.is_file() else None, + "modified": stat.st_mtime, + } + ) + except PermissionError as exc: + raise HTTPException(status_code=403, detail="Permission denied") from exc + + return {"path": path, "entries": entries} + + +@router.get("/{project_uuid}/files/{file_path:path}") +async def read_file( + project_uuid: str, + file_path: str, + user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + """Read a file from the project workspace.""" + project = await ops.get_project_by_uuid(db, project_uuid, user.id) + if not project or not project.workspace_path: + raise HTTPException(status_code=404, detail="Project not found") + + target = _safe_resolve(project.workspace_path, file_path) + if not target.exists(): + raise HTTPException(status_code=404, detail="File not found") + if target.is_dir(): + return await list_files(project_uuid, file_path, user, db) + + # Reject symlinks that point outside workspace + if target.is_symlink(): + try: + target.resolve().relative_to(Path(project.workspace_path).resolve()) + except ValueError: + raise HTTPException(status_code=400, detail="Symlink points outside workspace") + + # For text files, return content as JSON + mime, _ = mimetypes.guess_type(str(target)) + is_text = ( + mime is None + or mime.startswith("text/") + or mime in ("application/json", "application/xml", "application/x-yaml") + ) + + if is_text: + try: + content = target.read_text(encoding="utf-8", errors="replace") + if len(content) > 500_000: + content = content[:500_000] + "\n\n[... truncated at 500KB ...]" + return { + "path": file_path, + "content": content, + "size": target.stat().st_size, + } + except Exception: + pass + + return FileResponse(str(target), filename=target.name) + + +@router.put("/{project_uuid}/files/{file_path:path}") +async def write_file( + project_uuid: str, + file_path: str, + request: Request, + user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + """Write content to a file in the project workspace.""" + project = await ops.get_project_by_uuid(db, project_uuid, user.id) + if not project or not project.workspace_path: + raise HTTPException(status_code=404, detail="Project not found") + + target = _safe_resolve(project.workspace_path, file_path) + + body = await request.json() + content = body.get("content", "") + + # Enforce write size limit + if len(content) > MAX_WRITE_BYTES: + raise HTTPException( + status_code=413, + detail=f"Content too large (max {MAX_WRITE_BYTES // 1024 // 1024}MB)", + ) + + target.parent.mkdir(parents=True, exist_ok=True) + target.write_text(content, encoding="utf-8") + + return {"ok": True, "path": file_path, "size": target.stat().st_size} + + +@router.delete("/{project_uuid}/files/{file_path:path}") +async def delete_file( + project_uuid: str, + file_path: str, + user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + """Delete a file or directory from the project workspace.""" + project = await ops.get_project_by_uuid(db, project_uuid, user.id) + if not project or not project.workspace_path: + raise HTTPException(status_code=404, detail="Project not found") + + target = _safe_resolve(project.workspace_path, file_path) + if not target.exists(): + raise HTTPException(status_code=404, detail="File not found") + + # Prevent deleting top-level standard dirs + ws = Path(project.workspace_path) + rel = target.relative_to(ws) + protected = { + "code", + "data", + "models", + "outputs", + "papers", + "research", + "logs", + ".project-meta", + } + if str(rel) in protected: + raise HTTPException( + status_code=400, + detail="Cannot delete standard workspace directory", + ) + + if target.is_dir(): + _safe_rmtree(target, project.workspace_path) + else: + target.unlink() + + return {"ok": True} + + +@router.post("/{project_uuid}/upload/{file_path:path}") +async def upload_file( + project_uuid: str, + file_path: str, + file: UploadFile, + user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + """Upload a file to the project workspace.""" + project = await ops.get_project_by_uuid(db, project_uuid, user.id) + if not project or not project.workspace_path: + raise HTTPException(status_code=404, detail="Project not found") + + target = _safe_resolve(project.workspace_path, file_path) + + # Read with size limit to prevent OOM + content = await file.read(MAX_UPLOAD_BYTES + 1) + if len(content) > MAX_UPLOAD_BYTES: + raise HTTPException( + status_code=413, + detail=f"File too large (max {MAX_UPLOAD_BYTES // 1024 // 1024}MB)", + ) + + target.parent.mkdir(parents=True, exist_ok=True) + target.write_bytes(content) + + return {"ok": True, "path": file_path, "size": len(content)} diff --git a/backend/openmlr/routes/terminal.py b/backend/openmlr/routes/terminal.py new file mode 100644 index 0000000..dbd2028 --- /dev/null +++ b/backend/openmlr/routes/terminal.py @@ -0,0 +1,275 @@ +"""Terminal WebSocket endpoint — interactive PTY connected to compute resource. + +Provides a real terminal experience via xterm.js on the frontend, +connected to the project workspace's compute environment. + +Security: +- Minimal environment (no server secrets leaked) +- Workspace path validated against WORKSPACES_ROOT +- Shell spawned via subprocess (not os.fork) to avoid async corruption +- --norc --noprofile to prevent .bashrc injection +- Proper zombie process cleanup with SIGKILL escalation +""" + +import asyncio +import fcntl +import json +import logging +import os +import pty +import signal +import struct +import subprocess +import termios +from pathlib import Path + +from fastapi import APIRouter, Query, WebSocket, WebSocketDisconnect + +from ..auth.security import decode_access_token +from ..db import operations as ops +from ..db.engine import get_async_session +from ..db.models import User + +router = APIRouter(tags=["terminal"]) + +log = logging.getLogger(__name__) + +# Allowlisted environment variables for the PTY process. +# Server secrets (DATABASE_URL, API keys, JWT_SECRET_KEY, etc.) are NOT passed. +_SAFE_ENV_KEYS = {"LANG", "LC_ALL", "LC_CTYPE", "TZ"} + + +def _build_safe_env(workspace_path: str) -> dict[str, str]: + """Build a minimal, safe environment for the PTY child process.""" + env = { + "TERM": "xterm-256color", + "HOME": workspace_path, + "PWD": workspace_path, + "PATH": "/usr/local/bin:/usr/bin:/bin:/usr/sbin:/sbin", + "SHELL": "/bin/bash", + "USER": "openmlr", + } + # Copy only safe locale/timezone vars from parent + for key in _SAFE_ENV_KEYS: + val = os.environ.get(key) + if val: + env[key] = val + return env + + +def _validate_workspace_path(workspace_path: str) -> bool: + """Validate that a workspace path is within the expected root.""" + from .projects import WORKSPACES_ROOT + + try: + resolved = Path(workspace_path).resolve() + resolved.relative_to(WORKSPACES_ROOT.resolve()) + return True + except (ValueError, RuntimeError): + return False + + +async def _authenticate_ws(token: str | None) -> User | None: + """Authenticate WebSocket connection via token query param.""" + if not token: + return None + + payload = decode_access_token(token) + if not payload: + return None + + async with get_async_session() as db: + from sqlalchemy import select + + result = await db.execute( + select(User).where( + User.id == int(payload["sub"]), + User.is_active == True, # noqa: E712 + ) + ) + return result.scalar_one_or_none() + + +async def _cleanup_process(pid: int, master_fd: int) -> None: + """Clean up PTY process with SIGKILL escalation to prevent zombies.""" + # Close the master fd first + try: + os.close(master_fd) + except OSError: + pass + + if pid <= 0: + return + + # Send SIGTERM and wait with timeout + try: + os.kill(pid, signal.SIGTERM) + except ProcessLookupError: + return + + # Poll up to 2 seconds for graceful exit + for _ in range(20): + try: + result, _ = os.waitpid(pid, os.WNOHANG) + if result != 0: + return # Process exited + except ChildProcessError: + return # Already reaped + await asyncio.sleep(0.1) + + # Escalate to SIGKILL + try: + os.kill(pid, signal.SIGKILL) + os.waitpid(pid, 0) # Blocking wait after SIGKILL + except (ProcessLookupError, ChildProcessError): + pass + + +@router.websocket("/api/terminal/{project_uuid}") +async def terminal_websocket( + websocket: WebSocket, + project_uuid: str, + token: str = Query(default=None), +): + """WebSocket endpoint for interactive terminal sessions. + + Spawns a PTY process in the project workspace directory. + Messages from the client are written to the PTY stdin. + Output from the PTY is sent back to the client. + + Special messages (JSON): + - {"type": "resize", "cols": 80, "rows": 24} - resize the terminal + - {"type": "input", "data": "..."} - send input to the PTY + - Plain text messages are treated as input + """ + # Authenticate + user = await _authenticate_ws(token) + if not user: + await websocket.close(code=4001, reason="Unauthorized") + return + + # Look up the project to get the workspace path + async with get_async_session() as db: + project = await ops.get_project_by_uuid(db, project_uuid, user.id) + if not project or not project.workspace_path: + await websocket.close(code=4004, reason="Project not found") + return + workspace_path = project.workspace_path + + # Validate workspace path is within allowed root + if not _validate_workspace_path(workspace_path): + log.warning( + f"Terminal rejected: workspace path {workspace_path} " + f"is outside allowed root (user={user.id})" + ) + await websocket.close(code=4003, reason="Invalid workspace path") + return + + # Verify workspace exists + if not Path(workspace_path).exists(): + await websocket.close(code=4004, reason="Workspace not found") + return + + await websocket.accept() + + # Spawn PTY using subprocess instead of os.fork() to avoid + # corrupting the async event loop and leaking file descriptors. + master_fd, slave_fd = pty.openpty() + env = _build_safe_env(workspace_path) + shell = "/bin/bash" + + try: + proc = subprocess.Popen( + [shell, "--norc", "--noprofile"], + stdin=slave_fd, + stdout=slave_fd, + stderr=slave_fd, + cwd=workspace_path, + env=env, + start_new_session=True, + close_fds=True, + ) + pid = proc.pid + except Exception as e: + log.error(f"Failed to spawn terminal: {e}") + os.close(master_fd) + os.close(slave_fd) + await websocket.close(code=4500, reason="Failed to spawn shell") + return + + # Close slave fd in parent — only the child uses it + os.close(slave_fd) + + # Set master fd to non-blocking + flags = fcntl.fcntl(master_fd, fcntl.F_GETFL) + fcntl.fcntl(master_fd, fcntl.F_SETFL, flags | os.O_NONBLOCK) + + async def read_pty(): + """Read from PTY and send to WebSocket.""" + loop = asyncio.get_event_loop() + try: + while True: + try: + data = await loop.run_in_executor(None, lambda: os.read(master_fd, 4096)) + if not data: + break + await websocket.send_bytes(data) + except OSError: + break + except WebSocketDisconnect: + break + except Exception as e: + log.debug(f"PTY read ended: {e}") + + async def write_pty(): + """Read from WebSocket and write to PTY.""" + try: + while True: + msg = await websocket.receive() + if msg.get("type") == "websocket.disconnect": + break + + if "text" in msg: + try: + data = json.loads(msg["text"]) + if isinstance(data, dict): + if data.get("type") == "resize": + cols = min(int(data.get("cols", 80)), 500) + rows = min(int(data.get("rows", 24)), 200) + winsize = struct.pack("HHHH", rows, cols, 0, 0) + fcntl.ioctl(master_fd, termios.TIOCSWINSZ, winsize) + continue + elif data.get("type") == "input": + input_data = data.get("data", "") + if isinstance(input_data, str): + os.write(master_fd, input_data.encode()[:4096]) + continue + except (json.JSONDecodeError, ValueError): + pass + # Plain text input — cap at 4KB per message + os.write(master_fd, msg["text"].encode()[:4096]) + + elif "bytes" in msg: + os.write(master_fd, msg["bytes"][:4096]) + + except WebSocketDisconnect: + pass + except Exception as e: + log.debug(f"PTY write ended: {e}") + + # Run reader and writer concurrently + try: + reader_task = asyncio.create_task(read_pty()) + writer_task = asyncio.create_task(write_pty()) + done, pending = await asyncio.wait( + [reader_task, writer_task], + return_when=asyncio.FIRST_COMPLETED, + ) + for task in pending: + task.cancel() + finally: + await _cleanup_process(pid, master_fd) + try: + await websocket.close() + except Exception: + pass diff --git a/backend/openmlr/sandbox/local.py b/backend/openmlr/sandbox/local.py index 4b578d3..2968f15 100644 --- a/backend/openmlr/sandbox/local.py +++ b/backend/openmlr/sandbox/local.py @@ -15,19 +15,30 @@ class LocalSandbox(SandboxInterface): def __init__(self, workdir: str = None, workspace_manager=None): self._workspace_manager = workspace_manager self._conversation_uuid = None + self._project_workspace = None # project workspace path (takes priority) self.workdir = workdir or os.getcwd() async def create(self, config: dict) -> "LocalSandbox": self.workdir = config.get("workdir", os.getcwd()) self._conversation_uuid = config.get("conversation_uuid") + self._project_workspace = config.get("project_workspace_path") - # If workspace manager is available and conversation UUID is set, - # use the per-conversation workspace - if self._workspace_manager and self._conversation_uuid: + # Priority: project workspace > conversation workspace > default workdir + if self._project_workspace: + # Validate the project workspace path is within the allowed root + ws_path = Path(self._project_workspace).resolve() + from ..compute.workspace import WORKSPACES_ROOT + + try: + ws_path.relative_to(WORKSPACES_ROOT.resolve()) + except ValueError: + raise ValueError("Project workspace path is outside allowed root") + ws_path.mkdir(parents=True, exist_ok=True) + self.workdir = str(ws_path) + elif self._workspace_manager and self._conversation_uuid: ws_path = self._workspace_manager.create_workspace(self._conversation_uuid) self.workdir = str(ws_path) elif self._workspace_manager: - # Fallback: create workspace without UUID ws_path = self._workspace_manager.create_workspace("default") self.workdir = str(ws_path) @@ -36,7 +47,9 @@ async def create(self, config: dict) -> "LocalSandbox": async def execute(self, command: str, timeout: int = 120) -> ExecutionResult: return await self.execute_stream(command, timeout) - async def execute_stream(self, command: str, timeout: int = 120, on_chunk=None) -> ExecutionResult: + async def execute_stream( + self, command: str, timeout: int = 120, on_chunk=None + ) -> ExecutionResult: """Execute a command with optional streaming output.""" start = time.monotonic() try: @@ -99,16 +112,27 @@ async def _read_stream(stream, is_stderr): except Exception as e: return ExecutionResult(output=f"Error: {str(e)}", success=False, exit_code=-1) - async def read_file(self, path: str) -> str: + def _resolve_path(self, path: str) -> Path: + """Resolve a path relative to workdir, preventing traversal outside it.""" target = Path(path).expanduser() if not target.is_absolute(): target = Path(self.workdir) / target + resolved = target.resolve() + workdir_resolved = Path(self.workdir).resolve() + # Allow paths within workdir, or absolute paths if no project workspace + if self._project_workspace: + try: + resolved.relative_to(workdir_resolved) + except ValueError: + raise PermissionError("Access denied: path outside workspace") + return resolved + + async def read_file(self, path: str) -> str: + target = self._resolve_path(path) return target.read_text(encoding="utf-8", errors="replace") async def write_file(self, path: str, content: str) -> bool: - target = Path(path).expanduser() - if not target.is_absolute(): - target = Path(self.workdir) / target + target = self._resolve_path(path) target.parent.mkdir(parents=True, exist_ok=True) target.write_text(content, encoding="utf-8") return True @@ -122,19 +146,12 @@ async def edit_file(self, path: str, old: str, new: str) -> bool: return True async def file_exists(self, path: str) -> bool: - target = Path(path).expanduser() - if not target.is_absolute(): - target = Path(self.workdir) / target + target = self._resolve_path(path) return target.exists() async def list_files(self, path: str = ".") -> list[str]: - target = Path(path).expanduser() - if not target.is_absolute(): - target = Path(self.workdir) / target - return sorted([ - f"{e.name}{'/' if e.is_dir() else ''}" - for e in target.iterdir() - ]) + target = self._resolve_path(path) + return sorted([f"{e.name}{'/' if e.is_dir() else ''}" for e in target.iterdir()]) async def probe_environment(self): return await probe_sandbox(self) diff --git a/backend/openmlr/sandbox/manager.py b/backend/openmlr/sandbox/manager.py index 32060a0..7dd893b 100644 --- a/backend/openmlr/sandbox/manager.py +++ b/backend/openmlr/sandbox/manager.py @@ -1,4 +1,10 @@ -"""SandboxManager — lifecycle management and provider selection.""" +"""SandboxManager — lifecycle management and provider selection. + +The sandbox handles code execution on a compute resource. +The workspace (project-scoped) is decoupled: it persists independently +of which compute resource is active. The sandbox receives the workspace +path so it can operate within the project directory. +""" from .interface import SandboxInterface from .local import LocalSandbox @@ -7,13 +13,24 @@ class SandboxManager: - """Manages sandbox lifecycle: create, switch, destroy.""" + """Manages sandbox lifecycle: create, switch, destroy. + + Workspace and compute are decoupled: + - project_workspace_path: persistent project directory (survives compute changes) + - provider/config: determines WHERE code executes (local, ssh, modal) + """ - def __init__(self, workspace_manager=None, conversation_uuid: str = None): + def __init__( + self, + workspace_manager=None, + conversation_uuid: str = None, + project_workspace_path: str = None, + ): self._active: SandboxInterface | None = None self.active_type: str = "none" self._workspace_manager = workspace_manager self._conversation_uuid = conversation_uuid + self._project_workspace_path = project_workspace_path def get_active(self) -> SandboxInterface | None: return self._active @@ -28,6 +45,8 @@ async def create(self, provider: str, config: dict = None) -> SandboxInterface: # Inject workspace and conversation context config["conversation_uuid"] = self._conversation_uuid + if self._project_workspace_path: + config["project_workspace_path"] = self._project_workspace_path if provider == "local": sandbox = LocalSandbox(workspace_manager=self._workspace_manager) diff --git a/backend/openmlr/tools/papers.py b/backend/openmlr/tools/papers.py index ffd97db..37561dc 100644 --- a/backend/openmlr/tools/papers.py +++ b/backend/openmlr/tools/papers.py @@ -20,7 +20,7 @@ OPENALEX_API = "https://api.openalex.org" SEMANTIC_SCHOLAR_API = "https://api.semanticscholar.org/graph/v1" CROSSREF_API = "https://api.crossref.org" -ARXIV_API = "http://export.arxiv.org/api/query" +ARXIV_API = "https://export.arxiv.org/api/query" AR5IV_BASE = "https://ar5iv.labs.arxiv.org/html" PWC_API = "https://paperswithcode.com/api/v1" @@ -72,8 +72,17 @@ def create_papers_tool() -> ToolSpec: "operation": { "type": "string", "enum": [ - "search", "arxiv_search", "semantic_search", "trending", "details", "read_paper", - "citations", "recommend", "find_code", "find_datasets", "author_papers", + "search", + "arxiv_search", + "semantic_search", + "trending", + "details", + "read_paper", + "citations", + "recommend", + "find_code", + "find_datasets", + "author_papers", ], "description": ( "Operation to perform: " @@ -126,22 +135,36 @@ def create_papers_tool() -> ToolSpec: _search_counts: dict[int, int] = {} # session hash -> count _BUDGET_DEFAULT = 25 + def _check_budget(session=None) -> tuple[bool, str]: """Check if search budget allows another API call. Returns (ok, message).""" key = id(session) if session else 0 count = _search_counts.get(key, 0) - budget = session.config.paper_search_budget if session and hasattr(session, 'config') else _BUDGET_DEFAULT + budget = ( + session.config.paper_search_budget + if session and hasattr(session, "config") + else _BUDGET_DEFAULT + ) if count >= budget: - return False, f"Search budget exhausted ({count}/{budget} calls). Ask the user before continuing." + return ( + False, + f"Search budget exhausted ({count}/{budget} calls). Ask the user before continuing.", + ) return True, "" + def _increment_budget(session=None): key = id(session) if session else 0 _search_counts[key] = _search_counts.get(key, 0) + 1 + def _get_budget_info(session=None) -> dict: key = id(session) if session else 0 - budget = session.config.paper_search_budget if session and hasattr(session, 'config') else _BUDGET_DEFAULT + budget = ( + session.config.paper_search_budget + if session and hasattr(session, "config") + else _BUDGET_DEFAULT + ) return {"used": _search_counts.get(key, 0), "max": budget} @@ -158,7 +181,18 @@ async def _handle_papers( **kwargs, ) -> tuple[str, bool]: # Budget check for API-calling operations - api_ops = {"search", "arxiv_search", "semantic_search", "trending", "details", "citations", "recommend", "find_code", "find_datasets", "author_papers"} + api_ops = { + "search", + "arxiv_search", + "semantic_search", + "trending", + "details", + "citations", + "recommend", + "find_code", + "find_datasets", + "author_papers", + } if operation in api_ops: ok, msg = _check_budget(session) if not ok: @@ -167,10 +201,13 @@ async def _handle_papers( # Emit budget update if session: from ..agent.types import AgentEvent - await session.emit(AgentEvent( - event_type="search_budget", - data=_get_budget_info(session), - )) + + await session.emit( + AgentEvent( + event_type="search_budget", + data=_get_budget_info(session), + ) + ) handlers = { "search": lambda: _search(query, year_from, year_to, limit, source), @@ -196,7 +233,10 @@ async def _handle_papers( # ── Search (OpenAlex with S2 fallback) ──────────────────────────────────── -async def _search(query: str, year_from: int = None, year_to: int = None, limit: int = 10, source: str = "auto") -> tuple[str, bool]: + +async def _search( + query: str, year_from: int = None, year_to: int = None, limit: int = 10, source: str = "auto" +) -> tuple[str, bool]: if not query: return "Provide a 'query' for search.", False @@ -219,7 +259,9 @@ async def _search(query: str, year_from: int = None, year_to: int = None, limit: return "Invalid source specified", False -async def _openalex_search(query: str, year_from: int = None, year_to: int = None, limit: int = 10) -> tuple[str, bool]: +async def _openalex_search( + query: str, year_from: int = None, year_to: int = None, limit: int = 10 +) -> tuple[str, bool]: """Search using OpenAlex API with retry logic.""" params = _get_openalex_params({"search": query, "per_page": min(limit, 50)}) @@ -254,7 +296,9 @@ async def _openalex_search(query: str, year_from: int = None, year_to: int = Non total = r.json().get("meta", {}).get("count", len(works)) lines = [f"Found {total} papers for '{query}' (via OpenAlex):\n"] for i, w in enumerate(works, 1): - authors = ", ".join(a.get("author", {}).get("display_name", "") for a in (w.get("authorships") or [])[:3]) + authors = ", ".join( + a.get("author", {}).get("display_name", "") for a in (w.get("authorships") or [])[:3] + ) if len(w.get("authorships", [])) > 3: authors += " et al." doi = (w.get("doi") or "").replace("https://doi.org/", "") @@ -270,7 +314,10 @@ async def _openalex_search(query: str, year_from: int = None, year_to: int = Non # ── arXiv Search ──────────────────────────────────── -async def _arxiv_search(query: str, year_from: int = None, year_to: int = None, limit: int = 10) -> tuple[str, bool]: + +async def _arxiv_search( + query: str, year_from: int = None, year_to: int = None, limit: int = 10 +) -> tuple[str, bool]: """Search arXiv papers directly. Great for ML/CS/Physics preprints.""" if not query: return "Provide a 'query' for search.", False @@ -367,7 +414,10 @@ async def _arxiv_search(query: str, year_from: int = None, year_to: int = None, # ── Semantic Scholar Search ──────────────────────────────────── -async def _semantic_scholar_search(query: str, year_from: int = None, year_to: int = None, limit: int = 10) -> tuple[str, bool]: + +async def _semantic_scholar_search( + query: str, year_from: int = None, year_to: int = None, limit: int = 10 +) -> tuple[str, bool]: """Search using Semantic Scholar API with retry logic.""" if not query: return "Provide a 'query' for search.", False @@ -400,13 +450,19 @@ async def _semantic_scholar_search(query: str, year_from: int = None, year_to: i max_retries=3, ) except RateLimitError: - return "Semantic Scholar rate limit reached. Try again later or add SEMANTIC_SCHOLAR_API_KEY.", False + return ( + "Semantic Scholar rate limit reached. Try again later or add SEMANTIC_SCHOLAR_API_KEY.", + False, + ) except Exception as e: log.warning(f"Semantic Scholar search error: {e}") return f"Semantic Scholar error: {str(e)[:200]}", False if r.status_code == 429: - return "Semantic Scholar rate limit reached. Try again later or add SEMANTIC_SCHOLAR_API_KEY.", False + return ( + "Semantic Scholar rate limit reached. Try again later or add SEMANTIC_SCHOLAR_API_KEY.", + False, + ) if r.status_code != 200: return f"Semantic Scholar error {r.status_code}: {r.text[:300]}", False @@ -441,12 +497,15 @@ async def _semantic_scholar_search(query: str, year_from: int = None, year_to: i # ── Trending (OpenAlex) ────────────────────────────────── + async def _trending(query: str = None, limit: int = 10) -> tuple[str, bool]: - params = _get_openalex_params({ - "sort": "cited_by_count:desc", - "filter": "from_publication_date:2024-01-01", - "per_page": min(limit, 50), - }) + params = _get_openalex_params( + { + "sort": "cited_by_count:desc", + "filter": "from_publication_date:2024-01-01", + "per_page": min(limit, 50), + } + ) if query: params["search"] = query @@ -472,7 +531,9 @@ async def _trending(query: str = None, limit: int = 10) -> tuple[str, bool]: lines = [f"Trending papers{f' on: {query}' if query else ''}:\n"] for i, w in enumerate(works, 1): - authors = ", ".join(a.get("author", {}).get("display_name", "") for a in (w.get("authorships") or [])[:3]) + authors = ", ".join( + a.get("author", {}).get("display_name", "") for a in (w.get("authorships") or [])[:3] + ) lines.append( f"{i}. **{w.get('title', 'Untitled')}** ({w.get('publication_year', '?')})\n" f" {authors} | {w.get('cited_by_count', 0)} citations\n" @@ -482,6 +543,7 @@ async def _trending(query: str = None, limit: int = 10) -> tuple[str, bool]: # ── Details (OpenAlex + CrossRef) ───────────────────────── + async def _details(paper_id: str) -> tuple[str, bool]: if not paper_id: return "Provide a 'paper_id'.", False @@ -511,7 +573,9 @@ async def _details(paper_id: str) -> tuple[str, bool]: return f"Paper not found: {paper_id}", False w = r.json() - authors = ", ".join(a.get("author", {}).get("display_name", "") for a in (w.get("authorships") or [])) + authors = ", ".join( + a.get("author", {}).get("display_name", "") for a in (w.get("authorships") or []) + ) doi = (w.get("doi") or "").replace("https://doi.org/", "") oa_url = (w.get("open_access") or {}).get("oa_url", "") arxiv_id = _extract_arxiv_from_ids(w.get("ids", {})) @@ -555,8 +619,12 @@ async def _crossref_details(doi: str) -> tuple[str, bool]: w = r.json().get("message", {}) title = (w.get("title") or ["Untitled"])[0] - authors = ", ".join(f"{a.get('given', '')} {a.get('family', '')}" for a in (w.get("author") or [])) - year = (w.get("published-print") or w.get("published-online") or {}).get("date-parts", [[None]])[0][0] + authors = ", ".join( + f"{a.get('given', '')} {a.get('family', '')}" for a in (w.get("author") or []) + ) + year = (w.get("published-print") or w.get("published-online") or {}).get( + "date-parts", [[None]] + )[0][0] lines = [ f"# {title}", @@ -571,6 +639,7 @@ async def _crossref_details(doi: str) -> tuple[str, bool]: # ── Read Paper (ArXiv HTML via ar5iv) ───────────────────── + async def _read_paper(paper_id: str, section: str = None) -> tuple[str, bool]: if not paper_id: return "Provide a 'paper_id' (arXiv ID like '2301.12345').", False @@ -594,6 +663,7 @@ async def _read_paper(paper_id: str, section: str = None) -> tuple[str, bool]: return f"Failed to fetch paper HTML (status {r.status_code}).", False from bs4 import BeautifulSoup + soup = BeautifulSoup(r.text, "lxml") sections = _parse_sections(soup) @@ -620,6 +690,7 @@ async def _read_paper(paper_id: str, section: str = None) -> tuple[str, bool]: # ── Citations (OpenAlex) ────────────────────────────────── + async def _citations(paper_id: str, limit: int = 10) -> tuple[str, bool]: if not paper_id: return "Provide a 'paper_id'.", False @@ -645,7 +716,9 @@ async def _citations(paper_id: str, limit: int = 10) -> tuple[str, bool]: w = r.json() ref_ids = w.get("referenced_works", [])[:limit] - lines = [f"## References ({len(w.get('referenced_works', []))} total, showing {len(ref_ids)})\n"] + lines = [ + f"## References ({len(w.get('referenced_works', []))} total, showing {len(ref_ids)})\n" + ] # Batch-fetch referenced works if ref_ids: @@ -671,11 +744,13 @@ async def _citations(paper_id: str, limit: int = 10) -> tuple[str, bool]: try: r3 = await fetch_with_retry( f"{OPENALEX_API}/works", - params=_get_openalex_params({ - "filter": f"cites:{oa_id}", - "sort": "cited_by_count:desc", - "per_page": limit, - }), + params=_get_openalex_params( + { + "filter": f"cites:{oa_id}", + "sort": "cited_by_count:desc", + "per_page": limit, + } + ), timeout=20, max_retries=2, ) @@ -693,6 +768,7 @@ async def _citations(paper_id: str, limit: int = 10) -> tuple[str, bool]: # ── Recommendations (OpenAlex related_works) ────────────── + async def _recommend(paper_id: str, limit: int = 10) -> tuple[str, bool]: if not paper_id: return "Provide a 'paper_id'.", False @@ -736,7 +812,9 @@ async def _recommend(paper_id: str, limit: int = 10) -> tuple[str, bool]: lines = ["## Related Papers\n"] for i, w in enumerate(r2.json().get("results", []), 1): - authors = ", ".join(a.get("author", {}).get("display_name", "") for a in (w.get("authorships") or [])[:3]) + authors = ", ".join( + a.get("author", {}).get("display_name", "") for a in (w.get("authorships") or [])[:3] + ) lines.append( f"{i}. **{w.get('title', 'Untitled')}** ({w.get('publication_year', '?')})\n" f" {authors} | {w.get('cited_by_count', 0)} citations\n" @@ -746,6 +824,7 @@ async def _recommend(paper_id: str, limit: int = 10) -> tuple[str, bool]: # ── Find Code (Papers With Code) ───────────────────────── + async def _find_code(query: str) -> tuple[str, bool]: if not query: return "Provide a query.", False @@ -781,6 +860,7 @@ async def _find_code(query: str) -> tuple[str, bool]: # ── Find Datasets (Papers With Code) ───────────────────── + async def _find_datasets(query: str) -> tuple[str, bool]: if not query: return "Provide a query.", False @@ -815,6 +895,7 @@ async def _find_datasets(query: str) -> tuple[str, bool]: # ── Author Papers (Semantic Scholar) ───────────────────── + async def _author_papers(author_query: str, limit: int = 10) -> tuple[str, bool]: """Find papers by a specific author using Semantic Scholar.""" if not author_query: @@ -833,13 +914,19 @@ async def _author_papers(author_query: str, limit: int = 10) -> tuple[str, bool] max_retries=3, ) except RateLimitError: - return "Semantic Scholar rate limit reached. Try again later or add SEMANTIC_SCHOLAR_API_KEY.", False + return ( + "Semantic Scholar rate limit reached. Try again later or add SEMANTIC_SCHOLAR_API_KEY.", + False, + ) except Exception as e: log.warning(f"Semantic Scholar author search error: {e}") return f"Author search error: {str(e)[:200]}", False if r.status_code == 429: - return "Semantic Scholar rate limit reached. Try again later or add SEMANTIC_SCHOLAR_API_KEY.", False + return ( + "Semantic Scholar rate limit reached. Try again later or add SEMANTIC_SCHOLAR_API_KEY.", + False, + ) if r.status_code != 200: return f"Author search error {r.status_code}: {r.text[:300]}", False @@ -867,7 +954,10 @@ async def _author_papers(author_query: str, limit: int = 10) -> tuple[str, bool] max_retries=3, ) except RateLimitError: - return "Semantic Scholar rate limit reached. Try again later or add SEMANTIC_SCHOLAR_API_KEY.", False + return ( + "Semantic Scholar rate limit reached. Try again later or add SEMANTIC_SCHOLAR_API_KEY.", + False, + ) except Exception as e: log.warning(f"Semantic Scholar author papers error: {e}") return f"Error fetching author papers: {str(e)[:200]}", False @@ -904,6 +994,7 @@ async def _author_papers(author_query: str, limit: int = 10) -> tuple[str, bool] # ── Helpers ─────────────────────────────────────────────── + def _to_openalex_id(paper_id: str) -> str: """Convert various IDs to OpenAlex lookup format.""" if paper_id.startswith("W") or paper_id.startswith("https://openalex.org/"): @@ -917,10 +1008,10 @@ def _to_openalex_id(paper_id: str) -> str: def _extract_arxiv_id(text: str) -> str | None: - match = re.search(r'(\d{4}\.\d{4,5}(?:v\d+)?)', text) + match = re.search(r"(\d{4}\.\d{4,5}(?:v\d+)?)", text) if match: return match.group(1) - match = re.search(r'arxiv\.org/(?:abs|pdf)/(\d{4}\.\d{4,5}(?:v\d+)?)', text) + match = re.search(r"arxiv\.org/(?:abs|pdf)/(\d{4}\.\d{4,5}(?:v\d+)?)", text) if match: return match.group(1) return None @@ -955,11 +1046,13 @@ def _parse_sections(soup) -> list[dict]: abstract = soup.find("div", class_="ltx_abstract") if abstract: - sections.append({ - "title": "Abstract", - "text": abstract.get_text(strip=True).replace("Abstract", "", 1).strip(), - "level": 2, - }) + sections.append( + { + "title": "Abstract", + "text": abstract.get_text(strip=True).replace("Abstract", "", 1).strip(), + "level": 2, + } + ) for heading in soup.find_all(["h2", "h3", "h4"]): level = int(heading.name[1]) diff --git a/backend/openmlr/tools/registry.py b/backend/openmlr/tools/registry.py index 67f9997..1d8bc28 100644 --- a/backend/openmlr/tools/registry.py +++ b/backend/openmlr/tools/registry.py @@ -10,15 +10,28 @@ "plan": { # Plan mode: ask questions, create plans, read context — NO execution tools "allowed": { - "ask_user", "plan_tool", + "ask_user", + "plan_tool", # Read-only tools for gathering context - "read_file", "list_dir", "glob_files", "grep_search", - "web_search", "papers", - "github_search", "github_read_file", "github_read_repo", - "github_find_examples", "github_search_repos", "github_get_readme", + "read_file", + "list_dir", + "glob_files", + "grep_search", + "web_search", + "papers", + "github_search", + "github_read_file", + "github_read_repo", + "github_find_examples", + "github_search_repos", + "github_get_readme", "github_list_repos", # Compute planning (read-only / advisory) - "compute_list", "compute_plan", "compute_probe", + "compute_list", + "compute_plan", + "compute_probe", + # Workspace (knowledge graph, notes, search — always accessible) + "workspace", }, "blocked_message": ( "Tool '{tool}' is not available in PLAN mode. " @@ -87,7 +100,9 @@ def is_tool_allowed(self, name: str) -> tuple[bool, str]: blocked_tools = restrictions.get("blocked", set()) if blocked_tools: if name in blocked_tools: - error_msg = restrictions.get("blocked_message", "Tool '{tool}' not allowed in this mode.") + error_msg = restrictions.get( + "blocked_message", "Tool '{tool}' not allowed in this mode." + ) return False, error_msg.format(tool=name, mode=self._current_mode) return True, "" @@ -116,14 +131,16 @@ def get_tool_specs_for_llm(self, filter_by_mode: bool = True) -> list[dict]: if not allowed: continue - specs.append({ - "type": "function", - "function": { - "name": tool.name, - "description": tool.description, - "parameters": tool.parameters, - }, - }) + specs.append( + { + "type": "function", + "function": { + "name": tool.name, + "description": tool.description, + "parameters": tool.parameters, + }, + } + ) return specs def get_raw_specs(self) -> list[ToolSpec]: @@ -175,7 +192,10 @@ async def call_tool( return await tool.handler(**kwargs) if kwargs else await tool.handler(**arguments) except TypeError as e: # Handle argument mismatches (model sending wrong param names) - return f"Tool argument error: {e}. Expected parameters: {list(sig.parameters.keys())}", False + return ( + f"Tool argument error: {e}. Expected parameters: {list(sig.parameters.keys())}", + False, + ) # MCP tool (no handler — dispatch to MCP client) if self._mcp_client: @@ -253,11 +273,18 @@ def create_tool_router(sandbox_manager=None) -> ToolRouter: # Register compute tools from .compute_tools import create_compute_tools + router.register_many(create_compute_tools()) + # Register workspace tools + from .workspace_tools import create_workspace_tools + + router.register_many(create_workspace_tools()) + # Register sandbox tools if manager provided if sandbox_manager: from .sandbox_tools import create_sandbox_tools + router.register_many(create_sandbox_tools(sandbox_manager)) return router diff --git a/backend/openmlr/tools/workspace_tools.py b/backend/openmlr/tools/workspace_tools.py new file mode 100644 index 0000000..2789a0f --- /dev/null +++ b/backend/openmlr/tools/workspace_tools.py @@ -0,0 +1,383 @@ +"""Workspace tools — project workspace operations for the agent. + +Provides tools for the agent to interact with the project workspace: +- View workspace status and file tree +- Search files in workspace +- Save research notes +- Read/update the knowledge graph +- Log tool failures +""" + +import json +import logging +from contextvars import ContextVar + +from ..agent.types import ToolSpec +from ..workspace.knowledge import KnowledgeGraph +from ..workspace.persistence import WorkspacePersistence + +log = logging.getLogger(__name__) + +# Per-async-context workspace references — safe for concurrent sessions. +# ContextVar ensures each request/task has its own workspace context, +# preventing cross-user contamination in the async server. +_workspace_path_var: ContextVar[str | None] = ContextVar("workspace_path", default=None) +_persistence_var: ContextVar[WorkspacePersistence | None] = ContextVar("persistence", default=None) +_knowledge_var: ContextVar[KnowledgeGraph | None] = ContextVar("knowledge", default=None) + + +def set_workspace_context(workspace_path: str | None) -> None: + """Set the project workspace path for the current async context.""" + _workspace_path_var.set(workspace_path) + if workspace_path: + _persistence_var.set(WorkspacePersistence(workspace_path)) + _knowledge_var.set(KnowledgeGraph(workspace_path)) + else: + _persistence_var.set(None) + _knowledge_var.set(None) + + +def _require_workspace() -> tuple[WorkspacePersistence, KnowledgeGraph]: + """Ensure workspace is configured for the current context.""" + persistence = _persistence_var.get() + knowledge = _knowledge_var.get() + if not persistence or not knowledge: + raise ValueError("No project workspace is active. Create or select a project first.") + return persistence, knowledge + + +async def _handle_workspace( + operation: str, + # workspace_status + # workspace_search + query: str = "", + # workspace_note + topic: str = "", + content: str = "", + # knowledge_add + entity_id: str = "", + entity_type: str = "", + label: str = "", + properties: str = "", + # knowledge_relate + source_id: str = "", + target_id: str = "", + relationship: str = "", + # knowledge_query + # knowledge_summary + session=None, + **kwargs, +) -> tuple[str, bool]: + """Handle workspace tool operations.""" + try: + if operation == "status": + return await _workspace_status(session) + elif operation == "search": + return await _workspace_search(query) + elif operation == "note": + return await _workspace_note(topic, content, session) + elif operation == "knowledge_add": + return await _knowledge_add(entity_id, entity_type, label, properties, session) + elif operation == "knowledge_relate": + return await _knowledge_relate(source_id, target_id, relationship, session) + elif operation == "knowledge_query": + return await _knowledge_query(query) + elif operation == "knowledge_summary": + return await _knowledge_summary() + elif operation == "recent_failures": + return await _recent_failures() + else: + return f"Unknown workspace operation: {operation}", False + except ValueError as e: + return str(e), False + except Exception as e: + log.warning(f"Workspace tool error ({operation}): {e}") + return "Workspace operation failed. Check server logs for details.", False + + +async def _workspace_status(session=None) -> tuple[str, bool]: + """Get workspace status and summary.""" + persistence, knowledge = _require_workspace() + + summary = persistence.get_workspace_summary() + kg_summary = knowledge.get_summary() + + lines = [ + "## Workspace Status", + "", + f"**Papers:** {summary['papers']}", + f"**Research notes:** {summary['research_notes']}", + f"**Search results saved:** {summary['search_results']}", + f"**Code files:** {summary['code_files']}", + f"**Experiments logged:** {summary['experiments']}", + f"**Tool failures logged:** {summary['tool_failures']}", + "", + "### Knowledge Graph", + f"Entities: {kg_summary['total_nodes']} | Relationships: {kg_summary['total_edges']}", + ] + + if kg_summary.get("type_counts"): + lines.append( + "Types: " + ", ".join(f"{t}: {c}" for t, c in kg_summary["type_counts"].items()) + ) + + if summary.get("recent_tool_failures"): + lines.append("\n### Recent Tool Failures") + for f in summary["recent_tool_failures"]: + lines.append(f"- **{f['tool']}**: {f['error'][:100]}") + + state = persistence.get_state() + if state.get("key_findings"): + lines.append("\n### Key Findings") + for finding in state["key_findings"][-5:]: + lines.append(f"- {finding}") + + if state.get("open_questions"): + lines.append("\n### Open Questions") + for q in state["open_questions"][-5:]: + lines.append(f"- {q}") + + return "\n".join(lines), True + + +async def _workspace_search(query: str) -> tuple[str, bool]: + """Search files in workspace by name or content.""" + import os + + persistence, _ = _require_workspace() + + if not query: + return "Please provide a search query.", False + + results = [] + query_lower = query.lower() + ws_path = persistence.workspace_path + + # Limits to prevent DoS from deeply nested or very large workspaces + max_depth = 8 + max_files_scanned = 5000 + files_scanned = 0 + ws_path_str = str(ws_path) + + for dirpath, dirnames, filenames in os.walk(ws_path): + # Enforce depth limit + depth = dirpath[len(ws_path_str) :].count(os.sep) + if depth >= max_depth: + dirnames.clear() # Don't descend further + continue + + for fname in filenames: + if files_scanned >= max_files_scanned: + break + if fname.startswith("."): + continue + + files_scanned += 1 + fpath = os.path.join(dirpath, fname) + rel_path = os.path.relpath(fpath, ws_path) + + # Name match + if query_lower in fname.lower(): + results.append(f"- **{rel_path}** (name match)") + continue + + # Content match (text files only, skip large files) + try: + if os.path.getsize(fpath) > 500_000: + continue + with open(fpath, encoding="utf-8", errors="ignore") as f: + content = f.read(10000) + if query_lower in content.lower(): + results.append(f"- **{rel_path}** (content match)") + except Exception: + continue + + if files_scanned >= max_files_scanned: + break + + if not results: + return f"No files found matching '{query}'.", True + + return f"## Search Results for '{query}'\n\n" + "\n".join(results[:30]), True + + +async def _workspace_note(topic: str, content: str, session=None) -> tuple[str, bool]: + """Save a research note to the workspace.""" + persistence, _ = _require_workspace() + + if not topic or not content: + return "Please provide both 'topic' and 'content' for the note.", False + + conv_uuid = getattr(session, "conversation_uuid", None) if session else None + filepath = persistence.save_research_note(topic, content, conv_uuid) + + return f"Research note saved: {filepath.name}", True + + +async def _knowledge_add( + entity_id: str, + entity_type: str, + label: str, + properties: str, + session=None, +) -> tuple[str, bool]: + """Add an entity to the knowledge graph.""" + _, knowledge = _require_workspace() + + if not entity_id or not entity_type or not label: + return "Please provide entity_id, entity_type, and label.", False + + props = {} + if properties: + try: + props = json.loads(properties) + except json.JSONDecodeError: + return "Invalid JSON in properties.", False + + conv_uuid = getattr(session, "conversation_uuid", None) if session else None + is_new = knowledge.add_entity(entity_id, entity_type, label, props, conv_uuid) + knowledge.save() + + action = "Added" if is_new else "Updated" + return f"{action} entity: {label} ({entity_type})", True + + +async def _knowledge_relate( + source_id: str, + target_id: str, + relationship: str, + session=None, +) -> tuple[str, bool]: + """Add a relationship between entities in the knowledge graph.""" + _, knowledge = _require_workspace() + + if not source_id or not target_id or not relationship: + return "Please provide source_id, target_id, and relationship.", False + + conv_uuid = getattr(session, "conversation_uuid", None) if session else None + success = knowledge.add_relationship( + source_id, target_id, relationship, conversation_uuid=conv_uuid + ) + if success: + knowledge.save() + return f"Added relationship: {source_id} --[{relationship}]--> {target_id}", True + return "Failed to add relationship. Ensure both entities exist.", False + + +async def _knowledge_query(query: str) -> tuple[str, bool]: + """Search the knowledge graph.""" + _, knowledge = _require_workspace() + + if not query: + return "Please provide a search query.", False + + results = knowledge.search_entities(query) + if not results: + return f"No entities found matching '{query}'.", True + + lines = [f"## Knowledge Graph: '{query}'\n"] + for entity in results: + lines.append(f"- **{entity.get('label', entity['id'])}** ({entity.get('type', '?')})") + neighbors = knowledge.get_neighbors(entity["id"]) + for n in neighbors[:5]: + lines.append(f" - {n.get('relationship', '?')} -> {n.get('label', n['id'])}") + + return "\n".join(lines), True + + +async def _knowledge_summary() -> tuple[str, bool]: + """Get a full knowledge graph summary for context.""" + _, knowledge = _require_workspace() + context = knowledge.get_context_for_conversation() + if not context: + return "Knowledge graph is empty.", True + return context, True + + +async def _recent_failures() -> tuple[str, bool]: + """Get recent tool failure logs.""" + persistence, _ = _require_workspace() + failures = persistence.get_recent_failures(limit=10) + if not failures: + return "No recent tool failures.", True + + lines = ["## Recent Tool Failures\n"] + for f in failures: + lines.append(f"- **{f['tool']}** ({f.get('timestamp', '?')}): {f['error'][:200]}") + return "\n".join(lines), True + + +def create_workspace_tools() -> list[ToolSpec]: + """Create workspace tool specs.""" + return [ + ToolSpec( + name="workspace", + description=( + "Interact with the project workspace — persistent storage for research data, " + "knowledge graph, notes, and logs.\n\n" + "Operations:\n" + "- status: View workspace summary (file counts, knowledge graph size, recent failures)\n" + "- search: Search files by name or content (requires 'query')\n" + "- note: Save a research note (requires 'topic' and 'content')\n" + "- knowledge_add: Add entity to knowledge graph (requires 'entity_id', 'entity_type', 'label'; optional 'properties' as JSON)\n" + "- knowledge_relate: Add relationship (requires 'source_id', 'target_id', 'relationship')\n" + "- knowledge_query: Search knowledge graph (requires 'query')\n" + "- knowledge_summary: Get full knowledge graph context\n" + "- recent_failures: View recent tool/API failure logs\n\n" + "Entity types: paper, concept, method, dataset, finding, question, experiment, tool, author, code_artifact\n" + "Relationship types: cites, implements, evaluates_on, proposes, introduces, relates_to, answers, depends_on, uses, produces, contradicts, extends" + ), + parameters={ + "type": "object", + "properties": { + "operation": { + "type": "string", + "enum": [ + "status", + "search", + "note", + "knowledge_add", + "knowledge_relate", + "knowledge_query", + "knowledge_summary", + "recent_failures", + ], + "description": "The workspace operation to perform.", + }, + "query": { + "type": "string", + "description": "Search query (for search, knowledge_query).", + }, + "topic": {"type": "string", "description": "Note topic (for note)."}, + "content": {"type": "string", "description": "Note content (for note)."}, + "entity_id": { + "type": "string", + "description": "Entity ID (for knowledge_add).", + }, + "entity_type": { + "type": "string", + "description": "Entity type (for knowledge_add).", + }, + "label": {"type": "string", "description": "Entity label (for knowledge_add)."}, + "properties": { + "type": "string", + "description": "JSON string of additional properties (for knowledge_add).", + }, + "source_id": { + "type": "string", + "description": "Source entity ID (for knowledge_relate).", + }, + "target_id": { + "type": "string", + "description": "Target entity ID (for knowledge_relate).", + }, + "relationship": { + "type": "string", + "description": "Relationship type (for knowledge_relate).", + }, + }, + "required": ["operation"], + }, + handler=_handle_workspace, + ), + ] diff --git a/backend/openmlr/workspace/__init__.py b/backend/openmlr/workspace/__init__.py new file mode 100644 index 0000000..c475db8 --- /dev/null +++ b/backend/openmlr/workspace/__init__.py @@ -0,0 +1,6 @@ +"""Workspace package — project-scoped persistence, knowledge graph, and data logging.""" + +from .knowledge import KnowledgeGraph +from .persistence import WorkspacePersistence + +__all__ = ["KnowledgeGraph", "WorkspacePersistence"] diff --git a/backend/openmlr/workspace/knowledge.py b/backend/openmlr/workspace/knowledge.py new file mode 100644 index 0000000..db5746a --- /dev/null +++ b/backend/openmlr/workspace/knowledge.py @@ -0,0 +1,375 @@ +"""Knowledge Graph — lightweight persistent knowledge store backed by networkx. + +Stores entities (papers, concepts, methods, datasets, findings) and their +relationships as a directed graph. Serialized as JSON in the project workspace. + +The graph enables: +- Cross-conversation knowledge accumulation +- Context injection when starting new conversations +- Finding related prior work within a project +- Tracking what the agent knows vs. doesn't know +""" + +import json +import logging +from datetime import UTC, datetime +from pathlib import Path + +import networkx as nx + +log = logging.getLogger(__name__) + +# Node types for the knowledge graph +NODE_TYPES = { + "paper", + "concept", + "method", + "dataset", + "finding", + "question", + "experiment", + "tool", + "author", + "code_artifact", +} + +# Edge types (relationships) +EDGE_TYPES = { + "cites", # paper -> paper + "implements", # code_artifact -> method + "evaluates_on", # experiment -> dataset + "proposes", # paper -> method + "introduces", # paper -> dataset + "relates_to", # any -> any + "answers", # finding -> question + "depends_on", # method -> method, code -> code + "authored_by", # paper -> author + "uses", # experiment -> method + "produces", # experiment -> finding + "contradicts", # finding -> finding + "extends", # method -> method +} + + +# Size limits to prevent DoS via unbounded graph growth +MAX_NODES = 10_000 +MAX_EDGES = 50_000 + + +class KnowledgeGraph: + """A persistent knowledge graph for a project workspace. + + Uses networkx DiGraph internally and serializes to JSON. + Thread-safe for single-writer (agent loop is single-threaded per conversation). + """ + + def __init__(self, workspace_path: str | Path): + self.workspace_path = Path(workspace_path) + self.kg_path = self.workspace_path / ".project-meta" / "knowledge.json" + self._graph: nx.DiGraph = nx.DiGraph() + self._dirty = False + self._load() + + def _load(self) -> None: + """Load the knowledge graph from disk.""" + if not self.kg_path.exists(): + self._graph = nx.DiGraph() + return + + try: + data = json.loads(self.kg_path.read_text(encoding="utf-8")) + if data.get("nodes") or data.get("edges"): + self._graph = nx.DiGraph() + for node in data.get("nodes", []): + node_id = node.get("id") + if not node_id: + log.warning("Skipping node without 'id' in knowledge graph") + continue + attrs = {k: v for k, v in node.items() if k != "id"} + self._graph.add_node(node_id, **attrs) + for edge in data.get("edges", []): + src = edge.get("source") + tgt = edge.get("target") + if not src or not tgt: + log.warning("Skipping edge without source/target in knowledge graph") + continue + attrs = {k: v for k, v in edge.items() if k not in ("source", "target")} + self._graph.add_edge(src, tgt, **attrs) + else: + self._graph = nx.DiGraph() + except Exception as e: + log.warning(f"Failed to load knowledge graph: {e}") + self._graph = nx.DiGraph() + + def save(self) -> None: + """Persist the knowledge graph to disk.""" + if not self._dirty and self.kg_path.exists(): + return + + self.kg_path.parent.mkdir(parents=True, exist_ok=True) + + nodes = [] + for node_id, attrs in self._graph.nodes(data=True): + nodes.append({"id": node_id, **attrs}) + + edges = [] + for src, tgt, attrs in self._graph.edges(data=True): + edges.append({"source": src, "target": tgt, **attrs}) + + data = { + "version": 1, + "updated_at": datetime.now(UTC).isoformat(), + "node_count": len(nodes), + "edge_count": len(edges), + "nodes": nodes, + "edges": edges, + } + + self.kg_path.write_text(json.dumps(data, indent=2, default=str), encoding="utf-8") + self._dirty = False + + # ── Node operations ────────────────────────────────── + + # Reserved attribute names that properties cannot overwrite + _RESERVED_ATTRS = {"type", "label", "created_at", "updated_at", "source_conversation", "id"} + + def add_entity( + self, + entity_id: str, + entity_type: str, + label: str, + properties: dict | None = None, + conversation_uuid: str | None = None, + ) -> bool: + """Add or update an entity node. + + Returns True if the entity was newly added, False if updated. + Validates entity_type against NODE_TYPES. + Enforces MAX_NODES limit. + """ + if entity_type not in NODE_TYPES: + log.warning(f"Invalid entity type '{entity_type}', using 'concept'") + entity_type = "concept" + + is_new = entity_id not in self._graph + if is_new and self._graph.number_of_nodes() >= MAX_NODES: + log.warning(f"Knowledge graph at capacity ({MAX_NODES} nodes)") + return False + + attrs = { + "type": entity_type, + "label": label, + "updated_at": datetime.now(UTC).isoformat(), + } + if is_new: + attrs["created_at"] = datetime.now(UTC).isoformat() + if conversation_uuid: + attrs["source_conversation"] = conversation_uuid + if properties: + # Filter out reserved keys to prevent internal field overwrite + safe_props = {k: v for k, v in properties.items() if k not in self._RESERVED_ATTRS} + attrs.update(safe_props) + + self._graph.add_node(entity_id, **attrs) + self._dirty = True + return is_new + + def get_entity(self, entity_id: str) -> dict | None: + """Get an entity by ID.""" + if entity_id not in self._graph: + return None + return {"id": entity_id, **self._graph.nodes[entity_id]} + + def find_entities(self, entity_type: str | None = None, limit: int = 50) -> list[dict]: + """Find entities, optionally filtered by type.""" + results = [] + for node_id, attrs in self._graph.nodes(data=True): + if entity_type and attrs.get("type") != entity_type: + continue + results.append({"id": node_id, **attrs}) + if len(results) >= limit: + break + return results + + def search_entities(self, query: str, limit: int = 20) -> list[dict]: + """Search entities by label (case-insensitive substring match).""" + query_lower = query.lower() + results = [] + for node_id, attrs in self._graph.nodes(data=True): + label = attrs.get("label", "") + if query_lower in label.lower() or query_lower in node_id.lower(): + results.append({"id": node_id, **attrs}) + if len(results) >= limit: + break + return results + + def remove_entity(self, entity_id: str) -> bool: + """Remove an entity and all its edges.""" + if entity_id not in self._graph: + return False + self._graph.remove_node(entity_id) + self._dirty = True + return True + + # ── Edge operations ────────────────────────────────── + + def add_relationship( + self, + source_id: str, + target_id: str, + relationship: str, + properties: dict | None = None, + conversation_uuid: str | None = None, + ) -> bool: + """Add a directed relationship between two entities. + + Both entities must already exist. Returns True if edge was newly added. + Validates relationship against EDGE_TYPES. + Enforces MAX_EDGES limit. + """ + if relationship not in EDGE_TYPES: + log.warning(f"Invalid relationship type '{relationship}', using 'relates_to'") + relationship = "relates_to" + + if source_id not in self._graph or target_id not in self._graph: + log.warning( + f"Cannot add edge {source_id}->{target_id}: " + f"missing {'source' if source_id not in self._graph else 'target'}" + ) + return False + + is_new = not self._graph.has_edge(source_id, target_id) + if is_new and self._graph.number_of_edges() >= MAX_EDGES: + log.warning(f"Knowledge graph at edge capacity ({MAX_EDGES} edges)") + return False + + attrs = { + "type": relationship, + "updated_at": datetime.now(UTC).isoformat(), + } + if is_new: + attrs["created_at"] = datetime.now(UTC).isoformat() + if conversation_uuid: + attrs["source_conversation"] = conversation_uuid + if properties: + safe_props = {k: v for k, v in properties.items() if k not in self._RESERVED_ATTRS} + attrs.update(safe_props) + + self._graph.add_edge(source_id, target_id, **attrs) + self._dirty = True + return is_new + + def get_neighbors(self, entity_id: str, direction: str = "both") -> list[dict]: + """Get connected entities. + + Args: + entity_id: The entity to find neighbors for. + direction: "out" (successors), "in" (predecessors), or "both". + """ + if entity_id not in self._graph: + return [] + + neighbors = set() + if direction in ("out", "both"): + neighbors.update(self._graph.successors(entity_id)) + if direction in ("in", "both"): + neighbors.update(self._graph.predecessors(entity_id)) + + results = [] + for nid in neighbors: + edge_data = self._graph.edges.get((entity_id, nid), {}) or self._graph.edges.get( + (nid, entity_id), {} + ) + results.append( + { + "id": nid, + **self._graph.nodes[nid], + "relationship": edge_data.get("type", "relates_to"), + } + ) + return results + + # ── Query helpers ──────────────────────────────────── + + def get_summary(self) -> dict: + """Get a summary of the knowledge graph for context injection.""" + type_counts: dict[str, int] = {} + for _, attrs in self._graph.nodes(data=True): + t = attrs.get("type", "unknown") + type_counts[t] = type_counts.get(t, 0) + 1 + + # Get recent entities (by updated_at) + recent = sorted( + [{"id": nid, **attrs} for nid, attrs in self._graph.nodes(data=True)], + key=lambda x: x.get("updated_at", ""), + reverse=True, + )[:10] + + return { + "total_nodes": self._graph.number_of_nodes(), + "total_edges": self._graph.number_of_edges(), + "type_counts": type_counts, + "recent_entities": [ + {"id": e["id"], "type": e.get("type"), "label": e.get("label")} for e in recent + ], + } + + def get_context_for_conversation(self, max_tokens_approx: int = 2000) -> str: + """Generate a text summary of the knowledge graph for injecting into agent context. + + Produces a compact representation suitable for the system prompt. + """ + if self._graph.number_of_nodes() == 0: + return "" + + lines = ["## Project Knowledge Graph\n"] + + # Group by type + by_type: dict[str, list] = {} + for nid, attrs in self._graph.nodes(data=True): + t = attrs.get("type", "other") + by_type.setdefault(t, []).append((nid, attrs)) + + char_count = 0 + for entity_type, entities in by_type.items(): + if char_count > max_tokens_approx * 4: # rough char estimate + lines.append( + f"\n... and more ({self._graph.number_of_nodes() - len(lines)} entities)" + ) + break + + lines.append(f"\n### {entity_type.replace('_', ' ').title()}s") + for nid, attrs in entities[:15]: # cap per type + label = attrs.get("label", nid) + line = f"- **{label}**" + # Add key properties + if attrs.get("abstract"): + line += f": {attrs['abstract'][:150]}..." + elif attrs.get("description"): + line += f": {attrs['description'][:150]}..." + lines.append(line) + char_count += len(line) + + # Add key relationships + if self._graph.number_of_edges() > 0: + lines.append("\n### Key Relationships") + edge_count = 0 + for src, tgt, attrs in self._graph.edges(data=True): + if edge_count >= 20: + lines.append(f"... and {self._graph.number_of_edges() - edge_count} more") + break + src_label = self._graph.nodes[src].get("label", src) + tgt_label = self._graph.nodes[tgt].get("label", tgt) + rel = attrs.get("type", "relates_to") + lines.append(f"- {src_label} --[{rel}]--> {tgt_label}") + edge_count += 1 + + return "\n".join(lines) + + @property + def node_count(self) -> int: + return self._graph.number_of_nodes() + + @property + def edge_count(self) -> int: + return self._graph.number_of_edges() diff --git a/backend/openmlr/workspace/persistence.py b/backend/openmlr/workspace/persistence.py new file mode 100644 index 0000000..1ca330d --- /dev/null +++ b/backend/openmlr/workspace/persistence.py @@ -0,0 +1,353 @@ +"""Workspace Persistence — file-based storage for project working data. + +Handles saving/loading of: +- Search results (paper searches, web searches) +- Research notes and summaries +- Tool failure logs +- Compute capability snapshots +- Experiment logs +- Cross-conversation state +""" + +import json +import logging +from datetime import UTC, datetime +from pathlib import Path + +log = logging.getLogger(__name__) + + +class WorkspacePersistence: + """File-based persistence for a project workspace.""" + + def __init__(self, workspace_path: str | Path): + self.workspace_path = Path(workspace_path) + if not self.workspace_path.exists(): + log.warning(f"Workspace path does not exist: {workspace_path}") + + def _ensure_dir(self, *parts: str) -> Path: + """Ensure a subdirectory exists and return its path.""" + path = self.workspace_path.joinpath(*parts) + path.mkdir(parents=True, exist_ok=True) + return path + + def _timestamp(self) -> str: + return datetime.now(UTC).strftime("%Y%m%d-%H%M%S") + + @staticmethod + def _sanitize_filename(name: str, max_len: int = 80) -> str: + """Sanitize a string for safe use in filenames. Alphanumeric + hyphen/underscore only.""" + return "".join(c if c.isalnum() or c in "-_" else "_" for c in name)[:max_len] or "unknown" + + # ── Search Results ─────────────────────────────────── + + def save_search_results( + self, + query: str, + source: str, + results: list[dict], + conversation_uuid: str | None = None, + ) -> Path: + """Save paper/web search results to workspace.""" + dir_path = self._ensure_dir("research", "searches") + filename = f"{self._timestamp()}_{self._sanitize_filename(source)}.json" + filepath = dir_path / filename + + data = { + "query": query, + "source": source, + "timestamp": datetime.now(UTC).isoformat(), + "conversation_uuid": conversation_uuid, + "result_count": len(results), + "results": results, + } + filepath.write_text(json.dumps(data, indent=2, default=str), encoding="utf-8") + log.debug(f"Saved {len(results)} search results to {filepath}") + return filepath + + def get_recent_searches(self, limit: int = 10) -> list[dict]: + """Get recent search results (metadata only, not full results).""" + dir_path = self.workspace_path / "research" / "searches" + if not dir_path.exists(): + return [] + + searches = [] + for filepath in sorted(dir_path.glob("*.json"), reverse=True): + try: + data = json.loads(filepath.read_text(encoding="utf-8")) + searches.append( + { + "query": data.get("query"), + "source": data.get("source"), + "timestamp": data.get("timestamp"), + "result_count": data.get("result_count", 0), + "filename": filepath.name, + } + ) + if len(searches) >= limit: + break + except Exception: + continue + return searches + + # ── Research Notes ─────────────────────────────────── + + def save_research_note( + self, + topic: str, + content: str, + conversation_uuid: str | None = None, + ) -> Path: + """Save a research note or summary.""" + dir_path = self._ensure_dir("research", "notes") + # Sanitize topic for filename + safe_topic = "".join(c if c.isalnum() or c in "-_ " else "" for c in topic) + safe_topic = safe_topic.strip().replace(" ", "_")[:100] or "note" + filename = f"{self._timestamp()}_{safe_topic}.md" + filepath = dir_path / filename + + header = f"# {topic}\n\n" + header += f"_Generated: {datetime.now(UTC).isoformat()}_\n" + if conversation_uuid: + header += f"_Conversation: {conversation_uuid}_\n" + header += "\n---\n\n" + + filepath.write_text(header + content, encoding="utf-8") + log.debug(f"Saved research note to {filepath}") + return filepath + + def get_research_notes(self, limit: int = 20) -> list[dict]: + """List available research notes.""" + dir_path = self.workspace_path / "research" / "notes" + if not dir_path.exists(): + return [] + + notes = [] + for filepath in sorted(dir_path.glob("*.md"), reverse=True): + try: + content = filepath.read_text(encoding="utf-8") + # Extract title from first line + title = content.split("\n")[0].lstrip("# ").strip() + notes.append( + { + "title": title, + "filename": filepath.name, + "size": filepath.stat().st_size, + "modified": filepath.stat().st_mtime, + } + ) + if len(notes) >= limit: + break + except Exception: + continue + return notes + + # ── Paper Storage ──────────────────────────────────── + + def save_paper( + self, + paper_id: str, + title: str, + content: str, + metadata: dict | None = None, + ) -> Path: + """Save a parsed paper to the workspace.""" + dir_path = self._ensure_dir("papers") + # Use paper_id as filename (strictly sanitized) + safe_id = self._sanitize_filename(paper_id) + filepath = dir_path / f"{safe_id}.md" + + header = f"# {title}\n\n" + if metadata: + if metadata.get("authors"): + header += f"**Authors:** {metadata['authors']}\n" + if metadata.get("year"): + header += f"**Year:** {metadata['year']}\n" + if metadata.get("url"): + header += f"**URL:** {metadata['url']}\n" + header += "\n---\n\n" + + filepath.write_text(header + content, encoding="utf-8") + + # Save metadata separately as JSON + meta_path = dir_path / f"{safe_id}.meta.json" + meta_data = { + "paper_id": paper_id, + "title": title, + "saved_at": datetime.now(UTC).isoformat(), + **(metadata or {}), + } + meta_path.write_text(json.dumps(meta_data, indent=2, default=str), encoding="utf-8") + + return filepath + + # ── Tool Failure Logs ──────────────────────────────── + + def log_tool_failure( + self, + tool_name: str, + error: str, + args: dict | None = None, + conversation_uuid: str | None = None, + ) -> Path: + """Log a tool/API/MCP failure for future reference.""" + dir_path = self._ensure_dir("logs", "tool_failures") + filename = f"{self._timestamp()}_{self._sanitize_filename(tool_name)}.json" + filepath = dir_path / filename + + data = { + "tool": tool_name, + "error": error, + "args": args, + "timestamp": datetime.now(UTC).isoformat(), + "conversation_uuid": conversation_uuid, + } + filepath.write_text(json.dumps(data, indent=2, default=str), encoding="utf-8") + log.debug(f"Logged tool failure: {tool_name} -> {filepath}") + return filepath + + def get_recent_failures(self, limit: int = 10) -> list[dict]: + """Get recent tool failure logs.""" + dir_path = self.workspace_path / "logs" / "tool_failures" + if not dir_path.exists(): + return [] + + failures = [] + for filepath in sorted(dir_path.glob("*.json"), reverse=True): + try: + data = json.loads(filepath.read_text(encoding="utf-8")) + failures.append(data) + if len(failures) >= limit: + break + except Exception: + continue + return failures + + # ── Compute Logs ───────────────────────────────────── + + def log_compute_probe( + self, + node_name: str, + capabilities: dict, + ) -> Path: + """Log compute node probe results.""" + dir_path = self._ensure_dir("logs", "compute") + filename = f"{self._timestamp()}_{self._sanitize_filename(node_name)}.json" + filepath = dir_path / filename + + data = { + "node_name": node_name, + "capabilities": capabilities, + "probed_at": datetime.now(UTC).isoformat(), + } + filepath.write_text(json.dumps(data, indent=2, default=str), encoding="utf-8") + return filepath + + # ── Experiment Logs ────────────────────────────────── + + def log_experiment( + self, + name: str, + command: str, + result: dict, + conversation_uuid: str | None = None, + ) -> Path: + """Log an experiment execution.""" + dir_path = self._ensure_dir("logs", "experiments") + safe_name = "".join(c if c.isalnum() or c in "-_" else "_" for c in name)[:80] + filename = f"{self._timestamp()}_{safe_name}.json" + filepath = dir_path / filename + + data = { + "name": name, + "command": command, + "result": result, + "timestamp": datetime.now(UTC).isoformat(), + "conversation_uuid": conversation_uuid, + } + filepath.write_text(json.dumps(data, indent=2, default=str), encoding="utf-8") + return filepath + + # ── Cross-conversation State ───────────────────────── + + def get_state(self) -> dict: + """Load the cross-conversation state.""" + state_path = self.workspace_path / ".project-meta" / "state.json" + if not state_path.exists(): + return { + "last_conversation_uuid": None, + "open_questions": [], + "key_findings": [], + "active_experiments": [], + } + try: + return json.loads(state_path.read_text(encoding="utf-8")) + except Exception: + return {} + + def save_state(self, state: dict) -> None: + """Save the cross-conversation state.""" + state_path = self.workspace_path / ".project-meta" / "state.json" + state_path.parent.mkdir(parents=True, exist_ok=True) + state_path.write_text(json.dumps(state, indent=2, default=str), encoding="utf-8") + + def update_state(self, **kwargs) -> dict: + """Update specific fields in the cross-conversation state.""" + state = self.get_state() + state.update(kwargs) + self.save_state(state) + return state + + # ── Plan Storage ───────────────────────────────────── + + def save_plan( + self, + plan_content: str, + conversation_uuid: str, + ) -> Path: + """Save a task plan to the workspace.""" + dir_path = self._ensure_dir(".project-meta", "plans") + # Sanitize the UUID to prevent path injection + safe_uuid = self._sanitize_filename(conversation_uuid) + filename = f"{safe_uuid}.md" + filepath = dir_path / filename + filepath.write_text(plan_content, encoding="utf-8") + return filepath + + # ── Workspace Summary ──────────────────────────────── + + def get_workspace_summary(self) -> dict: + """Get a summary of all workspace contents for context injection.""" + summary = { + "papers": self._count_files("papers", "*.md"), + "research_notes": self._count_files("research/notes", "*.md"), + "search_results": self._count_files("research/searches", "*.json"), + "code_files": self._count_files_recursive("code"), + "experiments": self._count_files("logs/experiments", "*.json"), + "tool_failures": self._count_files("logs/tool_failures", "*.json"), + } + + # Add recent failures as warnings + recent_failures = self.get_recent_failures(limit=5) + if recent_failures: + summary["recent_tool_failures"] = [ + {"tool": f["tool"], "error": f["error"][:200], "time": f.get("timestamp")} + for f in recent_failures + ] + + return summary + + def _count_files(self, subdir: str, pattern: str) -> int: + path = self.workspace_path / subdir + if not path.exists(): + return 0 + return len(list(path.glob(pattern))) + + def _count_files_recursive(self, subdir: str) -> int: + path = self.workspace_path / subdir + if not path.exists(): + return 0 + count = 0 + for _ in path.rglob("*"): + count += 1 + return count diff --git a/backend/pyproject.toml b/backend/pyproject.toml index f6ac4d8..0ffb822 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -38,6 +38,9 @@ dependencies = [ "python-dotenv>=1.0.0", "pyyaml>=6.0.0", + # Knowledge graph (workspace persistence) + "networkx>=3.0", + # Background jobs "celery>=5.4.0", "redis>=5.0.0", diff --git a/backend/tests/test_projects.py b/backend/tests/test_projects.py new file mode 100644 index 0000000..62930a8 --- /dev/null +++ b/backend/tests/test_projects.py @@ -0,0 +1,196 @@ +"""Tests for Project model, DB operations, and API routes.""" + +import pytest +from sqlalchemy.ext.asyncio import AsyncSession + +from openmlr.db import operations as ops +from openmlr.db.models import User + +pytestmark = pytest.mark.asyncio + + +# ── DB Operations ──────────────────────────────────────── + + +class TestProjectOperations: + async def test_create_project(self, db_session: AsyncSession, test_user: User): + project = await ops.create_project( + db_session, + test_user.id, + name="Test Project", + slug="test-project", + description="A test project", + ) + assert project.id is not None + assert project.uuid is not None + assert project.name == "Test Project" + assert project.slug == "test-project" + assert project.status == "active" + + async def test_get_user_projects(self, db_session: AsyncSession, test_user: User): + await ops.create_project(db_session, test_user.id, "P1", "p1") + await ops.create_project(db_session, test_user.id, "P2", "p2") + + projects = await ops.get_user_projects(db_session, test_user.id) + assert len(projects) == 2 + + async def test_get_project_by_uuid(self, db_session: AsyncSession, test_user: User): + project = await ops.create_project(db_session, test_user.id, "Find Me", "find-me") + found = await ops.get_project_by_uuid(db_session, project.uuid, test_user.id) + assert found is not None + assert found.name == "Find Me" + + async def test_get_project_by_uuid_wrong_user(self, db_session: AsyncSession, test_user: User): + project = await ops.create_project(db_session, test_user.id, "Private", "private") + found = await ops.get_project_by_uuid(db_session, project.uuid, user_id=9999) + assert found is None + + async def test_get_project_by_slug(self, db_session: AsyncSession, test_user: User): + await ops.create_project(db_session, test_user.id, "Slug Test", "slug-test") + found = await ops.get_project_by_slug(db_session, test_user.id, "slug-test") + assert found is not None + assert found.name == "Slug Test" + + async def test_update_project(self, db_session: AsyncSession, test_user: User): + project = await ops.create_project(db_session, test_user.id, "Original", "original") + updated = await ops.update_project( + db_session, + project.id, + test_user.id, + name="Updated Name", + description="New description", + ) + assert updated.name == "Updated Name" + assert updated.description == "New description" + + async def test_archive_project(self, db_session: AsyncSession, test_user: User): + project = await ops.create_project(db_session, test_user.id, "To Archive", "to-archive") + archived = await ops.archive_project(db_session, project.id, test_user.id) + assert archived.status == "archived" + + async def test_get_user_projects_excludes_archived( + self, db_session: AsyncSession, test_user: User + ): + await ops.create_project(db_session, test_user.id, "Active", "active") + p2 = await ops.create_project(db_session, test_user.id, "To Archive", "to-archive") + await ops.archive_project(db_session, p2.id, test_user.id) + + active_only = await ops.get_user_projects(db_session, test_user.id, include_archived=False) + assert len(active_only) == 1 + assert active_only[0].name == "Active" + + all_projects = await ops.get_user_projects(db_session, test_user.id, include_archived=True) + assert len(all_projects) == 2 + + async def test_attach_conversation_to_project(self, db_session: AsyncSession, test_user: User): + project = await ops.create_project(db_session, test_user.id, "With Conv", "with-conv") + conv = await ops.create_conversation(db_session, test_user.id, title="Test Conv") + + success = await ops.attach_conversation_to_project(db_session, conv.id, project.id) + assert success is True + + convs = await ops.get_project_conversations(db_session, project.id) + assert len(convs) == 1 + assert convs[0].title == "Test Conv" + + async def test_detach_conversation_from_project( + self, db_session: AsyncSession, test_user: User + ): + project = await ops.create_project(db_session, test_user.id, "Detach Test", "detach-test") + conv = await ops.create_conversation(db_session, test_user.id, project_id=project.id) + + convs = await ops.get_project_conversations(db_session, project.id) + assert len(convs) == 1 + + await ops.attach_conversation_to_project(db_session, conv.id, None) + convs = await ops.get_project_conversations(db_session, project.id) + assert len(convs) == 0 + + async def test_create_conversation_with_project( + self, db_session: AsyncSession, test_user: User + ): + project = await ops.create_project(db_session, test_user.id, "Direct", "direct") + conv = await ops.create_conversation( + db_session, + test_user.id, + title="Project Conv", + project_id=project.id, + ) + assert conv.project_id == project.id + + +# ── API Routes ─────────────────────────────────────────── + + +class TestProjectRoutes: + async def test_create_project_api(self, auth_client): + resp = await auth_client.post( + "/api/projects", + json={ + "name": "API Project", + "description": "Created via API", + }, + ) + assert resp.status_code == 200 + data = resp.json() + assert data["project"]["name"] == "API Project" + assert data["project"]["slug"] == "api-project" + assert data["project"]["status"] == "active" + + async def test_list_projects_api(self, auth_client): + await auth_client.post("/api/projects", json={"name": "Project 1"}) + await auth_client.post("/api/projects", json={"name": "Project 2"}) + + resp = await auth_client.get("/api/projects") + assert resp.status_code == 200 + data = resp.json() + assert len(data["projects"]) == 2 + + async def test_get_project_api(self, auth_client): + create_resp = await auth_client.post("/api/projects", json={"name": "Get Me"}) + uuid = create_resp.json()["project"]["uuid"] + + resp = await auth_client.get(f"/api/projects/{uuid}") + assert resp.status_code == 200 + assert resp.json()["project"]["name"] == "Get Me" + + async def test_update_project_api(self, auth_client): + create_resp = await auth_client.post("/api/projects", json={"name": "Update Me"}) + uuid = create_resp.json()["project"]["uuid"] + + resp = await auth_client.put( + f"/api/projects/{uuid}", + json={ + "name": "Updated", + "description": "New desc", + }, + ) + assert resp.status_code == 200 + assert resp.json()["project"]["name"] == "Updated" + + async def test_delete_project_api(self, auth_client): + create_resp = await auth_client.post("/api/projects", json={"name": "Delete Me"}) + uuid = create_resp.json()["project"]["uuid"] + + resp = await auth_client.delete(f"/api/projects/{uuid}") + assert resp.status_code == 200 + + # Should be archived, not truly deleted + get_resp = await auth_client.get(f"/api/projects/{uuid}") + assert get_resp.json()["project"]["status"] == "archived" + + async def test_create_project_missing_name(self, auth_client): + resp = await auth_client.post("/api/projects", json={}) + assert resp.status_code == 400 + + async def test_get_nonexistent_project(self, auth_client): + resp = await auth_client.get("/api/projects/nonexistent-uuid") + assert resp.status_code == 404 + + async def test_project_conversations_api(self, auth_client): + create_resp = await auth_client.post("/api/projects", json={"name": "Conv Test"}) + uuid = create_resp.json()["project"]["uuid"] + + resp = await auth_client.get(f"/api/projects/{uuid}/conversations") + assert resp.status_code == 200 + assert resp.json()["conversations"] == [] diff --git a/backend/tests/test_tools_workspace.py b/backend/tests/test_tools_workspace.py new file mode 100644 index 0000000..69edca8 --- /dev/null +++ b/backend/tests/test_tools_workspace.py @@ -0,0 +1,208 @@ +"""Tests for workspace agent tools.""" + +import os +import tempfile + +import pytest + +from openmlr.tools.workspace_tools import ( + _handle_workspace, + create_workspace_tools, + set_workspace_context, +) + +pytestmark = pytest.mark.asyncio + + +@pytest.fixture +def workspace_dir(): + """Create a temporary workspace directory with standard structure.""" + with tempfile.TemporaryDirectory() as tmpdir: + for subdir in [ + "code", + "data", + "models", + "outputs", + "papers", + "research", + "research/searches", + "research/notes", + "research/citations", + "logs", + "logs/tool_failures", + "logs/compute", + "logs/experiments", + ".project-meta", + ".project-meta/plans", + ]: + os.makedirs(os.path.join(tmpdir, subdir), exist_ok=True) + set_workspace_context(tmpdir) + yield tmpdir + set_workspace_context(None) + + +class TestWorkspaceTools: + def test_create_workspace_tools(self): + tools = create_workspace_tools() + assert len(tools) == 1 + assert tools[0].name == "workspace" + + async def test_status_operation(self, workspace_dir): + result, success = await _handle_workspace(operation="status") + assert success is True + assert "Workspace Status" in result + + async def test_note_operation(self, workspace_dir): + result, success = await _handle_workspace( + operation="note", + topic="Test Topic", + content="This is a test note.", + ) + assert success is True + assert "saved" in result.lower() + + async def test_note_missing_params(self, workspace_dir): + result, success = await _handle_workspace(operation="note", topic="", content="") + assert success is False + + async def test_knowledge_add_operation(self, workspace_dir): + result, success = await _handle_workspace( + operation="knowledge_add", + entity_id="paper-1", + entity_type="paper", + label="Test Paper", + ) + assert success is True + assert "Added" in result + + async def test_knowledge_add_with_properties(self, workspace_dir): + result, success = await _handle_workspace( + operation="knowledge_add", + entity_id="paper-2", + entity_type="paper", + label="Paper 2", + properties='{"year": 2024, "venue": "NeurIPS"}', + ) + assert success is True + + async def test_knowledge_add_missing_params(self, workspace_dir): + result, success = await _handle_workspace( + operation="knowledge_add", + entity_id="", + entity_type="", + label="", + ) + assert success is False + + async def test_knowledge_relate_operation(self, workspace_dir): + # Add entities first + await _handle_workspace( + operation="knowledge_add", + entity_id="p1", + entity_type="paper", + label="Paper 1", + ) + await _handle_workspace( + operation="knowledge_add", + entity_id="m1", + entity_type="method", + label="Method 1", + ) + + result, success = await _handle_workspace( + operation="knowledge_relate", + source_id="p1", + target_id="m1", + relationship="proposes", + ) + assert success is True + assert "proposes" in result + + async def test_knowledge_relate_missing_entity(self, workspace_dir): + await _handle_workspace( + operation="knowledge_add", + entity_id="p1", + entity_type="paper", + label="Paper 1", + ) + result, success = await _handle_workspace( + operation="knowledge_relate", + source_id="p1", + target_id="missing", + relationship="proposes", + ) + assert success is False + + async def test_knowledge_query_operation(self, workspace_dir): + await _handle_workspace( + operation="knowledge_add", + entity_id="attn", + entity_type="method", + label="Self-Attention", + ) + result, success = await _handle_workspace( + operation="knowledge_query", + query="attention", + ) + assert success is True + assert "Self-Attention" in result + + async def test_knowledge_query_no_results(self, workspace_dir): + result, success = await _handle_workspace( + operation="knowledge_query", + query="nonexistent", + ) + assert success is True + assert "No entities found" in result + + async def test_knowledge_summary_empty(self, workspace_dir): + result, success = await _handle_workspace(operation="knowledge_summary") + assert success is True + assert "empty" in result.lower() + + async def test_knowledge_summary_with_data(self, workspace_dir): + await _handle_workspace( + operation="knowledge_add", + entity_id="p1", + entity_type="paper", + label="Test Paper", + ) + result, success = await _handle_workspace(operation="knowledge_summary") + assert success is True + assert "Test Paper" in result + + async def test_recent_failures_empty(self, workspace_dir): + result, success = await _handle_workspace(operation="recent_failures") + assert success is True + assert "No recent" in result + + async def test_search_operation(self, workspace_dir): + # Create a test file + test_file = os.path.join(workspace_dir, "code", "test.py") + with open(test_file, "w") as f: + f.write("import torch\nmodel = TransformerModel()") + + result, success = await _handle_workspace( + operation="search", + query="transformer", + ) + assert success is True + assert "test.py" in result + + async def test_search_no_results(self, workspace_dir): + result, success = await _handle_workspace( + operation="search", + query="xyznonexistent", + ) + assert success is True + assert "No files found" in result + + async def test_unknown_operation(self, workspace_dir): + result, success = await _handle_workspace(operation="unknown_op") + assert success is False + + async def test_no_workspace_context(self): + set_workspace_context(None) + result, success = await _handle_workspace(operation="status") + assert success is False + assert "No project workspace" in result diff --git a/backend/tests/test_workspace.py b/backend/tests/test_workspace.py new file mode 100644 index 0000000..d3f59c2 --- /dev/null +++ b/backend/tests/test_workspace.py @@ -0,0 +1,333 @@ +"""Tests for workspace persistence and knowledge graph.""" + +import json +import os +import tempfile + +import pytest + +from openmlr.workspace.knowledge import KnowledgeGraph +from openmlr.workspace.persistence import WorkspacePersistence + +pytestmark = pytest.mark.asyncio + + +@pytest.fixture +def workspace_dir(): + """Create a temporary workspace directory with standard structure.""" + with tempfile.TemporaryDirectory() as tmpdir: + # Create standard subdirs + for subdir in [ + "code", + "data", + "models", + "outputs", + "papers", + "research", + "research/searches", + "research/notes", + "research/citations", + "logs", + "logs/tool_failures", + "logs/compute", + "logs/experiments", + "venvs", + ".project-meta", + ".project-meta/plans", + ]: + os.makedirs(os.path.join(tmpdir, subdir), exist_ok=True) + yield tmpdir + + +# ── Knowledge Graph ────────────────────────────────────── + + +class TestKnowledgeGraph: + def test_init_empty(self, workspace_dir): + kg = KnowledgeGraph(workspace_dir) + assert kg.node_count == 0 + assert kg.edge_count == 0 + + def test_add_entity(self, workspace_dir): + kg = KnowledgeGraph(workspace_dir) + is_new = kg.add_entity("paper-1", "paper", "Attention Is All You Need") + assert is_new is True + assert kg.node_count == 1 + + def test_add_entity_update(self, workspace_dir): + kg = KnowledgeGraph(workspace_dir) + kg.add_entity("paper-1", "paper", "Original Label") + is_new = kg.add_entity("paper-1", "paper", "Updated Label") + assert is_new is False + assert kg.node_count == 1 + + entity = kg.get_entity("paper-1") + assert entity["label"] == "Updated Label" + + def test_add_entity_with_properties(self, workspace_dir): + kg = KnowledgeGraph(workspace_dir) + kg.add_entity( + "paper-1", + "paper", + "Test Paper", + properties={ + "year": 2017, + "abstract": "We propose a new architecture...", + }, + ) + entity = kg.get_entity("paper-1") + assert entity["year"] == 2017 + assert "architecture" in entity["abstract"] + + def test_get_nonexistent_entity(self, workspace_dir): + kg = KnowledgeGraph(workspace_dir) + assert kg.get_entity("nope") is None + + def test_find_entities_by_type(self, workspace_dir): + kg = KnowledgeGraph(workspace_dir) + kg.add_entity("p1", "paper", "Paper 1") + kg.add_entity("p2", "paper", "Paper 2") + kg.add_entity("m1", "method", "Method 1") + + papers = kg.find_entities("paper") + assert len(papers) == 2 + + methods = kg.find_entities("method") + assert len(methods) == 1 + + def test_search_entities(self, workspace_dir): + kg = KnowledgeGraph(workspace_dir) + kg.add_entity("attn", "method", "Self-Attention Mechanism") + kg.add_entity("conv", "method", "Convolutional Neural Network") + kg.add_entity("bert", "paper", "BERT: Pre-training") + + results = kg.search_entities("attention") + assert len(results) == 1 + assert results[0]["id"] == "attn" + + def test_remove_entity(self, workspace_dir): + kg = KnowledgeGraph(workspace_dir) + kg.add_entity("rm-me", "paper", "Remove Me") + assert kg.node_count == 1 + + removed = kg.remove_entity("rm-me") + assert removed is True + assert kg.node_count == 0 + + def test_remove_nonexistent_entity(self, workspace_dir): + kg = KnowledgeGraph(workspace_dir) + assert kg.remove_entity("nope") is False + + def test_add_relationship(self, workspace_dir): + kg = KnowledgeGraph(workspace_dir) + kg.add_entity("p1", "paper", "Paper 1") + kg.add_entity("m1", "method", "Method 1") + + is_new = kg.add_relationship("p1", "m1", "proposes") + assert is_new is True + assert kg.edge_count == 1 + + def test_add_relationship_missing_entity(self, workspace_dir): + kg = KnowledgeGraph(workspace_dir) + kg.add_entity("p1", "paper", "Paper 1") + + is_new = kg.add_relationship("p1", "missing", "proposes") + assert is_new is False + assert kg.edge_count == 0 + + def test_get_neighbors(self, workspace_dir): + kg = KnowledgeGraph(workspace_dir) + kg.add_entity("p1", "paper", "Paper 1") + kg.add_entity("m1", "method", "Method 1") + kg.add_entity("m2", "method", "Method 2") + kg.add_relationship("p1", "m1", "proposes") + kg.add_relationship("p1", "m2", "proposes") + + neighbors = kg.get_neighbors("p1", direction="out") + assert len(neighbors) == 2 + + neighbors_in = kg.get_neighbors("m1", direction="in") + assert len(neighbors_in) == 1 + + def test_save_and_reload(self, workspace_dir): + kg = KnowledgeGraph(workspace_dir) + kg.add_entity("p1", "paper", "Paper 1", properties={"year": 2020}) + kg.add_entity("m1", "method", "Method 1") + kg.add_relationship("p1", "m1", "proposes") + kg.save() + + # Reload from disk + kg2 = KnowledgeGraph(workspace_dir) + assert kg2.node_count == 2 + assert kg2.edge_count == 1 + + entity = kg2.get_entity("p1") + assert entity["label"] == "Paper 1" + assert entity["year"] == 2020 + + def test_get_summary(self, workspace_dir): + kg = KnowledgeGraph(workspace_dir) + kg.add_entity("p1", "paper", "Paper 1") + kg.add_entity("m1", "method", "Method 1") + kg.add_relationship("p1", "m1", "proposes") + + summary = kg.get_summary() + assert summary["total_nodes"] == 2 + assert summary["total_edges"] == 1 + assert "paper" in summary["type_counts"] + assert "method" in summary["type_counts"] + + def test_get_context_for_conversation_empty(self, workspace_dir): + kg = KnowledgeGraph(workspace_dir) + context = kg.get_context_for_conversation() + assert context == "" + + def test_get_context_for_conversation(self, workspace_dir): + kg = KnowledgeGraph(workspace_dir) + kg.add_entity("p1", "paper", "Attention Paper") + kg.add_entity("m1", "method", "Self-Attention") + kg.add_relationship("p1", "m1", "proposes") + + context = kg.get_context_for_conversation() + assert "Attention Paper" in context + assert "Self-Attention" in context + assert "proposes" in context + + +# ── Workspace Persistence ──────────────────────────────── + + +class TestWorkspacePersistence: + def test_save_search_results(self, workspace_dir): + wp = WorkspacePersistence(workspace_dir) + filepath = wp.save_search_results( + query="transformer attention", + source="arxiv", + results=[{"title": "Paper 1"}, {"title": "Paper 2"}], + ) + assert filepath.exists() + data = json.loads(filepath.read_text()) + assert data["query"] == "transformer attention" + assert data["source"] == "arxiv" + assert len(data["results"]) == 2 + + def test_get_recent_searches(self, workspace_dir): + wp = WorkspacePersistence(workspace_dir) + wp.save_search_results("q1", "arxiv", [{"t": "r1"}]) + wp.save_search_results("q2", "openalex", [{"t": "r2"}]) + + searches = wp.get_recent_searches(limit=10) + assert len(searches) == 2 + + def test_save_research_note(self, workspace_dir): + wp = WorkspacePersistence(workspace_dir) + filepath = wp.save_research_note( + topic="Attention Mechanisms", + content="Self-attention allows models to...", + ) + assert filepath.exists() + content = filepath.read_text() + assert "Attention Mechanisms" in content + assert "Self-attention allows" in content + + def test_get_research_notes(self, workspace_dir): + wp = WorkspacePersistence(workspace_dir) + wp.save_research_note("Note 1", "Content 1") + wp.save_research_note("Note 2", "Content 2") + + notes = wp.get_research_notes() + assert len(notes) == 2 + + def test_save_paper(self, workspace_dir): + wp = WorkspacePersistence(workspace_dir) + filepath = wp.save_paper( + paper_id="2301.12345", + title="Test Paper", + content="## Introduction\n\nThis paper...", + metadata={"authors": "Smith et al.", "year": 2023}, + ) + assert filepath.exists() + content = filepath.read_text() + assert "Test Paper" in content + assert "Introduction" in content + + def test_log_tool_failure(self, workspace_dir): + wp = WorkspacePersistence(workspace_dir) + filepath = wp.log_tool_failure( + tool_name="papers", + error="arXiv rate limit reached", + args={"query": "test"}, + ) + assert filepath.exists() + data = json.loads(filepath.read_text()) + assert data["tool"] == "papers" + assert "rate limit" in data["error"] + + def test_get_recent_failures(self, workspace_dir): + wp = WorkspacePersistence(workspace_dir) + wp.log_tool_failure("papers", "Error 1") + wp.log_tool_failure("web_search", "Error 2") + + failures = wp.get_recent_failures() + assert len(failures) == 2 + + def test_log_compute_probe(self, workspace_dir): + wp = WorkspacePersistence(workspace_dir) + filepath = wp.log_compute_probe( + node_name="gpu-server", + capabilities={"gpu": "A100", "ram_gb": 128}, + ) + assert filepath.exists() + + def test_log_experiment(self, workspace_dir): + wp = WorkspacePersistence(workspace_dir) + filepath = wp.log_experiment( + name="train-bert", + command="python train.py --lr 0.001", + result={"loss": 0.05, "accuracy": 0.95}, + ) + assert filepath.exists() + data = json.loads(filepath.read_text()) + assert data["name"] == "train-bert" + assert data["result"]["accuracy"] == 0.95 + + def test_state_persistence(self, workspace_dir): + wp = WorkspacePersistence(workspace_dir) + + # Initial state + state = wp.get_state() + assert state.get("key_findings") == [] or state.get("key_findings") is None + + # Update state + wp.update_state( + key_findings=["Attention is effective for NLP"], + open_questions=["Does it scale?"], + ) + + # Reload + wp2 = WorkspacePersistence(workspace_dir) + state2 = wp2.get_state() + assert "Attention is effective for NLP" in state2["key_findings"] + assert "Does it scale?" in state2["open_questions"] + + def test_save_plan(self, workspace_dir): + wp = WorkspacePersistence(workspace_dir) + filepath = wp.save_plan( + plan_content="# Plan\n\n1. Read papers\n2. Train model", + conversation_uuid="test-conv-uuid", + ) + assert filepath.exists() + assert "Read papers" in filepath.read_text() + + def test_get_workspace_summary(self, workspace_dir): + wp = WorkspacePersistence(workspace_dir) + + # Add some files + wp.save_search_results("q1", "arxiv", []) + wp.save_research_note("Note", "Content") + wp.log_tool_failure("test", "Error") + + summary = wp.get_workspace_summary() + assert summary["search_results"] >= 1 + assert summary["research_notes"] >= 1 + assert summary["tool_failures"] >= 1 diff --git a/docker-compose.prod.yml b/docker-compose.prod.yml index d479428..230115e 100644 --- a/docker-compose.prod.yml +++ b/docker-compose.prod.yml @@ -67,6 +67,7 @@ services: volumes: - ./backend/configs:/app/backend/configs - ./.keys:/app/.keys + - ${OPENMLR_WORKSPACES_PATH:-./.workspaces}:/app/.workspaces security_opt: - no-new-privileges:true restart: unless-stopped @@ -95,6 +96,7 @@ services: volumes: - ./backend/configs:/app/backend/configs - ./.keys:/app/.keys + - ${OPENMLR_WORKSPACES_PATH:-./.workspaces}:/app/.workspaces security_opt: - no-new-privileges:true restart: unless-stopped @@ -116,3 +118,5 @@ services: volumes: pgdata: redisdata: + # Note: workspaces use a bind mount (OPENMLR_WORKSPACES_PATH or ./.workspaces) + # for easy backup and inspection. Not a named volume. diff --git a/docker-compose.yml b/docker-compose.yml index 6eaadfd..127590f 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -71,6 +71,7 @@ services: - backend-venv:/app/backend/.venv - ./frontend/dist:/app/frontend/dist - ./.keys:/app/.keys + - workspaces:/app/.workspaces # Worker with auto-restart on code changes worker: @@ -105,6 +106,7 @@ services: - ./backend:/app/backend - backend-venv:/app/backend/.venv - ./.keys:/app/.keys + - workspaces:/app/.workspaces # Docs site with live reload docs: @@ -122,3 +124,4 @@ volumes: redisdata: backend-venv: docs-node-modules: + workspaces: # Project workspaces — persists across container rebuilds diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx index 806ce81..501a017 100644 --- a/frontend/src/App.tsx +++ b/frontend/src/App.tsx @@ -5,7 +5,7 @@ import { ComputeSelector } from './components/ComputeSelector'; import { useSSE } from './hooks/useSSE'; import { useJobStatus } from './hooks/useJobStatus'; import { api } from './api'; -import type { AgentEvent, Message, Conversation, User, QuestionsPayload, PlanTask, Resource, ContextUsage, SearchBudget } from './types'; +import type { AgentEvent, Message, Conversation, User, QuestionsPayload, PlanTask, Resource, ContextUsage, SearchBudget, Project } from './types'; import { MessageList } from './components/MessageList'; import { InputArea, type Mode } from './components/InputArea'; import { Sidebar } from './components/Sidebar'; @@ -17,6 +17,8 @@ import { RightPanel } from './components/RightPanel'; import { ReportDrawer } from './components/ReportDrawer'; import { AuthGuard } from './components/AuthGuard'; import { OnboardingModal } from './components/OnboardingModal'; +import { Terminal } from './components/Terminal'; +import { ProjectModal } from './components/ProjectModal'; import { SettingsPage } from './components/SettingsPage'; import { ProvidersSettings } from './components/settings/ProvidersSettings'; import { AgentSettings } from './components/settings/AgentSettings'; @@ -103,6 +105,10 @@ function ChatUI({ const [inputText, setInputText] = useState(''); const [computeNodes, setComputeNodes] = useState([]); const [activeCompute, setActiveCompute] = useState(null); + const [projects, setProjects] = useState([]); + const [activeProject, setActiveProject] = useState(null); + const [showProjectModal, setShowProjectModal] = useState(false); + const [terminalOpen, setTerminalOpen] = useState(false); // Ref to always have current conv UUID in SSE callback (avoids stale closure) const currentConvUuidRef = useRef(currentConvUuid); @@ -132,6 +138,15 @@ function ChatUI({ } }, []); + const loadProjects = useCallback(async () => { + try { + const data = await api.listProjects(); + setProjects(data.projects || []); + } catch { + setProjects([]); + } + }, []); + const loadActiveCompute = useCallback(async (uuid: string) => { try { const data = await api.getConversationCompute(uuid); @@ -144,7 +159,7 @@ function ChatUI({ // Initial load - load conversations and activate the correct one useEffect(() => { const init = async () => { - await loadComputeNodes(); + await Promise.all([loadComputeNodes(), loadProjects()]); const convs = await loadConversations(); // If URL has a conversation UUID, load it directly @@ -578,8 +593,12 @@ function ChatUI({ setShowProjectModal(true)} />
{/* RightPanel is fixed position, doesn't affect flex layout */} - setRightPanelOpen((v) => !v)} onViewReport={(r) => setViewingReport(r)} /> + setRightPanelOpen((v) => !v)} onViewReport={(r) => setViewingReport(r)} />
+ + {/* Terminal panel */} + setTerminalOpen((v) => !v)} + /> {viewingReport && setViewingReport(null)} />} + {showProjectModal && setShowProjectModal(false)} onCreate={(p) => { setProjects((prev) => [p, ...prev]); setActiveProject(p); }} />} ); } diff --git a/frontend/src/__tests__/RightPanel.test.tsx b/frontend/src/__tests__/RightPanel.test.tsx index 8b0b5b0..eb5ae61 100644 --- a/frontend/src/__tests__/RightPanel.test.tsx +++ b/frontend/src/__tests__/RightPanel.test.tsx @@ -33,6 +33,7 @@ describe('RightPanel', () => { contextUsage={mockContext} searchBudget={null} visible={false} + projectUuid={null} onToggle={vi.fn()} onViewReport={vi.fn()} /> @@ -48,6 +49,7 @@ describe('RightPanel', () => { contextUsage={null} searchBudget={null} visible={true} + projectUuid={null} onToggle={vi.fn()} onViewReport={vi.fn()} /> @@ -65,6 +67,7 @@ describe('RightPanel', () => { contextUsage={null} searchBudget={null} visible={true} + projectUuid={null} onToggle={vi.fn()} onViewReport={vi.fn()} /> @@ -80,6 +83,7 @@ describe('RightPanel', () => { contextUsage={null} searchBudget={null} visible={true} + projectUuid={null} onToggle={vi.fn()} onViewReport={vi.fn()} /> @@ -96,6 +100,7 @@ describe('RightPanel', () => { contextUsage={mockContext} searchBudget={null} visible={true} + projectUuid={null} onToggle={vi.fn()} onViewReport={vi.fn()} /> @@ -112,6 +117,7 @@ describe('RightPanel', () => { contextUsage={null} searchBudget={mockSearchBudget} visible={true} + projectUuid={null} onToggle={vi.fn()} onViewReport={vi.fn()} /> @@ -127,6 +133,7 @@ describe('RightPanel', () => { contextUsage={null} searchBudget={null} visible={true} + projectUuid={null} onToggle={vi.fn()} onViewReport={vi.fn()} /> @@ -144,6 +151,7 @@ describe('RightPanel', () => { contextUsage={null} searchBudget={null} visible={true} + projectUuid={null} onToggle={vi.fn()} onViewReport={vi.fn()} /> @@ -160,6 +168,7 @@ describe('RightPanel', () => { contextUsage={null} searchBudget={null} visible={false} + projectUuid={null} onToggle={vi.fn()} onViewReport={vi.fn()} /> @@ -175,6 +184,7 @@ describe('RightPanel', () => { contextUsage={null} searchBudget={null} visible={false} + projectUuid={null} onToggle={vi.fn()} onViewReport={vi.fn()} /> diff --git a/frontend/src/__tests__/Sidebar.test.tsx b/frontend/src/__tests__/Sidebar.test.tsx index c5fe47b..f31447f 100644 --- a/frontend/src/__tests__/Sidebar.test.tsx +++ b/frontend/src/__tests__/Sidebar.test.tsx @@ -51,9 +51,13 @@ describe('Sidebar', () => { currentUuid={null} user={mockUser} convStatuses={{}} + projects={[]} + activeProject={null} onSwitch={vi.fn()} onNew={vi.fn()} onDelete={vi.fn()} + onSelectProject={vi.fn()} + onNewProject={vi.fn()} /> ); @@ -68,9 +72,13 @@ describe('Sidebar', () => { currentUuid={null} user={mockUser} convStatuses={{}} + projects={[]} + activeProject={null} onSwitch={vi.fn()} onNew={vi.fn()} onDelete={vi.fn()} + onSelectProject={vi.fn()} + onNewProject={vi.fn()} /> ); @@ -86,9 +94,13 @@ describe('Sidebar', () => { currentUuid="conv-1" user={mockUser} convStatuses={{}} + projects={[]} + activeProject={null} onSwitch={vi.fn()} onNew={vi.fn()} onDelete={vi.fn()} + onSelectProject={vi.fn()} + onNewProject={vi.fn()} /> ); @@ -107,9 +119,13 @@ describe('Sidebar', () => { currentUuid={null} user={mockUser} convStatuses={{}} + projects={[]} + activeProject={null} onSwitch={onSwitch} onNew={vi.fn()} onDelete={vi.fn()} + onSelectProject={vi.fn()} + onNewProject={vi.fn()} /> ); @@ -125,9 +141,13 @@ describe('Sidebar', () => { currentUuid={null} user={mockUser} convStatuses={{}} + projects={[]} + activeProject={null} onSwitch={vi.fn()} onNew={vi.fn()} onDelete={vi.fn()} + onSelectProject={vi.fn()} + onNewProject={vi.fn()} /> ); @@ -142,9 +162,13 @@ describe('Sidebar', () => { currentUuid={null} user={mockUser} convStatuses={{}} + projects={[]} + activeProject={null} onSwitch={vi.fn()} onNew={vi.fn()} onDelete={vi.fn()} + onSelectProject={vi.fn()} + onNewProject={vi.fn()} /> ); @@ -159,9 +183,13 @@ describe('Sidebar', () => { currentUuid={null} user={mockUser} convStatuses={{}} + projects={[]} + activeProject={null} onSwitch={vi.fn()} onNew={vi.fn()} onDelete={vi.fn()} + onSelectProject={vi.fn()} + onNewProject={vi.fn()} /> ); @@ -177,9 +205,13 @@ describe('Sidebar', () => { currentUuid={null} user={mockUser} convStatuses={{}} + projects={[]} + activeProject={null} onSwitch={vi.fn()} onNew={vi.fn()} onDelete={vi.fn()} + onSelectProject={vi.fn()} + onNewProject={vi.fn()} /> ); @@ -198,9 +230,13 @@ describe('Sidebar', () => { currentUuid={null} user={mockUser} convStatuses={{}} + projects={[]} + activeProject={null} onSwitch={vi.fn()} onNew={onNew} onDelete={vi.fn()} + onSelectProject={vi.fn()} + onNewProject={vi.fn()} /> ); diff --git a/frontend/src/api.ts b/frontend/src/api.ts index 7d5a3ee..781fd78 100644 --- a/frontend/src/api.ts +++ b/frontend/src/api.ts @@ -116,6 +116,28 @@ export const api = { createKey: (body: Record) => post('/api/keys', body), deleteKey: (filename: string) => del(`/api/keys/${filename}`), + // Projects + listProjects: (includeArchived = false) => get(`/api/projects${includeArchived ? '?include_archived=true' : ''}`), + createProject: (name: string, description?: string) => post('/api/projects', { name, description }), + getProject: (uuid: string) => get(`/api/projects/${uuid}`), + updateProject: (uuid: string, body: Record) => put(`/api/projects/${uuid}`, body), + deleteProject: (uuid: string) => del(`/api/projects/${uuid}`), + listProjectConversations: (uuid: string) => get(`/api/projects/${uuid}/conversations`), + attachConversation: (projectUuid: string, convUuid: string) => + post(`/api/projects/${projectUuid}/attach/${convUuid}`, {}), + detachConversation: (projectUuid: string, convUuid: string) => + post(`/api/projects/${projectUuid}/detach/${convUuid}`, {}), + + // Project Files + listFiles: (projectUuid: string, path = '') => + get(`/api/projects/${projectUuid}/files${path ? `?path=${encodeURIComponent(path)}` : ''}`), + readFile: (projectUuid: string, filePath: string) => + get(`/api/projects/${projectUuid}/files/${encodeURIComponent(filePath)}`), + writeFile: (projectUuid: string, filePath: string, content: string) => + put(`/api/projects/${projectUuid}/files/${encodeURIComponent(filePath)}`, { content }), + deleteFile: (projectUuid: string, filePath: string) => + del(`/api/projects/${projectUuid}/files/${encodeURIComponent(filePath)}`), + // Compute Nodes getComputeNodes: () => get('/api/compute/nodes'), createComputeNode: (body: Record) => post('/api/compute/nodes', body), diff --git a/frontend/src/components/FileTree.tsx b/frontend/src/components/FileTree.tsx new file mode 100644 index 0000000..a892cc6 --- /dev/null +++ b/frontend/src/components/FileTree.tsx @@ -0,0 +1,303 @@ +import { useState, useCallback, useEffect } from 'react'; +import { + Folder, + FolderOpen, + FileText, + FileCode, + FileJson, + Image, + ChevronRight, + ChevronDown, + RefreshCw, + AlertCircle, +} from 'lucide-react'; +import { api } from '../api'; +import type { FileNode } from '../types'; + +interface Props { + projectUuid: string; + onFileSelect?: (path: string, content: string) => void; +} + +interface TreeNode extends FileNode { + children?: TreeNode[]; + loading?: boolean; + expanded?: boolean; +} + +const FILE_ICONS: Record = { + '.py': , + '.js': , + '.ts': , + '.tsx': , + '.json': , + '.md': , + '.txt': , + '.yaml': , + '.yml': , + '.png': , + '.jpg': , + '.svg': , +}; + +function getFileIcon(name: string, isDir: boolean): React.ReactNode { + if (isDir) return null; // handled by folder icons + const ext = name.includes('.') ? '.' + name.split('.').pop() : ''; + return FILE_ICONS[ext] || ; +} + +function formatSize(bytes: number | null): string { + if (bytes === null || bytes === undefined) return ''; + if (bytes < 1024) return `${bytes}B`; + if (bytes < 1024 * 1024) return `${(bytes / 1024).toFixed(1)}K`; + return `${(bytes / (1024 * 1024)).toFixed(1)}M`; +} + +function TreeItem({ + node, + depth, + onToggle, + onSelect, +}: { + node: TreeNode; + depth: number; + onToggle: (path: string) => void; + onSelect: (path: string) => void; +}) { + return ( +
+ + + {/* Children */} + {node.is_dir && node.expanded && node.children && ( +
+ {node.children.map((child) => ( + + ))} + {node.children.length === 0 && ( +
+ (empty) +
+ )} +
+ )} +
+ ); +} + +export function FileTree({ projectUuid, onFileSelect }: Props) { + const [nodes, setNodes] = useState([]); + const [loading, setLoading] = useState(true); + const [error, setError] = useState(null); + const [selectedFile, setSelectedFile] = useState(null); + const [fileContent, setFileContent] = useState(null); + const [fileLoading, setFileLoading] = useState(false); + + const loadDirectory = useCallback(async (path: string = ''): Promise => { + try { + const data = await api.listFiles(projectUuid, path); + return (data.entries || []).map((entry: FileNode) => ({ + ...entry, + expanded: false, + children: entry.is_dir ? undefined : undefined, + })); + } catch (err: any) { + setError(err.message); + return []; + } + }, [projectUuid]); + + // Initial load + useEffect(() => { + setLoading(true); + setError(null); + loadDirectory('').then((entries) => { + setNodes(entries); + setLoading(false); + }); + }, [projectUuid, loadDirectory]); + + const handleToggle = useCallback(async (path: string) => { + setNodes((prev) => { + const update = (items: TreeNode[]): TreeNode[] => + items.map((item) => { + if (item.path === path) { + if (item.expanded) { + return { ...item, expanded: false }; + } + return { ...item, loading: true, expanded: true }; + } + if (item.children) { + return { ...item, children: update(item.children) }; + } + return item; + }); + return update(prev); + }); + + // Load children + const children = await loadDirectory(path); + setNodes((prev) => { + const update = (items: TreeNode[]): TreeNode[] => + items.map((item) => { + if (item.path === path) { + return { ...item, loading: false, children }; + } + if (item.children) { + return { ...item, children: update(item.children) }; + } + return item; + }); + return update(prev); + }); + }, [loadDirectory]); + + const handleSelect = useCallback(async (path: string) => { + setSelectedFile(path); + setFileLoading(true); + setFileContent(null); + try { + const data = await api.readFile(projectUuid, path); + if (data.content !== undefined) { + setFileContent(data.content); + onFileSelect?.(path, data.content); + } + } catch { + setFileContent(null); + } + setFileLoading(false); + }, [projectUuid, onFileSelect]); + + const handleRefresh = useCallback(async () => { + setLoading(true); + setError(null); + const entries = await loadDirectory(''); + setNodes(entries); + setLoading(false); + }, [loadDirectory]); + + if (loading) { + return ( +
+ + Loading files... +
+ ); + } + + if (error) { + return ( +
+ + {error} +
+ ); + } + + return ( +
+ {/* Header */} +
+ Files + +
+ + {/* Tree */} +
+ {nodes.length === 0 ? ( +
No files yet
+ ) : ( + nodes.map((node) => ( + + )) + )} +
+ + {/* Selected file preview */} + {selectedFile && ( +
+
+ {selectedFile} + +
+ {fileLoading ? ( +
Loading...
+ ) : fileContent !== null ? ( +
+              {fileContent.length > 5000 ? fileContent.slice(0, 5000) + '\n...' : fileContent}
+            
+ ) : ( +
Binary file
+ )} +
+ )} +
+ ); +} diff --git a/frontend/src/components/ProjectModal.tsx b/frontend/src/components/ProjectModal.tsx new file mode 100644 index 0000000..7a4fbe7 --- /dev/null +++ b/frontend/src/components/ProjectModal.tsx @@ -0,0 +1,104 @@ +import { useState } from 'react'; +import { X, FolderPlus } from 'lucide-react'; +import { api } from '../api'; +import type { Project } from '../types'; + +interface Props { + onClose: () => void; + onCreate: (project: Project) => void; +} + +export function ProjectModal({ onClose, onCreate }: Props) { + const [name, setName] = useState(''); + const [description, setDescription] = useState(''); + const [loading, setLoading] = useState(false); + const [error, setError] = useState(null); + + const handleCreate = async () => { + if (!name.trim()) return; + setLoading(true); + setError(null); + try { + const data = await api.createProject(name.trim(), description.trim() || undefined); + onCreate(data.project); + onClose(); + } catch (err: any) { + setError(err.message || 'Failed to create project'); + } finally { + setLoading(false); + } + }; + + return ( +
+
+ {/* Header */} +
+
+ +

New Project

+
+ +
+ + {/* Form */} +
+
+ + setName(e.target.value)} + placeholder="e.g., Attention Mechanism Survey" + className="w-full px-3 py-2 bg-bg border border-border rounded-lg text-text placeholder-text-dim focus:border-primary focus:outline-none transition-colors" + autoFocus + onKeyDown={(e) => e.key === 'Enter' && handleCreate()} + /> +

+ A workspace directory will be created for this project +

+
+ +
+ +