From 6ff3346c93fddffb4275cca4601a9b61265b8d94 Mon Sep 17 00:00:00 2001 From: xprilion Date: Mon, 27 Apr 2026 01:02:49 +0530 Subject: [PATCH 1/8] Add compute nodes support --- .keys/.gitignore | 2 + .opencode/plans/compute-nodes.md | 752 ++++++++++++++++++ backend/configs/prompts/system_prompt.yaml | 10 + backend/openmlr/agent/prompts.py | 2 + backend/openmlr/app.py | 4 + backend/openmlr/celery_app.py | 14 +- backend/openmlr/compute/__init__.py | 6 + backend/openmlr/compute/capabilities.py | 85 ++ backend/openmlr/compute/manager.py | 45 ++ backend/openmlr/compute/probe.py | 170 ++++ backend/openmlr/compute/workspace.py | 166 ++++ backend/openmlr/db/models.py | 45 ++ backend/openmlr/db/operations.py | 153 ++++ backend/openmlr/keys/__init__.py | 3 + backend/openmlr/keys/manager.py | 169 ++++ backend/openmlr/routes/agent.py | 122 +++ backend/openmlr/routes/compute.py | 388 +++++++++ backend/openmlr/routes/keys.py | 136 ++++ backend/openmlr/sandbox/interface.py | 33 +- backend/openmlr/sandbox/local.py | 108 +-- backend/openmlr/sandbox/manager.py | 10 +- backend/openmlr/sandbox/modal_sandbox.py | 50 +- backend/openmlr/sandbox/ssh.py | 291 +++++-- backend/openmlr/services/session_manager.py | 82 +- backend/openmlr/tasks/compute_tasks.py | 113 +++ backend/openmlr/tools/compute_tools.py | 527 ++++++++++++ backend/openmlr/tools/registry.py | 18 + backend/openmlr/tools/sandbox_tools.py | 55 +- backend/openmlr/tools/writing.py | 12 +- backend/tests/test_compute.py | 693 ++++++++++++++++ backend/tests/test_sandbox_types.py | 45 +- docker-compose.prod.yml | 2 + docker-compose.yml | 2 + frontend/src/App.tsx | 56 +- frontend/src/api.ts | 21 + frontend/src/components/ComputeSelector.tsx | 101 +++ frontend/src/components/SettingsPage.tsx | 4 +- .../src/components/settings/AddKeyModal.tsx | 158 ++++ .../src/components/settings/AddNodeModal.tsx | 372 +++++++++ .../components/settings/ComputeSettings.tsx | 383 +++++++++ frontend/src/index.css | 14 + 41 files changed, 5188 insertions(+), 234 deletions(-) create mode 100644 .keys/.gitignore create mode 100644 .opencode/plans/compute-nodes.md create mode 100644 backend/openmlr/compute/__init__.py create mode 100644 backend/openmlr/compute/capabilities.py create mode 100644 backend/openmlr/compute/manager.py create mode 100644 backend/openmlr/compute/probe.py create mode 100644 backend/openmlr/compute/workspace.py create mode 100644 backend/openmlr/keys/__init__.py create mode 100644 backend/openmlr/keys/manager.py create mode 100644 backend/openmlr/routes/compute.py create mode 100644 backend/openmlr/routes/keys.py create mode 100644 backend/openmlr/tasks/compute_tasks.py create mode 100644 backend/openmlr/tools/compute_tools.py create mode 100644 backend/tests/test_compute.py create mode 100644 frontend/src/components/ComputeSelector.tsx create mode 100644 frontend/src/components/settings/AddKeyModal.tsx create mode 100644 frontend/src/components/settings/AddNodeModal.tsx create mode 100644 frontend/src/components/settings/ComputeSettings.tsx diff --git a/.keys/.gitignore b/.keys/.gitignore new file mode 100644 index 0000000..d6b7ef3 --- /dev/null +++ b/.keys/.gitignore @@ -0,0 +1,2 @@ +* +!.gitignore diff --git a/.opencode/plans/compute-nodes.md b/.opencode/plans/compute-nodes.md new file mode 100644 index 0000000..5a91c27 --- /dev/null +++ b/.opencode/plans/compute-nodes.md @@ -0,0 +1,752 @@ +# Design Document: Compute Node Ecosystem & Workspace Isolation + +**Status:** Draft +**Author:** OpenCode Agent +**Date:** 2026-04-26 + +--- + +## 1. Goals + +1. **Secure SSH/LAN compute node management** with host-key verification, key-based auth, and per-node credential pairing. +2. **Filesystem isolation** via a per-conversation workspace at `~/.openmlr/workspace-{conv-uuid}`. +3. **Per-conversation compute binding** (sticky default + per-convo override), matching the existing model-selection UX pattern. +4. **Agent-aware compute planning** — the agent can inspect available nodes, match tasks to capabilities, and execute remotely with streamed results. +5. **Zero auto-discovery** — all nodes are added manually. + +--- + +## 2. Non-Goals + +- Auto-discovery via mDNS/Zeroconf (explicitly out of scope). +- Kubernetes or Slurm cluster management. +- Billing/cost tracking for cloud nodes. + +--- + +## 3. Terminology + +| Term | Meaning | +|------|---------| +| **Compute Node** | A target execution environment: local workspace, SSH remote, or Modal sandbox. | +| **Key Asset** | An SSH private key stored on disk in `.keys/`. Metadata (filename, fingerprint, associated nodes) lives in the DB. | +| **Workspace** | A dedicated directory `~/.openmlr/workspace-{conv-uuid}` mounted into all local executions for that conversation. | +| **Active Compute** | The compute node currently bound to a conversation. Defaults to the user's sticky default; can be overridden per conversation. | + +--- + +## 4. Architecture Overview + +``` +┌──────────────────────────────────────────────────────────────┐ +│ Frontend (React) │ +│ Chat Header: [Model] + [Compute] selectors │ +│ Settings > Compute Nodes: CRUD, key pairing, health checks │ +└──────────────────────────┬───────────────────────────────────┘ + │ REST + SSE +┌──────────────────────────▼───────────────────────────────────┐ +│ Backend (FastAPI) │ +│ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ │ +│ │ KeyManager │ │ ComputeNode │ │ Workspace │ │ +│ │ (.keys dir) │ │ Registry │ │ Manager │ │ +│ └──────────────┘ └──────────────┘ └──────────────┘ │ +│ │ │ │ │ +│ ▼ ▼ ▼ │ +│ ┌────────────────────────────────────────────────────────┐ │ +│ │ SessionManager (per-conv) │ │ +│ │ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │ │ +│ │ │ Agent │ │ Tool │ │ Sandbox │ │ │ +│ │ │ Session │──│ Router │──│ Manager │ │ │ +│ │ │ │ │ │ │ (active) │ │ │ +│ │ └─────────────┘ └─────────────┘ └─────────────┘ │ │ +│ └────────────────────────────────────────────────────────┘ │ +│ │ │ │ +│ ▼ ▼ │ +│ ┌──────────────┐ ┌──────────────────────┐ │ +│ │ PostgreSQL │ │ Local / SSH / Modal │ │ +│ │ (metadata) │ │ Compute Nodes │ │ +│ └──────────────┘ └──────────────────────┘ │ +└──────────────────────────────────────────────────────────────┘ +``` + +--- + +## 5. Key Asset Management (`.keys/`) + +### 5.1 Directory Layout + +``` +/.keys/ +├── id_ed25519_workstation +├── id_ed25519_workstation.pub +├── id_rsa_labserver +├── id_rsa_labserver.pub +└── .gitignore # ignores everything (keys must NEVER be committed) +``` + +The `.keys/` directory is created on first run. It is mounted into the backend container via Docker Compose volume: + +```yaml +volumes: + - ./.keys:/app/.keys:ro # read-only in container; write via API +``` + +### 5.2 Key Lifecycle + +| Action | User Flow | Backend Behavior | +|--------|-----------|------------------| +| **Upload** | User pastes or uploads a private key in Settings > Compute > Keys | Backend writes to `.keys/{filename}` with `0o600` permissions; stores `{filename, fingerprint, algorithm, comment, created_at}` in `ssh_keys` table | +| **Generate** | User clicks "Generate Key Pair" | Backend runs `ssh-keygen -t ed25519 -f .keys/{name} -C "openmlr-{user_id}@{timestamp}"`; stores metadata | +| **Delete** | User clicks "Delete" | Backend deletes file from `.keys/` and all `compute_node` rows referencing it; warns if nodes will break | +| **List** | Settings page loads | Backend returns metadata (no private key content ever transmitted) | + +### 5.3 Security + +- **Filesystem**: Private keys are written with `0o600`, directory with `0o700`. +- **Network**: Private key content is NEVER returned in API responses. Only filenames, fingerprints, and public key content are exposed. +- **Validation**: Uploaded keys are validated with `cryptography` library before writing. + +--- + +## 6. Compute Node Registry + +### 6.1 Database Schema + +**`compute_nodes` table** (new) + +```sql +CREATE TABLE compute_nodes ( + id SERIAL PRIMARY KEY, + user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE, + name VARCHAR(100) NOT NULL, -- e.g. "Workstation", "Lab Server" + type VARCHAR(20) NOT NULL, -- "local", "ssh", "modal" + config JSONB NOT NULL DEFAULT '{}', -- host, port, username, key_filename, workdir, etc. + capabilities JSONB DEFAULT '{}', -- cached probe results + health_status VARCHAR(20) DEFAULT 'unknown', -- online, offline, degraded, unknown + last_probed_at TIMESTAMP WITH TIME ZONE, + last_seen_at TIMESTAMP WITH TIME ZONE, + is_default BOOLEAN DEFAULT FALSE, + priority INTEGER DEFAULT 0, -- fallback ordering + created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(), + updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(), + UNIQUE(user_id, name) +); +``` + +**`ssh_keys` table** (new) + +```sql +CREATE TABLE ssh_keys ( + id SERIAL PRIMARY KEY, + user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE, + filename VARCHAR(255) NOT NULL UNIQUE, + fingerprint VARCHAR(255) NOT NULL, -- SHA256 fingerprint + algorithm VARCHAR(50) NOT NULL, -- ssh-ed25519, rsa, etc. + public_key TEXT NOT NULL, + comment TEXT, + created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW() +); +``` + +### 6.2 Config JSONB Structure + +```json +{ + "host": "ml-workstation.local", + "port": 22, + "username": "researcher", + "key_filename": "id_ed25519_workstation", + "workdir": "/home/researcher/openmlr-workspaces", + "host_key_fingerprint": "SHA256:abc123...", + "jump_host": null, + "timeout_seconds": 30, + "modal": { + "image": "python:3.12", + "gpu": "A100", + "packages": ["torch", "transformers"] + } +} +``` + +### 6.3 API Endpoints + +**Key Management** +- `GET /api/keys` — list key metadata (no private content) +- `POST /api/keys` — upload or generate a key + - Body: `{ "action": "upload" | "generate", "filename": "...", "private_key": "..." (upload only), "passphrase": "" }` +- `DELETE /api/keys/{filename}` — delete key + warn of dependent nodes + +**Compute Nodes** +- `GET /api/compute/nodes` — list all nodes with capabilities +- `POST /api/compute/nodes` — create node +- `GET /api/compute/nodes/{id}` — get single node +- `PUT /api/compute/nodes/{id}` — update node config +- `DELETE /api/compute/nodes/{id}` — delete node +- `POST /api/compute/nodes/{id}/test` — connectivity check (lightweight) +- `POST /api/compute/nodes/{id}/probe` — deep capability discovery (heavyweight) +- `POST /api/compute/nodes/{id}/set-default` — set as sticky default + +**Per-Conversation Compute** +- `GET /api/conversations/{uuid}/compute` — get active compute for conversation +- `POST /api/conversations/{uuid}/compute` — bind compute to conversation +- `DELETE /api/conversations/{uuid}/compute` — unbind (falls back to default) + +### 6.4 SSH Security (Critical Fix) + +Replace `paramiko.AutoAddPolicy()` with **strict host-key verification**: + +```python +class StrictHostKeyPolicy(paramiko.MissingHostKeyPolicy): + def __init__(self, expected_fingerprint: str): + self.expected = expected_fingerprint + + def missing_host_key(self, client, hostname, key): + actual = key.get_fingerprint().hex() + if self.expected and actual != self.expected: + raise paramiko.SSHException( + f"Host key mismatch for {hostname}: expected {self.expected}, got {actual}" + ) + # If no fingerprint stored yet, accept and save + return +``` + +On first connect to a new SSH node: +1. Backend connects with `WarningPolicy()` to retrieve the host key. +2. Returns the fingerprint to the frontend with a "Verify Host Key" prompt. +3. User confirms → fingerprint is saved to `config.host_key_fingerprint`. +4. All subsequent connections use `StrictHostKeyPolicy`. + +--- + +## 7. Workspace Isolation + +### 7.1 Directory Layout + +``` +~/.openmlr/ +├── workspace-550e8400-e29b-41d4-a716-446655440000/ +│ ├── .openmlr-meta/ # internal state (not visible to agent) +│ ├── data/ # datasets uploaded or downloaded +│ ├── models/ # trained model checkpoints +│ ├── code/ # scripts written by agent +│ └── outputs/ # plots, logs, results +├── workspace-6ba7b810-9dad-11d1-80b4-00c04fd430c8/ +│ └── ... +└── config.json # global user preferences (optional future) +``` + +### 7.2 Lifecycle + +| Event | Action | +|-------|--------| +| Conversation created | `mkdir -p ~/.openmlr/workspace-{uuid}` | +| `sandbox_exec` in local mode | Commands run with `cwd=~/.openmlr/workspace-{uuid}` | +| `sandbox_write` in local mode | Files written relative to workspace root | +| `sandbox_read` in local mode | Files read relative to workspace root | +| Conversation deleted | Workspace is **archived** to `~/.openmlr/archive/workspace-{uuid}-{timestamp}.tar.gz` | +| User setting changed | N/A — workspaces are immutable boundaries | + +### 7.3 SSH Remote Workspaces + +For SSH nodes, the workspace concept maps to a **remote directory** on the target machine: + +```json +{ + "workdir": "/home/researcher/openmlr-workspaces/workspace-550e8400-..." +} +``` + +The backend ensures the remote directory exists before first execution: +```bash +ssh user@host "mkdir -p /home/researcher/openmlr-workspaces/workspace-{uuid}" +``` + +### 7.4 Modal Workspaces + +Modal sandboxes are ephemeral; "workspace" maps to the sandbox's working directory. Files are not persisted across sandbox destruction. + +--- + +## 8. Per-Conversation Compute Binding + +### 8.1 Sticky Defaults (User-Level) + +Stored in `user_settings` under category `compute`: + +```json +{ + "compute": { + "default_node_id": 3, + "default_node_name": "Workstation" + } +} +``` + +### 8.2 Per-Conversation Override + +Stored in `Conversation.extra` JSONB (existing column): + +```json +{ + "compute_node_id": 5, + "compute_node_name": "Lab Server" +} +``` + +If `extra.compute_node_id` is null, the conversation uses the user's sticky default. + +### 8.3 UX Pattern (Mirrors Model Selection) + +**Header selectors in Chat UI:** +``` +[Model: anthropic/claude-sonnet-4] [Compute: ★ Workstation (SSH)] +``` + +- Clicking **Compute** opens a dropdown: list of all nodes + "Default" option. +- Selecting a compute node: + 1. Calls `POST /api/conversations/{uuid}/compute` with `{node_id}`. + 2. Updates `Conversation.extra`. + 3. Session manager re-creates the sandbox manager for that conversation. +- Selecting "Default" clears the override (deletes key from `extra`). + +### 8.4 Session Creation Flow + +In `SessionManager.get_or_create_session()`: + +```python +# 1. Load user's sticky default compute +user_settings = await ops.get_all_settings(db, user_id, category="compute") +default_node_id = user_settings.get("compute", {}).get("default_node_id") + +# 2. Check conversation override +conv = await ops.get_conversation_by_id(db, conversation_id) +override_node_id = conv.extra.get("compute_node_id") if conv.extra else None + +effective_node_id = override_node_id or default_node_id + +# 3. Initialize sandbox manager with effective node +sandbox_manager = SandboxManager() +if effective_node_id: + node = await ops.get_compute_node(db, effective_node_id) + if node: + await sandbox_manager.create(node.type, node.config) +``` + +--- + +## 9. Agent Compute Tools + +New tools registered in the `ToolRouter`: + +### 9.1 Tool Specifications + +| Tool | Parameters | Description | +|------|------------|-------------| +| `compute_list` | `{}` | List all compute nodes with capabilities and health | +| `compute_probe` | `{"node_name": "..."}` | Run deep capability discovery on a node | +| `compute_select` | `{"node_name": "..."}` | Switch active compute for this conversation | +| `compute_plan` | `{"task": "...", "requirements": {"gpu": true, "min_ram_gb": 32}}` | Recommend best node for a task | + +### 9.2 System Prompt Enhancement + +The system prompt includes a **Compute Environment** section: + +```markdown +## Compute Environment + +Active compute: Workstation (SSH) — ml-workstation.local +- OS: Ubuntu 22.04 +- CPU: 32 cores +- RAM: 128 GB +- GPU: RTX 4090 (24 GB VRAM) +- CUDA: 12.4 +- Python: 3.11 +- Key packages: torch 2.3, transformers 4.40, jax 0.4 + +Other available nodes: +- Laptop (Local): CPU-only, 16 GB RAM +- Cloud (Modal): A100 80 GB, offline + +Use `compute_plan` before starting long-running tasks to verify the active node meets requirements. +``` + +### 9.3 Compute Planning Algorithm + +```python +def plan_compute(task_description: str, requirements: dict, nodes: list) -> dict: + scores = [] + for node in nodes: + if node.health_status != "online": + continue + score = 0 + caps = node.capabilities + + # GPU requirement + if requirements.get("gpu"): + if not caps.get("gpu_available"): + continue + score += 10 + # Prefer more VRAM + vram = caps.get("gpu_vram_gb", 0) + score += min(vram / 10, 5) + + # RAM requirement + min_ram = requirements.get("min_ram_gb", 0) + available_ram = caps.get("available_ram_gb", 0) + if available_ram < min_ram: + continue + score += min(available_ram / min_ram, 3) + + # Prefer lower latency (local > LAN > cloud) + if node.type == "local": + score += 5 + elif node.type == "ssh": + score += 2 + + scores.append({"node": node, "score": score}) + + scores.sort(key=lambda x: x["score"], reverse=True) + return scores[0] if scores else None +``` + +--- + +## 10. Capability Discovery + +### 10.1 Probed Attributes + +```python +@dataclass +class ComputeCapabilities: + platform: str # "Linux 6.5.0-15-generic" + cpu_cores: int + cpu_arch: str # "x86_64" + total_ram_gb: float + available_ram_gb: float + total_disk_gb: float + available_disk_gb: float + gpu_available: bool + gpu_count: int + gpu_info: list[dict] # [{"model": "RTX 4090", "vram_gb": 24, "cuda": "12.4"}] + python_versions: list[str] # ["3.11.4", "3.10.12"] + docker_available: bool + conda_envs: list[str] + installed_packages: list[str] # ["torch==2.3.0", "transformers==4.40.0"] + has_internet: bool + latency_ms: float +``` + +### 10.2 Probe Commands + +| Attribute | Local / SSH Command | Modal Command | +|-----------|---------------------|---------------| +| OS | `uname -s -r` | Same | +| CPU | `nproc && uname -m` | Same | +| RAM | `free -g` | Same | +| Disk | `df -BG /` | Same | +| GPU | `nvidia-smi --query-gpu=name,memory.total --format=csv,noheader` | Same | +| Python | `python3 --version; ls /usr/bin/python*` | Same | +| Packages | `pip list --format=freeze 2>/dev/null | head -50` | Same | +| Docker | `docker info 2>/dev/null` | N/A | + +### 10.3 Caching Strategy + +- Probe results are stored in `compute_nodes.capabilities` JSONB. +- **On-demand refresh**: `compute_probe` tool forces a refresh. +- **Background refresh**: Celery beat task runs every 5 minutes for nodes marked `online`. +- **Stale threshold**: Capabilities older than 1 hour are considered stale; UI shows warning. + +--- + +## 11. Execution Enhancements + +### 11.1 Connection Pooling (SSH) + +```python +class SSHConnectionPool: + """Maintains persistent SSH connections per node with TTL.""" + + def __init__(self, ttl_seconds: int = 300): + self._pools: dict[int, paramiko.SSHClient] = {} + self._last_used: dict[int, float] = {} + self._ttl = ttl_seconds + + async def get(self, node_id: int, config: dict) -> paramiko.SSHClient: + client = self._pools.get(node_id) + if client and client.get_transport() and client.get_transport().is_active(): + self._last_used[node_id] = time.monotonic() + return client + # Reconnect + client = await self._connect(config) + self._pools[node_id] = client + self._last_used[node_id] = time.monotonic() + return client + + async def cleanup(self): + now = time.monotonic() + for node_id, last in list(self._last_used.items()): + if now - last > self._ttl: + self._pools.pop(node_id, None) +``` + +### 11.2 Streaming Output + +Extend `sandbox_exec` to support streaming for long-running jobs: + +```python +# Tool parameter +{ + "command": "python train.py", + "timeout": 3600, + "stream": true # NEW +} +``` + +When `stream=true`: +1. Backend opens the SSH channel / Docker exec / Modal exec. +2. Reads stdout/stderr in chunks. +3. Broadcasts `tool_log` SSE events with partial output. +4. On completion, sends final `tool_output` event. + +### 11.3 File Sync + +New tools for batch transfer: + +| Tool | Description | +|------|-------------| +| `compute_sync_up` | Sync local workspace files to remote node (rsync/scp) | +| `compute_sync_down` | Sync remote files back to local workspace | + +```python +# compute_sync_up +{ + "paths": ["data/", "code/train.py"], + "direction": "up" +} +``` + +--- + +## 12. Frontend: Settings UI + +### 12.1 Settings Navigation + +Replace "Sandbox" with "Compute" in the settings nav: + +```typescript +const navItems = [ + { path: '/settings/providers', label: 'Providers', icon: Key }, + { path: '/settings/agent', label: 'Agent', icon: Bot }, + { path: '/settings/mcp', label: 'MCP Servers', icon: Server }, + { path: '/settings/compute', label: 'Compute', icon: Cpu }, // NEW + { path: '/settings/writing', label: 'Writing', icon: PenTool }, +]; +``` + +### 12.2 Compute Settings Page Structure + +``` +┌──────────────────────────────────────────────────────────────┐ +│ Compute │ +├──────────────────────────────────────────────────────────────┤ +│ SSH Keys [+ Add]│ +├──────────────────────────────────────────────────────────────┤ +│ • id_ed25519_workstation SHA256:abc... [Delete] │ +│ • id_rsa_labserver SHA256:def... [Delete] │ +├──────────────────────────────────────────────────────────────┤ +│ Compute Nodes [+ Add]│ +├──────────────────────────────────────────────────────────────┤ +│ ★ Workstation (SSH) ● Online RTX 4090 [Default] │ +│ Host: ml-workstation.local │ +│ Key: id_ed25519_workstation │ +│ Workspace: ~/.openmlr/workspace-... │ +│ [Test] [Probe] [Edit] [Delete] │ +├──────────────────────────────────────────────────────────────┤ +│ Laptop (Local) ● Online CPU-only │ +│ [Set Default] [Probe] [Edit] [Delete] │ +├──────────────────────────────────────────────────────────────┤ +│ Cloud GPU (Modal) ○ Offline A100 │ +│ [Set Default] [Probe] [Edit] [Delete] │ +└──────────────────────────────────────────────────────────────┘ +``` + +### 12.3 Add Node Modal + +**Step 1: Type** +- Radio: Local Workspace / SSH Remote / Modal Cloud + +**Step 2: Config** +- SSH: host, port, username, key selector (dropdown of uploaded keys), workdir +- Local: workdir (defaults to `~/.openmlr/workspace-{conv-uuid}`) +- Modal: image, GPU type, packages + +**Step 3: Test & Verify** +- "Test Connection" button +- If SSH and first connect: show host key fingerprint, ask user to verify +- On success: save node + +### 12.4 Chat Header Compute Selector + +```tsx +// In ChatUI header, next to ModelModal + api.setConversationCompute(currentConvUuid, nodeId)} +/> +``` + +Dropdown shows: +- "Default (★ Workstation)" — selects null override +- Separator +- All nodes with status dot (green/orange/gray) + +--- + +## 13. Implementation Phases + +### Phase 1: Foundation — Keys, Registry, Secure SSH +**Backend:** +- Create `.keys/` directory manager (read/write/validate/list) +- Create `ssh_keys` table and CRUD API (`/api/keys`) +- Create `compute_nodes` table and CRUD API (`/api/compute/nodes`) +- Fix SSH security: `StrictHostKeyPolicy`, fingerprint verification +- Update `SSHSandbox` to use `SSHConnectionPool` and key assets from `.keys/` + +**Frontend:** +- Replace "Sandbox" nav with "Compute" +- Build `ComputeSettings` page with KeyManager and NodeRegistry sub-components +- Build `AddNodeModal` and `AddKeyModal` + +**DB:** +- Alembic migration for `ssh_keys` and `compute_nodes` + +### Phase 2: Workspaces & Per-Conversation Binding +**Backend:** +- `WorkspaceManager` — create/get workspace directory per conversation UUID +- Update `LocalSandbox` to use workspace as `workdir` +- Update `SSHSandbox` to ensure remote workspace exists +- Add `/api/conversations/{uuid}/compute` endpoints +- Store sticky default in `user_settings` category `compute` +- Store override in `Conversation.extra` + +**Frontend:** +- `ComputeSelector` component in chat header +- Update `ChatUI` to load and display active compute +- Update conversation switch logic to restore active compute + +### Phase 3: Capability Discovery & Agent Tools +**Backend:** +- Enhanced `probe_environment()` with structured `ComputeCapabilities` +- Background Celery task for health checks +- Add `compute_list`, `compute_probe`, `compute_select`, `compute_plan` tools +- Update system prompt builder to include compute environment context + +**Frontend:** +- Node health status indicators (polling every 30s) +- "Probe" button in settings with progress spinner + +### Phase 4: Streaming & Advanced Execution +**Backend:** +- Streaming `sandbox_exec` with `tool_log` SSE events +- `compute_sync_up` / `compute_sync_down` tools (rsync wrapper) +- Modal workspace persistence (optional) + +**Frontend:** +- Live output streaming in message list for long-running commands +- Sync progress indicators + +--- + +## 14. Security Checklist + +| # | Concern | Mitigation | +|---|---------|------------| +| 1 | SSH key exposure | Keys stored on disk only; DB holds only metadata; API never returns private content | +| 2 | Host key spoofing | Strict fingerprint verification; warn on mismatch; reject unknown keys | +| 3 | Path traversal | All file operations validated against workspace root | +| 4 | Privilege escalation | Default to non-root; `can_sudo` not exposed to agent by default | +| 5 | Key file permissions | `0o600` on private keys, `0o700` on `.keys/` directory | +| 6 | Credential persistence | `.keys/` mounted via Docker volume; survives container restarts | +| 7 | Workspace isolation | Each conversation has unique workspace UUID; no cross-conversation access | + +--- + +## 15. Migration Path + +### From Existing `SandboxConfig` + +The existing `sandbox_configs` table is superseded by `compute_nodes`. + +**Migration strategy:** +1. Create `compute_nodes` and `ssh_keys` tables. +2. Migrate existing `sandbox_configs` rows: + - `type=local` → `compute_nodes` with `type=local` + - `type=ssh` → `compute_nodes` with `type=ssh`; extract `key_path` into a key asset if it points to `.keys/` + - `type=modal` → `compute_nodes` with `type=modal` +3. Deprecate `sandbox_configs` table (drop in a future migration). +4. Update `SandboxSettings.tsx` → `ComputeSettings.tsx`. + +--- + +## 16. Open Questions + +1. **Key passphrase support**: Should we support SSH keys with passphrases? If yes, how do we cache the decrypted key securely in memory? +2. **Workspace cleanup policy**: Should we auto-delete archived workspaces after N days, or keep them indefinitely? +3. **Modal integration depth**: Should Modal nodes support `Modal.App.lookup()` reuse, or always create ephemeral sandboxes? +4. **Multi-node parallel execution**: Should the agent be able to run commands on multiple nodes simultaneously (e.g., distributed training), or is single-node-at-a-time sufficient for V1? + +--- + +## 17. Files to Create / Modify + +### New Files +``` +backend/openmlr/keys/__init__.py +backend/openmlr/keys/manager.py +backend/openmlr/compute/__init__.py +backend/openmlr/compute/manager.py +backend/openmlr/compute/workspace.py +backend/openmlr/compute/probe.py +backend/openmlr/compute/planner.py +backend/openmlr/routes/keys.py +backend/openmlr/routes/compute.py +backend/openmlr/db/migrations/..._add_compute_nodes_and_ssh_keys.py +frontend/src/components/settings/ComputeSettings.tsx +frontend/src/components/settings/AddNodeModal.tsx +frontend/src/components/settings/AddKeyModal.tsx +frontend/src/components/ComputeSelector.tsx +``` + +### Modified Files +``` +backend/openmlr/db/models.py # Add ComputeNode, SSHKey models +backend/openmlr/db/operations.py # Add compute node CRUD ops +backend/openmlr/routes/settings.py # Add compute to provider list (optional) +backend/openmlr/sandbox/ssh.py # StrictHostKeyPolicy, connection pool +backend/openmlr/sandbox/manager.py # Integrate ComputeManager +backend/openmlr/sandbox/local.py # Workspace-aware workdir +backend/openmlr/services/session_manager.py # Bind compute on session creation +backend/openmlr/tools/registry.py # Register compute_* tools +backend/openmlr/agent/prompts.py # Inject compute env into system prompt +frontend/src/App.tsx # Add ComputeSelector to header +frontend/src/api.ts # Add compute endpoints +frontend/src/components/SettingsPage.tsx # Update nav items +``` + +--- + +## Appendix A: Example User Journey + +**Scenario:** User has a laptop and an SSH workstation. + +1. **User opens Settings > Compute** +2. **Uploads key**: Pastes `id_ed25519_workstation` private key → saved to `.keys/` +3. **Adds node**: Creates "Workstation" (SSH, host=ml.local, key=id_ed25519_workstation) +4. **Verifies host key**: Backend shows fingerprint, user clicks "Trust" +5. **Sets default**: Clicks "Set Default" on Workstation +6. **Creates conversation**: New chat auto-binds to Workstation +7. **Agent probes**: User asks "What GPU do I have?" → agent runs `sandbox_probe` → sees RTX 4090 +8. **Switches compute**: User clicks header dropdown, selects "Laptop (Local)" +9. **Agent adapts**: Next `sandbox_probe` shows CPU-only; agent avoids GPU tasks +10. **Runs training**: User asks "Train ResNet on CIFAR-10" → agent calls `compute_plan` → recommends Workstation → `compute_select` → runs training with streamed output diff --git a/backend/configs/prompts/system_prompt.yaml b/backend/configs/prompts/system_prompt.yaml index f73a71c..bd340c8 100644 --- a/backend/configs/prompts/system_prompt.yaml +++ b/backend/configs/prompts/system_prompt.yaml @@ -145,3 +145,13 @@ prompt: | - Date: {{ date }} - User: {{ username }} - Mode: {{ mode }} + + {% if compute_env %} + {{ compute_env }} + {% endif %} + + # Compute Planning + When starting tasks that require significant computation (training models, processing large datasets, etc.): + 1. Use `compute_plan` to verify the active node meets requirements + 2. If not, use `compute_select` to switch to a suitable node + 3. Always `sandbox_probe` before executing code on a node for the first time diff --git a/backend/openmlr/agent/prompts.py b/backend/openmlr/agent/prompts.py index 6004fcd..8528cf7 100644 --- a/backend/openmlr/agent/prompts.py +++ b/backend/openmlr/agent/prompts.py @@ -25,6 +25,7 @@ def build_system_prompt( mode: str = "general", username: str = "user", sandbox_info: str = "none", + compute_env: str = "", config: AgentConfig | None = None, ) -> str: """Build the full system prompt from YAML template.""" @@ -58,6 +59,7 @@ def build_system_prompt( timezone="UTC", username=username, sandbox_info=sandbox_info, + compute_env=compute_env, ) return prompt diff --git a/backend/openmlr/app.py b/backend/openmlr/app.py index d034841..b150bfc 100644 --- a/backend/openmlr/app.py +++ b/backend/openmlr/app.py @@ -71,13 +71,17 @@ async def lifespan(app: FastAPI): # ── API routers ────────────────────────────────────────── from .auth.router import router as auth_router from .routes.agent import router as agent_router +from .routes.compute import router as compute_router from .routes.health import router as health_router +from .routes.keys import router as keys_router from .routes.settings import router as settings_router app.include_router(auth_router) app.include_router(agent_router) app.include_router(settings_router) app.include_router(health_router) +app.include_router(keys_router) +app.include_router(compute_router) # ── Global error handler ──────────────────────────────── diff --git a/backend/openmlr/celery_app.py b/backend/openmlr/celery_app.py index 91b43ab..7190264 100644 --- a/backend/openmlr/celery_app.py +++ b/backend/openmlr/celery_app.py @@ -13,7 +13,7 @@ "openmlr", broker=REDIS_URL, backend=REDIS_URL, - include=["openmlr.tasks.agent_tasks"], + include=["openmlr.tasks.agent_tasks", "openmlr.tasks.compute_tasks"], ) # Celery configuration @@ -43,6 +43,18 @@ # Default queue task_default_queue="default", + + # Beat schedule for periodic tasks + beat_schedule={ + "health-check-all-nodes": { + "task": "openmlr.tasks.compute_tasks.health_check_all_nodes", + "schedule": 300.0, # Every 5 minutes + }, + "cleanup-old-workspaces": { + "task": "openmlr.tasks.compute_tasks.cleanup_old_workspaces", + "schedule": 86400.0, # Every 24 hours + }, + }, ) diff --git a/backend/openmlr/compute/__init__.py b/backend/openmlr/compute/__init__.py new file mode 100644 index 0000000..ff1c488 --- /dev/null +++ b/backend/openmlr/compute/__init__.py @@ -0,0 +1,6 @@ +from .capabilities import ComputeCapabilities, GPUInfo +from .manager import ComputeManager +from .probe import probe_sandbox +from .workspace import WorkspaceManager + +__all__ = ["ComputeCapabilities", "ComputeManager", "GPUInfo", "probe_sandbox", "WorkspaceManager"] diff --git a/backend/openmlr/compute/capabilities.py b/backend/openmlr/compute/capabilities.py new file mode 100644 index 0000000..21c36e8 --- /dev/null +++ b/backend/openmlr/compute/capabilities.py @@ -0,0 +1,85 @@ +"""Compute capability discovery and planning.""" + +from dataclasses import dataclass, field +from typing import Any + + +@dataclass +class GPUInfo: + """Information about a GPU.""" + model: str = "" + vram_gb: float = 0.0 + cuda_version: str = "" + driver_version: str = "" + + +@dataclass +class ComputeCapabilities: + """Comprehensive capabilities of a compute node.""" + # Platform + platform: str = "unknown" + cpu_cores: int = 0 + cpu_arch: str = "unknown" + + # Memory + total_ram_gb: float = 0.0 + available_ram_gb: float = 0.0 + + # Storage + total_disk_gb: float = 0.0 + available_disk_gb: float = 0.0 + + # GPU + gpu_available: bool = False + gpu_count: int = 0 + gpu_info: list[GPUInfo] = field(default_factory=list) + + # Software + python_versions: list[str] = field(default_factory=list) + docker_available: bool = False + conda_envs: list[str] = field(default_factory=list) + installed_packages: list[str] = field(default_factory=list) + + # Network + has_internet: bool = True + latency_ms: float = 0.0 + + def to_dict(self) -> dict[str, Any]: + """Serialize to dict.""" + return { + "platform": self.platform, + "cpu_cores": self.cpu_cores, + "cpu_arch": self.cpu_arch, + "total_ram_gb": self.total_ram_gb, + "available_ram_gb": self.available_ram_gb, + "total_disk_gb": self.total_disk_gb, + "available_disk_gb": self.available_disk_gb, + "gpu_available": self.gpu_available, + "gpu_count": self.gpu_count, + "gpu_info": [ + { + "model": g.model, + "vram_gb": g.vram_gb, + "cuda_version": g.cuda_version, + "driver_version": g.driver_version, + } + for g in self.gpu_info + ], + "python_versions": self.python_versions, + "docker_available": self.docker_available, + "conda_envs": self.conda_envs, + "installed_packages": self.installed_packages, + "has_internet": self.has_internet, + "latency_ms": self.latency_ms, + } + + @classmethod + def from_dict(cls, data: dict) -> "ComputeCapabilities": + """Deserialize from dict.""" + caps = cls() + for key, value in data.items(): + if key == "gpu_info" and value: + caps.gpu_info = [GPUInfo(**g) for g in value] + elif hasattr(caps, key): + setattr(caps, key, value) + return caps diff --git a/backend/openmlr/compute/manager.py b/backend/openmlr/compute/manager.py new file mode 100644 index 0000000..af4b8ef --- /dev/null +++ b/backend/openmlr/compute/manager.py @@ -0,0 +1,45 @@ +"""Compute Node Manager — registry, validation, and lifecycle.""" + +from pathlib import Path + + +class ComputeManager: + """High-level operations for compute node management.""" + + def __init__(self, key_manager): + self.key_manager = key_manager + + def validate_node_config(self, node_type: str, config: dict) -> tuple[bool, str]: + """Validate a compute node configuration. Pure check, no side effects.""" + if node_type == "ssh": + return self._validate_ssh_config(config) + elif node_type == "local": + return self._validate_local_config(config) + elif node_type == "modal": + return self._validate_modal_config(config) + else: + return False, f"Unknown node type: {node_type}" + + def _validate_ssh_config(self, config: dict) -> tuple[bool, str]: + required = ["host", "username"] + for field in required: + if not config.get(field): + return False, f"SSH config requires '{field}'" + + key_filename = config.get("key_filename") + if key_filename and not self.key_manager.key_exists(key_filename): + return False, f"SSH key not found: {key_filename}" + + return True, "" + + def _validate_local_config(self, config: dict) -> tuple[bool, str]: + workdir = config.get("workdir", "") + if workdir: + path = Path(workdir).expanduser() + # Only validate — don't create directories as a side effect + if path.exists() and not path.is_dir(): + return False, f"Path exists but is not a directory: {path}" + return True, "" + + def _validate_modal_config(self, config: dict) -> tuple[bool, str]: + return True, "" diff --git a/backend/openmlr/compute/probe.py b/backend/openmlr/compute/probe.py new file mode 100644 index 0000000..c8f9e8f --- /dev/null +++ b/backend/openmlr/compute/probe.py @@ -0,0 +1,170 @@ +"""Environment probing for all sandbox types.""" + +import time + +from .capabilities import ComputeCapabilities, GPUInfo + + +async def probe_sandbox(sandbox) -> ComputeCapabilities: + """Deep capability discovery for any sandbox implementation.""" + caps = ComputeCapabilities() + start = time.monotonic() + + # Platform + result = await sandbox.execute("uname -s -r 2>/dev/null || echo 'unknown'", timeout=5) + if result.success: + caps.platform = result.output.strip() + + # CPU cores and architecture + result = await sandbox.execute("nproc 2>/dev/null || sysctl -n hw.ncpu 2>/dev/null || echo '0'", timeout=5) + if result.success: + try: + caps.cpu_cores = int(result.output.strip()) + except ValueError: + pass + + result = await sandbox.execute("uname -m 2>/dev/null || echo 'unknown'", timeout=5) + if result.success: + caps.cpu_arch = result.output.strip() + + # RAM (Linux) + result = await sandbox.execute( + "free -g 2>/dev/null | grep Mem | awk '{print $2, $7}' || " + "echo '0 0'", + timeout=5, + ) + if result.success: + parts = result.output.strip().split() + if len(parts) >= 2: + try: + caps.total_ram_gb = float(parts[0]) + caps.available_ram_gb = float(parts[1]) + except ValueError: + pass + + # Disk + result = await sandbox.execute( + "df -BG / 2>/dev/null | tail -1 | awk '{print $2, $4}' || echo '0 0'", + timeout=5, + ) + if result.success: + parts = result.output.strip().split() + if len(parts) >= 2: + try: + caps.total_disk_gb = float(parts[0].replace("G", "")) + caps.available_disk_gb = float(parts[1].replace("G", "")) + except ValueError: + pass + + # GPU — query model, memory, driver; then get CUDA version separately + result = await sandbox.execute( + "nvidia-smi --query-gpu=name,memory.total,driver_version " + "--format=csv,noheader 2>/dev/null || echo ''", + timeout=10, + ) + if result.success and result.output.strip(): + lines = [ln.strip() for ln in result.output.strip().split("\n") if ln.strip()] + caps.gpu_count = len(lines) + caps.gpu_available = caps.gpu_count > 0 + + # Get CUDA toolkit version + cuda_ver = "" + cuda_result = await sandbox.execute( + "nvidia-smi 2>/dev/null | grep 'CUDA Version' | awk '{print $9}'", + timeout=5, + ) + if cuda_result.success and cuda_result.output.strip(): + cuda_ver = cuda_result.output.strip() + + for line in lines: + parts = [p.strip() for p in line.split(",")] + if len(parts) >= 3: + gpu = GPUInfo( + model=parts[0], + vram_gb=_parse_vram(parts[1]), + cuda_version=cuda_ver, + driver_version=parts[2], + ) + caps.gpu_info.append(gpu) + + # Python versions + result = await sandbox.execute( + "python3 --version 2>/dev/null; ls /usr/bin/python* 2>/dev/null || true", + timeout=5, + ) + if result.success: + versions = [] + for line in result.output.strip().split("\n"): + line = line.strip() + if line.startswith("Python "): + versions.append(line.replace("Python ", "")) + elif "/python" in line and not line.endswith("*"): + # Extract version from path like /usr/bin/python3.11 + ver = line.split("/")[-1].replace("python", "") + if ver and ver not in versions: + versions.append(ver) + caps.python_versions = versions + + # Docker + result = await sandbox.execute( + "docker info >/dev/null 2>&1 && echo 'DOCKER_OK' || echo 'DOCKER_FAIL'", + timeout=5, + ) + if result.success and "DOCKER_OK" in result.output: + caps.docker_available = True + + # Conda envs + result = await sandbox.execute( + "conda env list 2>/dev/null | grep -v '^#' | awk '{print $1}' || true", + timeout=5, + ) + if result.success: + envs = [ln.strip() for ln in result.output.strip().split("\n") if ln.strip()] + caps.conda_envs = envs + + # Key packages + result = await sandbox.execute( + "pip list --format=freeze 2>/dev/null | head -50 || true", + timeout=10, + ) + if result.success: + caps.installed_packages = [ + line.strip() for line in result.output.strip().split("\n") + if line.strip() and "==" in line + ] + + # Internet connectivity + result = await sandbox.execute( + "curl -s -o /dev/null -w '%{http_code}' --max-time 5 https://pypi.org/simple/ 2>/dev/null || echo '000'", + timeout=10, + ) + if result.success and result.output.strip() == "200": + caps.has_internet = True + else: + # Fallback ping + result = await sandbox.execute( + "ping -c 1 -W 3 8.8.8.8 2>/dev/null || true", + timeout=10, + ) + caps.has_internet = result.success and "1 received" in result.output + + caps.latency_ms = (time.monotonic() - start) * 1000 + return caps + + +def _parse_vram(vram_str: str) -> float: + """Parse VRAM string like '24576 MiB' or '24 GB' to GB.""" + vram_str = vram_str.strip().lower() + try: + if "mib" in vram_str: + return float(vram_str.replace("mib", "").strip()) / 1024 + elif "gib" in vram_str: + return float(vram_str.replace("gib", "").strip()) + elif "gb" in vram_str: + return float(vram_str.replace("gb", "").strip()) + elif "mb" in vram_str: + return float(vram_str.replace("mb", "").strip()) / 1024 + else: + return float(vram_str) + except ValueError: + return 0.0 diff --git a/backend/openmlr/compute/workspace.py b/backend/openmlr/compute/workspace.py new file mode 100644 index 0000000..9517bf3 --- /dev/null +++ b/backend/openmlr/compute/workspace.py @@ -0,0 +1,166 @@ +"""Workspace Manager — per-conversation filesystem isolation.""" + +import os +import shutil +import tarfile +from datetime import UTC, datetime +from pathlib import Path + + +class WorkspaceManager: + """Manages isolated workspace directories for each conversation.""" + + def __init__(self, base_dir: str | Path = None): + self.base_dir = Path(base_dir) if base_dir else Path.home() / ".openmlr" + self.workspace_dir = self.base_dir / "workspaces" + self.archive_dir = self.base_dir / "archive" + self._ensure_dirs() + + def _ensure_dirs(self) -> None: + """Ensure workspace and archive directories exist.""" + self.workspace_dir.mkdir(parents=True, exist_ok=True) + self.archive_dir.mkdir(parents=True, exist_ok=True) + + def get_workspace_path(self, conversation_uuid: str) -> Path: + """Get the workspace directory for a conversation.""" + return self.workspace_dir / f"workspace-{conversation_uuid}" + + def create_workspace(self, conversation_uuid: str) -> Path: + """Create a new workspace directory for a conversation.""" + path = self.get_workspace_path(conversation_uuid) + path.mkdir(parents=True, exist_ok=True) + # Create standard subdirectories + for subdir in ["data", "models", "code", "outputs"]: + (path / subdir).mkdir(exist_ok=True) + # Create meta directory (hidden from agent) + (path / ".openmlr-meta").mkdir(exist_ok=True) + return path + + def workspace_exists(self, conversation_uuid: str) -> bool: + """Check if a workspace exists.""" + return self.get_workspace_path(conversation_uuid).exists() + + def archive_workspace(self, conversation_uuid: str) -> Path | None: + """Archive a workspace before deletion. Returns archive path.""" + path = self.get_workspace_path(conversation_uuid) + if not path.exists(): + return None + + timestamp = datetime.now(UTC).strftime("%Y%m%d-%H%M%S") + archive_name = f"workspace-{conversation_uuid}-{timestamp}.tar.gz" + archive_path = self.archive_dir / archive_name + + with tarfile.open(archive_path, "w:gz") as tar: + tar.add(path, arcname=path.name) + + return archive_path + + def delete_workspace(self, conversation_uuid: str, archive: bool = True) -> bool: + """Delete a workspace. If archive=True, archive it first.""" + path = self.get_workspace_path(conversation_uuid) + if not path.exists(): + return False + + if archive: + self.archive_workspace(conversation_uuid) + + shutil.rmtree(path) + return True + + def get_workspace_size(self, conversation_uuid: str) -> int: + """Get total size of a workspace in bytes.""" + path = self.get_workspace_path(conversation_uuid) + if not path.exists(): + return 0 + + total = 0 + for dirpath, _, filenames in os.walk(path): + for f in filenames: + fp = Path(dirpath) / f + if fp.exists(): + total += fp.stat().st_size + return total + + def list_workspaces(self) -> list[dict]: + """List all workspaces with metadata.""" + workspaces = [] + for path in self.workspace_dir.glob("workspace-*"): + if path.is_dir(): + uuid = path.name.replace("workspace-", "") + size = self.get_workspace_size(uuid) + workspaces.append({ + "uuid": uuid, + "path": str(path), + "size_bytes": size, + "created": datetime.fromtimestamp(path.stat().st_ctime, UTC).isoformat(), + }) + return sorted(workspaces, key=lambda x: x["created"], reverse=True) + + def cleanup_archives(self, max_age_days: int = 30, max_count: int = 100) -> dict: + """Clean up old workspace archives. + + Args: + max_age_days: Delete archives older than this many days + max_count: Keep at most this many archives, delete oldest first + + Returns: + Dict with deleted count and freed bytes + """ + deleted = 0 + freed_bytes = 0 + + # Get all archives sorted by modification time (oldest first) + archives = [] + for path in self.archive_dir.glob("workspace-*.tar.gz"): + if path.is_file(): + mtime = datetime.fromtimestamp(path.stat().st_mtime, UTC) + archives.append({"path": path, "mtime": mtime, "size": path.stat().st_size}) + + archives.sort(key=lambda x: x["mtime"]) + + # Delete old archives + now = datetime.now(UTC) + for archive in archives: + age_days = (now - archive["mtime"]).days + if age_days > max_age_days: + freed_bytes += archive["size"] + archive["path"].unlink() + deleted += 1 + + # Delete excess archives (oldest first) + remaining = [a for a in archives if a["path"].exists()] + while len(remaining) > max_count: + oldest = remaining.pop(0) + freed_bytes += oldest["size"] + oldest["path"].unlink() + deleted += 1 + + return {"deleted": deleted, "freed_bytes": freed_bytes} + + def cleanup_workspaces(self, conversation_uuids: list[str], archive: bool = True) -> dict: + """Clean up workspaces for deleted conversations. + + Args: + conversation_uuids: List of conversation UUIDs to keep + archive: Whether to archive before deleting + + Returns: + Dict with deleted count and freed bytes + """ + deleted = 0 + freed_bytes = 0 + keep_set = set(conversation_uuids) + + for path in self.workspace_dir.glob("workspace-*"): + if not path.is_dir(): + continue + uuid = path.name.replace("workspace-", "") + if uuid not in keep_set: + size = self.get_workspace_size(uuid) + if archive: + self.archive_workspace(uuid) + shutil.rmtree(path) + freed_bytes += size + deleted += 1 + + return {"deleted": deleted, "freed_bytes": freed_bytes} diff --git a/backend/openmlr/db/models.py b/backend/openmlr/db/models.py index bb0cd93..ef40f74 100644 --- a/backend/openmlr/db/models.py +++ b/backend/openmlr/db/models.py @@ -39,6 +39,8 @@ class User(Base): settings = relationship("UserSetting", back_populates="user", cascade="all, delete-orphan") conversations = relationship("Conversation", back_populates="user", cascade="all, delete-orphan") sandbox_configs = relationship("SandboxConfig", back_populates="user", cascade="all, delete-orphan") + ssh_keys = relationship("SSHKey", back_populates="user", cascade="all, delete-orphan") + compute_nodes = relationship("ComputeNode", back_populates="user", cascade="all, delete-orphan") research_corpus = relationship("ResearchCorpus", back_populates="user", cascade="all, delete-orphan") writing_projects = relationship("WritingProject", back_populates="user", cascade="all, delete-orphan") @@ -111,6 +113,49 @@ class SandboxConfig(Base): user = relationship("User", back_populates="sandbox_configs") +class SSHKey(Base): + __tablename__ = "ssh_keys" + + id = Column(Integer, primary_key=True) + user_id = Column(Integer, ForeignKey("users.id", ondelete="CASCADE"), nullable=False) + filename = Column(String(255), nullable=False) + fingerprint = Column(String(255), nullable=False) + algorithm = Column(String(50), nullable=False) + public_key = Column(Text, nullable=False) + comment = Column(Text, nullable=True) + created_at = Column(DateTime(timezone=True), default=_utcnow, nullable=False) + + user = relationship("User", back_populates="ssh_keys") + __table_args__ = ( + # Unique constraint on (user_id, filename) + {}, + ) + + +class ComputeNode(Base): + __tablename__ = "compute_nodes" + + id = Column(Integer, primary_key=True) + user_id = Column(Integer, ForeignKey("users.id", ondelete="CASCADE"), nullable=False) + name = Column(String(100), nullable=False) + type = Column(String(20), nullable=False) # local, ssh, modal + config = Column(JSON, nullable=False, default=dict) + capabilities = Column(JSON, nullable=True) + health_status = Column(String(20), default="unknown", nullable=False) + last_probed_at = Column(DateTime(timezone=True), nullable=True) + last_seen_at = Column(DateTime(timezone=True), nullable=True) + is_default = Column(Boolean, default=False) + priority = Column(Integer, default=0, nullable=False) + created_at = Column(DateTime(timezone=True), default=_utcnow, nullable=False) + updated_at = Column(DateTime(timezone=True), default=_utcnow, onupdate=_utcnow, nullable=False) + + user = relationship("User", back_populates="compute_nodes") + __table_args__ = ( + # Unique constraint on (user_id, name) + {}, + ) + + class ResearchCorpus(Base): __tablename__ = "research_corpus" diff --git a/backend/openmlr/db/operations.py b/backend/openmlr/db/operations.py index 663a440..2ea72a1 100644 --- a/backend/openmlr/db/operations.py +++ b/backend/openmlr/db/operations.py @@ -7,10 +7,12 @@ from .models import ( AgentJob, + ComputeNode, Conversation, ConversationResource, ConversationTask, Message, + SSHKey, UserSetting, ) @@ -77,6 +79,13 @@ async def update_conversation_model(db: AsyncSession, conv_id: int, model: str): await db.commit() +async def update_conversation_extra(db: AsyncSession, conv_id: int, extra: dict): + await db.execute( + update(Conversation).where(Conversation.id == conv_id).values(extra=extra) + ) + await db.commit() + + async def increment_user_message_count(db: AsyncSession, conv_id: int): await db.execute( update(Conversation) @@ -495,3 +504,147 @@ async def get_user_agent_settings(db: AsyncSession, user_id: int) -> dict: for s in result.scalars().all(): settings[s.key] = _clean_json_value(s.value) return settings + + +# ---- SSH Keys ---- + +async def create_ssh_key( + db: AsyncSession, user_id: int, filename: str, fingerprint: str, + algorithm: str, public_key: str, comment: str | None = None, +) -> SSHKey: + key = SSHKey( + user_id=user_id, + filename=filename, + fingerprint=fingerprint, + algorithm=algorithm, + public_key=public_key, + comment=comment, + ) + db.add(key) + await db.commit() + await db.refresh(key) + return key + + +async def get_ssh_keys(db: AsyncSession, user_id: int) -> list[SSHKey]: + result = await db.execute( + select(SSHKey).where(SSHKey.user_id == user_id).order_by(SSHKey.created_at.desc()) + ) + return list(result.scalars().all()) + + +async def get_ssh_key_by_filename(db: AsyncSession, user_id: int, filename: str) -> SSHKey | None: + result = await db.execute( + select(SSHKey).where(SSHKey.user_id == user_id, SSHKey.filename == filename) + ) + return result.scalar_one_or_none() + + +async def delete_ssh_key(db: AsyncSession, user_id: int, filename: str) -> bool: + result = await db.execute( + select(SSHKey).where(SSHKey.user_id == user_id, SSHKey.filename == filename) + ) + key = result.scalar_one_or_none() + if not key: + return False + await db.delete(key) + await db.commit() + return True + + +# ---- Compute Nodes ---- + +async def create_compute_node( + db: AsyncSession, user_id: int, name: str, node_type: str, config: dict, + is_default: bool = False, priority: int = 0, +) -> ComputeNode: + node = ComputeNode( + user_id=user_id, + name=name, + type=node_type, + config=config, + is_default=is_default, + priority=priority, + ) + db.add(node) + await db.commit() + await db.refresh(node) + return node + + +async def get_compute_nodes(db: AsyncSession, user_id: int) -> list[ComputeNode]: + result = await db.execute( + select(ComputeNode).where(ComputeNode.user_id == user_id).order_by(ComputeNode.priority.desc(), ComputeNode.created_at.desc()) + ) + return list(result.scalars().all()) + + +async def get_compute_node_by_id(db: AsyncSession, node_id: int, user_id: int | None = None) -> ComputeNode | None: + query = select(ComputeNode).where(ComputeNode.id == node_id) + if user_id is not None: + query = query.where(ComputeNode.user_id == user_id) + result = await db.execute(query) + return result.scalar_one_or_none() + + +async def get_compute_node_by_name(db: AsyncSession, user_id: int, name: str) -> ComputeNode | None: + result = await db.execute( + select(ComputeNode).where(ComputeNode.user_id == user_id, ComputeNode.name == name) + ) + return result.scalar_one_or_none() + + +async def update_compute_node( + db: AsyncSession, node_id: int, user_id: int, **kwargs, +) -> ComputeNode | None: + result = await db.execute( + select(ComputeNode).where(ComputeNode.id == node_id, ComputeNode.user_id == user_id) + ) + node = result.scalar_one_or_none() + if not node: + return None + for key, value in kwargs.items(): + if hasattr(node, key): + setattr(node, key, value) + await db.commit() + await db.refresh(node) + return node + + +async def delete_compute_node(db: AsyncSession, node_id: int, user_id: int) -> bool: + result = await db.execute( + select(ComputeNode).where(ComputeNode.id == node_id, ComputeNode.user_id == user_id) + ) + node = result.scalar_one_or_none() + if not node: + return False + await db.delete(node) + await db.commit() + return True + + +async def set_default_compute_node(db: AsyncSession, user_id: int, node_id: int | None) -> None: + # Clear existing default + await db.execute( + update(ComputeNode) + .where(ComputeNode.user_id == user_id, ComputeNode.is_default.is_(True)) + .values(is_default=False) + ) + # Set new default + if node_id is not None: + await db.execute( + update(ComputeNode) + .where(ComputeNode.id == node_id, ComputeNode.user_id == user_id) + .values(is_default=True) + ) + await db.commit() + + +async def get_default_compute_node(db: AsyncSession, user_id: int) -> ComputeNode | None: + result = await db.execute( + select(ComputeNode).where( + ComputeNode.user_id == user_id, + ComputeNode.is_default.is_(True), + ) + ) + return result.scalar_one_or_none() diff --git a/backend/openmlr/keys/__init__.py b/backend/openmlr/keys/__init__.py new file mode 100644 index 0000000..4e02442 --- /dev/null +++ b/backend/openmlr/keys/__init__.py @@ -0,0 +1,3 @@ +from .manager import KeyManager + +__all__ = ["KeyManager"] diff --git a/backend/openmlr/keys/manager.py b/backend/openmlr/keys/manager.py new file mode 100644 index 0000000..a7798d4 --- /dev/null +++ b/backend/openmlr/keys/manager.py @@ -0,0 +1,169 @@ +"""SSH Key Asset Manager — handles .keys/ directory lifecycle.""" + +import os +import stat +from pathlib import Path + +from cryptography.hazmat.primitives import serialization as crypto_serialization +from cryptography.hazmat.primitives.asymmetric import ed25519, rsa + + +class KeyManager: + """Manages SSH private keys stored in a dedicated directory.""" + + def __init__(self, keys_dir: str | Path = None): + self.keys_dir = Path(keys_dir) if keys_dir else Path(__file__).parent.parent.parent.parent.parent / ".keys" + self._ensure_dir() + + def _ensure_dir(self) -> None: + """Ensure .keys/ directory exists with correct permissions.""" + self.keys_dir.mkdir(parents=True, exist_ok=True) + # Set directory permissions to 0o700 (owner read/write/execute only) + os.chmod(self.keys_dir, 0o700) + + def list_keys(self) -> list[dict]: + """List all key files (metadata only, no private content).""" + keys = [] + for path in sorted(self.keys_dir.glob("id_*")): + if path.suffix == ".pub": + continue + pub_path = path.with_suffix(path.suffix + ".pub") + keys.append({ + "filename": path.name, + "has_public": pub_path.exists(), + "size_bytes": path.stat().st_size, + }) + return keys + + def key_exists(self, filename: str) -> bool: + """Check if a key file exists.""" + return (self.keys_dir / filename).exists() + + def get_key_path(self, filename: str) -> Path: + """Get the absolute path to a key file.""" + return self.keys_dir / filename + + def write_key(self, filename: str, private_key_pem: str | bytes) -> Path: + """Write a private key to disk with restrictive permissions.""" + key_path = self.keys_dir / filename + if isinstance(private_key_pem, str): + private_key_pem = private_key_pem.encode("utf-8") + + key_path.write_bytes(private_key_pem) + # Set file permissions to 0o600 (owner read/write only) + os.chmod(key_path, stat.S_IRUSR | stat.S_IWUSR) + return key_path + + def read_key(self, filename: str) -> str: + """Read a private key from disk. Use sparingly.""" + key_path = self.keys_dir / filename + if not key_path.exists(): + raise FileNotFoundError(f"Key not found: {filename}") + return key_path.read_text("utf-8") + + def delete_key(self, filename: str) -> bool: + """Delete a key pair from disk.""" + key_path = self.keys_dir / filename + pub_path = key_path.with_suffix(key_path.suffix + ".pub") + deleted = False + if key_path.exists(): + key_path.unlink() + deleted = True + if pub_path.exists(): + pub_path.unlink() + deleted = True + return deleted + + def generate_key_pair(self, filename: str, algorithm: str = "ed25519", comment: str = "") -> tuple[Path, Path]: + """Generate a new SSH key pair and write to disk.""" + key_path = self.keys_dir / filename + pub_path = key_path.with_suffix(key_path.suffix + ".pub") + + if algorithm == "ed25519": + private_key = ed25519.Ed25519PrivateKey.generate() + private_pem = private_key.private_bytes( + encoding=crypto_serialization.Encoding.PEM, + format=crypto_serialization.PrivateFormat.OpenSSH, + encryption_algorithm=crypto_serialization.NoEncryption(), + ) + public_bytes = private_key.public_key().public_bytes( + encoding=crypto_serialization.Encoding.OpenSSH, + format=crypto_serialization.PublicFormat.OpenSSH, + ) + elif algorithm == "rsa": + private_key = rsa.generate_private_key(public_exponent=65537, key_size=4096) + private_pem = private_key.private_bytes( + encoding=crypto_serialization.Encoding.PEM, + format=crypto_serialization.PrivateFormat.TraditionalOpenSSL, + encryption_algorithm=crypto_serialization.NoEncryption(), + ) + public_bytes = private_key.public_key().public_bytes( + encoding=crypto_serialization.Encoding.OpenSSH, + format=crypto_serialization.PublicFormat.OpenSSH, + ) + else: + raise ValueError(f"Unsupported algorithm: {algorithm}. Use 'ed25519' or 'rsa'.") + + # Write private key with 0o600 + key_path.write_bytes(private_pem) + os.chmod(key_path, stat.S_IRUSR | stat.S_IWUSR) + + # Write public key with 0o644 + pub_pem = public_bytes + (f" {comment}".encode() if comment else b"") + pub_path.write_bytes(pub_pem) + os.chmod(pub_path, stat.S_IRUSR | stat.S_IWUSR | stat.S_IRGRP | stat.S_IROTH) + + return key_path, pub_path + + def validate_key(self, private_key_pem: str | bytes) -> dict: + """Validate an SSH private key and return metadata.""" + if isinstance(private_key_pem, str): + private_key_pem = private_key_pem.encode("utf-8") + + # Try to load as OpenSSH format + try: + key = crypto_serialization.load_ssh_private_key(private_key_pem, password=None) + except Exception: + # Try PEM format + try: + key = crypto_serialization.load_pem_private_key(private_key_pem, password=None) + except Exception as e: + raise ValueError(f"Invalid private key: {e}") + + # Determine algorithm + key_type = type(key).__name__.lower() + if "ed25519" in key_type: + algorithm = "ssh-ed25519" + elif "rsa" in key_type: + algorithm = "ssh-rsa" + else: + algorithm = key_type + + # Generate public key for fingerprint + public_key = key.public_key() + public_bytes = public_key.public_bytes( + encoding=crypto_serialization.Encoding.OpenSSH, + format=crypto_serialization.PublicFormat.OpenSSH, + ) + + # Compute SHA256 fingerprint matching ssh-keygen format: + # The fingerprint is the SHA256 of the raw key blob (base64-decoded + # portion of the OpenSSH public key line), base64-encoded. + import base64 + import hashlib + + pub_line = public_bytes.decode("utf-8").strip() + # OpenSSH format: "ssh-ed25519 AAAA... comment" + parts = pub_line.split() + if len(parts) >= 2: + key_blob = base64.b64decode(parts[1]) + else: + key_blob = public_bytes + raw_hash = hashlib.sha256(key_blob).digest() + fingerprint = base64.b64encode(raw_hash).decode("ascii").rstrip("=") + + return { + "algorithm": algorithm, + "fingerprint": f"SHA256:{fingerprint}", + "public_key": pub_line, + } diff --git a/backend/openmlr/routes/agent.py b/backend/openmlr/routes/agent.py index 5afdfd0..bc17185 100644 --- a/backend/openmlr/routes/agent.py +++ b/backend/openmlr/routes/agent.py @@ -207,6 +207,128 @@ async def switch_conversation( return {"ok": True} +# ── Per-Conversation Compute ───────────────────────────── + +@router.get("/conversations/{uuid}/compute") +async def get_conversation_compute( + uuid: str, + user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + """Get the active compute node for a conversation.""" + conv = await _get_conv_or_404(db, uuid, user.id) + + # Check conversation override first + if conv.extra and conv.extra.get("compute_node_id"): + node = await ops.get_compute_node_by_id(db, conv.extra["compute_node_id"], user.id) + if node: + return { + "node": { + "id": node.id, + "name": node.name, + "type": node.type, + }, + "source": "conversation", + } + + # Fall back to user's default + default_node = await ops.get_default_compute_node(db, user.id) + if default_node: + return { + "node": { + "id": default_node.id, + "name": default_node.name, + "type": default_node.type, + }, + "source": "default", + } + + return {"node": None, "source": None} + + +@router.post("/conversations/{uuid}/compute") +async def set_conversation_compute( + uuid: str, + request: Request, + user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + """Bind a compute node to a conversation.""" + body = await request.json() + node_id = body.get("node_id") + + conv = await _get_conv_or_404(db, uuid, user.id) + + if node_id is None: + # Clear override + extra = conv.extra or {} + extra.pop("compute_node_id", None) + extra.pop("compute_node_name", None) + await ops.update_conversation_extra(db, conv.id, extra) + return {"ok": True, "node": None} + + # Validate node exists and belongs to user + node = await ops.get_compute_node_by_id(db, node_id, user.id) + if not node: + raise HTTPException(status_code=404, detail="Compute node not found") + + extra = conv.extra or {} + extra["compute_node_id"] = node.id + extra["compute_node_name"] = node.name + await ops.update_conversation_extra(db, conv.id, extra) + + # Update active session if it exists — must rebuild tool_router + # since sandbox tools capture sandbox_manager in closures + sm = _sm(request) + active = sm.get_session(conv.id) + if active: + from ..compute import WorkspaceManager + from ..sandbox.manager import SandboxManager + from ..tools.registry import create_tool_router + + workspace_manager = WorkspaceManager() + sandbox_manager = SandboxManager( + workspace_manager=workspace_manager, + conversation_uuid=conv.uuid, + ) + await sandbox_manager.create(node.type, node.config) + # Destroy old sandbox + try: + await active.sandbox_manager.destroy() + except Exception: + pass + active.sandbox_manager = sandbox_manager + # Rebuild tool router with new sandbox_manager + active.tool_router = create_tool_router(sandbox_manager) + active.tool_router.set_context(user_id=user.id, db=db) + + return { + "ok": True, + "node": { + "id": node.id, + "name": node.name, + "type": node.type, + }, + } + + +@router.delete("/conversations/{uuid}/compute") +async def clear_conversation_compute( + uuid: str, + user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + """Clear the compute override for a conversation (falls back to default).""" + conv = await _get_conv_or_404(db, uuid, user.id) + + extra = conv.extra or {} + extra.pop("compute_node_id", None) + extra.pop("compute_node_name", None) + await ops.update_conversation_extra(db, conv.id, extra) + + return {"ok": True} + + # ── Messaging ──────────────────────────────────────────── @router.post("/message") diff --git a/backend/openmlr/routes/compute.py b/backend/openmlr/routes/compute.py new file mode 100644 index 0000000..fbc2c25 --- /dev/null +++ b/backend/openmlr/routes/compute.py @@ -0,0 +1,388 @@ +"""Compute Node routes — CRUD, testing, probing, and defaults.""" + +from datetime import UTC, datetime + +from fastapi import APIRouter, Depends, HTTPException, Request +from sqlalchemy.ext.asyncio import AsyncSession + +from ..compute import ComputeManager +from ..db import operations as ops +from ..db.engine import get_db +from ..db.models import User +from ..dependencies import get_current_user +from ..keys import KeyManager + +router = APIRouter(prefix="/api/compute", tags=["compute"]) + +key_manager = KeyManager() +compute_manager = ComputeManager(key_manager) + +# Fields to redact from config before sending to the frontend +_SENSITIVE_CONFIG_KEYS = {"password", "private_key", "secret", "token"} + + +def _redact_config(config: dict) -> dict: + """Return config with sensitive fields masked.""" + if not config: + return {} + redacted = {} + for k, v in config.items(): + if k in _SENSITIVE_CONFIG_KEYS and v: + redacted[k] = "***" + else: + redacted[k] = v + return redacted + + +def _node_dict(node) -> dict: + return { + "id": node.id, + "name": node.name, + "type": node.type, + "config": _redact_config(node.config), + "capabilities": node.capabilities or {}, + "health_status": node.health_status, + "last_probed_at": node.last_probed_at.isoformat() if node.last_probed_at else None, + "last_seen_at": node.last_seen_at.isoformat() if node.last_seen_at else None, + "is_default": node.is_default, + "priority": node.priority, + "created_at": node.created_at.isoformat() if node.created_at else None, + "updated_at": node.updated_at.isoformat() if node.updated_at else None, + } + + +# ── Compute Nodes ──────────────────────────────────────── + +@router.get("/nodes") +async def list_nodes( + user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + """List all compute nodes for the current user.""" + nodes = await ops.get_compute_nodes(db, user.id) + return {"nodes": [_node_dict(n) for n in nodes]} + + +@router.post("/nodes") +async def create_node( + request: Request, + user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + """Create a new compute node.""" + body = await request.json() + name = body.get("name", "").strip() + node_type = body.get("type", "").strip() + config = body.get("config", {}) + is_default = body.get("is_default", False) + priority = body.get("priority", 0) + + if not name: + raise HTTPException(status_code=400, detail="Missing 'name'") + if node_type not in ("local", "ssh", "modal"): + raise HTTPException(status_code=400, detail="type must be 'local', 'ssh', or 'modal'") + + # Validate config + valid, error = compute_manager.validate_node_config(node_type, config) + if not valid: + raise HTTPException(status_code=400, detail=error) + + # Check for duplicate name + existing = await ops.get_compute_node_by_name(db, user.id, name) + if existing: + raise HTTPException(status_code=409, detail=f"Node '{name}' already exists") + + # If setting as default, clear existing default + if is_default: + await ops.set_default_compute_node(db, user.id, None) + + node = await ops.create_compute_node( + db, user.id, name, node_type, config, + is_default=is_default, priority=priority, + ) + + return {"node": _node_dict(node)} + + +@router.get("/nodes/{node_id}") +async def get_node( + node_id: int, + user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + """Get a single compute node.""" + node = await ops.get_compute_node_by_id(db, node_id, user.id) + if not node: + raise HTTPException(status_code=404, detail="Node not found") + return {"node": _node_dict(node)} + + +@router.put("/nodes/{node_id}") +async def update_node( + node_id: int, + request: Request, + user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + """Update a compute node's configuration.""" + body = await request.json() + node = await ops.get_compute_node_by_id(db, node_id, user.id) + if not node: + raise HTTPException(status_code=404, detail="Node not found") + + updates = {} + if "name" in body: + new_name = body["name"].strip() + if new_name and new_name != node.name: + existing = await ops.get_compute_node_by_name(db, user.id, new_name) + if existing: + raise HTTPException(status_code=409, detail=f"Node '{new_name}' already exists") + updates["name"] = new_name + + if "config" in body: + config = body["config"] + valid, error = compute_manager.validate_node_config(node.type, config) + if not valid: + raise HTTPException(status_code=400, detail=error) + updates["config"] = config + + if "priority" in body: + updates["priority"] = int(body["priority"]) + + if "is_default" in body: + if body["is_default"]: + await ops.set_default_compute_node(db, user.id, None) + updates["is_default"] = bool(body["is_default"]) + + updated = await ops.update_compute_node(db, node_id, user.id, **updates) + return {"node": _node_dict(updated)} + + +@router.delete("/nodes/{node_id}") +async def delete_node( + node_id: int, + user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + """Delete a compute node.""" + deleted = await ops.delete_compute_node(db, node_id, user.id) + if not deleted: + raise HTTPException(status_code=404, detail="Node not found") + return {"ok": True} + + +@router.post("/nodes/{node_id}/set-default") +async def set_default_node( + node_id: int, + user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + """Set a compute node as the user's default.""" + node = await ops.get_compute_node_by_id(db, node_id, user.id) + if not node: + raise HTTPException(status_code=404, detail="Node not found") + await ops.set_default_compute_node(db, user.id, node_id) + return {"ok": True} + + +@router.post("/nodes/{node_id}/test") +async def test_node( + node_id: int, + user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + """Test connectivity to a compute node (lightweight).""" + node = await ops.get_compute_node_by_id(db, node_id, user.id) + if not node: + raise HTTPException(status_code=404, detail="Node not found") + + if node.type == "ssh": + return await _test_ssh_node(node) + elif node.type == "local": + return await _test_local_node(node) + elif node.type == "modal": + return await _test_modal_node(node) + + return {"ok": False, "error": "Unknown node type"} + + +@router.post("/test") +async def test_node_config( + request: Request, + user: User = Depends(get_current_user), +): + """Test connectivity for an unsaved node config. + + Used before creating a node so the user can verify credentials work. + """ + body = await request.json() + node_type = body.get("type", "") + config = body.get("config", {}) + + if node_type not in ("local", "ssh", "modal"): + return {"ok": False, "error": "Invalid node type"} + + # Build a lightweight mock object that _test_* functions can read + class _MockNode: + def __init__(self, t, c): + self.type = t + self.config = c + + mock = _MockNode(node_type, config) + + if node_type == "ssh": + return await _test_ssh_node(mock) + elif node_type == "local": + return await _test_local_node(mock) + elif node_type == "modal": + return await _test_modal_node(mock) + + return {"ok": False, "error": "Unknown node type"} + + +@router.post("/nodes/{node_id}/probe") +async def probe_node( + node_id: int, + user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + """Deep capability discovery for a compute node.""" + node = await ops.get_compute_node_by_id(db, node_id, user.id) + if not node: + raise HTTPException(status_code=404, detail="Node not found") + + from ..compute import WorkspaceManager + from ..compute.probe import probe_sandbox + from ..sandbox.manager import SandboxManager + + try: + wm = WorkspaceManager() + sm = SandboxManager(workspace_manager=wm) + await sm.create(node.type, node.config) + sandbox = sm.get_active() + + if not sandbox: + raise RuntimeError("Failed to create sandbox") + + caps = await probe_sandbox(sandbox) + + # Update node in database + await ops.update_compute_node( + db, node.id, user.id, + capabilities=caps.to_dict(), + health_status="online", + last_probed_at=datetime.now(UTC), + ) + + await sm.destroy() + + return { + "ok": True, + "capabilities": caps.to_dict(), + } + + except Exception as e: + await ops.update_compute_node( + db, node.id, user.id, + health_status="offline", + ) + return {"ok": False, "error": str(e)} + + +async def _test_ssh_node(node): + """Test SSH connectivity and retrieve host key fingerprint if not set.""" + import asyncio + + import paramiko + + config = node.config + host = config.get("host", "") + port = config.get("port", 22) + username = config.get("username", "") + key_filename = config.get("key_filename") + password = config.get("password") + + try: + def _do_test(): + client = paramiko.SSHClient() + # Use WarningPolicy to get host key without auto-adding + client.set_missing_host_key_policy(paramiko.WarningPolicy()) + + connect_kwargs = { + "hostname": host, + "port": port, + "username": username, + "timeout": 10, + } + + if key_filename: + key_path = key_manager.get_key_path(key_filename) + connect_kwargs["key_filename"] = str(key_path) + elif password: + connect_kwargs["password"] = password + + try: + client.connect(**connect_kwargs) + except paramiko.SSHException as e: + # If host key is unknown, paramiko raises an exception with WarningPolicy + # We need to extract the host key from the transport + transport = client.get_transport() + if transport: + transport.close() + raise e + + # Get host key fingerprint + transport = client.get_transport() + host_key = transport.get_remote_server_key() + fingerprint = host_key.get_fingerprint().hex() + + # Run a simple command + stdin, stdout, stderr = client.exec_command("echo ok", timeout=5) + exit_code = stdout.channel.recv_exit_status() + output = stdout.read().decode("utf-8", errors="replace").strip() + + client.close() + + return { + "connected": exit_code == 0 and output == "ok", + "host_key_fingerprint": fingerprint, + "output": output, + } + + result = await asyncio.to_thread(_do_test) + return { + "ok": result["connected"], + "host_key_fingerprint": result.get("host_key_fingerprint"), + "message": "Connected successfully" if result["connected"] else f"Unexpected output: {result['output']}", + } + + except Exception as e: + return {"ok": False, "error": str(e)} + + +async def _test_local_node(node): + """Test local workspace directory.""" + import os + from pathlib import Path + + config = node.config + workdir = config.get("workdir", "") + if not workdir: + workdir = os.getcwd() + + path = Path(workdir).expanduser() + if path.exists() and path.is_dir(): + return {"ok": True, "message": f"Workspace ready: {path}"} + else: + return {"ok": False, "error": f"Workspace not found: {path}"} + + +async def _test_modal_node(node): + """Test Modal connectivity.""" + try: + import importlib.util + if importlib.util.find_spec("modal") is not None: + return {"ok": True, "message": "Modal client available"} + return {"ok": False, "error": "Modal client not installed"} + except Exception: + return {"ok": False, "error": "Modal client not installed"} diff --git a/backend/openmlr/routes/keys.py b/backend/openmlr/routes/keys.py new file mode 100644 index 0000000..866bc96 --- /dev/null +++ b/backend/openmlr/routes/keys.py @@ -0,0 +1,136 @@ +"""SSH Key routes — CRUD for key assets stored in .keys/.""" + +from fastapi import APIRouter, Depends, HTTPException, Request +from sqlalchemy.ext.asyncio import AsyncSession + +from ..db import operations as ops +from ..db.engine import get_db +from ..db.models import User +from ..dependencies import get_current_user +from ..keys import KeyManager + +router = APIRouter(prefix="/api", tags=["keys"]) + +key_manager = KeyManager() + + +@router.get("/keys") +async def list_keys( + user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + """List all SSH key metadata for the current user.""" + keys = await ops.get_ssh_keys(db, user.id) + return { + "keys": [ + { + "id": k.id, + "filename": k.filename, + "fingerprint": k.fingerprint, + "algorithm": k.algorithm, + "comment": k.comment, + "created_at": k.created_at.isoformat() if k.created_at else None, + } + for k in keys + ] + } + + +@router.post("/keys") +async def create_key( + request: Request, + user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + """Upload or generate an SSH key pair.""" + body = await request.json() + action = body.get("action") + filename = body.get("filename", "") + + if not filename: + raise HTTPException(status_code=400, detail="Missing 'filename'") + + # Prevent path traversal in filename + from pathlib import Path as PyPath + safe_filename = PyPath(filename).name + if not safe_filename or safe_filename.startswith("."): + raise HTTPException(status_code=400, detail="Invalid filename") + + existing = await ops.get_ssh_key_by_filename(db, user.id, safe_filename) + if existing: + raise HTTPException(status_code=409, detail=f"Key '{safe_filename}' already exists") + + if action == "upload": + private_key = body.get("private_key", "") + if not private_key: + raise HTTPException(status_code=400, detail="Missing 'private_key' for upload") + + try: + meta = key_manager.validate_key(private_key) + except ValueError as e: + raise HTTPException(status_code=400, detail=f"Invalid key: {e}") + + key_manager.write_key(safe_filename, private_key) + + elif action == "generate": + algorithm = body.get("algorithm", "ed25519") + comment = body.get("comment", f"openmlr-{user.id}") + try: + key_path, pub_path = key_manager.generate_key_pair(safe_filename, algorithm, comment) + private_key = key_path.read_text() + meta = key_manager.validate_key(private_key) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + + else: + raise HTTPException(status_code=400, detail="action must be 'upload' or 'generate'") + + key = await ops.create_ssh_key( + db, user.id, safe_filename, meta["fingerprint"], + meta["algorithm"], meta["public_key"], body.get("comment"), + ) + + return { + "key": { + "id": key.id, + "filename": key.filename, + "fingerprint": key.fingerprint, + "algorithm": key.algorithm, + "public_key": key.public_key, + "comment": key.comment, + "created_at": key.created_at.isoformat() if key.created_at else None, + } + } + + +@router.delete("/keys/{filename}") +async def delete_key( + filename: str, + user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + """Delete an SSH key and its public counterpart.""" + # Sanitize filename to prevent path traversal + from pathlib import Path as PyPath + safe_filename = PyPath(filename).name + if not safe_filename or safe_filename != filename or safe_filename.startswith("."): + raise HTTPException(status_code=400, detail="Invalid filename") + filename = safe_filename + + # Check if any compute nodes reference this key + nodes = await ops.get_compute_nodes(db, user.id) + dependent_nodes = [n for n in nodes if n.config.get("key_filename") == filename] + + if dependent_nodes: + node_names = ", ".join(n.name for n in dependent_nodes) + raise HTTPException( + status_code=409, + detail=f"Cannot delete key: used by compute nodes: {node_names}" + ) + + deleted_db = await ops.delete_ssh_key(db, user.id, filename) + if not deleted_db: + raise HTTPException(status_code=404, detail="Key not found") + + key_manager.delete_key(filename) + return {"ok": True} diff --git a/backend/openmlr/sandbox/interface.py b/backend/openmlr/sandbox/interface.py index ef61c28..19b904a 100644 --- a/backend/openmlr/sandbox/interface.py +++ b/backend/openmlr/sandbox/interface.py @@ -1,19 +1,9 @@ """Abstract sandbox interface and data types.""" from abc import ABC, abstractmethod -from dataclasses import dataclass, field +from dataclasses import dataclass - -@dataclass -class EnvironmentInfo: - """Information about a sandbox environment.""" - os: str = "unknown" - python_version: str = "unknown" - gpu_available: bool = False - gpu_info: str | None = None - installed_packages: list[str] = field(default_factory=list) - available_disk_gb: float = 0.0 - available_ram_gb: float = 0.0 +from ..compute.capabilities import ComputeCapabilities @dataclass @@ -41,6 +31,23 @@ async def execute(self, command: str, timeout: int = 120) -> ExecutionResult: """Execute a shell command in the sandbox.""" ... + async def execute_stream(self, command: str, timeout: int = 120, on_chunk=None): + """Execute a command and stream output chunks via callback. + + Args: + command: Shell command to execute + timeout: Timeout in seconds + on_chunk: Callback function(text: str, is_stderr: bool) called for each chunk + + Returns: + ExecutionResult with full output + """ + # Default implementation falls back to regular execute + result = await self.execute(command, timeout) + if on_chunk and result.output: + on_chunk(result.output, False) + return result + @abstractmethod async def read_file(self, path: str) -> str: """Read a file from the sandbox filesystem.""" @@ -67,7 +74,7 @@ async def list_files(self, path: str = ".") -> list[str]: ... @abstractmethod - async def probe_environment(self) -> EnvironmentInfo: + async def probe_environment(self) -> ComputeCapabilities: """Probe the sandbox environment for capabilities.""" ... diff --git a/backend/openmlr/sandbox/local.py b/backend/openmlr/sandbox/local.py index 89143bc..4b578d3 100644 --- a/backend/openmlr/sandbox/local.py +++ b/backend/openmlr/sandbox/local.py @@ -2,25 +2,42 @@ import asyncio import os -import platform -import shutil import time from pathlib import Path -from .interface import EnvironmentInfo, ExecutionResult, SandboxInterface +from ..compute.probe import probe_sandbox +from .interface import ExecutionResult, SandboxInterface class LocalSandbox(SandboxInterface): """Execute commands directly on the local machine.""" - def __init__(self, workdir: str = None): + def __init__(self, workdir: str = None, workspace_manager=None): + self._workspace_manager = workspace_manager + self._conversation_uuid = None self.workdir = workdir or os.getcwd() async def create(self, config: dict) -> "LocalSandbox": self.workdir = config.get("workdir", os.getcwd()) + self._conversation_uuid = config.get("conversation_uuid") + + # If workspace manager is available and conversation UUID is set, + # use the per-conversation workspace + if self._workspace_manager and self._conversation_uuid: + ws_path = self._workspace_manager.create_workspace(self._conversation_uuid) + self.workdir = str(ws_path) + elif self._workspace_manager: + # Fallback: create workspace without UUID + ws_path = self._workspace_manager.create_workspace("default") + self.workdir = str(ws_path) + return self async def execute(self, command: str, timeout: int = 120) -> ExecutionResult: + return await self.execute_stream(command, timeout) + + async def execute_stream(self, command: str, timeout: int = 120, on_chunk=None) -> ExecutionResult: + """Execute a command with optional streaming output.""" start = time.monotonic() try: proc = await asyncio.create_subprocess_shell( @@ -29,25 +46,47 @@ async def execute(self, command: str, timeout: int = 120) -> ExecutionResult: stderr=asyncio.subprocess.PIPE, cwd=self.workdir, ) - stdout, stderr = await asyncio.wait_for( - proc.communicate(), timeout=timeout - ) output_parts = [] - if stdout: - output_parts.append(stdout.decode("utf-8", errors="replace")) - if stderr: - output_parts.append(f"STDERR:\n{stderr.decode('utf-8', errors='replace')}") - output = "\n".join(output_parts) if output_parts else "(no output)" + async def _read_stream(stream, is_stderr): + """Read a stream and emit chunks.""" + while True: + try: + line = await asyncio.wait_for(stream.readline(), timeout=0.5) + if not line: + break + text = line.decode("utf-8", errors="replace") + if on_chunk: + on_chunk(text, is_stderr) + output_parts.append(text) + except TimeoutError: + # Check if process is done + if proc.returncode is not None: + break + continue + + # Read stdout and stderr concurrently + await asyncio.gather( + _read_stream(proc.stdout, False), + _read_stream(proc.stderr, True), + ) + + # Wait for process to complete + try: + returncode = await asyncio.wait_for(proc.wait(), timeout=1.0) + except TimeoutError: + returncode = proc.returncode if proc.returncode is not None else -1 + + output = "".join(output_parts) if output_parts else "(no output)" if len(output) > 50000: output = output[:50000] + "\n...[truncated]" duration = time.monotonic() - start return ExecutionResult( output=output, - success=proc.returncode == 0, - exit_code=proc.returncode, + success=returncode == 0, + exit_code=returncode, duration_seconds=duration, ) except TimeoutError: @@ -97,45 +136,8 @@ async def list_files(self, path: str = ".") -> list[str]: for e in target.iterdir() ]) - async def probe_environment(self) -> EnvironmentInfo: - info = EnvironmentInfo( - os=f"{platform.system()} {platform.release()}", - ) - - # Python version - result = await self.execute("python3 --version", timeout=5) - if result.success: - info.python_version = result.output.strip() - - # GPU - result = await self.execute( - "nvidia-smi --query-gpu=name,memory.total --format=csv,noheader", - timeout=5, - ) - if result.success and result.output.strip(): - info.gpu_available = True - info.gpu_info = result.output.strip() - - # Disk - total, used, free = shutil.disk_usage(self.workdir) - info.available_disk_gb = free / (1024 ** 3) - - # RAM - try: - import psutil - info.available_ram_gb = psutil.virtual_memory().available / (1024 ** 3) - except ImportError: - pass - - # Key packages - result = await self.execute("pip list --format=freeze 2>/dev/null | head -30", timeout=10) - if result.success: - info.installed_packages = [ - line.split("==")[0] for line in result.output.strip().split("\n") - if "==" in line - ] - - return info + async def probe_environment(self): + return await probe_sandbox(self) async def destroy(self) -> None: pass # Local sandbox has nothing to clean up diff --git a/backend/openmlr/sandbox/manager.py b/backend/openmlr/sandbox/manager.py index dc054e7..32060a0 100644 --- a/backend/openmlr/sandbox/manager.py +++ b/backend/openmlr/sandbox/manager.py @@ -1,6 +1,5 @@ """SandboxManager — lifecycle management and provider selection.""" - from .interface import SandboxInterface from .local import LocalSandbox from .modal_sandbox import ModalSandbox @@ -10,9 +9,11 @@ class SandboxManager: """Manages sandbox lifecycle: create, switch, destroy.""" - def __init__(self): + def __init__(self, workspace_manager=None, conversation_uuid: str = None): self._active: SandboxInterface | None = None self.active_type: str = "none" + self._workspace_manager = workspace_manager + self._conversation_uuid = conversation_uuid def get_active(self) -> SandboxInterface | None: return self._active @@ -25,8 +26,11 @@ async def create(self, provider: str, config: dict = None) -> SandboxInterface: config = config or {} + # Inject workspace and conversation context + config["conversation_uuid"] = self._conversation_uuid + if provider == "local": - sandbox = LocalSandbox() + sandbox = LocalSandbox(workspace_manager=self._workspace_manager) elif provider == "ssh": sandbox = SSHSandbox() elif provider == "modal": diff --git a/backend/openmlr/sandbox/modal_sandbox.py b/backend/openmlr/sandbox/modal_sandbox.py index 7713a44..ad3277c 100644 --- a/backend/openmlr/sandbox/modal_sandbox.py +++ b/backend/openmlr/sandbox/modal_sandbox.py @@ -3,7 +3,8 @@ import asyncio import time -from .interface import EnvironmentInfo, ExecutionResult, SandboxInterface +from ..compute.probe import probe_sandbox +from .interface import ExecutionResult, SandboxInterface class ModalSandbox(SandboxInterface): @@ -117,7 +118,7 @@ async def read_file(self, path: str) -> str: async def write_file(self, path: str, content: str) -> bool: self._ensure_active() # Use heredoc for safe content transfer - content.replace("'", "'\\''") + content = content.replace("'", "'\\''") result = await self.execute( f"mkdir -p $(dirname '{path}') && cat > '{path}' << 'OPEN_MLR_EOF'\n{content}\nOPEN_MLR_EOF", timeout=10, @@ -143,49 +144,8 @@ async def list_files(self, path: str = ".") -> list[str]: return [] return [line for line in result.output.strip().split("\n") if line] - async def probe_environment(self) -> EnvironmentInfo: - info = EnvironmentInfo() - - result = await self.execute("uname -s -r", timeout=5) - if result.success: - info.os = result.output.strip() - - result = await self.execute("python3 --version", timeout=5) - if result.success: - info.python_version = result.output.strip() - - result = await self.execute( - "nvidia-smi --query-gpu=name,memory.total --format=csv,noheader 2>/dev/null", - timeout=10, - ) - if result.success and result.output.strip(): - info.gpu_available = True - info.gpu_info = result.output.strip() - - result = await self.execute("df -BG --output=avail / 2>/dev/null | tail -1", timeout=5) - if result.success: - try: - info.available_disk_gb = float(result.output.strip().replace("G", "")) - except ValueError: - pass - - result = await self.execute( - "free -g 2>/dev/null | grep Mem | awk '{print $7}'", timeout=5 - ) - if result.success: - try: - info.available_ram_gb = float(result.output.strip()) - except ValueError: - pass - - result = await self.execute("pip list --format=freeze 2>/dev/null | head -30", timeout=10) - if result.success: - info.installed_packages = [ - line.split("==")[0] for line in result.output.strip().split("\n") - if "==" in line - ] - - return info + async def probe_environment(self): + return await probe_sandbox(self) async def destroy(self) -> None: if self._sandbox: diff --git a/backend/openmlr/sandbox/ssh.py b/backend/openmlr/sandbox/ssh.py index aa0c63d..7ef3ee1 100644 --- a/backend/openmlr/sandbox/ssh.py +++ b/backend/openmlr/sandbox/ssh.py @@ -1,9 +1,110 @@ -"""SSH sandbox — remote execution via SSH/SFTP.""" +"""SSH sandbox — remote execution via SSH/SFTP with strict host-key verification +and connection pooling.""" import asyncio +import logging import time -from .interface import EnvironmentInfo, ExecutionResult, SandboxInterface +from ..compute.probe import probe_sandbox +from .interface import ExecutionResult, SandboxInterface + +log = logging.getLogger(__name__) + + +class StrictHostKeyPolicy: + """Paramiko policy that verifies host keys against expected fingerprints.""" + + def __init__(self, expected_fingerprint: str | None = None): + self.expected = expected_fingerprint + self.actual_fingerprint: str | None = None + + def missing_host_key(self, client, hostname, key): + import paramiko + actual = key.get_fingerprint().hex() + self.actual_fingerprint = actual + if self.expected and actual != self.expected.lower().replace(":", "").replace("sha256:", ""): + raise paramiko.SSHException( + f"Host key mismatch for {hostname}: expected {self.expected}, got {actual}" + ) + return + + +class SSHConnectionPool: + """Maintains persistent SSH connections per node with TTL-based eviction. + + Connections are keyed by (host, port, username) and reused across + sandbox instances. Idle connections are closed after ``ttl_seconds``. + """ + + _instance: "SSHConnectionPool | None" = None + + def __init__(self, ttl_seconds: int = 300): + self._connections: dict[str, tuple] = {} # key -> (client, sftp, fingerprint) + self._last_used: dict[str, float] = {} + self._ttl = ttl_seconds + + @classmethod + def get_pool(cls) -> "SSHConnectionPool": + if cls._instance is None: + cls._instance = SSHConnectionPool() + return cls._instance + + @staticmethod + def _make_key(host: str, port: int, username: str) -> str: + return f"{username}@{host}:{port}" + + def get(self, host: str, port: int, username: str): + """Return (client, sftp, fingerprint) if a healthy cached connection exists, else None.""" + key = self._make_key(host, port, username) + entry = self._connections.get(key) + if entry is None: + return None + + client, sftp, fp = entry + try: + transport = client.get_transport() + if transport and transport.is_active(): + self._last_used[key] = time.monotonic() + return client, sftp, fp + except Exception: + pass + + # Connection is dead — clean up + self._evict(key) + return None + + def put(self, host: str, port: int, username: str, client, sftp, fingerprint: str | None): + """Cache a connection for reuse.""" + key = self._make_key(host, port, username) + self._connections[key] = (client, sftp, fingerprint) + self._last_used[key] = time.monotonic() + + def remove(self, host: str, port: int, username: str): + """Remove and close a cached connection.""" + key = self._make_key(host, port, username) + self._evict(key) + + def _evict(self, key: str): + entry = self._connections.pop(key, None) + self._last_used.pop(key, None) + if entry: + client, sftp, _ = entry + try: + sftp.close() + except Exception: + pass + try: + client.close() + except Exception: + pass + + def cleanup_idle(self): + """Close connections idle beyond TTL. Call periodically.""" + now = time.monotonic() + stale = [k for k, t in self._last_used.items() if now - t > self._ttl] + for key in stale: + log.debug(f"SSH pool: evicting idle connection {key}") + self._evict(key) class SSHSandbox(SandboxInterface): @@ -12,33 +113,80 @@ class SSHSandbox(SandboxInterface): def __init__(self): self._client = None self._sftp = None + self._owns_connection = False # True if we created it (not from pool) self.host: str = "" self.port: int = 22 self.username: str = "" - self.key_path: str | None = None + self.key_filename: str | None = None self.password: str | None = None self.workdir: str = "~" + self.host_key_fingerprint: str | None = None + self._key_manager = None async def create(self, config: dict) -> "SSHSandbox": self.host = config.get("host", "") self.port = config.get("port", 22) self.username = config.get("username", "root") - self.key_path = config.get("key_path") + self.key_filename = config.get("key_filename") self.password = config.get("password") self.workdir = config.get("workdir", "~") + self.host_key_fingerprint = config.get("host_key_fingerprint") + self._conversation_uuid = config.get("conversation_uuid") + + if self.key_filename: + from ..keys import KeyManager + self._key_manager = KeyManager() if not self.host: raise ValueError("SSH config requires 'host'") await self._connect() + + # Ensure remote workspace exists if conversation UUID is set + if self._conversation_uuid: + remote_ws = f"{self.workdir}/workspace-{self._conversation_uuid}" + await self._ensure_remote_workspace(remote_ws) + self.workdir = remote_ws + return self + async def _ensure_remote_workspace(self, remote_path: str) -> None: + self._ensure_connected() + + def _do_mkdir(): + subdirs = " ".join(f"{remote_path}/{d}" for d in ["data", "models", "code", "outputs", ".openmlr-meta"]) + cmd = f"mkdir -p {subdirs}" + stdin, stdout, stderr = self._client.exec_command(cmd, timeout=10) + exit_code = stdout.channel.recv_exit_status() + if exit_code != 0: + err = stderr.read().decode("utf-8", errors="replace") + raise RuntimeError(f"Failed to create remote workspace: {err}") + + await asyncio.to_thread(_do_mkdir) + async def _connect(self): - """Establish SSH connection (run in thread to avoid blocking).""" + """Get a connection from the pool or create a new one.""" + pool = SSHConnectionPool.get_pool() + pool.cleanup_idle() + + cached = pool.get(self.host, self.port, self.username) + if cached: + self._client, self._sftp, fp = cached + self._owns_connection = False + if fp and not self.host_key_fingerprint: + self.host_key_fingerprint = fp + log.debug(f"SSH pool: reusing connection to {self.username}@{self.host}:{self.port}") + return + def _do_connect(): import paramiko client = paramiko.SSHClient() - client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + + if self.host_key_fingerprint: + policy = StrictHostKeyPolicy(self.host_key_fingerprint) + client.set_missing_host_key_policy(policy) + else: + client.set_missing_host_key_policy(paramiko.WarningPolicy()) connect_kwargs = { "hostname": self.host, @@ -47,35 +195,84 @@ def _do_connect(): "timeout": 30, } - if self.key_path: - connect_kwargs["key_filename"] = self.key_path + if self.key_filename and self._key_manager: + key_path = self._key_manager.get_key_path(self.key_filename) + connect_kwargs["key_filename"] = str(key_path) elif self.password: connect_kwargs["password"] = self.password client.connect(**connect_kwargs) sftp = client.open_sftp() - return client, sftp - self._client, self._sftp = await asyncio.to_thread(_do_connect) + actual_fp = None + transport = client.get_transport() + if transport: + remote_key = transport.get_remote_server_key() + if remote_key: + actual_fp = remote_key.get_fingerprint().hex() + + return client, sftp, actual_fp + + self._client, self._sftp, actual_fp = await asyncio.to_thread(_do_connect) + self._owns_connection = True + + if actual_fp and not self.host_key_fingerprint: + self.host_key_fingerprint = actual_fp + + # Put the new connection into the pool + pool.put(self.host, self.port, self.username, self._client, self._sftp, actual_fp) def _ensure_connected(self): if not self._client or not self._client.get_transport() or not self._client.get_transport().is_active(): raise RuntimeError("SSH connection lost. Recreate the sandbox.") async def execute(self, command: str, timeout: int = 120) -> ExecutionResult: + return await self.execute_stream(command, timeout) + + async def execute_stream(self, command: str, timeout: int = 120, on_chunk=None) -> ExecutionResult: self._ensure_connected() start = time.monotonic() - def _do_exec(): + def _do_exec_stream(): full_cmd = f"cd {self.workdir} && {command}" stdin, stdout, stderr = self._client.exec_command(full_cmd, timeout=timeout) - exit_code = stdout.channel.recv_exit_status() - out = stdout.read().decode("utf-8", errors="replace") - err = stderr.read().decode("utf-8", errors="replace") - return out, err, exit_code + + out_buf = [] + err_buf = [] + channel = stdout.channel + + while not channel.exit_status_ready(): + if channel.recv_ready(): + data = channel.recv(4096).decode("utf-8", errors="replace") + out_buf.append(data) + if on_chunk: + on_chunk(data, False) + + if channel.recv_stderr_ready(): + data = channel.recv_stderr(4096).decode("utf-8", errors="replace") + err_buf.append(data) + if on_chunk: + on_chunk(data, True) + + time.sleep(0.05) + + while channel.recv_ready(): + data = channel.recv(4096).decode("utf-8", errors="replace") + out_buf.append(data) + if on_chunk: + on_chunk(data, False) + + while channel.recv_stderr_ready(): + data = channel.recv_stderr(4096).decode("utf-8", errors="replace") + err_buf.append(data) + if on_chunk: + on_chunk(data, True) + + exit_code = channel.recv_exit_status() + return "".join(out_buf), "".join(err_buf), exit_code try: - out, err, exit_code = await asyncio.to_thread(_do_exec) + out, err, exit_code = await asyncio.to_thread(_do_exec_stream) output_parts = [] if out: output_parts.append(out) @@ -155,64 +352,12 @@ def _do_list(): return await asyncio.to_thread(_do_list) - async def probe_environment(self) -> EnvironmentInfo: - info = EnvironmentInfo() - - result = await self.execute("uname -s -r", timeout=5) - if result.success: - info.os = result.output.strip() - - result = await self.execute("python3 --version", timeout=5) - if result.success: - info.python_version = result.output.strip() - - result = await self.execute( - "nvidia-smi --query-gpu=name,memory.total --format=csv,noheader 2>/dev/null", - timeout=5, - ) - if result.success and result.output.strip(): - info.gpu_available = True - info.gpu_info = result.output.strip() - - result = await self.execute("df -BG --output=avail / 2>/dev/null | tail -1", timeout=5) - if result.success: - try: - info.available_disk_gb = float(result.output.strip().replace("G", "")) - except ValueError: - pass - - result = await self.execute( - "free -g 2>/dev/null | grep Mem | awk '{print $7}'", - timeout=5, - ) - if result.success: - try: - info.available_ram_gb = float(result.output.strip()) - except ValueError: - pass - - result = await self.execute( - "pip list --format=freeze 2>/dev/null | head -30", - timeout=10, - ) - if result.success: - info.installed_packages = [ - line.split("==")[0] for line in result.output.strip().split("\n") - if "==" in line - ] - - return info + async def probe_environment(self): + return await probe_sandbox(self) async def destroy(self) -> None: - if self._sftp: - try: - self._sftp.close() - except Exception: - pass - if self._client: - try: - self._client.close() - except Exception: - pass + # Don't close pooled connections — they'll be reused. + # Only close if we own the connection and it's not pooled. + # The pool handles TTL-based eviction. self._client = None self._sftp = None diff --git a/backend/openmlr/services/session_manager.py b/backend/openmlr/services/session_manager.py index adfde1f..d59d7b1 100644 --- a/backend/openmlr/services/session_manager.py +++ b/backend/openmlr/services/session_manager.py @@ -85,8 +85,46 @@ async def get_or_create_session( ) session = Session(config=config, conversation_id=conversation_id) - sandbox_manager = SandboxManager() + + # Determine effective compute node + effective_node = None + if user_id and db: + try: + from ..db import operations as ops + # Check conversation override + conv = await ops.get_conversation_by_id(db, conversation_id) + if conv and conv.extra: + override_node_id = conv.extra.get("compute_node_id") + if override_node_id: + effective_node = await ops.get_compute_node_by_id(db, override_node_id, user_id) + + # Fall back to user default + if not effective_node: + effective_node = await ops.get_default_compute_node(db, user_id) + + if effective_node: + log.info(f"Session {conversation_id}: using compute node '{effective_node.name}' ({effective_node.type})") + except Exception as e: + log.warning(f"Session {conversation_id}: failed to load compute node - {e}") + + # Initialize workspace manager and sandbox manager + from ..compute import WorkspaceManager + workspace_manager = WorkspaceManager() + sandbox_manager = SandboxManager( + workspace_manager=workspace_manager, + conversation_uuid=uuid, + ) + + # If a compute node is configured, activate it + if effective_node: + try: + await sandbox_manager.create(effective_node.type, effective_node.config) + except Exception as e: + log.warning(f"Session {conversation_id}: failed to create sandbox for node '{effective_node.name}' - {e}") + tool_router = create_tool_router(sandbox_manager) + # Inject user/db context for compute tools + tool_router.set_context(user_id=user_id, db=db) mcp_manager = MCPManager() # Load MCP servers from user settings if available @@ -108,11 +146,53 @@ async def get_or_create_session( except Exception as e: log.warning(f"Session {conversation_id}: failed to load MCP servers - {e}") + # Build compute environment info for system prompt + compute_env = "" + if effective_node: + caps = effective_node.capabilities or {} + lines = [f"\n## Active Compute Environment: {effective_node.name} ({effective_node.type})"] + if caps.get("platform"): + lines.append(f"- Platform: {caps['platform']}") + if caps.get("cpu_cores"): + lines.append(f"- CPU: {caps['cpu_cores']} cores ({caps.get('cpu_arch', 'unknown')})") + if caps.get("available_ram_gb"): + lines.append(f"- RAM: {caps['available_ram_gb']:.1f} GB available") + if caps.get("gpu_available"): + gpu_info = caps.get("gpu_info", []) + for gpu in gpu_info[:1]: + lines.append(f"- GPU: {gpu.get('model', 'unknown')} ({gpu.get('vram_gb', 0):.0f} GB VRAM)") + if gpu.get("cuda_version"): + lines.append(f" - CUDA: {gpu['cuda_version']}") + if caps.get("python_versions"): + lines.append(f"- Python: {', '.join(caps['python_versions'])}") + if caps.get("docker_available"): + lines.append("- Docker: available") + if caps.get("installed_packages"): + pkgs = caps["installed_packages"][:10] + lines.append(f"- Key packages: {', '.join(pkgs)}") + + # Add available nodes for context + all_nodes = [] + if user_id and db: + try: + all_nodes = await ops.get_compute_nodes(db, user_id) + except Exception: + pass + if len(all_nodes) > 1: + lines.append("\n### Other Available Nodes") + for node in all_nodes: + if node.id != effective_node.id: + status = "online" if node.health_status == "online" else "offline" + lines.append(f"- {node.name} ({node.type}): {status}") + + compute_env = "\n".join(lines) + # Build and set system prompt (after MCP tools are registered) session.context_manager.system_prompt = build_system_prompt( tool_specs=tool_router.get_raw_specs(), mode=mode, username=username, + compute_env=compute_env, ) # Wire event broadcasting diff --git a/backend/openmlr/tasks/compute_tasks.py b/backend/openmlr/tasks/compute_tasks.py new file mode 100644 index 0000000..e369286 --- /dev/null +++ b/backend/openmlr/tasks/compute_tasks.py @@ -0,0 +1,113 @@ +"""Compute background tasks — health checks and periodic maintenance.""" + +import asyncio +import logging +from datetime import UTC, datetime + +from ..celery_app import celery_app +from ..compute import WorkspaceManager +from ..db import operations as ops +from ..db.engine import get_worker_session + +logger = logging.getLogger(__name__) + + +@celery_app.task +def cleanup_old_workspaces(): + """Clean up old workspace archives and orphaned workspaces.""" + wm = WorkspaceManager() + + # Clean old archives + archive_result = wm.cleanup_archives(max_age_days=30, max_count=100) + logger.info( + f"Archive cleanup: deleted {archive_result['deleted']} archives, " + f"freed {archive_result['freed_bytes'] / (1024**3):.1f} GB" + ) + + # Clean orphaned workspaces (conversations that no longer exist) + async def _cleanup_orphaned(): + session_factory = get_worker_session() + async with session_factory() as db: + from sqlalchemy import select + + from ..db.models import Conversation + result = await db.execute(select(Conversation.uuid)) + active_uuids = {row[0] for row in result.all()} + + ws_result = wm.cleanup_workspaces( + conversation_uuids=list(active_uuids), + archive=True, + ) + logger.info( + f"Workspace cleanup: deleted {ws_result['deleted']} workspaces, " + f"freed {ws_result['freed_bytes'] / (1024**3):.1f} GB" + ) + + asyncio.run(_cleanup_orphaned()) + + +@celery_app.task(bind=True, max_retries=3) +def check_compute_node_health(self, node_id: int, user_id: int): + """Check health of a single compute node.""" + async def _check(): + session_factory = get_worker_session() + async with session_factory() as db: + node = await ops.get_compute_node_by_id(db, node_id, user_id) + if not node: + logger.warning(f"Node {node_id} not found for health check") + return + + from ..compute.probe import probe_sandbox + from ..sandbox.manager import SandboxManager + + sm = SandboxManager(workspace_manager=WorkspaceManager()) + try: + await sm.create(node.type, node.config) + sandbox = sm.get_active() + + if sandbox: + caps = await probe_sandbox(sandbox) + await ops.update_compute_node( + db, node.id, user_id, + capabilities=caps.to_dict(), + health_status="online", + last_seen_at=datetime.now(UTC), + ) + logger.info(f"Health check passed for node '{node.name}'") + else: + await ops.update_compute_node( + db, node.id, user_id, + health_status="offline", + ) + logger.warning(f"Health check failed for node '{node.name}': sandbox not created") + except Exception as e: + await ops.update_compute_node( + db, node.id, user_id, + health_status="offline", + ) + logger.warning(f"Health check failed for node '{node.name}': {e}") + finally: + await sm.destroy() + + asyncio.run(_check()) + + +@celery_app.task +def health_check_all_nodes(): + """Run health checks on all compute nodes for all users.""" + async def _check_all(): + session_factory = get_worker_session() + async with session_factory() as db: + from sqlalchemy import select + + from ..db.models import User + result = await db.execute(select(User)) + users = result.scalars().all() + + for user in users: + nodes = await ops.get_compute_nodes(db, user.id) + for node in nodes: + check_compute_node_health.delay(node.id, user.id) + + asyncio.run(_check_all()) + logger.info("Queued health checks for all compute nodes") diff --git a/backend/openmlr/tools/compute_tools.py b/backend/openmlr/tools/compute_tools.py new file mode 100644 index 0000000..2d6c0c9 --- /dev/null +++ b/backend/openmlr/tools/compute_tools.py @@ -0,0 +1,527 @@ +"""Compute tools — agent-facing tools for compute node discovery and selection.""" + +import asyncio +import io +import os +from datetime import UTC, datetime +from pathlib import Path + +from ..agent.types import ToolSpec +from ..compute.probe import probe_sandbox + + +def _validate_sync_path(workspace: Path, rel_path: str) -> tuple[Path, str | None]: + """Validate that a relative path stays within the workspace. Returns (resolved, error).""" + target = (workspace / rel_path).resolve() + try: + target.relative_to(workspace.resolve()) + except ValueError: + return target, f"Path '{rel_path}' escapes workspace boundary" + return target, None + + +async def _handle_list(user_id: int = None, db=None, **kwargs): + """List all compute nodes with capabilities.""" + if not db: + return "Database connection required for compute_list", False + + from ..db import operations as ops + nodes = await ops.get_compute_nodes(db, user_id) + + if not nodes: + return "No compute nodes configured. Add nodes in Settings > Compute.", True + + lines = ["## Available Compute Nodes\n"] + for node in nodes: + caps = node.capabilities or {} + status = "●" if node.health_status == "online" else "○" + gpu = "" + if caps.get("gpu_available"): + gpu_info = caps.get("gpu_info", []) + if gpu_info: + gpu = f" — GPU: {gpu_info[0].get('model', 'unknown')}" + else: + gpu = " — GPU: yes" + + ram = "" + if caps.get("available_ram_gb"): + ram = f" — RAM: {caps['available_ram_gb']:.0f}GB" + + default = " ★" if node.is_default else "" + lines.append(f"{status} {node.name} ({node.type}){default}{gpu}{ram}") + + return "\n".join(lines), True + + +async def _handle_probe(node_name: str, user_id: int = None, db=None, **kwargs): + """Probe a compute node for capabilities.""" + if not db: + return "Database connection required for compute_probe", False + + from ..db import operations as ops + node = await ops.get_compute_node_by_name(db, user_id, node_name) + if not node: + return f"Node '{node_name}' not found", False + + # Create sandbox and probe + from ..compute import WorkspaceManager + from ..sandbox.manager import SandboxManager + + try: + wm = WorkspaceManager() + sm = SandboxManager(workspace_manager=wm) + await sm.create(node.type, node.config) + sandbox = sm.get_active() + if not sandbox: + return f"Failed to create sandbox for {node_name}", False + + caps = await probe_sandbox(sandbox) + + # Update node in database + await ops.update_compute_node( + db, node.id, user_id, + capabilities=caps.to_dict(), + health_status="online", + last_probed_at=datetime.now(UTC), + ) + + await sm.destroy() + + # Format response + lines = [f"## {node.name} Capabilities\n"] + lines.append(f"Platform: {caps.platform}") + lines.append(f"CPU: {caps.cpu_cores} cores ({caps.cpu_arch})") + lines.append(f"RAM: {caps.available_ram_gb:.1f} GB available / {caps.total_ram_gb:.1f} GB total") + lines.append(f"Disk: {caps.available_disk_gb:.1f} GB available / {caps.total_disk_gb:.1f} GB total") + + if caps.gpu_available: + for gpu in caps.gpu_info: + lines.append(f"GPU: {gpu.model} ({gpu.vram_gb:.0f} GB VRAM)") + if gpu.cuda_version: + lines.append(f" CUDA: {gpu.cuda_version}, Driver: {gpu.driver_version}") + + if caps.python_versions: + lines.append(f"Python: {', '.join(caps.python_versions)}") + + if caps.docker_available: + lines.append("Docker: available") + + if caps.installed_packages: + lines.append(f"\nKey packages: {', '.join(caps.installed_packages[:10])}") + if len(caps.installed_packages) > 10: + lines.append(f"... and {len(caps.installed_packages) - 10} more") + + return "\n".join(lines), True + + except Exception as e: + try: + await sm.destroy() + except Exception: + pass + await ops.update_compute_node( + db, node.id, user_id, + health_status="offline", + ) + return f"Probe failed for {node_name}: {str(e)}", False + + +async def _handle_select(node_name: str, user_id: int = None, db=None, session=None, **kwargs): + """Select a compute node as active for this conversation.""" + if not db: + return "Database connection required for compute_select", False + + from ..db import operations as ops + node = await ops.get_compute_node_by_name(db, user_id, node_name) + if not node: + return f"Node '{node_name}' not found", False + + # If session is provided, update the active sandbox + if session and hasattr(session, 'conversation_id'): + # Update conversation extra + conv_id = session.conversation_id + conv = await ops.get_conversation_by_id(db, conv_id) + if conv: + extra = conv.extra or {} + extra["compute_node_id"] = node.id + extra["compute_node_name"] = node.name + await ops.update_conversation_extra(db, conv_id, extra) + + return f"Active compute switched to: {node.name} ({node.type})", True + + +async def _handle_plan(task: str, requirements: dict = None, user_id: int = None, db=None, **kwargs): + """Recommend the best compute node for a task.""" + if not db: + return "Database connection required for compute_plan", False + + requirements = requirements or {} + from ..db import operations as ops + nodes = await ops.get_compute_nodes(db, user_id) + + if not nodes: + return "No compute nodes configured.", False + + # Score each node + scores = [] + for node in nodes: + if node.health_status != "online": + continue + + caps = node.capabilities or {} + score = 0 + reasons = [] + + # GPU requirement + if requirements.get("gpu"): + if not caps.get("gpu_available"): + continue + score += 10 + vram = 0 + for gpu in caps.get("gpu_info", []): + vram = max(vram, gpu.get("vram_gb", 0)) + min_vram = requirements.get("min_vram_gb", 0) + if vram < min_vram: + continue + score += min(vram / 10, 5) + reasons.append(f"GPU with {vram:.0f}GB VRAM") + + # RAM requirement + min_ram = requirements.get("min_ram_gb", 0) + available_ram = caps.get("available_ram_gb", 0) + if available_ram < min_ram: + continue + score += min(available_ram / max(min_ram, 1), 3) + if available_ram > 0: + reasons.append(f"{available_ram:.0f}GB RAM") + + # Disk requirement + min_disk = requirements.get("min_disk_gb", 0) + available_disk = caps.get("available_disk_gb", 0) + if available_disk < min_disk: + continue + if available_disk > 0: + reasons.append(f"{available_disk:.0f}GB disk") + + # Prefer local > ssh > modal + if node.type == "local": + score += 5 + reasons.append("local (low latency)") + elif node.type == "ssh": + score += 2 + reasons.append("ssh (LAN)") + elif node.type == "modal": + reasons.append("modal (cloud)") + + scores.append({ + "node": node, + "score": score, + "reasons": reasons, + }) + + if not scores: + return "No compute nodes meet the requirements.", False + + scores.sort(key=lambda x: x["score"], reverse=True) + best = scores[0] + + lines = [f"## Recommended Compute for: {task}\n"] + lines.append(f"**Best choice: {best['node'].name}** ({best['node'].type})") + lines.append(f"Score: {best['score']:.1f}") + lines.append(f"Reasons: {', '.join(best['reasons'])}") + + if len(scores) > 1: + lines.append("\n### Alternatives") + for alt in scores[1:3]: + lines.append(f"- {alt['node'].name} (score: {alt['score']:.1f}, {', '.join(alt['reasons'])})") + + return "\n".join(lines), True + + +async def _get_sync_context(user_id, db, session): + """Helper: resolve conversation UUID and workspace path for sync ops.""" + from ..db import operations as ops + conv_uuid = None + if session and hasattr(session, 'conversation_id'): + conv = await ops.get_conversation_by_id(db, session.conversation_id) + if conv: + conv_uuid = conv.uuid + if not conv_uuid: + return None, None, "No active conversation workspace found" + from ..compute import WorkspaceManager + wm = WorkspaceManager() + local_ws = wm.get_workspace_path(conv_uuid) + return conv_uuid, local_ws, None + + +async def _handle_sync_up(paths: list, node_name: str, user_id: int = None, db=None, session=None, **kwargs): + """Sync files from local workspace to remote compute node.""" + if not db: + return "Database connection required", False + + from ..db import operations as ops + node = await ops.get_compute_node_by_name(db, user_id, node_name) + if not node: + return f"Node '{node_name}' not found", False + + conv_uuid, local_ws, err = await _get_sync_context(user_id, db, session) + if err: + return err, False + if not local_ws.exists(): + return f"Local workspace not found: {local_ws}", False + + if node.type == "local": + return "Local sync: files are already in the same workspace", True + + elif node.type == "ssh": + from ..sandbox.ssh import SSHSandbox + ssh_sandbox = SSHSandbox() + try: + config = dict(node.config) + config["conversation_uuid"] = conv_uuid + await ssh_sandbox.create(config) + + transferred = 0 + for rel_path in paths: + # Path traversal check + local_path, path_err = _validate_sync_path(local_ws, rel_path) + if path_err: + return path_err, False + if not local_path.exists(): + continue + + remote_base = ssh_sandbox.workdir + + if local_path.is_dir(): + for root, _, files in os.walk(local_path): + for file in files: + src = Path(root) / file + rel = src.relative_to(local_ws) + dst = f"{remote_base}/{rel}" + dst_dir = str(Path(dst).parent) + await ssh_sandbox.execute(f"mkdir -p '{dst_dir}'", timeout=5) + content = src.read_bytes() + await asyncio.to_thread( + lambda d=dst, c=content: ssh_sandbox._sftp.putfo(io.BytesIO(c), d) + ) + transferred += 1 + else: + rel = local_path.relative_to(local_ws) + dst = f"{remote_base}/{rel}" + dst_dir = str(Path(dst).parent) + await ssh_sandbox.execute(f"mkdir -p '{dst_dir}'", timeout=5) + content = local_path.read_bytes() + await asyncio.to_thread( + lambda d=dst, c=content: ssh_sandbox._sftp.putfo(io.BytesIO(c), d) + ) + transferred += 1 + + return f"Synced {transferred} item(s) to {node.name}", True + except Exception as e: + return f"Sync failed: {str(e)}", False + finally: + await ssh_sandbox.destroy() + + elif node.type == "modal": + return "File sync not supported for Modal nodes (ephemeral)", False + + return "Unsupported node type", False + + +async def _handle_sync_down(paths: list, node_name: str, user_id: int = None, db=None, session=None, **kwargs): + """Sync files from remote compute node to local workspace.""" + if not db: + return "Database connection required", False + + from ..db import operations as ops + node = await ops.get_compute_node_by_name(db, user_id, node_name) + if not node: + return f"Node '{node_name}' not found", False + + conv_uuid, local_ws, err = await _get_sync_context(user_id, db, session) + if err: + return err, False + local_ws.mkdir(parents=True, exist_ok=True) + + if node.type == "local": + return "Local sync: files are already in the same workspace", True + + elif node.type == "ssh": + from ..sandbox.ssh import SSHSandbox + ssh_sandbox = SSHSandbox() + try: + config = dict(node.config) + config["conversation_uuid"] = conv_uuid + await ssh_sandbox.create(config) + + transferred = 0 + for rel_path in paths: + # Path traversal check + local_path, path_err = _validate_sync_path(local_ws, rel_path) + if path_err: + return path_err, False + + remote_path = f"{ssh_sandbox.workdir}/{rel_path}" + + # Check remote type + result = await ssh_sandbox.execute( + f"test -d '{remote_path}' && echo dir || test -f '{remote_path}' && echo file || echo none", + timeout=5, + ) + remote_type = result.output.strip() + if remote_type == "none": + continue + + if remote_type == "file": + local_path.parent.mkdir(parents=True, exist_ok=True) + rp = remote_path # bind for closure + + def _do_get(rpath=rp): + buf = io.BytesIO() + ssh_sandbox._sftp.getfo(rpath, buf) + buf.seek(0) + return buf.read() + + data = await asyncio.to_thread(_do_get) + local_path.write_bytes(data) + transferred += 1 + + elif remote_type == "dir": + result = await ssh_sandbox.execute(f"find '{remote_path}' -type f", timeout=10) + remote_files = [ln.strip() for ln in result.output.strip().split("\n") if ln.strip()] + for rf in remote_files: + rel = rf.replace(remote_path + "/", "", 1) + dst = local_path / rel + # Path traversal check on each individual file + _, inner_err = _validate_sync_path(local_ws, str(Path(rel_path) / rel)) + if inner_err: + continue + dst.parent.mkdir(parents=True, exist_ok=True) + + # Bind rf in default arg to avoid closure-in-loop bug + def _do_get_file(rpath=rf): + buf = io.BytesIO() + ssh_sandbox._sftp.getfo(rpath, buf) + buf.seek(0) + return buf.read() + + data = await asyncio.to_thread(_do_get_file) + dst.write_bytes(data) + transferred += 1 + + return f"Synced {transferred} item(s) from {node.name}", True + except Exception as e: + return f"Sync failed: {str(e)}", False + finally: + await ssh_sandbox.destroy() + + elif node.type == "modal": + return "File sync not supported for Modal nodes (ephemeral)", False + + return "Unsupported node type", False + + +def create_compute_tools() -> list[ToolSpec]: + """Create agent tools for compute node management.""" + return [ + ToolSpec( + name="compute_list", + description="List all configured compute nodes with their capabilities and health status.", + parameters={"type": "object", "properties": {}}, + handler=_handle_list, + ), + ToolSpec( + name="compute_probe", + description="Probe a compute node to discover its capabilities (CPU, GPU, RAM, installed packages).", + parameters={ + "type": "object", + "properties": { + "node_name": { + "type": "string", + "description": "Name of the compute node to probe", + }, + }, + "required": ["node_name"], + }, + handler=_handle_probe, + ), + ToolSpec( + name="compute_select", + description="Switch the active compute node for this conversation. Use this before running tasks that need specific hardware.", + parameters={ + "type": "object", + "properties": { + "node_name": { + "type": "string", + "description": "Name of the compute node to activate", + }, + }, + "required": ["node_name"], + }, + handler=_handle_select, + ), + ToolSpec( + name="compute_plan", + description="Recommend the best compute node for a given task based on requirements.", + parameters={ + "type": "object", + "properties": { + "task": { + "type": "string", + "description": "Description of the task (e.g., 'Train a ResNet-50 with mixed precision')", + }, + "requirements": { + "type": "object", + "description": "Hardware requirements", + "properties": { + "gpu": {"type": "boolean", "description": "GPU required"}, + "min_vram_gb": {"type": "number", "description": "Minimum GPU VRAM in GB"}, + "min_ram_gb": {"type": "number", "description": "Minimum RAM in GB"}, + "min_disk_gb": {"type": "number", "description": "Minimum free disk in GB"}, + }, + }, + }, + "required": ["task"], + }, + handler=_handle_plan, + ), + ToolSpec( + name="compute_sync_up", + description="Sync files from local workspace to a remote compute node. Use before running code that needs data on the remote.", + parameters={ + "type": "object", + "properties": { + "paths": { + "type": "array", + "items": {"type": "string"}, + "description": "Relative paths to sync (e.g., ['data/', 'code/train.py'])", + }, + "node_name": { + "type": "string", + "description": "Name of the target compute node", + }, + }, + "required": ["paths", "node_name"], + }, + handler=_handle_sync_up, + ), + ToolSpec( + name="compute_sync_down", + description="Sync files from a remote compute node to local workspace. Use after training to download models, logs, and results.", + parameters={ + "type": "object", + "properties": { + "paths": { + "type": "array", + "items": {"type": "string"}, + "description": "Relative paths to sync (e.g., ['models/', 'outputs/'])", + }, + "node_name": { + "type": "string", + "description": "Name of the source compute node", + }, + }, + "required": ["paths", "node_name"], + }, + handler=_handle_sync_down, + ), + ] diff --git a/backend/openmlr/tools/registry.py b/backend/openmlr/tools/registry.py index cbd3184..67f9997 100644 --- a/backend/openmlr/tools/registry.py +++ b/backend/openmlr/tools/registry.py @@ -17,6 +17,8 @@ "github_search", "github_read_file", "github_read_repo", "github_find_examples", "github_search_repos", "github_get_readme", "github_list_repos", + # Compute planning (read-only / advisory) + "compute_list", "compute_plan", "compute_probe", }, "blocked_message": ( "Tool '{tool}' is not available in PLAN mode. " @@ -43,6 +45,13 @@ def __init__(self): self._mcp_client = None self._blocklist: set[str] = set() self._current_mode: str = "general" + self._user_id: int | None = None + self._db = None + + def set_context(self, user_id: int | None = None, db=None) -> None: + """Set per-request context (user_id, db) for tools that need them.""" + self._user_id = user_id + self._db = db def register(self, spec: ToolSpec) -> None: """Register a tool.""" @@ -157,6 +166,11 @@ async def call_tool( # Also pass tool_call_id if the handler accepts it if "tool_call_id" in sig.parameters and "tool_call_id" not in kwargs: kwargs["tool_call_id"] = kwargs.pop("id", "") + # Inject user_id and db for tools that need them (compute tools) + if "user_id" in sig.parameters and "user_id" not in kwargs: + kwargs["user_id"] = self._user_id + if "db" in sig.parameters and "db" not in kwargs: + kwargs["db"] = self._db try: return await tool.handler(**kwargs) if kwargs else await tool.handler(**arguments) except TypeError as e: @@ -237,6 +251,10 @@ def create_tool_router(sandbox_manager=None) -> ToolRouter: router.register(create_writing_tool()) router.register(create_ask_user_tool()) + # Register compute tools + from .compute_tools import create_compute_tools + router.register_many(create_compute_tools()) + # Register sandbox tools if manager provided if sandbox_manager: from .sandbox_tools import create_sandbox_tools diff --git a/backend/openmlr/tools/sandbox_tools.py b/backend/openmlr/tools/sandbox_tools.py index c73322a..2792927 100644 --- a/backend/openmlr/tools/sandbox_tools.py +++ b/backend/openmlr/tools/sandbox_tools.py @@ -1,6 +1,8 @@ """Sandbox tools — expose execution environments to the agent.""" -from ..agent.types import ToolSpec +import asyncio + +from ..agent.types import AgentEvent, ToolSpec def create_sandbox_tools(sandbox_manager) -> list[ToolSpec]: @@ -56,13 +58,14 @@ def create_sandbox_tools(sandbox_manager) -> list[ToolSpec]: name="sandbox_exec", description=( "Execute a command in the active sandbox. If no sandbox is active, " - "falls back to local execution." + "falls back to local execution. Use stream=true for long-running commands." ), parameters={ "type": "object", "properties": { "command": {"type": "string", "description": "Shell command to execute"}, - "timeout": {"type": "integer", "description": "Timeout in seconds (default 120)"}, + "timeout": {"type": "integer", "description": "Timeout in seconds (default 120, max 3600)"}, + "stream": {"type": "boolean", "description": "Stream output in real-time for long-running commands (default false)"}, }, "required": ["command"], }, @@ -102,17 +105,24 @@ async def _handle_probe(sandbox_manager, session=None, **kwargs) -> tuple[str, b return "No active sandbox. Using local environment.\n" + await _local_probe(), True try: - info = await sandbox.probe_environment() + caps = await sandbox.probe_environment() lines = [ f"## Sandbox Environment ({sandbox_manager.active_type})\n", - f"OS: {info.os}", - f"Python: {info.python_version}", - f"GPU: {'Yes — ' + info.gpu_info if info.gpu_available else 'No'}", - f"Disk: {info.available_disk_gb:.1f} GB free", - f"RAM: {info.available_ram_gb:.1f} GB free", + f"Platform: {caps.platform}", + f"CPU: {caps.cpu_cores} cores ({caps.cpu_arch})", + f"Python: {', '.join(caps.python_versions) if caps.python_versions else 'unknown'}", ] - if info.installed_packages: - lines.append(f"\nKey packages: {', '.join(info.installed_packages[:20])}") + if caps.gpu_available and caps.gpu_info: + for gpu in caps.gpu_info: + lines.append(f"GPU: {gpu.model} ({gpu.vram_gb:.0f} GB VRAM)") + elif caps.gpu_available: + lines.append("GPU: available") + else: + lines.append("GPU: not available") + lines.append(f"Disk: {caps.available_disk_gb:.1f} GB free") + lines.append(f"RAM: {caps.available_ram_gb:.1f} GB free") + if caps.installed_packages: + lines.append(f"\nKey packages: {', '.join(caps.installed_packages[:20])}") return "\n".join(lines), True except Exception as e: return f"Probe failed: {str(e)}", False @@ -156,7 +166,7 @@ async def _handle_create(sandbox_manager, provider: str, config: dict = None, se return f"Failed to create sandbox: {str(e)}", False -async def _handle_exec(sandbox_manager, command: str, timeout: int = 120, session=None, **kwargs) -> tuple[str, bool]: +async def _handle_exec(sandbox_manager, command: str, timeout: int = 120, stream: bool = False, session=None, **kwargs) -> tuple[str, bool]: sandbox = sandbox_manager.get_active() if not sandbox: # Fall back to local execution @@ -164,8 +174,25 @@ async def _handle_exec(sandbox_manager, command: str, timeout: int = 120, sessio return await _handle_bash(command=command, timeout=timeout) try: - result = await sandbox.execute(command, timeout=timeout) - return result.output, result.success + if stream and session: + # Stream output via tool_log events + # on_chunk may be called from a worker thread (SSH), so use + # call_soon_threadsafe to schedule the coroutine on the event loop. + loop = asyncio.get_running_loop() + + def on_chunk(text: str, is_stderr: bool): + prefix = "STDERR: " if is_stderr else "" + event = AgentEvent( + event_type="tool_log", + data={"message": f"{prefix}{text.rstrip()}"}, + ) + loop.call_soon_threadsafe(asyncio.ensure_future, session.emit(event)) + + result = await sandbox.execute_stream(command, timeout=timeout, on_chunk=on_chunk) + return result.output, result.success + else: + result = await sandbox.execute(command, timeout=timeout) + return result.output, result.success except Exception as e: return f"Execution error: {str(e)}", False diff --git a/backend/openmlr/tools/writing.py b/backend/openmlr/tools/writing.py index 827fccc..d79b21b 100644 --- a/backend/openmlr/tools/writing.py +++ b/backend/openmlr/tools/writing.py @@ -55,7 +55,7 @@ async def _get_author_info(db, conv_id: int) -> dict | None: conv = await ops.get_conversation(db, conv_id) if not conv or not conv.user_id: return None - + # Fetch author-related settings author_info = {} for key in ["author_name", "author_email", "author_affiliation", "author_orcid"]: @@ -63,7 +63,7 @@ async def _get_author_info(db, conv_id: int) -> dict | None: if setting: field = key.replace("author_", "") author_info[field] = setting - + return author_info if author_info else None @@ -320,21 +320,21 @@ async def _get_draft(conv_id: int) -> tuple[str, bool]: proj = _get_project(conv_id) if not proj: return "No paper project exists.", False - + # Fetch author info author_info = None if conv_id: session_factory = _get_session_factory() async with session_factory() as db: author_info = await _get_author_info(db, conv_id) - + return _get_draft_from_proj(proj, author_info) def _get_draft_from_proj(proj: dict, author_info: dict | None = None) -> tuple[str, bool]: """Generate the full markdown draft from a project dict.""" lines = [f"# {proj['title']}\n"] - + # Add author information block if available if author_info: author_lines = [] @@ -346,7 +346,7 @@ def _get_draft_from_proj(proj: dict, author_info: dict | None = None) -> tuple[s author_lines.append(f"Email: {author_info['email']}") if author_info.get("orcid"): author_lines.append(f"ORCID: [{author_info['orcid']}](https://orcid.org/{author_info['orcid']})") - + if author_lines: lines.append("\n".join(author_lines)) lines.append("\n---\n") diff --git a/backend/tests/test_compute.py b/backend/tests/test_compute.py new file mode 100644 index 0000000..d82d46d --- /dev/null +++ b/backend/tests/test_compute.py @@ -0,0 +1,693 @@ +"""Tests for the compute node ecosystem — KeyManager, WorkspaceManager, +ComputeCapabilities, SSHConnectionPool, compute tools, and routes.""" + +from unittest.mock import MagicMock + +import pytest + +pytestmark = pytest.mark.asyncio + +from openmlr.compute.capabilities import ComputeCapabilities, GPUInfo +from openmlr.compute.manager import ComputeManager +from openmlr.compute.workspace import WorkspaceManager +from openmlr.keys.manager import KeyManager +from openmlr.sandbox.ssh import SSHConnectionPool +from openmlr.tools.compute_tools import _validate_sync_path +from openmlr.tools.registry import MODE_TOOL_RESTRICTIONS, ToolRouter + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + +@pytest.fixture +def tmp_keys_dir(tmp_path): + keys_dir = tmp_path / ".keys" + keys_dir.mkdir() + return keys_dir + + +@pytest.fixture +def key_manager(tmp_keys_dir): + return KeyManager(keys_dir=tmp_keys_dir) + + +@pytest.fixture +def tmp_workspace_dir(tmp_path): + return tmp_path / ".openmlr" + + +@pytest.fixture +def workspace_manager(tmp_workspace_dir): + return WorkspaceManager(base_dir=tmp_workspace_dir) + + +# --------------------------------------------------------------------------- +# KeyManager +# --------------------------------------------------------------------------- + +class TestKeyManager: + def test_init_creates_dir(self, tmp_keys_dir, key_manager): + assert tmp_keys_dir.exists() + assert oct(tmp_keys_dir.stat().st_mode)[-3:] == "700" + + def test_list_keys_empty(self, key_manager): + assert key_manager.list_keys() == [] + + def test_generate_key_pair_ed25519(self, key_manager): + priv, pub = key_manager.generate_key_pair("id_test_ed", "ed25519", "test@host") + assert priv.exists() + assert pub.exists() + assert "id_test_ed" in priv.name + # Private key should be 0o600 + mode = oct(priv.stat().st_mode)[-3:] + assert mode == "600" + + def test_generate_key_pair_rsa(self, key_manager): + priv, pub = key_manager.generate_key_pair("id_test_rsa", "rsa", "test@host") + assert priv.exists() + assert pub.exists() + + def test_generate_unsupported_algorithm(self, key_manager): + with pytest.raises(ValueError, match="Unsupported algorithm"): + key_manager.generate_key_pair("id_bad", "dsa", "test") + + def test_key_exists(self, key_manager): + key_manager.generate_key_pair("id_exist", "ed25519") + assert key_manager.key_exists("id_exist") is True + assert key_manager.key_exists("id_nope") is False + + def test_list_keys_after_generate(self, key_manager): + key_manager.generate_key_pair("id_list_test", "ed25519") + keys = key_manager.list_keys() + assert len(keys) == 1 + assert keys[0]["filename"] == "id_list_test" + assert keys[0]["has_public"] is True + + def test_delete_key(self, key_manager): + key_manager.generate_key_pair("id_del", "ed25519") + assert key_manager.key_exists("id_del") + result = key_manager.delete_key("id_del") + assert result is True + assert not key_manager.key_exists("id_del") + + def test_delete_nonexistent(self, key_manager): + result = key_manager.delete_key("nope") + assert result is False + + def test_write_and_read_key(self, key_manager): + key_manager.write_key("id_manual", "-----BEGIN FAKE KEY-----\ndata\n-----END FAKE KEY-----\n") + content = key_manager.read_key("id_manual") + assert "FAKE KEY" in content + + def test_read_nonexistent_key(self, key_manager): + with pytest.raises(FileNotFoundError): + key_manager.read_key("nope") + + def test_validate_key_ed25519(self, key_manager): + key_manager.generate_key_pair("id_val", "ed25519", "comment") + private = key_manager.read_key("id_val") + meta = key_manager.validate_key(private) + assert meta["algorithm"] == "ssh-ed25519" + assert meta["fingerprint"].startswith("SHA256:") + assert len(meta["fingerprint"]) > 10 + assert meta["public_key"].startswith("ssh-ed25519") + + def test_validate_key_rsa(self, key_manager): + key_manager.generate_key_pair("id_val_rsa", "rsa") + private = key_manager.read_key("id_val_rsa") + meta = key_manager.validate_key(private) + assert meta["algorithm"] == "ssh-rsa" + assert meta["fingerprint"].startswith("SHA256:") + + def test_validate_invalid_key(self, key_manager): + with pytest.raises(ValueError, match="Invalid private key"): + key_manager.validate_key("not a key") + + def test_get_key_path(self, key_manager, tmp_keys_dir): + path = key_manager.get_key_path("id_some") + assert path == tmp_keys_dir / "id_some" + + +# --------------------------------------------------------------------------- +# WorkspaceManager +# --------------------------------------------------------------------------- + +class TestWorkspaceManager: + def test_create_workspace(self, workspace_manager): + path = workspace_manager.create_workspace("test-uuid-123") + assert path.exists() + assert (path / "data").exists() + assert (path / "models").exists() + assert (path / "code").exists() + assert (path / "outputs").exists() + assert (path / ".openmlr-meta").exists() + + def test_get_workspace_path(self, workspace_manager): + path = workspace_manager.get_workspace_path("abc") + assert "workspace-abc" in str(path) + + def test_workspace_exists(self, workspace_manager): + assert workspace_manager.workspace_exists("nope") is False + workspace_manager.create_workspace("nope") + assert workspace_manager.workspace_exists("nope") is True + + def test_delete_workspace_with_archive(self, workspace_manager): + workspace_manager.create_workspace("del-test") + ws_path = workspace_manager.get_workspace_path("del-test") + (ws_path / "data" / "file.txt").write_text("hello") + + result = workspace_manager.delete_workspace("del-test", archive=True) + assert result is True + assert not ws_path.exists() + # Check archive was created + archives = list(workspace_manager.archive_dir.glob("*.tar.gz")) + assert len(archives) == 1 + + def test_delete_workspace_without_archive(self, workspace_manager): + workspace_manager.create_workspace("del-no-archive") + result = workspace_manager.delete_workspace("del-no-archive", archive=False) + assert result is True + archives = list(workspace_manager.archive_dir.glob("*.tar.gz")) + assert len(archives) == 0 + + def test_delete_nonexistent(self, workspace_manager): + result = workspace_manager.delete_workspace("nonexistent") + assert result is False + + def test_get_workspace_size(self, workspace_manager): + workspace_manager.create_workspace("size-test") + path = workspace_manager.get_workspace_path("size-test") + (path / "data" / "big.bin").write_bytes(b"x" * 1024) + size = workspace_manager.get_workspace_size("size-test") + assert size >= 1024 + + def test_list_workspaces(self, workspace_manager): + workspace_manager.create_workspace("ws-a") + workspace_manager.create_workspace("ws-b") + ws_list = workspace_manager.list_workspaces() + uuids = [w["uuid"] for w in ws_list] + assert "ws-a" in uuids + assert "ws-b" in uuids + + def test_cleanup_archives(self, workspace_manager): + # Create and archive 3 workspaces + for i in range(3): + workspace_manager.create_workspace(f"cleanup-{i}") + workspace_manager.archive_workspace(f"cleanup-{i}") + + result = workspace_manager.cleanup_archives(max_age_days=0, max_count=1) + assert result["deleted"] >= 2 + remaining = list(workspace_manager.archive_dir.glob("*.tar.gz")) + assert len(remaining) <= 1 + + def test_cleanup_workspaces_orphaned(self, workspace_manager): + workspace_manager.create_workspace("keep") + workspace_manager.create_workspace("orphan") + result = workspace_manager.cleanup_workspaces( + conversation_uuids=["keep"], + archive=False, + ) + assert result["deleted"] == 1 + assert workspace_manager.workspace_exists("keep") + assert not workspace_manager.workspace_exists("orphan") + + +# --------------------------------------------------------------------------- +# ComputeCapabilities +# --------------------------------------------------------------------------- + +class TestComputeCapabilities: + def test_defaults(self): + caps = ComputeCapabilities() + assert caps.platform == "unknown" + assert caps.cpu_cores == 0 + assert caps.gpu_available is False + assert caps.gpu_info == [] + + def test_to_dict(self): + caps = ComputeCapabilities( + cpu_cores=8, + gpu_available=True, + gpu_info=[GPUInfo(model="A100", vram_gb=80.0, cuda_version="12.4")], + ) + d = caps.to_dict() + assert d["cpu_cores"] == 8 + assert d["gpu_available"] is True + assert len(d["gpu_info"]) == 1 + assert d["gpu_info"][0]["model"] == "A100" + + def test_from_dict(self): + d = { + "platform": "Linux", + "cpu_cores": 4, + "gpu_available": True, + "gpu_info": [{"model": "RTX 4090", "vram_gb": 24, "cuda_version": "12.4", "driver_version": "545"}], + } + caps = ComputeCapabilities.from_dict(d) + assert caps.platform == "Linux" + assert caps.cpu_cores == 4 + assert len(caps.gpu_info) == 1 + assert caps.gpu_info[0].model == "RTX 4090" + + def test_roundtrip(self): + original = ComputeCapabilities( + platform="test", + cpu_cores=16, + available_ram_gb=32.5, + gpu_available=True, + gpu_count=2, + gpu_info=[ + GPUInfo(model="A100", vram_gb=80), + GPUInfo(model="A100", vram_gb=80), + ], + python_versions=["3.12", "3.11"], + docker_available=True, + ) + d = original.to_dict() + restored = ComputeCapabilities.from_dict(d) + assert restored.platform == "test" + assert restored.cpu_cores == 16 + assert restored.available_ram_gb == 32.5 + assert len(restored.gpu_info) == 2 + assert restored.docker_available is True + + +# --------------------------------------------------------------------------- +# ComputeManager (validation) +# --------------------------------------------------------------------------- + +class TestComputeManager: + def test_validate_ssh_missing_host(self, key_manager): + cm = ComputeManager(key_manager) + ok, err = cm.validate_node_config("ssh", {"username": "user"}) + assert ok is False + assert "host" in err + + def test_validate_ssh_missing_username(self, key_manager): + cm = ComputeManager(key_manager) + ok, err = cm.validate_node_config("ssh", {"host": "example.com"}) + assert ok is False + assert "username" in err + + def test_validate_ssh_ok(self, key_manager): + cm = ComputeManager(key_manager) + ok, err = cm.validate_node_config("ssh", {"host": "example.com", "username": "user"}) + assert ok is True + + def test_validate_ssh_missing_key(self, key_manager): + cm = ComputeManager(key_manager) + ok, err = cm.validate_node_config("ssh", { + "host": "x", "username": "u", "key_filename": "nonexistent", + }) + assert ok is False + assert "not found" in err + + def test_validate_local_ok(self, key_manager): + cm = ComputeManager(key_manager) + ok, err = cm.validate_node_config("local", {}) + assert ok is True + + def test_validate_local_file_not_dir(self, key_manager, tmp_path): + f = tmp_path / "not_a_dir" + f.write_text("data") + cm = ComputeManager(key_manager) + ok, err = cm.validate_node_config("local", {"workdir": str(f)}) + assert ok is False + + def test_validate_modal_ok(self, key_manager): + cm = ComputeManager(key_manager) + ok, err = cm.validate_node_config("modal", {}) + assert ok is True + + def test_validate_unknown_type(self, key_manager): + cm = ComputeManager(key_manager) + ok, err = cm.validate_node_config("kubernetes", {}) + assert ok is False + assert "Unknown" in err + + +# --------------------------------------------------------------------------- +# SSHConnectionPool +# --------------------------------------------------------------------------- + +class TestSSHConnectionPool: + def test_singleton(self): + pool1 = SSHConnectionPool.get_pool() + pool2 = SSHConnectionPool.get_pool() + assert pool1 is pool2 + + def test_make_key(self): + assert SSHConnectionPool._make_key("host", 22, "user") == "user@host:22" + + def test_get_empty(self): + pool = SSHConnectionPool(ttl_seconds=300) + assert pool.get("host", 22, "user") is None + + def test_put_and_get(self): + pool = SSHConnectionPool(ttl_seconds=300) + # Mock a client with active transport + mock_client = MagicMock() + mock_transport = MagicMock() + mock_transport.is_active.return_value = True + mock_client.get_transport.return_value = mock_transport + mock_sftp = MagicMock() + + pool.put("host", 22, "user", mock_client, mock_sftp, "fp123") + result = pool.get("host", 22, "user") + assert result is not None + client, sftp, fp = result + assert client is mock_client + assert sftp is mock_sftp + assert fp == "fp123" + + def test_get_dead_connection(self): + pool = SSHConnectionPool(ttl_seconds=300) + mock_client = MagicMock() + mock_transport = MagicMock() + mock_transport.is_active.return_value = False + mock_client.get_transport.return_value = mock_transport + mock_sftp = MagicMock() + + pool.put("host", 22, "user", mock_client, mock_sftp, "fp") + result = pool.get("host", 22, "user") + assert result is None + + def test_cleanup_idle(self): + pool = SSHConnectionPool(ttl_seconds=0) # immediate expiry + mock_client = MagicMock() + mock_sftp = MagicMock() + pool.put("host", 22, "user", mock_client, mock_sftp, "fp") + pool._last_used["user@host:22"] = 0 # force stale + pool.cleanup_idle() + assert pool.get("host", 22, "user") is None + mock_sftp.close.assert_called_once() + mock_client.close.assert_called_once() + + def test_remove(self): + pool = SSHConnectionPool(ttl_seconds=300) + mock_client = MagicMock() + mock_sftp = MagicMock() + pool.put("host", 22, "user", mock_client, mock_sftp, "fp") + pool.remove("host", 22, "user") + assert pool.get("host", 22, "user") is None + + +# --------------------------------------------------------------------------- +# Path traversal validation +# --------------------------------------------------------------------------- + +class TestPathTraversal: + def test_valid_relative_path(self, tmp_path): + ws = tmp_path / "workspace" + ws.mkdir() + path, err = _validate_sync_path(ws, "data/file.txt") + assert err is None + assert str(ws) in str(path) + + def test_traversal_blocked(self, tmp_path): + ws = tmp_path / "workspace" + ws.mkdir() + path, err = _validate_sync_path(ws, "../../etc/passwd") + assert err is not None + assert "escapes" in err + + def test_absolute_path_blocked(self, tmp_path): + ws = tmp_path / "workspace" + ws.mkdir() + path, err = _validate_sync_path(ws, "/etc/passwd") + assert err is not None + assert "escapes" in err + + def test_nested_valid_path(self, tmp_path): + ws = tmp_path / "workspace" + ws.mkdir() + path, err = _validate_sync_path(ws, "data/subdir/deep/file.csv") + assert err is None + + +# --------------------------------------------------------------------------- +# ToolRouter compute context injection +# --------------------------------------------------------------------------- + +class TestToolRouterContext: + def test_set_context(self): + router = ToolRouter() + router.set_context(user_id=42, db="fake_db") + assert router._user_id == 42 + assert router._db == "fake_db" + + async def test_context_injected_into_handler(self): + router = ToolRouter() + router.set_context(user_id=42, db="fake_db") + + async def handler(user_id: int = None, db=None, arg: str = "") -> tuple[str, bool]: + return f"uid={user_id},db={db},arg={arg}", True + + from openmlr.agent.types import ToolSpec + tool = ToolSpec( + name="ctx_test", description="test", parameters={"type": "object", "properties": {}}, + handler=handler, + ) + router.register(tool) + result, ok = await router.call_tool("ctx_test", {"arg": "hello"}) + assert ok is True + assert "uid=42" in result + assert "db=fake_db" in result + assert "arg=hello" in result + + +# --------------------------------------------------------------------------- +# Plan mode allows compute tools +# --------------------------------------------------------------------------- + +class TestPlanModeComputeTools: + def test_compute_list_allowed(self): + assert "compute_list" in MODE_TOOL_RESTRICTIONS["plan"]["allowed"] + + def test_compute_plan_allowed(self): + assert "compute_plan" in MODE_TOOL_RESTRICTIONS["plan"]["allowed"] + + def test_compute_probe_allowed(self): + assert "compute_probe" in MODE_TOOL_RESTRICTIONS["plan"]["allowed"] + + def test_compute_select_not_in_plan(self): + assert "compute_select" not in MODE_TOOL_RESTRICTIONS["plan"]["allowed"] + + +# --------------------------------------------------------------------------- +# Config redaction (routes/compute.py) +# --------------------------------------------------------------------------- + +class TestConfigRedaction: + def test_redact_password(self): + from openmlr.routes.compute import _redact_config + config = {"host": "example.com", "password": "secret123", "username": "user"} + redacted = _redact_config(config) + assert redacted["host"] == "example.com" + assert redacted["password"] == "***" + assert redacted["username"] == "user" + + def test_redact_empty_config(self): + from openmlr.routes.compute import _redact_config + assert _redact_config({}) == {} + assert _redact_config(None) == {} + + def test_redact_no_sensitive_fields(self): + from openmlr.routes.compute import _redact_config + config = {"host": "x", "port": 22} + assert _redact_config(config) == config + + +# --------------------------------------------------------------------------- +# Routes (keys + compute) — integration via httpx +# --------------------------------------------------------------------------- + +class TestKeyRoutes: + async def test_list_keys_empty(self, auth_client): + resp = await auth_client.get("/api/keys") + assert resp.status_code == 200 + assert resp.json()["keys"] == [] + + async def test_generate_key(self, auth_client): + resp = await auth_client.post("/api/keys", json={ + "action": "generate", + "filename": "id_test_route", + "algorithm": "ed25519", + "comment": "test", + }) + assert resp.status_code == 200 + data = resp.json()["key"] + assert data["filename"] == "id_test_route" + assert data["algorithm"] == "ssh-ed25519" + assert data["fingerprint"].startswith("SHA256:") + + async def test_generate_duplicate(self, auth_client): + await auth_client.post("/api/keys", json={ + "action": "generate", "filename": "id_dup", "algorithm": "ed25519", + }) + resp = await auth_client.post("/api/keys", json={ + "action": "generate", "filename": "id_dup", "algorithm": "ed25519", + }) + assert resp.status_code == 409 + + async def test_delete_key(self, auth_client): + await auth_client.post("/api/keys", json={ + "action": "generate", "filename": "id_to_del", "algorithm": "ed25519", + }) + resp = await auth_client.delete("/api/keys/id_to_del") + assert resp.status_code == 200 + assert resp.json()["ok"] is True + + async def test_delete_nonexistent_key(self, auth_client): + resp = await auth_client.delete("/api/keys/id_nope") + assert resp.status_code == 404 + + async def test_create_key_missing_filename(self, auth_client): + resp = await auth_client.post("/api/keys", json={"action": "generate"}) + assert resp.status_code == 400 + + async def test_create_key_invalid_action(self, auth_client): + resp = await auth_client.post("/api/keys", json={ + "action": "nope", "filename": "id_x", + }) + assert resp.status_code == 400 + + async def test_unauthenticated_keys(self, client): + resp = await client.get("/api/keys") + assert resp.status_code == 401 + + +class TestComputeNodeRoutes: + async def test_list_empty(self, auth_client): + resp = await auth_client.get("/api/compute/nodes") + assert resp.status_code == 200 + assert resp.json()["nodes"] == [] + + async def test_create_local_node(self, auth_client): + resp = await auth_client.post("/api/compute/nodes", json={ + "name": "My Laptop", + "type": "local", + "config": {}, + }) + assert resp.status_code == 200 + node = resp.json()["node"] + assert node["name"] == "My Laptop" + assert node["type"] == "local" + assert node["health_status"] == "unknown" + + async def test_create_duplicate_name(self, auth_client): + await auth_client.post("/api/compute/nodes", json={ + "name": "Dup", "type": "local", "config": {}, + }) + resp = await auth_client.post("/api/compute/nodes", json={ + "name": "Dup", "type": "local", "config": {}, + }) + assert resp.status_code == 409 + + async def test_create_invalid_type(self, auth_client): + resp = await auth_client.post("/api/compute/nodes", json={ + "name": "Bad", "type": "kubernetes", "config": {}, + }) + assert resp.status_code == 400 + + async def test_get_node(self, auth_client): + create_resp = await auth_client.post("/api/compute/nodes", json={ + "name": "Get Test", "type": "local", "config": {}, + }) + node_id = create_resp.json()["node"]["id"] + resp = await auth_client.get(f"/api/compute/nodes/{node_id}") + assert resp.status_code == 200 + assert resp.json()["node"]["name"] == "Get Test" + + async def test_update_node(self, auth_client): + create_resp = await auth_client.post("/api/compute/nodes", json={ + "name": "Update Test", "type": "local", "config": {}, + }) + node_id = create_resp.json()["node"]["id"] + resp = await auth_client.put(f"/api/compute/nodes/{node_id}", json={ + "name": "Updated Name", + }) + assert resp.status_code == 200 + assert resp.json()["node"]["name"] == "Updated Name" + + async def test_delete_node(self, auth_client): + create_resp = await auth_client.post("/api/compute/nodes", json={ + "name": "Delete Test", "type": "local", "config": {}, + }) + node_id = create_resp.json()["node"]["id"] + resp = await auth_client.delete(f"/api/compute/nodes/{node_id}") + assert resp.status_code == 200 + assert resp.json()["ok"] is True + + async def test_set_default(self, auth_client): + create_resp = await auth_client.post("/api/compute/nodes", json={ + "name": "Default Test", "type": "local", "config": {}, + }) + node_id = create_resp.json()["node"]["id"] + resp = await auth_client.post(f"/api/compute/nodes/{node_id}/set-default") + assert resp.status_code == 200 + + # Verify it's now default + get_resp = await auth_client.get(f"/api/compute/nodes/{node_id}") + assert get_resp.json()["node"]["is_default"] is True + + async def test_config_redacted_in_response(self, auth_client): + resp = await auth_client.post("/api/compute/nodes", json={ + "name": "Redact Test", + "type": "ssh", + "config": {"host": "x", "username": "u", "password": "secret"}, + }) + assert resp.status_code == 200 + node = resp.json()["node"] + assert node["config"]["password"] == "***" + assert node["config"]["host"] == "x" + + async def test_test_local_node(self, auth_client): + create_resp = await auth_client.post("/api/compute/nodes", json={ + "name": "Test Local", "type": "local", "config": {}, + }) + node_id = create_resp.json()["node"]["id"] + resp = await auth_client.post(f"/api/compute/nodes/{node_id}/test") + assert resp.status_code == 200 + # Local test should pass (workspace will be CWD) + assert resp.json()["ok"] is True + + async def test_test_config_endpoint(self, auth_client): + resp = await auth_client.post("/api/compute/test", json={ + "type": "local", + "config": {}, + }) + assert resp.status_code == 200 + assert resp.json()["ok"] is True + + async def test_test_config_invalid_type(self, auth_client): + resp = await auth_client.post("/api/compute/test", json={ + "type": "kubernetes", + "config": {}, + }) + assert resp.status_code == 200 + assert resp.json()["ok"] is False + + async def test_unauthenticated(self, client): + resp = await client.get("/api/compute/nodes") + assert resp.status_code == 401 + + +# --------------------------------------------------------------------------- +# System prompt includes compute_env +# --------------------------------------------------------------------------- + +class TestSystemPromptCompute: + def test_prompt_includes_compute_env(self): + from openmlr.agent.prompts import build_system_prompt + prompt = build_system_prompt( + tool_specs=[], + compute_env="## Active Compute: TestNode (ssh)\n- CPU: 8 cores", + ) + assert "TestNode" in prompt + assert "8 cores" in prompt + + def test_prompt_without_compute_env(self): + from openmlr.agent.prompts import build_system_prompt + prompt = build_system_prompt(tool_specs=[], compute_env="") + assert "Active Compute" not in prompt diff --git a/backend/tests/test_sandbox_types.py b/backend/tests/test_sandbox_types.py index fc7553f..e64aeb6 100644 --- a/backend/tests/test_sandbox_types.py +++ b/backend/tests/test_sandbox_types.py @@ -2,34 +2,43 @@ import pytest -from openmlr.sandbox.interface import EnvironmentInfo, ExecutionResult, SandboxInterface +from openmlr.compute.capabilities import ComputeCapabilities, GPUInfo +from openmlr.sandbox.interface import ExecutionResult, SandboxInterface from openmlr.sandbox.local import LocalSandbox -class TestEnvironmentInfo: +class TestComputeCapabilities: def test_defaults(self): - info = EnvironmentInfo() - assert info.os == "unknown" - assert info.python_version == "unknown" - assert info.gpu_available is False - assert info.gpu_info is None - assert info.installed_packages == [] - assert info.available_disk_gb == 0.0 - assert info.available_ram_gb == 0.0 + caps = ComputeCapabilities() + assert caps.platform == "unknown" + assert caps.cpu_cores == 0 + assert caps.gpu_available is False + assert caps.gpu_info == [] + assert caps.installed_packages == [] + assert caps.available_disk_gb == 0.0 + assert caps.available_ram_gb == 0.0 def test_custom_values(self): - info = EnvironmentInfo( - os="Linux", - python_version="3.12.0", + caps = ComputeCapabilities( + platform="Linux 6.5.0", + cpu_cores=8, gpu_available=True, - gpu_info="NVIDIA A100", - installed_packages=["torch", "numpy"], + gpu_info=[GPUInfo(model="NVIDIA A100", vram_gb=80.0)], + installed_packages=["torch==2.3.0", "numpy==1.26.0"], available_disk_gb=50.0, available_ram_gb=32.0, ) - assert info.os == "Linux" - assert info.gpu_available is True - assert "torch" in info.installed_packages + assert caps.platform == "Linux 6.5.0" + assert caps.gpu_available is True + assert len(caps.gpu_info) == 1 + assert caps.gpu_info[0].model == "NVIDIA A100" + + def test_to_dict_roundtrip(self): + caps = ComputeCapabilities(cpu_cores=4, gpu_available=True) + d = caps.to_dict() + caps2 = ComputeCapabilities.from_dict(d) + assert caps2.cpu_cores == 4 + assert caps2.gpu_available is True class TestExecutionResult: diff --git a/docker-compose.prod.yml b/docker-compose.prod.yml index d83bd14..d479428 100644 --- a/docker-compose.prod.yml +++ b/docker-compose.prod.yml @@ -66,6 +66,7 @@ services: condition: service_healthy volumes: - ./backend/configs:/app/backend/configs + - ./.keys:/app/.keys security_opt: - no-new-privileges:true restart: unless-stopped @@ -93,6 +94,7 @@ services: condition: service_healthy volumes: - ./backend/configs:/app/backend/configs + - ./.keys:/app/.keys security_opt: - no-new-privileges:true restart: unless-stopped diff --git a/docker-compose.yml b/docker-compose.yml index 572272d..6eaadfd 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -70,6 +70,7 @@ services: - ./backend:/app/backend - backend-venv:/app/backend/.venv - ./frontend/dist:/app/frontend/dist + - ./.keys:/app/.keys # Worker with auto-restart on code changes worker: @@ -103,6 +104,7 @@ services: volumes: - ./backend:/app/backend - backend-venv:/app/backend/.venv + - ./.keys:/app/.keys # Docs site with live reload docs: diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx index 22490d2..8ba3112 100644 --- a/frontend/src/App.tsx +++ b/frontend/src/App.tsx @@ -1,6 +1,7 @@ import { useState, useCallback, useEffect, useRef } from 'react'; import { Routes, Route, Navigate, useNavigate, useParams } from 'react-router-dom'; import { Copy, Check } from 'lucide-react'; +import { ComputeSelector } from './components/ComputeSelector'; import { useSSE } from './hooks/useSSE'; import { useJobStatus } from './hooks/useJobStatus'; import { api } from './api'; @@ -20,7 +21,7 @@ import { SettingsPage } from './components/SettingsPage'; import { ProvidersSettings } from './components/settings/ProvidersSettings'; import { AgentSettings } from './components/settings/AgentSettings'; import { McpSettings } from './components/settings/McpSettings'; -import { SandboxSettings } from './components/settings/SandboxSettings'; +import { ComputeSettings } from './components/settings/ComputeSettings'; import { WritingSettings } from './components/settings/WritingSettings'; let msgId = 0; @@ -100,6 +101,8 @@ function ChatUI({ const [viewingReport, setViewingReport] = useState(null); const [inputMode, setInputMode] = useState('plan'); const [inputText, setInputText] = useState(''); + const [computeNodes, setComputeNodes] = useState([]); + const [activeCompute, setActiveCompute] = useState(null); // Ref to always have current conv UUID in SSE callback (avoids stale closure) const currentConvUuidRef = useRef(currentConvUuid); @@ -120,14 +123,33 @@ function ChatUI({ } }, []); + const loadComputeNodes = useCallback(async () => { + try { + const data = await api.getComputeNodes(); + setComputeNodes(data.nodes || []); + } catch { + setComputeNodes([]); + } + }, []); + + const loadActiveCompute = useCallback(async (uuid: string) => { + try { + const data = await api.getConversationCompute(uuid); + setActiveCompute(data.node || null); + } catch { + setActiveCompute(null); + } + }, []); + // Initial load - load conversations and activate the correct one useEffect(() => { const init = async () => { + await loadComputeNodes(); const convs = await loadConversations(); // If URL has a conversation UUID, load it directly if (routeUuid) { - switchConv(routeUuid); + await switchConv(routeUuid); return; } @@ -146,7 +168,7 @@ function ChatUI({ const first = convs[0]; setCurrentConvUuid(first.uuid); navigate(`/${first.uuid}`, { replace: true }); - switchConv(first.uuid); + await switchConv(first.uuid); } }; init(); @@ -197,6 +219,9 @@ function ChatUI({ } return { id: nextId(), role: m.role, content: m.content }; }) || []); + + // Load active compute for this conversation + await loadActiveCompute(uuid); } catch { /* */ } }; @@ -214,6 +239,8 @@ function ChatUI({ setMessages([]); setTasks([]); setResources([]); setContextUsage(null); setSearchBudget(null); setApprovalEvent(null); setQuestionsPayload(null); if (conv.model) setModel(conv.model); + // Load default compute for new conversation + await loadActiveCompute(conv.uuid); navigate(`/${conv.uuid}`, { replace: true }); } catch { /* */ } }; @@ -226,11 +253,24 @@ function ChatUI({ if (currentConvUuid === uuid) { setCurrentConvUuid(null); setMessages([]); setTasks([]); setResources([]); setApprovalEvent(null); setQuestionsPayload(null); + setActiveCompute(null); navigate('/', { replace: true }); } } catch { /* */ } }; + const handleComputeChange = useCallback(async (nodeId: number | null) => { + if (!currentConvUuid) return; + try { + if (nodeId === null) { + await api.clearConversationCompute(currentConvUuid); + } else { + await api.setConversationCompute(currentConvUuid, nodeId); + } + await loadActiveCompute(currentConvUuid); + } catch { /* */ } + }, [currentConvUuid, loadActiveCompute]); + // Helper to reload messages from DB for a given conversation const reloadConversationMessages = useCallback(async (uuid: string) => { try { @@ -401,7 +441,7 @@ function ChatUI({ }); break; case 'questions': setCurrentConvStatus('waiting_input'); setQuestionsPayload(data as QuestionsPayload); break; - case 'plan_update': + case 'plan_update': { setTasks(data?.tasks || []); setRightPanelOpen(true); // Auto-compact after all tasks are completed @@ -410,6 +450,7 @@ function ChatUI({ setTimeout(() => api.compact().catch(() => {}), 1000); } break; + } case 'resources_update': setResources(data?.resources || []); setRightPanelOpen(true); break; case 'context_usage': if (data) setContextUsage(data as ContextUsage); break; case 'search_budget': if (data) setSearchBudget(data as SearchBudget); break; @@ -521,6 +562,11 @@ function ChatUI({ />
+
@@ -626,7 +672,7 @@ export default function App() { } /> } /> } /> - } /> + } /> } /> diff --git a/frontend/src/api.ts b/frontend/src/api.ts index 09dc7f3..7d5a3ee 100644 --- a/frontend/src/api.ts +++ b/frontend/src/api.ts @@ -85,6 +85,10 @@ export const api = { getConversation: (uuid: string) => get(`/api/conversations/${uuid}`), deleteConversation: (uuid: string) => del(`/api/conversations/${uuid}`), switchConversation: (uuid: string) => post(`/api/conversations/${uuid}/switch`, {}), + getConversationCompute: (uuid: string) => get(`/api/conversations/${uuid}/compute`), + setConversationCompute: (uuid: string, nodeId: number | null) => + post(`/api/conversations/${uuid}/compute`, { node_id: nodeId }), + clearConversationCompute: (uuid: string) => del(`/api/conversations/${uuid}/compute`), // Settings getSettings: () => get('/api/settings'), @@ -106,4 +110,21 @@ export const api = { getModels: () => get('/api/models'), getStatus: () => get('/api/status'), saveConfig: (config: Record) => post('/api/config', config), + + // SSH Keys + getKeys: () => get('/api/keys'), + createKey: (body: Record) => post('/api/keys', body), + deleteKey: (filename: string) => del(`/api/keys/${filename}`), + + // Compute Nodes + getComputeNodes: () => get('/api/compute/nodes'), + createComputeNode: (body: Record) => post('/api/compute/nodes', body), + getComputeNode: (id: number) => get(`/api/compute/nodes/${id}`), + updateComputeNode: (id: number, body: Record) => put(`/api/compute/nodes/${id}`, body), + deleteComputeNode: (id: number) => del(`/api/compute/nodes/${id}`), + testComputeNode: (id: number) => post(`/api/compute/nodes/${id}/test`, {}), + testComputeConfig: (type: string, config: Record) => + post('/api/compute/test', { type, config }), + probeComputeNode: (id: number) => post(`/api/compute/nodes/${id}/probe`, {}), + setDefaultComputeNode: (id: number) => post(`/api/compute/nodes/${id}/set-default`, {}), }; diff --git a/frontend/src/components/ComputeSelector.tsx b/frontend/src/components/ComputeSelector.tsx new file mode 100644 index 0000000..a812b16 --- /dev/null +++ b/frontend/src/components/ComputeSelector.tsx @@ -0,0 +1,101 @@ +import { useState, useEffect, useRef } from 'react'; +import { Cpu, ChevronDown, Star } from 'lucide-react'; + +interface ComputeNode { + id: number; + name: string; + type: string; + health_status: string; +} + +interface ComputeSelectorProps { + currentNode: ComputeNode | null; + nodes: ComputeNode[]; + onChange: (nodeId: number | null) => void; +} + +export function ComputeSelector({ currentNode, nodes, onChange }: ComputeSelectorProps) { + const [open, setOpen] = useState(false); + const ref = useRef(null); + + useEffect(() => { + function handleClickOutside(event: MouseEvent) { + if (ref.current && !ref.current.contains(event.target as Node)) { + setOpen(false); + } + } + function handleEsc(event: KeyboardEvent) { + if (event.key === 'Escape') setOpen(false); + } + document.addEventListener('mousedown', handleClickOutside); + document.addEventListener('keydown', handleEsc); + return () => { + document.removeEventListener('mousedown', handleClickOutside); + document.removeEventListener('keydown', handleEsc); + }; + }, []); + + const getStatusColor = (status: string) => { + switch (status) { + case 'online': return 'bg-success'; + case 'offline': return 'bg-error'; + case 'degraded': return 'bg-warning'; + default: return 'bg-text-dim'; + } + }; + + return ( +
+ + + {open && ( +
+ {/* Default option */} + + + {nodes.length > 0 &&
} + + {/* Node list */} + {nodes.map((node) => ( + + ))} + + {nodes.length === 0 && ( +
+ No compute nodes configured +
+ )} +
+ )} +
+ ); +} diff --git a/frontend/src/components/SettingsPage.tsx b/frontend/src/components/SettingsPage.tsx index a2ad7ce..b2f4b71 100644 --- a/frontend/src/components/SettingsPage.tsx +++ b/frontend/src/components/SettingsPage.tsx @@ -1,11 +1,11 @@ import { Link, Outlet, useLocation } from 'react-router-dom'; -import { ArrowLeft, Key, Bot, Server, Box, PenTool } from 'lucide-react'; +import { ArrowLeft, Key, Bot, Server, Cpu, PenTool } from 'lucide-react'; const navItems = [ { path: '/settings/providers', label: 'Providers', icon: Key }, { path: '/settings/agent', label: 'Agent', icon: Bot }, { path: '/settings/mcp', label: 'MCP Servers', icon: Server }, - { path: '/settings/sandbox', label: 'Sandbox', icon: Box }, + { path: '/settings/compute', label: 'Compute', icon: Cpu }, { path: '/settings/writing', label: 'Writing', icon: PenTool }, ]; diff --git a/frontend/src/components/settings/AddKeyModal.tsx b/frontend/src/components/settings/AddKeyModal.tsx new file mode 100644 index 0000000..441786a --- /dev/null +++ b/frontend/src/components/settings/AddKeyModal.tsx @@ -0,0 +1,158 @@ +import { useState, useEffect } from 'react'; +import { X, Upload, KeyRound } from 'lucide-react'; + +interface AddKeyModalProps { + onClose: () => void; + onSubmit: (data: any) => void; +} + +export function AddKeyModal({ onClose, onSubmit }: AddKeyModalProps) { + useEffect(() => { + const handleEsc = (e: KeyboardEvent) => { if (e.key === 'Escape') onClose(); }; + document.addEventListener('keydown', handleEsc); + return () => document.removeEventListener('keydown', handleEsc); + }, [onClose]); + const [mode, setMode] = useState<'upload' | 'generate'>('upload'); + const [filename, setFilename] = useState(''); + const [privateKey, setPrivateKey] = useState(''); + const [algorithm, setAlgorithm] = useState('ed25519'); + const [comment, setComment] = useState(''); + const [submitting, setSubmitting] = useState(false); + + const handleSubmit = async (e: React.FormEvent) => { + e.preventDefault(); + if (!filename.trim()) return; + + setSubmitting(true); + try { + if (mode === 'upload') { + await onSubmit({ + action: 'upload', + filename: filename.trim(), + private_key: privateKey, + comment: comment || undefined, + }); + } else { + await onSubmit({ + action: 'generate', + filename: filename.trim(), + algorithm, + comment: comment || `openmlr-key`, + }); + } + } finally { + setSubmitting(false); + } + }; + + return ( +
+
e.stopPropagation()}> +
+

+ + Add SSH Key +

+ +
+ +
+ {/* Mode toggle */} +
+ + +
+ + {/* Filename */} +
+ + setFilename(e.target.value)} + className="w-full bg-bg border border-border rounded-lg px-3 py-2 text-text text-sm focus:border-primary focus:outline-none" + /> +

Stored in .keys/ directory

+
+ + {mode === 'upload' ? ( +
+ +