diff --git a/.keys/.gitignore b/.keys/.gitignore new file mode 100644 index 0000000..d6b7ef3 --- /dev/null +++ b/.keys/.gitignore @@ -0,0 +1,2 @@ +* +!.gitignore diff --git a/backend/configs/prompts/system_prompt.yaml b/backend/configs/prompts/system_prompt.yaml index f73a71c..bd340c8 100644 --- a/backend/configs/prompts/system_prompt.yaml +++ b/backend/configs/prompts/system_prompt.yaml @@ -145,3 +145,13 @@ prompt: | - Date: {{ date }} - User: {{ username }} - Mode: {{ mode }} + + {% if compute_env %} + {{ compute_env }} + {% endif %} + + # Compute Planning + When starting tasks that require significant computation (training models, processing large datasets, etc.): + 1. Use `compute_plan` to verify the active node meets requirements + 2. If not, use `compute_select` to switch to a suitable node + 3. Always `sandbox_probe` before executing code on a node for the first time diff --git a/backend/openmlr/agent/prompts.py b/backend/openmlr/agent/prompts.py index 6004fcd..8528cf7 100644 --- a/backend/openmlr/agent/prompts.py +++ b/backend/openmlr/agent/prompts.py @@ -25,6 +25,7 @@ def build_system_prompt( mode: str = "general", username: str = "user", sandbox_info: str = "none", + compute_env: str = "", config: AgentConfig | None = None, ) -> str: """Build the full system prompt from YAML template.""" @@ -58,6 +59,7 @@ def build_system_prompt( timezone="UTC", username=username, sandbox_info=sandbox_info, + compute_env=compute_env, ) return prompt diff --git a/backend/openmlr/app.py b/backend/openmlr/app.py index d034841..b150bfc 100644 --- a/backend/openmlr/app.py +++ b/backend/openmlr/app.py @@ -71,13 +71,17 @@ async def lifespan(app: FastAPI): # ── API routers ────────────────────────────────────────── from .auth.router import router as auth_router from .routes.agent import router as agent_router +from .routes.compute import router as compute_router from .routes.health import router as health_router +from .routes.keys import router as keys_router from .routes.settings import router as settings_router app.include_router(auth_router) app.include_router(agent_router) app.include_router(settings_router) app.include_router(health_router) +app.include_router(keys_router) +app.include_router(compute_router) # ── Global error handler ──────────────────────────────── diff --git a/backend/openmlr/celery_app.py b/backend/openmlr/celery_app.py index 91b43ab..7190264 100644 --- a/backend/openmlr/celery_app.py +++ b/backend/openmlr/celery_app.py @@ -13,7 +13,7 @@ "openmlr", broker=REDIS_URL, backend=REDIS_URL, - include=["openmlr.tasks.agent_tasks"], + include=["openmlr.tasks.agent_tasks", "openmlr.tasks.compute_tasks"], ) # Celery configuration @@ -43,6 +43,18 @@ # Default queue task_default_queue="default", + + # Beat schedule for periodic tasks + beat_schedule={ + "health-check-all-nodes": { + "task": "openmlr.tasks.compute_tasks.health_check_all_nodes", + "schedule": 300.0, # Every 5 minutes + }, + "cleanup-old-workspaces": { + "task": "openmlr.tasks.compute_tasks.cleanup_old_workspaces", + "schedule": 86400.0, # Every 24 hours + }, + }, ) diff --git a/backend/openmlr/compute/__init__.py b/backend/openmlr/compute/__init__.py new file mode 100644 index 0000000..ff1c488 --- /dev/null +++ b/backend/openmlr/compute/__init__.py @@ -0,0 +1,6 @@ +from .capabilities import ComputeCapabilities, GPUInfo +from .manager import ComputeManager +from .probe import probe_sandbox +from .workspace import WorkspaceManager + +__all__ = ["ComputeCapabilities", "ComputeManager", "GPUInfo", "probe_sandbox", "WorkspaceManager"] diff --git a/backend/openmlr/compute/capabilities.py b/backend/openmlr/compute/capabilities.py new file mode 100644 index 0000000..21c36e8 --- /dev/null +++ b/backend/openmlr/compute/capabilities.py @@ -0,0 +1,85 @@ +"""Compute capability discovery and planning.""" + +from dataclasses import dataclass, field +from typing import Any + + +@dataclass +class GPUInfo: + """Information about a GPU.""" + model: str = "" + vram_gb: float = 0.0 + cuda_version: str = "" + driver_version: str = "" + + +@dataclass +class ComputeCapabilities: + """Comprehensive capabilities of a compute node.""" + # Platform + platform: str = "unknown" + cpu_cores: int = 0 + cpu_arch: str = "unknown" + + # Memory + total_ram_gb: float = 0.0 + available_ram_gb: float = 0.0 + + # Storage + total_disk_gb: float = 0.0 + available_disk_gb: float = 0.0 + + # GPU + gpu_available: bool = False + gpu_count: int = 0 + gpu_info: list[GPUInfo] = field(default_factory=list) + + # Software + python_versions: list[str] = field(default_factory=list) + docker_available: bool = False + conda_envs: list[str] = field(default_factory=list) + installed_packages: list[str] = field(default_factory=list) + + # Network + has_internet: bool = True + latency_ms: float = 0.0 + + def to_dict(self) -> dict[str, Any]: + """Serialize to dict.""" + return { + "platform": self.platform, + "cpu_cores": self.cpu_cores, + "cpu_arch": self.cpu_arch, + "total_ram_gb": self.total_ram_gb, + "available_ram_gb": self.available_ram_gb, + "total_disk_gb": self.total_disk_gb, + "available_disk_gb": self.available_disk_gb, + "gpu_available": self.gpu_available, + "gpu_count": self.gpu_count, + "gpu_info": [ + { + "model": g.model, + "vram_gb": g.vram_gb, + "cuda_version": g.cuda_version, + "driver_version": g.driver_version, + } + for g in self.gpu_info + ], + "python_versions": self.python_versions, + "docker_available": self.docker_available, + "conda_envs": self.conda_envs, + "installed_packages": self.installed_packages, + "has_internet": self.has_internet, + "latency_ms": self.latency_ms, + } + + @classmethod + def from_dict(cls, data: dict) -> "ComputeCapabilities": + """Deserialize from dict.""" + caps = cls() + for key, value in data.items(): + if key == "gpu_info" and value: + caps.gpu_info = [GPUInfo(**g) for g in value] + elif hasattr(caps, key): + setattr(caps, key, value) + return caps diff --git a/backend/openmlr/compute/manager.py b/backend/openmlr/compute/manager.py new file mode 100644 index 0000000..af4b8ef --- /dev/null +++ b/backend/openmlr/compute/manager.py @@ -0,0 +1,45 @@ +"""Compute Node Manager — registry, validation, and lifecycle.""" + +from pathlib import Path + + +class ComputeManager: + """High-level operations for compute node management.""" + + def __init__(self, key_manager): + self.key_manager = key_manager + + def validate_node_config(self, node_type: str, config: dict) -> tuple[bool, str]: + """Validate a compute node configuration. Pure check, no side effects.""" + if node_type == "ssh": + return self._validate_ssh_config(config) + elif node_type == "local": + return self._validate_local_config(config) + elif node_type == "modal": + return self._validate_modal_config(config) + else: + return False, f"Unknown node type: {node_type}" + + def _validate_ssh_config(self, config: dict) -> tuple[bool, str]: + required = ["host", "username"] + for field in required: + if not config.get(field): + return False, f"SSH config requires '{field}'" + + key_filename = config.get("key_filename") + if key_filename and not self.key_manager.key_exists(key_filename): + return False, f"SSH key not found: {key_filename}" + + return True, "" + + def _validate_local_config(self, config: dict) -> tuple[bool, str]: + workdir = config.get("workdir", "") + if workdir: + path = Path(workdir).expanduser() + # Only validate — don't create directories as a side effect + if path.exists() and not path.is_dir(): + return False, f"Path exists but is not a directory: {path}" + return True, "" + + def _validate_modal_config(self, config: dict) -> tuple[bool, str]: + return True, "" diff --git a/backend/openmlr/compute/probe.py b/backend/openmlr/compute/probe.py new file mode 100644 index 0000000..c8f9e8f --- /dev/null +++ b/backend/openmlr/compute/probe.py @@ -0,0 +1,170 @@ +"""Environment probing for all sandbox types.""" + +import time + +from .capabilities import ComputeCapabilities, GPUInfo + + +async def probe_sandbox(sandbox) -> ComputeCapabilities: + """Deep capability discovery for any sandbox implementation.""" + caps = ComputeCapabilities() + start = time.monotonic() + + # Platform + result = await sandbox.execute("uname -s -r 2>/dev/null || echo 'unknown'", timeout=5) + if result.success: + caps.platform = result.output.strip() + + # CPU cores and architecture + result = await sandbox.execute("nproc 2>/dev/null || sysctl -n hw.ncpu 2>/dev/null || echo '0'", timeout=5) + if result.success: + try: + caps.cpu_cores = int(result.output.strip()) + except ValueError: + pass + + result = await sandbox.execute("uname -m 2>/dev/null || echo 'unknown'", timeout=5) + if result.success: + caps.cpu_arch = result.output.strip() + + # RAM (Linux) + result = await sandbox.execute( + "free -g 2>/dev/null | grep Mem | awk '{print $2, $7}' || " + "echo '0 0'", + timeout=5, + ) + if result.success: + parts = result.output.strip().split() + if len(parts) >= 2: + try: + caps.total_ram_gb = float(parts[0]) + caps.available_ram_gb = float(parts[1]) + except ValueError: + pass + + # Disk + result = await sandbox.execute( + "df -BG / 2>/dev/null | tail -1 | awk '{print $2, $4}' || echo '0 0'", + timeout=5, + ) + if result.success: + parts = result.output.strip().split() + if len(parts) >= 2: + try: + caps.total_disk_gb = float(parts[0].replace("G", "")) + caps.available_disk_gb = float(parts[1].replace("G", "")) + except ValueError: + pass + + # GPU — query model, memory, driver; then get CUDA version separately + result = await sandbox.execute( + "nvidia-smi --query-gpu=name,memory.total,driver_version " + "--format=csv,noheader 2>/dev/null || echo ''", + timeout=10, + ) + if result.success and result.output.strip(): + lines = [ln.strip() for ln in result.output.strip().split("\n") if ln.strip()] + caps.gpu_count = len(lines) + caps.gpu_available = caps.gpu_count > 0 + + # Get CUDA toolkit version + cuda_ver = "" + cuda_result = await sandbox.execute( + "nvidia-smi 2>/dev/null | grep 'CUDA Version' | awk '{print $9}'", + timeout=5, + ) + if cuda_result.success and cuda_result.output.strip(): + cuda_ver = cuda_result.output.strip() + + for line in lines: + parts = [p.strip() for p in line.split(",")] + if len(parts) >= 3: + gpu = GPUInfo( + model=parts[0], + vram_gb=_parse_vram(parts[1]), + cuda_version=cuda_ver, + driver_version=parts[2], + ) + caps.gpu_info.append(gpu) + + # Python versions + result = await sandbox.execute( + "python3 --version 2>/dev/null; ls /usr/bin/python* 2>/dev/null || true", + timeout=5, + ) + if result.success: + versions = [] + for line in result.output.strip().split("\n"): + line = line.strip() + if line.startswith("Python "): + versions.append(line.replace("Python ", "")) + elif "/python" in line and not line.endswith("*"): + # Extract version from path like /usr/bin/python3.11 + ver = line.split("/")[-1].replace("python", "") + if ver and ver not in versions: + versions.append(ver) + caps.python_versions = versions + + # Docker + result = await sandbox.execute( + "docker info >/dev/null 2>&1 && echo 'DOCKER_OK' || echo 'DOCKER_FAIL'", + timeout=5, + ) + if result.success and "DOCKER_OK" in result.output: + caps.docker_available = True + + # Conda envs + result = await sandbox.execute( + "conda env list 2>/dev/null | grep -v '^#' | awk '{print $1}' || true", + timeout=5, + ) + if result.success: + envs = [ln.strip() for ln in result.output.strip().split("\n") if ln.strip()] + caps.conda_envs = envs + + # Key packages + result = await sandbox.execute( + "pip list --format=freeze 2>/dev/null | head -50 || true", + timeout=10, + ) + if result.success: + caps.installed_packages = [ + line.strip() for line in result.output.strip().split("\n") + if line.strip() and "==" in line + ] + + # Internet connectivity + result = await sandbox.execute( + "curl -s -o /dev/null -w '%{http_code}' --max-time 5 https://pypi.org/simple/ 2>/dev/null || echo '000'", + timeout=10, + ) + if result.success and result.output.strip() == "200": + caps.has_internet = True + else: + # Fallback ping + result = await sandbox.execute( + "ping -c 1 -W 3 8.8.8.8 2>/dev/null || true", + timeout=10, + ) + caps.has_internet = result.success and "1 received" in result.output + + caps.latency_ms = (time.monotonic() - start) * 1000 + return caps + + +def _parse_vram(vram_str: str) -> float: + """Parse VRAM string like '24576 MiB' or '24 GB' to GB.""" + vram_str = vram_str.strip().lower() + try: + if "mib" in vram_str: + return float(vram_str.replace("mib", "").strip()) / 1024 + elif "gib" in vram_str: + return float(vram_str.replace("gib", "").strip()) + elif "gb" in vram_str: + return float(vram_str.replace("gb", "").strip()) + elif "mb" in vram_str: + return float(vram_str.replace("mb", "").strip()) / 1024 + else: + return float(vram_str) + except ValueError: + return 0.0 diff --git a/backend/openmlr/compute/workspace.py b/backend/openmlr/compute/workspace.py new file mode 100644 index 0000000..9517bf3 --- /dev/null +++ b/backend/openmlr/compute/workspace.py @@ -0,0 +1,166 @@ +"""Workspace Manager — per-conversation filesystem isolation.""" + +import os +import shutil +import tarfile +from datetime import UTC, datetime +from pathlib import Path + + +class WorkspaceManager: + """Manages isolated workspace directories for each conversation.""" + + def __init__(self, base_dir: str | Path = None): + self.base_dir = Path(base_dir) if base_dir else Path.home() / ".openmlr" + self.workspace_dir = self.base_dir / "workspaces" + self.archive_dir = self.base_dir / "archive" + self._ensure_dirs() + + def _ensure_dirs(self) -> None: + """Ensure workspace and archive directories exist.""" + self.workspace_dir.mkdir(parents=True, exist_ok=True) + self.archive_dir.mkdir(parents=True, exist_ok=True) + + def get_workspace_path(self, conversation_uuid: str) -> Path: + """Get the workspace directory for a conversation.""" + return self.workspace_dir / f"workspace-{conversation_uuid}" + + def create_workspace(self, conversation_uuid: str) -> Path: + """Create a new workspace directory for a conversation.""" + path = self.get_workspace_path(conversation_uuid) + path.mkdir(parents=True, exist_ok=True) + # Create standard subdirectories + for subdir in ["data", "models", "code", "outputs"]: + (path / subdir).mkdir(exist_ok=True) + # Create meta directory (hidden from agent) + (path / ".openmlr-meta").mkdir(exist_ok=True) + return path + + def workspace_exists(self, conversation_uuid: str) -> bool: + """Check if a workspace exists.""" + return self.get_workspace_path(conversation_uuid).exists() + + def archive_workspace(self, conversation_uuid: str) -> Path | None: + """Archive a workspace before deletion. Returns archive path.""" + path = self.get_workspace_path(conversation_uuid) + if not path.exists(): + return None + + timestamp = datetime.now(UTC).strftime("%Y%m%d-%H%M%S") + archive_name = f"workspace-{conversation_uuid}-{timestamp}.tar.gz" + archive_path = self.archive_dir / archive_name + + with tarfile.open(archive_path, "w:gz") as tar: + tar.add(path, arcname=path.name) + + return archive_path + + def delete_workspace(self, conversation_uuid: str, archive: bool = True) -> bool: + """Delete a workspace. If archive=True, archive it first.""" + path = self.get_workspace_path(conversation_uuid) + if not path.exists(): + return False + + if archive: + self.archive_workspace(conversation_uuid) + + shutil.rmtree(path) + return True + + def get_workspace_size(self, conversation_uuid: str) -> int: + """Get total size of a workspace in bytes.""" + path = self.get_workspace_path(conversation_uuid) + if not path.exists(): + return 0 + + total = 0 + for dirpath, _, filenames in os.walk(path): + for f in filenames: + fp = Path(dirpath) / f + if fp.exists(): + total += fp.stat().st_size + return total + + def list_workspaces(self) -> list[dict]: + """List all workspaces with metadata.""" + workspaces = [] + for path in self.workspace_dir.glob("workspace-*"): + if path.is_dir(): + uuid = path.name.replace("workspace-", "") + size = self.get_workspace_size(uuid) + workspaces.append({ + "uuid": uuid, + "path": str(path), + "size_bytes": size, + "created": datetime.fromtimestamp(path.stat().st_ctime, UTC).isoformat(), + }) + return sorted(workspaces, key=lambda x: x["created"], reverse=True) + + def cleanup_archives(self, max_age_days: int = 30, max_count: int = 100) -> dict: + """Clean up old workspace archives. + + Args: + max_age_days: Delete archives older than this many days + max_count: Keep at most this many archives, delete oldest first + + Returns: + Dict with deleted count and freed bytes + """ + deleted = 0 + freed_bytes = 0 + + # Get all archives sorted by modification time (oldest first) + archives = [] + for path in self.archive_dir.glob("workspace-*.tar.gz"): + if path.is_file(): + mtime = datetime.fromtimestamp(path.stat().st_mtime, UTC) + archives.append({"path": path, "mtime": mtime, "size": path.stat().st_size}) + + archives.sort(key=lambda x: x["mtime"]) + + # Delete old archives + now = datetime.now(UTC) + for archive in archives: + age_days = (now - archive["mtime"]).days + if age_days > max_age_days: + freed_bytes += archive["size"] + archive["path"].unlink() + deleted += 1 + + # Delete excess archives (oldest first) + remaining = [a for a in archives if a["path"].exists()] + while len(remaining) > max_count: + oldest = remaining.pop(0) + freed_bytes += oldest["size"] + oldest["path"].unlink() + deleted += 1 + + return {"deleted": deleted, "freed_bytes": freed_bytes} + + def cleanup_workspaces(self, conversation_uuids: list[str], archive: bool = True) -> dict: + """Clean up workspaces for deleted conversations. + + Args: + conversation_uuids: List of conversation UUIDs to keep + archive: Whether to archive before deleting + + Returns: + Dict with deleted count and freed bytes + """ + deleted = 0 + freed_bytes = 0 + keep_set = set(conversation_uuids) + + for path in self.workspace_dir.glob("workspace-*"): + if not path.is_dir(): + continue + uuid = path.name.replace("workspace-", "") + if uuid not in keep_set: + size = self.get_workspace_size(uuid) + if archive: + self.archive_workspace(uuid) + shutil.rmtree(path) + freed_bytes += size + deleted += 1 + + return {"deleted": deleted, "freed_bytes": freed_bytes} diff --git a/backend/openmlr/db/models.py b/backend/openmlr/db/models.py index bb0cd93..ef40f74 100644 --- a/backend/openmlr/db/models.py +++ b/backend/openmlr/db/models.py @@ -39,6 +39,8 @@ class User(Base): settings = relationship("UserSetting", back_populates="user", cascade="all, delete-orphan") conversations = relationship("Conversation", back_populates="user", cascade="all, delete-orphan") sandbox_configs = relationship("SandboxConfig", back_populates="user", cascade="all, delete-orphan") + ssh_keys = relationship("SSHKey", back_populates="user", cascade="all, delete-orphan") + compute_nodes = relationship("ComputeNode", back_populates="user", cascade="all, delete-orphan") research_corpus = relationship("ResearchCorpus", back_populates="user", cascade="all, delete-orphan") writing_projects = relationship("WritingProject", back_populates="user", cascade="all, delete-orphan") @@ -111,6 +113,49 @@ class SandboxConfig(Base): user = relationship("User", back_populates="sandbox_configs") +class SSHKey(Base): + __tablename__ = "ssh_keys" + + id = Column(Integer, primary_key=True) + user_id = Column(Integer, ForeignKey("users.id", ondelete="CASCADE"), nullable=False) + filename = Column(String(255), nullable=False) + fingerprint = Column(String(255), nullable=False) + algorithm = Column(String(50), nullable=False) + public_key = Column(Text, nullable=False) + comment = Column(Text, nullable=True) + created_at = Column(DateTime(timezone=True), default=_utcnow, nullable=False) + + user = relationship("User", back_populates="ssh_keys") + __table_args__ = ( + # Unique constraint on (user_id, filename) + {}, + ) + + +class ComputeNode(Base): + __tablename__ = "compute_nodes" + + id = Column(Integer, primary_key=True) + user_id = Column(Integer, ForeignKey("users.id", ondelete="CASCADE"), nullable=False) + name = Column(String(100), nullable=False) + type = Column(String(20), nullable=False) # local, ssh, modal + config = Column(JSON, nullable=False, default=dict) + capabilities = Column(JSON, nullable=True) + health_status = Column(String(20), default="unknown", nullable=False) + last_probed_at = Column(DateTime(timezone=True), nullable=True) + last_seen_at = Column(DateTime(timezone=True), nullable=True) + is_default = Column(Boolean, default=False) + priority = Column(Integer, default=0, nullable=False) + created_at = Column(DateTime(timezone=True), default=_utcnow, nullable=False) + updated_at = Column(DateTime(timezone=True), default=_utcnow, onupdate=_utcnow, nullable=False) + + user = relationship("User", back_populates="compute_nodes") + __table_args__ = ( + # Unique constraint on (user_id, name) + {}, + ) + + class ResearchCorpus(Base): __tablename__ = "research_corpus" diff --git a/backend/openmlr/db/operations.py b/backend/openmlr/db/operations.py index 663a440..2ea72a1 100644 --- a/backend/openmlr/db/operations.py +++ b/backend/openmlr/db/operations.py @@ -7,10 +7,12 @@ from .models import ( AgentJob, + ComputeNode, Conversation, ConversationResource, ConversationTask, Message, + SSHKey, UserSetting, ) @@ -77,6 +79,13 @@ async def update_conversation_model(db: AsyncSession, conv_id: int, model: str): await db.commit() +async def update_conversation_extra(db: AsyncSession, conv_id: int, extra: dict): + await db.execute( + update(Conversation).where(Conversation.id == conv_id).values(extra=extra) + ) + await db.commit() + + async def increment_user_message_count(db: AsyncSession, conv_id: int): await db.execute( update(Conversation) @@ -495,3 +504,147 @@ async def get_user_agent_settings(db: AsyncSession, user_id: int) -> dict: for s in result.scalars().all(): settings[s.key] = _clean_json_value(s.value) return settings + + +# ---- SSH Keys ---- + +async def create_ssh_key( + db: AsyncSession, user_id: int, filename: str, fingerprint: str, + algorithm: str, public_key: str, comment: str | None = None, +) -> SSHKey: + key = SSHKey( + user_id=user_id, + filename=filename, + fingerprint=fingerprint, + algorithm=algorithm, + public_key=public_key, + comment=comment, + ) + db.add(key) + await db.commit() + await db.refresh(key) + return key + + +async def get_ssh_keys(db: AsyncSession, user_id: int) -> list[SSHKey]: + result = await db.execute( + select(SSHKey).where(SSHKey.user_id == user_id).order_by(SSHKey.created_at.desc()) + ) + return list(result.scalars().all()) + + +async def get_ssh_key_by_filename(db: AsyncSession, user_id: int, filename: str) -> SSHKey | None: + result = await db.execute( + select(SSHKey).where(SSHKey.user_id == user_id, SSHKey.filename == filename) + ) + return result.scalar_one_or_none() + + +async def delete_ssh_key(db: AsyncSession, user_id: int, filename: str) -> bool: + result = await db.execute( + select(SSHKey).where(SSHKey.user_id == user_id, SSHKey.filename == filename) + ) + key = result.scalar_one_or_none() + if not key: + return False + await db.delete(key) + await db.commit() + return True + + +# ---- Compute Nodes ---- + +async def create_compute_node( + db: AsyncSession, user_id: int, name: str, node_type: str, config: dict, + is_default: bool = False, priority: int = 0, +) -> ComputeNode: + node = ComputeNode( + user_id=user_id, + name=name, + type=node_type, + config=config, + is_default=is_default, + priority=priority, + ) + db.add(node) + await db.commit() + await db.refresh(node) + return node + + +async def get_compute_nodes(db: AsyncSession, user_id: int) -> list[ComputeNode]: + result = await db.execute( + select(ComputeNode).where(ComputeNode.user_id == user_id).order_by(ComputeNode.priority.desc(), ComputeNode.created_at.desc()) + ) + return list(result.scalars().all()) + + +async def get_compute_node_by_id(db: AsyncSession, node_id: int, user_id: int | None = None) -> ComputeNode | None: + query = select(ComputeNode).where(ComputeNode.id == node_id) + if user_id is not None: + query = query.where(ComputeNode.user_id == user_id) + result = await db.execute(query) + return result.scalar_one_or_none() + + +async def get_compute_node_by_name(db: AsyncSession, user_id: int, name: str) -> ComputeNode | None: + result = await db.execute( + select(ComputeNode).where(ComputeNode.user_id == user_id, ComputeNode.name == name) + ) + return result.scalar_one_or_none() + + +async def update_compute_node( + db: AsyncSession, node_id: int, user_id: int, **kwargs, +) -> ComputeNode | None: + result = await db.execute( + select(ComputeNode).where(ComputeNode.id == node_id, ComputeNode.user_id == user_id) + ) + node = result.scalar_one_or_none() + if not node: + return None + for key, value in kwargs.items(): + if hasattr(node, key): + setattr(node, key, value) + await db.commit() + await db.refresh(node) + return node + + +async def delete_compute_node(db: AsyncSession, node_id: int, user_id: int) -> bool: + result = await db.execute( + select(ComputeNode).where(ComputeNode.id == node_id, ComputeNode.user_id == user_id) + ) + node = result.scalar_one_or_none() + if not node: + return False + await db.delete(node) + await db.commit() + return True + + +async def set_default_compute_node(db: AsyncSession, user_id: int, node_id: int | None) -> None: + # Clear existing default + await db.execute( + update(ComputeNode) + .where(ComputeNode.user_id == user_id, ComputeNode.is_default.is_(True)) + .values(is_default=False) + ) + # Set new default + if node_id is not None: + await db.execute( + update(ComputeNode) + .where(ComputeNode.id == node_id, ComputeNode.user_id == user_id) + .values(is_default=True) + ) + await db.commit() + + +async def get_default_compute_node(db: AsyncSession, user_id: int) -> ComputeNode | None: + result = await db.execute( + select(ComputeNode).where( + ComputeNode.user_id == user_id, + ComputeNode.is_default.is_(True), + ) + ) + return result.scalar_one_or_none() diff --git a/backend/openmlr/keys/__init__.py b/backend/openmlr/keys/__init__.py new file mode 100644 index 0000000..4e02442 --- /dev/null +++ b/backend/openmlr/keys/__init__.py @@ -0,0 +1,3 @@ +from .manager import KeyManager + +__all__ = ["KeyManager"] diff --git a/backend/openmlr/keys/manager.py b/backend/openmlr/keys/manager.py new file mode 100644 index 0000000..a7798d4 --- /dev/null +++ b/backend/openmlr/keys/manager.py @@ -0,0 +1,169 @@ +"""SSH Key Asset Manager — handles .keys/ directory lifecycle.""" + +import os +import stat +from pathlib import Path + +from cryptography.hazmat.primitives import serialization as crypto_serialization +from cryptography.hazmat.primitives.asymmetric import ed25519, rsa + + +class KeyManager: + """Manages SSH private keys stored in a dedicated directory.""" + + def __init__(self, keys_dir: str | Path = None): + self.keys_dir = Path(keys_dir) if keys_dir else Path(__file__).parent.parent.parent.parent.parent / ".keys" + self._ensure_dir() + + def _ensure_dir(self) -> None: + """Ensure .keys/ directory exists with correct permissions.""" + self.keys_dir.mkdir(parents=True, exist_ok=True) + # Set directory permissions to 0o700 (owner read/write/execute only) + os.chmod(self.keys_dir, 0o700) + + def list_keys(self) -> list[dict]: + """List all key files (metadata only, no private content).""" + keys = [] + for path in sorted(self.keys_dir.glob("id_*")): + if path.suffix == ".pub": + continue + pub_path = path.with_suffix(path.suffix + ".pub") + keys.append({ + "filename": path.name, + "has_public": pub_path.exists(), + "size_bytes": path.stat().st_size, + }) + return keys + + def key_exists(self, filename: str) -> bool: + """Check if a key file exists.""" + return (self.keys_dir / filename).exists() + + def get_key_path(self, filename: str) -> Path: + """Get the absolute path to a key file.""" + return self.keys_dir / filename + + def write_key(self, filename: str, private_key_pem: str | bytes) -> Path: + """Write a private key to disk with restrictive permissions.""" + key_path = self.keys_dir / filename + if isinstance(private_key_pem, str): + private_key_pem = private_key_pem.encode("utf-8") + + key_path.write_bytes(private_key_pem) + # Set file permissions to 0o600 (owner read/write only) + os.chmod(key_path, stat.S_IRUSR | stat.S_IWUSR) + return key_path + + def read_key(self, filename: str) -> str: + """Read a private key from disk. Use sparingly.""" + key_path = self.keys_dir / filename + if not key_path.exists(): + raise FileNotFoundError(f"Key not found: {filename}") + return key_path.read_text("utf-8") + + def delete_key(self, filename: str) -> bool: + """Delete a key pair from disk.""" + key_path = self.keys_dir / filename + pub_path = key_path.with_suffix(key_path.suffix + ".pub") + deleted = False + if key_path.exists(): + key_path.unlink() + deleted = True + if pub_path.exists(): + pub_path.unlink() + deleted = True + return deleted + + def generate_key_pair(self, filename: str, algorithm: str = "ed25519", comment: str = "") -> tuple[Path, Path]: + """Generate a new SSH key pair and write to disk.""" + key_path = self.keys_dir / filename + pub_path = key_path.with_suffix(key_path.suffix + ".pub") + + if algorithm == "ed25519": + private_key = ed25519.Ed25519PrivateKey.generate() + private_pem = private_key.private_bytes( + encoding=crypto_serialization.Encoding.PEM, + format=crypto_serialization.PrivateFormat.OpenSSH, + encryption_algorithm=crypto_serialization.NoEncryption(), + ) + public_bytes = private_key.public_key().public_bytes( + encoding=crypto_serialization.Encoding.OpenSSH, + format=crypto_serialization.PublicFormat.OpenSSH, + ) + elif algorithm == "rsa": + private_key = rsa.generate_private_key(public_exponent=65537, key_size=4096) + private_pem = private_key.private_bytes( + encoding=crypto_serialization.Encoding.PEM, + format=crypto_serialization.PrivateFormat.TraditionalOpenSSL, + encryption_algorithm=crypto_serialization.NoEncryption(), + ) + public_bytes = private_key.public_key().public_bytes( + encoding=crypto_serialization.Encoding.OpenSSH, + format=crypto_serialization.PublicFormat.OpenSSH, + ) + else: + raise ValueError(f"Unsupported algorithm: {algorithm}. Use 'ed25519' or 'rsa'.") + + # Write private key with 0o600 + key_path.write_bytes(private_pem) + os.chmod(key_path, stat.S_IRUSR | stat.S_IWUSR) + + # Write public key with 0o644 + pub_pem = public_bytes + (f" {comment}".encode() if comment else b"") + pub_path.write_bytes(pub_pem) + os.chmod(pub_path, stat.S_IRUSR | stat.S_IWUSR | stat.S_IRGRP | stat.S_IROTH) + + return key_path, pub_path + + def validate_key(self, private_key_pem: str | bytes) -> dict: + """Validate an SSH private key and return metadata.""" + if isinstance(private_key_pem, str): + private_key_pem = private_key_pem.encode("utf-8") + + # Try to load as OpenSSH format + try: + key = crypto_serialization.load_ssh_private_key(private_key_pem, password=None) + except Exception: + # Try PEM format + try: + key = crypto_serialization.load_pem_private_key(private_key_pem, password=None) + except Exception as e: + raise ValueError(f"Invalid private key: {e}") + + # Determine algorithm + key_type = type(key).__name__.lower() + if "ed25519" in key_type: + algorithm = "ssh-ed25519" + elif "rsa" in key_type: + algorithm = "ssh-rsa" + else: + algorithm = key_type + + # Generate public key for fingerprint + public_key = key.public_key() + public_bytes = public_key.public_bytes( + encoding=crypto_serialization.Encoding.OpenSSH, + format=crypto_serialization.PublicFormat.OpenSSH, + ) + + # Compute SHA256 fingerprint matching ssh-keygen format: + # The fingerprint is the SHA256 of the raw key blob (base64-decoded + # portion of the OpenSSH public key line), base64-encoded. + import base64 + import hashlib + + pub_line = public_bytes.decode("utf-8").strip() + # OpenSSH format: "ssh-ed25519 AAAA... comment" + parts = pub_line.split() + if len(parts) >= 2: + key_blob = base64.b64decode(parts[1]) + else: + key_blob = public_bytes + raw_hash = hashlib.sha256(key_blob).digest() + fingerprint = base64.b64encode(raw_hash).decode("ascii").rstrip("=") + + return { + "algorithm": algorithm, + "fingerprint": f"SHA256:{fingerprint}", + "public_key": pub_line, + } diff --git a/backend/openmlr/routes/agent.py b/backend/openmlr/routes/agent.py index 5afdfd0..bc17185 100644 --- a/backend/openmlr/routes/agent.py +++ b/backend/openmlr/routes/agent.py @@ -207,6 +207,128 @@ async def switch_conversation( return {"ok": True} +# ── Per-Conversation Compute ───────────────────────────── + +@router.get("/conversations/{uuid}/compute") +async def get_conversation_compute( + uuid: str, + user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + """Get the active compute node for a conversation.""" + conv = await _get_conv_or_404(db, uuid, user.id) + + # Check conversation override first + if conv.extra and conv.extra.get("compute_node_id"): + node = await ops.get_compute_node_by_id(db, conv.extra["compute_node_id"], user.id) + if node: + return { + "node": { + "id": node.id, + "name": node.name, + "type": node.type, + }, + "source": "conversation", + } + + # Fall back to user's default + default_node = await ops.get_default_compute_node(db, user.id) + if default_node: + return { + "node": { + "id": default_node.id, + "name": default_node.name, + "type": default_node.type, + }, + "source": "default", + } + + return {"node": None, "source": None} + + +@router.post("/conversations/{uuid}/compute") +async def set_conversation_compute( + uuid: str, + request: Request, + user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + """Bind a compute node to a conversation.""" + body = await request.json() + node_id = body.get("node_id") + + conv = await _get_conv_or_404(db, uuid, user.id) + + if node_id is None: + # Clear override + extra = conv.extra or {} + extra.pop("compute_node_id", None) + extra.pop("compute_node_name", None) + await ops.update_conversation_extra(db, conv.id, extra) + return {"ok": True, "node": None} + + # Validate node exists and belongs to user + node = await ops.get_compute_node_by_id(db, node_id, user.id) + if not node: + raise HTTPException(status_code=404, detail="Compute node not found") + + extra = conv.extra or {} + extra["compute_node_id"] = node.id + extra["compute_node_name"] = node.name + await ops.update_conversation_extra(db, conv.id, extra) + + # Update active session if it exists — must rebuild tool_router + # since sandbox tools capture sandbox_manager in closures + sm = _sm(request) + active = sm.get_session(conv.id) + if active: + from ..compute import WorkspaceManager + from ..sandbox.manager import SandboxManager + from ..tools.registry import create_tool_router + + workspace_manager = WorkspaceManager() + sandbox_manager = SandboxManager( + workspace_manager=workspace_manager, + conversation_uuid=conv.uuid, + ) + await sandbox_manager.create(node.type, node.config) + # Destroy old sandbox + try: + await active.sandbox_manager.destroy() + except Exception: + pass + active.sandbox_manager = sandbox_manager + # Rebuild tool router with new sandbox_manager + active.tool_router = create_tool_router(sandbox_manager) + active.tool_router.set_context(user_id=user.id, db=db) + + return { + "ok": True, + "node": { + "id": node.id, + "name": node.name, + "type": node.type, + }, + } + + +@router.delete("/conversations/{uuid}/compute") +async def clear_conversation_compute( + uuid: str, + user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + """Clear the compute override for a conversation (falls back to default).""" + conv = await _get_conv_or_404(db, uuid, user.id) + + extra = conv.extra or {} + extra.pop("compute_node_id", None) + extra.pop("compute_node_name", None) + await ops.update_conversation_extra(db, conv.id, extra) + + return {"ok": True} + + # ── Messaging ──────────────────────────────────────────── @router.post("/message") diff --git a/backend/openmlr/routes/compute.py b/backend/openmlr/routes/compute.py new file mode 100644 index 0000000..fbc2c25 --- /dev/null +++ b/backend/openmlr/routes/compute.py @@ -0,0 +1,388 @@ +"""Compute Node routes — CRUD, testing, probing, and defaults.""" + +from datetime import UTC, datetime + +from fastapi import APIRouter, Depends, HTTPException, Request +from sqlalchemy.ext.asyncio import AsyncSession + +from ..compute import ComputeManager +from ..db import operations as ops +from ..db.engine import get_db +from ..db.models import User +from ..dependencies import get_current_user +from ..keys import KeyManager + +router = APIRouter(prefix="/api/compute", tags=["compute"]) + +key_manager = KeyManager() +compute_manager = ComputeManager(key_manager) + +# Fields to redact from config before sending to the frontend +_SENSITIVE_CONFIG_KEYS = {"password", "private_key", "secret", "token"} + + +def _redact_config(config: dict) -> dict: + """Return config with sensitive fields masked.""" + if not config: + return {} + redacted = {} + for k, v in config.items(): + if k in _SENSITIVE_CONFIG_KEYS and v: + redacted[k] = "***" + else: + redacted[k] = v + return redacted + + +def _node_dict(node) -> dict: + return { + "id": node.id, + "name": node.name, + "type": node.type, + "config": _redact_config(node.config), + "capabilities": node.capabilities or {}, + "health_status": node.health_status, + "last_probed_at": node.last_probed_at.isoformat() if node.last_probed_at else None, + "last_seen_at": node.last_seen_at.isoformat() if node.last_seen_at else None, + "is_default": node.is_default, + "priority": node.priority, + "created_at": node.created_at.isoformat() if node.created_at else None, + "updated_at": node.updated_at.isoformat() if node.updated_at else None, + } + + +# ── Compute Nodes ──────────────────────────────────────── + +@router.get("/nodes") +async def list_nodes( + user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + """List all compute nodes for the current user.""" + nodes = await ops.get_compute_nodes(db, user.id) + return {"nodes": [_node_dict(n) for n in nodes]} + + +@router.post("/nodes") +async def create_node( + request: Request, + user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + """Create a new compute node.""" + body = await request.json() + name = body.get("name", "").strip() + node_type = body.get("type", "").strip() + config = body.get("config", {}) + is_default = body.get("is_default", False) + priority = body.get("priority", 0) + + if not name: + raise HTTPException(status_code=400, detail="Missing 'name'") + if node_type not in ("local", "ssh", "modal"): + raise HTTPException(status_code=400, detail="type must be 'local', 'ssh', or 'modal'") + + # Validate config + valid, error = compute_manager.validate_node_config(node_type, config) + if not valid: + raise HTTPException(status_code=400, detail=error) + + # Check for duplicate name + existing = await ops.get_compute_node_by_name(db, user.id, name) + if existing: + raise HTTPException(status_code=409, detail=f"Node '{name}' already exists") + + # If setting as default, clear existing default + if is_default: + await ops.set_default_compute_node(db, user.id, None) + + node = await ops.create_compute_node( + db, user.id, name, node_type, config, + is_default=is_default, priority=priority, + ) + + return {"node": _node_dict(node)} + + +@router.get("/nodes/{node_id}") +async def get_node( + node_id: int, + user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + """Get a single compute node.""" + node = await ops.get_compute_node_by_id(db, node_id, user.id) + if not node: + raise HTTPException(status_code=404, detail="Node not found") + return {"node": _node_dict(node)} + + +@router.put("/nodes/{node_id}") +async def update_node( + node_id: int, + request: Request, + user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + """Update a compute node's configuration.""" + body = await request.json() + node = await ops.get_compute_node_by_id(db, node_id, user.id) + if not node: + raise HTTPException(status_code=404, detail="Node not found") + + updates = {} + if "name" in body: + new_name = body["name"].strip() + if new_name and new_name != node.name: + existing = await ops.get_compute_node_by_name(db, user.id, new_name) + if existing: + raise HTTPException(status_code=409, detail=f"Node '{new_name}' already exists") + updates["name"] = new_name + + if "config" in body: + config = body["config"] + valid, error = compute_manager.validate_node_config(node.type, config) + if not valid: + raise HTTPException(status_code=400, detail=error) + updates["config"] = config + + if "priority" in body: + updates["priority"] = int(body["priority"]) + + if "is_default" in body: + if body["is_default"]: + await ops.set_default_compute_node(db, user.id, None) + updates["is_default"] = bool(body["is_default"]) + + updated = await ops.update_compute_node(db, node_id, user.id, **updates) + return {"node": _node_dict(updated)} + + +@router.delete("/nodes/{node_id}") +async def delete_node( + node_id: int, + user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + """Delete a compute node.""" + deleted = await ops.delete_compute_node(db, node_id, user.id) + if not deleted: + raise HTTPException(status_code=404, detail="Node not found") + return {"ok": True} + + +@router.post("/nodes/{node_id}/set-default") +async def set_default_node( + node_id: int, + user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + """Set a compute node as the user's default.""" + node = await ops.get_compute_node_by_id(db, node_id, user.id) + if not node: + raise HTTPException(status_code=404, detail="Node not found") + await ops.set_default_compute_node(db, user.id, node_id) + return {"ok": True} + + +@router.post("/nodes/{node_id}/test") +async def test_node( + node_id: int, + user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + """Test connectivity to a compute node (lightweight).""" + node = await ops.get_compute_node_by_id(db, node_id, user.id) + if not node: + raise HTTPException(status_code=404, detail="Node not found") + + if node.type == "ssh": + return await _test_ssh_node(node) + elif node.type == "local": + return await _test_local_node(node) + elif node.type == "modal": + return await _test_modal_node(node) + + return {"ok": False, "error": "Unknown node type"} + + +@router.post("/test") +async def test_node_config( + request: Request, + user: User = Depends(get_current_user), +): + """Test connectivity for an unsaved node config. + + Used before creating a node so the user can verify credentials work. + """ + body = await request.json() + node_type = body.get("type", "") + config = body.get("config", {}) + + if node_type not in ("local", "ssh", "modal"): + return {"ok": False, "error": "Invalid node type"} + + # Build a lightweight mock object that _test_* functions can read + class _MockNode: + def __init__(self, t, c): + self.type = t + self.config = c + + mock = _MockNode(node_type, config) + + if node_type == "ssh": + return await _test_ssh_node(mock) + elif node_type == "local": + return await _test_local_node(mock) + elif node_type == "modal": + return await _test_modal_node(mock) + + return {"ok": False, "error": "Unknown node type"} + + +@router.post("/nodes/{node_id}/probe") +async def probe_node( + node_id: int, + user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + """Deep capability discovery for a compute node.""" + node = await ops.get_compute_node_by_id(db, node_id, user.id) + if not node: + raise HTTPException(status_code=404, detail="Node not found") + + from ..compute import WorkspaceManager + from ..compute.probe import probe_sandbox + from ..sandbox.manager import SandboxManager + + try: + wm = WorkspaceManager() + sm = SandboxManager(workspace_manager=wm) + await sm.create(node.type, node.config) + sandbox = sm.get_active() + + if not sandbox: + raise RuntimeError("Failed to create sandbox") + + caps = await probe_sandbox(sandbox) + + # Update node in database + await ops.update_compute_node( + db, node.id, user.id, + capabilities=caps.to_dict(), + health_status="online", + last_probed_at=datetime.now(UTC), + ) + + await sm.destroy() + + return { + "ok": True, + "capabilities": caps.to_dict(), + } + + except Exception as e: + await ops.update_compute_node( + db, node.id, user.id, + health_status="offline", + ) + return {"ok": False, "error": str(e)} + + +async def _test_ssh_node(node): + """Test SSH connectivity and retrieve host key fingerprint if not set.""" + import asyncio + + import paramiko + + config = node.config + host = config.get("host", "") + port = config.get("port", 22) + username = config.get("username", "") + key_filename = config.get("key_filename") + password = config.get("password") + + try: + def _do_test(): + client = paramiko.SSHClient() + # Use WarningPolicy to get host key without auto-adding + client.set_missing_host_key_policy(paramiko.WarningPolicy()) + + connect_kwargs = { + "hostname": host, + "port": port, + "username": username, + "timeout": 10, + } + + if key_filename: + key_path = key_manager.get_key_path(key_filename) + connect_kwargs["key_filename"] = str(key_path) + elif password: + connect_kwargs["password"] = password + + try: + client.connect(**connect_kwargs) + except paramiko.SSHException as e: + # If host key is unknown, paramiko raises an exception with WarningPolicy + # We need to extract the host key from the transport + transport = client.get_transport() + if transport: + transport.close() + raise e + + # Get host key fingerprint + transport = client.get_transport() + host_key = transport.get_remote_server_key() + fingerprint = host_key.get_fingerprint().hex() + + # Run a simple command + stdin, stdout, stderr = client.exec_command("echo ok", timeout=5) + exit_code = stdout.channel.recv_exit_status() + output = stdout.read().decode("utf-8", errors="replace").strip() + + client.close() + + return { + "connected": exit_code == 0 and output == "ok", + "host_key_fingerprint": fingerprint, + "output": output, + } + + result = await asyncio.to_thread(_do_test) + return { + "ok": result["connected"], + "host_key_fingerprint": result.get("host_key_fingerprint"), + "message": "Connected successfully" if result["connected"] else f"Unexpected output: {result['output']}", + } + + except Exception as e: + return {"ok": False, "error": str(e)} + + +async def _test_local_node(node): + """Test local workspace directory.""" + import os + from pathlib import Path + + config = node.config + workdir = config.get("workdir", "") + if not workdir: + workdir = os.getcwd() + + path = Path(workdir).expanduser() + if path.exists() and path.is_dir(): + return {"ok": True, "message": f"Workspace ready: {path}"} + else: + return {"ok": False, "error": f"Workspace not found: {path}"} + + +async def _test_modal_node(node): + """Test Modal connectivity.""" + try: + import importlib.util + if importlib.util.find_spec("modal") is not None: + return {"ok": True, "message": "Modal client available"} + return {"ok": False, "error": "Modal client not installed"} + except Exception: + return {"ok": False, "error": "Modal client not installed"} diff --git a/backend/openmlr/routes/keys.py b/backend/openmlr/routes/keys.py new file mode 100644 index 0000000..866bc96 --- /dev/null +++ b/backend/openmlr/routes/keys.py @@ -0,0 +1,136 @@ +"""SSH Key routes — CRUD for key assets stored in .keys/.""" + +from fastapi import APIRouter, Depends, HTTPException, Request +from sqlalchemy.ext.asyncio import AsyncSession + +from ..db import operations as ops +from ..db.engine import get_db +from ..db.models import User +from ..dependencies import get_current_user +from ..keys import KeyManager + +router = APIRouter(prefix="/api", tags=["keys"]) + +key_manager = KeyManager() + + +@router.get("/keys") +async def list_keys( + user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + """List all SSH key metadata for the current user.""" + keys = await ops.get_ssh_keys(db, user.id) + return { + "keys": [ + { + "id": k.id, + "filename": k.filename, + "fingerprint": k.fingerprint, + "algorithm": k.algorithm, + "comment": k.comment, + "created_at": k.created_at.isoformat() if k.created_at else None, + } + for k in keys + ] + } + + +@router.post("/keys") +async def create_key( + request: Request, + user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + """Upload or generate an SSH key pair.""" + body = await request.json() + action = body.get("action") + filename = body.get("filename", "") + + if not filename: + raise HTTPException(status_code=400, detail="Missing 'filename'") + + # Prevent path traversal in filename + from pathlib import Path as PyPath + safe_filename = PyPath(filename).name + if not safe_filename or safe_filename.startswith("."): + raise HTTPException(status_code=400, detail="Invalid filename") + + existing = await ops.get_ssh_key_by_filename(db, user.id, safe_filename) + if existing: + raise HTTPException(status_code=409, detail=f"Key '{safe_filename}' already exists") + + if action == "upload": + private_key = body.get("private_key", "") + if not private_key: + raise HTTPException(status_code=400, detail="Missing 'private_key' for upload") + + try: + meta = key_manager.validate_key(private_key) + except ValueError as e: + raise HTTPException(status_code=400, detail=f"Invalid key: {e}") + + key_manager.write_key(safe_filename, private_key) + + elif action == "generate": + algorithm = body.get("algorithm", "ed25519") + comment = body.get("comment", f"openmlr-{user.id}") + try: + key_path, pub_path = key_manager.generate_key_pair(safe_filename, algorithm, comment) + private_key = key_path.read_text() + meta = key_manager.validate_key(private_key) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + + else: + raise HTTPException(status_code=400, detail="action must be 'upload' or 'generate'") + + key = await ops.create_ssh_key( + db, user.id, safe_filename, meta["fingerprint"], + meta["algorithm"], meta["public_key"], body.get("comment"), + ) + + return { + "key": { + "id": key.id, + "filename": key.filename, + "fingerprint": key.fingerprint, + "algorithm": key.algorithm, + "public_key": key.public_key, + "comment": key.comment, + "created_at": key.created_at.isoformat() if key.created_at else None, + } + } + + +@router.delete("/keys/{filename}") +async def delete_key( + filename: str, + user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + """Delete an SSH key and its public counterpart.""" + # Sanitize filename to prevent path traversal + from pathlib import Path as PyPath + safe_filename = PyPath(filename).name + if not safe_filename or safe_filename != filename or safe_filename.startswith("."): + raise HTTPException(status_code=400, detail="Invalid filename") + filename = safe_filename + + # Check if any compute nodes reference this key + nodes = await ops.get_compute_nodes(db, user.id) + dependent_nodes = [n for n in nodes if n.config.get("key_filename") == filename] + + if dependent_nodes: + node_names = ", ".join(n.name for n in dependent_nodes) + raise HTTPException( + status_code=409, + detail=f"Cannot delete key: used by compute nodes: {node_names}" + ) + + deleted_db = await ops.delete_ssh_key(db, user.id, filename) + if not deleted_db: + raise HTTPException(status_code=404, detail="Key not found") + + key_manager.delete_key(filename) + return {"ok": True} diff --git a/backend/openmlr/sandbox/interface.py b/backend/openmlr/sandbox/interface.py index ef61c28..19b904a 100644 --- a/backend/openmlr/sandbox/interface.py +++ b/backend/openmlr/sandbox/interface.py @@ -1,19 +1,9 @@ """Abstract sandbox interface and data types.""" from abc import ABC, abstractmethod -from dataclasses import dataclass, field +from dataclasses import dataclass - -@dataclass -class EnvironmentInfo: - """Information about a sandbox environment.""" - os: str = "unknown" - python_version: str = "unknown" - gpu_available: bool = False - gpu_info: str | None = None - installed_packages: list[str] = field(default_factory=list) - available_disk_gb: float = 0.0 - available_ram_gb: float = 0.0 +from ..compute.capabilities import ComputeCapabilities @dataclass @@ -41,6 +31,23 @@ async def execute(self, command: str, timeout: int = 120) -> ExecutionResult: """Execute a shell command in the sandbox.""" ... + async def execute_stream(self, command: str, timeout: int = 120, on_chunk=None): + """Execute a command and stream output chunks via callback. + + Args: + command: Shell command to execute + timeout: Timeout in seconds + on_chunk: Callback function(text: str, is_stderr: bool) called for each chunk + + Returns: + ExecutionResult with full output + """ + # Default implementation falls back to regular execute + result = await self.execute(command, timeout) + if on_chunk and result.output: + on_chunk(result.output, False) + return result + @abstractmethod async def read_file(self, path: str) -> str: """Read a file from the sandbox filesystem.""" @@ -67,7 +74,7 @@ async def list_files(self, path: str = ".") -> list[str]: ... @abstractmethod - async def probe_environment(self) -> EnvironmentInfo: + async def probe_environment(self) -> ComputeCapabilities: """Probe the sandbox environment for capabilities.""" ... diff --git a/backend/openmlr/sandbox/local.py b/backend/openmlr/sandbox/local.py index 89143bc..4b578d3 100644 --- a/backend/openmlr/sandbox/local.py +++ b/backend/openmlr/sandbox/local.py @@ -2,25 +2,42 @@ import asyncio import os -import platform -import shutil import time from pathlib import Path -from .interface import EnvironmentInfo, ExecutionResult, SandboxInterface +from ..compute.probe import probe_sandbox +from .interface import ExecutionResult, SandboxInterface class LocalSandbox(SandboxInterface): """Execute commands directly on the local machine.""" - def __init__(self, workdir: str = None): + def __init__(self, workdir: str = None, workspace_manager=None): + self._workspace_manager = workspace_manager + self._conversation_uuid = None self.workdir = workdir or os.getcwd() async def create(self, config: dict) -> "LocalSandbox": self.workdir = config.get("workdir", os.getcwd()) + self._conversation_uuid = config.get("conversation_uuid") + + # If workspace manager is available and conversation UUID is set, + # use the per-conversation workspace + if self._workspace_manager and self._conversation_uuid: + ws_path = self._workspace_manager.create_workspace(self._conversation_uuid) + self.workdir = str(ws_path) + elif self._workspace_manager: + # Fallback: create workspace without UUID + ws_path = self._workspace_manager.create_workspace("default") + self.workdir = str(ws_path) + return self async def execute(self, command: str, timeout: int = 120) -> ExecutionResult: + return await self.execute_stream(command, timeout) + + async def execute_stream(self, command: str, timeout: int = 120, on_chunk=None) -> ExecutionResult: + """Execute a command with optional streaming output.""" start = time.monotonic() try: proc = await asyncio.create_subprocess_shell( @@ -29,25 +46,47 @@ async def execute(self, command: str, timeout: int = 120) -> ExecutionResult: stderr=asyncio.subprocess.PIPE, cwd=self.workdir, ) - stdout, stderr = await asyncio.wait_for( - proc.communicate(), timeout=timeout - ) output_parts = [] - if stdout: - output_parts.append(stdout.decode("utf-8", errors="replace")) - if stderr: - output_parts.append(f"STDERR:\n{stderr.decode('utf-8', errors='replace')}") - output = "\n".join(output_parts) if output_parts else "(no output)" + async def _read_stream(stream, is_stderr): + """Read a stream and emit chunks.""" + while True: + try: + line = await asyncio.wait_for(stream.readline(), timeout=0.5) + if not line: + break + text = line.decode("utf-8", errors="replace") + if on_chunk: + on_chunk(text, is_stderr) + output_parts.append(text) + except TimeoutError: + # Check if process is done + if proc.returncode is not None: + break + continue + + # Read stdout and stderr concurrently + await asyncio.gather( + _read_stream(proc.stdout, False), + _read_stream(proc.stderr, True), + ) + + # Wait for process to complete + try: + returncode = await asyncio.wait_for(proc.wait(), timeout=1.0) + except TimeoutError: + returncode = proc.returncode if proc.returncode is not None else -1 + + output = "".join(output_parts) if output_parts else "(no output)" if len(output) > 50000: output = output[:50000] + "\n...[truncated]" duration = time.monotonic() - start return ExecutionResult( output=output, - success=proc.returncode == 0, - exit_code=proc.returncode, + success=returncode == 0, + exit_code=returncode, duration_seconds=duration, ) except TimeoutError: @@ -97,45 +136,8 @@ async def list_files(self, path: str = ".") -> list[str]: for e in target.iterdir() ]) - async def probe_environment(self) -> EnvironmentInfo: - info = EnvironmentInfo( - os=f"{platform.system()} {platform.release()}", - ) - - # Python version - result = await self.execute("python3 --version", timeout=5) - if result.success: - info.python_version = result.output.strip() - - # GPU - result = await self.execute( - "nvidia-smi --query-gpu=name,memory.total --format=csv,noheader", - timeout=5, - ) - if result.success and result.output.strip(): - info.gpu_available = True - info.gpu_info = result.output.strip() - - # Disk - total, used, free = shutil.disk_usage(self.workdir) - info.available_disk_gb = free / (1024 ** 3) - - # RAM - try: - import psutil - info.available_ram_gb = psutil.virtual_memory().available / (1024 ** 3) - except ImportError: - pass - - # Key packages - result = await self.execute("pip list --format=freeze 2>/dev/null | head -30", timeout=10) - if result.success: - info.installed_packages = [ - line.split("==")[0] for line in result.output.strip().split("\n") - if "==" in line - ] - - return info + async def probe_environment(self): + return await probe_sandbox(self) async def destroy(self) -> None: pass # Local sandbox has nothing to clean up diff --git a/backend/openmlr/sandbox/manager.py b/backend/openmlr/sandbox/manager.py index dc054e7..32060a0 100644 --- a/backend/openmlr/sandbox/manager.py +++ b/backend/openmlr/sandbox/manager.py @@ -1,6 +1,5 @@ """SandboxManager — lifecycle management and provider selection.""" - from .interface import SandboxInterface from .local import LocalSandbox from .modal_sandbox import ModalSandbox @@ -10,9 +9,11 @@ class SandboxManager: """Manages sandbox lifecycle: create, switch, destroy.""" - def __init__(self): + def __init__(self, workspace_manager=None, conversation_uuid: str = None): self._active: SandboxInterface | None = None self.active_type: str = "none" + self._workspace_manager = workspace_manager + self._conversation_uuid = conversation_uuid def get_active(self) -> SandboxInterface | None: return self._active @@ -25,8 +26,11 @@ async def create(self, provider: str, config: dict = None) -> SandboxInterface: config = config or {} + # Inject workspace and conversation context + config["conversation_uuid"] = self._conversation_uuid + if provider == "local": - sandbox = LocalSandbox() + sandbox = LocalSandbox(workspace_manager=self._workspace_manager) elif provider == "ssh": sandbox = SSHSandbox() elif provider == "modal": diff --git a/backend/openmlr/sandbox/modal_sandbox.py b/backend/openmlr/sandbox/modal_sandbox.py index 7713a44..ad3277c 100644 --- a/backend/openmlr/sandbox/modal_sandbox.py +++ b/backend/openmlr/sandbox/modal_sandbox.py @@ -3,7 +3,8 @@ import asyncio import time -from .interface import EnvironmentInfo, ExecutionResult, SandboxInterface +from ..compute.probe import probe_sandbox +from .interface import ExecutionResult, SandboxInterface class ModalSandbox(SandboxInterface): @@ -117,7 +118,7 @@ async def read_file(self, path: str) -> str: async def write_file(self, path: str, content: str) -> bool: self._ensure_active() # Use heredoc for safe content transfer - content.replace("'", "'\\''") + content = content.replace("'", "'\\''") result = await self.execute( f"mkdir -p $(dirname '{path}') && cat > '{path}' << 'OPEN_MLR_EOF'\n{content}\nOPEN_MLR_EOF", timeout=10, @@ -143,49 +144,8 @@ async def list_files(self, path: str = ".") -> list[str]: return [] return [line for line in result.output.strip().split("\n") if line] - async def probe_environment(self) -> EnvironmentInfo: - info = EnvironmentInfo() - - result = await self.execute("uname -s -r", timeout=5) - if result.success: - info.os = result.output.strip() - - result = await self.execute("python3 --version", timeout=5) - if result.success: - info.python_version = result.output.strip() - - result = await self.execute( - "nvidia-smi --query-gpu=name,memory.total --format=csv,noheader 2>/dev/null", - timeout=10, - ) - if result.success and result.output.strip(): - info.gpu_available = True - info.gpu_info = result.output.strip() - - result = await self.execute("df -BG --output=avail / 2>/dev/null | tail -1", timeout=5) - if result.success: - try: - info.available_disk_gb = float(result.output.strip().replace("G", "")) - except ValueError: - pass - - result = await self.execute( - "free -g 2>/dev/null | grep Mem | awk '{print $7}'", timeout=5 - ) - if result.success: - try: - info.available_ram_gb = float(result.output.strip()) - except ValueError: - pass - - result = await self.execute("pip list --format=freeze 2>/dev/null | head -30", timeout=10) - if result.success: - info.installed_packages = [ - line.split("==")[0] for line in result.output.strip().split("\n") - if "==" in line - ] - - return info + async def probe_environment(self): + return await probe_sandbox(self) async def destroy(self) -> None: if self._sandbox: diff --git a/backend/openmlr/sandbox/ssh.py b/backend/openmlr/sandbox/ssh.py index aa0c63d..7ef3ee1 100644 --- a/backend/openmlr/sandbox/ssh.py +++ b/backend/openmlr/sandbox/ssh.py @@ -1,9 +1,110 @@ -"""SSH sandbox — remote execution via SSH/SFTP.""" +"""SSH sandbox — remote execution via SSH/SFTP with strict host-key verification +and connection pooling.""" import asyncio +import logging import time -from .interface import EnvironmentInfo, ExecutionResult, SandboxInterface +from ..compute.probe import probe_sandbox +from .interface import ExecutionResult, SandboxInterface + +log = logging.getLogger(__name__) + + +class StrictHostKeyPolicy: + """Paramiko policy that verifies host keys against expected fingerprints.""" + + def __init__(self, expected_fingerprint: str | None = None): + self.expected = expected_fingerprint + self.actual_fingerprint: str | None = None + + def missing_host_key(self, client, hostname, key): + import paramiko + actual = key.get_fingerprint().hex() + self.actual_fingerprint = actual + if self.expected and actual != self.expected.lower().replace(":", "").replace("sha256:", ""): + raise paramiko.SSHException( + f"Host key mismatch for {hostname}: expected {self.expected}, got {actual}" + ) + return + + +class SSHConnectionPool: + """Maintains persistent SSH connections per node with TTL-based eviction. + + Connections are keyed by (host, port, username) and reused across + sandbox instances. Idle connections are closed after ``ttl_seconds``. + """ + + _instance: "SSHConnectionPool | None" = None + + def __init__(self, ttl_seconds: int = 300): + self._connections: dict[str, tuple] = {} # key -> (client, sftp, fingerprint) + self._last_used: dict[str, float] = {} + self._ttl = ttl_seconds + + @classmethod + def get_pool(cls) -> "SSHConnectionPool": + if cls._instance is None: + cls._instance = SSHConnectionPool() + return cls._instance + + @staticmethod + def _make_key(host: str, port: int, username: str) -> str: + return f"{username}@{host}:{port}" + + def get(self, host: str, port: int, username: str): + """Return (client, sftp, fingerprint) if a healthy cached connection exists, else None.""" + key = self._make_key(host, port, username) + entry = self._connections.get(key) + if entry is None: + return None + + client, sftp, fp = entry + try: + transport = client.get_transport() + if transport and transport.is_active(): + self._last_used[key] = time.monotonic() + return client, sftp, fp + except Exception: + pass + + # Connection is dead — clean up + self._evict(key) + return None + + def put(self, host: str, port: int, username: str, client, sftp, fingerprint: str | None): + """Cache a connection for reuse.""" + key = self._make_key(host, port, username) + self._connections[key] = (client, sftp, fingerprint) + self._last_used[key] = time.monotonic() + + def remove(self, host: str, port: int, username: str): + """Remove and close a cached connection.""" + key = self._make_key(host, port, username) + self._evict(key) + + def _evict(self, key: str): + entry = self._connections.pop(key, None) + self._last_used.pop(key, None) + if entry: + client, sftp, _ = entry + try: + sftp.close() + except Exception: + pass + try: + client.close() + except Exception: + pass + + def cleanup_idle(self): + """Close connections idle beyond TTL. Call periodically.""" + now = time.monotonic() + stale = [k for k, t in self._last_used.items() if now - t > self._ttl] + for key in stale: + log.debug(f"SSH pool: evicting idle connection {key}") + self._evict(key) class SSHSandbox(SandboxInterface): @@ -12,33 +113,80 @@ class SSHSandbox(SandboxInterface): def __init__(self): self._client = None self._sftp = None + self._owns_connection = False # True if we created it (not from pool) self.host: str = "" self.port: int = 22 self.username: str = "" - self.key_path: str | None = None + self.key_filename: str | None = None self.password: str | None = None self.workdir: str = "~" + self.host_key_fingerprint: str | None = None + self._key_manager = None async def create(self, config: dict) -> "SSHSandbox": self.host = config.get("host", "") self.port = config.get("port", 22) self.username = config.get("username", "root") - self.key_path = config.get("key_path") + self.key_filename = config.get("key_filename") self.password = config.get("password") self.workdir = config.get("workdir", "~") + self.host_key_fingerprint = config.get("host_key_fingerprint") + self._conversation_uuid = config.get("conversation_uuid") + + if self.key_filename: + from ..keys import KeyManager + self._key_manager = KeyManager() if not self.host: raise ValueError("SSH config requires 'host'") await self._connect() + + # Ensure remote workspace exists if conversation UUID is set + if self._conversation_uuid: + remote_ws = f"{self.workdir}/workspace-{self._conversation_uuid}" + await self._ensure_remote_workspace(remote_ws) + self.workdir = remote_ws + return self + async def _ensure_remote_workspace(self, remote_path: str) -> None: + self._ensure_connected() + + def _do_mkdir(): + subdirs = " ".join(f"{remote_path}/{d}" for d in ["data", "models", "code", "outputs", ".openmlr-meta"]) + cmd = f"mkdir -p {subdirs}" + stdin, stdout, stderr = self._client.exec_command(cmd, timeout=10) + exit_code = stdout.channel.recv_exit_status() + if exit_code != 0: + err = stderr.read().decode("utf-8", errors="replace") + raise RuntimeError(f"Failed to create remote workspace: {err}") + + await asyncio.to_thread(_do_mkdir) + async def _connect(self): - """Establish SSH connection (run in thread to avoid blocking).""" + """Get a connection from the pool or create a new one.""" + pool = SSHConnectionPool.get_pool() + pool.cleanup_idle() + + cached = pool.get(self.host, self.port, self.username) + if cached: + self._client, self._sftp, fp = cached + self._owns_connection = False + if fp and not self.host_key_fingerprint: + self.host_key_fingerprint = fp + log.debug(f"SSH pool: reusing connection to {self.username}@{self.host}:{self.port}") + return + def _do_connect(): import paramiko client = paramiko.SSHClient() - client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + + if self.host_key_fingerprint: + policy = StrictHostKeyPolicy(self.host_key_fingerprint) + client.set_missing_host_key_policy(policy) + else: + client.set_missing_host_key_policy(paramiko.WarningPolicy()) connect_kwargs = { "hostname": self.host, @@ -47,35 +195,84 @@ def _do_connect(): "timeout": 30, } - if self.key_path: - connect_kwargs["key_filename"] = self.key_path + if self.key_filename and self._key_manager: + key_path = self._key_manager.get_key_path(self.key_filename) + connect_kwargs["key_filename"] = str(key_path) elif self.password: connect_kwargs["password"] = self.password client.connect(**connect_kwargs) sftp = client.open_sftp() - return client, sftp - self._client, self._sftp = await asyncio.to_thread(_do_connect) + actual_fp = None + transport = client.get_transport() + if transport: + remote_key = transport.get_remote_server_key() + if remote_key: + actual_fp = remote_key.get_fingerprint().hex() + + return client, sftp, actual_fp + + self._client, self._sftp, actual_fp = await asyncio.to_thread(_do_connect) + self._owns_connection = True + + if actual_fp and not self.host_key_fingerprint: + self.host_key_fingerprint = actual_fp + + # Put the new connection into the pool + pool.put(self.host, self.port, self.username, self._client, self._sftp, actual_fp) def _ensure_connected(self): if not self._client or not self._client.get_transport() or not self._client.get_transport().is_active(): raise RuntimeError("SSH connection lost. Recreate the sandbox.") async def execute(self, command: str, timeout: int = 120) -> ExecutionResult: + return await self.execute_stream(command, timeout) + + async def execute_stream(self, command: str, timeout: int = 120, on_chunk=None) -> ExecutionResult: self._ensure_connected() start = time.monotonic() - def _do_exec(): + def _do_exec_stream(): full_cmd = f"cd {self.workdir} && {command}" stdin, stdout, stderr = self._client.exec_command(full_cmd, timeout=timeout) - exit_code = stdout.channel.recv_exit_status() - out = stdout.read().decode("utf-8", errors="replace") - err = stderr.read().decode("utf-8", errors="replace") - return out, err, exit_code + + out_buf = [] + err_buf = [] + channel = stdout.channel + + while not channel.exit_status_ready(): + if channel.recv_ready(): + data = channel.recv(4096).decode("utf-8", errors="replace") + out_buf.append(data) + if on_chunk: + on_chunk(data, False) + + if channel.recv_stderr_ready(): + data = channel.recv_stderr(4096).decode("utf-8", errors="replace") + err_buf.append(data) + if on_chunk: + on_chunk(data, True) + + time.sleep(0.05) + + while channel.recv_ready(): + data = channel.recv(4096).decode("utf-8", errors="replace") + out_buf.append(data) + if on_chunk: + on_chunk(data, False) + + while channel.recv_stderr_ready(): + data = channel.recv_stderr(4096).decode("utf-8", errors="replace") + err_buf.append(data) + if on_chunk: + on_chunk(data, True) + + exit_code = channel.recv_exit_status() + return "".join(out_buf), "".join(err_buf), exit_code try: - out, err, exit_code = await asyncio.to_thread(_do_exec) + out, err, exit_code = await asyncio.to_thread(_do_exec_stream) output_parts = [] if out: output_parts.append(out) @@ -155,64 +352,12 @@ def _do_list(): return await asyncio.to_thread(_do_list) - async def probe_environment(self) -> EnvironmentInfo: - info = EnvironmentInfo() - - result = await self.execute("uname -s -r", timeout=5) - if result.success: - info.os = result.output.strip() - - result = await self.execute("python3 --version", timeout=5) - if result.success: - info.python_version = result.output.strip() - - result = await self.execute( - "nvidia-smi --query-gpu=name,memory.total --format=csv,noheader 2>/dev/null", - timeout=5, - ) - if result.success and result.output.strip(): - info.gpu_available = True - info.gpu_info = result.output.strip() - - result = await self.execute("df -BG --output=avail / 2>/dev/null | tail -1", timeout=5) - if result.success: - try: - info.available_disk_gb = float(result.output.strip().replace("G", "")) - except ValueError: - pass - - result = await self.execute( - "free -g 2>/dev/null | grep Mem | awk '{print $7}'", - timeout=5, - ) - if result.success: - try: - info.available_ram_gb = float(result.output.strip()) - except ValueError: - pass - - result = await self.execute( - "pip list --format=freeze 2>/dev/null | head -30", - timeout=10, - ) - if result.success: - info.installed_packages = [ - line.split("==")[0] for line in result.output.strip().split("\n") - if "==" in line - ] - - return info + async def probe_environment(self): + return await probe_sandbox(self) async def destroy(self) -> None: - if self._sftp: - try: - self._sftp.close() - except Exception: - pass - if self._client: - try: - self._client.close() - except Exception: - pass + # Don't close pooled connections — they'll be reused. + # Only close if we own the connection and it's not pooled. + # The pool handles TTL-based eviction. self._client = None self._sftp = None diff --git a/backend/openmlr/services/session_manager.py b/backend/openmlr/services/session_manager.py index adfde1f..d59d7b1 100644 --- a/backend/openmlr/services/session_manager.py +++ b/backend/openmlr/services/session_manager.py @@ -85,8 +85,46 @@ async def get_or_create_session( ) session = Session(config=config, conversation_id=conversation_id) - sandbox_manager = SandboxManager() + + # Determine effective compute node + effective_node = None + if user_id and db: + try: + from ..db import operations as ops + # Check conversation override + conv = await ops.get_conversation_by_id(db, conversation_id) + if conv and conv.extra: + override_node_id = conv.extra.get("compute_node_id") + if override_node_id: + effective_node = await ops.get_compute_node_by_id(db, override_node_id, user_id) + + # Fall back to user default + if not effective_node: + effective_node = await ops.get_default_compute_node(db, user_id) + + if effective_node: + log.info(f"Session {conversation_id}: using compute node '{effective_node.name}' ({effective_node.type})") + except Exception as e: + log.warning(f"Session {conversation_id}: failed to load compute node - {e}") + + # Initialize workspace manager and sandbox manager + from ..compute import WorkspaceManager + workspace_manager = WorkspaceManager() + sandbox_manager = SandboxManager( + workspace_manager=workspace_manager, + conversation_uuid=uuid, + ) + + # If a compute node is configured, activate it + if effective_node: + try: + await sandbox_manager.create(effective_node.type, effective_node.config) + except Exception as e: + log.warning(f"Session {conversation_id}: failed to create sandbox for node '{effective_node.name}' - {e}") + tool_router = create_tool_router(sandbox_manager) + # Inject user/db context for compute tools + tool_router.set_context(user_id=user_id, db=db) mcp_manager = MCPManager() # Load MCP servers from user settings if available @@ -108,11 +146,53 @@ async def get_or_create_session( except Exception as e: log.warning(f"Session {conversation_id}: failed to load MCP servers - {e}") + # Build compute environment info for system prompt + compute_env = "" + if effective_node: + caps = effective_node.capabilities or {} + lines = [f"\n## Active Compute Environment: {effective_node.name} ({effective_node.type})"] + if caps.get("platform"): + lines.append(f"- Platform: {caps['platform']}") + if caps.get("cpu_cores"): + lines.append(f"- CPU: {caps['cpu_cores']} cores ({caps.get('cpu_arch', 'unknown')})") + if caps.get("available_ram_gb"): + lines.append(f"- RAM: {caps['available_ram_gb']:.1f} GB available") + if caps.get("gpu_available"): + gpu_info = caps.get("gpu_info", []) + for gpu in gpu_info[:1]: + lines.append(f"- GPU: {gpu.get('model', 'unknown')} ({gpu.get('vram_gb', 0):.0f} GB VRAM)") + if gpu.get("cuda_version"): + lines.append(f" - CUDA: {gpu['cuda_version']}") + if caps.get("python_versions"): + lines.append(f"- Python: {', '.join(caps['python_versions'])}") + if caps.get("docker_available"): + lines.append("- Docker: available") + if caps.get("installed_packages"): + pkgs = caps["installed_packages"][:10] + lines.append(f"- Key packages: {', '.join(pkgs)}") + + # Add available nodes for context + all_nodes = [] + if user_id and db: + try: + all_nodes = await ops.get_compute_nodes(db, user_id) + except Exception: + pass + if len(all_nodes) > 1: + lines.append("\n### Other Available Nodes") + for node in all_nodes: + if node.id != effective_node.id: + status = "online" if node.health_status == "online" else "offline" + lines.append(f"- {node.name} ({node.type}): {status}") + + compute_env = "\n".join(lines) + # Build and set system prompt (after MCP tools are registered) session.context_manager.system_prompt = build_system_prompt( tool_specs=tool_router.get_raw_specs(), mode=mode, username=username, + compute_env=compute_env, ) # Wire event broadcasting diff --git a/backend/openmlr/tasks/compute_tasks.py b/backend/openmlr/tasks/compute_tasks.py new file mode 100644 index 0000000..e369286 --- /dev/null +++ b/backend/openmlr/tasks/compute_tasks.py @@ -0,0 +1,113 @@ +"""Compute background tasks — health checks and periodic maintenance.""" + +import asyncio +import logging +from datetime import UTC, datetime + +from ..celery_app import celery_app +from ..compute import WorkspaceManager +from ..db import operations as ops +from ..db.engine import get_worker_session + +logger = logging.getLogger(__name__) + + +@celery_app.task +def cleanup_old_workspaces(): + """Clean up old workspace archives and orphaned workspaces.""" + wm = WorkspaceManager() + + # Clean old archives + archive_result = wm.cleanup_archives(max_age_days=30, max_count=100) + logger.info( + f"Archive cleanup: deleted {archive_result['deleted']} archives, " + f"freed {archive_result['freed_bytes'] / (1024**3):.1f} GB" + ) + + # Clean orphaned workspaces (conversations that no longer exist) + async def _cleanup_orphaned(): + session_factory = get_worker_session() + async with session_factory() as db: + from sqlalchemy import select + + from ..db.models import Conversation + result = await db.execute(select(Conversation.uuid)) + active_uuids = {row[0] for row in result.all()} + + ws_result = wm.cleanup_workspaces( + conversation_uuids=list(active_uuids), + archive=True, + ) + logger.info( + f"Workspace cleanup: deleted {ws_result['deleted']} workspaces, " + f"freed {ws_result['freed_bytes'] / (1024**3):.1f} GB" + ) + + asyncio.run(_cleanup_orphaned()) + + +@celery_app.task(bind=True, max_retries=3) +def check_compute_node_health(self, node_id: int, user_id: int): + """Check health of a single compute node.""" + async def _check(): + session_factory = get_worker_session() + async with session_factory() as db: + node = await ops.get_compute_node_by_id(db, node_id, user_id) + if not node: + logger.warning(f"Node {node_id} not found for health check") + return + + from ..compute.probe import probe_sandbox + from ..sandbox.manager import SandboxManager + + sm = SandboxManager(workspace_manager=WorkspaceManager()) + try: + await sm.create(node.type, node.config) + sandbox = sm.get_active() + + if sandbox: + caps = await probe_sandbox(sandbox) + await ops.update_compute_node( + db, node.id, user_id, + capabilities=caps.to_dict(), + health_status="online", + last_seen_at=datetime.now(UTC), + ) + logger.info(f"Health check passed for node '{node.name}'") + else: + await ops.update_compute_node( + db, node.id, user_id, + health_status="offline", + ) + logger.warning(f"Health check failed for node '{node.name}': sandbox not created") + except Exception as e: + await ops.update_compute_node( + db, node.id, user_id, + health_status="offline", + ) + logger.warning(f"Health check failed for node '{node.name}': {e}") + finally: + await sm.destroy() + + asyncio.run(_check()) + + +@celery_app.task +def health_check_all_nodes(): + """Run health checks on all compute nodes for all users.""" + async def _check_all(): + session_factory = get_worker_session() + async with session_factory() as db: + from sqlalchemy import select + + from ..db.models import User + result = await db.execute(select(User)) + users = result.scalars().all() + + for user in users: + nodes = await ops.get_compute_nodes(db, user.id) + for node in nodes: + check_compute_node_health.delay(node.id, user.id) + + asyncio.run(_check_all()) + logger.info("Queued health checks for all compute nodes") diff --git a/backend/openmlr/tools/compute_tools.py b/backend/openmlr/tools/compute_tools.py new file mode 100644 index 0000000..2d6c0c9 --- /dev/null +++ b/backend/openmlr/tools/compute_tools.py @@ -0,0 +1,527 @@ +"""Compute tools — agent-facing tools for compute node discovery and selection.""" + +import asyncio +import io +import os +from datetime import UTC, datetime +from pathlib import Path + +from ..agent.types import ToolSpec +from ..compute.probe import probe_sandbox + + +def _validate_sync_path(workspace: Path, rel_path: str) -> tuple[Path, str | None]: + """Validate that a relative path stays within the workspace. Returns (resolved, error).""" + target = (workspace / rel_path).resolve() + try: + target.relative_to(workspace.resolve()) + except ValueError: + return target, f"Path '{rel_path}' escapes workspace boundary" + return target, None + + +async def _handle_list(user_id: int = None, db=None, **kwargs): + """List all compute nodes with capabilities.""" + if not db: + return "Database connection required for compute_list", False + + from ..db import operations as ops + nodes = await ops.get_compute_nodes(db, user_id) + + if not nodes: + return "No compute nodes configured. Add nodes in Settings > Compute.", True + + lines = ["## Available Compute Nodes\n"] + for node in nodes: + caps = node.capabilities or {} + status = "●" if node.health_status == "online" else "○" + gpu = "" + if caps.get("gpu_available"): + gpu_info = caps.get("gpu_info", []) + if gpu_info: + gpu = f" — GPU: {gpu_info[0].get('model', 'unknown')}" + else: + gpu = " — GPU: yes" + + ram = "" + if caps.get("available_ram_gb"): + ram = f" — RAM: {caps['available_ram_gb']:.0f}GB" + + default = " ★" if node.is_default else "" + lines.append(f"{status} {node.name} ({node.type}){default}{gpu}{ram}") + + return "\n".join(lines), True + + +async def _handle_probe(node_name: str, user_id: int = None, db=None, **kwargs): + """Probe a compute node for capabilities.""" + if not db: + return "Database connection required for compute_probe", False + + from ..db import operations as ops + node = await ops.get_compute_node_by_name(db, user_id, node_name) + if not node: + return f"Node '{node_name}' not found", False + + # Create sandbox and probe + from ..compute import WorkspaceManager + from ..sandbox.manager import SandboxManager + + try: + wm = WorkspaceManager() + sm = SandboxManager(workspace_manager=wm) + await sm.create(node.type, node.config) + sandbox = sm.get_active() + if not sandbox: + return f"Failed to create sandbox for {node_name}", False + + caps = await probe_sandbox(sandbox) + + # Update node in database + await ops.update_compute_node( + db, node.id, user_id, + capabilities=caps.to_dict(), + health_status="online", + last_probed_at=datetime.now(UTC), + ) + + await sm.destroy() + + # Format response + lines = [f"## {node.name} Capabilities\n"] + lines.append(f"Platform: {caps.platform}") + lines.append(f"CPU: {caps.cpu_cores} cores ({caps.cpu_arch})") + lines.append(f"RAM: {caps.available_ram_gb:.1f} GB available / {caps.total_ram_gb:.1f} GB total") + lines.append(f"Disk: {caps.available_disk_gb:.1f} GB available / {caps.total_disk_gb:.1f} GB total") + + if caps.gpu_available: + for gpu in caps.gpu_info: + lines.append(f"GPU: {gpu.model} ({gpu.vram_gb:.0f} GB VRAM)") + if gpu.cuda_version: + lines.append(f" CUDA: {gpu.cuda_version}, Driver: {gpu.driver_version}") + + if caps.python_versions: + lines.append(f"Python: {', '.join(caps.python_versions)}") + + if caps.docker_available: + lines.append("Docker: available") + + if caps.installed_packages: + lines.append(f"\nKey packages: {', '.join(caps.installed_packages[:10])}") + if len(caps.installed_packages) > 10: + lines.append(f"... and {len(caps.installed_packages) - 10} more") + + return "\n".join(lines), True + + except Exception as e: + try: + await sm.destroy() + except Exception: + pass + await ops.update_compute_node( + db, node.id, user_id, + health_status="offline", + ) + return f"Probe failed for {node_name}: {str(e)}", False + + +async def _handle_select(node_name: str, user_id: int = None, db=None, session=None, **kwargs): + """Select a compute node as active for this conversation.""" + if not db: + return "Database connection required for compute_select", False + + from ..db import operations as ops + node = await ops.get_compute_node_by_name(db, user_id, node_name) + if not node: + return f"Node '{node_name}' not found", False + + # If session is provided, update the active sandbox + if session and hasattr(session, 'conversation_id'): + # Update conversation extra + conv_id = session.conversation_id + conv = await ops.get_conversation_by_id(db, conv_id) + if conv: + extra = conv.extra or {} + extra["compute_node_id"] = node.id + extra["compute_node_name"] = node.name + await ops.update_conversation_extra(db, conv_id, extra) + + return f"Active compute switched to: {node.name} ({node.type})", True + + +async def _handle_plan(task: str, requirements: dict = None, user_id: int = None, db=None, **kwargs): + """Recommend the best compute node for a task.""" + if not db: + return "Database connection required for compute_plan", False + + requirements = requirements or {} + from ..db import operations as ops + nodes = await ops.get_compute_nodes(db, user_id) + + if not nodes: + return "No compute nodes configured.", False + + # Score each node + scores = [] + for node in nodes: + if node.health_status != "online": + continue + + caps = node.capabilities or {} + score = 0 + reasons = [] + + # GPU requirement + if requirements.get("gpu"): + if not caps.get("gpu_available"): + continue + score += 10 + vram = 0 + for gpu in caps.get("gpu_info", []): + vram = max(vram, gpu.get("vram_gb", 0)) + min_vram = requirements.get("min_vram_gb", 0) + if vram < min_vram: + continue + score += min(vram / 10, 5) + reasons.append(f"GPU with {vram:.0f}GB VRAM") + + # RAM requirement + min_ram = requirements.get("min_ram_gb", 0) + available_ram = caps.get("available_ram_gb", 0) + if available_ram < min_ram: + continue + score += min(available_ram / max(min_ram, 1), 3) + if available_ram > 0: + reasons.append(f"{available_ram:.0f}GB RAM") + + # Disk requirement + min_disk = requirements.get("min_disk_gb", 0) + available_disk = caps.get("available_disk_gb", 0) + if available_disk < min_disk: + continue + if available_disk > 0: + reasons.append(f"{available_disk:.0f}GB disk") + + # Prefer local > ssh > modal + if node.type == "local": + score += 5 + reasons.append("local (low latency)") + elif node.type == "ssh": + score += 2 + reasons.append("ssh (LAN)") + elif node.type == "modal": + reasons.append("modal (cloud)") + + scores.append({ + "node": node, + "score": score, + "reasons": reasons, + }) + + if not scores: + return "No compute nodes meet the requirements.", False + + scores.sort(key=lambda x: x["score"], reverse=True) + best = scores[0] + + lines = [f"## Recommended Compute for: {task}\n"] + lines.append(f"**Best choice: {best['node'].name}** ({best['node'].type})") + lines.append(f"Score: {best['score']:.1f}") + lines.append(f"Reasons: {', '.join(best['reasons'])}") + + if len(scores) > 1: + lines.append("\n### Alternatives") + for alt in scores[1:3]: + lines.append(f"- {alt['node'].name} (score: {alt['score']:.1f}, {', '.join(alt['reasons'])})") + + return "\n".join(lines), True + + +async def _get_sync_context(user_id, db, session): + """Helper: resolve conversation UUID and workspace path for sync ops.""" + from ..db import operations as ops + conv_uuid = None + if session and hasattr(session, 'conversation_id'): + conv = await ops.get_conversation_by_id(db, session.conversation_id) + if conv: + conv_uuid = conv.uuid + if not conv_uuid: + return None, None, "No active conversation workspace found" + from ..compute import WorkspaceManager + wm = WorkspaceManager() + local_ws = wm.get_workspace_path(conv_uuid) + return conv_uuid, local_ws, None + + +async def _handle_sync_up(paths: list, node_name: str, user_id: int = None, db=None, session=None, **kwargs): + """Sync files from local workspace to remote compute node.""" + if not db: + return "Database connection required", False + + from ..db import operations as ops + node = await ops.get_compute_node_by_name(db, user_id, node_name) + if not node: + return f"Node '{node_name}' not found", False + + conv_uuid, local_ws, err = await _get_sync_context(user_id, db, session) + if err: + return err, False + if not local_ws.exists(): + return f"Local workspace not found: {local_ws}", False + + if node.type == "local": + return "Local sync: files are already in the same workspace", True + + elif node.type == "ssh": + from ..sandbox.ssh import SSHSandbox + ssh_sandbox = SSHSandbox() + try: + config = dict(node.config) + config["conversation_uuid"] = conv_uuid + await ssh_sandbox.create(config) + + transferred = 0 + for rel_path in paths: + # Path traversal check + local_path, path_err = _validate_sync_path(local_ws, rel_path) + if path_err: + return path_err, False + if not local_path.exists(): + continue + + remote_base = ssh_sandbox.workdir + + if local_path.is_dir(): + for root, _, files in os.walk(local_path): + for file in files: + src = Path(root) / file + rel = src.relative_to(local_ws) + dst = f"{remote_base}/{rel}" + dst_dir = str(Path(dst).parent) + await ssh_sandbox.execute(f"mkdir -p '{dst_dir}'", timeout=5) + content = src.read_bytes() + await asyncio.to_thread( + lambda d=dst, c=content: ssh_sandbox._sftp.putfo(io.BytesIO(c), d) + ) + transferred += 1 + else: + rel = local_path.relative_to(local_ws) + dst = f"{remote_base}/{rel}" + dst_dir = str(Path(dst).parent) + await ssh_sandbox.execute(f"mkdir -p '{dst_dir}'", timeout=5) + content = local_path.read_bytes() + await asyncio.to_thread( + lambda d=dst, c=content: ssh_sandbox._sftp.putfo(io.BytesIO(c), d) + ) + transferred += 1 + + return f"Synced {transferred} item(s) to {node.name}", True + except Exception as e: + return f"Sync failed: {str(e)}", False + finally: + await ssh_sandbox.destroy() + + elif node.type == "modal": + return "File sync not supported for Modal nodes (ephemeral)", False + + return "Unsupported node type", False + + +async def _handle_sync_down(paths: list, node_name: str, user_id: int = None, db=None, session=None, **kwargs): + """Sync files from remote compute node to local workspace.""" + if not db: + return "Database connection required", False + + from ..db import operations as ops + node = await ops.get_compute_node_by_name(db, user_id, node_name) + if not node: + return f"Node '{node_name}' not found", False + + conv_uuid, local_ws, err = await _get_sync_context(user_id, db, session) + if err: + return err, False + local_ws.mkdir(parents=True, exist_ok=True) + + if node.type == "local": + return "Local sync: files are already in the same workspace", True + + elif node.type == "ssh": + from ..sandbox.ssh import SSHSandbox + ssh_sandbox = SSHSandbox() + try: + config = dict(node.config) + config["conversation_uuid"] = conv_uuid + await ssh_sandbox.create(config) + + transferred = 0 + for rel_path in paths: + # Path traversal check + local_path, path_err = _validate_sync_path(local_ws, rel_path) + if path_err: + return path_err, False + + remote_path = f"{ssh_sandbox.workdir}/{rel_path}" + + # Check remote type + result = await ssh_sandbox.execute( + f"test -d '{remote_path}' && echo dir || test -f '{remote_path}' && echo file || echo none", + timeout=5, + ) + remote_type = result.output.strip() + if remote_type == "none": + continue + + if remote_type == "file": + local_path.parent.mkdir(parents=True, exist_ok=True) + rp = remote_path # bind for closure + + def _do_get(rpath=rp): + buf = io.BytesIO() + ssh_sandbox._sftp.getfo(rpath, buf) + buf.seek(0) + return buf.read() + + data = await asyncio.to_thread(_do_get) + local_path.write_bytes(data) + transferred += 1 + + elif remote_type == "dir": + result = await ssh_sandbox.execute(f"find '{remote_path}' -type f", timeout=10) + remote_files = [ln.strip() for ln in result.output.strip().split("\n") if ln.strip()] + for rf in remote_files: + rel = rf.replace(remote_path + "/", "", 1) + dst = local_path / rel + # Path traversal check on each individual file + _, inner_err = _validate_sync_path(local_ws, str(Path(rel_path) / rel)) + if inner_err: + continue + dst.parent.mkdir(parents=True, exist_ok=True) + + # Bind rf in default arg to avoid closure-in-loop bug + def _do_get_file(rpath=rf): + buf = io.BytesIO() + ssh_sandbox._sftp.getfo(rpath, buf) + buf.seek(0) + return buf.read() + + data = await asyncio.to_thread(_do_get_file) + dst.write_bytes(data) + transferred += 1 + + return f"Synced {transferred} item(s) from {node.name}", True + except Exception as e: + return f"Sync failed: {str(e)}", False + finally: + await ssh_sandbox.destroy() + + elif node.type == "modal": + return "File sync not supported for Modal nodes (ephemeral)", False + + return "Unsupported node type", False + + +def create_compute_tools() -> list[ToolSpec]: + """Create agent tools for compute node management.""" + return [ + ToolSpec( + name="compute_list", + description="List all configured compute nodes with their capabilities and health status.", + parameters={"type": "object", "properties": {}}, + handler=_handle_list, + ), + ToolSpec( + name="compute_probe", + description="Probe a compute node to discover its capabilities (CPU, GPU, RAM, installed packages).", + parameters={ + "type": "object", + "properties": { + "node_name": { + "type": "string", + "description": "Name of the compute node to probe", + }, + }, + "required": ["node_name"], + }, + handler=_handle_probe, + ), + ToolSpec( + name="compute_select", + description="Switch the active compute node for this conversation. Use this before running tasks that need specific hardware.", + parameters={ + "type": "object", + "properties": { + "node_name": { + "type": "string", + "description": "Name of the compute node to activate", + }, + }, + "required": ["node_name"], + }, + handler=_handle_select, + ), + ToolSpec( + name="compute_plan", + description="Recommend the best compute node for a given task based on requirements.", + parameters={ + "type": "object", + "properties": { + "task": { + "type": "string", + "description": "Description of the task (e.g., 'Train a ResNet-50 with mixed precision')", + }, + "requirements": { + "type": "object", + "description": "Hardware requirements", + "properties": { + "gpu": {"type": "boolean", "description": "GPU required"}, + "min_vram_gb": {"type": "number", "description": "Minimum GPU VRAM in GB"}, + "min_ram_gb": {"type": "number", "description": "Minimum RAM in GB"}, + "min_disk_gb": {"type": "number", "description": "Minimum free disk in GB"}, + }, + }, + }, + "required": ["task"], + }, + handler=_handle_plan, + ), + ToolSpec( + name="compute_sync_up", + description="Sync files from local workspace to a remote compute node. Use before running code that needs data on the remote.", + parameters={ + "type": "object", + "properties": { + "paths": { + "type": "array", + "items": {"type": "string"}, + "description": "Relative paths to sync (e.g., ['data/', 'code/train.py'])", + }, + "node_name": { + "type": "string", + "description": "Name of the target compute node", + }, + }, + "required": ["paths", "node_name"], + }, + handler=_handle_sync_up, + ), + ToolSpec( + name="compute_sync_down", + description="Sync files from a remote compute node to local workspace. Use after training to download models, logs, and results.", + parameters={ + "type": "object", + "properties": { + "paths": { + "type": "array", + "items": {"type": "string"}, + "description": "Relative paths to sync (e.g., ['models/', 'outputs/'])", + }, + "node_name": { + "type": "string", + "description": "Name of the source compute node", + }, + }, + "required": ["paths", "node_name"], + }, + handler=_handle_sync_down, + ), + ] diff --git a/backend/openmlr/tools/local.py b/backend/openmlr/tools/local.py index ea4a8a9..73b016d 100644 --- a/backend/openmlr/tools/local.py +++ b/backend/openmlr/tools/local.py @@ -1,6 +1,9 @@ """Local tools — bash (via Docker), read, write, edit. -bash commands run inside a Docker container for isolation. +bash commands run inside a Docker container for isolation when running locally. +When running inside a container (Docker Compose), commands run directly since +the container itself provides isolation. + read/write/edit operate on the host filesystem (for project files). """ @@ -24,6 +27,32 @@ WORKSPACE_ROOT = os.environ.get("OPENMLR_WORKSPACE_ROOT", "") +def _running_in_container() -> bool: + """Detect if we're running inside a Docker container. + + This is useful for determining whether to use Docker-in-Docker or direct execution. + When running in a container, the container already provides isolation. + """ + # Check for common container indicators + if os.path.exists("/.dockerenv"): + return True + + # Check cgroup for docker/container runtime + try: + with open("/proc/1/cgroup") as f: + content = f.read() + if "docker" in content or "containerd" in content or "kubepods" in content: + return True + except (FileNotFoundError, PermissionError): + pass + + # Check for container-related environment variables + if os.environ.get("KUBERNETES_SERVICE_HOST"): + return True + + return False + + def _validate_path(path: Path) -> tuple[Path, str | None]: """Validate path is within allowed workspace. Returns (resolved_path, error_or_none).""" try: @@ -59,10 +88,9 @@ def create_local_tools() -> list[ToolSpec]: ToolSpec( name="bash", description=( - "Execute a shell command inside a Docker container for safe isolation. " - "The container has access to a /workspace volume mapped to the project directory. " - "Use for running scripts, installing packages, training models, etc. " - "If Docker is unavailable, falls back to direct execution with a warning." + "Execute a shell command. When running locally, uses Docker for isolation. " + "When running in a containerized deployment, commands run directly in the isolated environment. " + "Use for running scripts, installing packages, training models, etc." ), parameters={ "type": "object", @@ -138,8 +166,15 @@ async def _docker_available() -> bool: async def _handle_bash(command: str, timeout: int = 120, workdir: str = None, **kwargs) -> tuple[str, bool]: timeout = min(int(timeout), 3600) - cwd = os.getcwd() + cwd = workdir or os.getcwd() + + # If we're already running inside a container, execute directly + # The container itself provides isolation, so no need for Docker-in-Docker + if _running_in_container(): + logger.debug(f"Running in container, executing directly: {command[:100]}") + return await _direct_exec(command, timeout, cwd) + # When running on host, try Docker for isolation if await _docker_available(): return await _docker_exec(command, timeout, cwd, workdir) else: diff --git a/backend/openmlr/tools/registry.py b/backend/openmlr/tools/registry.py index cbd3184..67f9997 100644 --- a/backend/openmlr/tools/registry.py +++ b/backend/openmlr/tools/registry.py @@ -17,6 +17,8 @@ "github_search", "github_read_file", "github_read_repo", "github_find_examples", "github_search_repos", "github_get_readme", "github_list_repos", + # Compute planning (read-only / advisory) + "compute_list", "compute_plan", "compute_probe", }, "blocked_message": ( "Tool '{tool}' is not available in PLAN mode. " @@ -43,6 +45,13 @@ def __init__(self): self._mcp_client = None self._blocklist: set[str] = set() self._current_mode: str = "general" + self._user_id: int | None = None + self._db = None + + def set_context(self, user_id: int | None = None, db=None) -> None: + """Set per-request context (user_id, db) for tools that need them.""" + self._user_id = user_id + self._db = db def register(self, spec: ToolSpec) -> None: """Register a tool.""" @@ -157,6 +166,11 @@ async def call_tool( # Also pass tool_call_id if the handler accepts it if "tool_call_id" in sig.parameters and "tool_call_id" not in kwargs: kwargs["tool_call_id"] = kwargs.pop("id", "") + # Inject user_id and db for tools that need them (compute tools) + if "user_id" in sig.parameters and "user_id" not in kwargs: + kwargs["user_id"] = self._user_id + if "db" in sig.parameters and "db" not in kwargs: + kwargs["db"] = self._db try: return await tool.handler(**kwargs) if kwargs else await tool.handler(**arguments) except TypeError as e: @@ -237,6 +251,10 @@ def create_tool_router(sandbox_manager=None) -> ToolRouter: router.register(create_writing_tool()) router.register(create_ask_user_tool()) + # Register compute tools + from .compute_tools import create_compute_tools + router.register_many(create_compute_tools()) + # Register sandbox tools if manager provided if sandbox_manager: from .sandbox_tools import create_sandbox_tools diff --git a/backend/openmlr/tools/sandbox_tools.py b/backend/openmlr/tools/sandbox_tools.py index c73322a..2792927 100644 --- a/backend/openmlr/tools/sandbox_tools.py +++ b/backend/openmlr/tools/sandbox_tools.py @@ -1,6 +1,8 @@ """Sandbox tools — expose execution environments to the agent.""" -from ..agent.types import ToolSpec +import asyncio + +from ..agent.types import AgentEvent, ToolSpec def create_sandbox_tools(sandbox_manager) -> list[ToolSpec]: @@ -56,13 +58,14 @@ def create_sandbox_tools(sandbox_manager) -> list[ToolSpec]: name="sandbox_exec", description=( "Execute a command in the active sandbox. If no sandbox is active, " - "falls back to local execution." + "falls back to local execution. Use stream=true for long-running commands." ), parameters={ "type": "object", "properties": { "command": {"type": "string", "description": "Shell command to execute"}, - "timeout": {"type": "integer", "description": "Timeout in seconds (default 120)"}, + "timeout": {"type": "integer", "description": "Timeout in seconds (default 120, max 3600)"}, + "stream": {"type": "boolean", "description": "Stream output in real-time for long-running commands (default false)"}, }, "required": ["command"], }, @@ -102,17 +105,24 @@ async def _handle_probe(sandbox_manager, session=None, **kwargs) -> tuple[str, b return "No active sandbox. Using local environment.\n" + await _local_probe(), True try: - info = await sandbox.probe_environment() + caps = await sandbox.probe_environment() lines = [ f"## Sandbox Environment ({sandbox_manager.active_type})\n", - f"OS: {info.os}", - f"Python: {info.python_version}", - f"GPU: {'Yes — ' + info.gpu_info if info.gpu_available else 'No'}", - f"Disk: {info.available_disk_gb:.1f} GB free", - f"RAM: {info.available_ram_gb:.1f} GB free", + f"Platform: {caps.platform}", + f"CPU: {caps.cpu_cores} cores ({caps.cpu_arch})", + f"Python: {', '.join(caps.python_versions) if caps.python_versions else 'unknown'}", ] - if info.installed_packages: - lines.append(f"\nKey packages: {', '.join(info.installed_packages[:20])}") + if caps.gpu_available and caps.gpu_info: + for gpu in caps.gpu_info: + lines.append(f"GPU: {gpu.model} ({gpu.vram_gb:.0f} GB VRAM)") + elif caps.gpu_available: + lines.append("GPU: available") + else: + lines.append("GPU: not available") + lines.append(f"Disk: {caps.available_disk_gb:.1f} GB free") + lines.append(f"RAM: {caps.available_ram_gb:.1f} GB free") + if caps.installed_packages: + lines.append(f"\nKey packages: {', '.join(caps.installed_packages[:20])}") return "\n".join(lines), True except Exception as e: return f"Probe failed: {str(e)}", False @@ -156,7 +166,7 @@ async def _handle_create(sandbox_manager, provider: str, config: dict = None, se return f"Failed to create sandbox: {str(e)}", False -async def _handle_exec(sandbox_manager, command: str, timeout: int = 120, session=None, **kwargs) -> tuple[str, bool]: +async def _handle_exec(sandbox_manager, command: str, timeout: int = 120, stream: bool = False, session=None, **kwargs) -> tuple[str, bool]: sandbox = sandbox_manager.get_active() if not sandbox: # Fall back to local execution @@ -164,8 +174,25 @@ async def _handle_exec(sandbox_manager, command: str, timeout: int = 120, sessio return await _handle_bash(command=command, timeout=timeout) try: - result = await sandbox.execute(command, timeout=timeout) - return result.output, result.success + if stream and session: + # Stream output via tool_log events + # on_chunk may be called from a worker thread (SSH), so use + # call_soon_threadsafe to schedule the coroutine on the event loop. + loop = asyncio.get_running_loop() + + def on_chunk(text: str, is_stderr: bool): + prefix = "STDERR: " if is_stderr else "" + event = AgentEvent( + event_type="tool_log", + data={"message": f"{prefix}{text.rstrip()}"}, + ) + loop.call_soon_threadsafe(asyncio.ensure_future, session.emit(event)) + + result = await sandbox.execute_stream(command, timeout=timeout, on_chunk=on_chunk) + return result.output, result.success + else: + result = await sandbox.execute(command, timeout=timeout) + return result.output, result.success except Exception as e: return f"Execution error: {str(e)}", False diff --git a/backend/openmlr/tools/writing.py b/backend/openmlr/tools/writing.py index 827fccc..2956541 100644 --- a/backend/openmlr/tools/writing.py +++ b/backend/openmlr/tools/writing.py @@ -52,10 +52,10 @@ async def _load_project(conv_id: int) -> dict | None: async def _get_author_info(db, conv_id: int) -> dict | None: """Fetch author settings for the conversation's user.""" # Get conversation to find user_id - conv = await ops.get_conversation(db, conv_id) + conv = await ops.get_conversation_by_id(db, conv_id) if not conv or not conv.user_id: return None - + # Fetch author-related settings author_info = {} for key in ["author_name", "author_email", "author_affiliation", "author_orcid"]: @@ -63,7 +63,7 @@ async def _get_author_info(db, conv_id: int) -> dict | None: if setting: field = key.replace("author_", "") author_info[field] = setting - + return author_info if author_info else None @@ -320,21 +320,21 @@ async def _get_draft(conv_id: int) -> tuple[str, bool]: proj = _get_project(conv_id) if not proj: return "No paper project exists.", False - + # Fetch author info author_info = None if conv_id: session_factory = _get_session_factory() async with session_factory() as db: author_info = await _get_author_info(db, conv_id) - + return _get_draft_from_proj(proj, author_info) def _get_draft_from_proj(proj: dict, author_info: dict | None = None) -> tuple[str, bool]: """Generate the full markdown draft from a project dict.""" lines = [f"# {proj['title']}\n"] - + # Add author information block if available if author_info: author_lines = [] @@ -346,7 +346,7 @@ def _get_draft_from_proj(proj: dict, author_info: dict | None = None) -> tuple[s author_lines.append(f"Email: {author_info['email']}") if author_info.get("orcid"): author_lines.append(f"ORCID: [{author_info['orcid']}](https://orcid.org/{author_info['orcid']})") - + if author_lines: lines.append("\n".join(author_lines)) lines.append("\n---\n") diff --git a/backend/tests/test_compute.py b/backend/tests/test_compute.py new file mode 100644 index 0000000..d82d46d --- /dev/null +++ b/backend/tests/test_compute.py @@ -0,0 +1,693 @@ +"""Tests for the compute node ecosystem — KeyManager, WorkspaceManager, +ComputeCapabilities, SSHConnectionPool, compute tools, and routes.""" + +from unittest.mock import MagicMock + +import pytest + +pytestmark = pytest.mark.asyncio + +from openmlr.compute.capabilities import ComputeCapabilities, GPUInfo +from openmlr.compute.manager import ComputeManager +from openmlr.compute.workspace import WorkspaceManager +from openmlr.keys.manager import KeyManager +from openmlr.sandbox.ssh import SSHConnectionPool +from openmlr.tools.compute_tools import _validate_sync_path +from openmlr.tools.registry import MODE_TOOL_RESTRICTIONS, ToolRouter + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + +@pytest.fixture +def tmp_keys_dir(tmp_path): + keys_dir = tmp_path / ".keys" + keys_dir.mkdir() + return keys_dir + + +@pytest.fixture +def key_manager(tmp_keys_dir): + return KeyManager(keys_dir=tmp_keys_dir) + + +@pytest.fixture +def tmp_workspace_dir(tmp_path): + return tmp_path / ".openmlr" + + +@pytest.fixture +def workspace_manager(tmp_workspace_dir): + return WorkspaceManager(base_dir=tmp_workspace_dir) + + +# --------------------------------------------------------------------------- +# KeyManager +# --------------------------------------------------------------------------- + +class TestKeyManager: + def test_init_creates_dir(self, tmp_keys_dir, key_manager): + assert tmp_keys_dir.exists() + assert oct(tmp_keys_dir.stat().st_mode)[-3:] == "700" + + def test_list_keys_empty(self, key_manager): + assert key_manager.list_keys() == [] + + def test_generate_key_pair_ed25519(self, key_manager): + priv, pub = key_manager.generate_key_pair("id_test_ed", "ed25519", "test@host") + assert priv.exists() + assert pub.exists() + assert "id_test_ed" in priv.name + # Private key should be 0o600 + mode = oct(priv.stat().st_mode)[-3:] + assert mode == "600" + + def test_generate_key_pair_rsa(self, key_manager): + priv, pub = key_manager.generate_key_pair("id_test_rsa", "rsa", "test@host") + assert priv.exists() + assert pub.exists() + + def test_generate_unsupported_algorithm(self, key_manager): + with pytest.raises(ValueError, match="Unsupported algorithm"): + key_manager.generate_key_pair("id_bad", "dsa", "test") + + def test_key_exists(self, key_manager): + key_manager.generate_key_pair("id_exist", "ed25519") + assert key_manager.key_exists("id_exist") is True + assert key_manager.key_exists("id_nope") is False + + def test_list_keys_after_generate(self, key_manager): + key_manager.generate_key_pair("id_list_test", "ed25519") + keys = key_manager.list_keys() + assert len(keys) == 1 + assert keys[0]["filename"] == "id_list_test" + assert keys[0]["has_public"] is True + + def test_delete_key(self, key_manager): + key_manager.generate_key_pair("id_del", "ed25519") + assert key_manager.key_exists("id_del") + result = key_manager.delete_key("id_del") + assert result is True + assert not key_manager.key_exists("id_del") + + def test_delete_nonexistent(self, key_manager): + result = key_manager.delete_key("nope") + assert result is False + + def test_write_and_read_key(self, key_manager): + key_manager.write_key("id_manual", "-----BEGIN FAKE KEY-----\ndata\n-----END FAKE KEY-----\n") + content = key_manager.read_key("id_manual") + assert "FAKE KEY" in content + + def test_read_nonexistent_key(self, key_manager): + with pytest.raises(FileNotFoundError): + key_manager.read_key("nope") + + def test_validate_key_ed25519(self, key_manager): + key_manager.generate_key_pair("id_val", "ed25519", "comment") + private = key_manager.read_key("id_val") + meta = key_manager.validate_key(private) + assert meta["algorithm"] == "ssh-ed25519" + assert meta["fingerprint"].startswith("SHA256:") + assert len(meta["fingerprint"]) > 10 + assert meta["public_key"].startswith("ssh-ed25519") + + def test_validate_key_rsa(self, key_manager): + key_manager.generate_key_pair("id_val_rsa", "rsa") + private = key_manager.read_key("id_val_rsa") + meta = key_manager.validate_key(private) + assert meta["algorithm"] == "ssh-rsa" + assert meta["fingerprint"].startswith("SHA256:") + + def test_validate_invalid_key(self, key_manager): + with pytest.raises(ValueError, match="Invalid private key"): + key_manager.validate_key("not a key") + + def test_get_key_path(self, key_manager, tmp_keys_dir): + path = key_manager.get_key_path("id_some") + assert path == tmp_keys_dir / "id_some" + + +# --------------------------------------------------------------------------- +# WorkspaceManager +# --------------------------------------------------------------------------- + +class TestWorkspaceManager: + def test_create_workspace(self, workspace_manager): + path = workspace_manager.create_workspace("test-uuid-123") + assert path.exists() + assert (path / "data").exists() + assert (path / "models").exists() + assert (path / "code").exists() + assert (path / "outputs").exists() + assert (path / ".openmlr-meta").exists() + + def test_get_workspace_path(self, workspace_manager): + path = workspace_manager.get_workspace_path("abc") + assert "workspace-abc" in str(path) + + def test_workspace_exists(self, workspace_manager): + assert workspace_manager.workspace_exists("nope") is False + workspace_manager.create_workspace("nope") + assert workspace_manager.workspace_exists("nope") is True + + def test_delete_workspace_with_archive(self, workspace_manager): + workspace_manager.create_workspace("del-test") + ws_path = workspace_manager.get_workspace_path("del-test") + (ws_path / "data" / "file.txt").write_text("hello") + + result = workspace_manager.delete_workspace("del-test", archive=True) + assert result is True + assert not ws_path.exists() + # Check archive was created + archives = list(workspace_manager.archive_dir.glob("*.tar.gz")) + assert len(archives) == 1 + + def test_delete_workspace_without_archive(self, workspace_manager): + workspace_manager.create_workspace("del-no-archive") + result = workspace_manager.delete_workspace("del-no-archive", archive=False) + assert result is True + archives = list(workspace_manager.archive_dir.glob("*.tar.gz")) + assert len(archives) == 0 + + def test_delete_nonexistent(self, workspace_manager): + result = workspace_manager.delete_workspace("nonexistent") + assert result is False + + def test_get_workspace_size(self, workspace_manager): + workspace_manager.create_workspace("size-test") + path = workspace_manager.get_workspace_path("size-test") + (path / "data" / "big.bin").write_bytes(b"x" * 1024) + size = workspace_manager.get_workspace_size("size-test") + assert size >= 1024 + + def test_list_workspaces(self, workspace_manager): + workspace_manager.create_workspace("ws-a") + workspace_manager.create_workspace("ws-b") + ws_list = workspace_manager.list_workspaces() + uuids = [w["uuid"] for w in ws_list] + assert "ws-a" in uuids + assert "ws-b" in uuids + + def test_cleanup_archives(self, workspace_manager): + # Create and archive 3 workspaces + for i in range(3): + workspace_manager.create_workspace(f"cleanup-{i}") + workspace_manager.archive_workspace(f"cleanup-{i}") + + result = workspace_manager.cleanup_archives(max_age_days=0, max_count=1) + assert result["deleted"] >= 2 + remaining = list(workspace_manager.archive_dir.glob("*.tar.gz")) + assert len(remaining) <= 1 + + def test_cleanup_workspaces_orphaned(self, workspace_manager): + workspace_manager.create_workspace("keep") + workspace_manager.create_workspace("orphan") + result = workspace_manager.cleanup_workspaces( + conversation_uuids=["keep"], + archive=False, + ) + assert result["deleted"] == 1 + assert workspace_manager.workspace_exists("keep") + assert not workspace_manager.workspace_exists("orphan") + + +# --------------------------------------------------------------------------- +# ComputeCapabilities +# --------------------------------------------------------------------------- + +class TestComputeCapabilities: + def test_defaults(self): + caps = ComputeCapabilities() + assert caps.platform == "unknown" + assert caps.cpu_cores == 0 + assert caps.gpu_available is False + assert caps.gpu_info == [] + + def test_to_dict(self): + caps = ComputeCapabilities( + cpu_cores=8, + gpu_available=True, + gpu_info=[GPUInfo(model="A100", vram_gb=80.0, cuda_version="12.4")], + ) + d = caps.to_dict() + assert d["cpu_cores"] == 8 + assert d["gpu_available"] is True + assert len(d["gpu_info"]) == 1 + assert d["gpu_info"][0]["model"] == "A100" + + def test_from_dict(self): + d = { + "platform": "Linux", + "cpu_cores": 4, + "gpu_available": True, + "gpu_info": [{"model": "RTX 4090", "vram_gb": 24, "cuda_version": "12.4", "driver_version": "545"}], + } + caps = ComputeCapabilities.from_dict(d) + assert caps.platform == "Linux" + assert caps.cpu_cores == 4 + assert len(caps.gpu_info) == 1 + assert caps.gpu_info[0].model == "RTX 4090" + + def test_roundtrip(self): + original = ComputeCapabilities( + platform="test", + cpu_cores=16, + available_ram_gb=32.5, + gpu_available=True, + gpu_count=2, + gpu_info=[ + GPUInfo(model="A100", vram_gb=80), + GPUInfo(model="A100", vram_gb=80), + ], + python_versions=["3.12", "3.11"], + docker_available=True, + ) + d = original.to_dict() + restored = ComputeCapabilities.from_dict(d) + assert restored.platform == "test" + assert restored.cpu_cores == 16 + assert restored.available_ram_gb == 32.5 + assert len(restored.gpu_info) == 2 + assert restored.docker_available is True + + +# --------------------------------------------------------------------------- +# ComputeManager (validation) +# --------------------------------------------------------------------------- + +class TestComputeManager: + def test_validate_ssh_missing_host(self, key_manager): + cm = ComputeManager(key_manager) + ok, err = cm.validate_node_config("ssh", {"username": "user"}) + assert ok is False + assert "host" in err + + def test_validate_ssh_missing_username(self, key_manager): + cm = ComputeManager(key_manager) + ok, err = cm.validate_node_config("ssh", {"host": "example.com"}) + assert ok is False + assert "username" in err + + def test_validate_ssh_ok(self, key_manager): + cm = ComputeManager(key_manager) + ok, err = cm.validate_node_config("ssh", {"host": "example.com", "username": "user"}) + assert ok is True + + def test_validate_ssh_missing_key(self, key_manager): + cm = ComputeManager(key_manager) + ok, err = cm.validate_node_config("ssh", { + "host": "x", "username": "u", "key_filename": "nonexistent", + }) + assert ok is False + assert "not found" in err + + def test_validate_local_ok(self, key_manager): + cm = ComputeManager(key_manager) + ok, err = cm.validate_node_config("local", {}) + assert ok is True + + def test_validate_local_file_not_dir(self, key_manager, tmp_path): + f = tmp_path / "not_a_dir" + f.write_text("data") + cm = ComputeManager(key_manager) + ok, err = cm.validate_node_config("local", {"workdir": str(f)}) + assert ok is False + + def test_validate_modal_ok(self, key_manager): + cm = ComputeManager(key_manager) + ok, err = cm.validate_node_config("modal", {}) + assert ok is True + + def test_validate_unknown_type(self, key_manager): + cm = ComputeManager(key_manager) + ok, err = cm.validate_node_config("kubernetes", {}) + assert ok is False + assert "Unknown" in err + + +# --------------------------------------------------------------------------- +# SSHConnectionPool +# --------------------------------------------------------------------------- + +class TestSSHConnectionPool: + def test_singleton(self): + pool1 = SSHConnectionPool.get_pool() + pool2 = SSHConnectionPool.get_pool() + assert pool1 is pool2 + + def test_make_key(self): + assert SSHConnectionPool._make_key("host", 22, "user") == "user@host:22" + + def test_get_empty(self): + pool = SSHConnectionPool(ttl_seconds=300) + assert pool.get("host", 22, "user") is None + + def test_put_and_get(self): + pool = SSHConnectionPool(ttl_seconds=300) + # Mock a client with active transport + mock_client = MagicMock() + mock_transport = MagicMock() + mock_transport.is_active.return_value = True + mock_client.get_transport.return_value = mock_transport + mock_sftp = MagicMock() + + pool.put("host", 22, "user", mock_client, mock_sftp, "fp123") + result = pool.get("host", 22, "user") + assert result is not None + client, sftp, fp = result + assert client is mock_client + assert sftp is mock_sftp + assert fp == "fp123" + + def test_get_dead_connection(self): + pool = SSHConnectionPool(ttl_seconds=300) + mock_client = MagicMock() + mock_transport = MagicMock() + mock_transport.is_active.return_value = False + mock_client.get_transport.return_value = mock_transport + mock_sftp = MagicMock() + + pool.put("host", 22, "user", mock_client, mock_sftp, "fp") + result = pool.get("host", 22, "user") + assert result is None + + def test_cleanup_idle(self): + pool = SSHConnectionPool(ttl_seconds=0) # immediate expiry + mock_client = MagicMock() + mock_sftp = MagicMock() + pool.put("host", 22, "user", mock_client, mock_sftp, "fp") + pool._last_used["user@host:22"] = 0 # force stale + pool.cleanup_idle() + assert pool.get("host", 22, "user") is None + mock_sftp.close.assert_called_once() + mock_client.close.assert_called_once() + + def test_remove(self): + pool = SSHConnectionPool(ttl_seconds=300) + mock_client = MagicMock() + mock_sftp = MagicMock() + pool.put("host", 22, "user", mock_client, mock_sftp, "fp") + pool.remove("host", 22, "user") + assert pool.get("host", 22, "user") is None + + +# --------------------------------------------------------------------------- +# Path traversal validation +# --------------------------------------------------------------------------- + +class TestPathTraversal: + def test_valid_relative_path(self, tmp_path): + ws = tmp_path / "workspace" + ws.mkdir() + path, err = _validate_sync_path(ws, "data/file.txt") + assert err is None + assert str(ws) in str(path) + + def test_traversal_blocked(self, tmp_path): + ws = tmp_path / "workspace" + ws.mkdir() + path, err = _validate_sync_path(ws, "../../etc/passwd") + assert err is not None + assert "escapes" in err + + def test_absolute_path_blocked(self, tmp_path): + ws = tmp_path / "workspace" + ws.mkdir() + path, err = _validate_sync_path(ws, "/etc/passwd") + assert err is not None + assert "escapes" in err + + def test_nested_valid_path(self, tmp_path): + ws = tmp_path / "workspace" + ws.mkdir() + path, err = _validate_sync_path(ws, "data/subdir/deep/file.csv") + assert err is None + + +# --------------------------------------------------------------------------- +# ToolRouter compute context injection +# --------------------------------------------------------------------------- + +class TestToolRouterContext: + def test_set_context(self): + router = ToolRouter() + router.set_context(user_id=42, db="fake_db") + assert router._user_id == 42 + assert router._db == "fake_db" + + async def test_context_injected_into_handler(self): + router = ToolRouter() + router.set_context(user_id=42, db="fake_db") + + async def handler(user_id: int = None, db=None, arg: str = "") -> tuple[str, bool]: + return f"uid={user_id},db={db},arg={arg}", True + + from openmlr.agent.types import ToolSpec + tool = ToolSpec( + name="ctx_test", description="test", parameters={"type": "object", "properties": {}}, + handler=handler, + ) + router.register(tool) + result, ok = await router.call_tool("ctx_test", {"arg": "hello"}) + assert ok is True + assert "uid=42" in result + assert "db=fake_db" in result + assert "arg=hello" in result + + +# --------------------------------------------------------------------------- +# Plan mode allows compute tools +# --------------------------------------------------------------------------- + +class TestPlanModeComputeTools: + def test_compute_list_allowed(self): + assert "compute_list" in MODE_TOOL_RESTRICTIONS["plan"]["allowed"] + + def test_compute_plan_allowed(self): + assert "compute_plan" in MODE_TOOL_RESTRICTIONS["plan"]["allowed"] + + def test_compute_probe_allowed(self): + assert "compute_probe" in MODE_TOOL_RESTRICTIONS["plan"]["allowed"] + + def test_compute_select_not_in_plan(self): + assert "compute_select" not in MODE_TOOL_RESTRICTIONS["plan"]["allowed"] + + +# --------------------------------------------------------------------------- +# Config redaction (routes/compute.py) +# --------------------------------------------------------------------------- + +class TestConfigRedaction: + def test_redact_password(self): + from openmlr.routes.compute import _redact_config + config = {"host": "example.com", "password": "secret123", "username": "user"} + redacted = _redact_config(config) + assert redacted["host"] == "example.com" + assert redacted["password"] == "***" + assert redacted["username"] == "user" + + def test_redact_empty_config(self): + from openmlr.routes.compute import _redact_config + assert _redact_config({}) == {} + assert _redact_config(None) == {} + + def test_redact_no_sensitive_fields(self): + from openmlr.routes.compute import _redact_config + config = {"host": "x", "port": 22} + assert _redact_config(config) == config + + +# --------------------------------------------------------------------------- +# Routes (keys + compute) — integration via httpx +# --------------------------------------------------------------------------- + +class TestKeyRoutes: + async def test_list_keys_empty(self, auth_client): + resp = await auth_client.get("/api/keys") + assert resp.status_code == 200 + assert resp.json()["keys"] == [] + + async def test_generate_key(self, auth_client): + resp = await auth_client.post("/api/keys", json={ + "action": "generate", + "filename": "id_test_route", + "algorithm": "ed25519", + "comment": "test", + }) + assert resp.status_code == 200 + data = resp.json()["key"] + assert data["filename"] == "id_test_route" + assert data["algorithm"] == "ssh-ed25519" + assert data["fingerprint"].startswith("SHA256:") + + async def test_generate_duplicate(self, auth_client): + await auth_client.post("/api/keys", json={ + "action": "generate", "filename": "id_dup", "algorithm": "ed25519", + }) + resp = await auth_client.post("/api/keys", json={ + "action": "generate", "filename": "id_dup", "algorithm": "ed25519", + }) + assert resp.status_code == 409 + + async def test_delete_key(self, auth_client): + await auth_client.post("/api/keys", json={ + "action": "generate", "filename": "id_to_del", "algorithm": "ed25519", + }) + resp = await auth_client.delete("/api/keys/id_to_del") + assert resp.status_code == 200 + assert resp.json()["ok"] is True + + async def test_delete_nonexistent_key(self, auth_client): + resp = await auth_client.delete("/api/keys/id_nope") + assert resp.status_code == 404 + + async def test_create_key_missing_filename(self, auth_client): + resp = await auth_client.post("/api/keys", json={"action": "generate"}) + assert resp.status_code == 400 + + async def test_create_key_invalid_action(self, auth_client): + resp = await auth_client.post("/api/keys", json={ + "action": "nope", "filename": "id_x", + }) + assert resp.status_code == 400 + + async def test_unauthenticated_keys(self, client): + resp = await client.get("/api/keys") + assert resp.status_code == 401 + + +class TestComputeNodeRoutes: + async def test_list_empty(self, auth_client): + resp = await auth_client.get("/api/compute/nodes") + assert resp.status_code == 200 + assert resp.json()["nodes"] == [] + + async def test_create_local_node(self, auth_client): + resp = await auth_client.post("/api/compute/nodes", json={ + "name": "My Laptop", + "type": "local", + "config": {}, + }) + assert resp.status_code == 200 + node = resp.json()["node"] + assert node["name"] == "My Laptop" + assert node["type"] == "local" + assert node["health_status"] == "unknown" + + async def test_create_duplicate_name(self, auth_client): + await auth_client.post("/api/compute/nodes", json={ + "name": "Dup", "type": "local", "config": {}, + }) + resp = await auth_client.post("/api/compute/nodes", json={ + "name": "Dup", "type": "local", "config": {}, + }) + assert resp.status_code == 409 + + async def test_create_invalid_type(self, auth_client): + resp = await auth_client.post("/api/compute/nodes", json={ + "name": "Bad", "type": "kubernetes", "config": {}, + }) + assert resp.status_code == 400 + + async def test_get_node(self, auth_client): + create_resp = await auth_client.post("/api/compute/nodes", json={ + "name": "Get Test", "type": "local", "config": {}, + }) + node_id = create_resp.json()["node"]["id"] + resp = await auth_client.get(f"/api/compute/nodes/{node_id}") + assert resp.status_code == 200 + assert resp.json()["node"]["name"] == "Get Test" + + async def test_update_node(self, auth_client): + create_resp = await auth_client.post("/api/compute/nodes", json={ + "name": "Update Test", "type": "local", "config": {}, + }) + node_id = create_resp.json()["node"]["id"] + resp = await auth_client.put(f"/api/compute/nodes/{node_id}", json={ + "name": "Updated Name", + }) + assert resp.status_code == 200 + assert resp.json()["node"]["name"] == "Updated Name" + + async def test_delete_node(self, auth_client): + create_resp = await auth_client.post("/api/compute/nodes", json={ + "name": "Delete Test", "type": "local", "config": {}, + }) + node_id = create_resp.json()["node"]["id"] + resp = await auth_client.delete(f"/api/compute/nodes/{node_id}") + assert resp.status_code == 200 + assert resp.json()["ok"] is True + + async def test_set_default(self, auth_client): + create_resp = await auth_client.post("/api/compute/nodes", json={ + "name": "Default Test", "type": "local", "config": {}, + }) + node_id = create_resp.json()["node"]["id"] + resp = await auth_client.post(f"/api/compute/nodes/{node_id}/set-default") + assert resp.status_code == 200 + + # Verify it's now default + get_resp = await auth_client.get(f"/api/compute/nodes/{node_id}") + assert get_resp.json()["node"]["is_default"] is True + + async def test_config_redacted_in_response(self, auth_client): + resp = await auth_client.post("/api/compute/nodes", json={ + "name": "Redact Test", + "type": "ssh", + "config": {"host": "x", "username": "u", "password": "secret"}, + }) + assert resp.status_code == 200 + node = resp.json()["node"] + assert node["config"]["password"] == "***" + assert node["config"]["host"] == "x" + + async def test_test_local_node(self, auth_client): + create_resp = await auth_client.post("/api/compute/nodes", json={ + "name": "Test Local", "type": "local", "config": {}, + }) + node_id = create_resp.json()["node"]["id"] + resp = await auth_client.post(f"/api/compute/nodes/{node_id}/test") + assert resp.status_code == 200 + # Local test should pass (workspace will be CWD) + assert resp.json()["ok"] is True + + async def test_test_config_endpoint(self, auth_client): + resp = await auth_client.post("/api/compute/test", json={ + "type": "local", + "config": {}, + }) + assert resp.status_code == 200 + assert resp.json()["ok"] is True + + async def test_test_config_invalid_type(self, auth_client): + resp = await auth_client.post("/api/compute/test", json={ + "type": "kubernetes", + "config": {}, + }) + assert resp.status_code == 200 + assert resp.json()["ok"] is False + + async def test_unauthenticated(self, client): + resp = await client.get("/api/compute/nodes") + assert resp.status_code == 401 + + +# --------------------------------------------------------------------------- +# System prompt includes compute_env +# --------------------------------------------------------------------------- + +class TestSystemPromptCompute: + def test_prompt_includes_compute_env(self): + from openmlr.agent.prompts import build_system_prompt + prompt = build_system_prompt( + tool_specs=[], + compute_env="## Active Compute: TestNode (ssh)\n- CPU: 8 cores", + ) + assert "TestNode" in prompt + assert "8 cores" in prompt + + def test_prompt_without_compute_env(self): + from openmlr.agent.prompts import build_system_prompt + prompt = build_system_prompt(tool_specs=[], compute_env="") + assert "Active Compute" not in prompt diff --git a/backend/tests/test_sandbox_types.py b/backend/tests/test_sandbox_types.py index fc7553f..e64aeb6 100644 --- a/backend/tests/test_sandbox_types.py +++ b/backend/tests/test_sandbox_types.py @@ -2,34 +2,43 @@ import pytest -from openmlr.sandbox.interface import EnvironmentInfo, ExecutionResult, SandboxInterface +from openmlr.compute.capabilities import ComputeCapabilities, GPUInfo +from openmlr.sandbox.interface import ExecutionResult, SandboxInterface from openmlr.sandbox.local import LocalSandbox -class TestEnvironmentInfo: +class TestComputeCapabilities: def test_defaults(self): - info = EnvironmentInfo() - assert info.os == "unknown" - assert info.python_version == "unknown" - assert info.gpu_available is False - assert info.gpu_info is None - assert info.installed_packages == [] - assert info.available_disk_gb == 0.0 - assert info.available_ram_gb == 0.0 + caps = ComputeCapabilities() + assert caps.platform == "unknown" + assert caps.cpu_cores == 0 + assert caps.gpu_available is False + assert caps.gpu_info == [] + assert caps.installed_packages == [] + assert caps.available_disk_gb == 0.0 + assert caps.available_ram_gb == 0.0 def test_custom_values(self): - info = EnvironmentInfo( - os="Linux", - python_version="3.12.0", + caps = ComputeCapabilities( + platform="Linux 6.5.0", + cpu_cores=8, gpu_available=True, - gpu_info="NVIDIA A100", - installed_packages=["torch", "numpy"], + gpu_info=[GPUInfo(model="NVIDIA A100", vram_gb=80.0)], + installed_packages=["torch==2.3.0", "numpy==1.26.0"], available_disk_gb=50.0, available_ram_gb=32.0, ) - assert info.os == "Linux" - assert info.gpu_available is True - assert "torch" in info.installed_packages + assert caps.platform == "Linux 6.5.0" + assert caps.gpu_available is True + assert len(caps.gpu_info) == 1 + assert caps.gpu_info[0].model == "NVIDIA A100" + + def test_to_dict_roundtrip(self): + caps = ComputeCapabilities(cpu_cores=4, gpu_available=True) + d = caps.to_dict() + caps2 = ComputeCapabilities.from_dict(d) + assert caps2.cpu_cores == 4 + assert caps2.gpu_available is True class TestExecutionResult: diff --git a/backend/tests/test_tools_local.py b/backend/tests/test_tools_local.py index e2608d0..a1f007f 100644 --- a/backend/tests/test_tools_local.py +++ b/backend/tests/test_tools_local.py @@ -11,6 +11,7 @@ _handle_edit, _handle_read, _handle_write, + _running_in_container, _validate_path, create_local_tools, ) @@ -222,3 +223,26 @@ def test_allow_direct_exec_default(self, monkeypatch): import openmlr.tools.local allow = openmlr.tools.local.ALLOW_DIRECT_EXEC assert allow is False + + +class TestRunningInContainer: + def test_returns_bool(self): + # Just verify it returns a boolean (actual detection depends on environment) + result = _running_in_container() + assert isinstance(result, bool) + + def test_kubernetes_env_detected(self, monkeypatch): + # Simulate Kubernetes environment + monkeypatch.setenv("KUBERNETES_SERVICE_HOST", "10.0.0.1") + # Reload the function to pick up the env var + assert _running_in_container() is True + + def test_dockerenv_file_not_present_outside_container(self, monkeypatch, tmp_path): + # When /.dockerenv doesn't exist and no other indicators + monkeypatch.delenv("KUBERNETES_SERVICE_HOST", raising=False) + # The function checks for /.dockerenv at system root, so on a host system + # this should return False (unless we're actually in a container) + # This is more of a smoke test + result = _running_in_container() + # Can't assert specific value as it depends on actual runtime environment + assert isinstance(result, bool) diff --git a/backend/tests/test_tools_writing.py b/backend/tests/test_tools_writing.py index 30b18bf..d2d7cee 100644 --- a/backend/tests/test_tools_writing.py +++ b/backend/tests/test_tools_writing.py @@ -103,18 +103,26 @@ async def test_writes_section(self): class TestGetDraft: async def test_no_project(self): + from unittest.mock import AsyncMock, patch + from openmlr.tools.writing import _projects _projects.clear() - result, ok = _get_draft(conv_id=999) + # Mock _get_author_info to avoid database calls + with patch('openmlr.tools.writing._get_author_info', new_callable=AsyncMock, return_value=None): + result, ok = await _get_draft(conv_id=999) assert ok is False async def test_generates_draft(self): + from unittest.mock import AsyncMock, patch + from openmlr.tools.writing import _projects _projects.clear() _create_project(conv_id=1, title="The Paper") _set_outline(conv_id=1, outline=[{"id": "intro", "title": "Introduction"}]) _write_section(conv_id=1, section_id="intro", content="This is the intro.") - result, ok = _get_draft(conv_id=1) + # Mock _get_author_info to avoid database calls + with patch('openmlr.tools.writing._get_author_info', new_callable=AsyncMock, return_value=None): + result, ok = await _get_draft(conv_id=1) assert ok is True assert "# The Paper" in result assert "Introduction" in result diff --git a/docker-compose.prod.yml b/docker-compose.prod.yml index d83bd14..d479428 100644 --- a/docker-compose.prod.yml +++ b/docker-compose.prod.yml @@ -66,6 +66,7 @@ services: condition: service_healthy volumes: - ./backend/configs:/app/backend/configs + - ./.keys:/app/.keys security_opt: - no-new-privileges:true restart: unless-stopped @@ -93,6 +94,7 @@ services: condition: service_healthy volumes: - ./backend/configs:/app/backend/configs + - ./.keys:/app/.keys security_opt: - no-new-privileges:true restart: unless-stopped diff --git a/docker-compose.yml b/docker-compose.yml index 572272d..6eaadfd 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -70,6 +70,7 @@ services: - ./backend:/app/backend - backend-venv:/app/backend/.venv - ./frontend/dist:/app/frontend/dist + - ./.keys:/app/.keys # Worker with auto-restart on code changes worker: @@ -103,6 +104,7 @@ services: volumes: - ./backend:/app/backend - backend-venv:/app/backend/.venv + - ./.keys:/app/.keys # Docs site with live reload docs: diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx index 22490d2..c3cd07f 100644 --- a/frontend/src/App.tsx +++ b/frontend/src/App.tsx @@ -1,6 +1,7 @@ import { useState, useCallback, useEffect, useRef } from 'react'; import { Routes, Route, Navigate, useNavigate, useParams } from 'react-router-dom'; import { Copy, Check } from 'lucide-react'; +import { ComputeSelector } from './components/ComputeSelector'; import { useSSE } from './hooks/useSSE'; import { useJobStatus } from './hooks/useJobStatus'; import { api } from './api'; @@ -20,7 +21,7 @@ import { SettingsPage } from './components/SettingsPage'; import { ProvidersSettings } from './components/settings/ProvidersSettings'; import { AgentSettings } from './components/settings/AgentSettings'; import { McpSettings } from './components/settings/McpSettings'; -import { SandboxSettings } from './components/settings/SandboxSettings'; +import { ComputeSettings } from './components/settings/ComputeSettings'; import { WritingSettings } from './components/settings/WritingSettings'; let msgId = 0; @@ -100,6 +101,8 @@ function ChatUI({ const [viewingReport, setViewingReport] = useState(null); const [inputMode, setInputMode] = useState('plan'); const [inputText, setInputText] = useState(''); + const [computeNodes, setComputeNodes] = useState([]); + const [activeCompute, setActiveCompute] = useState(null); // Ref to always have current conv UUID in SSE callback (avoids stale closure) const currentConvUuidRef = useRef(currentConvUuid); @@ -120,14 +123,33 @@ function ChatUI({ } }, []); + const loadComputeNodes = useCallback(async () => { + try { + const data = await api.getComputeNodes(); + setComputeNodes(data.nodes || []); + } catch { + setComputeNodes([]); + } + }, []); + + const loadActiveCompute = useCallback(async (uuid: string) => { + try { + const data = await api.getConversationCompute(uuid); + setActiveCompute(data.node || null); + } catch { + setActiveCompute(null); + } + }, []); + // Initial load - load conversations and activate the correct one useEffect(() => { const init = async () => { + await loadComputeNodes(); const convs = await loadConversations(); // If URL has a conversation UUID, load it directly if (routeUuid) { - switchConv(routeUuid); + await switchConv(routeUuid); return; } @@ -146,7 +168,7 @@ function ChatUI({ const first = convs[0]; setCurrentConvUuid(first.uuid); navigate(`/${first.uuid}`, { replace: true }); - switchConv(first.uuid); + await switchConv(first.uuid); } }; init(); @@ -197,6 +219,9 @@ function ChatUI({ } return { id: nextId(), role: m.role, content: m.content }; }) || []); + + // Load active compute for this conversation + await loadActiveCompute(uuid); } catch { /* */ } }; @@ -214,6 +239,8 @@ function ChatUI({ setMessages([]); setTasks([]); setResources([]); setContextUsage(null); setSearchBudget(null); setApprovalEvent(null); setQuestionsPayload(null); if (conv.model) setModel(conv.model); + // Load default compute for new conversation + await loadActiveCompute(conv.uuid); navigate(`/${conv.uuid}`, { replace: true }); } catch { /* */ } }; @@ -226,11 +253,24 @@ function ChatUI({ if (currentConvUuid === uuid) { setCurrentConvUuid(null); setMessages([]); setTasks([]); setResources([]); setApprovalEvent(null); setQuestionsPayload(null); + setActiveCompute(null); navigate('/', { replace: true }); } } catch { /* */ } }; + const handleComputeChange = useCallback(async (nodeId: number | null) => { + if (!currentConvUuid) return; + try { + if (nodeId === null) { + await api.clearConversationCompute(currentConvUuid); + } else { + await api.setConversationCompute(currentConvUuid, nodeId); + } + await loadActiveCompute(currentConvUuid); + } catch { /* */ } + }, [currentConvUuid, loadActiveCompute]); + // Helper to reload messages from DB for a given conversation const reloadConversationMessages = useCallback(async (uuid: string) => { try { @@ -401,7 +441,7 @@ function ChatUI({ }); break; case 'questions': setCurrentConvStatus('waiting_input'); setQuestionsPayload(data as QuestionsPayload); break; - case 'plan_update': + case 'plan_update': { setTasks(data?.tasks || []); setRightPanelOpen(true); // Auto-compact after all tasks are completed @@ -410,6 +450,7 @@ function ChatUI({ setTimeout(() => api.compact().catch(() => {}), 1000); } break; + } case 'resources_update': setResources(data?.resources || []); setRightPanelOpen(true); break; case 'context_usage': if (data) setContextUsage(data as ContextUsage); break; case 'search_budget': if (data) setSearchBudget(data as SearchBudget); break; @@ -512,15 +553,20 @@ function ChatUI({ return (
{/* Header */} -
-
- OpenMLR +
+
+ OpenMLR
-
+
+
@@ -535,32 +581,35 @@ function ChatUI({ onDelete={handleDeleteConversation} /> -
+
{/* Empty state */} {messages.length === 0 && !effectiveProcessing && ( -
+
{/* Large embossed background text */} {/* Foreground prompt */} -

What would you like to research?

+

What would you like to research?

)} - + {approvalEvent && setApprovalEvent(null)} />} {questionsPayload && { setQuestionsPayload(null); @@ -579,6 +628,7 @@ function ChatUI({ />
+ {/* RightPanel is fixed position, doesn't affect flex layout */} setRightPanelOpen((v) => !v)} onViewReport={(r) => setViewingReport(r)} />
@@ -626,7 +676,7 @@ export default function App() { } /> } /> } /> - } /> + } /> } /> diff --git a/frontend/src/__tests__/SettingsPage.test.tsx b/frontend/src/__tests__/SettingsPage.test.tsx index bb89a1e..694fc1c 100644 --- a/frontend/src/__tests__/SettingsPage.test.tsx +++ b/frontend/src/__tests__/SettingsPage.test.tsx @@ -26,7 +26,8 @@ describe('SettingsPage', () => { renderSettings(); expect(screen.getByText('Providers')).toBeInTheDocument(); expect(screen.getByText('Agent')).toBeInTheDocument(); - expect(screen.getByText('Sandbox')).toBeInTheDocument(); + expect(screen.getByText('MCP Servers')).toBeInTheDocument(); + expect(screen.getByText('Compute')).toBeInTheDocument(); expect(screen.getByText('Writing')).toBeInTheDocument(); }); diff --git a/frontend/src/api.ts b/frontend/src/api.ts index 09dc7f3..7d5a3ee 100644 --- a/frontend/src/api.ts +++ b/frontend/src/api.ts @@ -85,6 +85,10 @@ export const api = { getConversation: (uuid: string) => get(`/api/conversations/${uuid}`), deleteConversation: (uuid: string) => del(`/api/conversations/${uuid}`), switchConversation: (uuid: string) => post(`/api/conversations/${uuid}/switch`, {}), + getConversationCompute: (uuid: string) => get(`/api/conversations/${uuid}/compute`), + setConversationCompute: (uuid: string, nodeId: number | null) => + post(`/api/conversations/${uuid}/compute`, { node_id: nodeId }), + clearConversationCompute: (uuid: string) => del(`/api/conversations/${uuid}/compute`), // Settings getSettings: () => get('/api/settings'), @@ -106,4 +110,21 @@ export const api = { getModels: () => get('/api/models'), getStatus: () => get('/api/status'), saveConfig: (config: Record) => post('/api/config', config), + + // SSH Keys + getKeys: () => get('/api/keys'), + createKey: (body: Record) => post('/api/keys', body), + deleteKey: (filename: string) => del(`/api/keys/${filename}`), + + // Compute Nodes + getComputeNodes: () => get('/api/compute/nodes'), + createComputeNode: (body: Record) => post('/api/compute/nodes', body), + getComputeNode: (id: number) => get(`/api/compute/nodes/${id}`), + updateComputeNode: (id: number, body: Record) => put(`/api/compute/nodes/${id}`, body), + deleteComputeNode: (id: number) => del(`/api/compute/nodes/${id}`), + testComputeNode: (id: number) => post(`/api/compute/nodes/${id}/test`, {}), + testComputeConfig: (type: string, config: Record) => + post('/api/compute/test', { type, config }), + probeComputeNode: (id: number) => post(`/api/compute/nodes/${id}/probe`, {}), + setDefaultComputeNode: (id: number) => post(`/api/compute/nodes/${id}/set-default`, {}), }; diff --git a/frontend/src/components/ComputeSelector.tsx b/frontend/src/components/ComputeSelector.tsx new file mode 100644 index 0000000..c7b836f --- /dev/null +++ b/frontend/src/components/ComputeSelector.tsx @@ -0,0 +1,105 @@ +import { useState, useEffect, useRef } from 'react'; +import { Cpu, ChevronDown, Monitor } from 'lucide-react'; + +interface ComputeNode { + id: number; + name: string; + type: string; + health_status: string; +} + +interface ComputeSelectorProps { + currentNode: ComputeNode | null; + nodes: ComputeNode[]; + onChange: (nodeId: number | null) => void; +} + +export function ComputeSelector({ currentNode, nodes, onChange }: ComputeSelectorProps) { + const [open, setOpen] = useState(false); + const ref = useRef(null); + + useEffect(() => { + function handleClickOutside(event: MouseEvent) { + if (ref.current && !ref.current.contains(event.target as Node)) { + setOpen(false); + } + } + function handleEsc(event: KeyboardEvent) { + if (event.key === 'Escape') setOpen(false); + } + document.addEventListener('mousedown', handleClickOutside); + document.addEventListener('keydown', handleEsc); + return () => { + document.removeEventListener('mousedown', handleClickOutside); + document.removeEventListener('keydown', handleEsc); + }; + }, []); + + const getStatusColor = (status: string) => { + switch (status) { + case 'online': return 'bg-success'; + case 'offline': return 'bg-error'; + case 'degraded': return 'bg-warning'; + default: return 'bg-text-dim'; + } + }; + + return ( +
+ + + {open && ( +
+ {/* Local workspace option */} + + + {nodes.length > 0 &&
} + + {/* Node list */} + {nodes.map((node) => ( + + ))} + + {nodes.length === 0 && ( +
+ No compute nodes configured +
+ )} +
+ )} +
+ ); +} diff --git a/frontend/src/components/InputArea.tsx b/frontend/src/components/InputArea.tsx index 74e72dc..95e8860 100644 --- a/frontend/src/components/InputArea.tsx +++ b/frontend/src/components/InputArea.tsx @@ -1,4 +1,4 @@ -import { useRef, useEffect, useCallback } from 'react'; +import { useRef, useEffect, useCallback, useState } from 'react'; import { ArrowUp, Square } from 'lucide-react'; export type Mode = 'plan' | 'execute'; @@ -52,10 +52,24 @@ export function InputArea({ disabled, showStop, mode, onModeChange, onSend, onSt }, [mode, onModeChange]); const isPlan = mode === 'plan'; + + // Use shorter placeholders on mobile (check window width) + const [isMobile, setIsMobile] = useState(false); + + useEffect(() => { + const checkMobile = () => setIsMobile(window.innerWidth < 640); + checkMobile(); + window.addEventListener('resize', checkMobile); + return () => window.removeEventListener('resize', checkMobile); + }, []); + + const placeholder = isPlan + ? (isMobile ? 'Plan your research...' : 'Plan: ask questions, gather context, create plan...') + : (isMobile ? 'Execute task...' : 'Execute: tell the agent what to do...'); return ( -
-
+
+
{/* Mode toggle button - fixed height to match input */}
-
- {tasks.map((t, i) => ( -
- - {STATUS_ICONS[t.status] || } - - {t.title} -
- ))} - {tasks.length === 0 && ( -
No tasks yet
- )} -
+ + {!tasksCollapsed && ( +
+ {tasks.map((t, i) => ( +
+ + {STATUS_ICONS[t.status] || } + + {t.title} +
+ ))} + {tasks.length === 0 && ( +
No tasks yet
+ )} +
+ )}
{/* Draggable separator */} @@ -277,50 +293,56 @@ export function RightPanel({ tasks, resources, contextUsage, searchBudget, visib /> {/* Resources section */} -
-
+
+
-
- {[...otherResources].sort((a, b) => (a.type === 'plan' ? -1 : b.type === 'plan' ? 1 : 0)).map((r, i) => ( -
- - {RES_ICONS[r.type] || } - -
- {(r.type === 'report' || r.type === 'plan') && r.id ? ( - - ) : ( - {r.title} - )} - {r.url && ( - - - {r.url.length > 35 ? r.url.slice(0, 35) + '...' : r.url} - - )} + + {!resourcesCollapsed && ( +
+ {[...otherResources].sort((a, b) => (a.type === 'plan' ? -1 : b.type === 'plan' ? 1 : 0)).map((r, i) => ( +
+ + {RES_ICONS[r.type] || } + +
+ {(r.type === 'report' || r.type === 'plan') && r.id ? ( + + ) : ( + {r.title} + )} + {r.url && ( + + + {r.url.length > 35 ? r.url.slice(0, 35) + '...' : r.url} + + )} +
-
- ))} - {otherResources.length === 0 && ( -
No resources yet
- )} -
+ ))} + {otherResources.length === 0 && ( +
No resources yet
+ )} +
+ )}
); diff --git a/frontend/src/components/SettingsPage.tsx b/frontend/src/components/SettingsPage.tsx index a2ad7ce..b2f4b71 100644 --- a/frontend/src/components/SettingsPage.tsx +++ b/frontend/src/components/SettingsPage.tsx @@ -1,11 +1,11 @@ import { Link, Outlet, useLocation } from 'react-router-dom'; -import { ArrowLeft, Key, Bot, Server, Box, PenTool } from 'lucide-react'; +import { ArrowLeft, Key, Bot, Server, Cpu, PenTool } from 'lucide-react'; const navItems = [ { path: '/settings/providers', label: 'Providers', icon: Key }, { path: '/settings/agent', label: 'Agent', icon: Bot }, { path: '/settings/mcp', label: 'MCP Servers', icon: Server }, - { path: '/settings/sandbox', label: 'Sandbox', icon: Box }, + { path: '/settings/compute', label: 'Compute', icon: Cpu }, { path: '/settings/writing', label: 'Writing', icon: PenTool }, ]; diff --git a/frontend/src/components/settings/AddKeyModal.tsx b/frontend/src/components/settings/AddKeyModal.tsx new file mode 100644 index 0000000..441786a --- /dev/null +++ b/frontend/src/components/settings/AddKeyModal.tsx @@ -0,0 +1,158 @@ +import { useState, useEffect } from 'react'; +import { X, Upload, KeyRound } from 'lucide-react'; + +interface AddKeyModalProps { + onClose: () => void; + onSubmit: (data: any) => void; +} + +export function AddKeyModal({ onClose, onSubmit }: AddKeyModalProps) { + useEffect(() => { + const handleEsc = (e: KeyboardEvent) => { if (e.key === 'Escape') onClose(); }; + document.addEventListener('keydown', handleEsc); + return () => document.removeEventListener('keydown', handleEsc); + }, [onClose]); + const [mode, setMode] = useState<'upload' | 'generate'>('upload'); + const [filename, setFilename] = useState(''); + const [privateKey, setPrivateKey] = useState(''); + const [algorithm, setAlgorithm] = useState('ed25519'); + const [comment, setComment] = useState(''); + const [submitting, setSubmitting] = useState(false); + + const handleSubmit = async (e: React.FormEvent) => { + e.preventDefault(); + if (!filename.trim()) return; + + setSubmitting(true); + try { + if (mode === 'upload') { + await onSubmit({ + action: 'upload', + filename: filename.trim(), + private_key: privateKey, + comment: comment || undefined, + }); + } else { + await onSubmit({ + action: 'generate', + filename: filename.trim(), + algorithm, + comment: comment || `openmlr-key`, + }); + } + } finally { + setSubmitting(false); + } + }; + + return ( +
+
e.stopPropagation()}> +
+

+ + Add SSH Key +

+ +
+ +
+ {/* Mode toggle */} +
+ + +
+ + {/* Filename */} +
+ + setFilename(e.target.value)} + className="w-full bg-bg border border-border rounded-lg px-3 py-2 text-text text-sm focus:border-primary focus:outline-none" + /> +

Stored in .keys/ directory

+
+ + {mode === 'upload' ? ( +
+ +