Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .keys/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
*
!.gitignore
10 changes: 10 additions & 0 deletions backend/configs/prompts/system_prompt.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -145,3 +145,13 @@ prompt: |
- Date: {{ date }}
- User: {{ username }}
- Mode: {{ mode }}

{% if compute_env %}
{{ compute_env }}
{% endif %}

# Compute Planning
When starting tasks that require significant computation (training models, processing large datasets, etc.):
1. Use `compute_plan` to verify the active node meets requirements
2. If not, use `compute_select` to switch to a suitable node
3. Always `sandbox_probe` before executing code on a node for the first time
2 changes: 2 additions & 0 deletions backend/openmlr/agent/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
mode: str = "general",
username: str = "user",
sandbox_info: str = "none",
compute_env: str = "",
config: AgentConfig | None = None,

Check notice on line 29 in backend/openmlr/agent/prompts.py

View workflow job for this annotation

GitHub Actions / Qodana for Python

Unused local symbols

Parameter 'config' value is not used
) -> str:
"""Build the full system prompt from YAML template."""
template_path = PROMPT_DIR / "system_prompt.yaml"
Expand Down Expand Up @@ -58,6 +59,7 @@
timezone="UTC",
username=username,
sandbox_info=sandbox_info,
compute_env=compute_env,
)

return prompt
Expand Down
4 changes: 4 additions & 0 deletions backend/openmlr/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@


@asynccontextmanager
async def lifespan(app: FastAPI):

Check notice on line 22 in backend/openmlr/app.py

View workflow job for this annotation

GitHub Actions / Qodana for Python

Shadowing names from outer scopes

Shadows name 'app' from outer scope
"""Startup: create tables & shared state. Shutdown: teardown sessions."""
import logging
logger = logging.getLogger("openmlr.app")
Expand Down Expand Up @@ -71,18 +71,22 @@
# ── API routers ──────────────────────────────────────────
from .auth.router import router as auth_router
from .routes.agent import router as agent_router
from .routes.compute import router as compute_router
from .routes.health import router as health_router
from .routes.keys import router as keys_router
from .routes.settings import router as settings_router

app.include_router(auth_router)
app.include_router(agent_router)
app.include_router(settings_router)
app.include_router(health_router)
app.include_router(keys_router)
app.include_router(compute_router)


# ── Global error handler ────────────────────────────────
@app.exception_handler(Exception)
async def global_exception_handler(request: Request, exc: Exception):

Check notice on line 89 in backend/openmlr/app.py

View workflow job for this annotation

GitHub Actions / Qodana for Python

Unused local symbols

Parameter 'request' value is not used
import logging
logger = logging.getLogger(__name__)
logger.exception(f"Unhandled exception: {exc}")
Expand Down
14 changes: 13 additions & 1 deletion backend/openmlr/celery_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
"openmlr",
broker=REDIS_URL,
backend=REDIS_URL,
include=["openmlr.tasks.agent_tasks"],
include=["openmlr.tasks.agent_tasks", "openmlr.tasks.compute_tasks"],
)

# Celery configuration
Expand Down Expand Up @@ -43,6 +43,18 @@

# Default queue
task_default_queue="default",

# Beat schedule for periodic tasks
beat_schedule={
"health-check-all-nodes": {
"task": "openmlr.tasks.compute_tasks.health_check_all_nodes",
"schedule": 300.0, # Every 5 minutes
},
"cleanup-old-workspaces": {
"task": "openmlr.tasks.compute_tasks.cleanup_old_workspaces",
"schedule": 86400.0, # Every 24 hours
},
},
)


Expand Down
6 changes: 6 additions & 0 deletions backend/openmlr/compute/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from .capabilities import ComputeCapabilities, GPUInfo
from .manager import ComputeManager
from .probe import probe_sandbox
from .workspace import WorkspaceManager

__all__ = ["ComputeCapabilities", "ComputeManager", "GPUInfo", "probe_sandbox", "WorkspaceManager"]
85 changes: 85 additions & 0 deletions backend/openmlr/compute/capabilities.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
"""Compute capability discovery and planning."""

from dataclasses import dataclass, field
from typing import Any


@dataclass
class GPUInfo:
"""Information about a GPU."""
model: str = ""
vram_gb: float = 0.0
cuda_version: str = ""
driver_version: str = ""


@dataclass
class ComputeCapabilities:
"""Comprehensive capabilities of a compute node."""
# Platform
platform: str = "unknown"
cpu_cores: int = 0
cpu_arch: str = "unknown"

# Memory
total_ram_gb: float = 0.0
available_ram_gb: float = 0.0

# Storage
total_disk_gb: float = 0.0
available_disk_gb: float = 0.0

# GPU
gpu_available: bool = False
gpu_count: int = 0
gpu_info: list[GPUInfo] = field(default_factory=list)

# Software
python_versions: list[str] = field(default_factory=list)
docker_available: bool = False
conda_envs: list[str] = field(default_factory=list)
installed_packages: list[str] = field(default_factory=list)

# Network
has_internet: bool = True
latency_ms: float = 0.0

def to_dict(self) -> dict[str, Any]:
"""Serialize to dict."""
return {
"platform": self.platform,
"cpu_cores": self.cpu_cores,
"cpu_arch": self.cpu_arch,
"total_ram_gb": self.total_ram_gb,
"available_ram_gb": self.available_ram_gb,
"total_disk_gb": self.total_disk_gb,
"available_disk_gb": self.available_disk_gb,
"gpu_available": self.gpu_available,
"gpu_count": self.gpu_count,
"gpu_info": [
{
"model": g.model,
"vram_gb": g.vram_gb,
"cuda_version": g.cuda_version,
"driver_version": g.driver_version,
}
for g in self.gpu_info
],
"python_versions": self.python_versions,
"docker_available": self.docker_available,
"conda_envs": self.conda_envs,
"installed_packages": self.installed_packages,
"has_internet": self.has_internet,
"latency_ms": self.latency_ms,
}

@classmethod
def from_dict(cls, data: dict) -> "ComputeCapabilities":
"""Deserialize from dict."""
caps = cls()
for key, value in data.items():
if key == "gpu_info" and value:
caps.gpu_info = [GPUInfo(**g) for g in value]
elif hasattr(caps, key):
setattr(caps, key, value)
return caps
45 changes: 45 additions & 0 deletions backend/openmlr/compute/manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
"""Compute Node Manager — registry, validation, and lifecycle."""

from pathlib import Path


class ComputeManager:
"""High-level operations for compute node management."""

def __init__(self, key_manager):
self.key_manager = key_manager

def validate_node_config(self, node_type: str, config: dict) -> tuple[bool, str]:
"""Validate a compute node configuration. Pure check, no side effects."""
if node_type == "ssh":
return self._validate_ssh_config(config)
elif node_type == "local":
return self._validate_local_config(config)
elif node_type == "modal":
return self._validate_modal_config(config)
else:
return False, f"Unknown node type: {node_type}"

def _validate_ssh_config(self, config: dict) -> tuple[bool, str]:
required = ["host", "username"]
for field in required:
if not config.get(field):
return False, f"SSH config requires '{field}'"

key_filename = config.get("key_filename")
if key_filename and not self.key_manager.key_exists(key_filename):
return False, f"SSH key not found: {key_filename}"

return True, ""

def _validate_local_config(self, config: dict) -> tuple[bool, str]:

Check notice on line 35 in backend/openmlr/compute/manager.py

View workflow job for this annotation

GitHub Actions / Qodana for Python

Method is not declared static

Method `_validate_local_config` may be 'static'
workdir = config.get("workdir", "")
if workdir:
path = Path(workdir).expanduser()
# Only validate — don't create directories as a side effect
if path.exists() and not path.is_dir():
return False, f"Path exists but is not a directory: {path}"
return True, ""

def _validate_modal_config(self, config: dict) -> tuple[bool, str]:

Check notice on line 44 in backend/openmlr/compute/manager.py

View workflow job for this annotation

GitHub Actions / Qodana for Python

Method is not declared static

Method `_validate_modal_config` may be 'static'

Check notice on line 44 in backend/openmlr/compute/manager.py

View workflow job for this annotation

GitHub Actions / Qodana for Python

Unused local symbols

Parameter 'config' value is not used
return True, ""
170 changes: 170 additions & 0 deletions backend/openmlr/compute/probe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
"""Environment probing for all sandbox types."""

import time

from .capabilities import ComputeCapabilities, GPUInfo


async def probe_sandbox(sandbox) -> ComputeCapabilities:
"""Deep capability discovery for any sandbox implementation."""
caps = ComputeCapabilities()
start = time.monotonic()

# Platform
result = await sandbox.execute("uname -s -r 2>/dev/null || echo 'unknown'", timeout=5)
if result.success:
caps.platform = result.output.strip()

# CPU cores and architecture
result = await sandbox.execute("nproc 2>/dev/null || sysctl -n hw.ncpu 2>/dev/null || echo '0'", timeout=5)
if result.success:
try:
caps.cpu_cores = int(result.output.strip())
except ValueError:
pass

result = await sandbox.execute("uname -m 2>/dev/null || echo 'unknown'", timeout=5)
if result.success:
caps.cpu_arch = result.output.strip()

# RAM (Linux)
result = await sandbox.execute(
"free -g 2>/dev/null | grep Mem | awk '{print $2, $7}' || "
"echo '0 0'",
timeout=5,
)
if result.success:
parts = result.output.strip().split()
if len(parts) >= 2:
try:
caps.total_ram_gb = float(parts[0])
caps.available_ram_gb = float(parts[1])
except ValueError:
pass

# Disk
result = await sandbox.execute(
"df -BG / 2>/dev/null | tail -1 | awk '{print $2, $4}' || echo '0 0'",
timeout=5,
)
if result.success:
parts = result.output.strip().split()
if len(parts) >= 2:
try:
caps.total_disk_gb = float(parts[0].replace("G", ""))
caps.available_disk_gb = float(parts[1].replace("G", ""))
except ValueError:
pass

# GPU — query model, memory, driver; then get CUDA version separately
result = await sandbox.execute(
"nvidia-smi --query-gpu=name,memory.total,driver_version "
"--format=csv,noheader 2>/dev/null || echo ''",
timeout=10,
)
if result.success and result.output.strip():
lines = [ln.strip() for ln in result.output.strip().split("\n") if ln.strip()]
caps.gpu_count = len(lines)
caps.gpu_available = caps.gpu_count > 0

# Get CUDA toolkit version
cuda_ver = ""
cuda_result = await sandbox.execute(
"nvidia-smi 2>/dev/null | grep 'CUDA Version' | awk '{print $9}'",
timeout=5,
)
if cuda_result.success and cuda_result.output.strip():
cuda_ver = cuda_result.output.strip()

for line in lines:
parts = [p.strip() for p in line.split(",")]
if len(parts) >= 3:
gpu = GPUInfo(
model=parts[0],
vram_gb=_parse_vram(parts[1]),
cuda_version=cuda_ver,
driver_version=parts[2],
)
caps.gpu_info.append(gpu)

# Python versions
result = await sandbox.execute(
"python3 --version 2>/dev/null; ls /usr/bin/python* 2>/dev/null || true",
timeout=5,
)
if result.success:
versions = []
for line in result.output.strip().split("\n"):
line = line.strip()
if line.startswith("Python "):
versions.append(line.replace("Python ", ""))
elif "/python" in line and not line.endswith("*"):
# Extract version from path like /usr/bin/python3.11
ver = line.split("/")[-1].replace("python", "")
if ver and ver not in versions:
versions.append(ver)
caps.python_versions = versions

# Docker
result = await sandbox.execute(
"docker info >/dev/null 2>&1 && echo 'DOCKER_OK' || echo 'DOCKER_FAIL'",
timeout=5,
)
if result.success and "DOCKER_OK" in result.output:
caps.docker_available = True

# Conda envs
result = await sandbox.execute(
"conda env list 2>/dev/null | grep -v '^#' | awk '{print $1}' || true",
timeout=5,
)
if result.success:
envs = [ln.strip() for ln in result.output.strip().split("\n") if ln.strip()]
caps.conda_envs = envs

# Key packages
result = await sandbox.execute(
"pip list --format=freeze 2>/dev/null | head -50 || true",
timeout=10,
)
if result.success:
caps.installed_packages = [
line.strip() for line in result.output.strip().split("\n")
if line.strip() and "==" in line
]

# Internet connectivity
result = await sandbox.execute(
"curl -s -o /dev/null -w '%{http_code}' --max-time 5 https://pypi.org/simple/ 2>/dev/null || echo '000'",
timeout=10,
)
if result.success and result.output.strip() == "200":
caps.has_internet = True
else:
# Fallback ping
result = await sandbox.execute(
"ping -c 1 -W 3 8.8.8.8 2>/dev/null || true",
timeout=10,
)
caps.has_internet = result.success and "1 received" in result.output

caps.latency_ms = (time.monotonic() - start) * 1000
return caps


def _parse_vram(vram_str: str) -> float:
"""Parse VRAM string like '24576 MiB' or '24 GB' to GB."""
vram_str = vram_str.strip().lower()
try:
if "mib" in vram_str:
return float(vram_str.replace("mib", "").strip()) / 1024
elif "gib" in vram_str:
return float(vram_str.replace("gib", "").strip())
elif "gb" in vram_str:
return float(vram_str.replace("gb", "").strip())
elif "mb" in vram_str:
return float(vram_str.replace("mb", "").strip()) / 1024
else:
return float(vram_str)
except ValueError:
return 0.0
Loading
Loading