diff --git a/backend/app/api/agent.py b/backend/app/api/agent.py index 98bbe34..7f46ce0 100644 --- a/backend/app/api/agent.py +++ b/backend/app/api/agent.py @@ -169,12 +169,15 @@ async def get_or_create_agent( llm_api_key = llm_config.api_key llm_model = llm_config.model + # Track if tool calling is supported + supports_tool_calling = True + if llm_config.provider == "system" and llm_config.deployment_id: - # Look up the deployment + # Look up the deployment with model info result = await db.execute( select(Deployment) .where(Deployment.id == llm_config.deployment_id) - .options(selectinload(Deployment.worker)) + .options(selectinload(Deployment.worker), selectinload(Deployment.model)) ) deployment = result.scalar_one_or_none() @@ -184,9 +187,16 @@ async def get_or_create_agent( raise HTTPException(status_code=400, detail="Deployment is not running") worker = deployment.worker - llm_base_url = f"http://{worker.host}:{deployment.port}/v1" + # Extract IP from worker address (format: IP:Port) + worker_ip = worker.address.split(":")[0] + llm_base_url = f"http://{worker_ip}:{deployment.port}/v1" llm_api_key = "dummy" - llm_model = llm_config.model or "default" + # Use the actual model_id from the LLMModel (e.g., "Qwen/Qwen2.5-0.5B-Instruct") + llm_model = deployment.model.model_id + + # Check if deployment has tool calling enabled + extra_params = deployment.extra_params or {} + supports_tool_calling = extra_params.get("enable-auto-tool-choice", False) elif llm_config.provider == "openai": llm_base_url = "https://api.openai.com/v1" @@ -216,6 +226,7 @@ async def get_or_create_agent( llm_model=llm_model, mcp_api_url=mcp_api_url, mcp_api_token=api_token, + supports_tool_calling=supports_tool_calling, ) await agent.initialize() @@ -443,7 +454,7 @@ async def agent_chat_simple( result = await db.execute( select(Deployment) .where(Deployment.id == request.llm_config.deployment_id) - .options(selectinload(Deployment.worker)) + .options(selectinload(Deployment.worker), selectinload(Deployment.model)) ) deployment = result.scalar_one_or_none() @@ -453,9 +464,12 @@ async def agent_chat_simple( raise HTTPException(status_code=400, detail="Deployment is not running") worker = deployment.worker - llm_base_url = f"http://{worker.host}:{deployment.port}/v1" + # Extract IP from worker address (format: IP:Port) + worker_ip = worker.address.split(":")[0] + llm_base_url = f"http://{worker_ip}:{deployment.port}/v1" llm_api_key = "dummy" - llm_model = request.llm_config.model or "default" + # Use the actual model_id from the LLMModel (e.g., "Qwen/Qwen2.5-0.5B-Instruct") + llm_model = deployment.model.model_id elif request.llm_config.provider == "openai": llm_base_url = "https://api.openai.com/v1" diff --git a/backend/app/api/auto_tuning.py b/backend/app/api/auto_tuning.py index f6b2f99..2b2285d 100644 --- a/backend/app/api/auto_tuning.py +++ b/backend/app/api/auto_tuning.py @@ -34,10 +34,17 @@ BenchmarkRequest, BenchmarkResultListResponse, BenchmarkResultResponse, + ComprehensiveBenchmarkMetrics, + ComprehensiveBenchmarkRequest, + ComprehensiveBenchmarkResponse, + ConcurrencyLevelResult, KnowledgeQuery, KnowledgeQueryResponse, KnowledgeRecord, KnowledgeSaveRequest, + LatencyPercentiles, + SaturationDetectionRequest, + SaturationDetectionResponse, TuningJobCreate, TuningJobListResponse, TuningJobProgress, @@ -397,6 +404,198 @@ async def list_benchmark_results( ) +# ============================================================================ +# Comprehensive Benchmark Endpoints +# ============================================================================ + + +@router.post("/benchmarks/comprehensive", response_model=ComprehensiveBenchmarkResponse) +async def run_comprehensive_benchmark( + request: ComprehensiveBenchmarkRequest, + db: AsyncSession = Depends(get_db), + current_user: User = Depends(require_operator), +): + """ + Run a comprehensive benchmark with detailed metrics. + + Returns metrics including: + - TTFT, ITL, TPOT with percentiles (p50, p90, p95, p99) + - Throughput (tokens/sec, requests/sec) + - Success rate and error statistics + """ + from app.services.benchmark import BenchmarkConfig, BenchmarkRunner, LoadPattern + + # Verify deployment exists and is running + result = await db.execute( + select(Deployment) + .where(Deployment.id == request.deployment_id) + .options( + selectinload(Deployment.model), + selectinload(Deployment.worker), + ) + ) + deployment = result.scalar_one_or_none() + + if not deployment: + raise HTTPException(status_code=404, detail="Deployment not found") + + if deployment.status != DeploymentStatus.RUNNING.value: + raise HTTPException(status_code=400, detail="Deployment is not running") + + worker = deployment.worker + if not worker: + raise HTTPException(status_code=400, detail="Worker not found") + + endpoint = f"http://{worker.host}:{deployment.port}/v1" + model_name = deployment.model.model_id if deployment.model else "default" + + # Configure benchmark + config = BenchmarkConfig( + endpoint=endpoint, + model_name=model_name, + load_pattern=LoadPattern.FIXED, + concurrency=request.concurrency, + num_requests=request.num_requests, + warmup_requests=request.warmup_requests, + prompt_tokens=request.prompt_tokens, + output_tokens=request.output_tokens, + custom_prompt=request.custom_prompt, + stream=True, + ) + + # Run benchmark + runner = BenchmarkRunner(config) + bench_result = await runner.run() + + # Convert metrics to response format + m = bench_result.metrics + + def to_percentiles(lm) -> LatencyPercentiles: + return LatencyPercentiles( + mean=lm.mean, + median=lm.median, + min=lm.min, + max=lm.max, + std=lm.std, + p50=lm.p50, + p90=lm.p90, + p95=lm.p95, + p99=lm.p99, + ) + + metrics = ComprehensiveBenchmarkMetrics( + ttft=to_percentiles(m.ttft), + itl=to_percentiles(m.itl), + tpot=to_percentiles(m.tpot), + e2e_latency=to_percentiles(m.e2e_latency), + throughput_tps=m.throughput_tps, + throughput_rps=m.throughput_rps, + output_tps=m.output_tps, + total_requests=m.total_requests, + successful_requests=m.successful_requests, + failed_requests=m.failed_requests, + success_rate=m.success_rate, + total_prompt_tokens=m.total_prompt_tokens, + total_completion_tokens=m.total_completion_tokens, + avg_prompt_tokens=m.avg_prompt_tokens, + avg_completion_tokens=m.avg_completion_tokens, + total_duration_seconds=m.total_duration_seconds, + concurrency=m.concurrency, + ) + + return ComprehensiveBenchmarkResponse( + metrics=metrics, + config={ + "endpoint": endpoint, + "model_name": model_name, + "concurrency": request.concurrency, + "num_requests": request.num_requests, + }, + error=bench_result.error, + started_at=bench_result.started_at, + completed_at=bench_result.completed_at, + duration_seconds=bench_result.duration_seconds, + ) + + +@router.post("/benchmarks/saturation", response_model=SaturationDetectionResponse) +async def run_saturation_detection( + request: SaturationDetectionRequest, + db: AsyncSession = Depends(get_db), + current_user: User = Depends(require_operator), +): + """ + Run saturation detection to find optimal concurrency. + + Automatically increases concurrency and detects: + - Throughput plateau (where adding more concurrency doesn't help) + - Latency degradation (where latency increases significantly) + - Error rate increases + + Returns the optimal concurrency level for the deployment. + """ + from app.services.benchmark import SaturationConfig, SaturationDetector + + # Verify deployment exists and is running + result = await db.execute( + select(Deployment) + .where(Deployment.id == request.deployment_id) + .options( + selectinload(Deployment.model), + selectinload(Deployment.worker), + ) + ) + deployment = result.scalar_one_or_none() + + if not deployment: + raise HTTPException(status_code=404, detail="Deployment not found") + + if deployment.status != DeploymentStatus.RUNNING.value: + raise HTTPException(status_code=400, detail="Deployment is not running") + + worker = deployment.worker + if not worker: + raise HTTPException(status_code=400, detail="Worker not found") + + endpoint = f"http://{worker.host}:{deployment.port}/v1" + model_name = deployment.model.model_id if deployment.model else "default" + + # Configure saturation detection + config = SaturationConfig( + enabled=True, + start_concurrency=request.start_concurrency, + max_concurrency=request.max_concurrency, + requests_per_level=request.requests_per_level, + use_exponential=request.use_exponential, + step_size=request.step_size, + step_multiplier=request.step_multiplier, + ) + + # Run saturation detection + detector = SaturationDetector(endpoint, model_name, config) + sat_result = await detector.detect() + + # Convert to response + return SaturationDetectionResponse( + optimal_concurrency=sat_result.optimal_concurrency, + max_throughput_tps=sat_result.max_throughput_tps, + latency_at_optimal_ms=sat_result.latency_at_optimal_ms, + saturation_concurrency=sat_result.saturation_concurrency, + saturation_detected=sat_result.saturation_detected, + stop_reason=sat_result.stop_reason, + concurrency_results=[ + ConcurrencyLevelResult( + concurrency=r.concurrency, + throughput_tps=r.throughput_tps, + avg_latency_ms=r.avg_latency_ms, + p95_latency_ms=r.p95_latency_ms, + success_rate=r.success_rate, + ) + for r in sat_result.results_by_concurrency + ], + ) + + # ============================================================================ # Knowledge Base Endpoints # ============================================================================ @@ -558,7 +757,7 @@ async def agent_chat( current_user: User = Depends(require_viewer), ): """Chat with the Auto-Tuning Agent""" - from app.services.tuning_agent import AGENT_SYSTEM_PROMPT, AgentToolExecutor, get_agent_tools + from app.services.tuning import AGENT_SYSTEM_PROMPT, AgentToolExecutor, get_agent_tools config = request.config provider = config.get("provider", "system") @@ -706,34 +905,44 @@ async def run_auto_tuning(job_id: int, llm_config: dict | None = None): async def _run_benchmark_test(deployment: Deployment, request: BenchmarkRequest) -> dict: - """Run actual benchmark test on a deployment using HTTP requests""" - from app.services.tuning_agent import _run_http_benchmark + """Run actual benchmark test on a deployment using the benchmark module""" + from app.services.benchmark import BenchmarkConfig, BenchmarkRunner, LoadPattern # Get worker info worker = deployment.worker if not worker: return {"error": "Worker not found"} - base_url = f"http://{worker.host}:{deployment.port}/v1" + endpoint = f"http://{worker.host}:{deployment.port}/v1" + model_name = deployment.model.model_id if deployment.model else "default" - result = await _run_http_benchmark( - base_url=base_url, - num_requests=max(10, request.concurrency * 5), + # Configure benchmark + config = BenchmarkConfig( + endpoint=endpoint, + model_name=model_name, + load_pattern=LoadPattern.FIXED, concurrency=request.concurrency, - input_tokens=request.input_length, + num_requests=max(20, request.concurrency * 5), + warmup_requests=5, + prompt_tokens=request.input_length, output_tokens=request.output_length, + stream=True, ) - if not result.get("success"): - return {"error": result.get("error", "Benchmark failed")} + # Run benchmark + runner = BenchmarkRunner(config) + result = await runner.run() + + if result.error: + return {"error": result.error} - metrics = result.get("metrics", {}) + metrics = result.metrics return { - "throughput_tps": metrics.get("throughput_tps"), - "ttft_ms": metrics.get("avg_ttft_ms"), - "tpot_ms": metrics.get("avg_tpot_ms"), - "total_latency_ms": None, # Not directly measured + "throughput_tps": metrics.output_tps, + "ttft_ms": metrics.ttft.mean, + "tpot_ms": metrics.tpot.mean, + "total_latency_ms": metrics.e2e_latency.mean, "gpu_utilization": None, # Would need GPU monitoring "vram_usage_gb": None, # Would need GPU monitoring - "raw": result.get("summary"), + "raw": metrics.to_dict(), } diff --git a/backend/app/api/workers.py b/backend/app/api/workers.py index 398fd20..ac448c4 100644 --- a/backend/app/api/workers.py +++ b/backend/app/api/workers.py @@ -71,6 +71,14 @@ async def list_workers( "status": worker.status, "gpu_info": worker.gpu_info, "system_info": worker.system_info, + "os_type": worker.os_type, + "gpu_type": worker.gpu_type, + "capabilities": worker.capabilities, + "available_backends": worker.available_backends, + "connection_type": worker.connection_type, + "tailscale_ip": worker.tailscale_ip, + "headscale_node_id": worker.headscale_node_id, + "effective_address": worker.effective_address, "created_at": worker.created_at, "updated_at": worker.updated_at, "last_heartbeat": worker.last_heartbeat, @@ -149,6 +157,14 @@ async def create_worker( original_worker.system_info = ( worker_in.system_info.model_dump() if worker_in.system_info else None ) + # Update os_type, gpu_type, capabilities from system_info + if worker_in.system_info: + if worker_in.system_info.os_type: + original_worker.os_type = worker_in.system_info.os_type + if worker_in.system_info.gpu_type: + original_worker.gpu_type = worker_in.system_info.gpu_type + if worker_in.system_info.capabilities: + original_worker.capabilities = worker_in.system_info.capabilities.model_dump() original_worker.status = WorkerStatus.ONLINE.value original_worker.last_heartbeat = datetime.now(UTC) @@ -173,6 +189,14 @@ async def create_worker( status=original_worker.status, gpu_info=original_worker.gpu_info, system_info=original_worker.system_info, + os_type=original_worker.os_type, + gpu_type=original_worker.gpu_type, + capabilities=original_worker.capabilities, + available_backends=original_worker.available_backends, + connection_type=original_worker.connection_type, + tailscale_ip=original_worker.tailscale_ip, + headscale_node_id=original_worker.headscale_node_id, + effective_address=original_worker.effective_address, created_at=original_worker.created_at, updated_at=original_worker.updated_at, last_heartbeat=original_worker.last_heartbeat, @@ -206,6 +230,18 @@ async def create_worker( if token.is_local: worker_labels["type"] = "local" + # Extract os_type, gpu_type, capabilities from system_info + os_type = "linux" + gpu_type = "nvidia" + capabilities = None + if worker_in.system_info: + if worker_in.system_info.os_type: + os_type = worker_in.system_info.os_type + if worker_in.system_info.gpu_type: + gpu_type = worker_in.system_info.gpu_type + if worker_in.system_info.capabilities: + capabilities = worker_in.system_info.capabilities.model_dump() + worker = Worker( name=worker_in.name, address=real_address, @@ -213,6 +249,9 @@ async def create_worker( labels=worker_labels if worker_labels else None, gpu_info=([gpu.model_dump() for gpu in worker_in.gpu_info] if worker_in.gpu_info else None), system_info=(worker_in.system_info.model_dump() if worker_in.system_info else None), + os_type=os_type, + gpu_type=gpu_type, + capabilities=capabilities, status=WorkerStatus.ONLINE.value, last_heartbeat=datetime.now(UTC), ) @@ -236,6 +275,14 @@ async def create_worker( status=worker.status, gpu_info=worker.gpu_info, system_info=worker.system_info, + os_type=worker.os_type, + gpu_type=worker.gpu_type, + capabilities=worker.capabilities, + available_backends=worker.available_backends, + connection_type=worker.connection_type, + tailscale_ip=worker.tailscale_ip, + headscale_node_id=worker.headscale_node_id, + effective_address=worker.effective_address, created_at=worker.created_at, updated_at=worker.updated_at, last_heartbeat=worker.last_heartbeat, @@ -268,6 +315,14 @@ async def get_worker( status=worker.status, gpu_info=worker.gpu_info, system_info=worker.system_info, + os_type=worker.os_type, + gpu_type=worker.gpu_type, + capabilities=worker.capabilities, + available_backends=worker.available_backends, + connection_type=worker.connection_type, + tailscale_ip=worker.tailscale_ip, + headscale_node_id=worker.headscale_node_id, + effective_address=worker.effective_address, created_at=worker.created_at, updated_at=worker.updated_at, last_heartbeat=worker.last_heartbeat, @@ -318,6 +373,14 @@ async def update_worker( status=worker.status, gpu_info=worker.gpu_info, system_info=worker.system_info, + os_type=worker.os_type, + gpu_type=worker.gpu_type, + capabilities=worker.capabilities, + available_backends=worker.available_backends, + connection_type=worker.connection_type, + tailscale_ip=worker.tailscale_ip, + headscale_node_id=worker.headscale_node_id, + effective_address=worker.effective_address, created_at=worker.created_at, updated_at=worker.updated_at, last_heartbeat=worker.last_heartbeat, @@ -377,7 +440,15 @@ async def worker_heartbeat( worker.gpu_info = [gpu.model_dump() for gpu in heartbeat.gpu_info] if heartbeat.system_info: - worker.system_info = heartbeat.system_info.model_dump() + system_data = heartbeat.system_info.model_dump() + worker.system_info = system_data + # Extract os_type, gpu_type, capabilities from system_info + if heartbeat.system_info.os_type: + worker.os_type = heartbeat.system_info.os_type + if heartbeat.system_info.gpu_type: + worker.gpu_type = heartbeat.system_info.gpu_type + if heartbeat.system_info.capabilities: + worker.capabilities = heartbeat.system_info.capabilities.model_dump() # Check if worker is going offline is_going_offline = heartbeat.status == WorkerStatus.OFFLINE @@ -731,21 +802,44 @@ async def _refresh_worker_resources(worker_id: int): continue try: async with httpx.AsyncClient(timeout=5.0) as client: - response = await client.get( - f"http://{worker.address}/containers/{deployment.container_id}" - ) - if response.status_code == 200: - container_info = response.json() - state = container_info.get("state", "").lower() - if state == "running": - deployment.status = DeploymentStatus.RUNNING.value - deployment.status_message = "Model ready" - elif state in ("exited", "dead"): - deployment.status = DeploymentStatus.STOPPED.value - deployment.status_message = f"Container {state}" - elif response.status_code == 404: - deployment.status = DeploymentStatus.ERROR.value - deployment.status_message = "Container not found" + # Check if this is a native deployment + if deployment.container_id.startswith("native-"): + # For native deployments, check via /native/processes + response = await client.get(f"http://{worker.address}/native/processes") + if response.status_code == 200: + processes = response.json().get("processes", []) + found = False + for p in processes: + if p.get("process_id") == deployment.container_id: + found = True + if p.get("running"): + deployment.status = DeploymentStatus.RUNNING.value + deployment.status_message = "Model ready" + else: + deployment.status = DeploymentStatus.STOPPED.value + deployment.status_message = "Process stopped" + break + if not found: + # Process not in manager, but might still be running via Ollama + # Don't mark as error, just skip + pass + else: + # Docker container check + response = await client.get( + f"http://{worker.address}/containers/{deployment.container_id}" + ) + if response.status_code == 200: + container_info = response.json() + state = container_info.get("state", "").lower() + if state == "running": + deployment.status = DeploymentStatus.RUNNING.value + deployment.status_message = "Model ready" + elif state in ("exited", "dead"): + deployment.status = DeploymentStatus.STOPPED.value + deployment.status_message = f"Container {state}" + elif response.status_code == 404: + deployment.status = DeploymentStatus.ERROR.value + deployment.status_message = "Container not found" except Exception as e: logger.warning(f"Failed to check deployment {deployment.id}: {e}") diff --git a/backend/app/models/llm_model.py b/backend/app/models/llm_model.py index ebca271..f34d6c0 100644 --- a/backend/app/models/llm_model.py +++ b/backend/app/models/llm_model.py @@ -12,10 +12,15 @@ class BackendType(str, Enum): """Inference backend type""" + # Docker-based backends (Linux with NVIDIA GPU) VLLM = "vllm" SGLANG = "sglang" OLLAMA = "ollama" + # Native backends (macOS with Apple Silicon) + MLX = "mlx" # MLX-LM for Apple Silicon + LLAMA_CPP = "llama_cpp" # llama.cpp with Metal + class ModelSource(str, Enum): """Model source type""" diff --git a/backend/app/models/worker.py b/backend/app/models/worker.py index 83fe603..3be352d 100644 --- a/backend/app/models/worker.py +++ b/backend/app/models/worker.py @@ -24,6 +24,23 @@ class ConnectionType(str, Enum): TAILSCALE = "tailscale" # Via Tailscale/Headscale VPN +class OSType(str, Enum): + """Worker operating system type""" + + LINUX = "linux" + DARWIN = "darwin" # macOS + WINDOWS = "windows" + + +class GPUType(str, Enum): + """GPU type for inference""" + + NVIDIA = "nvidia" + APPLE_SILICON = "apple_silicon" # M1/M2/M3/M4 + AMD = "amd" + NONE = "none" + + class Worker(Base): """Worker node model - represents a GPU node that can run LLM inference""" @@ -43,6 +60,13 @@ class Worker(Base): Integer, nullable=True ) # Node ID in Headscale + # System detection + os_type: Mapped[str] = mapped_column(String(50), default=OSType.LINUX.value) + gpu_type: Mapped[str] = mapped_column(String(50), default=GPUType.NVIDIA.value) + capabilities: Mapped[dict | None] = mapped_column( + JSON, nullable=True + ) # Available backends/tools + gpu_info: Mapped[dict | None] = mapped_column(JSON, nullable=True) system_info: Mapped[dict | None] = mapped_column(JSON, nullable=True) labels: Mapped[dict | None] = mapped_column(JSON, nullable=True) @@ -72,5 +96,40 @@ def effective_address(self) -> str: return f"{self.tailscale_ip}:{port}" return self.address + @property + def supports_docker(self) -> bool: + """Check if this worker supports Docker deployments.""" + # macOS without Docker support uses native process management + caps = self.capabilities or {} + return caps.get("docker", self.os_type == OSType.LINUX.value) + + @property + def is_mac(self) -> bool: + """Check if this is a macOS worker.""" + return self.os_type == OSType.DARWIN.value + + @property + def available_backends(self) -> list[str]: + """Get list of available backends for this worker.""" + caps = self.capabilities or {} + backends = [] + + if self.supports_docker: + # Docker-based backends + if self.gpu_type == GPUType.NVIDIA.value: + backends.extend(["vllm", "sglang", "ollama"]) + else: + backends.append("ollama") + else: + # Native backends (Mac) + if caps.get("ollama"): + backends.append("ollama") + if caps.get("mlx"): + backends.append("mlx") + if caps.get("llama_cpp"): + backends.append("llama_cpp") + + return backends + def __repr__(self) -> str: - return f"" + return f"" diff --git a/backend/app/schemas/tuning.py b/backend/app/schemas/tuning.py index 616f872..8a21a60 100644 --- a/backend/app/schemas/tuning.py +++ b/backend/app/schemas/tuning.py @@ -270,3 +270,116 @@ class KnowledgeSaveRequest(BaseModel): benchmark_result_id: int = Field(..., description="ID of the benchmark result to save") model_family: str = Field(..., description="Model family for categorization") model_params_b: float | None = Field(default=None, description="Model parameters in billions") + + +# ============================================================================ +# Comprehensive Benchmark Schemas +# ============================================================================ + + +class LatencyPercentiles(BaseModel): + """Latency metrics with percentiles""" + + mean: float = 0.0 + median: float = 0.0 + min: float = 0.0 + max: float = 0.0 + std: float = 0.0 + p50: float = 0.0 + p90: float = 0.0 + p95: float = 0.0 + p99: float = 0.0 + + +class ComprehensiveBenchmarkMetrics(BaseModel): + """Comprehensive benchmark metrics with percentiles""" + + # Latency metrics with percentiles + ttft: LatencyPercentiles = Field(default_factory=LatencyPercentiles) + itl: LatencyPercentiles = Field(default_factory=LatencyPercentiles) + tpot: LatencyPercentiles = Field(default_factory=LatencyPercentiles) + e2e_latency: LatencyPercentiles = Field(default_factory=LatencyPercentiles) + + # Throughput metrics + throughput_tps: float = Field(0.0, description="Total tokens per second") + throughput_rps: float = Field(0.0, description="Requests per second") + output_tps: float = Field(0.0, description="Output tokens per second") + + # Request statistics + total_requests: int = 0 + successful_requests: int = 0 + failed_requests: int = 0 + success_rate: float = 0.0 + + # Token statistics + total_prompt_tokens: int = 0 + total_completion_tokens: int = 0 + avg_prompt_tokens: float = 0.0 + avg_completion_tokens: float = 0.0 + + # Timing + total_duration_seconds: float = 0.0 + concurrency: int = 0 + + +class ComprehensiveBenchmarkRequest(BaseModel): + """Request for running a comprehensive benchmark""" + + deployment_id: int = Field(..., description="ID of the deployment to benchmark") + concurrency: int = Field(default=10, ge=1, le=128, description="Concurrent requests") + num_requests: int = Field(default=50, ge=10, le=1000, description="Total requests") + warmup_requests: int = Field(default=5, ge=0, le=20, description="Warmup requests") + prompt_tokens: int = Field(default=256, ge=32, le=8192, description="Approximate input tokens") + output_tokens: int = Field(default=128, ge=16, le=2048, description="Max output tokens") + custom_prompt: str | None = Field(default=None, description="Custom prompt to use") + + +class ComprehensiveBenchmarkResponse(BaseModel): + """Response for comprehensive benchmark""" + + metrics: ComprehensiveBenchmarkMetrics + config: dict + error: str | None = None + started_at: float + completed_at: float + duration_seconds: float + + +class SaturationDetectionRequest(BaseModel): + """Request for saturation detection""" + + deployment_id: int = Field(..., description="ID of the deployment to test") + start_concurrency: int = Field(default=1, ge=1, description="Starting concurrency") + max_concurrency: int = Field( + default=64, ge=1, le=256, description="Maximum concurrency to test" + ) + requests_per_level: int = Field(default=20, ge=10, le=100, description="Requests per level") + use_exponential: bool = Field(default=True, description="Use exponential stepping") + step_size: int = Field(default=2, ge=1, description="Linear step size") + step_multiplier: float = Field( + default=1.5, ge=1.1, le=3.0, description="Exponential multiplier" + ) + + +class ConcurrencyLevelResult(BaseModel): + """Result for a single concurrency level""" + + concurrency: int + throughput_tps: float + avg_latency_ms: float + p95_latency_ms: float + success_rate: float + + +class SaturationDetectionResponse(BaseModel): + """Response for saturation detection""" + + optimal_concurrency: int = Field(description="Recommended concurrency level") + max_throughput_tps: float = Field(description="Maximum throughput achieved") + latency_at_optimal_ms: float = Field(description="Latency at optimal concurrency") + saturation_concurrency: int = Field(description="Concurrency where saturation started") + saturation_detected: bool = Field(description="Whether saturation was detected") + stop_reason: str = Field(description="Reason for stopping") + concurrency_results: list[ConcurrencyLevelResult] = Field( + default_factory=list, description="Results for each concurrency level" + ) diff --git a/backend/app/schemas/worker.py b/backend/app/schemas/worker.py index e89b026..4c405e4 100644 --- a/backend/app/schemas/worker.py +++ b/backend/app/schemas/worker.py @@ -45,12 +45,27 @@ class DiskInfo(BaseModel): percent: float = 0 # percentage +class CapabilitiesInfo(BaseModel): + """Worker capabilities schema (available backends)""" + + os_type: str = "linux" # linux, darwin, windows + gpu_type: str = "nvidia" # nvidia, apple_silicon, amd, none + docker: bool = True # Docker available + ollama: bool = False # Ollama installed + ollama_running: bool = False # Ollama service running + mlx: bool = False # MLX-LM available (Mac only) + llama_cpp: bool = False # llama.cpp available + + class SystemInfo(BaseModel): """System information schema (CPU, Memory, Disk)""" cpu: CPUInfo | None = None memory: MemoryInfo | None = None disk: DiskInfo | None = None + os_type: str | None = None # linux, darwin, windows + gpu_type: str | None = None # nvidia, apple_silicon, amd, none + capabilities: CapabilitiesInfo | None = None # Available backends class WorkerBase(BaseModel): @@ -97,6 +112,10 @@ class WorkerResponse(WorkerBase): status: str gpu_info: list[dict] | None = None system_info: dict | None = None + os_type: str = "linux" # linux, darwin, windows + gpu_type: str = "nvidia" # nvidia, apple_silicon, amd, none + capabilities: dict | None = None # Available backends + available_backends: list[str] | None = None # List of available backend names tailscale_ip: str | None = None headscale_node_id: int | None = None effective_address: str | None = None # The actual address to use for connections diff --git a/backend/app/services/__init__.py b/backend/app/services/__init__.py index 97b5f64..d13eb9e 100644 --- a/backend/app/services/__init__.py +++ b/backend/app/services/__init__.py @@ -4,6 +4,7 @@ from app.services.deployer import DeployerService from app.services.deployment_sync import DeploymentSyncService, deployment_sync_service from app.services.gateway import GatewayService, gateway_service +from app.services.tuning import run_tuning_agent __all__ = [ "DeployerService", @@ -13,4 +14,5 @@ "gateway_service", "AuthService", "auth_service", + "run_tuning_agent", ] diff --git a/backend/app/services/bayesian_tuner.py b/backend/app/services/bayesian_tuner.py index 1abe1d7..1ba34d7 100644 --- a/backend/app/services/bayesian_tuner.py +++ b/backend/app/services/bayesian_tuner.py @@ -13,7 +13,6 @@ """ import asyncio -import json import logging import time from dataclasses import dataclass, field @@ -21,7 +20,6 @@ from enum import Enum from typing import Any -import httpx import optuna from optuna.samplers import TPESampler from sqlalchemy import select @@ -32,6 +30,7 @@ from app.models.deployment import Deployment, DeploymentStatus from app.models.tuning import OptimizationTarget, PerformanceKnowledge, TuningJob, TuningJobStatus from app.models.worker import Worker +from app.services.benchmark import BenchmarkConfig, BenchmarkRunner, LoadPattern from app.services.deployer import DeployerService # Configure logging with detailed format @@ -251,6 +250,7 @@ class ObjectiveCalculator: Computes optimization objective from benchmark metrics. Implements Scorer phase: score valid configurations by performance. + Supports comprehensive metrics including percentiles. """ def __init__(self, target: OptimizationTarget): @@ -262,27 +262,42 @@ def compute(self, metrics: dict[str, float]) -> float: For Optuna, we negate values when minimizing so all objectives are treated as maximization internally. + + Uses p95 latencies for SLA-focused optimization. """ throughput = metrics.get("throughput_tps", 0.0) - ttft = metrics.get("avg_ttft_ms", float("inf")) - tpot = metrics.get("avg_tpot_ms", float("inf")) + + # Use mean for basic calculation, p95 for SLA + ttft_mean = metrics.get("avg_ttft_ms", float("inf")) + tpot_mean = metrics.get("avg_tpot_ms", float("inf")) + ttft_p95 = metrics.get("ttft_p95_ms", ttft_mean) + tpot_p95 = metrics.get("tpot_p95_ms", tpot_mean) + itl_p95 = metrics.get("itl_p95_ms", 0.0) if self.target == OptimizationTarget.THROUGHPUT: # Maximize throughput return throughput elif self.target == OptimizationTarget.LATENCY: - # Minimize latency (negate for maximization) - if ttft == float("inf"): + # Minimize latency using p95 values for SLA compliance + if ttft_p95 == float("inf") or ttft_p95 == 0: return float("-inf") - return -1.0 * (ttft + tpot * 10) # Weight TPOT more + # Combined latency score: TTFT + (ITL * typical_output_tokens) + # Assume ~50 output tokens for a typical response + estimated_generation_latency = itl_p95 * 50 if itl_p95 > 0 else tpot_p95 * 50 + total_latency = ttft_p95 + estimated_generation_latency + return -1.0 * total_latency elif self.target == OptimizationTarget.BALANCED: - # Combined score: throughput / latency - if ttft == 0 or throughput == 0: + # Combined score: throughput / latency with p95 consideration + if ttft_mean == 0 or throughput == 0: return float("-inf") - latency_factor = 1 + (ttft / 100) + (tpot / 10) - return throughput / latency_factor + + # Penalize high p95 latencies more than mean + latency_factor = 1 + (ttft_mean / 100) + (tpot_mean / 10) + p95_penalty = 1 + (ttft_p95 - ttft_mean) / 200 + (tpot_p95 - tpot_mean) / 20 + + return throughput / (latency_factor * p95_penalty) elif self.target == OptimizationTarget.COST: # Maximize efficiency (throughput per resource unit) @@ -763,10 +778,14 @@ async def _wait_for_deployment(self, deployment_id: int, timeout: int = 600) -> async def _run_benchmark( self, deployment_id: int, - num_requests: int = 20, - concurrency: int = 4, + num_requests: int = 30, + concurrency: int = 8, ) -> dict[str, float]: - """Run benchmark against deployment""" + """ + Run benchmark against deployment using the benchmark module. + + Returns comprehensive metrics including percentiles. + """ result = await self.db.execute( select(Deployment) .where(Deployment.id == deployment_id) @@ -779,106 +798,32 @@ async def _run_benchmark( # Build endpoint URL worker_ip = deployment.worker.address.split(":")[0] - base_url = f"http://{worker_ip}:{deployment.port}/v1" + endpoint = f"http://{worker_ip}:{deployment.port}/v1" model_name = deployment.model.model_id - # Run HTTP benchmark (reuse existing implementation pattern) - return await self._http_benchmark(base_url, model_name, num_requests, concurrency) + # Configure benchmark + config = BenchmarkConfig( + endpoint=endpoint, + model_name=model_name, + load_pattern=LoadPattern.FIXED, + concurrency=concurrency, + num_requests=num_requests, + warmup_requests=5, + output_tokens=64, + prompt_tokens=200, + stream=True, + verbose=False, + ) - async def _http_benchmark( - self, - base_url: str, - model_name: str, - num_requests: int, - concurrency: int, - ) -> dict[str, float]: - """Execute HTTP benchmark against OpenAI-compatible endpoint""" - test_prompt = "Explain the concept of machine learning in simple terms. " * 20 - - results = [] - semaphore = asyncio.Semaphore(concurrency) - - async def make_request(client: httpx.AsyncClient) -> dict | None: - async with semaphore: - start = time.perf_counter() - first_token_time = None - token_count = 0 - - try: - async with client.stream( - "POST", - f"{base_url}/chat/completions", - json={ - "model": model_name, - "messages": [{"role": "user", "content": test_prompt}], - "max_tokens": 64, - "stream": True, - }, - timeout=60.0, - ) as resp: - if resp.status_code != 200: - return None - - async for line in resp.aiter_lines(): - if line.startswith("data: ") and line != "data: [DONE]": - try: - chunk = json.loads(line[6:]) - content = ( - chunk.get("choices", [{}])[0] - .get("delta", {}) - .get("content", "") - ) - if content: - if first_token_time is None: - first_token_time = time.perf_counter() - token_count += 1 - except json.JSONDecodeError: - pass - - end = time.perf_counter() - - if first_token_time and token_count > 0: - return { - "ttft_ms": (first_token_time - start) * 1000, - "tpot_ms": ( - ((end - first_token_time) / max(1, token_count - 1)) * 1000 - if token_count > 1 - else 0 - ), - "tokens": token_count, - "total_time": end - start, - } - except Exception: - pass - return None - - async with httpx.AsyncClient() as client: - # Warmup - for _ in range(2): - await make_request(client) - - # Actual benchmark - tasks = [make_request(client) for _ in range(num_requests)] - results = await asyncio.gather(*tasks) - - valid = [r for r in results if r] - if not valid: - return {"throughput_tps": 0, "avg_ttft_ms": 0, "avg_tpot_ms": 0} - - total_tokens = sum(r["tokens"] for r in valid) - total_time = sum(r["total_time"] for r in valid) - - return { - "throughput_tps": round(total_tokens / total_time, 2) if total_time > 0 else 0, - "avg_ttft_ms": round(sum(r["ttft_ms"] for r in valid) / len(valid), 2), - "avg_tpot_ms": round( - sum(r["tpot_ms"] for r in valid if r["tpot_ms"] > 0) - / max(1, len([r for r in valid if r["tpot_ms"] > 0])), - 2, - ), - "successful_requests": len(valid), - "total_requests": num_requests, - } + # Run benchmark + runner = BenchmarkRunner(config) + bench_result = await runner.run() + + if bench_result.error: + await self._log("WARNING", f"Benchmark warning: {bench_result.error}") + + # Return comprehensive metrics in simple dict format + return bench_result.metrics.to_simple_dict() async def _cleanup_deployment(self): """Stop and remove current deployment""" diff --git a/backend/app/services/benchmark/__init__.py b/backend/app/services/benchmark/__init__.py new file mode 100644 index 0000000..3bfb8cb --- /dev/null +++ b/backend/app/services/benchmark/__init__.py @@ -0,0 +1,46 @@ +""" +Benchmark Module + +Independent benchmarking system for LLM inference performance evaluation. +Supports comprehensive metrics, multiple load patterns, and saturation detection. + +Usage: + from app.services.benchmark import BenchmarkRunner, BenchmarkConfig, LoadPattern + + config = BenchmarkConfig( + endpoint="http://localhost:8000/v1", + model_name="Qwen/Qwen2.5-7B-Instruct", + load_pattern=LoadPattern.FIXED, + concurrency=10, + duration_seconds=60, + ) + + runner = BenchmarkRunner(config) + result = await runner.run() + + print(f"Throughput: {result.metrics.throughput_tps:.2f} TPS") + print(f"TTFT p95: {result.metrics.ttft.p95:.2f} ms") +""" + +from .config import BenchmarkConfig, LoadPattern, SaturationConfig +from .metrics import BenchmarkMetrics, LatencyMetrics, RequestResult, compute_percentiles +from .runner import BenchmarkResult, BenchmarkRunner +from .saturation import SaturationDetector, SaturationResult + +__all__ = [ + # Config + "BenchmarkConfig", + "LoadPattern", + "SaturationConfig", + # Metrics + "LatencyMetrics", + "BenchmarkMetrics", + "RequestResult", + "compute_percentiles", + # Runner + "BenchmarkRunner", + "BenchmarkResult", + # Saturation + "SaturationDetector", + "SaturationResult", +] diff --git a/backend/app/services/benchmark/config.py b/backend/app/services/benchmark/config.py new file mode 100644 index 0000000..e2b3f5b --- /dev/null +++ b/backend/app/services/benchmark/config.py @@ -0,0 +1,147 @@ +""" +Benchmark Configuration + +Defines configuration options for benchmark execution. +""" + +from dataclasses import dataclass, field +from enum import Enum +from typing import Any + + +class LoadPattern(Enum): + """Load pattern types for benchmark execution""" + + # Fixed concurrency throughout the test + FIXED = "fixed" + + # Gradually increase concurrency to find saturation point + INCREMENTAL = "incremental" + + # Burst load pattern with spikes + BURST = "burst" + + # Step-wise increase for finding optimal concurrency + STEP = "step" + + +@dataclass +class SaturationConfig: + """Configuration for saturation detection""" + + # Enable saturation detection (auto-find optimal concurrency) + enabled: bool = False + + # Starting concurrency level + start_concurrency: int = 1 + + # Maximum concurrency to test + max_concurrency: int = 64 + + # Concurrency step size + step_size: int = 2 + + # Multiplier for exponential stepping (alternative to linear) + step_multiplier: float = 1.5 + + # Use exponential stepping instead of linear + use_exponential: bool = True + + # Minimum requests per concurrency level + requests_per_level: int = 20 + + # Throughput degradation threshold (e.g., 0.95 = 5% drop triggers saturation) + degradation_threshold: float = 0.95 + + # Latency increase threshold (e.g., 1.5 = 50% increase triggers saturation) + latency_threshold: float = 1.5 + + # Number of consecutive degradations before declaring saturation + consecutive_degradations: int = 2 + + +@dataclass +class BenchmarkConfig: + """Configuration for benchmark execution""" + + # Target endpoint (OpenAI-compatible API) + endpoint: str + + # Model name/ID for API requests + model_name: str + + # Load pattern to use + load_pattern: LoadPattern = LoadPattern.FIXED + + # Concurrency level (for FIXED pattern) + concurrency: int = 10 + + # Duration in seconds (0 = use num_requests instead) + duration_seconds: int = 0 + + # Number of requests (used when duration_seconds = 0) + num_requests: int = 50 + + # Warmup requests before actual benchmark + warmup_requests: int = 5 + + # Request timeout in seconds + request_timeout: float = 120.0 + + # Prompt configuration + prompt_tokens: int = 256 # Approximate input tokens + output_tokens: int = 128 # Max output tokens + + # Custom prompt (overrides prompt_tokens if set) + custom_prompt: str | None = None + + # Saturation detection config + saturation: SaturationConfig = field(default_factory=SaturationConfig) + + # Extra parameters for requests + extra_params: dict[str, Any] = field(default_factory=dict) + + # Stream responses (required for accurate ITL measurement) + stream: bool = True + + # Verbose logging + verbose: bool = False + + def get_prompt(self) -> str: + """Get the prompt to use for benchmarking""" + if self.custom_prompt: + return self.custom_prompt + + # Generate synthetic prompt of approximately prompt_tokens length + # Average English word is ~4 characters, ~1.3 tokens + words_needed = int(self.prompt_tokens / 1.3) + base_prompt = "Explain the following concept in detail: " + filler = "artificial intelligence machine learning neural network deep learning natural language processing computer vision reinforcement learning transformer architecture attention mechanism gradient descent backpropagation optimization algorithm data science statistical analysis pattern recognition feature extraction dimensionality reduction clustering classification regression prediction inference training validation testing deployment scalability performance efficiency accuracy precision recall " + + # Repeat filler to reach desired length + repeated = (filler * ((words_needed // len(filler.split())) + 1)).split() + prompt = base_prompt + " ".join(repeated[:words_needed]) + + return prompt + + def validate(self) -> tuple[bool, str]: + """Validate configuration""" + if not self.endpoint: + return False, "Endpoint is required" + + if not self.model_name: + return False, "Model name is required" + + if self.concurrency < 1: + return False, "Concurrency must be at least 1" + + if self.duration_seconds < 0: + return False, "Duration cannot be negative" + + if self.num_requests < 1 and self.duration_seconds == 0: + return False, "Either duration_seconds or num_requests must be positive" + + if self.request_timeout < 1: + return False, "Request timeout must be at least 1 second" + + return True, "" diff --git a/backend/app/services/benchmark/metrics.py b/backend/app/services/benchmark/metrics.py new file mode 100644 index 0000000..0792213 --- /dev/null +++ b/backend/app/services/benchmark/metrics.py @@ -0,0 +1,254 @@ +""" +Benchmark Metrics + +Comprehensive metrics calculation including: +- TTFT (Time to First Token) +- ITL (Inter-Token Latency) +- TPOT (Time Per Output Token) +- Percentiles (p50, p90, p95, p99) +- Throughput (tokens/sec, requests/sec) +""" + +import statistics +from collections.abc import Sequence +from dataclasses import dataclass, field + + +@dataclass +class RequestResult: + """Result from a single request""" + + # Timing metrics (in milliseconds) + ttft_ms: float = 0.0 # Time to first token + total_latency_ms: float = 0.0 # Total request duration + token_timestamps_ms: list[float] = field(default_factory=list) # Timestamp of each token + + # Token counts + prompt_tokens: int = 0 + completion_tokens: int = 0 + + # Request metadata + success: bool = True + error_message: str | None = None + start_time: float = 0.0 + end_time: float = 0.0 + + @property + def itl_values_ms(self) -> list[float]: + """Calculate inter-token latencies from timestamps""" + if len(self.token_timestamps_ms) < 2: + return [] + + itl_values = [] + for i in range(1, len(self.token_timestamps_ms)): + itl = self.token_timestamps_ms[i] - self.token_timestamps_ms[i - 1] + itl_values.append(itl) + return itl_values + + @property + def tpot_ms(self) -> float: + """Time per output token (excluding TTFT)""" + if self.completion_tokens <= 1: + return 0.0 + + generation_time = self.total_latency_ms - self.ttft_ms + return generation_time / (self.completion_tokens - 1) + + +@dataclass +class LatencyMetrics: + """Latency metrics with percentiles""" + + mean: float = 0.0 + median: float = 0.0 + min: float = 0.0 + max: float = 0.0 + std: float = 0.0 + p50: float = 0.0 + p90: float = 0.0 + p95: float = 0.0 + p99: float = 0.0 + + def to_dict(self) -> dict[str, float]: + """Convert to dictionary""" + return { + "mean": round(self.mean, 2), + "median": round(self.median, 2), + "min": round(self.min, 2), + "max": round(self.max, 2), + "std": round(self.std, 2), + "p50": round(self.p50, 2), + "p90": round(self.p90, 2), + "p95": round(self.p95, 2), + "p99": round(self.p99, 2), + } + + +def compute_percentiles(values: Sequence[float]) -> LatencyMetrics: + """Compute latency metrics with percentiles from a sequence of values""" + if not values: + return LatencyMetrics() + + sorted_values = sorted(values) + n = len(sorted_values) + + def percentile(p: float) -> float: + """Calculate percentile value""" + if n == 1: + return sorted_values[0] + k = (n - 1) * (p / 100) + f = int(k) + c = f + 1 if f + 1 < n else f + return sorted_values[f] + (k - f) * (sorted_values[c] - sorted_values[f]) + + return LatencyMetrics( + mean=statistics.mean(values), + median=statistics.median(values), + min=min(values), + max=max(values), + std=statistics.stdev(values) if n > 1 else 0.0, + p50=percentile(50), + p90=percentile(90), + p95=percentile(95), + p99=percentile(99), + ) + + +@dataclass +class BenchmarkMetrics: + """Comprehensive benchmark metrics""" + + # Latency metrics + ttft: LatencyMetrics = field(default_factory=LatencyMetrics) + itl: LatencyMetrics = field(default_factory=LatencyMetrics) + tpot: LatencyMetrics = field(default_factory=LatencyMetrics) + e2e_latency: LatencyMetrics = field(default_factory=LatencyMetrics) + + # Throughput metrics + throughput_tps: float = 0.0 # Tokens per second + throughput_rps: float = 0.0 # Requests per second + output_tps: float = 0.0 # Output tokens per second + + # Request statistics + total_requests: int = 0 + successful_requests: int = 0 + failed_requests: int = 0 + success_rate: float = 0.0 + + # Token statistics + total_prompt_tokens: int = 0 + total_completion_tokens: int = 0 + avg_prompt_tokens: float = 0.0 + avg_completion_tokens: float = 0.0 + + # Timing + total_duration_seconds: float = 0.0 + concurrency: int = 0 + + def to_dict(self) -> dict: + """Convert to dictionary for serialization""" + return { + "ttft": self.ttft.to_dict(), + "itl": self.itl.to_dict(), + "tpot": self.tpot.to_dict(), + "e2e_latency": self.e2e_latency.to_dict(), + "throughput_tps": round(self.throughput_tps, 2), + "throughput_rps": round(self.throughput_rps, 4), + "output_tps": round(self.output_tps, 2), + "total_requests": self.total_requests, + "successful_requests": self.successful_requests, + "failed_requests": self.failed_requests, + "success_rate": round(self.success_rate, 4), + "total_prompt_tokens": self.total_prompt_tokens, + "total_completion_tokens": self.total_completion_tokens, + "avg_prompt_tokens": round(self.avg_prompt_tokens, 1), + "avg_completion_tokens": round(self.avg_completion_tokens, 1), + "total_duration_seconds": round(self.total_duration_seconds, 2), + "concurrency": self.concurrency, + } + + @classmethod + def from_results( + cls, + results: list[RequestResult], + total_duration: float, + concurrency: int, + ) -> "BenchmarkMetrics": + """Compute metrics from a list of request results""" + if not results: + return cls() + + successful = [r for r in results if r.success] + failed = [r for r in results if not r.success] + + if not successful: + return cls( + total_requests=len(results), + failed_requests=len(failed), + total_duration_seconds=total_duration, + concurrency=concurrency, + ) + + # Collect latency values + ttft_values = [r.ttft_ms for r in successful if r.ttft_ms > 0] + e2e_values = [r.total_latency_ms for r in successful if r.total_latency_ms > 0] + tpot_values = [r.tpot_ms for r in successful if r.tpot_ms > 0] + + # Collect all ITL values + all_itl_values: list[float] = [] + for r in successful: + all_itl_values.extend(r.itl_values_ms) + + # Token counts + total_prompt = sum(r.prompt_tokens for r in successful) + total_completion = sum(r.completion_tokens for r in successful) + total_tokens = total_prompt + total_completion + + # Throughput + throughput_tps = total_tokens / total_duration if total_duration > 0 else 0 + throughput_rps = len(successful) / total_duration if total_duration > 0 else 0 + output_tps = total_completion / total_duration if total_duration > 0 else 0 + + return cls( + ttft=compute_percentiles(ttft_values), + itl=compute_percentiles(all_itl_values) if all_itl_values else LatencyMetrics(), + tpot=compute_percentiles(tpot_values), + e2e_latency=compute_percentiles(e2e_values), + throughput_tps=throughput_tps, + throughput_rps=throughput_rps, + output_tps=output_tps, + total_requests=len(results), + successful_requests=len(successful), + failed_requests=len(failed), + success_rate=len(successful) / len(results) if results else 0, + total_prompt_tokens=total_prompt, + total_completion_tokens=total_completion, + avg_prompt_tokens=total_prompt / len(successful) if successful else 0, + avg_completion_tokens=total_completion / len(successful) if successful else 0, + total_duration_seconds=total_duration, + concurrency=concurrency, + ) + + def to_simple_dict(self) -> dict[str, float]: + """Convert to simple dictionary for backward compatibility with BayesianTuner""" + return { + "throughput_tps": self.output_tps, + "avg_ttft_ms": self.ttft.mean, + "avg_tpot_ms": self.tpot.mean, + "avg_itl_ms": self.itl.mean, + "ttft_p50_ms": self.ttft.p50, + "ttft_p95_ms": self.ttft.p95, + "ttft_p99_ms": self.ttft.p99, + "tpot_p50_ms": self.tpot.p50, + "tpot_p95_ms": self.tpot.p95, + "tpot_p99_ms": self.tpot.p99, + "itl_p50_ms": self.itl.p50, + "itl_p95_ms": self.itl.p95, + "itl_p99_ms": self.itl.p99, + "e2e_latency_p50_ms": self.e2e_latency.p50, + "e2e_latency_p95_ms": self.e2e_latency.p95, + "e2e_latency_p99_ms": self.e2e_latency.p99, + "successful_requests": float(self.successful_requests), + "total_requests": float(self.total_requests), + "success_rate": self.success_rate, + } diff --git a/backend/app/services/benchmark/runner.py b/backend/app/services/benchmark/runner.py new file mode 100644 index 0000000..6c37dd2 --- /dev/null +++ b/backend/app/services/benchmark/runner.py @@ -0,0 +1,484 @@ +""" +Benchmark Runner + +Executes benchmarks against OpenAI-compatible endpoints with accurate +token-level timing measurements using streaming responses. +""" + +import asyncio +import json +import logging +import time +from dataclasses import dataclass, field +from typing import Any + +import httpx + +from .config import BenchmarkConfig, LoadPattern +from .metrics import BenchmarkMetrics, RequestResult + +logger = logging.getLogger(__name__) + + +@dataclass +class BenchmarkResult: + """Complete benchmark result""" + + metrics: BenchmarkMetrics + config: BenchmarkConfig + raw_results: list[RequestResult] = field(default_factory=list) + error: str | None = None + started_at: float = 0.0 + completed_at: float = 0.0 + + @property + def duration_seconds(self) -> float: + """Total benchmark duration""" + return self.completed_at - self.started_at + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary""" + return { + "metrics": self.metrics.to_dict(), + "config": { + "endpoint": self.config.endpoint, + "model_name": self.config.model_name, + "load_pattern": self.config.load_pattern.value, + "concurrency": self.config.concurrency, + "num_requests": self.config.num_requests, + "duration_seconds": self.config.duration_seconds, + }, + "error": self.error, + "started_at": self.started_at, + "completed_at": self.completed_at, + "duration_seconds": self.duration_seconds, + } + + +class BenchmarkRunner: + """ + Executes benchmarks against LLM inference endpoints. + + Supports multiple load patterns and accurate token-level timing. + """ + + def __init__(self, config: BenchmarkConfig): + self.config = config + self._cancelled = False + self._results: list[RequestResult] = [] + self._active_requests = 0 + self._completed_requests = 0 + self._start_time: float = 0.0 + + async def run(self) -> BenchmarkResult: + """Execute the benchmark""" + # Validate config + valid, error = self.config.validate() + if not valid: + return BenchmarkResult( + metrics=BenchmarkMetrics(), + config=self.config, + error=error, + ) + + self._start_time = time.time() + result = BenchmarkResult( + metrics=BenchmarkMetrics(), + config=self.config, + started_at=self._start_time, + ) + + try: + # Run warmup + if self.config.warmup_requests > 0: + await self._run_warmup() + + # Execute based on load pattern + if self.config.load_pattern == LoadPattern.FIXED: + await self._run_fixed_load() + elif self.config.load_pattern == LoadPattern.INCREMENTAL: + await self._run_incremental_load() + elif self.config.load_pattern == LoadPattern.BURST: + await self._run_burst_load() + elif self.config.load_pattern == LoadPattern.STEP: + await self._run_step_load() + else: + await self._run_fixed_load() + + # Calculate metrics + end_time = time.time() + total_duration = end_time - self._start_time + + result.metrics = BenchmarkMetrics.from_results( + self._results, + total_duration, + self.config.concurrency, + ) + result.raw_results = self._results + result.completed_at = end_time + + except Exception as e: + logger.exception(f"Benchmark failed: {e}") + result.error = str(e) + result.completed_at = time.time() + + return result + + def cancel(self): + """Cancel the running benchmark""" + self._cancelled = True + + async def _run_warmup(self): + """Run warmup requests""" + if self.config.verbose: + logger.info(f"Running {self.config.warmup_requests} warmup requests...") + + async with httpx.AsyncClient(timeout=self.config.request_timeout) as client: + tasks = [ + self._make_request(client, warmup=True) for _ in range(self.config.warmup_requests) + ] + await asyncio.gather(*tasks, return_exceptions=True) + + if self.config.verbose: + logger.info("Warmup complete") + + async def _run_fixed_load(self): + """Run with fixed concurrency""" + semaphore = asyncio.Semaphore(self.config.concurrency) + + async def limited_request(client: httpx.AsyncClient): + async with semaphore: + return await self._make_request(client) + + async with httpx.AsyncClient(timeout=self.config.request_timeout) as client: + if self.config.duration_seconds > 0: + # Duration-based + end_time = time.time() + self.config.duration_seconds + tasks: list[asyncio.Task] = [] + + while time.time() < end_time and not self._cancelled: + if len(tasks) < self.config.concurrency * 2: + task = asyncio.create_task(limited_request(client)) + tasks.append(task) + + # Clean up completed tasks + done = [t for t in tasks if t.done()] + for t in done: + try: + result = t.result() + if result: + self._results.append(result) + except Exception: + pass + tasks.remove(t) + + await asyncio.sleep(0.01) + + # Wait for remaining tasks + for task in tasks: + try: + result = await task + if result: + self._results.append(result) + except Exception: + pass + else: + # Request count-based + tasks = [limited_request(client) for _ in range(self.config.num_requests)] + results = await asyncio.gather(*tasks, return_exceptions=True) + + for r in results: + if isinstance(r, RequestResult): + self._results.append(r) + + async def _run_incremental_load(self): + """Run with incrementally increasing concurrency""" + sat_config = self.config.saturation + current_concurrency = sat_config.start_concurrency + + async with httpx.AsyncClient(timeout=self.config.request_timeout) as client: + while current_concurrency <= sat_config.max_concurrency and not self._cancelled: + if self.config.verbose: + logger.info(f"Testing concurrency level: {current_concurrency}") + + semaphore = asyncio.Semaphore(current_concurrency) + + async def limited_request(sem=semaphore): + async with sem: + return await self._make_request(client) + + tasks = [limited_request() for _ in range(sat_config.requests_per_level)] + results = await asyncio.gather(*tasks, return_exceptions=True) + + for r in results: + if isinstance(r, RequestResult): + self._results.append(r) + + # Increase concurrency + if sat_config.use_exponential: + current_concurrency = int(current_concurrency * sat_config.step_multiplier) + else: + current_concurrency += sat_config.step_size + + async def _run_burst_load(self): + """Run with burst load pattern""" + # Burst pattern: low -> high -> low + concurrency_levels = [ + self.config.concurrency // 4, # Low + self.config.concurrency, # High (burst) + self.config.concurrency // 4, # Low + self.config.concurrency, # High (burst) + self.config.concurrency // 4, # Low + ] + + requests_per_phase = self.config.num_requests // len(concurrency_levels) + + async with httpx.AsyncClient(timeout=self.config.request_timeout) as client: + for phase, concurrency in enumerate(concurrency_levels): + if self._cancelled: + break + + if self.config.verbose: + logger.info(f"Burst phase {phase + 1}: concurrency={concurrency}") + + semaphore = asyncio.Semaphore(max(1, concurrency)) + + async def limited_request(sem=semaphore): + async with sem: + return await self._make_request(client) + + tasks = [limited_request() for _ in range(requests_per_phase)] + results = await asyncio.gather(*tasks, return_exceptions=True) + + for r in results: + if isinstance(r, RequestResult): + self._results.append(r) + + async def _run_step_load(self): + """Run with step-wise load increase""" + sat_config = self.config.saturation + steps = [] + + # Generate step levels + current = sat_config.start_concurrency + while current <= sat_config.max_concurrency: + steps.append(current) + if sat_config.use_exponential: + current = int(current * sat_config.step_multiplier) + else: + current += sat_config.step_size + + async with httpx.AsyncClient(timeout=self.config.request_timeout) as client: + for step_idx, concurrency in enumerate(steps): + if self._cancelled: + break + + if self.config.verbose: + logger.info(f"Step {step_idx + 1}/{len(steps)}: concurrency={concurrency}") + + semaphore = asyncio.Semaphore(concurrency) + + async def limited_request(sem=semaphore): + async with sem: + return await self._make_request(client) + + tasks = [limited_request() for _ in range(sat_config.requests_per_level)] + results = await asyncio.gather(*tasks, return_exceptions=True) + + for r in results: + if isinstance(r, RequestResult): + self._results.append(r) + + async def _make_request( + self, + client: httpx.AsyncClient, + warmup: bool = False, + ) -> RequestResult | None: + """Make a single request with token-level timing""" + if self._cancelled: + return None + + result = RequestResult(start_time=time.time()) + prompt = self.config.get_prompt() + + request_body = { + "model": self.config.model_name, + "messages": [{"role": "user", "content": prompt}], + "max_tokens": self.config.output_tokens, + "stream": self.config.stream, + **self.config.extra_params, + } + + endpoint = self.config.endpoint.rstrip("/") + url = f"{endpoint}/chat/completions" + + try: + start_time = time.perf_counter() + + if self.config.stream: + result = await self._make_streaming_request(client, url, request_body, start_time) + else: + result = await self._make_non_streaming_request( + client, url, request_body, start_time + ) + + result.end_time = time.time() + + if self.config.verbose and not warmup: + logger.info( + f"Request completed: TTFT={result.ttft_ms:.1f}ms, " + f"tokens={result.completion_tokens}, " + f"latency={result.total_latency_ms:.1f}ms" + ) + + except httpx.TimeoutException: + result.success = False + result.error_message = "Request timeout" + result.end_time = time.time() + + except Exception as e: + result.success = False + result.error_message = str(e) + result.end_time = time.time() + + return result + + async def _make_streaming_request( + self, + client: httpx.AsyncClient, + url: str, + request_body: dict, + start_time: float, + ) -> RequestResult: + """Make a streaming request with token-level timing""" + result = RequestResult(start_time=time.time()) + first_token_received = False + token_timestamps: list[float] = [] + token_count = 0 + + async with client.stream("POST", url, json=request_body) as response: + if response.status_code != 200: + result.success = False + result.error_message = f"HTTP {response.status_code}" + return result + + async for line in response.aiter_lines(): + if self._cancelled: + break + + if not line.startswith("data: "): + continue + + data = line[6:] + if data == "[DONE]": + break + + try: + chunk = json.loads(data) + choices = chunk.get("choices", []) + if not choices: + continue + + delta = choices[0].get("delta", {}) + content = delta.get("content", "") + + if content: + current_time = time.perf_counter() + + if not first_token_received: + result.ttft_ms = (current_time - start_time) * 1000 + first_token_received = True + + # Record timestamp for each token + token_timestamps.append((current_time - start_time) * 1000) + token_count += 1 + + # Extract usage if available (some backends include it) + usage = chunk.get("usage", {}) + if usage: + result.prompt_tokens = usage.get("prompt_tokens", result.prompt_tokens) + result.completion_tokens = usage.get("completion_tokens", token_count) + + except json.JSONDecodeError: + continue + + end_time = time.perf_counter() + result.total_latency_ms = (end_time - start_time) * 1000 + result.token_timestamps_ms = token_timestamps + result.completion_tokens = ( + token_count if result.completion_tokens == 0 else result.completion_tokens + ) + + return result + + async def _make_non_streaming_request( + self, + client: httpx.AsyncClient, + url: str, + request_body: dict, + start_time: float, + ) -> RequestResult: + """Make a non-streaming request""" + result = RequestResult(start_time=time.time()) + + response = await client.post(url, json=request_body) + end_time = time.perf_counter() + + if response.status_code != 200: + result.success = False + result.error_message = f"HTTP {response.status_code}" + return result + + try: + data = response.json() + usage = data.get("usage", {}) + result.prompt_tokens = usage.get("prompt_tokens", 0) + result.completion_tokens = usage.get("completion_tokens", 0) + + # For non-streaming, TTFT = total latency + result.total_latency_ms = (end_time - start_time) * 1000 + result.ttft_ms = result.total_latency_ms + + except Exception as e: + result.success = False + result.error_message = str(e) + + return result + + +async def run_benchmark( + endpoint: str, + model_name: str, + concurrency: int = 10, + num_requests: int = 50, + warmup_requests: int = 5, + output_tokens: int = 128, + verbose: bool = False, +) -> BenchmarkResult: + """ + Convenience function to run a simple benchmark. + + Args: + endpoint: OpenAI-compatible API endpoint + model_name: Model name/ID + concurrency: Number of concurrent requests + num_requests: Total number of requests + warmup_requests: Number of warmup requests + output_tokens: Max output tokens per request + verbose: Enable verbose logging + + Returns: + BenchmarkResult with metrics + """ + config = BenchmarkConfig( + endpoint=endpoint, + model_name=model_name, + concurrency=concurrency, + num_requests=num_requests, + warmup_requests=warmup_requests, + output_tokens=output_tokens, + verbose=verbose, + ) + + runner = BenchmarkRunner(config) + return await runner.run() diff --git a/backend/app/services/benchmark/saturation.py b/backend/app/services/benchmark/saturation.py new file mode 100644 index 0000000..c6a14bd --- /dev/null +++ b/backend/app/services/benchmark/saturation.py @@ -0,0 +1,289 @@ +""" +Saturation Detection + +Automatically finds the optimal concurrency level by detecting when +performance starts to degrade (throughput plateaus or latency spikes). +""" + +import logging +from dataclasses import dataclass, field +from typing import Any + +from .config import BenchmarkConfig, LoadPattern, SaturationConfig +from .metrics import BenchmarkMetrics +from .runner import BenchmarkRunner + +logger = logging.getLogger(__name__) + + +@dataclass +class ConcurrencyResult: + """Result for a single concurrency level""" + + concurrency: int + metrics: BenchmarkMetrics + throughput_tps: float = 0.0 + avg_latency_ms: float = 0.0 + p95_latency_ms: float = 0.0 + success_rate: float = 0.0 + + +@dataclass +class SaturationResult: + """Result from saturation detection""" + + # Optimal concurrency found + optimal_concurrency: int = 0 + + # Maximum throughput achieved + max_throughput_tps: float = 0.0 + + # Latency at optimal concurrency + latency_at_optimal_ms: float = 0.0 + + # Saturation concurrency (where degradation starts) + saturation_concurrency: int = 0 + + # All tested concurrency levels + results_by_concurrency: list[ConcurrencyResult] = field(default_factory=list) + + # Was saturation detected? + saturation_detected: bool = False + + # Reason for stopping + stop_reason: str = "" + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary""" + return { + "optimal_concurrency": self.optimal_concurrency, + "max_throughput_tps": round(self.max_throughput_tps, 2), + "latency_at_optimal_ms": round(self.latency_at_optimal_ms, 2), + "saturation_concurrency": self.saturation_concurrency, + "saturation_detected": self.saturation_detected, + "stop_reason": self.stop_reason, + "concurrency_results": [ + { + "concurrency": r.concurrency, + "throughput_tps": round(r.throughput_tps, 2), + "avg_latency_ms": round(r.avg_latency_ms, 2), + "p95_latency_ms": round(r.p95_latency_ms, 2), + "success_rate": round(r.success_rate, 4), + } + for r in self.results_by_concurrency + ], + } + + +class SaturationDetector: + """ + Detects optimal concurrency by running incremental load tests. + + Algorithm: + 1. Start with low concurrency + 2. Increase concurrency and measure throughput/latency + 3. Stop when: + - Throughput stops increasing (plateau) + - Latency increases significantly + - Error rate increases + 4. Return the concurrency level with best throughput/latency balance + """ + + def __init__( + self, + endpoint: str, + model_name: str, + config: SaturationConfig | None = None, + ): + self.endpoint = endpoint + self.model_name = model_name + self.config = config or SaturationConfig(enabled=True) + self._cancelled = False + + async def detect(self) -> SaturationResult: + """Run saturation detection""" + result = SaturationResult() + results_by_level: list[ConcurrencyResult] = [] + + current_concurrency = self.config.start_concurrency + peak_throughput = 0.0 + baseline_latency = 0.0 + consecutive_degradations = 0 + + logger.info( + f"Starting saturation detection: " + f"start={self.config.start_concurrency}, " + f"max={self.config.max_concurrency}" + ) + + while current_concurrency <= self.config.max_concurrency and not self._cancelled: + logger.info(f"Testing concurrency: {current_concurrency}") + + # Run benchmark at current concurrency + benchmark_config = BenchmarkConfig( + endpoint=self.endpoint, + model_name=self.model_name, + load_pattern=LoadPattern.FIXED, + concurrency=current_concurrency, + num_requests=self.config.requests_per_level, + warmup_requests=2, + verbose=False, + ) + + runner = BenchmarkRunner(benchmark_config) + bench_result = await runner.run() + + if bench_result.error: + logger.warning( + f"Benchmark error at concurrency {current_concurrency}: {bench_result.error}" + ) + consecutive_degradations += 1 + + if consecutive_degradations >= self.config.consecutive_degradations: + result.stop_reason = f"Too many errors: {bench_result.error}" + result.saturation_detected = True + result.saturation_concurrency = current_concurrency + break + + # Skip to next level + current_concurrency = self._next_concurrency(current_concurrency) + continue + + metrics = bench_result.metrics + + # Record result + level_result = ConcurrencyResult( + concurrency=current_concurrency, + metrics=metrics, + throughput_tps=metrics.output_tps, + avg_latency_ms=metrics.e2e_latency.mean, + p95_latency_ms=metrics.e2e_latency.p95, + success_rate=metrics.success_rate, + ) + results_by_level.append(level_result) + + logger.info( + f"Concurrency {current_concurrency}: " + f"throughput={metrics.output_tps:.1f} TPS, " + f"latency={metrics.e2e_latency.mean:.1f}ms (p95={metrics.e2e_latency.p95:.1f}ms), " + f"success={metrics.success_rate:.1%}" + ) + + # Set baseline latency from first successful run + if baseline_latency == 0.0 and metrics.e2e_latency.mean > 0: + baseline_latency = metrics.e2e_latency.mean + + # Check for peak throughput + if metrics.output_tps > peak_throughput: + peak_throughput = metrics.output_tps + consecutive_degradations = 0 + + # Check for degradation + degradation_detected = False + + # Throughput degradation + if peak_throughput > 0: + throughput_ratio = metrics.output_tps / peak_throughput + if throughput_ratio < self.config.degradation_threshold: + logger.info( + f"Throughput degradation detected: " + f"{metrics.output_tps:.1f} vs peak {peak_throughput:.1f}" + ) + degradation_detected = True + + # Latency degradation + if baseline_latency > 0: + latency_ratio = metrics.e2e_latency.mean / baseline_latency + if latency_ratio > self.config.latency_threshold: + logger.info( + f"Latency degradation detected: " + f"{metrics.e2e_latency.mean:.1f}ms vs baseline {baseline_latency:.1f}ms" + ) + degradation_detected = True + + # Success rate check + if metrics.success_rate < 0.95: + logger.info(f"High error rate detected: {1 - metrics.success_rate:.1%}") + degradation_detected = True + + if degradation_detected: + consecutive_degradations += 1 + if consecutive_degradations >= self.config.consecutive_degradations: + result.saturation_detected = True + result.saturation_concurrency = current_concurrency + result.stop_reason = "Performance degradation detected" + break + else: + consecutive_degradations = 0 + + # Move to next concurrency level + current_concurrency = self._next_concurrency(current_concurrency) + + # Find optimal concurrency (best throughput/latency ratio) + if results_by_level: + result.results_by_concurrency = results_by_level + + # Find best result (highest throughput with acceptable latency) + best = max( + results_by_level, + key=lambda r: ( + r.throughput_tps / max(1, r.avg_latency_ms / 100) if r.success_rate > 0.9 else 0 + ), + ) + + result.optimal_concurrency = best.concurrency + result.max_throughput_tps = peak_throughput + result.latency_at_optimal_ms = best.avg_latency_ms + + if not result.saturation_detected: + result.stop_reason = "Reached max concurrency" + + logger.info( + f"Saturation detection complete: " + f"optimal={result.optimal_concurrency}, " + f"max_throughput={result.max_throughput_tps:.1f} TPS" + ) + + return result + + def _next_concurrency(self, current: int) -> int: + """Calculate next concurrency level""" + if self.config.use_exponential: + next_val = int(current * self.config.step_multiplier) + # Ensure we make at least step_size progress + return max(next_val, current + self.config.step_size) + else: + return current + self.config.step_size + + def cancel(self): + """Cancel the saturation detection""" + self._cancelled = True + + +async def find_optimal_concurrency( + endpoint: str, + model_name: str, + max_concurrency: int = 64, + requests_per_level: int = 20, +) -> SaturationResult: + """ + Convenience function to find optimal concurrency. + + Args: + endpoint: OpenAI-compatible API endpoint + model_name: Model name/ID + max_concurrency: Maximum concurrency to test + requests_per_level: Requests per concurrency level + + Returns: + SaturationResult with optimal concurrency and metrics + """ + config = SaturationConfig( + enabled=True, + start_concurrency=1, + max_concurrency=max_concurrency, + requests_per_level=requests_per_level, + ) + + detector = SaturationDetector(endpoint, model_name, config) + return await detector.detect() diff --git a/backend/app/services/deployer.py b/backend/app/services/deployer.py deleted file mode 100644 index 537d50d..0000000 --- a/backend/app/services/deployer.py +++ /dev/null @@ -1,928 +0,0 @@ -"""Deployment service - handles model deployment on workers""" - -import asyncio -import logging -import socket - -import docker -import httpx -from sqlalchemy import select -from sqlalchemy.orm import selectinload - -from app.config import get_settings -from app.database import async_session_maker -from app.models.deployment import Deployment, DeploymentStatus -from app.models.llm_model import BackendType, LLMModel - -logger = logging.getLogger(__name__) -settings = get_settings() - - -async def _update_semantic_router_config_background(): - """Background task to update semantic router config after deployment changes.""" - try: - from app.api.apps.deployment import update_semantic_router_config_if_deployed - - async with async_session_maker() as db: - await update_semantic_router_config_if_deployed(db) - except Exception as e: - logger.debug(f"Failed to update semantic router config: {e}") - - -class DeployerService: - """Service for deploying models to workers""" - - # Health check configuration - HEALTH_CHECK_INTERVAL = 5 # seconds between checks - HEALTH_CHECK_SLOW_THRESHOLD = 600 # seconds before showing "slow loading" message (10 min) - HEALTH_CHECK_REQUEST_TIMEOUT = 10 # timeout for each health check request - - async def deploy(self, deployment_id: int) -> None: - """Deploy a model to a worker""" - async with async_session_maker() as db: - result = await db.execute( - select(Deployment) - .where(Deployment.id == deployment_id) - .options( - selectinload(Deployment.worker), - selectinload(Deployment.model), - ) - ) - deployment = result.scalar_one_or_none() - - if not deployment: - logger.error(f"Deployment {deployment_id} not found") - return - - try: - # Update status to starting - deployment.status = DeploymentStatus.STARTING.value - deployment.status_message = "Sending deployment request to worker..." - await db.commit() - - # Build deployment request - deploy_request = self._build_deploy_request(deployment) - - # Check if this is a local worker - is_local = self._is_local_worker(deployment.worker.address) - - if is_local: - # Check if image needs to be pulled - image = deploy_request["image"] - if not self._image_exists_local(image): - deployment.status_message = f"Pulling image: {image}..." - await db.commit() - - pull_success = await self._pull_image_local(image) - if not pull_success: - deployment.status = DeploymentStatus.ERROR.value - deployment.status_message = f"Failed to pull image: {image}" - await db.commit() - return - - deployment.status_message = "Starting container..." - await db.commit() - - # Deploy locally using Docker directly - result = await self._deploy_local(deploy_request) - if result.get("error"): - deployment.status = DeploymentStatus.ERROR.value - deployment.status_message = result["error"] - await db.commit() - return - deployment.container_id = result.get("container_id") - deployment.port = result.get("port") - # Store container_name for internal Docker network communication - local_container_name = result.get("container_name") - deployment.container_name = local_container_name - else: - local_container_name = None # Remote workers use IP:port - # Send to remote worker agent - worker_url = f"http://{deployment.worker.address}/deploy" - progress_url = ( - f"http://{deployment.worker.address}/pull-progress/{deployment.id}" - ) - - # Start deployment request and poll for progress - async with httpx.AsyncClient(timeout=300.0) as client: - # Start the deployment in a task - deploy_task = asyncio.create_task( - client.post(worker_url, json=deploy_request) - ) - - # Poll for progress while waiting - while not deploy_task.done(): - try: - progress_resp = await client.get(progress_url, timeout=5.0) - if progress_resp.status_code == 200: - progress_data = progress_resp.json() - status = progress_data.get("status", "") - image = progress_data.get("image", "") - progress = progress_data.get("progress", 0) - - if status == "pulling": - deployment.status_message = ( - f"Pulling image {image}... ({progress}%)" - ) - await db.commit() - elif status == "completed": - deployment.status_message = ( - "Image pulled, starting container..." - ) - await db.commit() - elif status == "starting": - deployment.status_message = "Starting container..." - await db.commit() - except Exception: - pass # Progress polling is best-effort - - await asyncio.sleep(2) - - response = await deploy_task - - if response.status_code != 200: - deployment.status = DeploymentStatus.ERROR.value - deployment.status_message = f"Worker returned error: {response.text}" - await db.commit() - return - - result_data = response.json() - deployment.container_id = result_data.get("container_id") - deployment.port = result_data.get("port") - - # Container started, now waiting for model to load - deployment.status = DeploymentStatus.STARTING.value - deployment.status_message = "Downloading model and Loading model into GPU memory..." - await db.commit() - - # For Ollama, we need to pull the model first - if deployment.backend == BackendType.OLLAMA.value: - deployment.status_message = "Waiting for Ollama container to start..." - await db.commit() - - # Wait for Ollama API to be available before pulling - ollama_ready = await self._wait_for_ollama_ready( - deployment.worker.address, - deployment.port, - container_name=local_container_name, - ) - if not ollama_ready: - deployment.status = DeploymentStatus.ERROR.value - deployment.status_message = "Ollama container failed to start" - await db.commit() - return - - deployment.status_message = "Pulling model with Ollama..." - await db.commit() - - pull_success = await self._ollama_pull_model( - deployment.worker.address, - deployment.port, - deployment.model.model_id, - container_name=local_container_name, - ) - if not pull_success: - deployment.status = DeploymentStatus.ERROR.value - deployment.status_message = "Failed to pull model with Ollama" - await db.commit() - return - - deployment.status_message = "Model pulled, waiting for API..." - await db.commit() - - # Wait for the API endpoint to become ready - api_ready = await self._wait_for_api_ready( - deployment.worker.address, - deployment.port, - deployment_id, - db, - backend=deployment.backend, - container_name=local_container_name, - ) - - # Refresh deployment object after health check updates - await db.refresh(deployment) - - if api_ready is None: - # Deployment was cancelled, don't update status - logger.info(f"Deployment {deployment_id} cancelled during startup") - return - else: - # api_ready is True, model is ready - deployment.status = DeploymentStatus.RUNNING.value - deployment.status_message = "Model ready" - - # Update semantic router config if deployed - asyncio.create_task(_update_semantic_router_config_background()) - - except httpx.ConnectError: - deployment.status = DeploymentStatus.ERROR.value - deployment.status_message = ( - f"Cannot connect to worker at {deployment.worker.address}" - ) - except Exception as e: - logger.exception(f"Error deploying {deployment_id}") - deployment.status = DeploymentStatus.ERROR.value - deployment.status_message = str(e) - - await db.commit() - - async def _wait_for_ollama_ready( - self, - worker_address: str, - port: int, - timeout: int = 60, - container_name: str | None = None, - ) -> bool: - """Wait for Ollama API to be available. - - Args: - worker_address: Worker address (host:port) - port: Ollama container port - timeout: Maximum wait time in seconds - container_name: Container name for Docker network (Windows compatibility) - - Returns: - True if Ollama is ready, False on timeout - """ - # Ollama is configured to use port 8000 (OLLAMA_HOST=0.0.0.0:8000) - if container_name: - api_url = f"http://{container_name}:8000/api/tags" - else: - worker_ip = worker_address.split(":")[0] - api_url = f"http://{worker_ip}:{port}/api/tags" - - logger.info(f"Waiting for Ollama API at {api_url}") - - elapsed = 0 - check_interval = 2 - - async with httpx.AsyncClient(timeout=10.0) as client: - while elapsed < timeout: - try: - response = await client.get(api_url) - if response.status_code == 200: - logger.info(f"Ollama API ready after {elapsed}s") - return True - except httpx.ConnectError: - logger.debug(f"Ollama not ready yet ({elapsed}s)") - except Exception as e: - logger.debug(f"Ollama check error: {e}") - - await asyncio.sleep(check_interval) - elapsed += check_interval - - logger.error(f"Ollama API not ready after {timeout}s") - return False - - async def _ollama_pull_model( - self, - worker_address: str, - port: int, - model_id: str, - container_name: str | None = None, - ) -> bool: - """Pull a model using Ollama API. - - Ollama requires models to be pulled before they can be used. - This method calls the /api/pull endpoint and waits for completion. - """ - # Ollama is configured to use port 8000 (OLLAMA_HOST=0.0.0.0:8000) - if container_name: - api_url = f"http://{container_name}:8000/api/pull" - else: - worker_ip = worker_address.split(":")[0] - api_url = f"http://{worker_ip}:{port}/api/pull" - - logger.info(f"Pulling Ollama model: {model_id}") - - try: - async with httpx.AsyncClient(timeout=1800.0) as client: # 30 min timeout - # Ollama pull is a streaming endpoint - async with client.stream( - "POST", - api_url, - json={"name": model_id, "stream": True}, - ) as response: - if response.status_code != 200: - logger.error(f"Ollama pull failed: {response.status_code}") - return False - - # Process the streaming response - async for line in response.aiter_lines(): - if line: - try: - import json - - data = json.loads(line) - status = data.get("status", "") - - # Log progress - if "completed" in data and "total" in data: - pct = int(data["completed"] / data["total"] * 100) - logger.debug(f"Ollama pull: {status} ({pct}%)") - elif status: - logger.debug(f"Ollama pull: {status}") - - # Check for completion - if status == "success": - logger.info(f"Ollama model {model_id} pulled successfully") - return True - - except Exception as e: - logger.debug(f"Error parsing Ollama response: {e}") - - logger.info(f"Ollama model {model_id} pull completed") - return True - - except httpx.ConnectError: - logger.error(f"Cannot connect to Ollama at {api_url}") - return False - except Exception as e: - logger.error(f"Ollama pull error: {e}") - return False - - async def _wait_for_api_ready( - self, - worker_address: str, - port: int, - deployment_id: int, - db, - backend: str = BackendType.VLLM.value, - container_name: str | None = None, - ) -> bool | None: - """ - Poll the OpenAI API endpoint until it's ready or cancelled. - - Args: - worker_address: Worker address (host:port) - port: Host port for the model API - deployment_id: Deployment ID for status updates - db: Database session - backend: Backend type (vllm, ollama, etc.) - container_name: Container name for local Docker network communication. - If set, uses container_name:8000 instead of worker_ip:port. - This is needed for Windows Docker Desktop compatibility. - - Returns: - True: API is ready - None: Cancelled (user stopped deployment) - """ - # For local deployments with container_name, use Docker internal networking - # All backends (vLLM, SGLang, Ollama) are configured to use port 8000 - if container_name: - api_base_url = f"http://{container_name}:8000" - logger.info(f"Using Docker network for API: {api_base_url}") - else: - worker_ip = worker_address.split(":")[0] - api_base_url = f"http://{worker_ip}:{port}" - - # Both vLLM and Ollama support OpenAI-compatible /v1/models endpoint - health_endpoint = f"{api_base_url}/v1/models" - - # For Ollama, we can also check /api/tags as a fallback - is_ollama = backend == BackendType.OLLAMA.value - - elapsed = 0 - check_count = 0 - shown_slow_message = False - - logger.info(f"Waiting for API to be ready at {health_endpoint} (backend={backend})") - - async with httpx.AsyncClient(timeout=self.HEALTH_CHECK_REQUEST_TIMEOUT) as client: - while True: # Wait indefinitely until ready or cancelled - check_count += 1 - - # Check if deployment was cancelled - try: - result = await db.execute( - select(Deployment).where(Deployment.id == deployment_id) - ) - deployment = result.scalar_one_or_none() - if deployment and deployment.status in [ - DeploymentStatus.STOPPED.value, - DeploymentStatus.STOPPING.value, - ]: - logger.info(f"Deployment {deployment_id} was cancelled") - return None - except Exception as e: - logger.debug(f"Error checking deployment status: {e}") - - try: - response = await client.get(health_endpoint) - - if response.status_code == 200: - data = response.json() - # vLLM returns {"object": "list", "data": [...]} - # Ollama returns {"object": "list", "data": [...]} (OpenAI compat) - if data.get("data") and len(data["data"]) > 0: - logger.info( - f"API ready at {health_endpoint} after {elapsed}s " - f"({check_count} checks)" - ) - return True - - # For Ollama, also try the native /api/tags endpoint - if is_ollama and response.status_code != 200: - ollama_endpoint = f"{api_base_url}/api/tags" - ollama_response = await client.get(ollama_endpoint) - if ollama_response.status_code == 200: - ollama_data = ollama_response.json() - if ollama_data.get("models") and len(ollama_data["models"]) > 0: - logger.info( - f"Ollama API ready at {ollama_endpoint} after {elapsed}s" - ) - return True - - logger.debug(f"Health check {check_count}: status={response.status_code}") - - except httpx.ConnectError: - # Container not ready yet, this is expected during startup - logger.debug(f"Health check {check_count}: connection refused") - except httpx.ReadTimeout: - logger.debug(f"Health check {check_count}: read timeout") - except Exception as e: - logger.debug(f"Health check {check_count}: {type(e).__name__}: {e}") - - # Update status message periodically - if check_count % 6 == 0: # Every 30 seconds - try: - result = await db.execute( - select(Deployment).where(Deployment.id == deployment_id) - ) - deployment = result.scalar_one_or_none() - if deployment and deployment.status == DeploymentStatus.STARTING.value: - mins = elapsed // 60 - secs = elapsed % 60 - time_str = f"{mins}m {secs}s" if mins > 0 else f"{secs}s" - - # Show patience message after threshold - if ( - elapsed >= self.HEALTH_CHECK_SLOW_THRESHOLD - and not shown_slow_message - ): - deployment.status_message = ( - f"Loading model... ({time_str}) - " - "Large model or slow network detected. Please be patient." - ) - shown_slow_message = True - elif shown_slow_message: - deployment.status_message = ( - f"Loading model... ({time_str}) - Please be patient." - ) - else: - deployment.status_message = ( - f"Loading model into GPU memory... ({time_str})" - ) - await db.commit() - except Exception as e: - logger.debug(f"Error updating deployment status message: {e}") - - await asyncio.sleep(self.HEALTH_CHECK_INTERVAL) - elapsed += self.HEALTH_CHECK_INTERVAL - - def _is_local_worker(self, address: str) -> bool: - """Check if the worker address refers to the local machine.""" - if not address: - return False - host = address.split(":")[0].lower() - return host in ("localhost", "127.0.0.1", "local") - - def _image_exists_local(self, image: str) -> bool: - """Check if a Docker image exists locally.""" - try: - client = docker.from_env() - client.images.get(image) - return True - except docker.errors.ImageNotFound: - return False - except docker.errors.APIError as e: - logger.warning(f"Docker API error checking image {image}: {e}") - return False - except Exception as e: - logger.warning(f"Unexpected error checking image {image}: {e}") - return False - - async def _pull_image_local(self, image: str) -> bool: - """Pull a Docker image locally with progress logging.""" - try: - client = docker.from_env() - logger.info(f"Pulling image: {image}") - - # Pull with progress - for line in client.api.pull(image, stream=True, decode=True): - if "status" in line: - status = line.get("status", "") - progress = line.get("progress", "") - if progress: - logger.debug(f"{status}: {progress}") - - logger.info(f"Image pulled successfully: {image}") - return True - except Exception as e: - logger.error(f"Failed to pull image {image}: {e}") - return False - - def _find_available_port(self, start_port: int = 8001, end_port: int = 9000) -> int: - """Find an available port on the local machine.""" - for port in range(start_port, end_port): - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - try: - s.bind(("", port)) - return port - except OSError: - continue - raise RuntimeError(f"No available ports in range {start_port}-{end_port}") - - async def _deploy_local(self, deploy_request: dict) -> dict: - """Deploy a container locally using Docker. - - This is used for local workers where we don't need to go through - a remote worker agent. - - On Windows Docker Desktop, containers must be on the same network - to communicate. We put model containers on the 'lmstack' network - so the backend can reach them via container name. - """ - try: - client = docker.from_env() - - image = deploy_request["image"] - command = deploy_request.get("command", []) - environment = deploy_request.get("environment", {}) - gpu_indexes = deploy_request.get("gpu_indexes", [0]) - deployment_name = deploy_request.get("deployment_name", "lmstack-deployment") - - # Find available port (still used for external access) - host_port = self._find_available_port() - - # Container name - used for internal Docker network communication - container_name = f"lmstack-{deployment_name}-{deploy_request['deployment_id']}" - - # Ensure lmstack network exists (for Windows Docker Desktop compatibility) - network_name = "lmstack_lmstack" - try: - client.networks.get(network_name) - except docker.errors.NotFound: - # Try alternative network name (depends on compose project name) - try: - network_name = "lmstack" - client.networks.get(network_name) - except docker.errors.NotFound: - # Create the network if it doesn't exist - logger.info(f"Creating Docker network: {network_name}") - client.networks.create(network_name, driver="bridge") - - # Build GPU device requests - device_requests = [ - docker.types.DeviceRequest( - device_ids=[str(i) for i in gpu_indexes], - capabilities=[["gpu"]], - ) - ] - - # Remove existing container with same name if exists - try: - existing = client.containers.get(container_name) - existing.remove(force=True) - except docker.errors.NotFound: - pass - - # Run container - # Use configurable HF cache directory from settings - hf_cache = settings.hf_cache_dir - container = client.containers.run( - image=image, - command=command, - name=container_name, - detach=True, - ports={"8000/tcp": host_port}, - environment=environment, - device_requests=device_requests, - volumes={ - hf_cache: {"bind": "/root/.cache/huggingface", "mode": "rw"}, - }, - shm_size="16g", # Required for large model inference - restart_policy={"Name": "unless-stopped"}, - network=network_name, # Join lmstack network for Windows compatibility - ) - - logger.info( - f"Started local container: {container.id[:12]} " - f"(name={container_name}) on network={network_name}, port={host_port}" - ) - - return { - "container_id": container.id, - "container_name": container_name, - "port": host_port, - } - - except docker.errors.ImageNotFound as e: - logger.error(f"Docker image not found: {e}") - return {"error": f"Docker image not found: {deploy_request['image']}"} - except docker.errors.APIError as e: - logger.error(f"Docker API error: {e}") - return {"error": f"Docker error: {str(e)}"} - except Exception as e: - logger.exception(f"Error deploying locally: {e}") - return {"error": str(e)} - - def _build_deploy_request(self, deployment: Deployment) -> dict: - """Build the deployment request for worker agent. - - Supports multiple backends: - - vLLM: High-throughput inference with OpenAI-compatible API - - Ollama: Simple local LLM inference with OpenAI-compatible API - """ - model = deployment.model - - # Determine docker image based on backend - # Priority: deployment extra_params > model docker_image > backend default - deployment_image = ( - deployment.extra_params.get("docker_image") if deployment.extra_params else None - ) - - backend = deployment.backend - - if deployment_image: - image = deployment_image - elif model.docker_image: - image = model.docker_image - elif backend == BackendType.VLLM.value: - image = settings.vllm_default_image - elif backend == BackendType.SGLANG.value: - image = settings.sglang_default_image - elif backend == BackendType.OLLAMA.value: - image = settings.ollama_default_image - else: - logger.warning(f"Unknown backend: {backend}, defaulting to vLLM") - image = settings.vllm_default_image - - # Build command based on backend type - if backend == BackendType.OLLAMA.value: - cmd, env = self._build_ollama_config(model, deployment) - elif backend == BackendType.SGLANG.value: - cmd, env = self._build_sglang_config(model, deployment) - else: - cmd, env = self._build_vllm_config(model, deployment) - - request = { - "deployment_id": deployment.id, - "deployment_name": deployment.name, - "image": image, - "command": cmd, - "model_id": model.model_id, - "gpu_indexes": deployment.gpu_indexes or [0], - "environment": env, - } - - # Note: We don't reuse existing port to avoid conflicts. - # Worker will automatically allocate an available port. - - return request - - def _build_vllm_config( - self, - model: "LLMModel", - deployment: Deployment, - ) -> tuple[list[str], dict[str, str]]: - """Build vLLM container command and environment.""" - cmd = [ - "--model", - model.model_id, - "--host", - "0.0.0.0", - "--port", - "8000", - ] - - # Add default params if any - if model.default_params: - for key, value in model.default_params.items(): - if value is True: - cmd.append(f"--{key}") - elif value is not False and value is not None: - cmd.extend([f"--{key}", str(value)]) - - # Add extra params if any (skip special keys like docker_image, custom_args) - if deployment.extra_params: - skip_keys = {"docker_image", "custom_args"} - for key, value in deployment.extra_params.items(): - if key in skip_keys: - continue - if value is True: - cmd.append(f"--{key}") - elif value is not False and value is not None: - cmd.extend([f"--{key}", str(value)]) - - # Handle custom CLI arguments - custom_args = deployment.extra_params.get("custom_args") - if custom_args and isinstance(custom_args, str): - # Parse custom args: split by newlines and spaces - for line in custom_args.strip().split("\n"): - line = line.strip() - if line and not line.startswith("#"): - # Split each line by spaces for multi-arg support - cmd.extend(line.split()) - - env = { - "HF_HOME": "/root/.cache/huggingface", - } - - return cmd, env - - def _build_sglang_config( - self, - model: "LLMModel", - deployment: Deployment, - ) -> tuple[list[str], dict[str, str]]: - """Build SGLang container command and environment. - - SGLang uses similar command-line arguments to vLLM but with some - differences in parameter names. Unlike vLLM, the sglang Docker image - does not have a proper ENTRYPOINT, so we need to explicitly specify - the launch command. - """ - cmd = [ - "python", - "-m", - "sglang.launch_server", - "--model-path", - model.model_id, - "--host", - "0.0.0.0", - "--port", - "8000", - ] - - # Add default params if any - if model.default_params: - for key, value in model.default_params.items(): - if value is True: - cmd.append(f"--{key}") - elif value is not False and value is not None: - cmd.extend([f"--{key}", str(value)]) - - # Add extra params if any (skip special keys like docker_image, custom_args) - if deployment.extra_params: - skip_keys = {"docker_image", "custom_args"} - for key, value in deployment.extra_params.items(): - if key in skip_keys: - continue - if value is True: - cmd.append(f"--{key}") - elif value is not False and value is not None: - cmd.extend([f"--{key}", str(value)]) - - # Handle custom CLI arguments - custom_args = deployment.extra_params.get("custom_args") - if custom_args and isinstance(custom_args, str): - # Parse custom args: split by newlines and spaces - for line in custom_args.strip().split("\n"): - line = line.strip() - if line and not line.startswith("#"): - # Split each line by spaces for multi-arg support - cmd.extend(line.split()) - - env = { - "HF_HOME": "/root/.cache/huggingface", - } - - return cmd, env - - def _build_ollama_config( - self, - model: "LLMModel", - deployment: Deployment, - ) -> tuple[list[str], dict[str, str]]: - """Build Ollama container command and environment. - - Ollama uses environment variables for configuration instead of - command-line arguments. The model is pulled and run via API after - the container starts. - """ - # Ollama's default entrypoint is "ollama serve" - cmd = ["serve"] - - # Ollama environment configuration - env = { - "OLLAMA_HOST": "0.0.0.0:8000", # Bind to container port 8000 - "OLLAMA_ORIGINS": "*", # Allow CORS from all origins (required for web UI) - "OLLAMA_NUM_PARALLEL": str( - deployment.extra_params.get("num_parallel", 4) if deployment.extra_params else "4" - ), - "OLLAMA_MAX_LOADED_MODELS": str( - deployment.extra_params.get("max_loaded_models", 1) - if deployment.extra_params - else "1" - ), - # GPU settings - "OLLAMA_GPU_OVERHEAD": "0", - "CUDA_VISIBLE_DEVICES": ",".join(str(i) for i in (deployment.gpu_indexes or [0])), - } - - # Add any custom environment variables from extra_params - if deployment.extra_params: - for key, value in deployment.extra_params.items(): - if key.startswith("OLLAMA_") and value is not None: - env[key] = str(value) - - # Handle custom environment variables from custom_args - custom_args = deployment.extra_params.get("custom_args") - if custom_args and isinstance(custom_args, str): - # Parse custom args as environment variables (KEY=VALUE format) - for line in custom_args.strip().split("\n"): - line = line.strip() - if line and not line.startswith("#") and "=" in line: - key, _, value = line.partition("=") - env[key.strip()] = value.strip() - - return cmd, env - - async def stop(self, deployment_id: int) -> None: - """Stop a deployment""" - async with async_session_maker() as db: - result = await db.execute( - select(Deployment) - .where(Deployment.id == deployment_id) - .options(selectinload(Deployment.worker)) - ) - deployment = result.scalar_one_or_none() - - if not deployment or not deployment.container_id: - return - - try: - is_local = self._is_local_worker(deployment.worker.address) - - if is_local: - # Stop locally using Docker directly - await self._stop_local(deployment.container_id) - else: - worker_url = f"http://{deployment.worker.address}/stop" - - async with httpx.AsyncClient(timeout=60.0) as client: - await client.post( - worker_url, json={"container_id": deployment.container_id} - ) - - except Exception as e: - logger.warning(f"Error stopping deployment {deployment_id}: {e}") - - async def _stop_local(self, container_id: str) -> None: - """Stop a container locally.""" - try: - client = docker.from_env() - container = client.containers.get(container_id) - container.stop(timeout=30) - container.remove() - logger.info(f"Stopped local container: {container_id[:12]}") - except docker.errors.NotFound: - logger.warning(f"Container not found: {container_id}") - except Exception as e: - logger.warning(f"Error stopping local container: {e}") - - async def get_logs(self, deployment: Deployment, tail: int = 100) -> str: - """Get logs from a deployment""" - if not deployment.container_id or not deployment.worker: - return "No container running" - - try: - is_local = self._is_local_worker(deployment.worker.address) - - if is_local: - return self._get_logs_local(deployment.container_id, tail) - else: - worker_url = f"http://{deployment.worker.address}/logs" - - async with httpx.AsyncClient(timeout=30.0) as client: - response = await client.get( - worker_url, - params={ - "container_id": deployment.container_id, - "tail": tail, - }, - ) - - if response.status_code == 200: - return response.json().get("logs", "") - else: - return f"Error fetching logs: {response.text}" - - except httpx.ConnectError: - return f"Cannot connect to worker at {deployment.worker.address}" - except Exception as e: - return f"Error: {str(e)}" - - def _get_logs_local(self, container_id: str, tail: int = 100) -> str: - """Get logs from a local container.""" - try: - client = docker.from_env() - container = client.containers.get(container_id) - logs = container.logs(tail=tail, timestamps=True).decode("utf-8") - return logs - except docker.errors.NotFound: - return "Container not found" - except Exception as e: - return f"Error: {str(e)}" diff --git a/backend/app/services/deployer/__init__.py b/backend/app/services/deployer/__init__.py new file mode 100644 index 0000000..5576b80 --- /dev/null +++ b/backend/app/services/deployer/__init__.py @@ -0,0 +1,10 @@ +"""Deployer service package - handles model deployment on workers. + +This package provides the DeployerService class for deploying models +to workers using various backends (vLLM, SGLang, Ollama) and deployment +methods (Docker, native). +""" + +from .service import DeployerService, _update_semantic_router_config_background + +__all__ = ["DeployerService", "_update_semantic_router_config_background"] diff --git a/backend/app/services/deployer/config.py b/backend/app/services/deployer/config.py new file mode 100644 index 0000000..0beeac1 --- /dev/null +++ b/backend/app/services/deployer/config.py @@ -0,0 +1,232 @@ +"""Backend configuration builders for deployment. + +This module contains functions that build deployment configurations +for different inference backends (vLLM, SGLang, Ollama). +""" + +import logging +from typing import TYPE_CHECKING + +from app.config import get_settings +from app.models.llm_model import BackendType + +if TYPE_CHECKING: + from app.models.deployment import Deployment + from app.models.llm_model import LLMModel + +logger = logging.getLogger(__name__) +settings = get_settings() + + +def build_deploy_request(deployment: "Deployment") -> dict: + """Build the deployment request for worker agent. + + Supports multiple backends: + - vLLM: High-throughput inference with OpenAI-compatible API + - Ollama: Simple local LLM inference with OpenAI-compatible API + """ + model = deployment.model + + # Determine docker image based on backend + # Priority: deployment extra_params > model docker_image > backend default + deployment_image = ( + deployment.extra_params.get("docker_image") if deployment.extra_params else None + ) + + backend = deployment.backend + + if deployment_image: + image = deployment_image + elif model.docker_image: + image = model.docker_image + elif backend == BackendType.VLLM.value: + image = settings.vllm_default_image + elif backend == BackendType.SGLANG.value: + image = settings.sglang_default_image + elif backend == BackendType.OLLAMA.value: + image = settings.ollama_default_image + else: + logger.warning(f"Unknown backend: {backend}, defaulting to vLLM") + image = settings.vllm_default_image + + # Build command based on backend type + if backend == BackendType.OLLAMA.value: + cmd, env = build_ollama_config(model, deployment) + elif backend == BackendType.SGLANG.value: + cmd, env = build_sglang_config(model, deployment) + else: + cmd, env = build_vllm_config(model, deployment) + + request = { + "deployment_id": deployment.id, + "deployment_name": deployment.name, + "image": image, + "command": cmd, + "model_id": model.model_id, + "gpu_indexes": deployment.gpu_indexes or [0], + "environment": env, + } + + # Note: We don't reuse existing port to avoid conflicts. + # Worker will automatically allocate an available port. + + return request + + +def build_vllm_config( + model: "LLMModel", + deployment: "Deployment", +) -> tuple[list[str], dict[str, str]]: + """Build vLLM container command and environment.""" + cmd = [ + "--model", + model.model_id, + "--host", + "0.0.0.0", + "--port", + "8000", + ] + + # Add default params if any + if model.default_params: + for key, value in model.default_params.items(): + if value is True: + cmd.append(f"--{key}") + elif value is not False and value is not None: + cmd.extend([f"--{key}", str(value)]) + + # Add extra params if any (skip special keys like docker_image, custom_args) + if deployment.extra_params: + skip_keys = {"docker_image", "custom_args"} + for key, value in deployment.extra_params.items(): + if key in skip_keys: + continue + if value is True: + cmd.append(f"--{key}") + # Auto-add tool-call-parser when enable-auto-tool-choice is enabled + if key == "enable-auto-tool-choice": + cmd.extend(["--tool-call-parser", "hermes"]) + elif value is not False and value is not None: + cmd.extend([f"--{key}", str(value)]) + + # Handle custom CLI arguments + custom_args = deployment.extra_params.get("custom_args") + if custom_args and isinstance(custom_args, str): + # Parse custom args: split by newlines and spaces + for line in custom_args.strip().split("\n"): + line = line.strip() + if line and not line.startswith("#"): + # Split each line by spaces for multi-arg support + cmd.extend(line.split()) + + env = { + "HF_HOME": "/root/.cache/huggingface", + } + + return cmd, env + + +def build_sglang_config( + model: "LLMModel", + deployment: "Deployment", +) -> tuple[list[str], dict[str, str]]: + """Build SGLang container command and environment. + + SGLang uses similar command-line arguments to vLLM but with some + differences in parameter names. Unlike vLLM, the sglang Docker image + does not have a proper ENTRYPOINT, so we need to explicitly specify + the launch command. + """ + cmd = [ + "python", + "-m", + "sglang.launch_server", + "--model-path", + model.model_id, + "--host", + "0.0.0.0", + "--port", + "8000", + ] + + # Add default params if any + if model.default_params: + for key, value in model.default_params.items(): + if value is True: + cmd.append(f"--{key}") + elif value is not False and value is not None: + cmd.extend([f"--{key}", str(value)]) + + # Add extra params if any (skip special keys like docker_image, custom_args) + if deployment.extra_params: + skip_keys = {"docker_image", "custom_args"} + for key, value in deployment.extra_params.items(): + if key in skip_keys: + continue + if value is True: + cmd.append(f"--{key}") + elif value is not False and value is not None: + cmd.extend([f"--{key}", str(value)]) + + # Handle custom CLI arguments + custom_args = deployment.extra_params.get("custom_args") + if custom_args and isinstance(custom_args, str): + # Parse custom args: split by newlines and spaces + for line in custom_args.strip().split("\n"): + line = line.strip() + if line and not line.startswith("#"): + # Split each line by spaces for multi-arg support + cmd.extend(line.split()) + + env = { + "HF_HOME": "/root/.cache/huggingface", + } + + return cmd, env + + +def build_ollama_config( + model: "LLMModel", + deployment: "Deployment", +) -> tuple[list[str], dict[str, str]]: + """Build Ollama container command and environment. + + Ollama uses environment variables for configuration instead of + command-line arguments. The model is pulled and run via API after + the container starts. + """ + # Ollama's default entrypoint is "ollama serve" + cmd = ["serve"] + + # Ollama environment configuration + env = { + "OLLAMA_HOST": "0.0.0.0:8000", # Bind to container port 8000 + "OLLAMA_ORIGINS": "*", # Allow CORS from all origins (required for web UI) + "OLLAMA_NUM_PARALLEL": str( + deployment.extra_params.get("num_parallel", 4) if deployment.extra_params else "4" + ), + "OLLAMA_MAX_LOADED_MODELS": str( + deployment.extra_params.get("max_loaded_models", 1) if deployment.extra_params else "1" + ), + # GPU settings + "OLLAMA_GPU_OVERHEAD": "0", + "CUDA_VISIBLE_DEVICES": ",".join(str(i) for i in (deployment.gpu_indexes or [0])), + } + + # Add any custom environment variables from extra_params + if deployment.extra_params: + for key, value in deployment.extra_params.items(): + if key.startswith("OLLAMA_") and value is not None: + env[key] = str(value) + + # Handle custom environment variables from custom_args + custom_args = deployment.extra_params.get("custom_args") + if custom_args and isinstance(custom_args, str): + # Parse custom args as environment variables (KEY=VALUE format) + for line in custom_args.strip().split("\n"): + line = line.strip() + if line and not line.startswith("#") and "=" in line: + key, _, value = line.partition("=") + env[key.strip()] = value.strip() + + return cmd, env diff --git a/backend/app/services/deployer/docker.py b/backend/app/services/deployer/docker.py new file mode 100644 index 0000000..95eb1fd --- /dev/null +++ b/backend/app/services/deployer/docker.py @@ -0,0 +1,186 @@ +"""Docker deployment operations. + +This module handles Docker-specific deployment operations including +image management, container lifecycle, and local Docker operations. +""" + +import logging +import socket + +import docker + +from app.config import get_settings + +logger = logging.getLogger(__name__) +settings = get_settings() + + +def image_exists_local(image: str) -> bool: + """Check if a Docker image exists locally.""" + try: + client = docker.from_env() + client.images.get(image) + return True + except docker.errors.ImageNotFound: + return False + except docker.errors.APIError as e: + logger.warning(f"Docker API error checking image {image}: {e}") + return False + except Exception as e: + logger.warning(f"Unexpected error checking image {image}: {e}") + return False + + +async def pull_image_local(image: str) -> bool: + """Pull a Docker image locally with progress logging.""" + try: + client = docker.from_env() + logger.info(f"Pulling image: {image}") + + # Pull with progress + for line in client.api.pull(image, stream=True, decode=True): + if "status" in line: + status = line.get("status", "") + progress = line.get("progress", "") + if progress: + logger.debug(f"{status}: {progress}") + + logger.info(f"Image pulled successfully: {image}") + return True + except Exception as e: + logger.error(f"Failed to pull image {image}: {e}") + return False + + +def find_available_port(start_port: int = 8001, end_port: int = 9000) -> int: + """Find an available port on the local machine.""" + for port in range(start_port, end_port): + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + try: + s.bind(("", port)) + return port + except OSError: + continue + raise RuntimeError(f"No available ports in range {start_port}-{end_port}") + + +async def deploy_local(deploy_request: dict) -> dict: + """Deploy a container locally using Docker. + + This is used for local workers where we don't need to go through + a remote worker agent. + + On Windows Docker Desktop, containers must be on the same network + to communicate. We put model containers on the 'lmstack' network + so the backend can reach them via container name. + """ + try: + client = docker.from_env() + + image = deploy_request["image"] + command = deploy_request.get("command", []) + environment = deploy_request.get("environment", {}) + gpu_indexes = deploy_request.get("gpu_indexes", [0]) + deployment_name = deploy_request.get("deployment_name", "lmstack-deployment") + + # Find available port (still used for external access) + host_port = find_available_port() + + # Container name - used for internal Docker network communication + container_name = f"lmstack-{deployment_name}-{deploy_request['deployment_id']}" + + # Ensure lmstack network exists (for Windows Docker Desktop compatibility) + network_name = "lmstack_lmstack" + try: + client.networks.get(network_name) + except docker.errors.NotFound: + # Try alternative network name (depends on compose project name) + try: + network_name = "lmstack" + client.networks.get(network_name) + except docker.errors.NotFound: + # Create the network if it doesn't exist + logger.info(f"Creating Docker network: {network_name}") + client.networks.create(network_name, driver="bridge") + + # Build GPU device requests + device_requests = [ + docker.types.DeviceRequest( + device_ids=[str(i) for i in gpu_indexes], + capabilities=[["gpu"]], + ) + ] + + # Remove existing container with same name if exists + try: + existing = client.containers.get(container_name) + existing.remove(force=True) + except docker.errors.NotFound: + pass + + # Run container + # Use configurable HF cache directory from settings + hf_cache = settings.hf_cache_dir + container = client.containers.run( + image=image, + command=command, + name=container_name, + detach=True, + ports={"8000/tcp": host_port}, + environment=environment, + device_requests=device_requests, + volumes={ + hf_cache: {"bind": "/root/.cache/huggingface", "mode": "rw"}, + }, + shm_size="16g", # Required for large model inference + restart_policy={"Name": "unless-stopped"}, + network=network_name, # Join lmstack network for Windows compatibility + ) + + logger.info( + f"Started local container: {container.id[:12]} " + f"(name={container_name}) on network={network_name}, port={host_port}" + ) + + return { + "container_id": container.id, + "container_name": container_name, + "port": host_port, + } + + except docker.errors.ImageNotFound as e: + logger.error(f"Docker image not found: {e}") + return {"error": f"Docker image not found: {deploy_request['image']}"} + except docker.errors.APIError as e: + logger.error(f"Docker API error: {e}") + return {"error": f"Docker error: {str(e)}"} + except Exception as e: + logger.exception(f"Error deploying locally: {e}") + return {"error": str(e)} + + +async def stop_local(container_id: str) -> None: + """Stop a container locally.""" + try: + client = docker.from_env() + container = client.containers.get(container_id) + container.stop(timeout=30) + container.remove() + logger.info(f"Stopped local container: {container_id[:12]}") + except docker.errors.NotFound: + logger.warning(f"Container not found: {container_id}") + except Exception as e: + logger.warning(f"Error stopping local container: {e}") + + +def get_logs_local(container_id: str, tail: int = 100) -> str: + """Get logs from a local container.""" + try: + client = docker.from_env() + container = client.containers.get(container_id) + logs = container.logs(tail=tail, timestamps=True).decode("utf-8") + return logs + except docker.errors.NotFound: + return "Container not found" + except Exception as e: + return f"Error: {str(e)}" diff --git a/backend/app/services/deployer/health.py b/backend/app/services/deployer/health.py new file mode 100644 index 0000000..09ed0d6 --- /dev/null +++ b/backend/app/services/deployer/health.py @@ -0,0 +1,336 @@ +"""Health check and API readiness operations. + +This module handles health checking for deployed models, +including waiting for APIs to become ready. +""" + +import asyncio +import json +import logging + +import httpx +from sqlalchemy import select + +from app.models.deployment import Deployment, DeploymentStatus +from app.models.llm_model import BackendType + +logger = logging.getLogger(__name__) + +# Health check configuration constants +HEALTH_CHECK_INTERVAL = 5 # seconds between checks +HEALTH_CHECK_SLOW_THRESHOLD = 600 # seconds before showing "slow loading" message (10 min) +HEALTH_CHECK_REQUEST_TIMEOUT = 10 # timeout for each health check request + + +async def wait_for_ollama_ready( + worker_address: str, + port: int, + timeout: int = 60, + container_name: str | None = None, +) -> bool: + """Wait for Ollama API to be available. + + Args: + worker_address: Worker address (host:port) + port: Ollama container port + timeout: Maximum wait time in seconds + container_name: Container name for Docker network (Windows compatibility) + + Returns: + True if Ollama is ready, False on timeout + """ + # Ollama is configured to use port 8000 (OLLAMA_HOST=0.0.0.0:8000) + if container_name: + api_url = f"http://{container_name}:8000/api/tags" + else: + worker_ip = worker_address.split(":")[0] + api_url = f"http://{worker_ip}:{port}/api/tags" + + logger.info(f"Waiting for Ollama API at {api_url}") + + elapsed = 0 + check_interval = 2 + + async with httpx.AsyncClient(timeout=10.0) as client: + while elapsed < timeout: + try: + response = await client.get(api_url) + if response.status_code == 200: + logger.info(f"Ollama API ready after {elapsed}s") + return True + except httpx.ConnectError: + logger.debug(f"Ollama not ready yet ({elapsed}s)") + except Exception as e: + logger.debug(f"Ollama check error: {e}") + + await asyncio.sleep(check_interval) + elapsed += check_interval + + logger.error(f"Ollama API not ready after {timeout}s") + return False + + +async def ollama_pull_model( + worker_address: str, + port: int, + model_id: str, + container_name: str | None = None, +) -> bool: + """Pull a model using Ollama API. + + Ollama requires models to be pulled before they can be used. + This method calls the /api/pull endpoint and waits for completion. + """ + # Ollama is configured to use port 8000 (OLLAMA_HOST=0.0.0.0:8000) + if container_name: + api_url = f"http://{container_name}:8000/api/pull" + else: + worker_ip = worker_address.split(":")[0] + api_url = f"http://{worker_ip}:{port}/api/pull" + + logger.info(f"Pulling Ollama model: {model_id}") + + try: + async with httpx.AsyncClient(timeout=1800.0) as client: # 30 min timeout + # Ollama pull is a streaming endpoint + async with client.stream( + "POST", + api_url, + json={"name": model_id, "stream": True}, + ) as response: + if response.status_code != 200: + logger.error(f"Ollama pull failed: {response.status_code}") + return False + + # Process the streaming response + async for line in response.aiter_lines(): + if line: + try: + data = json.loads(line) + status = data.get("status", "") + + # Log progress + if "completed" in data and "total" in data: + pct = int(data["completed"] / data["total"] * 100) + logger.debug(f"Ollama pull: {status} ({pct}%)") + elif status: + logger.debug(f"Ollama pull: {status}") + + # Check for completion + if status == "success": + logger.info(f"Ollama model {model_id} pulled successfully") + return True + + except Exception as e: + logger.debug(f"Error parsing Ollama response: {e}") + + logger.info(f"Ollama model {model_id} pull completed") + return True + + except httpx.ConnectError: + logger.error(f"Cannot connect to Ollama at {api_url}") + return False + except Exception as e: + logger.error(f"Ollama pull error: {e}") + return False + + +async def wait_for_api_ready( + worker_address: str, + port: int, + deployment_id: int, + db, + backend: str = BackendType.VLLM.value, + container_name: str | None = None, +) -> bool | None: + """ + Poll the OpenAI API endpoint until it's ready or cancelled. + + Args: + worker_address: Worker address (host:port) + port: Host port for the model API + deployment_id: Deployment ID for status updates + db: Database session + backend: Backend type (vllm, ollama, etc.) + container_name: Container name for local Docker network communication. + If set, uses container_name:8000 instead of worker_ip:port. + This is needed for Windows Docker Desktop compatibility. + + Returns: + True: API is ready + None: Cancelled (user stopped deployment) + """ + # For local deployments with container_name, use Docker internal networking + # All backends (vLLM, SGLang, Ollama) are configured to use port 8000 + if container_name: + api_base_url = f"http://{container_name}:8000" + logger.info(f"Using Docker network for API: {api_base_url}") + else: + worker_ip = worker_address.split(":")[0] + api_base_url = f"http://{worker_ip}:{port}" + + # Both vLLM and Ollama support OpenAI-compatible /v1/models endpoint + health_endpoint = f"{api_base_url}/v1/models" + + # For Ollama, we can also check /api/tags as a fallback + is_ollama = backend == BackendType.OLLAMA.value + + elapsed = 0 + check_count = 0 + shown_slow_message = False + + logger.info(f"Waiting for API to be ready at {health_endpoint} (backend={backend})") + + async with httpx.AsyncClient(timeout=HEALTH_CHECK_REQUEST_TIMEOUT) as client: + while True: # Wait indefinitely until ready or cancelled + check_count += 1 + + # Check if deployment was cancelled + try: + result = await db.execute(select(Deployment).where(Deployment.id == deployment_id)) + deployment = result.scalar_one_or_none() + if deployment and deployment.status in [ + DeploymentStatus.STOPPED.value, + DeploymentStatus.STOPPING.value, + ]: + logger.info(f"Deployment {deployment_id} was cancelled") + return None + except Exception as e: + logger.debug(f"Error checking deployment status: {e}") + + try: + response = await client.get(health_endpoint) + + if response.status_code == 200: + data = response.json() + # vLLM returns {"object": "list", "data": [...]} + # Ollama returns {"object": "list", "data": [...]} (OpenAI compat) + if data.get("data") and len(data["data"]) > 0: + logger.info( + f"API ready at {health_endpoint} after {elapsed}s " + f"({check_count} checks)" + ) + return True + + # For Ollama, also try the native /api/tags endpoint + if is_ollama and response.status_code != 200: + ollama_endpoint = f"{api_base_url}/api/tags" + ollama_response = await client.get(ollama_endpoint) + if ollama_response.status_code == 200: + ollama_data = ollama_response.json() + if ollama_data.get("models") and len(ollama_data["models"]) > 0: + logger.info(f"Ollama API ready at {ollama_endpoint} after {elapsed}s") + return True + + logger.debug(f"Health check {check_count}: status={response.status_code}") + + except httpx.ConnectError: + # Container not ready yet, this is expected during startup + logger.debug(f"Health check {check_count}: connection refused") + except httpx.ReadTimeout: + logger.debug(f"Health check {check_count}: read timeout") + except Exception as e: + logger.debug(f"Health check {check_count}: {type(e).__name__}: {e}") + + # Update status message periodically + if check_count % 6 == 0: # Every 30 seconds + try: + result = await db.execute( + select(Deployment).where(Deployment.id == deployment_id) + ) + deployment = result.scalar_one_or_none() + if deployment and deployment.status == DeploymentStatus.STARTING.value: + mins = elapsed // 60 + secs = elapsed % 60 + time_str = f"{mins}m {secs}s" if mins > 0 else f"{secs}s" + + # Show patience message after threshold + if elapsed >= HEALTH_CHECK_SLOW_THRESHOLD and not shown_slow_message: + deployment.status_message = ( + f"Loading model... ({time_str}) - " + "Large model or slow network detected. Please be patient." + ) + shown_slow_message = True + elif shown_slow_message: + deployment.status_message = ( + f"Loading model... ({time_str}) - Please be patient." + ) + else: + deployment.status_message = ( + f"Loading model into GPU memory... ({time_str})" + ) + await db.commit() + except Exception as e: + logger.debug(f"Error updating deployment status message: {e}") + + await asyncio.sleep(HEALTH_CHECK_INTERVAL) + elapsed += HEALTH_CHECK_INTERVAL + + +async def wait_for_native_api_ready( + worker_address: str, + port: int, + deployment_id: int, + db, + backend: str = BackendType.OLLAMA.value, + timeout: int = 300, +) -> bool | None: + """Wait for native API to be ready. + + For native deployments, we check via worker agent because + Ollama only listens on localhost. + + Returns: + True: API is ready + False: Timeout + None: Cancelled + """ + # For native deployments, check via worker agent's health endpoint + worker_health_url = f"http://{worker_address}/native/health" + + elapsed = 0 + check_interval = 5 + + async with httpx.AsyncClient(timeout=10.0) as client: + while elapsed < timeout: + # Check if cancelled + try: + result = await db.execute(select(Deployment).where(Deployment.id == deployment_id)) + deployment = result.scalar_one_or_none() + if deployment and deployment.status in [ + DeploymentStatus.STOPPED.value, + DeploymentStatus.STOPPING.value, + ]: + return None + except Exception: + pass + + try: + response = await client.get( + worker_health_url, params={"backend": backend, "port": port} + ) + if response.status_code == 200: + data = response.json() + if data.get("ready"): + logger.info("Native API ready (checked via worker)") + return True + except Exception: + pass + + await asyncio.sleep(check_interval) + elapsed += check_interval + + # Update status periodically + if elapsed % 30 == 0: + try: + result = await db.execute( + select(Deployment).where(Deployment.id == deployment_id) + ) + deployment = result.scalar_one_or_none() + if deployment: + deployment.status_message = f"Loading model... ({elapsed}s)" + await db.commit() + except Exception: + pass + + return False diff --git a/backend/app/services/deployer/native.py b/backend/app/services/deployer/native.py new file mode 100644 index 0000000..4127b41 --- /dev/null +++ b/backend/app/services/deployer/native.py @@ -0,0 +1,112 @@ +"""Native Mac deployment operations. + +This module handles native deployment operations for macOS, +including Ollama, MLX, and llama.cpp backends. +""" + +import asyncio +import logging + +import httpx + +from app.models.deployment import Deployment, DeploymentStatus +from app.models.llm_model import BackendType + +logger = logging.getLogger(__name__) + + +async def deploy_native(deployment: Deployment, db) -> dict: + """Deploy using native backend (Mac without Docker). + + Supports Ollama, MLX, and llama.cpp backends on macOS. + """ + # Import here to avoid circular imports + from app.services.deployer.health import wait_for_native_api_ready + + worker = deployment.worker + model = deployment.model + backend = deployment.backend + + # Validate backend is supported + available_backends = worker.available_backends + if backend not in available_backends: + return { + "error": f"Backend '{backend}' not available on this worker. " + f"Available backends: {', '.join(available_backends)}" + } + + try: + worker_url = f"http://{worker.effective_address}/native/deploy" + + deploy_request = { + "deployment_id": deployment.id, + "deployment_name": deployment.name, + "model_id": model.model_id, + "backend": backend, + "port": 0, # Auto-assign + "extra_params": deployment.extra_params, + } + + deployment.status_message = f"Starting {backend} deployment..." + await db.commit() + + async with httpx.AsyncClient(timeout=600.0) as client: + response = await client.post(worker_url, json=deploy_request) + + if response.status_code != 200: + error_detail = response.json().get("detail", response.text) + return {"error": f"Native deployment failed: {error_detail}"} + + result = response.json() + deployment.port = result.get("port") + # Use process_id as container_id for native deployments + deployment.container_id = result.get("process_id") + + # Wait for API to be ready + deployment.status_message = "Waiting for model to be ready..." + await db.commit() + + # For Ollama on native, the API is at port 11434 + if backend == BackendType.OLLAMA.value: + api_port = 11434 + else: + api_port = deployment.port + + api_ready = await wait_for_native_api_ready( + worker.effective_address, + api_port, + deployment.id, + db, + backend=backend, + ) + + if api_ready is None: + return {} # Cancelled + elif api_ready: + deployment.status = DeploymentStatus.RUNNING.value + deployment.status_message = "Model ready" + asyncio.create_task(_update_semantic_router_config_background()) + else: + deployment.status = DeploymentStatus.ERROR.value + deployment.status_message = "API failed to start" + + await db.commit() + return {} + + except httpx.ConnectError as e: + return {"error": f"Cannot connect to worker: {e}"} + except Exception as e: + logger.exception(f"Native deployment error: {e}") + return {"error": str(e)} + + +async def _update_semantic_router_config_background(): + """Background task to update semantic router config after deployment changes.""" + try: + from app.api.apps.deployment import update_semantic_router_config_if_deployed + from app.database import async_session_maker + + async with async_session_maker() as db: + await update_semantic_router_config_if_deployed(db) + except Exception as e: + logger.debug(f"Failed to update semantic router config: {e}") diff --git a/backend/app/services/deployer/service.py b/backend/app/services/deployer/service.py new file mode 100644 index 0000000..33b1aa3 --- /dev/null +++ b/backend/app/services/deployer/service.py @@ -0,0 +1,363 @@ +"""DeployerService - Main orchestration for model deployment. + +This module contains the main DeployerService class that orchestrates +model deployment operations across different backends and workers. +""" + +import asyncio +import logging + +import httpx +from sqlalchemy import select +from sqlalchemy.orm import selectinload + +from app.database import async_session_maker +from app.models.deployment import Deployment, DeploymentStatus +from app.models.llm_model import BackendType +from app.models.worker import OSType + +from .config import build_deploy_request +from .docker import deploy_local, get_logs_local, image_exists_local, pull_image_local, stop_local +from .health import ( + HEALTH_CHECK_INTERVAL, + HEALTH_CHECK_REQUEST_TIMEOUT, + HEALTH_CHECK_SLOW_THRESHOLD, + ollama_pull_model, + wait_for_api_ready, + wait_for_ollama_ready, +) +from .native import deploy_native + +logger = logging.getLogger(__name__) + + +async def _update_semantic_router_config_background(): + """Background task to update semantic router config after deployment changes.""" + try: + from app.api.apps.deployment import update_semantic_router_config_if_deployed + + async with async_session_maker() as db: + await update_semantic_router_config_if_deployed(db) + except Exception as e: + logger.debug(f"Failed to update semantic router config: {e}") + + +class DeployerService: + """Service for deploying models to workers""" + + # Health check configuration (exposed as class attributes for backwards compatibility) + HEALTH_CHECK_INTERVAL = HEALTH_CHECK_INTERVAL + HEALTH_CHECK_SLOW_THRESHOLD = HEALTH_CHECK_SLOW_THRESHOLD + HEALTH_CHECK_REQUEST_TIMEOUT = HEALTH_CHECK_REQUEST_TIMEOUT + + async def deploy(self, deployment_id: int) -> None: + """Deploy a model to a worker""" + async with async_session_maker() as db: + result = await db.execute( + select(Deployment) + .where(Deployment.id == deployment_id) + .options( + selectinload(Deployment.worker), + selectinload(Deployment.model), + ) + ) + deployment = result.scalar_one_or_none() + + if not deployment: + logger.error(f"Deployment {deployment_id} not found") + return + + try: + # Update status to starting + deployment.status = DeploymentStatus.STARTING.value + deployment.status_message = "Sending deployment request to worker..." + await db.commit() + + # Check if worker supports Docker or needs native deployment + worker = deployment.worker + backend = deployment.backend + + # Mac with Ollama should always use native deployment (use local Ollama) + # Mac without Docker should also use native deployment + is_mac = worker.os_type == OSType.DARWIN.value + is_mac_native = is_mac and ( + backend == BackendType.OLLAMA.value or not worker.supports_docker + ) + + # Use native deployment for Mac + if is_mac_native: + result = await deploy_native(deployment, db) + if result.get("error"): + deployment.status = DeploymentStatus.ERROR.value + deployment.status_message = result["error"] + await db.commit() + return + + # Build deployment request + deploy_request = build_deploy_request(deployment) + + # Check if this is a local worker + is_local = self._is_local_worker(deployment.worker.address) + + if is_local: + # Check if image needs to be pulled + image = deploy_request["image"] + if not image_exists_local(image): + deployment.status_message = f"Pulling image: {image}..." + await db.commit() + + pull_success = await pull_image_local(image) + if not pull_success: + deployment.status = DeploymentStatus.ERROR.value + deployment.status_message = f"Failed to pull image: {image}" + await db.commit() + return + + deployment.status_message = "Starting container..." + await db.commit() + + # Deploy locally using Docker directly + result = await deploy_local(deploy_request) + if result.get("error"): + deployment.status = DeploymentStatus.ERROR.value + deployment.status_message = result["error"] + await db.commit() + return + deployment.container_id = result.get("container_id") + deployment.port = result.get("port") + # Store container_name for internal Docker network communication + local_container_name = result.get("container_name") + deployment.container_name = local_container_name + else: + local_container_name = None # Remote workers use IP:port + # Send to remote worker agent + worker_url = f"http://{deployment.worker.address}/deploy" + progress_url = ( + f"http://{deployment.worker.address}/pull-progress/{deployment.id}" + ) + + # Start deployment request and poll for progress + async with httpx.AsyncClient(timeout=300.0) as client: + # Start the deployment in a task + deploy_task = asyncio.create_task( + client.post(worker_url, json=deploy_request) + ) + + # Poll for progress while waiting + while not deploy_task.done(): + try: + progress_resp = await client.get(progress_url, timeout=5.0) + if progress_resp.status_code == 200: + progress_data = progress_resp.json() + status = progress_data.get("status", "") + image = progress_data.get("image", "") + progress = progress_data.get("progress", 0) + + if status == "pulling": + deployment.status_message = ( + f"Pulling image {image}... ({progress}%)" + ) + await db.commit() + elif status == "completed": + deployment.status_message = ( + "Image pulled, starting container..." + ) + await db.commit() + elif status == "starting": + deployment.status_message = "Starting container..." + await db.commit() + except Exception: + pass # Progress polling is best-effort + + await asyncio.sleep(2) + + response = await deploy_task + + if response.status_code != 200: + deployment.status = DeploymentStatus.ERROR.value + deployment.status_message = f"Worker returned error: {response.text}" + await db.commit() + return + + result_data = response.json() + deployment.container_id = result_data.get("container_id") + deployment.port = result_data.get("port") + + # Container started, now waiting for model to load + deployment.status = DeploymentStatus.STARTING.value + deployment.status_message = "Downloading model and Loading model into GPU memory..." + await db.commit() + + # For Ollama, we need to pull the model first + if deployment.backend == BackendType.OLLAMA.value: + deployment.status_message = "Waiting for Ollama container to start..." + await db.commit() + + # Wait for Ollama API to be available before pulling + ollama_ready = await wait_for_ollama_ready( + deployment.worker.address, + deployment.port, + container_name=local_container_name, + ) + if not ollama_ready: + deployment.status = DeploymentStatus.ERROR.value + deployment.status_message = "Ollama container failed to start" + await db.commit() + return + + deployment.status_message = "Pulling model with Ollama..." + await db.commit() + + pull_success = await ollama_pull_model( + deployment.worker.address, + deployment.port, + deployment.model.model_id, + container_name=local_container_name, + ) + if not pull_success: + deployment.status = DeploymentStatus.ERROR.value + deployment.status_message = "Failed to pull model with Ollama" + await db.commit() + return + + deployment.status_message = "Model pulled, waiting for API..." + await db.commit() + + # Wait for the API endpoint to become ready + api_ready = await wait_for_api_ready( + deployment.worker.address, + deployment.port, + deployment_id, + db, + backend=deployment.backend, + container_name=local_container_name, + ) + + # Refresh deployment object after health check updates + await db.refresh(deployment) + + if api_ready is None: + # Deployment was cancelled, don't update status + logger.info(f"Deployment {deployment_id} cancelled during startup") + return + else: + # api_ready is True, model is ready + deployment.status = DeploymentStatus.RUNNING.value + deployment.status_message = "Model ready" + + # Update semantic router config if deployed + asyncio.create_task(_update_semantic_router_config_background()) + + except httpx.ConnectError: + deployment.status = DeploymentStatus.ERROR.value + deployment.status_message = ( + f"Cannot connect to worker at {deployment.worker.address}" + ) + except Exception as e: + logger.exception(f"Error deploying {deployment_id}") + deployment.status = DeploymentStatus.ERROR.value + deployment.status_message = str(e) + + await db.commit() + + def _is_local_worker(self, address: str) -> bool: + """Check if the worker address refers to the local machine.""" + if not address: + return False + host = address.split(":")[0].lower() + return host in ("localhost", "127.0.0.1", "local") + + async def stop(self, deployment_id: int) -> None: + """Stop a deployment""" + async with async_session_maker() as db: + result = await db.execute( + select(Deployment) + .where(Deployment.id == deployment_id) + .options(selectinload(Deployment.worker)) + ) + deployment = result.scalar_one_or_none() + + if not deployment or not deployment.container_id: + return + + try: + worker = deployment.worker + + # Check if this is a native deployment (process_id starts with "native-") + is_native = deployment.container_id.startswith("native-") + + if is_native: + # Stop native process + worker_url = f"http://{worker.effective_address}/native/stop" + async with httpx.AsyncClient(timeout=60.0) as client: + await client.post( + worker_url, + json={"process_id": deployment.container_id}, + ) + else: + # Docker-based deployment + is_local = self._is_local_worker(worker.address) + + if is_local: + # Stop locally using Docker directly + await stop_local(deployment.container_id) + else: + worker_url = f"http://{worker.address}/stop" + + async with httpx.AsyncClient(timeout=60.0) as client: + await client.post( + worker_url, json={"container_id": deployment.container_id} + ) + + except Exception as e: + logger.warning(f"Error stopping deployment {deployment_id}: {e}") + + async def get_logs(self, deployment: Deployment, tail: int = 100) -> str: + """Get logs from a deployment""" + if not deployment.container_id or not deployment.worker: + return "No container running" + + try: + worker = deployment.worker + + # Check if this is a native deployment + is_native = deployment.container_id.startswith("native-") + + if is_native: + # Get logs from native process + worker_url = ( + f"http://{worker.effective_address}/native/logs/{deployment.container_id}" + ) + async with httpx.AsyncClient(timeout=30.0) as client: + response = await client.get(worker_url, params={"tail": tail}) + if response.status_code == 200: + return response.json().get("logs", "") + else: + return f"Error fetching logs: {response.text}" + else: + # Docker-based deployment + is_local = self._is_local_worker(worker.address) + + if is_local: + return get_logs_local(deployment.container_id, tail) + else: + worker_url = f"http://{worker.address}/logs" + + async with httpx.AsyncClient(timeout=30.0) as client: + response = await client.get( + worker_url, + params={ + "container_id": deployment.container_id, + "tail": tail, + }, + ) + + if response.status_code == 200: + return response.json().get("logs", "") + else: + return f"Error fetching logs: {response.text}" + + except httpx.ConnectError: + return f"Cannot connect to worker at {deployment.worker.address}" + except Exception as e: + return f"Error: {str(e)}" diff --git a/backend/app/services/mcp/agent.py b/backend/app/services/mcp/agent.py index 016c541..be5d491 100644 --- a/backend/app/services/mcp/agent.py +++ b/backend/app/services/mcp/agent.py @@ -210,6 +210,7 @@ def __init__( mcp_api_url: str | None = None, mcp_api_token: str | None = None, max_iterations: int = 20, + supports_tool_calling: bool = True, ): """ Initialize the Agent Service. @@ -222,12 +223,14 @@ def __init__( mcp_api_url: LMStack API URL for MCP server. mcp_api_token: LMStack API token for MCP server. max_iterations: Maximum tool call iterations. + supports_tool_calling: Whether the LLM supports tool/function calling. """ self.llm_client = llm_client self.llm_model = llm_model self.llm_base_url = llm_base_url self.llm_api_key = llm_api_key self.max_iterations = max_iterations + self.supports_tool_calling = supports_tool_calling # MCP configuration self.mcp_api_url = mcp_api_url @@ -537,6 +540,15 @@ async def chat( self.conversation.append(user_msg) try: + # Warn if tool calling is not supported + if not self.supports_tool_calling: + yield AgentEvent( + type=EventType.MESSAGE, + content="**Note:** This deployment does not have tool calling enabled. " + "The agent will respond as a regular chat model without executing tools. " + "To enable Agent features, redeploy the model with 'Enable Tool Calling' option in Advanced Settings.\n\n", + ) + # Initial thinking event yield AgentEvent(type=EventType.THINKING, content="Analyzing your request...") @@ -555,15 +567,18 @@ async def chat( # For continuation, don't add the user message again messages = messages[:-1] - tools_schema = self._build_tools_schema() + # Only include tools if tool calling is supported + tools_schema = self._build_tools_schema() if self.supports_tool_calling else None try: # Use streaming for better UX + # Note: Don't pass tool_choice="auto" as it requires vLLM to have + # --enable-auto-tool-choice and --tool-call-parser flags enabled. + # Without tool_choice, the API uses default behavior. stream = await self.llm_client.chat.completions.create( model=self.llm_model, messages=messages, tools=tools_schema if tools_schema else None, - tool_choice="auto" if tools_schema else None, stream=True, ) diff --git a/backend/app/services/tuning/__init__.py b/backend/app/services/tuning/__init__.py new file mode 100644 index 0000000..9a454ce --- /dev/null +++ b/backend/app/services/tuning/__init__.py @@ -0,0 +1,19 @@ +"""Auto-Tuning Agent Service package. + +A true LLM-driven agent that: +1. Uses an LLM to reason about configurations +2. Actually deploys models with different configs +3. Runs real benchmarks against deployed endpoints +4. Analyzes results and decides next steps +""" + +from .agent import run_tuning_agent +from .executor import AgentToolExecutor +from .tools import AGENT_SYSTEM_PROMPT, get_agent_tools + +__all__ = [ + "run_tuning_agent", + "AgentToolExecutor", + "AGENT_SYSTEM_PROMPT", + "get_agent_tools", +] diff --git a/backend/app/services/tuning/agent.py b/backend/app/services/tuning/agent.py new file mode 100644 index 0000000..ea300ad --- /dev/null +++ b/backend/app/services/tuning/agent.py @@ -0,0 +1,324 @@ +"""Main agent runner for auto-tuning. + +This module contains the main run_tuning_agent function that orchestrates +the auto-tuning process using an LLM-driven agent. +""" + +import json +import logging +from datetime import UTC, datetime + +from openai import AsyncOpenAI +from sqlalchemy import select +from sqlalchemy.orm import selectinload + +from app.config import get_settings +from app.database import async_session_maker +from app.models.tuning import TuningJob, TuningJobStatus + +from .executor import AgentToolExecutor +from .tools import AGENT_SYSTEM_PROMPT, get_agent_tools + +logger = logging.getLogger(__name__) + + +async def run_tuning_agent(job_id: int, llm_config: dict | None = None): + """Run the Auto-Tuning Agent for a job + + Args: + job_id: The tuning job ID + llm_config: Optional LLM configuration from chat panel: + - deployment_id: Use a local deployment + - base_url: Custom endpoint URL + - api_key: API key for the endpoint + - model: Model name + """ + settings = get_settings() + + async with async_session_maker() as db: + # Load job with relationships + result = await db.execute( + select(TuningJob) + .where(TuningJob.id == job_id) + .options( + selectinload(TuningJob.model), + selectinload(TuningJob.worker), + ) + ) + job = result.scalar_one_or_none() + + if not job: + logger.error(f"Tuning job {job_id} not found") + return + + # Initialize tool executor + executor = AgentToolExecutor(db, job) + + try: + # Determine LLM configuration (priority: llm_config > settings > auto-detect) + api_key = None + base_url = None + model_name = "gpt-4o" + + if llm_config: + # Use config from chat panel + if llm_config.get("deployment_id"): + # Use specified local deployment + from app.models.deployment import Deployment + + deploy_result = await db.execute( + select(Deployment) + .where(Deployment.id == llm_config["deployment_id"]) + .options(selectinload(Deployment.worker), selectinload(Deployment.model)) + ) + deployment = deploy_result.scalar_one_or_none() + + if deployment and deployment.worker: + worker_ip = deployment.worker.address.split(":")[0] + base_url = f"http://{worker_ip}:{deployment.port}/v1" + api_key = "dummy" + model_name = deployment.model.model_id if deployment.model else model_name + logger.info( + f"Using specified deployment as agent LLM: {base_url} ({model_name})" + ) + else: + job.status = TuningJobStatus.FAILED.value + job.status_message = ( + f"Deployment {llm_config['deployment_id']} not found or not running" + ) + await db.commit() + return + elif llm_config.get("base_url"): + # Use custom endpoint + base_url = llm_config["base_url"] + api_key = llm_config.get("api_key") or "dummy" + model_name = llm_config.get("model") or model_name + logger.info(f"Using custom endpoint as agent LLM: {base_url} ({model_name})") + + # Fall back to settings if no llm_config + if not api_key: + api_key = settings.openai_api_key + base_url = settings.openai_base_url + model_name = settings.openai_model or model_name + + # If still no API key, try to find any running deployment + if not api_key: + from app.models.deployment import Deployment, DeploymentStatus + + deploy_result = await db.execute( + select(Deployment) + .where(Deployment.status == DeploymentStatus.RUNNING.value) + .options(selectinload(Deployment.worker), selectinload(Deployment.model)) + .limit(1) + ) + local_deployment = deploy_result.scalar_one_or_none() + + if local_deployment and local_deployment.worker: + worker_ip = local_deployment.worker.address.split(":")[0] + base_url = f"http://{worker_ip}:{local_deployment.port}/v1" + api_key = "dummy" + model_name = ( + local_deployment.model.model_id if local_deployment.model else model_name + ) + logger.info( + f"Auto-detected local deployment as agent LLM: {base_url} ({model_name})" + ) + else: + job.status = TuningJobStatus.FAILED.value + job.status_message = ( + "No LLM configured for Auto-Tuning Agent. " + "Please select a model in the chat panel, or deploy a model first." + ) + await db.commit() + return + + # Initialize OpenAI client (supports OpenAI-compatible endpoints) + client = AsyncOpenAI(api_key=api_key, base_url=base_url or "https://api.openai.com/v1") + + # Build initial user message with explicit steps + user_message = f"""Find the optimal deployment configuration for {job.model.name} on {job.worker.name}. +Optimization target: {job.optimization_target} +Model ID: {job.model_id}, Worker ID: {job.worker_id} + +REQUIRED STEPS (you must complete all of these): +1. Call get_hardware_info(worker_id={job.worker_id}) to check GPU specs +2. Call query_knowledge_base() to check historical data +3. Deploy the model with deploy_model() and wait for it +4. Run run_benchmark() to test performance +5. Stop the deployment and optionally test other configurations +6. Call finish_tuning() with best_config and all benchmark results + +Start with Step 1: get_hardware_info""" + + messages = [ + {"role": "system", "content": AGENT_SYSTEM_PROMPT}, + {"role": "user", "content": user_message}, + ] + + # Initialize conversation log for UI display + conversation_log = [ + { + "role": "user", + "content": user_message, + "timestamp": datetime.now(UTC).isoformat(), + } + ] + + # Helper to save conversation log + async def save_log(): + job.conversation_log = conversation_log + await db.commit() + + # Update job status + job.status = TuningJobStatus.ANALYZING.value + job.status_message = "Agent is analyzing the environment..." + job.conversation_log = conversation_log + await db.commit() + + # Agent loop - limit iterations to prevent infinite loops + max_iterations = 15 + iteration = 0 + + while iteration < max_iterations: + iteration += 1 + + # Check if cancelled + await db.refresh(job) + if job.status == TuningJobStatus.CANCELLED.value: + logger.info(f"Job {job_id} was cancelled") + await executor.cleanup() + return + + # Call LLM + logger.info(f"Agent iteration {iteration}, calling LLM with model: {model_name}...") + + # Force tool calls if essential steps not completed + # Use "required" to ensure tool is called when needed + if not executor.hardware_checked or ( + not executor.benchmark_results and iteration < 10 + ): + tool_choice = "required" + else: + tool_choice = "auto" + + response = await client.chat.completions.create( + model=model_name, + messages=messages, + tools=get_agent_tools(), + tool_choice=tool_choice, + max_tokens=4096, + ) + + assistant_message = response.choices[0].message + messages.append(assistant_message.model_dump(exclude_none=True)) + + # Add assistant message to conversation log + log_entry = { + "role": "assistant", + "content": assistant_message.content or "", + "timestamp": datetime.now(UTC).isoformat(), + } + if assistant_message.tool_calls: + log_entry["tool_calls"] = [ + { + "id": tc.id, + "name": tc.function.name, + "arguments": tc.function.arguments, + } + for tc in assistant_message.tool_calls + ] + conversation_log.append(log_entry) + await save_log() + + # Check if no tool calls - prompt the agent to take action + if not assistant_message.tool_calls: + logger.warning(f"Agent responded without tool calls at iteration {iteration}") + # Build a context-aware prompt based on current state + if not executor.hardware_checked: + prompt_message = ( + f"You must call get_hardware_info(worker_id={job.worker_id}) first " + "to check the GPU environment before proceeding." + ) + elif not executor.benchmark_results: + prompt_message = ( + "You must run at least one benchmark before finishing. " + f"Call deploy_model(model_id={job.model_id}, worker_id={job.worker_id}, engine='vllm') " + "to deploy the model, then run run_benchmark() after it's ready." + ) + else: + prompt_message = ( + "You have benchmark results. Call finish_tuning() with the best configuration " + "to complete the tuning process." + ) + messages.append({"role": "user", "content": prompt_message}) + conversation_log.append( + { + "role": "user", + "content": prompt_message, + "timestamp": datetime.now(UTC).isoformat(), + } + ) + await save_log() + continue # Continue the loop to get tool calls + + # Execute tool calls + for tool_call in assistant_message.tool_calls: + tool_name = tool_call.function.name + tool_args = json.loads(tool_call.function.arguments) + + logger.info(f"Executing tool: {tool_name}({tool_args})") + + # Update job progress + job.status_message = f"Executing: {tool_name}" + job.progress = { + "step": iteration, + "total_steps": max_iterations, + "step_name": tool_name, + "step_description": f"Executing {tool_name} with args: {tool_args}", + "configs_tested": 0, + "configs_total": 0, + } + await db.commit() + + # Execute tool + result = await executor.execute(tool_name, tool_args) + + # Add tool result to conversation log + conversation_log.append( + { + "role": "tool", + "tool_call_id": tool_call.id, + "name": tool_name, + "content": result, + "timestamp": datetime.now(UTC).isoformat(), + } + ) + await save_log() + + # Check if this was a termination tool + if tool_name == "finish_tuning": + logger.info(f"Agent completed tuning for job {job_id}") + return + if tool_name == "abort_tuning": + logger.info(f"Agent aborted tuning for job {job_id}") + return + + # Add tool result to messages + messages.append( + {"role": "tool", "tool_call_id": tool_call.id, "content": result} + ) + + # If we reached max iterations without finishing + job.status = TuningJobStatus.FAILED.value + job.status_message = "Agent reached maximum iterations without completing" + await db.commit() + + except Exception as e: + logger.exception(f"Agent error for job {job_id}: {e}") + job.status = TuningJobStatus.FAILED.value + job.status_message = f"Agent error: {str(e)}" + await db.commit() + + finally: + # Cleanup any test deployments + await executor.cleanup() diff --git a/backend/app/services/tuning/benchmark.py b/backend/app/services/tuning/benchmark.py new file mode 100644 index 0000000..47b6708 --- /dev/null +++ b/backend/app/services/tuning/benchmark.py @@ -0,0 +1,163 @@ +"""HTTP benchmark implementation for tuning agent. + +This module contains the benchmark functions used to measure +model performance during auto-tuning. +""" + +import asyncio +import json +import logging +import time + +import httpx + +logger = logging.getLogger(__name__) + + +async def run_http_benchmark( + base_url: str, + model_name: str = "default", + num_requests: int = 20, + concurrency: int = 4, + input_tokens: int = 128, + output_tokens: int = 64, +) -> dict: + """Run actual HTTP benchmark against an OpenAI-compatible endpoint. + + Args: + base_url: Base URL of the OpenAI-compatible API (e.g., "http://localhost:8000/v1") + model_name: Model name to use for API calls + num_requests: Number of requests to send + concurrency: Number of concurrent requests + input_tokens: Approximate input token count + output_tokens: Max output tokens + + Returns: + Dictionary with benchmark results including throughput and latency metrics + """ + # Generate test prompt with approximate token count + test_prompt = "Write a detailed explanation about " + " ".join( + ["artificial intelligence"] * (input_tokens // 3) + ) + + results = [] + errors = 0 + + semaphore = asyncio.Semaphore(concurrency) + + async def make_request(client: httpx.AsyncClient) -> dict | None: + nonlocal errors + async with semaphore: + start_time = time.perf_counter() + first_token_time = None + token_times = [] + total_tokens = 0 + + try: + async with client.stream( + "POST", + f"{base_url}/chat/completions", + json={ + "model": model_name, + "messages": [{"role": "user", "content": test_prompt}], + "max_tokens": output_tokens, + "stream": True, + }, + timeout=60.0, + ) as response: + if response.status_code != 200: + errors += 1 + return None + + async for line in response.aiter_lines(): + if line.startswith("data: "): + data = line[6:] + if data == "[DONE]": + break + try: + chunk = json.loads(data) + content = ( + chunk.get("choices", [{}])[0] + .get("delta", {}) + .get("content", "") + ) + if content: + current_time = time.perf_counter() + if first_token_time is None: + first_token_time = current_time + token_times.append(current_time) + total_tokens += 1 + except json.JSONDecodeError: + pass + + end_time = time.perf_counter() + + if first_token_time and total_tokens > 0: + ttft = (first_token_time - start_time) * 1000 # ms + total_time = end_time - start_time + + # Calculate TPOT (time per output token) excluding TTFT + if total_tokens > 1: + generation_time = end_time - first_token_time + tpot = (generation_time / (total_tokens - 1)) * 1000 # ms + else: + tpot = 0 + + return { + "ttft_ms": ttft, + "tpot_ms": tpot, + "total_tokens": total_tokens, + "total_time_s": total_time, + } + except Exception as e: + logger.warning(f"Benchmark request failed: {e}") + errors += 1 + return None + + async with httpx.AsyncClient() as client: + # Warm up with a few requests + logger.info("Warming up benchmark endpoint...") + for _ in range(min(2, num_requests)): + await make_request(client) + + # Run actual benchmark + logger.info(f"Running {num_requests} benchmark requests with concurrency {concurrency}...") + tasks = [make_request(client) for _ in range(num_requests)] + results = await asyncio.gather(*tasks) + + # Filter out failed requests + valid_results = [r for r in results if r is not None] + + if not valid_results: + return {"success": False, "error": "All requests failed", "errors": errors} + + # Calculate metrics + ttft_values = [r["ttft_ms"] for r in valid_results] + tpot_values = [r["tpot_ms"] for r in valid_results if r["tpot_ms"] > 0] + total_tokens = sum(r["total_tokens"] for r in valid_results) + total_time = sum(r["total_time_s"] for r in valid_results) + + avg_ttft = sum(ttft_values) / len(ttft_values) + avg_tpot = sum(tpot_values) / len(tpot_values) if tpot_values else 0 + throughput = total_tokens / total_time if total_time > 0 else 0 + + return { + "success": True, + "metrics": { + "throughput_tps": round(throughput, 2), + "avg_ttft_ms": round(avg_ttft, 2), + "avg_tpot_ms": round(avg_tpot, 2), + "p50_ttft_ms": round(sorted(ttft_values)[len(ttft_values) // 2], 2), + "p99_ttft_ms": ( + round(sorted(ttft_values)[int(len(ttft_values) * 0.99)], 2) + if len(ttft_values) > 1 + else round(ttft_values[0], 2) + ), + }, + "summary": { + "total_requests": num_requests, + "successful_requests": len(valid_results), + "failed_requests": errors, + "total_tokens_generated": total_tokens, + }, + } diff --git a/backend/app/services/tuning/executor.py b/backend/app/services/tuning/executor.py new file mode 100644 index 0000000..8170eee --- /dev/null +++ b/backend/app/services/tuning/executor.py @@ -0,0 +1,756 @@ +"""AgentToolExecutor - Executes agent tools with real system interactions. + +This module contains the AgentToolExecutor class that handles executing +tools called by the auto-tuning LLM agent. +""" + +import asyncio +import json +import logging +import time +from datetime import UTC, datetime + +import httpx +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import selectinload + +from app.models.deployment import Deployment, DeploymentStatus +from app.models.llm_model import LLMModel +from app.models.tuning import PerformanceKnowledge, TuningJob, TuningJobStatus +from app.models.worker import Worker + +from .benchmark import run_http_benchmark +from .helpers import extract_model_family + +logger = logging.getLogger(__name__) + + +class AgentToolExecutor: + """Execute agent tools with real system interactions""" + + def __init__(self, db: AsyncSession, job: TuningJob): + self.db = db + self.job = job + self.created_deployments: list[int] = [] + self.benchmark_results: list[dict] = [] # Track completed benchmarks + self.hardware_checked: bool = False # Track if hardware was checked + + async def execute(self, tool_name: str, args: dict) -> str: + """Execute a tool and return result as string""" + try: + method = getattr(self, f"_tool_{tool_name}", None) + if method: + result = await method(**args) + return json.dumps(result, indent=2, default=str) + return json.dumps({"error": f"Unknown tool: {tool_name}"}) + except Exception as e: + logger.error(f"Tool {tool_name} failed: {e}") + return json.dumps({"error": str(e)}) + + async def _tool_get_hardware_info(self, worker_id: int) -> dict: + """Get hardware info for a worker""" + self.hardware_checked = True # Mark hardware as checked + result = await self.db.execute(select(Worker).where(Worker.id == worker_id)) + worker = result.scalar_one_or_none() + + if not worker: + return {"error": "Worker not found"} + + gpu_info = worker.gpu_info or [] + + # Determine the unit divisor from memory_total (which is always large) + # memory_total for a typical GPU should be 8-80 GB + def get_divisor(memory_total: int | float) -> float: + """Determine the divisor to convert memory values to GB. + + We use memory_total to figure out what unit the values are in: + - If memory_total > 1 billion: values are in bytes + - If memory_total > 1 million: values are in KB + - If memory_total > 1000: values are in MB + - Otherwise: values are already in GB + """ + if memory_total > 1_000_000_000: + return 1024 * 1024 * 1024 # bytes to GB + elif memory_total > 1_000_000: + return 1024 * 1024 # KB to GB + elif memory_total > 1000: + return 1024 # MB to GB + else: + return 1 # already GB + + def convert_gpu_memory(gpu: dict) -> dict: + """Convert a single GPU's memory values to GB.""" + mem_total = gpu.get("memory_total", 0) + mem_used = gpu.get("memory_used", 0) + + divisor = get_divisor(mem_total) + + return { + "memory_total_gb": round(mem_total / divisor, 1) if mem_total else 0, + "memory_used_gb": round(mem_used / divisor, 1) if mem_used else 0, + "memory_free_gb": round((mem_total - mem_used) / divisor, 1) if mem_total else 0, + } + + # Convert GPU memory values + gpus_converted = [] + total_vram_gb = 0 + for i, g in enumerate(gpu_info): + mem = convert_gpu_memory(g) + gpus_converted.append( + { + "index": g.get("index", i), + "name": g.get("name", "Unknown"), + "memory_total_gb": mem["memory_total_gb"], + "memory_used_gb": mem["memory_used_gb"], + "memory_free_gb": mem["memory_free_gb"], + "utilization_percent": g.get("utilization_gpu", 0), + } + ) + total_vram_gb += mem["memory_total_gb"] + + return { + "worker_id": worker.id, + "worker_name": worker.name, + "status": worker.status, + "gpu_count": len(gpu_info), + "gpus": gpus_converted, + "total_vram_gb": round(total_vram_gb, 1), + } + + async def _tool_get_model_info(self, model_id: int) -> dict: + """Get model info""" + result = await self.db.execute(select(LLMModel).where(LLMModel.id == model_id)) + model = result.scalar_one_or_none() + + if not model: + return {"error": "Model not found"} + + # Extract model family from name + model_family = extract_model_family(model.name) + + return { + "model_id": model.id, + "name": model.name, + "model_id_hf": model.model_id, + "source": model.source, + "model_family": model_family, + "default_backend": model.backend, + } + + async def _tool_query_knowledge_base( + self, model_family: str | None = None, gpu_model: str | None = None + ) -> dict: + """Query knowledge base for similar configurations""" + stmt = select(PerformanceKnowledge) + + if model_family: + stmt = stmt.where(PerformanceKnowledge.model_family.ilike(f"%{model_family}%")) + if gpu_model: + stmt = stmt.where(PerformanceKnowledge.gpu_model.ilike(f"%{gpu_model}%")) + + stmt = stmt.order_by(PerformanceKnowledge.score.desc().nulls_last()).limit(5) + + result = await self.db.execute(stmt) + records = result.scalars().all() + + if not records: + return { + "found": 0, + "message": "No historical data found. You'll need to run benchmarks to gather data.", + "records": [], + } + + return { + "found": len(records), + "records": [ + { + "model_name": r.model_name, + "model_family": r.model_family, + "gpu_model": r.gpu_model, + "gpu_count": r.gpu_count, + "engine": r.engine, + "quantization": r.quantization, + "tensor_parallel": r.tensor_parallel, + "throughput_tps": r.throughput_tps, + "ttft_ms": r.ttft_ms, + "tpot_ms": r.tpot_ms, + "score": r.score, + } + for r in records + ], + } + + async def _tool_list_deployments(self, worker_id: int) -> dict: + """List all deployments on a worker""" + try: + result = await self.db.execute( + select(Deployment) + .where(Deployment.worker_id == worker_id) + .options(selectinload(Deployment.model)) + ) + deployments = result.scalars().all() + + if not deployments: + return { + "worker_id": worker_id, + "count": 0, + "deployments": [], + "message": "No deployments found on this worker. GPU memory may be used by processes outside LMStack.", + } + + deployment_list = [] + for d in deployments: + deployment_list.append( + { + "deployment_id": d.id, + "name": d.name, + "model_name": d.model.name if d.model else "Unknown", + "status": d.status, + "backend": d.backend, + "port": d.port, + "container_id": d.container_id[:12] if d.container_id else None, + } + ) + + return { + "worker_id": worker_id, + "count": len(deployments), + "deployments": deployment_list, + "message": f"Found {len(deployments)} deployment(s). Stop running deployments to free GPU memory.", + } + except Exception as e: + logger.exception(f"Failed to list deployments: {e}") + return {"error": str(e)} + + async def _tool_deploy_model( + self, + model_id: int, + worker_id: int, + engine: str, + gpu_indexes: list[int] | None = None, + extra_params: dict | None = None, + ) -> dict: + """Deploy a model with specific configuration""" + from app.services.deployer import DeployerService + + try: + # Check if there are any pending deployments from this tuning job + if self.created_deployments: + return { + "success": False, + "error": f"You still have active deployments: {self.created_deployments}. " + f"Please stop them first using stop_deployment before creating a new one.", + } + + # Check GPU memory availability + worker_result = await self.db.execute(select(Worker).where(Worker.id == worker_id)) + worker = worker_result.scalar_one_or_none() + if worker and worker.gpu_info: + for g in worker.gpu_info: + mem_total = g.get("memory_total", 0) + mem_used = g.get("memory_used", 0) + # Check if less than 20% memory is free + if mem_total > 0 and (mem_total - mem_used) / mem_total < 0.2: + free_pct = round((mem_total - mem_used) / mem_total * 100, 1) + return { + "success": False, + "error": f"GPU memory is low (only {free_pct}% free). " + f"Please stop any existing deployments first.", + } + + # Get model to generate deployment name + model_result = await self.db.execute(select(LLMModel).where(LLMModel.id == model_id)) + model = model_result.scalar_one_or_none() + if not model: + return {"success": False, "error": "Model not found"} + + # Generate unique deployment name + deploy_name = f"tuning-{model.name.replace('/', '-')[:30]}-{int(time.time())}" + + # Create deployment + deployment = Deployment( + name=deploy_name, + model_id=model_id, + worker_id=worker_id, + backend=engine, + gpu_indexes=gpu_indexes or [0], + extra_params=extra_params or {}, + status=DeploymentStatus.PENDING.value, + ) + + self.db.add(deployment) + await self.db.commit() + await self.db.refresh(deployment) + + self.created_deployments.append(deployment.id) + + # Start deployment in background using DeployerService + deployer = DeployerService() + asyncio.create_task(deployer.deploy(deployment.id)) + + return { + "success": True, + "deployment_id": deployment.id, + "deployment_name": deploy_name, + "config": { + "engine": engine, + "gpu_indexes": gpu_indexes or [0], + "extra_params": extra_params, + }, + "message": "Deployment created. Use wait_for_deployment to wait until ready.", + } + except Exception as e: + logger.exception(f"Failed to deploy model: {e}") + return {"success": False, "error": str(e)} + + async def _tool_wait_for_deployment( + self, deployment_id: int, timeout_seconds: int = 300 + ) -> dict: + """Wait for deployment to be ready""" + start_time = time.time() + + while time.time() - start_time < timeout_seconds: + result = await self.db.execute( + select(Deployment) + .where(Deployment.id == deployment_id) + .options(selectinload(Deployment.worker)) + ) + deployment = result.scalar_one_or_none() + + if not deployment: + return { + "success": False, + "deployment_id": deployment_id, + "error": "Deployment not found. It may have been deleted.", + } + + # Update job status to show deployment progress + elapsed = int(time.time() - start_time) + status_map = { + "pending": "Preparing deployment...", + "starting": "Loading model into GPU memory...", + "running": "Model ready!", + "error": "Deployment failed", + "stopped": "Deployment stopped", + } + status_desc = status_map.get(deployment.status, deployment.status) + self.job.status_message = f"Waiting for model: {status_desc} ({elapsed}s)" + self.job.progress = { + "step": self.job.progress.get("step", 0) if self.job.progress else 0, + "total_steps": ( + self.job.progress.get("total_steps", 15) if self.job.progress else 15 + ), + "step_name": "wait_for_deployment", + "step_description": f"Deployment #{deployment_id}: {status_desc}", + "deployment_status": deployment.status, + "deployment_message": deployment.status_message or "", + "elapsed_seconds": elapsed, + } + await self.db.commit() + + if deployment.status == DeploymentStatus.RUNNING.value: + return { + "success": True, + "deployment_id": deployment_id, + "status": "running", + "port": deployment.port, + "endpoint": f"http://{deployment.worker.address.split(':')[0] if deployment.worker else 'localhost'}:{deployment.port}/v1", + "wait_time_seconds": round(time.time() - start_time, 1), + } + elif deployment.status in [ + DeploymentStatus.ERROR.value, + DeploymentStatus.STOPPED.value, + ]: + return { + "success": False, + "deployment_id": deployment_id, + "status": deployment.status, + "error": deployment.status_message or "Deployment failed", + "action_required": "Call stop_deployment to clean up before trying again", + } + + await asyncio.sleep(5) + + return { + "success": False, + "deployment_id": deployment_id, + "error": f"Timeout after {timeout_seconds}s", + "action_required": ( + f"1. Call test_deployment_endpoint({deployment_id}) to check if it's actually ready\n" + f"2. If not ready, call stop_deployment({deployment_id}) and try next config\n" + f"DO NOT wait again - move on to the next configuration!" + ), + } + + async def _tool_run_benchmark( + self, + deployment_id: int, + num_requests: int = 20, + concurrency: int = 4, + input_tokens: int = 128, + output_tokens: int = 64, + ) -> dict: + """Run actual benchmark against deployment""" + result = await self.db.execute( + select(Deployment) + .where(Deployment.id == deployment_id) + .options( + selectinload(Deployment.worker), + selectinload(Deployment.model), + ) + ) + deployment = result.scalar_one_or_none() + + if not deployment: + return {"error": "Deployment not found"} + + if deployment.status != DeploymentStatus.RUNNING.value: + return {"error": f"Deployment is not running (status: {deployment.status})"} + + # Build endpoint URL + worker = deployment.worker + worker_ip = worker.address.split(":")[0] + base_url = f"http://{worker_ip}:{deployment.port}/v1" + + # Get the model name for API calls + model_name = deployment.model.model_id if deployment.model else "default" + + # Run benchmark + metrics = await run_http_benchmark( + base_url=base_url, + model_name=model_name, + num_requests=num_requests, + concurrency=concurrency, + input_tokens=input_tokens, + output_tokens=output_tokens, + ) + + # Save successful benchmark results for tracking + if metrics.get("success"): + self.benchmark_results.append( + { + "deployment_id": deployment_id, + "engine": deployment.backend, + "gpu_indexes": deployment.gpu_indexes or [0], + "extra_params": deployment.extra_params or {}, + "metrics": metrics.get("metrics", {}), + } + ) + + return metrics + + async def _tool_stop_deployment(self, deployment_id: int) -> dict: + """Stop and remove a deployment""" + from app.services.deployer import DeployerService + + try: + # Get deployment + result = await self.db.execute(select(Deployment).where(Deployment.id == deployment_id)) + deployment = result.scalar_one_or_none() + + if not deployment: + return {"success": False, "error": "Deployment not found"} + + # Stop container if running + if deployment.container_id: + try: + deployer = DeployerService() + await deployer.stop(deployment_id) + except Exception as e: + logger.warning(f"Failed to stop container for deployment {deployment_id}: {e}") + + # Always update status to stopped first (in case delete fails) + deployment.status = DeploymentStatus.STOPPED.value + deployment.status_message = "Stopped by tuning agent" + await self.db.commit() + + # Then try to delete + try: + await self.db.delete(deployment) + await self.db.commit() + except Exception as e: + logger.warning(f"Failed to delete deployment {deployment_id}: {e}") + # Status is already stopped, so it's ok + + if deployment_id in self.created_deployments: + self.created_deployments.remove(deployment_id) + + return {"success": True, "message": f"Deployment {deployment_id} stopped and removed"} + except Exception as e: + logger.exception(f"Failed to stop deployment: {e}") + # Try to at least mark it stopped + try: + result = await self.db.execute( + select(Deployment).where(Deployment.id == deployment_id) + ) + deployment = result.scalar_one_or_none() + if deployment: + deployment.status = DeploymentStatus.STOPPED.value + await self.db.commit() + except Exception: + pass + return {"success": False, "error": str(e)} + + async def _tool_check_deployment_status(self, deployment_id: int) -> dict: + """Check the current status of a deployment without waiting""" + try: + result = await self.db.execute( + select(Deployment) + .where(Deployment.id == deployment_id) + .options(selectinload(Deployment.worker)) + ) + deployment = result.scalar_one_or_none() + + if not deployment: + return {"error": "Deployment not found"} + + return { + "deployment_id": deployment_id, + "status": deployment.status, + "status_message": deployment.status_message, + "container_id": deployment.container_id, + "port": deployment.port, + "backend": deployment.backend, + "is_ready": deployment.status == DeploymentStatus.RUNNING.value, + "is_failed": deployment.status == DeploymentStatus.ERROR.value, + "is_loading": deployment.status == DeploymentStatus.STARTING.value, + } + except Exception as e: + logger.exception(f"Failed to check deployment status: {e}") + return {"error": str(e)} + + async def _tool_test_deployment_endpoint(self, deployment_id: int) -> dict: + """Test if the deployment API endpoint is responding""" + try: + result = await self.db.execute( + select(Deployment) + .where(Deployment.id == deployment_id) + .options(selectinload(Deployment.worker)) + ) + deployment = result.scalar_one_or_none() + + if not deployment: + return {"error": "Deployment not found"} + + if not deployment.worker or not deployment.port: + return { + "deployment_id": deployment_id, + "ready": False, + "error": "Deployment not fully initialized (no worker or port)", + } + + # Build endpoint URL + worker = deployment.worker + worker_ip = worker.address.split(":")[0] + base_url = f"http://{worker_ip}:{deployment.port}/v1" + + # Test the /v1/models endpoint + async with httpx.AsyncClient(timeout=10.0) as client: + try: + response = await client.get(f"{base_url}/models") + if response.status_code == 200: + data = response.json() + models = data.get("data", []) + if models: + return { + "deployment_id": deployment_id, + "ready": True, + "endpoint": base_url, + "models": [m.get("id") for m in models], + "message": "Deployment is ready! You can now run benchmarks.", + } + else: + return { + "deployment_id": deployment_id, + "ready": False, + "endpoint": base_url, + "message": "API responding but no models loaded yet", + } + else: + return { + "deployment_id": deployment_id, + "ready": False, + "endpoint": base_url, + "status_code": response.status_code, + "message": f"API returned status {response.status_code}", + } + except httpx.ConnectError: + return { + "deployment_id": deployment_id, + "ready": False, + "endpoint": base_url, + "message": "Cannot connect to endpoint - container may still be starting", + } + except httpx.ReadTimeout: + return { + "deployment_id": deployment_id, + "ready": False, + "endpoint": base_url, + "message": "Connection timeout - model may still be loading", + } + except Exception as e: + logger.exception(f"Failed to test deployment endpoint: {e}") + return {"error": str(e)} + + async def _tool_get_deployment_logs(self, deployment_id: int, tail: int = 100) -> dict: + """Get Docker container logs for a deployment""" + from app.services.deployer import DeployerService + + try: + # Get deployment with worker + result = await self.db.execute( + select(Deployment) + .where(Deployment.id == deployment_id) + .options(selectinload(Deployment.worker)) + ) + deployment = result.scalar_one_or_none() + + if not deployment: + return {"error": "Deployment not found"} + + if not deployment.container_id: + return { + "deployment_id": deployment_id, + "status": deployment.status, + "error": "No container ID - deployment may not have started yet", + "status_message": deployment.status_message, + } + + # Use DeployerService to get logs (handles both local and remote) + deployer = DeployerService() + logs = await deployer.get_logs(deployment, tail=tail) + + return { + "deployment_id": deployment_id, + "container_id": deployment.container_id, + "status": deployment.status, + "status_message": deployment.status_message, + "logs": logs, + } + except Exception as e: + logger.exception(f"Failed to get deployment logs: {e}") + return {"error": str(e)} + + async def _tool_finish_tuning( + self, best_config: dict, reasoning: str, all_results: list | None = None + ) -> dict: + """Mark tuning as complete and save to knowledge base""" + # Validate that proper steps were completed + if not self.hardware_checked: + return { + "success": False, + "error": "Cannot finish tuning: You must call get_hardware_info first to check the GPU environment.", + "required_action": "Call get_hardware_info(worker_id=...) before finishing.", + } + + if not self.benchmark_results and not all_results: + return { + "success": False, + "error": "Cannot finish tuning: No benchmark results found. You must run at least one benchmark.", + "required_action": "Deploy a model, run run_benchmark(), then call finish_tuning with the results.", + } + + # Use tracked benchmark results if all_results not provided + if not all_results: + all_results = self.benchmark_results + + # Update job status + self.job.status = TuningJobStatus.COMPLETED.value + self.job.status_message = "Auto-tuning completed successfully" + self.job.best_config = {**best_config, "reasoning": reasoning} + self.job.all_results = all_results or [] + self.job.completed_at = datetime.now(UTC) + + # Update progress to 100% + # Use the total_steps from current progress (set during agent loop) or default + current_total = self.job.progress.get("total_steps", 20) if self.job.progress else 20 + self.job.current_step = current_total + self.job.total_steps = current_total + self.job.progress = { + "step": current_total, + "total_steps": current_total, + "step_name": "completed", + "step_description": "Tuning completed successfully", + "configs_tested": len(all_results) if all_results else 1, + "configs_total": len(all_results) if all_results else 1, + } + + # Save results to knowledge base + saved_count = 0 + if all_results: + # Get model and worker info for knowledge base + model = self.job.model + worker = self.job.worker + gpu_info = worker.gpu_info[0] if worker.gpu_info else {} + gpu_name = gpu_info.get("name", "Unknown GPU") + + for result in all_results: + metrics = result.get("metrics", {}) + if not metrics: + continue + + # Create knowledge record + knowledge = PerformanceKnowledge( + gpu_model=gpu_name, + gpu_count=len(result.get("gpu_indexes", [0])), + total_vram_gb=sum( + ( + g.get("memory_total", 0) / (1024**3) + if g.get("memory_total", 0) > 1_000_000 + else g.get("memory_total", 0) + ) + for g in (worker.gpu_info or []) + ), + model_name=model.name, + model_family=extract_model_family(model.name), + engine=result.get("engine", best_config.get("engine", "vllm")), + quantization=result.get("extra_params", {}).get("quantization"), + tensor_parallel=len(result.get("gpu_indexes", [0])), + extra_args=result.get("extra_params"), + throughput_tps=metrics.get("throughput_tps", 0), + ttft_ms=metrics.get("avg_ttft_ms", 0), + tpot_ms=metrics.get("avg_tpot_ms", 0), + input_length=128, # Default test params + output_length=64, + concurrency=4, + score=metrics.get("throughput_tps", 0), # For throughput optimization + source_tuning_job_id=self.job.id, + ) + self.db.add(knowledge) + saved_count += 1 + + await self.db.commit() + + return { + "success": True, + "message": f"Tuning completed. Saved {saved_count} result(s) to knowledge base.", + "best_config": best_config, + "reasoning": reasoning, + } + + async def _tool_abort_tuning(self, reason: str) -> dict: + """Abort the tuning process""" + self.job.status = TuningJobStatus.FAILED.value + self.job.status_message = f"Aborted: {reason}" + self.job.completed_at = datetime.now(UTC) + + # Update progress to show aborted state + self.job.progress = { + "step": self.job.current_step, + "total_steps": self.job.total_steps, + "step_name": "aborted", + "step_description": reason, + } + + await self.db.commit() + + return {"success": True, "message": "Tuning aborted", "reason": reason} + + async def cleanup(self): + """Clean up any deployments created during tuning""" + for deployment_id in self.created_deployments: + try: + await self._tool_stop_deployment(deployment_id) + except Exception as e: + logger.warning(f"Failed to cleanup deployment {deployment_id}: {e}") diff --git a/backend/app/services/tuning/helpers.py b/backend/app/services/tuning/helpers.py new file mode 100644 index 0000000..9bc1ced --- /dev/null +++ b/backend/app/services/tuning/helpers.py @@ -0,0 +1,30 @@ +"""Utility functions for the tuning agent. + +This module contains helper functions used across the tuning package. +""" + + +def extract_model_family(model_name: str) -> str: + """Extract model family from name. + + Args: + model_name: The model name or ID (e.g., "Qwen/Qwen3-0.6B") + + Returns: + The model family name (e.g., "Qwen") or "Unknown" if not recognized + """ + name_lower = model_name.lower() + families = { + "qwen": "Qwen", + "llama": "Llama", + "mistral": "Mistral", + "deepseek": "DeepSeek", + "phi": "Phi", + "gemma": "Gemma", + "yi": "Yi", + "glm": "GLM", + } + for key, value in families.items(): + if key in name_lower: + return value + return "Unknown" diff --git a/backend/app/services/tuning/tools.py b/backend/app/services/tuning/tools.py new file mode 100644 index 0000000..14f4138 --- /dev/null +++ b/backend/app/services/tuning/tools.py @@ -0,0 +1,382 @@ +"""Tool definitions and system prompt for the tuning agent. + +This module contains the agent system prompt and tool definitions +used by the auto-tuning LLM agent. +""" + +# ============================================================================= +# Agent System Prompt +# ============================================================================= + +AGENT_SYSTEM_PROMPT = """You are an Auto-Tuning Agent helping to find the optimal deployment configuration for LLM models. + +IMPORTANT COMMUNICATION RULES: +1. ALWAYS explain what you're about to do BEFORE calling any tool +2. After each tool result, briefly summarize what you learned +3. Be conversational - talk like you're explaining to a colleague +4. No emojis, keep it professional but friendly + +=== OPTIMIZATION TARGETS === + +**Throughput** (tokens per second): +- Goal: Maximize TPS for batch processing / high volume +- Strategy: Use vLLM with large batch sizes, enable continuous batching +- Key metric: throughput_tps (higher is better) +- Trade-off: May have higher latency per request + +**Latency** (response time): +- Goal: Minimize time-to-first-token (TTFT) and time-per-output-token (TPOT) +- Strategy: Use smaller batch sizes, consider sglang for multi-turn +- Key metrics: avg_ttft_ms, avg_tpot_ms (lower is better) +- Trade-off: Lower overall throughput + +**Balanced**: +- Goal: Good balance between throughput and latency +- Strategy: Test multiple configs, calculate combined score +- Score formula: throughput_tps / (avg_ttft_ms * 0.01) - balance speed and responsiveness +- Pick config with best combined score + +**Cost** (minimum resources): +- Goal: Use minimum GPU memory while maintaining acceptable performance +- Strategy: Try quantization (awq, gptq), use fewer GPUs if possible +- Key consideration: memory_used_gb, still need decent throughput +- Trade-off: May sacrifice some performance for efficiency + +=== AVAILABLE ENGINES === +- vllm: Best throughput, tensor parallelism, supports fp8/awq/gptq quantization +- sglang: Good for multi-turn, efficient memory, fast prefix caching +- ollama: Simple deployment, good for smaller models, easy setup + +=== QUANTIZATION NOTES === +- AWQ/GPTQ: Requires a pre-quantized model (e.g., "Qwen/Qwen3-0.6B-AWQ") + Do NOT use quantization=awq with a base model like "Qwen/Qwen3-0.6B" +- FP8: Only works on Hopper+ GPUs (H100, etc.), not consumer GPUs +- For consumer GPUs (RTX 4090, etc.), use default FP16 or find a pre-quantized model + +=== PROCESS === +1. Check hardware (GPU model, VRAM, count) +2. Query knowledge base for similar setups +3. Based on optimization target, choose 2-3 promising configs to test +4. For EACH config: + a. Deploy model + b. Wait for deployment (use short timeout like 120s) + c. If timeout/slow: Check logs with get_deployment_logs to diagnose + d. If failed: STOP deployment, analyze error, try next config + e. If success: Run benchmark, record results, STOP deployment +5. Compare all results, call finish_tuning with recommendation + +=== DIAGNOSING DEPLOYMENT ISSUES === +When wait_for_deployment times out: +1. Call test_deployment_endpoint ONCE to check if API is ready + - If ready=true: Proceed to run_benchmark immediately + - If ready=false: Call get_deployment_logs ONCE +2. Based on logs, make a QUICK decision: + - If "Loading model" in logs: Wait ONE more time with wait_for_deployment(timeout_seconds=120) + - If any error: Call stop_deployment and try next config + - If unclear: Call stop_deployment and try next config + +STRICT RULES TO AVOID LOOPS: +- Maximum 2 calls to wait_for_deployment per config +- Maximum 2 calls to test_deployment_endpoint per config +- If deployment not ready after 2 waits, STOP it and move to next config +- Do NOT repeatedly check status - make a decision and move on! +- Small models (< 1B) should load in 60s, if not working after 2 attempts, skip it + +A 0.6B model should load in under 60 seconds. If it takes longer, something is wrong. + +=== HANDLING LOW GPU MEMORY === +If deploy_model fails with "GPU memory is low": +1. Call list_deployments(worker_id=X) to find existing deployments +2. Stop all running deployments using stop_deployment(deployment_id=X) +3. If no deployments found, GPU is used by external processes - inform user +4. After stopping, retry deploy_model + +IMPORTANT: ALWAYS stop a deployment before starting a new one! +- If deployment times out → check logs, then stop_deployment +- If deployment fails → stop_deployment immediately +- After benchmark complete → stop_deployment before next test +- Never have multiple test deployments running at once! + +=== EXAMPLE FLOW === +"Let me first check what hardware we have available..." +[call get_hardware_info] +"I can see we have 1x RTX 4090 with 24GB VRAM. Let me check if we have historical data..." +[call query_knowledge_base] +"No historical data found. Since we're optimizing for throughput, I'll test vLLM first..." +[call deploy_model] +"Deployment created with ID 1. Let me wait for it to become ready..." +[call wait_for_deployment(deployment_id=1, timeout_seconds=120)] +-- If timeout occurs -- +"Wait timed out after 120s. Let me first test if the endpoint is actually ready..." +[call test_deployment_endpoint(deployment_id=1)] +-- If ready=true -- +"The endpoint is responding! The model is ready. Let me run the benchmark now..." +[call run_benchmark(deployment_id=1)] +-- If ready=false -- +"Endpoint not ready yet. Let me check the container logs..." +[call get_deployment_logs(deployment_id=1, tail=100)] +"I see from the logs: 'Loading checkpoint shards: 3/4 (75%)' - model is still loading. +Let me test the endpoint again in a moment..." +[call test_deployment_endpoint(deployment_id=1)] +-- Keep testing until ready, then run benchmark -- +-- OR if logs show an error -- +"The logs show 'CUDA out of memory'. I need to stop and try a different config..." +[call stop_deployment(deployment_id=1)] +[call deploy_model with different params] + +ALWAYS provide context. Never call tools silently. +ALWAYS test endpoint and check logs before giving up on a deployment. +""" + + +# ============================================================================= +# Tools for the Agent +# ============================================================================= + + +def get_agent_tools() -> list[dict]: + """Define tools available to the agent. + + Returns: + List of tool definitions in OpenAI function calling format + """ + return [ + { + "type": "function", + "function": { + "name": "get_hardware_info", + "description": "Get detailed hardware information for a worker node including GPU model, VRAM, count, and current utilization.", + "parameters": { + "type": "object", + "properties": { + "worker_id": {"type": "integer", "description": "ID of the worker to query"} + }, + "required": ["worker_id"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "get_model_info", + "description": "Get information about the model to be deployed.", + "parameters": { + "type": "object", + "properties": { + "model_id": {"type": "integer", "description": "ID of the model"} + }, + "required": ["model_id"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "query_knowledge_base", + "description": "Query historical performance data for similar model/hardware combinations.", + "parameters": { + "type": "object", + "properties": { + "model_family": { + "type": "string", + "description": "Model family (e.g., Qwen, Llama, Mistral)", + }, + "gpu_model": { + "type": "string", + "description": "GPU model pattern (e.g., RTX 4090, A100)", + }, + }, + "required": [], + }, + }, + }, + { + "type": "function", + "function": { + "name": "list_deployments", + "description": "List all deployments on a worker. Use this to find existing deployments that may be using GPU memory.", + "parameters": { + "type": "object", + "properties": { + "worker_id": {"type": "integer", "description": "Worker ID to query"} + }, + "required": ["worker_id"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "deploy_model", + "description": "Deploy a model with specific configuration. Returns deployment ID if successful.", + "parameters": { + "type": "object", + "properties": { + "model_id": {"type": "integer"}, + "worker_id": {"type": "integer"}, + "engine": {"type": "string", "enum": ["vllm", "sglang", "ollama"]}, + "gpu_indexes": { + "type": "array", + "items": {"type": "integer"}, + "description": "GPU indices to use", + }, + "extra_params": { + "type": "object", + "description": "Additional engine parameters", + }, + }, + "required": ["model_id", "worker_id", "engine"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "wait_for_deployment", + "description": "Wait for a deployment to be ready (running status).", + "parameters": { + "type": "object", + "properties": { + "deployment_id": {"type": "integer"}, + "timeout_seconds": { + "type": "integer", + "default": 300, + "description": "Maximum time to wait", + }, + }, + "required": ["deployment_id"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "run_benchmark", + "description": "Run performance benchmark on a running deployment. Returns throughput, TTFT, TPOT metrics.", + "parameters": { + "type": "object", + "properties": { + "deployment_id": {"type": "integer"}, + "num_requests": { + "type": "integer", + "default": 20, + "description": "Number of requests to send", + }, + "concurrency": { + "type": "integer", + "default": 4, + "description": "Concurrent requests", + }, + "input_tokens": { + "type": "integer", + "default": 128, + "description": "Approximate input token count", + }, + "output_tokens": { + "type": "integer", + "default": 64, + "description": "Max output tokens", + }, + }, + "required": ["deployment_id"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "stop_deployment", + "description": "Stop and remove a deployment.", + "parameters": { + "type": "object", + "properties": {"deployment_id": {"type": "integer"}}, + "required": ["deployment_id"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "get_deployment_logs", + "description": "Get the Docker container logs for a deployment. Use this to check why a deployment is slow to start or failing.", + "parameters": { + "type": "object", + "properties": { + "deployment_id": {"type": "integer"}, + "tail": { + "type": "integer", + "default": 100, + "description": "Number of log lines to retrieve", + }, + }, + "required": ["deployment_id"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "check_deployment_status", + "description": "Check the current status of a deployment without waiting. Use this after wait_for_deployment times out to see if the model is still loading or has failed.", + "parameters": { + "type": "object", + "properties": {"deployment_id": {"type": "integer"}}, + "required": ["deployment_id"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "test_deployment_endpoint", + "description": "Test if the deployment API endpoint is responding. Use this to check if a model is ready even if wait_for_deployment timed out.", + "parameters": { + "type": "object", + "properties": {"deployment_id": {"type": "integer"}}, + "required": ["deployment_id"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "finish_tuning", + "description": "Complete the tuning process with final recommendation.", + "parameters": { + "type": "object", + "properties": { + "best_config": { + "type": "object", + "description": "The recommended configuration", + }, + "reasoning": { + "type": "string", + "description": "Explanation of why this is the best config", + }, + "all_results": { + "type": "array", + "description": "All benchmark results collected", + }, + }, + "required": ["best_config", "reasoning"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "abort_tuning", + "description": "Abort the tuning process when it cannot be completed (e.g., GPU memory used by external processes, hardware issues).", + "parameters": { + "type": "object", + "properties": { + "reason": { + "type": "string", + "description": "Explanation of why tuning cannot be completed", + } + }, + "required": ["reason"], + }, + }, + }, + ] diff --git a/backend/app/services/tuning_agent.py b/backend/app/services/tuning_agent.py deleted file mode 100644 index 9d4fead..0000000 --- a/backend/app/services/tuning_agent.py +++ /dev/null @@ -1,1610 +0,0 @@ -""" -Auto-Tuning Agent Service - -A true LLM-driven agent that: -1. Uses an LLM to reason about configurations -2. Actually deploys models with different configs -3. Runs real benchmarks against deployed endpoints -4. Analyzes results and decides next steps -""" - -import asyncio -import json -import logging -import time -from datetime import UTC, datetime - -import httpx -from openai import AsyncOpenAI -from sqlalchemy import select -from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import selectinload - -from app.config import get_settings -from app.database import async_session_maker -from app.models.deployment import Deployment, DeploymentStatus -from app.models.llm_model import LLMModel -from app.models.tuning import PerformanceKnowledge, TuningJob, TuningJobStatus -from app.models.worker import Worker - -logger = logging.getLogger(__name__) - - -# ============================================================================= -# Agent System Prompt -# ============================================================================= - -AGENT_SYSTEM_PROMPT = """You are an Auto-Tuning Agent helping to find the optimal deployment configuration for LLM models. - -IMPORTANT COMMUNICATION RULES: -1. ALWAYS explain what you're about to do BEFORE calling any tool -2. After each tool result, briefly summarize what you learned -3. Be conversational - talk like you're explaining to a colleague -4. No emojis, keep it professional but friendly - -=== OPTIMIZATION TARGETS === - -**Throughput** (tokens per second): -- Goal: Maximize TPS for batch processing / high volume -- Strategy: Use vLLM with large batch sizes, enable continuous batching -- Key metric: throughput_tps (higher is better) -- Trade-off: May have higher latency per request - -**Latency** (response time): -- Goal: Minimize time-to-first-token (TTFT) and time-per-output-token (TPOT) -- Strategy: Use smaller batch sizes, consider sglang for multi-turn -- Key metrics: avg_ttft_ms, avg_tpot_ms (lower is better) -- Trade-off: Lower overall throughput - -**Balanced**: -- Goal: Good balance between throughput and latency -- Strategy: Test multiple configs, calculate combined score -- Score formula: throughput_tps / (avg_ttft_ms * 0.01) - balance speed and responsiveness -- Pick config with best combined score - -**Cost** (minimum resources): -- Goal: Use minimum GPU memory while maintaining acceptable performance -- Strategy: Try quantization (awq, gptq), use fewer GPUs if possible -- Key consideration: memory_used_gb, still need decent throughput -- Trade-off: May sacrifice some performance for efficiency - -=== AVAILABLE ENGINES === -- vllm: Best throughput, tensor parallelism, supports fp8/awq/gptq quantization -- sglang: Good for multi-turn, efficient memory, fast prefix caching -- ollama: Simple deployment, good for smaller models, easy setup - -=== QUANTIZATION NOTES === -- AWQ/GPTQ: Requires a pre-quantized model (e.g., "Qwen/Qwen3-0.6B-AWQ") - Do NOT use quantization=awq with a base model like "Qwen/Qwen3-0.6B" -- FP8: Only works on Hopper+ GPUs (H100, etc.), not consumer GPUs -- For consumer GPUs (RTX 4090, etc.), use default FP16 or find a pre-quantized model - -=== PROCESS === -1. Check hardware (GPU model, VRAM, count) -2. Query knowledge base for similar setups -3. Based on optimization target, choose 2-3 promising configs to test -4. For EACH config: - a. Deploy model - b. Wait for deployment (use short timeout like 120s) - c. If timeout/slow: Check logs with get_deployment_logs to diagnose - d. If failed: STOP deployment, analyze error, try next config - e. If success: Run benchmark, record results, STOP deployment -5. Compare all results, call finish_tuning with recommendation - -=== DIAGNOSING DEPLOYMENT ISSUES === -When wait_for_deployment times out: -1. Call test_deployment_endpoint ONCE to check if API is ready - - If ready=true: Proceed to run_benchmark immediately - - If ready=false: Call get_deployment_logs ONCE -2. Based on logs, make a QUICK decision: - - If "Loading model" in logs: Wait ONE more time with wait_for_deployment(timeout_seconds=120) - - If any error: Call stop_deployment and try next config - - If unclear: Call stop_deployment and try next config - -STRICT RULES TO AVOID LOOPS: -- Maximum 2 calls to wait_for_deployment per config -- Maximum 2 calls to test_deployment_endpoint per config -- If deployment not ready after 2 waits, STOP it and move to next config -- Do NOT repeatedly check status - make a decision and move on! -- Small models (< 1B) should load in 60s, if not working after 2 attempts, skip it - -A 0.6B model should load in under 60 seconds. If it takes longer, something is wrong. - -=== HANDLING LOW GPU MEMORY === -If deploy_model fails with "GPU memory is low": -1. Call list_deployments(worker_id=X) to find existing deployments -2. Stop all running deployments using stop_deployment(deployment_id=X) -3. If no deployments found, GPU is used by external processes - inform user -4. After stopping, retry deploy_model - -IMPORTANT: ALWAYS stop a deployment before starting a new one! -- If deployment times out → check logs, then stop_deployment -- If deployment fails → stop_deployment immediately -- After benchmark complete → stop_deployment before next test -- Never have multiple test deployments running at once! - -=== EXAMPLE FLOW === -"Let me first check what hardware we have available..." -[call get_hardware_info] -"I can see we have 1x RTX 4090 with 24GB VRAM. Let me check if we have historical data..." -[call query_knowledge_base] -"No historical data found. Since we're optimizing for throughput, I'll test vLLM first..." -[call deploy_model] -"Deployment created with ID 1. Let me wait for it to become ready..." -[call wait_for_deployment(deployment_id=1, timeout_seconds=120)] --- If timeout occurs -- -"Wait timed out after 120s. Let me first test if the endpoint is actually ready..." -[call test_deployment_endpoint(deployment_id=1)] --- If ready=true -- -"The endpoint is responding! The model is ready. Let me run the benchmark now..." -[call run_benchmark(deployment_id=1)] --- If ready=false -- -"Endpoint not ready yet. Let me check the container logs..." -[call get_deployment_logs(deployment_id=1, tail=100)] -"I see from the logs: 'Loading checkpoint shards: 3/4 (75%)' - model is still loading. -Let me test the endpoint again in a moment..." -[call test_deployment_endpoint(deployment_id=1)] --- Keep testing until ready, then run benchmark -- --- OR if logs show an error -- -"The logs show 'CUDA out of memory'. I need to stop and try a different config..." -[call stop_deployment(deployment_id=1)] -[call deploy_model with different params] - -ALWAYS provide context. Never call tools silently. -ALWAYS test endpoint and check logs before giving up on a deployment. -""" - - -# ============================================================================= -# Tools for the Agent -# ============================================================================= - - -def get_agent_tools() -> list[dict]: - """Define tools available to the agent""" - return [ - { - "type": "function", - "function": { - "name": "get_hardware_info", - "description": "Get detailed hardware information for a worker node including GPU model, VRAM, count, and current utilization.", - "parameters": { - "type": "object", - "properties": { - "worker_id": {"type": "integer", "description": "ID of the worker to query"} - }, - "required": ["worker_id"], - }, - }, - }, - { - "type": "function", - "function": { - "name": "get_model_info", - "description": "Get information about the model to be deployed.", - "parameters": { - "type": "object", - "properties": { - "model_id": {"type": "integer", "description": "ID of the model"} - }, - "required": ["model_id"], - }, - }, - }, - { - "type": "function", - "function": { - "name": "query_knowledge_base", - "description": "Query historical performance data for similar model/hardware combinations.", - "parameters": { - "type": "object", - "properties": { - "model_family": { - "type": "string", - "description": "Model family (e.g., Qwen, Llama, Mistral)", - }, - "gpu_model": { - "type": "string", - "description": "GPU model pattern (e.g., RTX 4090, A100)", - }, - }, - "required": [], - }, - }, - }, - { - "type": "function", - "function": { - "name": "list_deployments", - "description": "List all deployments on a worker. Use this to find existing deployments that may be using GPU memory.", - "parameters": { - "type": "object", - "properties": { - "worker_id": {"type": "integer", "description": "Worker ID to query"} - }, - "required": ["worker_id"], - }, - }, - }, - { - "type": "function", - "function": { - "name": "deploy_model", - "description": "Deploy a model with specific configuration. Returns deployment ID if successful.", - "parameters": { - "type": "object", - "properties": { - "model_id": {"type": "integer"}, - "worker_id": {"type": "integer"}, - "engine": {"type": "string", "enum": ["vllm", "sglang", "ollama"]}, - "gpu_indexes": { - "type": "array", - "items": {"type": "integer"}, - "description": "GPU indices to use", - }, - "extra_params": { - "type": "object", - "description": "Additional engine parameters", - }, - }, - "required": ["model_id", "worker_id", "engine"], - }, - }, - }, - { - "type": "function", - "function": { - "name": "wait_for_deployment", - "description": "Wait for a deployment to be ready (running status).", - "parameters": { - "type": "object", - "properties": { - "deployment_id": {"type": "integer"}, - "timeout_seconds": { - "type": "integer", - "default": 300, - "description": "Maximum time to wait", - }, - }, - "required": ["deployment_id"], - }, - }, - }, - { - "type": "function", - "function": { - "name": "run_benchmark", - "description": "Run performance benchmark on a running deployment. Returns throughput, TTFT, TPOT metrics.", - "parameters": { - "type": "object", - "properties": { - "deployment_id": {"type": "integer"}, - "num_requests": { - "type": "integer", - "default": 20, - "description": "Number of requests to send", - }, - "concurrency": { - "type": "integer", - "default": 4, - "description": "Concurrent requests", - }, - "input_tokens": { - "type": "integer", - "default": 128, - "description": "Approximate input token count", - }, - "output_tokens": { - "type": "integer", - "default": 64, - "description": "Max output tokens", - }, - }, - "required": ["deployment_id"], - }, - }, - }, - { - "type": "function", - "function": { - "name": "stop_deployment", - "description": "Stop and remove a deployment.", - "parameters": { - "type": "object", - "properties": {"deployment_id": {"type": "integer"}}, - "required": ["deployment_id"], - }, - }, - }, - { - "type": "function", - "function": { - "name": "get_deployment_logs", - "description": "Get the Docker container logs for a deployment. Use this to check why a deployment is slow to start or failing.", - "parameters": { - "type": "object", - "properties": { - "deployment_id": {"type": "integer"}, - "tail": { - "type": "integer", - "default": 100, - "description": "Number of log lines to retrieve", - }, - }, - "required": ["deployment_id"], - }, - }, - }, - { - "type": "function", - "function": { - "name": "check_deployment_status", - "description": "Check the current status of a deployment without waiting. Use this after wait_for_deployment times out to see if the model is still loading or has failed.", - "parameters": { - "type": "object", - "properties": {"deployment_id": {"type": "integer"}}, - "required": ["deployment_id"], - }, - }, - }, - { - "type": "function", - "function": { - "name": "test_deployment_endpoint", - "description": "Test if the deployment API endpoint is responding. Use this to check if a model is ready even if wait_for_deployment timed out.", - "parameters": { - "type": "object", - "properties": {"deployment_id": {"type": "integer"}}, - "required": ["deployment_id"], - }, - }, - }, - { - "type": "function", - "function": { - "name": "finish_tuning", - "description": "Complete the tuning process with final recommendation.", - "parameters": { - "type": "object", - "properties": { - "best_config": { - "type": "object", - "description": "The recommended configuration", - }, - "reasoning": { - "type": "string", - "description": "Explanation of why this is the best config", - }, - "all_results": { - "type": "array", - "description": "All benchmark results collected", - }, - }, - "required": ["best_config", "reasoning"], - }, - }, - }, - { - "type": "function", - "function": { - "name": "abort_tuning", - "description": "Abort the tuning process when it cannot be completed (e.g., GPU memory used by external processes, hardware issues).", - "parameters": { - "type": "object", - "properties": { - "reason": { - "type": "string", - "description": "Explanation of why tuning cannot be completed", - } - }, - "required": ["reason"], - }, - }, - }, - ] - - -# ============================================================================= -# Tool Implementations -# ============================================================================= - - -class AgentToolExecutor: - """Execute agent tools with real system interactions""" - - def __init__(self, db: AsyncSession, job: TuningJob): - self.db = db - self.job = job - self.created_deployments: list[int] = [] - self.benchmark_results: list[dict] = [] # Track completed benchmarks - self.hardware_checked: bool = False # Track if hardware was checked - - async def execute(self, tool_name: str, args: dict) -> str: - """Execute a tool and return result as string""" - try: - method = getattr(self, f"_tool_{tool_name}", None) - if method: - result = await method(**args) - return json.dumps(result, indent=2, default=str) - return json.dumps({"error": f"Unknown tool: {tool_name}"}) - except Exception as e: - logger.error(f"Tool {tool_name} failed: {e}") - return json.dumps({"error": str(e)}) - - async def _tool_get_hardware_info(self, worker_id: int) -> dict: - """Get hardware info for a worker""" - self.hardware_checked = True # Mark hardware as checked - result = await self.db.execute(select(Worker).where(Worker.id == worker_id)) - worker = result.scalar_one_or_none() - - if not worker: - return {"error": "Worker not found"} - - gpu_info = worker.gpu_info or [] - - # Determine the unit divisor from memory_total (which is always large) - # memory_total for a typical GPU should be 8-80 GB - def get_divisor(memory_total: int | float) -> float: - """Determine the divisor to convert memory values to GB. - - We use memory_total to figure out what unit the values are in: - - If memory_total > 1 billion: values are in bytes - - If memory_total > 1 million: values are in KB - - If memory_total > 1000: values are in MB - - Otherwise: values are already in GB - """ - if memory_total > 1_000_000_000: - return 1024 * 1024 * 1024 # bytes to GB - elif memory_total > 1_000_000: - return 1024 * 1024 # KB to GB - elif memory_total > 1000: - return 1024 # MB to GB - else: - return 1 # already GB - - def convert_gpu_memory(gpu: dict) -> dict: - """Convert a single GPU's memory values to GB.""" - mem_total = gpu.get("memory_total", 0) - mem_used = gpu.get("memory_used", 0) - - divisor = get_divisor(mem_total) - - return { - "memory_total_gb": round(mem_total / divisor, 1) if mem_total else 0, - "memory_used_gb": round(mem_used / divisor, 1) if mem_used else 0, - "memory_free_gb": round((mem_total - mem_used) / divisor, 1) if mem_total else 0, - } - - # Convert GPU memory values - gpus_converted = [] - total_vram_gb = 0 - for i, g in enumerate(gpu_info): - mem = convert_gpu_memory(g) - gpus_converted.append( - { - "index": g.get("index", i), - "name": g.get("name", "Unknown"), - "memory_total_gb": mem["memory_total_gb"], - "memory_used_gb": mem["memory_used_gb"], - "memory_free_gb": mem["memory_free_gb"], - "utilization_percent": g.get("utilization_gpu", 0), - } - ) - total_vram_gb += mem["memory_total_gb"] - - return { - "worker_id": worker.id, - "worker_name": worker.name, - "status": worker.status, - "gpu_count": len(gpu_info), - "gpus": gpus_converted, - "total_vram_gb": round(total_vram_gb, 1), - } - - async def _tool_get_model_info(self, model_id: int) -> dict: - """Get model info""" - result = await self.db.execute(select(LLMModel).where(LLMModel.id == model_id)) - model = result.scalar_one_or_none() - - if not model: - return {"error": "Model not found"} - - # Extract model family from name - model_family = _extract_model_family(model.name) - - return { - "model_id": model.id, - "name": model.name, - "model_id_hf": model.model_id, - "source": model.source, - "model_family": model_family, - "default_backend": model.backend, - } - - async def _tool_query_knowledge_base( - self, model_family: str | None = None, gpu_model: str | None = None - ) -> dict: - """Query knowledge base for similar configurations""" - stmt = select(PerformanceKnowledge) - - if model_family: - stmt = stmt.where(PerformanceKnowledge.model_family.ilike(f"%{model_family}%")) - if gpu_model: - stmt = stmt.where(PerformanceKnowledge.gpu_model.ilike(f"%{gpu_model}%")) - - stmt = stmt.order_by(PerformanceKnowledge.score.desc().nulls_last()).limit(5) - - result = await self.db.execute(stmt) - records = result.scalars().all() - - if not records: - return { - "found": 0, - "message": "No historical data found. You'll need to run benchmarks to gather data.", - "records": [], - } - - return { - "found": len(records), - "records": [ - { - "model_name": r.model_name, - "model_family": r.model_family, - "gpu_model": r.gpu_model, - "gpu_count": r.gpu_count, - "engine": r.engine, - "quantization": r.quantization, - "tensor_parallel": r.tensor_parallel, - "throughput_tps": r.throughput_tps, - "ttft_ms": r.ttft_ms, - "tpot_ms": r.tpot_ms, - "score": r.score, - } - for r in records - ], - } - - async def _tool_list_deployments(self, worker_id: int) -> dict: - """List all deployments on a worker""" - try: - result = await self.db.execute( - select(Deployment) - .where(Deployment.worker_id == worker_id) - .options(selectinload(Deployment.model)) - ) - deployments = result.scalars().all() - - if not deployments: - return { - "worker_id": worker_id, - "count": 0, - "deployments": [], - "message": "No deployments found on this worker. GPU memory may be used by processes outside LMStack.", - } - - deployment_list = [] - for d in deployments: - deployment_list.append( - { - "deployment_id": d.id, - "name": d.name, - "model_name": d.model.name if d.model else "Unknown", - "status": d.status, - "backend": d.backend, - "port": d.port, - "container_id": d.container_id[:12] if d.container_id else None, - } - ) - - return { - "worker_id": worker_id, - "count": len(deployments), - "deployments": deployment_list, - "message": f"Found {len(deployments)} deployment(s). Stop running deployments to free GPU memory.", - } - except Exception as e: - logger.exception(f"Failed to list deployments: {e}") - return {"error": str(e)} - - async def _tool_deploy_model( - self, - model_id: int, - worker_id: int, - engine: str, - gpu_indexes: list[int] | None = None, - extra_params: dict | None = None, - ) -> dict: - """Deploy a model with specific configuration""" - from app.services.deployer import DeployerService - - try: - # Check if there are any pending deployments from this tuning job - if self.created_deployments: - return { - "success": False, - "error": f"You still have active deployments: {self.created_deployments}. " - f"Please stop them first using stop_deployment before creating a new one.", - } - - # Check GPU memory availability - worker_result = await self.db.execute(select(Worker).where(Worker.id == worker_id)) - worker = worker_result.scalar_one_or_none() - if worker and worker.gpu_info: - for g in worker.gpu_info: - mem_total = g.get("memory_total", 0) - mem_used = g.get("memory_used", 0) - # Check if less than 20% memory is free - if mem_total > 0 and (mem_total - mem_used) / mem_total < 0.2: - free_pct = round((mem_total - mem_used) / mem_total * 100, 1) - return { - "success": False, - "error": f"GPU memory is low (only {free_pct}% free). " - f"Please stop any existing deployments first.", - } - - # Get model to generate deployment name - model_result = await self.db.execute(select(LLMModel).where(LLMModel.id == model_id)) - model = model_result.scalar_one_or_none() - if not model: - return {"success": False, "error": "Model not found"} - - # Generate unique deployment name - import time - - deploy_name = f"tuning-{model.name.replace('/', '-')[:30]}-{int(time.time())}" - - # Create deployment - deployment = Deployment( - name=deploy_name, - model_id=model_id, - worker_id=worker_id, - backend=engine, - gpu_indexes=gpu_indexes or [0], - extra_params=extra_params or {}, - status=DeploymentStatus.PENDING.value, - ) - - self.db.add(deployment) - await self.db.commit() - await self.db.refresh(deployment) - - self.created_deployments.append(deployment.id) - - # Start deployment in background using DeployerService - deployer = DeployerService() - asyncio.create_task(deployer.deploy(deployment.id)) - - return { - "success": True, - "deployment_id": deployment.id, - "deployment_name": deploy_name, - "config": { - "engine": engine, - "gpu_indexes": gpu_indexes or [0], - "extra_params": extra_params, - }, - "message": "Deployment created. Use wait_for_deployment to wait until ready.", - } - except Exception as e: - logger.exception(f"Failed to deploy model: {e}") - return {"success": False, "error": str(e)} - - async def _tool_wait_for_deployment( - self, deployment_id: int, timeout_seconds: int = 300 - ) -> dict: - """Wait for deployment to be ready""" - start_time = time.time() - - while time.time() - start_time < timeout_seconds: - result = await self.db.execute( - select(Deployment) - .where(Deployment.id == deployment_id) - .options(selectinload(Deployment.worker)) - ) - deployment = result.scalar_one_or_none() - - if not deployment: - return { - "success": False, - "deployment_id": deployment_id, - "error": "Deployment not found. It may have been deleted.", - } - - # Update job status to show deployment progress - elapsed = int(time.time() - start_time) - status_map = { - "pending": "Preparing deployment...", - "starting": "Loading model into GPU memory...", - "running": "Model ready!", - "error": "Deployment failed", - "stopped": "Deployment stopped", - } - status_desc = status_map.get(deployment.status, deployment.status) - self.job.status_message = f"Waiting for model: {status_desc} ({elapsed}s)" - self.job.progress = { - "step": self.job.progress.get("step", 0) if self.job.progress else 0, - "total_steps": ( - self.job.progress.get("total_steps", 15) if self.job.progress else 15 - ), - "step_name": "wait_for_deployment", - "step_description": f"Deployment #{deployment_id}: {status_desc}", - "deployment_status": deployment.status, - "deployment_message": deployment.status_message or "", - "elapsed_seconds": elapsed, - } - await self.db.commit() - - if deployment.status == DeploymentStatus.RUNNING.value: - return { - "success": True, - "deployment_id": deployment_id, - "status": "running", - "port": deployment.port, - "endpoint": f"http://{deployment.worker.address.split(':')[0] if deployment.worker else 'localhost'}:{deployment.port}/v1", - "wait_time_seconds": round(time.time() - start_time, 1), - } - elif deployment.status in [ - DeploymentStatus.ERROR.value, - DeploymentStatus.STOPPED.value, - ]: - return { - "success": False, - "deployment_id": deployment_id, - "status": deployment.status, - "error": deployment.status_message or "Deployment failed", - "action_required": "Call stop_deployment to clean up before trying again", - } - - await asyncio.sleep(5) - - return { - "success": False, - "deployment_id": deployment_id, - "error": f"Timeout after {timeout_seconds}s", - "action_required": ( - f"1. Call test_deployment_endpoint({deployment_id}) to check if it's actually ready\n" - f"2. If not ready, call stop_deployment({deployment_id}) and try next config\n" - f"DO NOT wait again - move on to the next configuration!" - ), - } - - async def _tool_run_benchmark( - self, - deployment_id: int, - num_requests: int = 20, - concurrency: int = 4, - input_tokens: int = 128, - output_tokens: int = 64, - ) -> dict: - """Run actual benchmark against deployment""" - result = await self.db.execute( - select(Deployment) - .where(Deployment.id == deployment_id) - .options( - selectinload(Deployment.worker), - selectinload(Deployment.model), - ) - ) - deployment = result.scalar_one_or_none() - - if not deployment: - return {"error": "Deployment not found"} - - if deployment.status != DeploymentStatus.RUNNING.value: - return {"error": f"Deployment is not running (status: {deployment.status})"} - - # Build endpoint URL - worker = deployment.worker - worker_ip = worker.address.split(":")[0] - base_url = f"http://{worker_ip}:{deployment.port}/v1" - - # Get the model name for API calls - model_name = deployment.model.model_id if deployment.model else "default" - - # Run benchmark - metrics = await _run_http_benchmark( - base_url=base_url, - model_name=model_name, - num_requests=num_requests, - concurrency=concurrency, - input_tokens=input_tokens, - output_tokens=output_tokens, - ) - - # Save successful benchmark results for tracking - if metrics.get("success"): - self.benchmark_results.append( - { - "deployment_id": deployment_id, - "engine": deployment.backend, - "gpu_indexes": deployment.gpu_indexes or [0], - "extra_params": deployment.extra_params or {}, - "metrics": metrics.get("metrics", {}), - } - ) - - return metrics - - async def _tool_stop_deployment(self, deployment_id: int) -> dict: - """Stop and remove a deployment""" - from app.services.deployer import DeployerService - - try: - # Get deployment - result = await self.db.execute(select(Deployment).where(Deployment.id == deployment_id)) - deployment = result.scalar_one_or_none() - - if not deployment: - return {"success": False, "error": "Deployment not found"} - - # Stop container if running - if deployment.container_id: - try: - deployer = DeployerService() - await deployer.stop(deployment_id) - except Exception as e: - logger.warning(f"Failed to stop container for deployment {deployment_id}: {e}") - - # Always update status to stopped first (in case delete fails) - deployment.status = DeploymentStatus.STOPPED.value - deployment.status_message = "Stopped by tuning agent" - await self.db.commit() - - # Then try to delete - try: - await self.db.delete(deployment) - await self.db.commit() - except Exception as e: - logger.warning(f"Failed to delete deployment {deployment_id}: {e}") - # Status is already stopped, so it's ok - - if deployment_id in self.created_deployments: - self.created_deployments.remove(deployment_id) - - return {"success": True, "message": f"Deployment {deployment_id} stopped and removed"} - except Exception as e: - logger.exception(f"Failed to stop deployment: {e}") - # Try to at least mark it stopped - try: - result = await self.db.execute( - select(Deployment).where(Deployment.id == deployment_id) - ) - deployment = result.scalar_one_or_none() - if deployment: - deployment.status = DeploymentStatus.STOPPED.value - await self.db.commit() - except Exception: - pass - return {"success": False, "error": str(e)} - - async def _tool_check_deployment_status(self, deployment_id: int) -> dict: - """Check the current status of a deployment without waiting""" - try: - result = await self.db.execute( - select(Deployment) - .where(Deployment.id == deployment_id) - .options(selectinload(Deployment.worker)) - ) - deployment = result.scalar_one_or_none() - - if not deployment: - return {"error": "Deployment not found"} - - return { - "deployment_id": deployment_id, - "status": deployment.status, - "status_message": deployment.status_message, - "container_id": deployment.container_id, - "port": deployment.port, - "backend": deployment.backend, - "is_ready": deployment.status == DeploymentStatus.RUNNING.value, - "is_failed": deployment.status == DeploymentStatus.ERROR.value, - "is_loading": deployment.status == DeploymentStatus.STARTING.value, - } - except Exception as e: - logger.exception(f"Failed to check deployment status: {e}") - return {"error": str(e)} - - async def _tool_test_deployment_endpoint(self, deployment_id: int) -> dict: - """Test if the deployment API endpoint is responding""" - try: - result = await self.db.execute( - select(Deployment) - .where(Deployment.id == deployment_id) - .options(selectinload(Deployment.worker)) - ) - deployment = result.scalar_one_or_none() - - if not deployment: - return {"error": "Deployment not found"} - - if not deployment.worker or not deployment.port: - return { - "deployment_id": deployment_id, - "ready": False, - "error": "Deployment not fully initialized (no worker or port)", - } - - # Build endpoint URL - worker = deployment.worker - worker_ip = worker.address.split(":")[0] - base_url = f"http://{worker_ip}:{deployment.port}/v1" - - # Test the /v1/models endpoint - async with httpx.AsyncClient(timeout=10.0) as client: - try: - response = await client.get(f"{base_url}/models") - if response.status_code == 200: - data = response.json() - models = data.get("data", []) - if models: - return { - "deployment_id": deployment_id, - "ready": True, - "endpoint": base_url, - "models": [m.get("id") for m in models], - "message": "Deployment is ready! You can now run benchmarks.", - } - else: - return { - "deployment_id": deployment_id, - "ready": False, - "endpoint": base_url, - "message": "API responding but no models loaded yet", - } - else: - return { - "deployment_id": deployment_id, - "ready": False, - "endpoint": base_url, - "status_code": response.status_code, - "message": f"API returned status {response.status_code}", - } - except httpx.ConnectError: - return { - "deployment_id": deployment_id, - "ready": False, - "endpoint": base_url, - "message": "Cannot connect to endpoint - container may still be starting", - } - except httpx.ReadTimeout: - return { - "deployment_id": deployment_id, - "ready": False, - "endpoint": base_url, - "message": "Connection timeout - model may still be loading", - } - except Exception as e: - logger.exception(f"Failed to test deployment endpoint: {e}") - return {"error": str(e)} - - async def _tool_get_deployment_logs(self, deployment_id: int, tail: int = 100) -> dict: - """Get Docker container logs for a deployment""" - from app.services.deployer import DeployerService - - try: - # Get deployment with worker - result = await self.db.execute( - select(Deployment) - .where(Deployment.id == deployment_id) - .options(selectinload(Deployment.worker)) - ) - deployment = result.scalar_one_or_none() - - if not deployment: - return {"error": "Deployment not found"} - - if not deployment.container_id: - return { - "deployment_id": deployment_id, - "status": deployment.status, - "error": "No container ID - deployment may not have started yet", - "status_message": deployment.status_message, - } - - # Use DeployerService to get logs (handles both local and remote) - deployer = DeployerService() - logs = await deployer.get_logs(deployment, tail=tail) - - return { - "deployment_id": deployment_id, - "container_id": deployment.container_id, - "status": deployment.status, - "status_message": deployment.status_message, - "logs": logs, - } - except Exception as e: - logger.exception(f"Failed to get deployment logs: {e}") - return {"error": str(e)} - - async def _tool_finish_tuning( - self, best_config: dict, reasoning: str, all_results: list | None = None - ) -> dict: - """Mark tuning as complete and save to knowledge base""" - # Validate that proper steps were completed - if not self.hardware_checked: - return { - "success": False, - "error": "Cannot finish tuning: You must call get_hardware_info first to check the GPU environment.", - "required_action": "Call get_hardware_info(worker_id=...) before finishing.", - } - - if not self.benchmark_results and not all_results: - return { - "success": False, - "error": "Cannot finish tuning: No benchmark results found. You must run at least one benchmark.", - "required_action": "Deploy a model, run run_benchmark(), then call finish_tuning with the results.", - } - - # Use tracked benchmark results if all_results not provided - if not all_results: - all_results = self.benchmark_results - - # Update job status - self.job.status = TuningJobStatus.COMPLETED.value - self.job.status_message = "Auto-tuning completed successfully" - self.job.best_config = {**best_config, "reasoning": reasoning} - self.job.all_results = all_results or [] - self.job.completed_at = datetime.now(UTC) - - # Update progress to 100% - # Use the total_steps from current progress (set during agent loop) or default - current_total = self.job.progress.get("total_steps", 20) if self.job.progress else 20 - self.job.current_step = current_total - self.job.total_steps = current_total - self.job.progress = { - "step": current_total, - "total_steps": current_total, - "step_name": "completed", - "step_description": "Tuning completed successfully", - "configs_tested": len(all_results) if all_results else 1, - "configs_total": len(all_results) if all_results else 1, - } - - # Save results to knowledge base - saved_count = 0 - if all_results: - # Get model and worker info for knowledge base - model = self.job.model - worker = self.job.worker - gpu_info = worker.gpu_info[0] if worker.gpu_info else {} - gpu_name = gpu_info.get("name", "Unknown GPU") - - for result in all_results: - metrics = result.get("metrics", {}) - if not metrics: - continue - - # Create knowledge record - knowledge = PerformanceKnowledge( - gpu_model=gpu_name, - gpu_count=len(result.get("gpu_indexes", [0])), - total_vram_gb=sum( - ( - g.get("memory_total", 0) / (1024**3) - if g.get("memory_total", 0) > 1_000_000 - else g.get("memory_total", 0) - ) - for g in (worker.gpu_info or []) - ), - model_name=model.name, - model_family=_extract_model_family(model.name), - engine=result.get("engine", best_config.get("engine", "vllm")), - quantization=result.get("extra_params", {}).get("quantization"), - tensor_parallel=len(result.get("gpu_indexes", [0])), - extra_args=result.get("extra_params"), - throughput_tps=metrics.get("throughput_tps", 0), - ttft_ms=metrics.get("avg_ttft_ms", 0), - tpot_ms=metrics.get("avg_tpot_ms", 0), - input_length=128, # Default test params - output_length=64, - concurrency=4, - score=metrics.get("throughput_tps", 0), # For throughput optimization - source_tuning_job_id=self.job.id, - ) - self.db.add(knowledge) - saved_count += 1 - - await self.db.commit() - - return { - "success": True, - "message": f"Tuning completed. Saved {saved_count} result(s) to knowledge base.", - "best_config": best_config, - "reasoning": reasoning, - } - - async def _tool_abort_tuning(self, reason: str) -> dict: - """Abort the tuning process""" - self.job.status = TuningJobStatus.FAILED.value - self.job.status_message = f"Aborted: {reason}" - self.job.completed_at = datetime.now(UTC) - - # Update progress to show aborted state - self.job.progress = { - "step": self.job.current_step, - "total_steps": self.job.total_steps, - "step_name": "aborted", - "step_description": reason, - } - - await self.db.commit() - - return {"success": True, "message": "Tuning aborted", "reason": reason} - - async def cleanup(self): - """Clean up any deployments created during tuning""" - for deployment_id in self.created_deployments: - try: - await self._tool_stop_deployment(deployment_id) - except Exception as e: - logger.warning(f"Failed to cleanup deployment {deployment_id}: {e}") - - -# ============================================================================= -# Benchmark Implementation -# ============================================================================= - - -async def _run_http_benchmark( - base_url: str, - model_name: str = "default", - num_requests: int = 20, - concurrency: int = 4, - input_tokens: int = 128, - output_tokens: int = 64, -) -> dict: - """Run actual HTTP benchmark against an OpenAI-compatible endpoint""" - - # Generate test prompt with approximate token count - test_prompt = "Write a detailed explanation about " + " ".join( - ["artificial intelligence"] * (input_tokens // 3) - ) - - results = [] - errors = 0 - - semaphore = asyncio.Semaphore(concurrency) - - async def make_request(client: httpx.AsyncClient) -> dict | None: - nonlocal errors - async with semaphore: - start_time = time.perf_counter() - first_token_time = None - token_times = [] - total_tokens = 0 - - try: - async with client.stream( - "POST", - f"{base_url}/chat/completions", - json={ - "model": model_name, - "messages": [{"role": "user", "content": test_prompt}], - "max_tokens": output_tokens, - "stream": True, - }, - timeout=60.0, - ) as response: - if response.status_code != 200: - errors += 1 - return None - - async for line in response.aiter_lines(): - if line.startswith("data: "): - data = line[6:] - if data == "[DONE]": - break - try: - chunk = json.loads(data) - content = ( - chunk.get("choices", [{}])[0] - .get("delta", {}) - .get("content", "") - ) - if content: - current_time = time.perf_counter() - if first_token_time is None: - first_token_time = current_time - token_times.append(current_time) - total_tokens += 1 - except json.JSONDecodeError: - pass - - end_time = time.perf_counter() - - if first_token_time and total_tokens > 0: - ttft = (first_token_time - start_time) * 1000 # ms - total_time = end_time - start_time - - # Calculate TPOT (time per output token) excluding TTFT - if total_tokens > 1: - generation_time = end_time - first_token_time - tpot = (generation_time / (total_tokens - 1)) * 1000 # ms - else: - tpot = 0 - - return { - "ttft_ms": ttft, - "tpot_ms": tpot, - "total_tokens": total_tokens, - "total_time_s": total_time, - } - except Exception as e: - logger.warning(f"Benchmark request failed: {e}") - errors += 1 - return None - - async with httpx.AsyncClient() as client: - # Warm up with a few requests - logger.info("Warming up benchmark endpoint...") - for _ in range(min(2, num_requests)): - await make_request(client) - - # Run actual benchmark - logger.info(f"Running {num_requests} benchmark requests with concurrency {concurrency}...") - tasks = [make_request(client) for _ in range(num_requests)] - results = await asyncio.gather(*tasks) - - # Filter out failed requests - valid_results = [r for r in results if r is not None] - - if not valid_results: - return {"success": False, "error": "All requests failed", "errors": errors} - - # Calculate metrics - ttft_values = [r["ttft_ms"] for r in valid_results] - tpot_values = [r["tpot_ms"] for r in valid_results if r["tpot_ms"] > 0] - total_tokens = sum(r["total_tokens"] for r in valid_results) - total_time = sum(r["total_time_s"] for r in valid_results) - - avg_ttft = sum(ttft_values) / len(ttft_values) - avg_tpot = sum(tpot_values) / len(tpot_values) if tpot_values else 0 - throughput = total_tokens / total_time if total_time > 0 else 0 - - return { - "success": True, - "metrics": { - "throughput_tps": round(throughput, 2), - "avg_ttft_ms": round(avg_ttft, 2), - "avg_tpot_ms": round(avg_tpot, 2), - "p50_ttft_ms": round(sorted(ttft_values)[len(ttft_values) // 2], 2), - "p99_ttft_ms": ( - round(sorted(ttft_values)[int(len(ttft_values) * 0.99)], 2) - if len(ttft_values) > 1 - else round(ttft_values[0], 2) - ), - }, - "summary": { - "total_requests": num_requests, - "successful_requests": len(valid_results), - "failed_requests": errors, - "total_tokens_generated": total_tokens, - }, - } - - -def _extract_model_family(model_name: str) -> str: - """Extract model family from name""" - name_lower = model_name.lower() - families = { - "qwen": "Qwen", - "llama": "Llama", - "mistral": "Mistral", - "deepseek": "DeepSeek", - "phi": "Phi", - "gemma": "Gemma", - "yi": "Yi", - "glm": "GLM", - } - for key, value in families.items(): - if key in name_lower: - return value - return "Unknown" - - -# ============================================================================= -# Main Agent Runner -# ============================================================================= - - -async def run_tuning_agent(job_id: int, llm_config: dict | None = None): - """Run the Auto-Tuning Agent for a job - - Args: - job_id: The tuning job ID - llm_config: Optional LLM configuration from chat panel: - - deployment_id: Use a local deployment - - base_url: Custom endpoint URL - - api_key: API key for the endpoint - - model: Model name - """ - settings = get_settings() - - async with async_session_maker() as db: - # Load job with relationships - result = await db.execute( - select(TuningJob) - .where(TuningJob.id == job_id) - .options( - selectinload(TuningJob.model), - selectinload(TuningJob.worker), - ) - ) - job = result.scalar_one_or_none() - - if not job: - logger.error(f"Tuning job {job_id} not found") - return - - # Initialize tool executor - executor = AgentToolExecutor(db, job) - - try: - # Determine LLM configuration (priority: llm_config > settings > auto-detect) - api_key = None - base_url = None - model_name = "gpt-4o" - - if llm_config: - # Use config from chat panel - if llm_config.get("deployment_id"): - # Use specified local deployment - from app.models.deployment import Deployment, DeploymentStatus - - deploy_result = await db.execute( - select(Deployment) - .where(Deployment.id == llm_config["deployment_id"]) - .options(selectinload(Deployment.worker), selectinload(Deployment.model)) - ) - deployment = deploy_result.scalar_one_or_none() - - if deployment and deployment.worker: - worker_ip = deployment.worker.address.split(":")[0] - base_url = f"http://{worker_ip}:{deployment.port}/v1" - api_key = "dummy" - model_name = deployment.model.model_id if deployment.model else model_name - logger.info( - f"Using specified deployment as agent LLM: {base_url} ({model_name})" - ) - else: - job.status = TuningJobStatus.FAILED.value - job.status_message = ( - f"Deployment {llm_config['deployment_id']} not found or not running" - ) - await db.commit() - return - elif llm_config.get("base_url"): - # Use custom endpoint - base_url = llm_config["base_url"] - api_key = llm_config.get("api_key") or "dummy" - model_name = llm_config.get("model") or model_name - logger.info(f"Using custom endpoint as agent LLM: {base_url} ({model_name})") - - # Fall back to settings if no llm_config - if not api_key: - api_key = settings.openai_api_key - base_url = settings.openai_base_url - model_name = settings.openai_model or model_name - - # If still no API key, try to find any running deployment - if not api_key: - from app.models.deployment import Deployment, DeploymentStatus - - deploy_result = await db.execute( - select(Deployment) - .where(Deployment.status == DeploymentStatus.RUNNING.value) - .options(selectinload(Deployment.worker), selectinload(Deployment.model)) - .limit(1) - ) - local_deployment = deploy_result.scalar_one_or_none() - - if local_deployment and local_deployment.worker: - worker_ip = local_deployment.worker.address.split(":")[0] - base_url = f"http://{worker_ip}:{local_deployment.port}/v1" - api_key = "dummy" - model_name = ( - local_deployment.model.model_id if local_deployment.model else model_name - ) - logger.info( - f"Auto-detected local deployment as agent LLM: {base_url} ({model_name})" - ) - else: - job.status = TuningJobStatus.FAILED.value - job.status_message = ( - "No LLM configured for Auto-Tuning Agent. " - "Please select a model in the chat panel, or deploy a model first." - ) - await db.commit() - return - - # Initialize OpenAI client (supports OpenAI-compatible endpoints) - client = AsyncOpenAI(api_key=api_key, base_url=base_url or "https://api.openai.com/v1") - - # Build initial user message with explicit steps - user_message = f"""Find the optimal deployment configuration for {job.model.name} on {job.worker.name}. -Optimization target: {job.optimization_target} -Model ID: {job.model_id}, Worker ID: {job.worker_id} - -REQUIRED STEPS (you must complete all of these): -1. Call get_hardware_info(worker_id={job.worker_id}) to check GPU specs -2. Call query_knowledge_base() to check historical data -3. Deploy the model with deploy_model() and wait for it -4. Run run_benchmark() to test performance -5. Stop the deployment and optionally test other configurations -6. Call finish_tuning() with best_config and all benchmark results - -Start with Step 1: get_hardware_info""" - - messages = [ - {"role": "system", "content": AGENT_SYSTEM_PROMPT}, - {"role": "user", "content": user_message}, - ] - - # Initialize conversation log for UI display - conversation_log = [ - { - "role": "user", - "content": user_message, - "timestamp": datetime.now(UTC).isoformat(), - } - ] - - # Helper to save conversation log - async def save_log(): - job.conversation_log = conversation_log - await db.commit() - - # Update job status - job.status = TuningJobStatus.ANALYZING.value - job.status_message = "Agent is analyzing the environment..." - job.conversation_log = conversation_log - await db.commit() - - # Agent loop - limit iterations to prevent infinite loops - max_iterations = 15 - iteration = 0 - - while iteration < max_iterations: - iteration += 1 - - # Check if cancelled - await db.refresh(job) - if job.status == TuningJobStatus.CANCELLED.value: - logger.info(f"Job {job_id} was cancelled") - await executor.cleanup() - return - - # Call LLM - logger.info(f"Agent iteration {iteration}, calling LLM with model: {model_name}...") - - # Force tool calls if essential steps not completed - # Use "required" to ensure tool is called when needed - if not executor.hardware_checked or ( - not executor.benchmark_results and iteration < 10 - ): - tool_choice = "required" - else: - tool_choice = "auto" - - response = await client.chat.completions.create( - model=model_name, - messages=messages, - tools=get_agent_tools(), - tool_choice=tool_choice, - max_tokens=4096, - ) - - assistant_message = response.choices[0].message - messages.append(assistant_message.model_dump(exclude_none=True)) - - # Add assistant message to conversation log - log_entry = { - "role": "assistant", - "content": assistant_message.content or "", - "timestamp": datetime.now(UTC).isoformat(), - } - if assistant_message.tool_calls: - log_entry["tool_calls"] = [ - { - "id": tc.id, - "name": tc.function.name, - "arguments": tc.function.arguments, - } - for tc in assistant_message.tool_calls - ] - conversation_log.append(log_entry) - await save_log() - - # Check if no tool calls - prompt the agent to take action - if not assistant_message.tool_calls: - logger.warning(f"Agent responded without tool calls at iteration {iteration}") - # Build a context-aware prompt based on current state - if not executor.hardware_checked: - prompt_message = ( - f"You must call get_hardware_info(worker_id={job.worker_id}) first " - "to check the GPU environment before proceeding." - ) - elif not executor.benchmark_results: - prompt_message = ( - "You must run at least one benchmark before finishing. " - f"Call deploy_model(model_id={job.model_id}, worker_id={job.worker_id}, engine='vllm') " - "to deploy the model, then run run_benchmark() after it's ready." - ) - else: - prompt_message = ( - "You have benchmark results. Call finish_tuning() with the best configuration " - "to complete the tuning process." - ) - messages.append({"role": "user", "content": prompt_message}) - conversation_log.append( - { - "role": "user", - "content": prompt_message, - "timestamp": datetime.now(UTC).isoformat(), - } - ) - await save_log() - continue # Continue the loop to get tool calls - - # Execute tool calls - for tool_call in assistant_message.tool_calls: - tool_name = tool_call.function.name - tool_args = json.loads(tool_call.function.arguments) - - logger.info(f"Executing tool: {tool_name}({tool_args})") - - # Update job progress - job.status_message = f"Executing: {tool_name}" - job.progress = { - "step": iteration, - "total_steps": max_iterations, - "step_name": tool_name, - "step_description": f"Executing {tool_name} with args: {tool_args}", - "configs_tested": 0, - "configs_total": 0, - } - await db.commit() - - # Execute tool - result = await executor.execute(tool_name, tool_args) - - # Add tool result to conversation log - conversation_log.append( - { - "role": "tool", - "tool_call_id": tool_call.id, - "name": tool_name, - "content": result, - "timestamp": datetime.now(UTC).isoformat(), - } - ) - await save_log() - - # Check if this was a termination tool - if tool_name == "finish_tuning": - logger.info(f"Agent completed tuning for job {job_id}") - return - if tool_name == "abort_tuning": - logger.info(f"Agent aborted tuning for job {job_id}") - return - - # Add tool result to messages - messages.append( - {"role": "tool", "tool_call_id": tool_call.id, "content": result} - ) - - # If we reached max iterations without finishing - job.status = TuningJobStatus.FAILED.value - job.status_message = "Agent reached maximum iterations without completing" - await db.commit() - - except Exception as e: - logger.exception(f"Agent error for job {job_id}: {e}") - job.status = TuningJobStatus.FAILED.value - job.status_message = f"Agent error: {str(e)}" - await db.commit() - - finally: - # Cleanup any test deployments - await executor.cleanup() diff --git a/backend/migrations/011_add_worker_mac_support.py b/backend/migrations/011_add_worker_mac_support.py new file mode 100644 index 0000000..3e66bc8 --- /dev/null +++ b/backend/migrations/011_add_worker_mac_support.py @@ -0,0 +1,85 @@ +""" +Migration: Add Mac native deployment support to workers table + +Adds os_type, gpu_type, and capabilities columns to track worker environment +and available backends (Docker, Ollama, MLX, llama.cpp). + +Run with: python -m migrations.011_add_worker_mac_support +""" + +import asyncio +import sys +from pathlib import Path + +# Add parent directory to path for imports +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from sqlalchemy import text +from sqlalchemy.ext.asyncio import create_async_engine + +from app.config import get_settings + + +async def column_exists(conn, table_name: str, column_name: str) -> bool: + """Check if a column exists in a table (SQLite compatible)""" + result = await conn.execute(text(f"PRAGMA table_info({table_name})")) + columns = [row[1] for row in result.fetchall()] + return column_name in columns + + +async def migrate(): + settings = get_settings() + engine = create_async_engine(settings.database_url, echo=True) + + async with engine.begin() as conn: + # Add os_type column + if not await column_exists(conn, "workers", "os_type"): + print("Adding 'os_type' column to workers table...") + await conn.execute( + text( + """ + ALTER TABLE workers ADD COLUMN os_type VARCHAR(50) DEFAULT 'linux' + """ + ) + ) + print("'os_type' column added!") + else: + print("'os_type' column already exists") + + # Add gpu_type column + if not await column_exists(conn, "workers", "gpu_type"): + print("Adding 'gpu_type' column to workers table...") + await conn.execute( + text( + """ + ALTER TABLE workers ADD COLUMN gpu_type VARCHAR(50) DEFAULT 'nvidia' + """ + ) + ) + print("'gpu_type' column added!") + else: + print("'gpu_type' column already exists") + + # Add capabilities column (JSON) + if not await column_exists(conn, "workers", "capabilities"): + print("Adding 'capabilities' column to workers table...") + await conn.execute( + text( + """ + ALTER TABLE workers ADD COLUMN capabilities JSON + """ + ) + ) + print("'capabilities' column added!") + else: + print("'capabilities' column already exists") + + print("\n" + "=" * 50) + print("Migration completed successfully!") + print("=" * 50) + + await engine.dispose() + + +if __name__ == "__main__": + asyncio.run(migrate()) diff --git a/frontend/src/components/DeploymentAdvancedForm.tsx b/frontend/src/components/DeploymentAdvancedForm.tsx index 62871e8..b141d1d 100644 --- a/frontend/src/components/DeploymentAdvancedForm.tsx +++ b/frontend/src/components/DeploymentAdvancedForm.tsx @@ -94,6 +94,14 @@ const VllmAdvancedParams = () => ( > + + + ; + return ( + + ); } return ( @@ -128,11 +132,20 @@ export function AgentChatView({ function EmptyState({ colors, isDark: _isDark, + onSendMessage, }: { colors: ThemeColors; isDark: boolean; + onSendMessage?: (message: string) => void; }) { const isDark = _isDark; + + const quickActions = [ + { icon: , text: "列出所有 Worker 狀態" }, + { icon: , text: "GPU 記憶體使用狀況" }, + { icon: , text: "列出所有容器" }, + ]; + return (
- +
- LMStack AI Agent + Agent
- Powered by MCP. Deploy models, run benchmarks, and optimize - configurations through natural language. + 部署模型、執行基準測試、查詢系統狀態
- {/* Capability cards */} + {/* Quick actions */}
- {[ - { icon: , text: "Deploy LLM models to GPU workers" }, - { icon: , text: "Run performance benchmarks" }, - { - icon: , - text: "Query knowledge base for optimal configs", - }, - ].map((item, idx) => ( + {quickActions.map((item, idx) => (
onSendMessage?.(item.text)} style={{ display: "flex", alignItems: "center", - gap: 12, - padding: "12px 16px", - borderRadius: 10, + gap: 10, + padding: "10px 14px", + borderRadius: 8, background: isDark ? "rgba(255,255,255,0.03)" : "rgba(0,0,0,0.02)", border: `1px solid ${isDark ? "rgba(255,255,255,0.06)" : "rgba(0,0,0,0.06)"}`, + cursor: "pointer", + transition: "background 0.15s", + }} + onMouseEnter={(e) => { + e.currentTarget.style.background = isDark + ? "rgba(255,255,255,0.06)" + : "rgba(0,0,0,0.04)"; + }} + onMouseLeave={(e) => { + e.currentTarget.style.background = isDark + ? "rgba(255,255,255,0.03)" + : "rgba(0,0,0,0.02)"; }} > -
{item.icon}
+
+ {item.icon} +
{item.text}
))}
- -
- Try: "Deploy Qwen-7B on Worker 1" or "What's the GPU memory status?" -
); } @@ -309,17 +319,19 @@ function MessageBlock({ {/* Avatar */}
- +
{/* Content */} @@ -342,19 +354,22 @@ function MessageBlock({ style={{ display: "flex", alignItems: "center", - gap: 10, - padding: "12px 14px", - borderRadius: 12, + gap: 8, + padding: "10px 12px", + borderRadius: 8, background: isDark - ? "rgba(59, 130, 246, 0.1)" - : "rgba(59, 130, 246, 0.06)", - border: `1px solid ${isDark ? "rgba(59, 130, 246, 0.2)" : "rgba(59, 130, 246, 0.15)"}`, + ? "rgba(255,255,255,0.03)" + : "rgba(0,0,0,0.02)", + border: `1px solid ${isDark ? "rgba(255,255,255,0.06)" : "rgba(0,0,0,0.06)"}`, marginBottom: 12, }} > - - - Analyzing your request... + + + 處理中... )} @@ -436,7 +451,7 @@ function PageReferenceCard({ colors, }: PageReferenceCardProps) { const getIcon = () => { - const iconStyle = { fontSize: 16, color: "#3b82f6" }; + const iconStyle = { fontSize: 14, color: colors.textMuted }; switch (reference.icon) { case "cluster": return ; @@ -462,66 +477,44 @@ function PageReferenceCard({ display: "flex", alignItems: "center", gap: 10, - padding: "10px 14px", - borderRadius: 10, - background: isDark - ? "rgba(59, 130, 246, 0.1)" - : "rgba(59, 130, 246, 0.06)", + padding: "8px 12px", + borderRadius: 6, + background: isDark ? "rgba(255,255,255,0.03)" : "rgba(0,0,0,0.02)", border: `1px solid ${ - isDark ? "rgba(59, 130, 246, 0.25)" : "rgba(59, 130, 246, 0.2)" + isDark ? "rgba(255,255,255,0.08)" : "rgba(0,0,0,0.08)" }`, cursor: "pointer", - transition: "all 0.2s ease", - minWidth: 180, + transition: "background 0.15s", }} onMouseEnter={(e) => { e.currentTarget.style.background = isDark - ? "rgba(59, 130, 246, 0.15)" - : "rgba(59, 130, 246, 0.1)"; - e.currentTarget.style.borderColor = isDark - ? "rgba(59, 130, 246, 0.4)" - : "rgba(59, 130, 246, 0.35)"; + ? "rgba(255,255,255,0.06)" + : "rgba(0,0,0,0.04)"; }} onMouseLeave={(e) => { e.currentTarget.style.background = isDark - ? "rgba(59, 130, 246, 0.1)" - : "rgba(59, 130, 246, 0.06)"; - e.currentTarget.style.borderColor = isDark - ? "rgba(59, 130, 246, 0.25)" - : "rgba(59, 130, 246, 0.2)"; + ? "rgba(255,255,255,0.03)" + : "rgba(0,0,0,0.02)"; }} > -
- {getIcon()} -
+ {getIcon()}
{reference.title}
-
- {reference.description} -
+ {reference.description && ( +
+ {reference.description} +
+ )}
- + ); } @@ -543,7 +536,7 @@ function ActionSuggestionButton({ colors, }: ActionSuggestionButtonProps) { const getIcon = () => { - const iconStyle = { fontSize: 14 }; + const iconStyle = { fontSize: 12 }; switch (suggestion.icon) { case "rocket": return ; @@ -558,60 +551,11 @@ function ActionSuggestionButton({ case "tool": return ; default: - return ; - } - }; - - const getButtonStyle = () => { - switch (suggestion.type) { - case "primary": - return { - background: isDark - ? "rgba(59, 130, 246, 0.15)" - : "rgba(59, 130, 246, 0.1)", - border: `1px solid ${ - isDark ? "rgba(59, 130, 246, 0.4)" : "rgba(59, 130, 246, 0.3)" - }`, - color: "#3b82f6", - hoverBg: isDark - ? "rgba(59, 130, 246, 0.25)" - : "rgba(59, 130, 246, 0.15)", - hoverBorder: isDark - ? "rgba(59, 130, 246, 0.6)" - : "rgba(59, 130, 246, 0.5)", - }; - case "danger": - return { - background: isDark - ? "rgba(239, 68, 68, 0.1)" - : "rgba(239, 68, 68, 0.06)", - border: `1px solid ${ - isDark ? "rgba(239, 68, 68, 0.3)" : "rgba(239, 68, 68, 0.2)" - }`, - color: "#ef4444", - hoverBg: isDark ? "rgba(239, 68, 68, 0.2)" : "rgba(239, 68, 68, 0.1)", - hoverBorder: isDark - ? "rgba(239, 68, 68, 0.5)" - : "rgba(239, 68, 68, 0.4)", - }; - default: - return { - background: isDark - ? "rgba(255, 255, 255, 0.05)" - : "rgba(0, 0, 0, 0.04)", - border: `1px solid ${ - isDark ? "rgba(255, 255, 255, 0.1)" : "rgba(0, 0, 0, 0.1)" - }`, - color: colors.textSecondary, - hoverBg: isDark ? "rgba(255, 255, 255, 0.1)" : "rgba(0, 0, 0, 0.08)", - hoverBorder: isDark - ? "rgba(255, 255, 255, 0.2)" - : "rgba(0, 0, 0, 0.2)", - }; + return ; } }; - const style = getButtonStyle(); + const isDanger = suggestion.type === "danger"; return (
{ - e.currentTarget.style.background = style.hoverBg; - e.currentTarget.style.borderColor = style.hoverBorder; + e.currentTarget.style.background = isDanger + ? isDark + ? "rgba(239, 68, 68, 0.15)" + : "rgba(239, 68, 68, 0.1)" + : isDark + ? "rgba(255, 255, 255, 0.08)" + : "rgba(0, 0, 0, 0.06)"; }} onMouseLeave={(e) => { - e.currentTarget.style.background = style.background; - e.currentTarget.style.borderColor = style.border.replace( - "1px solid ", - "", - ); + e.currentTarget.style.background = isDanger + ? isDark + ? "rgba(239, 68, 68, 0.1)" + : "rgba(239, 68, 68, 0.06)" + : isDark + ? "rgba(255, 255, 255, 0.04)" + : "rgba(0, 0, 0, 0.03)"; }} > {getIcon()} @@ -751,12 +715,10 @@ function StepItem({ step, onToggle, isDark, colors }: StepItemProps) { return (
{/* Header */} @@ -768,27 +730,27 @@ function StepItem({ step, onToggle, isDark, colors }: StepItemProps) { gap: 8, padding: "10px 12px", cursor: "pointer", - background: isDark - ? "rgba(139,92,246,0.08)" - : "rgba(139,92,246,0.05)", + background: isDark ? "rgba(255,255,255,0.02)" : "rgba(0,0,0,0.02)", borderBottom: step.expanded - ? `1px solid ${isDark ? "rgba(139,92,246,0.2)" : "rgba(139,92,246,0.15)"}` + ? `1px solid ${isDark ? "rgba(255,255,255,0.06)" : "rgba(0,0,0,0.06)"}` : "none", }} > - + {step.title} {step.status === "running" && ( - + )} void; @@ -90,32 +60,6 @@ function savePanelState(state: Partial) { } } -/** - * Load chat mode from localStorage - */ -function loadChatMode(): ChatMode { - try { - const saved = localStorage.getItem(CHAT_MODE_STORAGE_KEY); - if (saved === "chat" || saved === "agent") { - return saved; - } - } catch { - // Ignore parse errors - } - return "agent"; // Default to agent mode -} - -/** - * Save chat mode to localStorage - */ -function saveChatMode(mode: ChatMode) { - try { - localStorage.setItem(CHAT_MODE_STORAGE_KEY, mode); - } catch { - // Ignore save errors - } -} - /** * Global chat panel component */ @@ -140,26 +84,8 @@ export function ChatPanel({ () => loadPanelState().customEndpoints || [], ); - // Chat mode state - const [chatMode, setChatMode] = useState(() => loadChatMode()); - - // Traditional chat state + // Input state const [inputValue, setInputValue] = useState(""); - const { - messages: chatMessages, - isStreaming: chatIsStreaming, - isExecutingTool, - currentToolName, - pendingTools, - showConfirmModal, - systemContext, - refreshContext, - sendMessage: chatSendMessage, - stopStreaming: chatStopStreaming, - clearMessages: chatClearMessages, - confirmToolExecution, - cancelToolExecution, - } = useChat(); // Agent chat state const { @@ -172,6 +98,7 @@ export function ChatPanel({ stopStreaming: agentStopStreaming, clearMessages: agentClearMessages, toggleStepExpanded, + startNewConversation, } = useAgentChat(); // Chat panel context - for external access @@ -185,49 +112,28 @@ export function ChatPanel({ }; }, [_registerSendFunction, agentSendMessage, selectedModel]); - // Derived state based on mode - const isStreaming = chatMode === "agent" ? agentIsStreaming : chatIsStreaming; - const hasMessages = - chatMode === "agent" ? agentMessages.length > 0 : chatMessages.length > 0; - - // Handle mode change - const handleModeChange = useCallback((value: string | number) => { - const newMode = value as ChatMode; - setChatMode(newMode); - saveChatMode(newMode); - }, []); + // Derived state + const isStreaming = agentIsStreaming; + const hasMessages = agentMessages.length > 0; - // Handle send message based on mode + // Handle send message const handleSendMessage = useCallback(() => { if (!inputValue.trim() || !selectedModel) return; - - if (chatMode === "agent") { - agentSendMessage(inputValue, selectedModel); - } else { - chatSendMessage(inputValue, selectedModel); - } + agentSendMessage(inputValue, selectedModel); setInputValue(""); - }, [inputValue, selectedModel, chatMode, agentSendMessage, chatSendMessage]); + }, [inputValue, selectedModel, agentSendMessage]); - // Handle stop streaming based on mode + // Handle stop streaming const handleStopStreaming = useCallback(() => { - if (chatMode === "agent") { - agentStopStreaming(); - } else { - chatStopStreaming(); - } - }, [chatMode, agentStopStreaming, chatStopStreaming]); + agentStopStreaming(); + }, [agentStopStreaming]); - // Handle clear messages based on mode + // Handle clear messages const handleClearMessages = useCallback(() => { - if (chatMode === "agent") { - agentClearMessages(); - } else { - chatClearMessages(); - } - }, [chatMode, agentClearMessages, chatClearMessages]); + agentClearMessages(); + }, [agentClearMessages]); - // Handle sending message from action suggestions (agent mode) + // Handle sending message from action suggestions const handleAgentSendSuggestion = useCallback( (message: string) => { if (!selectedModel || agentIsStreaming) return; @@ -236,9 +142,25 @@ export function ChatPanel({ [selectedModel, agentIsStreaming, agentSendMessage], ); + // Handle model change - start new conversation when model changes + const handleModelChange = useCallback( + (model: ChatModelConfig | null) => { + // If model type or identity changes, start a new conversation + const modelChanged = + selectedModel?.type !== model?.type || + selectedModel?.deploymentId !== model?.deploymentId || + selectedModel?.endpoint !== model?.endpoint; + + if (modelChanged && agentMessages.length > 0) { + startNewConversation(); + } + setSelectedModel(model); + }, + [selectedModel, agentMessages.length, startNewConversation], + ); + // Refs const messagesContainerRef = useRef(null); - const messagesEndRef = useRef(null); const resizeHandleRef = useRef(null); const [isResizing, setIsResizing] = useState(false); const [showScrollButton, setShowScrollButton] = useState(false); @@ -336,18 +258,10 @@ export function ChatPanel({ // Scroll on new messages useEffect(() => { - const messageCount = - chatMode === "agent" ? agentMessages.length : chatMessages.length; - if (!isStreaming && messageCount > 0) { + if (!isStreaming && agentMessages.length > 0) { scrollToBottom(); } - }, [ - chatMode, - agentMessages.length, - chatMessages.length, - isStreaming, - scrollToBottom, - ]); + }, [agentMessages.length, isStreaming, scrollToBottom]); if (!isOpen) return null; @@ -426,7 +340,7 @@ export function ChatPanel({ {/* Left side: Model selector */} - {/* Right side: Mode toggle and actions */} + {/* Right side: Actions */}
- {/* Mode toggle */} - , - label: "Agent", - }, - { - value: "chat", - icon: , - label: "Chat", - }, - ]} - style={{ - background: isDark ? "#27272a" : "#f4f4f5", - }} - /> - - {/* System status indicator (only in chat mode) */} - {chatMode === "chat" && systemContext && ( - -
Workers: {systemContext.workers.length}
-
- Deployments:{" "} - { - systemContext.deployments.filter( - (d) => d.status === "running", - ).length - } -
-
Models: {systemContext.models.length}
-
- Click to refresh -
-
- } - > - - - )} - {/* Clear button */} {hasMessages && ( @@ -544,56 +391,19 @@ export function ChatPanel({ position: "relative", }} > - {chatMode === "agent" ? ( - /* Agent Mode - Claude Code style */ - - ) : /* Chat Mode - Traditional tool calling */ - chatMessages.length === 0 ? ( - - ) : ( -
- {chatMessages.map((msg, index) => { - const isLast = index === chatMessages.length - 1; - const showStreaming = - isLast && chatIsStreaming && msg.role === "assistant"; - const showToolExecution = - isLast && isExecutingTool && msg.role === "assistant"; - - return ( - - ); - })} -
-
- )} +
{/* Scroll to bottom button */} @@ -642,394 +452,6 @@ export function ChatPanel({
- - {/* Tool Confirmation Modal (only for chat mode) */} - {chatMode === "chat" && ( - - )} ); } - -/** - * Empty state component - */ -interface EmptyStateProps { - selectedModel: ChatModelConfig | null; - systemContext: import("./systemContext").SystemContext | null; - colors: ThemeColors; -} - -function EmptyState({ selectedModel, systemContext, colors }: EmptyStateProps) { - const activeDeployments = - systemContext?.deployments.filter((d) => d.status === "running") || []; - const runningContainers = - systemContext?.containers.filter( - (c) => - c.status.toLowerCase().includes("running") || - c.status.toLowerCase().includes("up"), - ) || []; - - return ( -
-
- -
-
- {selectedModel ? "LMStack AI Assistant" : "Select a Model"} -
-
- {selectedModel - ? "I can help you manage your LLM infrastructure" - : "Choose a model from the dropdown above"} -
- - {/* System summary */} - {selectedModel && systemContext && ( -
-
- System Overview -
-
-
- • {systemContext.workers.length} Worker - {systemContext.workers.length !== 1 ? "s" : ""} -
-
- • {runningContainers.length}/{systemContext.containers.length}{" "} - Container{systemContext.containers.length !== 1 ? "s" : ""}{" "} - running -
-
- • {activeDeployments.length} Model deployment - {activeDeployments.length !== 1 ? "s" : ""} active -
-
- • {systemContext.models.length} Model - {systemContext.models.length !== 1 ? "s" : ""} available -
-
- • {systemContext.images.length} Docker image - {systemContext.images.length !== 1 ? "s" : ""} -
-
-
- Try: "有幾個容器在運行?" or "GPU 記憶體剩多少?" -
-
- )} -
- ); -} - -/** - * Format tool name for display - */ -function formatToolName(name: string): string { - return name.replace(/_/g, " ").replace(/\b\w/g, (c) => c.toUpperCase()); -} - -/** - * Message bubble component - */ -interface MessageBubbleProps { - message: ChatMessage; - isStreaming: boolean; - isExecutingTool: boolean; - currentToolName: string | null; - isDark: boolean; - colors: ThemeColors; -} - -function MessageBubble({ - message, - isStreaming, - isExecutingTool, - currentToolName, - isDark, - colors, -}: MessageBubbleProps) { - const isUser = message.role === "user"; - - return ( -
- {/* Avatar */} - {!isUser && ( -
- -
- )} - - {/* Content */} -
- {isUser ? ( -
- {message.content} -
- ) : ( - <> - {/* Thinking section for reasoning models */} - {message.thinking && ( - - - Thinking Process - - ), - children: ( -
- {message.thinking} -
- ), - }, - ]} - /> - )} - - {/* Tool execution indicator - show when executing */} - {isExecutingTool && currentToolName && ( -
- - - Executing: {formatToolName(currentToolName)} - -
- )} - - {/* Tool calls display - always show in history */} - {message.toolCalls && - message.toolCalls.length > 0 && - !isExecutingTool && ( -
-
- - Tool Calls Executed -
- {message.toolCalls.map((tc) => { - let args: Record = {}; - try { - args = JSON.parse(tc.function.arguments); - } catch { - args = {}; - } - return ( -
-
- - - {formatToolName(tc.function.name)} - -
- {Object.keys(args).length > 0 && ( -
- {Object.entries(args).map(([key, value]) => ( -
- {key}:{" "} - {typeof value === "object" - ? JSON.stringify(value) - : String(value)} -
- ))} -
- )} -
- ); - })} -
- )} - - - {message.model && ( -
- via {message.model} -
- )} - - )} -
- - {/* User avatar */} - {isUser && ( -
- -
- )} -
- ); -} diff --git a/frontend/src/components/chat-panel/ToolConfirmModal.tsx b/frontend/src/components/chat-panel/ToolConfirmModal.tsx deleted file mode 100644 index 5627359..0000000 --- a/frontend/src/components/chat-panel/ToolConfirmModal.tsx +++ /dev/null @@ -1,218 +0,0 @@ -/** - * Tool Confirmation Modal - * - * A modal dialog that asks for user confirmation before executing - * AI-requested tool actions. - */ -import { Modal, Button, Tag, Descriptions } from "antd"; -import { - ExclamationCircleOutlined, - CheckCircleOutlined, - DeleteOutlined, - RocketOutlined, - PauseCircleOutlined, - PlayCircleOutlined, - PlusOutlined, - StopOutlined, - ThunderboltOutlined, - KeyOutlined, - CloudDownloadOutlined, - DatabaseOutlined, - PieChartOutlined, - ClearOutlined, -} from "@ant-design/icons"; -import type { PendingToolExecution } from "./tools"; - -interface ToolConfirmModalProps { - visible: boolean; - pendingTools: PendingToolExecution[]; - onConfirm: () => void; - onCancel: () => void; - isDark: boolean; -} - -/** - * Get icon for tool - */ -function getToolIcon(iconName: string) { - const iconMap: Record = { - delete: , - rocket: , - "pause-circle": , - "play-circle": , - plus: , - stop: , - thunderbolt: , - key: , - download: , - database: , - "pie-chart": , - clear: , - }; - return iconMap[iconName] || ; -} - -/** - * Format argument value for display - */ -function formatArgValue(value: any): string { - if (value === null || value === undefined) { - return "-"; - } - if (Array.isArray(value)) { - return value.join(", "); - } - if (typeof value === "object") { - return JSON.stringify(value); - } - return String(value); -} - -/** - * Tool Confirmation Modal Component - */ -export function ToolConfirmModal({ - visible, - pendingTools, - onConfirm, - onCancel, - isDark, -}: ToolConfirmModalProps) { - if (pendingTools.length === 0) return null; - - const hasDangerous = pendingTools.some((t) => t.meta.dangerous); - - return ( - - {hasDangerous ? ( - - ) : ( - - )} - Confirm Action - - } - onCancel={onCancel} - footer={[ - , - , - ]} - width={500} - centered - styles={{ - body: { - background: isDark ? "#1f1f1f" : "#ffffff", - }, - header: { - background: isDark ? "#1f1f1f" : "#ffffff", - }, - content: { - background: isDark ? "#1f1f1f" : "#ffffff", - }, - }} - > -
-
- AI assistant wants to execute the following actions: -
- - {pendingTools.map((tool, index) => ( -
- {/* Tool header */} -
- - {getToolIcon(tool.meta.icon)} - - - {tool.meta.displayName} - - {tool.meta.dangerous && ( - - Dangerous - - )} -
- - {/* Tool arguments */} - - {Object.entries(tool.parsedArgs).map(([key, value]) => ( - - {formatArgValue(value)} - - ))} - -
- ))} - - {hasDangerous && ( -
- - This action may be irreversible. Please confirm before proceeding. -
- )} -
-
- ); -} diff --git a/frontend/src/components/chat-panel/index.ts b/frontend/src/components/chat-panel/index.ts index cdda4aa..ca1b850 100644 --- a/frontend/src/components/chat-panel/index.ts +++ b/frontend/src/components/chat-panel/index.ts @@ -2,19 +2,15 @@ * Chat Panel Components * * Global chat panel for AI conversations accessible from any page. - * Supports two modes: - * - Agent Mode: MCP-based agent with Claude Code-style interaction - * - Chat Mode: Traditional tool calling via LLM API + * Uses MCP-based agent with Claude Code-style interaction. */ // Components export { ChatPanel } from "./ChatPanel"; export { ModelSelector } from "./ModelSelector"; -export { ToolConfirmModal } from "./ToolConfirmModal"; export { AgentChatView } from "./AgentChatView"; // Hooks -export { useChat } from "./useChat"; export { useAgentChat } from "./useAgentChat"; // Types @@ -25,14 +21,6 @@ export type { ModelSourceType, } from "./types"; -export type { - ToolDefinition, - ToolCall, - ToolResult, - ToolMeta, - PendingToolExecution, -} from "./tools"; - export type { AgentEventType, AgentEvent, @@ -47,11 +35,3 @@ export { MAX_PANEL_WIDTH, CHAT_PANEL_STORAGE_KEY, } from "./types"; - -export { - CHAT_TOOLS, - TOOL_META, - requiresConfirmation, - getToolMeta, - executeTool, -} from "./tools"; diff --git a/frontend/src/components/chat-panel/systemContext.ts b/frontend/src/components/chat-panel/systemContext.ts deleted file mode 100644 index a6631a1..0000000 --- a/frontend/src/components/chat-panel/systemContext.ts +++ /dev/null @@ -1,515 +0,0 @@ -/** - * System Context Builder - * - * Builds system context for the AI assistant to understand the current - * state of the LMStack platform. - */ -import { api } from "../../api/client"; - -export interface SystemContext { - workers: WorkerInfo[]; - deployments: DeploymentInfo[]; - models: ModelInfo[]; - containers: ContainerInfo[]; - images: ImageInfo[]; - storageVolumes: StorageVolumeInfo[]; - semanticRouter: SemanticRouterInfo | null; - timestamp: string; -} - -interface WorkerInfo { - id: number; - name: string; - host: string; - status: string; - gpus: GpuInfo[]; -} - -interface GpuInfo { - index: number; - name: string; - memoryTotal: number; - memoryUsed: number; - utilizationGpu: number; -} - -interface DeploymentInfo { - id: number; - name: string; - modelName: string; - workerName: string; - status: string; - endpoint?: string; -} - -interface ModelInfo { - id: number; - name: string; - source: string; - parameters?: string; - quantization?: string; -} - -interface ContainerInfo { - id: string; - name: string; - image: string; - status: string; - workerName: string; -} - -interface ImageInfo { - id: string; - name: string; - tag: string; - size: number; - workerName: string; -} - -interface StorageVolumeInfo { - name: string; - driver: string; - mountpoint: string; - workerName: string; -} - -interface SemanticRouterInfo { - deployed: boolean; - status: string; - models: string[]; -} - -/** - * Fetch current system state - */ -export async function fetchSystemContext(): Promise { - try { - const [ - workersRes, - deploymentsRes, - modelsRes, - containersRes, - imagesRes, - storageRes, - srStatus, - ] = await Promise.all([ - api.get("/workers").catch(() => ({ data: { items: [] } })), - api.get("/deployments").catch(() => ({ data: { items: [] } })), - api.get("/models").catch(() => ({ data: { items: [] } })), - api.get("/containers").catch(() => ({ data: { items: [] } })), - api.get("/images").catch(() => ({ data: { items: [] } })), - api.get("/storage/volumes").catch(() => ({ data: [] })), - api.get("/semantic-router/status").catch(() => ({ data: null })), - ]); - - const workers: WorkerInfo[] = (workersRes.data.items || []).map( - (w: any) => ({ - id: w.id, - name: w.name, - host: w.host, - status: w.status, - gpus: (w.gpu_info || []).map((g: any) => ({ - index: g.index, - name: g.name, - memoryTotal: g.memory_total, - memoryUsed: g.memory_used, - utilizationGpu: g.utilization_gpu, - })), - }), - ); - - const deployments: DeploymentInfo[] = (deploymentsRes.data.items || []).map( - (d: any) => ({ - id: d.id, - name: d.name, - modelName: d.model?.name || d.name, - workerName: d.worker?.name || "unknown", - status: d.status, - endpoint: - d.status === "running" ? `/api/deployments/${d.id}/chat` : undefined, - }), - ); - - const models: ModelInfo[] = (modelsRes.data.items || []).map((m: any) => ({ - id: m.id, - name: m.name, - source: m.source, - parameters: m.parameters, - quantization: m.quantization, - })); - - const containers: ContainerInfo[] = (containersRes.data.items || []).map( - (c: any) => ({ - id: c.id || c.container_id, - name: c.name, - image: c.image, - status: c.status, - workerName: c.worker?.name || c.worker_name || "unknown", - }), - ); - - const images: ImageInfo[] = (imagesRes.data.items || []).map((i: any) => ({ - id: i.id || i.image_id, - name: i.name || i.repository, - tag: i.tag || "latest", - size: i.size || 0, - workerName: i.worker?.name || i.worker_name || "unknown", - })); - - // Backend /storage/volumes returns a list directly, not { items: [] } - const storageVolumes: StorageVolumeInfo[] = ( - Array.isArray(storageRes.data) ? storageRes.data : [] - ).map((v: any) => ({ - name: v.name, - driver: v.driver || "local", - mountpoint: v.mountpoint || "", - workerName: v.worker_name || "unknown", - })); - - const semanticRouter: SemanticRouterInfo | null = srStatus.data - ? { - deployed: srStatus.data.deployed, - status: srStatus.data.status, - models: srStatus.data.models || [], - } - : null; - - return { - workers, - deployments, - models, - containers, - images, - storageVolumes, - semanticRouter, - timestamp: new Date().toISOString(), - }; - } catch (error) { - console.error("Failed to fetch system context:", error); - return { - workers: [], - deployments: [], - models: [], - containers: [], - images: [], - storageVolumes: [], - semanticRouter: null, - timestamp: new Date().toISOString(), - }; - } -} - -/** - * Page routes for navigation links - */ -const PAGE_ROUTES = { - dashboard: "/dashboard", - workers: "/workers", - containers: "/containers", - images: "/images", - storage: "/storage", - models: "/models", - deployments: "/deployments", - chat: "/chat", - apiKeys: "/api-keys", - settings: "/settings", -}; - -/** - * Action links that open modals directly - */ -const ACTION_LINKS = { - newDeployment: "/deployments?action=new", - newModel: "/models?action=new", - newApiKey: "/api-keys?action=new", -}; - -/** - * Format system context as a system message for the LLM - */ -export function formatSystemPrompt(context: SystemContext): string { - const lines: string[] = [ - "You are an AI assistant for the LMStack platform - an LLM deployment and management system.", - "You have access to real-time system information and can help users manage their AI infrastructure.", - "", - "## Navigation Links", - "When referencing pages, use markdown links so users can click to navigate:", - `- Workers: [Workers](${PAGE_ROUTES.workers})`, - `- Docker Containers: [Containers](${PAGE_ROUTES.containers})`, - `- Docker Images: [Images](${PAGE_ROUTES.images})`, - `- Storage Volumes: [Storage](${PAGE_ROUTES.storage})`, - `- Models: [Models](${PAGE_ROUTES.models})`, - `- Deployments: [Deployments](${PAGE_ROUTES.deployments})`, - `- API Keys: [API Keys](${PAGE_ROUTES.apiKeys})`, - "", - "## Quick Action Links (ALWAYS use for create/deploy/add operations)", - "**CRITICAL: For deploying models, adding models, or creating API keys, ALWAYS guide users to the UI. NEVER use deploy_model, add_model, or create_api_key tools directly.**", - "", - "These links open the action dialog directly:", - `- Deploy a Model: [New Deployment](${ACTION_LINKS.newDeployment})`, - `- Add a Model: [Add Model](${ACTION_LINKS.newModel})`, - `- Create API Key: [Create API Key](${ACTION_LINKS.newApiKey})`, - "", - "**TOOL USAGE RULES:**", - "- deploy_model, add_model, create_api_key → NEVER use these. Always guide to UI instead.", - "- stop_deployment, delete_deployment, stop_container, remove_container, delete_* → OK to use (destructive actions need confirmation)", - "- list_*, get_* → OK to use (query tools, no confirmation needed)", - "", - "**EXAMPLES:**", - `- User: '我想部署模型' → '請點擊 [New Deployment](${ACTION_LINKS.newDeployment}) 開啟部署表單。'`, - `- User: '幫我部署 Qwen' → '請點擊 [New Deployment](${ACTION_LINKS.newDeployment}) 來部署,選擇 Qwen 模型即可。'`, - `- User: '幫我新增模型' → '請點擊 [Add Model](${ACTION_LINKS.newModel}) 來新增模型。'`, - "- User: '有哪些模型?' → Use list_models tool", - "- User: '停止 deployment 1' → Use stop_deployment tool", - "", - "## Current System Status", - `Last updated: ${new Date(context.timestamp).toLocaleString()}`, - "", - ]; - - // Workers section - lines.push("### Workers"); - if (context.workers.length === 0) { - lines.push("No workers registered. Go to [Workers](/workers) to add one."); - } else { - lines.push(`Total: ${context.workers.length} worker(s)`); - for (const worker of context.workers) { - lines.push(`- **${worker.name}** (${worker.host}): ${worker.status}`); - for (const gpu of worker.gpus) { - const memUsedGB = (gpu.memoryUsed / 1024).toFixed(1); - const memTotalGB = (gpu.memoryTotal / 1024).toFixed(1); - const memFreeGB = ((gpu.memoryTotal - gpu.memoryUsed) / 1024).toFixed( - 1, - ); - lines.push( - ` - GPU ${gpu.index}: ${gpu.name}, Used: ${memUsedGB}GB, Free: ${memFreeGB}GB, Total: ${memTotalGB}GB, Util: ${gpu.utilizationGpu}%`, - ); - } - } - } - lines.push(""); - - // Docker Containers section - lines.push("### Docker Containers"); - if (context.containers.length === 0) { - lines.push( - "No containers running. View [Containers](/containers) page for details.", - ); - } else { - const runningContainers = context.containers.filter( - (c) => - c.status.toLowerCase().includes("running") || - c.status.toLowerCase().includes("up"), - ); - lines.push( - `Total: ${context.containers.length} container(s), Running: ${runningContainers.length}`, - ); - for (const container of context.containers.slice(0, 10)) { - lines.push( - `- **${container.name}** (${container.image}): ${container.status} on ${container.workerName}`, - ); - } - if (context.containers.length > 10) { - lines.push( - ` ... and ${context.containers.length - 10} more. See [Containers](/containers) for full list.`, - ); - } - } - lines.push(""); - - // Docker Images section - lines.push("### Docker Images"); - if (context.images.length === 0) { - lines.push("No images found. View [Images](/images) page for details."); - } else { - lines.push(`Total: ${context.images.length} image(s)`); - const imagesByWorker: Record = {}; - for (const img of context.images) { - if (!imagesByWorker[img.workerName]) imagesByWorker[img.workerName] = []; - imagesByWorker[img.workerName].push(img); - } - for (const [worker, imgs] of Object.entries(imagesByWorker)) { - lines.push(`- ${worker}: ${imgs.length} image(s)`); - } - } - lines.push(""); - - // Storage section - lines.push("### Storage Volumes"); - if (context.storageVolumes.length === 0) { - lines.push( - "No storage volumes. View [Storage](/storage) page for details.", - ); - } else { - lines.push(`Total: ${context.storageVolumes.length} volume(s)`); - } - lines.push(""); - - // Deployments section - lines.push("### Model Deployments"); - const activeDeployments = context.deployments.filter( - (d) => d.status === "running", - ); - const allDeployments = context.deployments; - lines.push( - `Total: ${allDeployments.length}, Running: ${activeDeployments.length}`, - ); - if (activeDeployments.length === 0) { - lines.push( - "No active model deployments. Go to [Deployments](/deployments) to deploy a model.", - ); - } else { - for (const dep of activeDeployments) { - lines.push( - `- **${dep.modelName}** on ${dep.workerName} (ID: ${dep.id}) - running`, - ); - } - } - lines.push(""); - - // Available models section - lines.push("### Available Models"); - if (context.models.length === 0) { - lines.push("No models registered. Go to [Models](/models) to add models."); - } else { - lines.push(`Total: ${context.models.length} model(s)`); - const modelsBySource: Record = {}; - for (const model of context.models) { - if (!modelsBySource[model.source]) { - modelsBySource[model.source] = []; - } - modelsBySource[model.source].push(model); - } - for (const [source, models] of Object.entries(modelsBySource)) { - lines.push(`- **${source}:** ${models.map((m) => m.name).join(", ")}`); - } - } - lines.push(""); - - // Semantic Router - if (context.semanticRouter) { - lines.push("### Semantic Router"); - lines.push( - `Status: ${context.semanticRouter.deployed ? "Deployed" : "Not deployed"}`, - ); - if (context.semanticRouter.models.length > 0) { - lines.push( - `Connected models: ${context.semanticRouter.models.join(", ")}`, - ); - } - lines.push(""); - } - - // Tool Calling Capabilities - lines.push("## Available Actions (Tool Calling)"); - lines.push( - "You have access to tools that allow you to TAKE ACTIONS on the system:", - ); - lines.push(""); - lines.push("### Query Tools (No confirmation needed)"); - lines.push("- `get_system_status`: Get complete system overview"); - lines.push("- `list_workers`: List all workers with GPU information"); - lines.push( - "- `list_containers`: List Docker containers (filter by status/worker_id)", - ); - lines.push("- `list_deployments`: List model deployments (filter by status)"); - lines.push("- `list_models`: List available models (filter by source)"); - lines.push("- `get_gpu_status`: Get detailed GPU status"); - lines.push(""); - lines.push("### Model Management (Requires user confirmation)"); - lines.push( - "- `add_model`: Add a new model (name, source: huggingface/ollama)", - ); - lines.push("- `delete_model`: Delete a model (model_id)"); - lines.push(""); - lines.push("### Deployment Management (Requires user confirmation)"); - lines.push( - "- `deploy_model`: Deploy a model to a worker (model_id, worker_id, gpu_ids?)", - ); - lines.push("- `stop_deployment`: Stop a running deployment (deployment_id)"); - lines.push( - "- `start_deployment`: Start a stopped deployment (deployment_id)", - ); - lines.push( - "- `delete_deployment`: Delete a deployment permanently (deployment_id)", - ); - lines.push(""); - lines.push("### Container Management (Requires user confirmation)"); - lines.push( - "- `stop_container`: Stop a Docker container (container_name, worker_id)", - ); - lines.push( - "- `remove_container`: Remove a Docker container (container_name, worker_id, force?)", - ); - lines.push(""); - lines.push("### API Key Management"); - lines.push("- `list_api_keys`: List all API keys (No confirmation needed)"); - lines.push( - "- `create_api_key`: Create a new API key (name, description?, expires_in_days?)", - ); - lines.push("- `delete_api_key`: Delete an API key (api_key_id)"); - lines.push(""); - lines.push("### Docker Image Management"); - lines.push( - "- `list_images`: List all Docker images (No confirmation needed)", - ); - lines.push("- `pull_image`: Pull a Docker image (worker_id, image)"); - lines.push( - "- `delete_image`: Delete a Docker image (image_id, worker_id, force?)", - ); - lines.push(""); - lines.push("### Storage Management"); - lines.push( - "- `list_storage_volumes`: List storage volumes (No confirmation needed)", - ); - lines.push( - "- `get_disk_usage`: Get disk usage statistics (No confirmation needed)", - ); - lines.push( - "- `delete_storage_volume`: Delete a storage volume (volume_name, worker_id, force?)", - ); - lines.push( - "- `prune_storage`: Clean up unused Docker resources (images?, containers?, volumes?, build_cache?)", - ); - lines.push(""); - lines.push("**WORKFLOW FOR CONTAINER/IMAGE OPERATIONS:**"); - lines.push( - "1. If user asks to stop/remove a container by name, FIRST call list_containers to find the worker_id", - ); - lines.push( - "2. Then call stop_container or remove_container with both container_name AND worker_id", - ); - lines.push( - "3. Same workflow applies to images - use list_images first to find worker_id", - ); - lines.push(""); - lines.push( - "**IMPORTANT:** When users ask you to perform actions, USE THE TOOLS to execute them. The user will see a confirmation dialog before any action is executed.", - ); - lines.push(""); - - // Capabilities and instructions - lines.push("## Instructions"); - lines.push( - "1. Always use markdown links when mentioning pages (e.g., [Containers](/containers))", - ); - lines.push("2. Provide specific numbers from the system data above"); - lines.push("3. Be concise and accurate"); - lines.push( - "4. When users ask about 'containers' or 'docker containers', refer to the Docker Containers section", - ); - lines.push( - "5. When users ask about 'deployments', distinguish between Docker containers and Model Deployments", - ); - lines.push( - "6. When users ask you to deploy a model, USE the deploy_model tool", - ); - lines.push( - "7. When users ask you to stop/start/delete a deployment, USE the corresponding tool", - ); - lines.push("8. After executing an action, report the result to the user"); - lines.push("9. Respond in the same language as the user's query"); - lines.push(""); - - return lines.join("\n"); -} diff --git a/frontend/src/components/chat-panel/tools.ts b/frontend/src/components/chat-panel/tools.ts deleted file mode 100644 index 125cce1..0000000 --- a/frontend/src/components/chat-panel/tools.ts +++ /dev/null @@ -1,2252 +0,0 @@ -/** - * Chat Tools Definition - * - * Tools that the AI assistant can call to interact with LMStack. - * Includes confirmation flow for dangerous operations. - */ -import { api } from "../../api/client"; -import type { ChatModelConfig } from "./types"; - -/** - * Tool definition for OpenAI-compatible API - */ -export interface ToolDefinition { - type: "function"; - function: { - name: string; - description: string; - parameters: { - type: "object"; - properties: Record; - required?: string[]; - }; - }; -} - -/** - * Tool call from LLM response - */ -export interface ToolCall { - id: string; - type: "function"; - function: { - name: string; - arguments: string; - }; -} - -/** - * Tool execution result - */ -export interface ToolResult { - tool_call_id: string; - role: "tool"; - content: string; -} - -/** - * Tool metadata for UI display - */ -export interface ToolMeta { - name: string; - displayName: string; - description: string; - category: "query" | "action"; - dangerous: boolean; - icon: string; -} - -/** - * Pending tool execution for confirmation - */ -export interface PendingToolExecution { - toolCall: ToolCall; - parsedArgs: Record; - meta: ToolMeta; -} - -/** - * Tool metadata registry - */ -export const TOOL_META: Record = { - // Query tools (no confirmation needed) - get_system_status: { - name: "get_system_status", - displayName: "Get System Status", - description: "Query complete system status", - category: "query", - dangerous: false, - icon: "dashboard", - }, - list_workers: { - name: "list_workers", - displayName: "List Workers", - description: "Query all worker nodes", - category: "query", - dangerous: false, - icon: "cluster", - }, - list_containers: { - name: "list_containers", - displayName: "List Containers", - description: "Query Docker containers", - category: "query", - dangerous: false, - icon: "container", - }, - list_deployments: { - name: "list_deployments", - displayName: "List Deployments", - description: "Query model deployments", - category: "query", - dangerous: false, - icon: "rocket", - }, - list_models: { - name: "list_models", - displayName: "List Models", - description: "Query available models", - category: "query", - dangerous: false, - icon: "robot", - }, - get_gpu_status: { - name: "get_gpu_status", - displayName: "Get GPU Status", - description: "Query GPU usage", - category: "query", - dangerous: false, - icon: "thunderbolt", - }, - - // Action tools (confirmation needed) - add_model: { - name: "add_model", - displayName: "Add Model", - description: "Add a new model to the system", - category: "action", - dangerous: false, - icon: "plus", - }, - delete_model: { - name: "delete_model", - displayName: "Delete Model", - description: "Delete a model from the system", - category: "action", - dangerous: true, - icon: "delete", - }, - deploy_model: { - name: "deploy_model", - displayName: "Deploy Model", - description: "Deploy a model to a worker", - category: "action", - dangerous: false, - icon: "rocket", - }, - stop_deployment: { - name: "stop_deployment", - displayName: "Stop Deployment", - description: "Stop a running deployment", - category: "action", - dangerous: true, - icon: "pause-circle", - }, - start_deployment: { - name: "start_deployment", - displayName: "Start Deployment", - description: "Start a stopped deployment", - category: "action", - dangerous: false, - icon: "play-circle", - }, - delete_deployment: { - name: "delete_deployment", - displayName: "Delete Deployment", - description: "Permanently delete a deployment", - category: "action", - dangerous: true, - icon: "delete", - }, - stop_container: { - name: "stop_container", - displayName: "Stop Container", - description: "Stop a Docker container", - category: "action", - dangerous: true, - icon: "stop", - }, - remove_container: { - name: "remove_container", - displayName: "Remove Container", - description: "Remove a Docker container", - category: "action", - dangerous: true, - icon: "delete", - }, - - // API Key tools - list_api_keys: { - name: "list_api_keys", - displayName: "List API Keys", - description: "Query all API keys", - category: "query", - dangerous: false, - icon: "key", - }, - create_api_key: { - name: "create_api_key", - displayName: "Create API Key", - description: "Create a new API key", - category: "action", - dangerous: false, - icon: "plus", - }, - delete_api_key: { - name: "delete_api_key", - displayName: "Delete API Key", - description: "Delete an API key", - category: "action", - dangerous: true, - icon: "delete", - }, - - // Docker Image tools - list_images: { - name: "list_images", - displayName: "List Images", - description: "Query Docker images", - category: "query", - dangerous: false, - icon: "container", - }, - pull_image: { - name: "pull_image", - displayName: "Pull Image", - description: "Pull a Docker image from registry", - category: "action", - dangerous: false, - icon: "download", - }, - delete_image: { - name: "delete_image", - displayName: "Delete Image", - description: "Delete a Docker image", - category: "action", - dangerous: true, - icon: "delete", - }, - - // Storage tools - list_storage_volumes: { - name: "list_storage_volumes", - displayName: "List Storage Volumes", - description: "Query Docker storage volumes", - category: "query", - dangerous: false, - icon: "database", - }, - get_disk_usage: { - name: "get_disk_usage", - displayName: "Get Disk Usage", - description: "Query disk usage statistics", - category: "query", - dangerous: false, - icon: "pie-chart", - }, - delete_storage_volume: { - name: "delete_storage_volume", - displayName: "Delete Storage Volume", - description: "Delete a Docker storage volume", - category: "action", - dangerous: true, - icon: "delete", - }, - prune_storage: { - name: "prune_storage", - displayName: "Prune Storage", - description: "Clean up unused Docker resources", - category: "action", - dangerous: true, - icon: "clear", - }, - - // Auto-Tuning tools - list_tuning_jobs: { - name: "list_tuning_jobs", - displayName: "List Tuning Jobs", - description: "Query all auto-tuning jobs", - category: "query", - dangerous: false, - icon: "thunderbolt", - }, - start_auto_tuning: { - name: "start_auto_tuning", - displayName: "Start Auto-Tuning", - description: "Start a new auto-tuning job", - category: "action", - dangerous: false, - icon: "experiment", - }, - get_tuning_job: { - name: "get_tuning_job", - displayName: "Get Tuning Job", - description: "Get details of a tuning job", - category: "query", - dangerous: false, - icon: "info-circle", - }, - cancel_tuning_job: { - name: "cancel_tuning_job", - displayName: "Cancel Tuning Job", - description: "Cancel a running tuning job", - category: "action", - dangerous: true, - icon: "stop", - }, - query_knowledge_base: { - name: "query_knowledge_base", - displayName: "Query Knowledge Base", - description: "Query performance knowledge base", - category: "query", - dangerous: false, - icon: "database", - }, - run_benchmark: { - name: "run_benchmark", - displayName: "Run Benchmark", - description: "Run performance benchmark on a deployment", - category: "action", - dangerous: false, - icon: "bar-chart", - }, -}; - -/** - * Check if a tool requires confirmation - */ -export function requiresConfirmation(toolName: string): boolean { - const meta = TOOL_META[toolName]; - return meta?.category === "action"; -} - -/** - * Get tool metadata - */ -export function getToolMeta(toolName: string): ToolMeta { - return ( - TOOL_META[toolName] || { - name: toolName, - displayName: toolName, - description: "Unknown tool", - category: "action", - dangerous: false, - icon: "question", - } - ); -} - -/** - * Available tools for the AI assistant - */ -export const CHAT_TOOLS: ToolDefinition[] = [ - // ============== Query Tools ============== - { - type: "function", - function: { - name: "get_system_status", - description: - "Get complete LMStack system status including workers, GPUs, containers, and deployments. Call this to get the latest system information.", - parameters: { - type: "object", - properties: {}, - required: [], - }, - }, - }, - { - type: "function", - function: { - name: "list_workers", - description: - "List all worker nodes with their GPU status, memory usage, and availability.", - parameters: { - type: "object", - properties: {}, - required: [], - }, - }, - }, - { - type: "function", - function: { - name: "list_containers", - description: "List all Docker containers running across all workers.", - parameters: { - type: "object", - properties: { - status: { - type: "string", - description: "Filter by status: running, stopped, or all", - enum: ["running", "stopped", "all"], - }, - worker_id: { - type: "number", - description: "Optional: Filter by specific worker ID", - }, - }, - required: [], - }, - }, - }, - { - type: "function", - function: { - name: "list_deployments", - description: "List all model deployments with their status.", - parameters: { - type: "object", - properties: { - status: { - type: "string", - description: "Filter by status: running, stopped, or all", - enum: ["running", "stopped", "all"], - }, - }, - required: [], - }, - }, - }, - { - type: "function", - function: { - name: "list_models", - description: "List all available models that can be deployed.", - parameters: { - type: "object", - properties: { - source: { - type: "string", - description: "Filter by source", - enum: ["huggingface", "ollama", "local"], - }, - }, - required: [], - }, - }, - }, - { - type: "function", - function: { - name: "get_gpu_status", - description: - "Get detailed GPU status including memory usage, utilization, and temperature for all workers.", - parameters: { - type: "object", - properties: { - worker_id: { - type: "number", - description: "Optional: Filter by specific worker ID", - }, - }, - required: [], - }, - }, - }, - - // ============== Model Management Tools ============== - { - type: "function", - function: { - name: "add_model", - description: - "Add a new model to the system. Supports HuggingFace and Ollama models.", - parameters: { - type: "object", - properties: { - name: { - type: "string", - description: - "Model name/identifier (e.g., 'Qwen/Qwen2.5-7B-Instruct' for HuggingFace, 'llama3.2' for Ollama)", - }, - source: { - type: "string", - description: "Model source", - enum: ["huggingface", "ollama"], - }, - parameters: { - type: "string", - description: "Optional: Model parameters (e.g., '7B', '13B')", - }, - quantization: { - type: "string", - description: - "Optional: Quantization format (e.g., 'GPTQ', 'AWQ', 'GGUF')", - }, - }, - required: ["name", "source"], - }, - }, - }, - { - type: "function", - function: { - name: "delete_model", - description: - "Delete a model from the system. This will NOT delete any deployments using this model.", - parameters: { - type: "object", - properties: { - model_id: { - type: "number", - description: - "ID of the model to delete (use list_models to find IDs)", - }, - }, - required: ["model_id"], - }, - }, - }, - - // ============== Deployment Tools ============== - { - type: "function", - function: { - name: "deploy_model", - description: - "Deploy a model to a worker. This will start the model inference service.", - parameters: { - type: "object", - properties: { - model_id: { - type: "number", - description: - "ID of the model to deploy (use list_models to find IDs)", - }, - worker_id: { - type: "number", - description: - "ID of the worker to deploy to (use list_workers to find IDs)", - }, - gpu_ids: { - type: "array", - items: { type: "number" }, - description: - "Optional: Specific GPU indices to use. If not provided, GPUs will be auto-selected.", - }, - }, - required: ["model_id", "worker_id"], - }, - }, - }, - { - type: "function", - function: { - name: "stop_deployment", - description: "Stop a running model deployment.", - parameters: { - type: "object", - properties: { - deployment_id: { - type: "number", - description: - "ID of the deployment to stop (use list_deployments to find IDs)", - }, - }, - required: ["deployment_id"], - }, - }, - }, - { - type: "function", - function: { - name: "start_deployment", - description: "Start a stopped model deployment.", - parameters: { - type: "object", - properties: { - deployment_id: { - type: "number", - description: "ID of the deployment to start", - }, - }, - required: ["deployment_id"], - }, - }, - }, - { - type: "function", - function: { - name: "delete_deployment", - description: - "Delete a model deployment completely. This cannot be undone.", - parameters: { - type: "object", - properties: { - deployment_id: { - type: "number", - description: "ID of the deployment to delete", - }, - }, - required: ["deployment_id"], - }, - }, - }, - - // ============== Container Tools ============== - { - type: "function", - function: { - name: "stop_container", - description: - "Stop a running Docker container. If you don't know the worker_id, call list_containers first to find it.", - parameters: { - type: "object", - properties: { - container_name: { - type: "string", - description: - "Name of the container to stop (e.g., 'lmstack-llama')", - }, - worker_id: { - type: "number", - description: - "ID of the worker where the container is running. Use list_containers to find this.", - }, - }, - required: ["container_name", "worker_id"], - }, - }, - }, - { - type: "function", - function: { - name: "remove_container", - description: - "Remove/delete a Docker container. If you don't know the worker_id, call list_containers first to find it.", - parameters: { - type: "object", - properties: { - container_name: { - type: "string", - description: "Name of the container to remove", - }, - worker_id: { - type: "number", - description: - "ID of the worker where the container is located. Use list_containers to find this.", - }, - force: { - type: "boolean", - description: "Force remove even if running (default: false)", - }, - }, - required: ["container_name", "worker_id"], - }, - }, - }, - - // ============== API Key Tools ============== - { - type: "function", - function: { - name: "list_api_keys", - description: - "List all API keys in the system with their usage statistics.", - parameters: { - type: "object", - properties: {}, - required: [], - }, - }, - }, - { - type: "function", - function: { - name: "create_api_key", - description: "Create a new API key for accessing the LMStack API.", - parameters: { - type: "object", - properties: { - name: { - type: "string", - description: - "Name for the API key (e.g., 'production-key', 'test-key')", - }, - description: { - type: "string", - description: "Optional description for the API key", - }, - expires_in_days: { - type: "number", - description: - "Optional: Number of days until the key expires. If not set, the key never expires.", - }, - }, - required: ["name"], - }, - }, - }, - { - type: "function", - function: { - name: "delete_api_key", - description: "Delete an API key from the system.", - parameters: { - type: "object", - properties: { - api_key_id: { - type: "number", - description: - "ID of the API key to delete (use list_api_keys to find IDs)", - }, - }, - required: ["api_key_id"], - }, - }, - }, - - // ============== Docker Image Tools ============== - { - type: "function", - function: { - name: "list_images", - description: "List all Docker images across all workers.", - parameters: { - type: "object", - properties: { - worker_id: { - type: "number", - description: "Optional: Filter by specific worker ID", - }, - repository: { - type: "string", - description: "Optional: Filter by repository name", - }, - }, - required: [], - }, - }, - }, - { - type: "function", - function: { - name: "pull_image", - description: "Pull a Docker image from a registry to a worker.", - parameters: { - type: "object", - properties: { - worker_id: { - type: "number", - description: "ID of the worker to pull the image to", - }, - image: { - type: "string", - description: - "Image reference (e.g., 'nginx:latest', 'python:3.11')", - }, - }, - required: ["worker_id", "image"], - }, - }, - }, - { - type: "function", - function: { - name: "delete_image", - description: "Delete a Docker image from a worker.", - parameters: { - type: "object", - properties: { - image_id: { - type: "string", - description: "ID or name of the image to delete", - }, - worker_id: { - type: "number", - description: "ID of the worker where the image is located", - }, - force: { - type: "boolean", - description: - "Force removal even if image is in use (default: false)", - }, - }, - required: ["image_id", "worker_id"], - }, - }, - }, - - // ============== Storage Tools ============== - { - type: "function", - function: { - name: "list_storage_volumes", - description: "List all Docker storage volumes across all workers.", - parameters: { - type: "object", - properties: { - worker_id: { - type: "number", - description: "Optional: Filter by specific worker ID", - }, - }, - required: [], - }, - }, - }, - { - type: "function", - function: { - name: "get_disk_usage", - description: - "Get Docker disk usage statistics including images, containers, volumes, and build cache.", - parameters: { - type: "object", - properties: { - worker_id: { - type: "number", - description: "Optional: Filter by specific worker ID", - }, - }, - required: [], - }, - }, - }, - { - type: "function", - function: { - name: "delete_storage_volume", - description: "Delete a Docker storage volume from a worker.", - parameters: { - type: "object", - properties: { - volume_name: { - type: "string", - description: "Name of the volume to delete", - }, - worker_id: { - type: "number", - description: "ID of the worker where the volume is located", - }, - force: { - type: "boolean", - description: "Force removal (default: false)", - }, - }, - required: ["volume_name", "worker_id"], - }, - }, - }, - { - type: "function", - function: { - name: "prune_storage", - description: - "Clean up unused Docker resources (images, containers, volumes, build cache) to free disk space.", - parameters: { - type: "object", - properties: { - worker_id: { - type: "number", - description: - "Optional: Only prune on specific worker. If not set, prunes on all workers.", - }, - images: { - type: "boolean", - description: "Prune unused images (default: true)", - }, - containers: { - type: "boolean", - description: "Prune stopped containers (default: true)", - }, - volumes: { - type: "boolean", - description: "Prune unused volumes (default: false - be careful!)", - }, - build_cache: { - type: "boolean", - description: "Prune build cache (default: true)", - }, - }, - required: [], - }, - }, - }, - - // ============== Auto-Tuning Tools ============== - { - type: "function", - function: { - name: "list_tuning_jobs", - description: "List all auto-tuning jobs with their status and progress.", - parameters: { - type: "object", - properties: { - status: { - type: "string", - description: "Filter by status", - enum: [ - "pending", - "analyzing", - "querying_kb", - "exploring", - "benchmarking", - "completed", - "failed", - "cancelled", - ], - }, - }, - required: [], - }, - }, - }, - { - type: "function", - function: { - name: "start_auto_tuning", - description: - "Start a new auto-tuning job to find the best deployment configuration for a model. The agent will analyze the environment, query the knowledge base, explore configuration space, run benchmarks, and find the optimal settings.", - parameters: { - type: "object", - properties: { - model_id: { - type: "number", - description: - "ID of the model to tune (use list_models to find IDs)", - }, - worker_id: { - type: "number", - description: - "ID of the worker to use for tuning (use list_workers to find IDs)", - }, - optimization_target: { - type: "string", - description: "What to optimize for", - enum: ["throughput", "latency", "cost", "balanced"], - }, - }, - required: ["model_id", "worker_id"], - }, - }, - }, - { - type: "function", - function: { - name: "get_tuning_job", - description: - "Get detailed information about a specific tuning job including progress, best configuration, and all results.", - parameters: { - type: "object", - properties: { - job_id: { - type: "number", - description: "ID of the tuning job", - }, - }, - required: ["job_id"], - }, - }, - }, - { - type: "function", - function: { - name: "cancel_tuning_job", - description: "Cancel a running auto-tuning job.", - parameters: { - type: "object", - properties: { - job_id: { - type: "number", - description: "ID of the tuning job to cancel", - }, - }, - required: ["job_id"], - }, - }, - }, - { - type: "function", - function: { - name: "query_knowledge_base", - description: - "Query the performance knowledge base to find similar configurations and their benchmark results. This uses transfer learning from previous tuning results.", - parameters: { - type: "object", - properties: { - model_name: { - type: "string", - description: "Model name pattern to match (e.g., 'Qwen', 'Llama')", - }, - model_family: { - type: "string", - description: "Model family: Qwen, Llama, Mistral, etc.", - }, - gpu_model: { - type: "string", - description: "GPU model pattern (e.g., 'RTX 4090', 'A100')", - }, - optimization_target: { - type: "string", - description: "Optimization target for scoring", - enum: ["throughput", "latency", "cost", "balanced"], - }, - limit: { - type: "number", - description: "Maximum number of results to return (default: 10)", - }, - }, - required: [], - }, - }, - }, - { - type: "function", - function: { - name: "run_benchmark", - description: - "Run a performance benchmark on a deployment to measure throughput, latency, and resource usage.", - parameters: { - type: "object", - properties: { - deployment_id: { - type: "number", - description: - "ID of the deployment to benchmark (use list_deployments to find IDs)", - }, - test_type: { - type: "string", - description: "Type of benchmark test", - enum: ["throughput", "latency"], - }, - duration_seconds: { - type: "number", - description: "Test duration in seconds (10-600, default: 60)", - }, - input_length: { - type: "number", - description: "Input token length (default: 512)", - }, - output_length: { - type: "number", - description: "Output token length (default: 128)", - }, - concurrency: { - type: "number", - description: "Number of concurrent requests (1-64, default: 1)", - }, - }, - required: ["deployment_id"], - }, - }, - }, -]; - -/** - * Execute a tool call and return the result - * @param toolCall - The tool call to execute - * @param modelConfig - Optional model config for tools that need LLM access (like auto-tuning) - */ -export async function executeTool( - toolCall: ToolCall, - modelConfig?: ChatModelConfig, -): Promise { - const { name, arguments: argsStr } = toolCall.function; - let args: Record = {}; - - try { - args = JSON.parse(argsStr); - } catch { - return { - tool_call_id: toolCall.id, - role: "tool", - content: `Error: Invalid arguments JSON: ${argsStr}`, - }; - } - - try { - let result: string; - - switch (name) { - // Query tools - case "get_system_status": - result = await getSystemStatus(); - break; - - case "list_workers": - result = await listWorkers(); - break; - - case "list_containers": - result = await listContainers(args.status, args.worker_id); - break; - - case "list_deployments": - result = await listDeployments(args.status); - break; - - case "list_models": - result = await listModels(args.source); - break; - - case "get_gpu_status": - result = await getGpuStatus(args.worker_id); - break; - - // Model management tools - case "add_model": - result = await addModel( - args.name, - args.source, - args.parameters, - args.quantization, - ); - break; - - case "delete_model": - result = await deleteModel(args.model_id); - break; - - // Deployment tools - case "deploy_model": - result = await deployModel(args.model_id, args.worker_id, args.gpu_ids); - break; - - case "stop_deployment": - result = await stopDeployment(args.deployment_id); - break; - - case "start_deployment": - result = await startDeployment(args.deployment_id); - break; - - case "delete_deployment": - result = await deleteDeployment(args.deployment_id); - break; - - // Container tools - case "stop_container": - result = await stopContainer(args.container_name, args.worker_id); - break; - - case "remove_container": - result = await removeContainer( - args.container_name, - args.worker_id, - args.force, - ); - break; - - // API Key tools - case "list_api_keys": - result = await listApiKeys(); - break; - - case "create_api_key": - result = await createApiKey( - args.name, - args.description, - args.expires_in_days, - ); - break; - - case "delete_api_key": - result = await deleteApiKey(args.api_key_id); - break; - - // Docker Image tools - case "list_images": - result = await listImages(args.worker_id, args.repository); - break; - - case "pull_image": - result = await pullImage(args.worker_id, args.image); - break; - - case "delete_image": - result = await deleteImage(args.image_id, args.worker_id, args.force); - break; - - // Storage tools - case "list_storage_volumes": - result = await listStorageVolumes(args.worker_id); - break; - - case "get_disk_usage": - result = await getDiskUsage(args.worker_id); - break; - - case "delete_storage_volume": - result = await deleteStorageVolume( - args.volume_name, - args.worker_id, - args.force, - ); - break; - - case "prune_storage": - result = await pruneStorage( - args.worker_id, - args.images, - args.containers, - args.volumes, - args.build_cache, - ); - break; - - // Auto-Tuning tools - case "list_tuning_jobs": - result = await listTuningJobs(args.status); - break; - - case "start_auto_tuning": - result = await startAutoTuning( - args.model_id, - args.worker_id, - args.optimization_target, - modelConfig, - ); - break; - - case "get_tuning_job": - result = await getTuningJob(args.job_id); - break; - - case "cancel_tuning_job": - result = await cancelTuningJob(args.job_id); - break; - - case "query_knowledge_base": - result = await queryKnowledgeBase( - args.model_name, - args.model_family, - args.gpu_model, - args.optimization_target, - args.limit, - ); - break; - - case "run_benchmark": - result = await runBenchmark( - args.deployment_id, - args.test_type, - args.duration_seconds, - args.input_length, - args.output_length, - args.concurrency, - ); - break; - - default: - result = `Unknown tool: ${name}`; - } - - return { - tool_call_id: toolCall.id, - role: "tool", - content: result, - }; - } catch (error) { - const message = error instanceof Error ? error.message : String(error); - return { - tool_call_id: toolCall.id, - role: "tool", - content: `Error executing ${name}: ${message}`, - }; - } -} - -// ============================================================================ -// Tool Implementations -// ============================================================================ - -async function getSystemStatus(): Promise { - const [workers, containers, deployments, models] = await Promise.all([ - api.get("/workers").then((r) => r.data.items || []), - api.get("/containers").then((r) => r.data.items || []), - api.get("/deployments").then((r) => r.data.items || []), - api.get("/models").then((r) => r.data.items || []), - ]); - - const onlineWorkers = workers.filter((w: any) => w.status === "online"); - const runningContainers = containers.filter( - (c: any) => - c.status?.toLowerCase().includes("running") || - c.status?.toLowerCase().includes("up"), - ); - const runningDeployments = deployments.filter( - (d: any) => d.status === "running", - ); - - let totalGpuMem = 0, - usedGpuMem = 0; - for (const w of workers) { - for (const g of w.gpu_info || []) { - totalGpuMem += g.memory_total || 0; - usedGpuMem += g.memory_used || 0; - } - } - - return JSON.stringify( - { - summary: { - workers: `${onlineWorkers.length}/${workers.length} online`, - containers: `${runningContainers.length}/${containers.length} running`, - deployments: `${runningDeployments.length}/${deployments.length} running`, - models: `${models.length} available`, - gpu_memory: { - used_gb: (usedGpuMem / 1024).toFixed(1), - free_gb: ((totalGpuMem - usedGpuMem) / 1024).toFixed(1), - total_gb: (totalGpuMem / 1024).toFixed(1), - }, - }, - workers: workers.map((w: any) => ({ - id: w.id, - name: w.name, - status: w.status, - gpus: (w.gpu_info || []).map((g: any) => ({ - index: g.index, - name: g.name, - memory_used_gb: (g.memory_used / 1024).toFixed(1), - memory_free_gb: ((g.memory_total - g.memory_used) / 1024).toFixed(1), - memory_total_gb: (g.memory_total / 1024).toFixed(1), - utilization: g.utilization_gpu, - })), - })), - running_deployments: runningDeployments.map((d: any) => ({ - id: d.id, - model: d.model?.name || d.name, - worker: d.worker?.name, - })), - }, - null, - 2, - ); -} - -async function listWorkers(): Promise { - const response = await api.get("/workers"); - const workers = response.data.items || []; - - return JSON.stringify( - workers.map((w: any) => ({ - id: w.id, - name: w.name, - host: w.host, - status: w.status, - gpus: (w.gpu_info || []).map((g: any) => ({ - index: g.index, - name: g.name, - memory_used_gb: (g.memory_used / 1024).toFixed(1), - memory_free_gb: ((g.memory_total - g.memory_used) / 1024).toFixed(1), - memory_total_gb: (g.memory_total / 1024).toFixed(1), - utilization_percent: g.utilization_gpu, - })), - })), - null, - 2, - ); -} - -async function listContainers( - status?: string, - workerId?: number, -): Promise { - const response = await api.get("/containers"); - let containers = response.data.items || []; - - if (workerId) { - containers = containers.filter( - (c: any) => c.worker?.id === workerId || c.worker_id === workerId, - ); - } - - if (status && status !== "all") { - containers = containers.filter((c: any) => { - const s = c.status?.toLowerCase() || ""; - if (status === "running") { - return s.includes("running") || s.includes("up"); - } - return s.includes(status); - }); - } - - return JSON.stringify( - containers.map((c: any) => ({ - id: c.id?.substring(0, 12), - name: c.name, - image: c.image, - status: c.status, - worker: c.worker?.name || c.worker_name, - worker_id: c.worker?.id || c.worker_id, - })), - null, - 2, - ); -} - -async function listDeployments(status?: string): Promise { - const response = await api.get("/deployments"); - let deployments = response.data.items || []; - - if (status && status !== "all") { - deployments = deployments.filter((d: any) => d.status === status); - } - - return JSON.stringify( - deployments.map((d: any) => ({ - id: d.id, - name: d.name, - model: d.model?.name, - model_id: d.model?.id, - worker: d.worker?.name, - worker_id: d.worker?.id, - status: d.status, - gpu_ids: d.gpu_ids, - port: d.port, - created_at: d.created_at, - })), - null, - 2, - ); -} - -async function listModels(source?: string): Promise { - const response = await api.get("/models"); - let models = response.data.items || []; - - if (source) { - models = models.filter((m: any) => m.source === source); - } - - return JSON.stringify( - models.map((m: any) => ({ - id: m.id, - name: m.name, - source: m.source, - parameters: m.parameters, - quantization: m.quantization, - })), - null, - 2, - ); -} - -async function getGpuStatus(workerId?: number): Promise { - const response = await api.get("/workers"); - let workers = response.data.items || []; - - if (workerId) { - workers = workers.filter((w: any) => w.id === workerId); - } - - const result = workers.map((w: any) => ({ - worker_id: w.id, - worker_name: w.name, - status: w.status, - gpus: (w.gpu_info || []).map((g: any) => ({ - index: g.index, - name: g.name, - memory_used_gb: (g.memory_used / 1024).toFixed(1), - memory_free_gb: ((g.memory_total - g.memory_used) / 1024).toFixed(1), - memory_total_gb: (g.memory_total / 1024).toFixed(1), - utilization_percent: g.utilization_gpu, - temperature: g.temperature, - })), - })); - - return JSON.stringify(result, null, 2); -} - -async function deployModel( - modelId: number, - workerId: number, - gpuIds?: number[], -): Promise { - if (!modelId || !workerId) { - return "Error: model_id and worker_id are required"; - } - - const response = await api.post("/deployments", { - model_id: modelId, - worker_id: workerId, - gpu_ids: gpuIds, - }); - - const deployment = response.data; - const deploymentId = deployment.id; - - // Poll deployment status for up to 60 seconds - const maxPollTime = 60000; - const pollInterval = 3000; - const startTime = Date.now(); - let lastStatus = deployment.status; - const statusUpdates: string[] = [`Initial status: ${deployment.status}`]; - - while (Date.now() - startTime < maxPollTime) { - await new Promise((resolve) => setTimeout(resolve, pollInterval)); - - try { - const statusResponse = await api.get(`/deployments/${deploymentId}`); - const currentStatus = statusResponse.data.status; - const statusMessage = statusResponse.data.status_message; - - if (currentStatus !== lastStatus) { - statusUpdates.push( - `Status changed: ${lastStatus} → ${currentStatus}${statusMessage ? ` (${statusMessage})` : ""}`, - ); - lastStatus = currentStatus; - } - - // Stop polling if deployment reached a terminal state - if (["running", "error", "stopped"].includes(currentStatus)) { - return JSON.stringify( - { - success: currentStatus === "running", - message: - currentStatus === "running" - ? `Deployment completed successfully! Model is now running.` - : currentStatus === "error" - ? `Deployment failed: ${statusMessage || "Unknown error"}` - : `Deployment stopped`, - deployment: { - id: deploymentId, - status: currentStatus, - status_message: statusMessage, - model: deployment.model?.name, - worker: deployment.worker?.name, - port: statusResponse.data.port, - }, - status_history: statusUpdates, - }, - null, - 2, - ); - } - } catch (error) { - // Continue polling even if one request fails - } - } - - // Timeout - return current status - return JSON.stringify( - { - success: false, - message: `Deployment is still in progress (status: ${lastStatus}). Check [Deployments](/deployments) page for updates.`, - deployment: { - id: deploymentId, - status: lastStatus, - model: deployment.model?.name, - worker: deployment.worker?.name, - }, - status_history: statusUpdates, - note: "Deployment is taking longer than expected. This is normal for large models that need to be downloaded.", - }, - null, - 2, - ); -} - -async function stopDeployment(deploymentId: number): Promise { - if (!deploymentId) { - return "Error: deployment_id is required"; - } - - await api.post(`/deployments/${deploymentId}/stop`); - return JSON.stringify( - { - success: true, - message: `Deployment ${deploymentId} stopped successfully`, - }, - null, - 2, - ); -} - -async function startDeployment(deploymentId: number): Promise { - if (!deploymentId) { - return "Error: deployment_id is required"; - } - - await api.post(`/deployments/${deploymentId}/start`); - return JSON.stringify( - { - success: true, - message: `Deployment ${deploymentId} started successfully`, - }, - null, - 2, - ); -} - -async function deleteDeployment(deploymentId: number): Promise { - if (!deploymentId) { - return "Error: deployment_id is required"; - } - - await api.delete(`/deployments/${deploymentId}`); - return JSON.stringify( - { - success: true, - message: `Deployment ${deploymentId} deleted successfully`, - }, - null, - 2, - ); -} - -// ============================================================================ -// Model Management Tools -// ============================================================================ - -async function addModel( - name: string, - source: string, - parameters?: string, - quantization?: string, -): Promise { - if (!name || !source) { - return "Error: name and source are required"; - } - - const response = await api.post("/models", { - name, - source, - parameters, - quantization, - }); - - const model = response.data; - return JSON.stringify( - { - success: true, - message: `Model added successfully`, - model: { - id: model.id, - name: model.name, - source: model.source, - parameters: model.parameters, - quantization: model.quantization, - }, - }, - null, - 2, - ); -} - -async function deleteModel(modelId: number): Promise { - if (!modelId) { - return "Error: model_id is required"; - } - - await api.delete(`/models/${modelId}`); - return JSON.stringify( - { - success: true, - message: `Model ${modelId} deleted successfully`, - }, - null, - 2, - ); -} - -// ============================================================================ -// Container Tools -// ============================================================================ - -async function stopContainer( - containerName: string, - workerId: number, -): Promise { - if (!containerName || !workerId) { - return "Error: container_name and worker_id are required"; - } - - // Backend expects: POST /containers/{container_id}/stop?worker_id=X - await api.post( - `/containers/${encodeURIComponent(containerName)}/stop`, - null, - { - params: { worker_id: workerId }, - }, - ); - - return JSON.stringify( - { - success: true, - message: `Container "${containerName}" stopped successfully`, - }, - null, - 2, - ); -} - -async function removeContainer( - containerName: string, - workerId: number, - force?: boolean, -): Promise { - if (!containerName || !workerId) { - return "Error: container_name and worker_id are required"; - } - - // Backend expects: DELETE /containers/{container_id}?worker_id=X&force=Y - await api.delete(`/containers/${encodeURIComponent(containerName)}`, { - params: { worker_id: workerId, force: force || false }, - }); - - return JSON.stringify( - { - success: true, - message: `Container "${containerName}" removed successfully`, - }, - null, - 2, - ); -} - -// ============================================================================ -// API Key Tools -// ============================================================================ - -async function listApiKeys(): Promise { - const response = await api.get("/api-keys"); - const apiKeys = response.data.items || []; - - return JSON.stringify( - { - total: response.data.total || apiKeys.length, - api_keys: apiKeys.map((k: any) => ({ - id: k.id, - name: k.name, - description: k.description, - access_key: k.access_key, - expires_at: k.expires_at, - created_at: k.created_at, - last_used_at: k.last_used_at, - })), - }, - null, - 2, - ); -} - -async function createApiKey( - name: string, - description?: string, - expiresInDays?: number, -): Promise { - if (!name) { - return "Error: name is required"; - } - - const response = await api.post("/api-keys", { - name, - description, - expires_in_days: expiresInDays, - }); - - const apiKey = response.data; - return JSON.stringify( - { - success: true, - message: "API key created successfully", - api_key: { - id: apiKey.id, - name: apiKey.name, - access_key: apiKey.access_key, - full_key: apiKey.api_key, // The full key is only shown once! - expires_at: apiKey.expires_at, - }, - warning: "Save the full API key now! It will not be shown again.", - }, - null, - 2, - ); -} - -async function deleteApiKey(apiKeyId: number): Promise { - if (!apiKeyId) { - return "Error: api_key_id is required"; - } - - await api.delete(`/api-keys/${apiKeyId}`); - return JSON.stringify( - { - success: true, - message: `API key ${apiKeyId} deleted successfully`, - }, - null, - 2, - ); -} - -// ============================================================================ -// Docker Image Tools -// ============================================================================ - -async function listImages( - workerId?: number, - repository?: string, -): Promise { - const params: any = {}; - if (workerId) params.worker_id = workerId; - if (repository) params.repository = repository; - - const response = await api.get("/images", { params }); - const images = response.data.items || []; - - return JSON.stringify( - { - total: response.data.total || images.length, - images: images.map((img: any) => ({ - id: img.id?.substring(0, 12), - repository: img.repository, - tag: img.tag, - full_name: img.full_name, - size_mb: (img.size / 1024 / 1024).toFixed(1), - created_at: img.created_at, - worker: img.worker_name, - worker_id: img.worker_id, - })), - }, - null, - 2, - ); -} - -async function pullImage(workerId: number, image: string): Promise { - if (!workerId || !image) { - return "Error: worker_id and image are required"; - } - - const response = await api.post("/images/pull", { - worker_id: workerId, - image, - }); - - return JSON.stringify( - { - success: true, - message: `Image "${image}" pulled successfully`, - image: response.data.image, - }, - null, - 2, - ); -} - -async function deleteImage( - imageId: string, - workerId: number, - force?: boolean, -): Promise { - if (!imageId || !workerId) { - return "Error: image_id and worker_id are required"; - } - - await api.delete(`/images/${encodeURIComponent(imageId)}`, { - params: { worker_id: workerId, force: force || false }, - }); - - return JSON.stringify( - { - success: true, - message: `Image "${imageId}" deleted successfully`, - }, - null, - 2, - ); -} - -// ============================================================================ -// Storage Tools -// ============================================================================ - -async function listStorageVolumes(workerId?: number): Promise { - const params: any = {}; - if (workerId) params.worker_id = workerId; - - const response = await api.get("/storage/volumes", { params }); - const volumes = Array.isArray(response.data) ? response.data : []; - - return JSON.stringify( - { - total: volumes.length, - volumes: volumes.map((v: any) => ({ - name: v.name, - driver: v.driver, - mountpoint: v.mountpoint, - created_at: v.created_at, - worker: v.worker_name, - worker_id: v.worker_id, - })), - }, - null, - 2, - ); -} - -async function getDiskUsage(workerId?: number): Promise { - const params: any = {}; - if (workerId) params.worker_id = workerId; - - const response = await api.get("/storage/disk-usage", { params }); - const usageList = Array.isArray(response.data) ? response.data : []; - - const formatSize = (bytes: number) => { - if (bytes >= 1024 * 1024 * 1024) { - return `${(bytes / 1024 / 1024 / 1024).toFixed(2)} GB`; - } - return `${(bytes / 1024 / 1024).toFixed(2)} MB`; - }; - - return JSON.stringify( - { - workers: usageList.map((u: any) => ({ - worker: u.worker_name, - worker_id: u.worker_id, - images: { - count: u.images.count, - size: formatSize(u.images.size), - reclaimable: formatSize(u.images.reclaimable), - }, - containers: { - count: u.containers.count, - size: formatSize(u.containers.size), - reclaimable: formatSize(u.containers.reclaimable), - }, - volumes: { - count: u.volumes.count, - size: formatSize(u.volumes.size), - reclaimable: formatSize(u.volumes.reclaimable), - }, - build_cache: { - count: u.build_cache.count, - size: formatSize(u.build_cache.size), - reclaimable: formatSize(u.build_cache.reclaimable), - }, - total_size: formatSize(u.total_size), - total_reclaimable: formatSize(u.total_reclaimable), - })), - }, - null, - 2, - ); -} - -async function deleteStorageVolume( - volumeName: string, - workerId: number, - force?: boolean, -): Promise { - if (!volumeName || !workerId) { - return "Error: volume_name and worker_id are required"; - } - - await api.delete(`/storage/volumes/${encodeURIComponent(volumeName)}`, { - params: { worker_id: workerId, force: force || false }, - }); - - return JSON.stringify( - { - success: true, - message: `Volume "${volumeName}" deleted successfully`, - }, - null, - 2, - ); -} - -async function pruneStorage( - workerId?: number, - images: boolean = true, - containers: boolean = true, - volumes: boolean = false, - buildCache: boolean = true, -): Promise { - const params: any = {}; - if (workerId) params.worker_id = workerId; - - const response = await api.post( - "/storage/prune", - { - images, - containers, - volumes, - build_cache: buildCache, - }, - { params }, - ); - - const results = Array.isArray(response.data) ? response.data : []; - - const formatSize = (bytes: number) => { - if (bytes >= 1024 * 1024 * 1024) { - return `${(bytes / 1024 / 1024 / 1024).toFixed(2)} GB`; - } - return `${(bytes / 1024 / 1024).toFixed(2)} MB`; - }; - - return JSON.stringify( - { - success: true, - message: "Storage pruned successfully", - results: results.map((r: any) => ({ - worker: r.worker_name, - images_deleted: r.images_deleted, - containers_deleted: r.containers_deleted, - volumes_deleted: r.volumes_deleted, - build_cache_deleted: r.build_cache_deleted, - space_reclaimed: formatSize(r.space_reclaimed), - })), - }, - null, - 2, - ); -} - -// ============================================================================ -// Auto-Tuning Tools -// ============================================================================ - -async function listTuningJobs(status?: string): Promise { - const response = await api.get("/auto-tuning/jobs"); - let jobs = response.data.items || []; - - if (status) { - jobs = jobs.filter((j: any) => j.status === status); - } - - return JSON.stringify( - { - total: jobs.length, - jobs: jobs.map((j: any) => ({ - id: j.id, - model: j.model_name, - worker: j.worker_name, - optimization_target: j.optimization_target, - status: j.status, - progress: j.progress - ? { - step: j.progress.step, - total_steps: j.progress.total_steps, - step_name: j.progress.step_name, - configs_tested: j.progress.configs_tested, - configs_total: j.progress.configs_total, - best_score: j.progress.best_score_so_far, - } - : null, - created_at: j.created_at, - })), - note: "View detailed results at [Auto-Tuning](/auto-tuning) page", - }, - null, - 2, - ); -} - -async function startAutoTuning( - modelId: number, - workerId: number, - optimizationTarget?: string, - modelConfig?: ChatModelConfig, -): Promise { - if (!modelId || !workerId) { - return "Error: model_id and worker_id are required"; - } - - // Build the LLM configuration for the agent - const llmConfig: Record = {}; - if (modelConfig) { - if (modelConfig.type === "deployment" && modelConfig.deploymentId) { - // Use local deployment - llmConfig.deployment_id = modelConfig.deploymentId; - } else if (modelConfig.type === "custom" && modelConfig.endpoint) { - // Use custom endpoint - llmConfig.base_url = modelConfig.endpoint; - llmConfig.api_key = modelConfig.apiKey || ""; - llmConfig.model = modelConfig.modelId || modelConfig.name; - } - } - - const response = await api.post("/auto-tuning/jobs", { - model_id: modelId, - worker_id: workerId, - optimization_target: optimizationTarget || "balanced", - llm_config: Object.keys(llmConfig).length > 0 ? llmConfig : undefined, - }); - - const job = response.data; - - return JSON.stringify( - { - success: true, - message: - "Auto-tuning job started! The agent will use the selected model for reasoning.", - job: { - id: job.id, - model: job.model_name, - worker: job.worker_name, - optimization_target: job.optimization_target, - status: job.status, - agent_llm: modelConfig?.name || "auto-detected", - }, - note: "Track progress at [Auto-Tuning](/auto-tuning) page. This may take several minutes depending on the number of configurations to test.", - }, - null, - 2, - ); -} - -async function getTuningJob(jobId: number): Promise { - if (!jobId) { - return "Error: job_id is required"; - } - - const response = await api.get(`/auto-tuning/jobs/${jobId}`); - const job = response.data; - - return JSON.stringify( - { - id: job.id, - model: job.model_name, - worker: job.worker_name, - optimization_target: job.optimization_target, - status: job.status, - status_message: job.status_message, - progress: job.progress - ? { - step: job.progress.step, - total_steps: job.progress.total_steps, - step_name: job.progress.step_name, - step_description: job.progress.step_description, - configs_tested: job.progress.configs_tested, - configs_total: job.progress.configs_total, - current_config: job.progress.current_config, - best_config_so_far: job.progress.best_config_so_far, - best_score_so_far: job.progress.best_score_so_far, - } - : null, - best_config: job.best_config, - all_results: job.all_results, - created_at: job.created_at, - completed_at: job.completed_at, - }, - null, - 2, - ); -} - -async function cancelTuningJob(jobId: number): Promise { - if (!jobId) { - return "Error: job_id is required"; - } - - await api.post(`/auto-tuning/jobs/${jobId}/cancel`); - - return JSON.stringify( - { - success: true, - message: `Tuning job ${jobId} cancelled successfully`, - }, - null, - 2, - ); -} - -async function queryKnowledgeBase( - modelName?: string, - modelFamily?: string, - gpuModel?: string, - optimizationTarget?: string, - limit?: number, -): Promise { - const response = await api.post("/auto-tuning/knowledge/query", { - model_name: modelName, - model_family: modelFamily, - gpu_model: gpuModel, - optimization_target: optimizationTarget || "balanced", - limit: limit || 10, - }); - - const data = response.data; - - return JSON.stringify( - { - total: data.total, - query: data.query, - results: data.items.map((r: any) => ({ - gpu: `${r.gpu_count}x ${r.gpu_model}`, - total_vram_gb: r.total_vram_gb, - model: r.model_name, - model_family: r.model_family, - model_params_b: r.model_params_b, - engine: r.engine, - quantization: r.quantization, - tensor_parallel: r.tensor_parallel, - throughput_tps: r.throughput_tps, - ttft_ms: r.ttft_ms, - tpot_ms: r.tpot_ms, - score: r.score, - })), - note: "These are benchmark results from similar configurations. Use this to guide your deployment decisions.", - }, - null, - 2, - ); -} - -async function runBenchmark( - deploymentId: number, - testType?: string, - durationSeconds?: number, - inputLength?: number, - outputLength?: number, - concurrency?: number, -): Promise { - if (!deploymentId) { - return "Error: deployment_id is required"; - } - - const response = await api.post("/auto-tuning/benchmarks/run", { - deployment_id: deploymentId, - test_type: testType || "throughput", - duration_seconds: durationSeconds || 60, - input_length: inputLength || 512, - output_length: outputLength || 128, - concurrency: concurrency || 1, - }); - - const result = response.data; - - return JSON.stringify( - { - success: true, - message: "Benchmark completed", - benchmark: { - id: result.id, - deployment_id: result.deployment_id, - test_type: result.test_type, - duration_seconds: result.test_duration_seconds, - config: { - input_length: result.input_length, - output_length: result.output_length, - concurrency: result.concurrency, - }, - metrics: { - throughput_tps: result.metrics?.throughput_tps, - ttft_ms: result.metrics?.ttft_ms, - tpot_ms: result.metrics?.tpot_ms, - total_latency_ms: result.metrics?.total_latency_ms, - gpu_utilization: result.metrics?.gpu_utilization, - vram_usage_gb: result.metrics?.vram_usage_gb, - }, - }, - note: "Higher throughput (TPS) is better. Lower latency (TTFT, TPOT) is better.", - }, - null, - 2, - ); -} diff --git a/frontend/src/components/chat-panel/useChat.ts b/frontend/src/components/chat-panel/useChat.ts deleted file mode 100644 index 4b3353e..0000000 --- a/frontend/src/components/chat-panel/useChat.ts +++ /dev/null @@ -1,762 +0,0 @@ -/** - * useChat Hook - * - * Encapsulates chat logic for streaming conversations with LLM endpoints. - * Supports deployments, Semantic Router, and custom OpenAI-compatible endpoints. - * Includes system context injection for AI assistant capabilities. - * Supports Tool Calling for LLM to interact with LMStack system. - */ -import { useState, useRef, useCallback, useEffect } from "react"; -import { message } from "antd"; -import { generateMessageId } from "../chat"; -import type { ChatMessage } from "../chat"; -import type { ChatModelConfig } from "./types"; -import { STORAGE_KEYS } from "../../constants"; -import { - fetchSystemContext, - formatSystemPrompt, - type SystemContext, -} from "./systemContext"; -import { - CHAT_TOOLS, - executeTool, - requiresConfirmation, - getToolMeta, - type ToolCall, - type ToolResult, - type PendingToolExecution, -} from "./tools"; - -interface UseChatOptions { - /** Called when a new message is added */ - onMessageAdded?: (message: ChatMessage) => void; - /** Called when streaming completes */ - onStreamComplete?: (userMsg: ChatMessage, assistantMsg: ChatMessage) => void; -} - -interface UseChatReturn { - messages: ChatMessage[]; - isStreaming: boolean; - isExecutingTool: boolean; - currentToolName: string | null; - pendingTools: PendingToolExecution[]; - showConfirmModal: boolean; - systemContext: SystemContext | null; - refreshContext: () => Promise; - sendMessage: (content: string, model: ChatModelConfig) => Promise; - stopStreaming: () => void; - clearMessages: () => void; - setMessages: React.Dispatch>; - confirmToolExecution: () => void; - cancelToolExecution: () => void; -} - -/** - * Hook for managing chat state and streaming - */ -export function useChat(options: UseChatOptions = {}): UseChatReturn { - const { onMessageAdded, onStreamComplete } = options; - - const [messages, setMessages] = useState([]); - const [isStreaming, setIsStreaming] = useState(false); - const [isExecutingTool, setIsExecutingTool] = useState(false); - const [currentToolName, setCurrentToolName] = useState(null); - const [systemContext, setSystemContext] = useState( - null, - ); - - // Tool confirmation state - const [pendingTools, setPendingTools] = useState([]); - const [showConfirmModal, setShowConfirmModal] = useState(false); - const pendingToolResolveRef = useRef<((confirmed: boolean) => void) | null>( - null, - ); - - const abortControllerRef = useRef(null); - - // Fetch system context on mount and periodically refresh - const refreshContext = useCallback(async () => { - const context = await fetchSystemContext(); - setSystemContext(context); - }, []); - - useEffect(() => { - refreshContext(); - const interval = setInterval(refreshContext, 30000); // Refresh every 30s - return () => clearInterval(interval); - }, [refreshContext]); - - /** - * Check if model uses proxy endpoint - */ - const isProxyRequest = useCallback((model: ChatModelConfig): boolean => { - return model.type === "custom"; - }, []); - - /** - * Get the chat endpoint URL based on model config - */ - const getEndpointUrl = useCallback((model: ChatModelConfig): string => { - switch (model.type) { - case "deployment": - return `/api/deployments/${model.deploymentId}/chat`; - case "semantic-router": - return `/api/semantic-router/chat`; - case "custom": - // Use backend proxy to avoid CORS issues - return `/api/chat-proxy`; - default: - return ""; - } - }, []); - - /** - * Get headers for the request - */ - const getHeaders = useCallback((): HeadersInit => { - const headers: HeadersInit = { - "Content-Type": "application/json", - }; - - const token = localStorage.getItem(STORAGE_KEYS.TOKEN); - if (token) { - headers["Authorization"] = `Bearer ${token}`; - } - - return headers; - }, []); - - /** - * Stream a chat completion request - */ - const streamChatCompletion = useCallback( - async ( - endpoint: string, - requestBody: any, - assistantMessageId: string, - signal: AbortSignal, - ): Promise<{ - content: string; - thinking: string; - model?: string; - toolCalls?: ToolCall[]; - finishReason?: string; - }> => { - const response = await fetch(endpoint, { - method: "POST", - headers: getHeaders(), - body: JSON.stringify(requestBody), - signal, - }); - - if (!response.ok) { - throw new Error(`API error: ${response.status} ${response.statusText}`); - } - - const reader = response.body?.getReader(); - if (!reader) throw new Error("No response body"); - - const decoder = new TextDecoder(); - let accumulatedContent = ""; - let accumulatedThinking = ""; - let responseModel: string | undefined; - let buffer = ""; - let finishReason: string | undefined; - - // Track tool calls being accumulated - const toolCallsMap: Map< - number, - { id: string; name: string; arguments: string } - > = new Map(); - - // eslint-disable-next-line no-constant-condition - while (true) { - const { done, value } = await reader.read(); - if (done) break; - - buffer += decoder.decode(value, { stream: true }); - const lines = buffer.split("\n"); - buffer = lines.pop() || ""; - - for (const line of lines) { - const trimmedLine = line.trim(); - if (!trimmedLine || !trimmedLine.startsWith("data:")) continue; - - const data = trimmedLine.slice(5).trim(); - if (data === "[DONE]") continue; - - try { - const parsed = JSON.parse(data); - - if (parsed.error) { - throw new Error(parsed.error.message || "API error"); - } - - const choice = parsed.choices?.[0]; - const delta = choice?.delta; - - // Track finish reason - if (choice?.finish_reason) { - finishReason = choice.finish_reason; - } - - // Handle regular content - const deltaContent = delta?.content || ""; - accumulatedContent += deltaContent; - - // Handle thinking/reasoning content (for models like DeepSeek-R1) - const deltaThinking = delta?.reasoning_content || ""; - accumulatedThinking += deltaThinking; - - // Handle tool calls streaming - if (delta?.tool_calls) { - for (const tc of delta.tool_calls) { - const index = tc.index ?? 0; - if (!toolCallsMap.has(index)) { - toolCallsMap.set(index, { - id: tc.id || "", - name: "", - arguments: "", - }); - } - const existing = toolCallsMap.get(index)!; - if (tc.id) existing.id = tc.id; - if (tc.function?.name) existing.name = tc.function.name; - if (tc.function?.arguments) - existing.arguments += tc.function.arguments; - } - } - - if (!responseModel && parsed.model) { - responseModel = parsed.model; - } - - setMessages((prev) => - prev.map((m) => - m.id === assistantMessageId - ? { - ...m, - content: accumulatedContent, - thinking: accumulatedThinking || undefined, - model: responseModel, - } - : m, - ), - ); - } catch { - // Skip invalid JSON - } - } - } - - // Process remaining buffer - if (buffer.trim().startsWith("data:")) { - const data = buffer.trim().slice(5).trim(); - if (data !== "[DONE]") { - try { - const parsed = JSON.parse(data); - const choice = parsed.choices?.[0]; - const delta = choice?.delta; - if (choice?.finish_reason) finishReason = choice.finish_reason; - const deltaContent = delta?.content || ""; - const deltaThinking = delta?.reasoning_content || ""; - accumulatedContent += deltaContent; - accumulatedThinking += deltaThinking; - if (!responseModel && parsed.model) { - responseModel = parsed.model; - } - } catch { - // Skip invalid JSON - } - } - } - - // Convert tool calls map to array - const toolCalls: ToolCall[] = []; - for (const [, tc] of toolCallsMap) { - if (tc.id && tc.name) { - toolCalls.push({ - id: tc.id, - type: "function", - function: { name: tc.name, arguments: tc.arguments }, - }); - } - } - - return { - content: accumulatedContent, - thinking: accumulatedThinking, - model: responseModel, - toolCalls: toolCalls.length > 0 ? toolCalls : undefined, - finishReason, - }; - }, - [getHeaders], - ); - - /** - * Send a message and stream the response with tool calling support - */ - const sendMessage = useCallback( - async (content: string, model: ChatModelConfig) => { - if (!content.trim() || isStreaming) return; - - const endpoint = getEndpointUrl(model); - if (!endpoint) { - message.error("Invalid endpoint configuration"); - return; - } - - setIsStreaming(true); - - const userMessage: ChatMessage = { - id: generateMessageId(), - role: "user", - content: content.trim(), - timestamp: new Date(), - }; - - const assistantMessage: ChatMessage = { - id: generateMessageId(), - role: "assistant", - content: "", - timestamp: new Date(), - }; - - setMessages((prev) => [...prev, userMessage, assistantMessage]); - onMessageAdded?.(userMessage); - - try { - abortControllerRef.current = new AbortController(); - - const modelName = model.modelId || model.name; - - // Build messages array with system context - type ChatMessagePayload = { - role: string; - content: string | null; - tool_calls?: ToolCall[]; - tool_call_id?: string; - }; - const chatMessages: ChatMessagePayload[] = []; - - // Add system prompt with current context - if (systemContext) { - chatMessages.push({ - role: "system", - content: formatSystemPrompt(systemContext), - }); - } - - // Add conversation history - chatMessages.push( - ...messages.map((m) => ({ role: m.role, content: m.content })), - { role: "user", content: content.trim() }, - ); - - // Build the chat payload with tools - const chatPayload = { - model: modelName, - messages: chatMessages, - stream: true, - temperature: 0.7, - tools: CHAT_TOOLS, - tool_choice: "auto" as const, - }; - - // Build request body - const requestBody = isProxyRequest(model) - ? { - endpoint: model.endpoint, - api_key: model.apiKey || null, - payload: chatPayload, - } - : chatPayload; - - // Track the current message being streamed to - let currentMessageId = assistantMessage.id; - - // First streaming request - let result = await streamChatCompletion( - endpoint, - requestBody, - currentMessageId, - abortControllerRef.current.signal, - ); - - // Tool calling loop - continue until no more tool calls - let iterationCount = 0; - const maxIterations = 10; // Prevent infinite loops - - while ( - result.toolCalls && - result.toolCalls.length > 0 && - iterationCount < maxIterations - ) { - iterationCount++; - - // Check if any tool requires confirmation - const toolsNeedingConfirmation = result.toolCalls.filter((tc) => - requiresConfirmation(tc.function.name), - ); - - if (toolsNeedingConfirmation.length > 0) { - // Prepare pending tools for confirmation - const pending: PendingToolExecution[] = - toolsNeedingConfirmation.map((tc) => { - let parsedArgs: Record = {}; - try { - parsedArgs = JSON.parse(tc.function.arguments); - } catch { - parsedArgs = { raw: tc.function.arguments }; - } - return { - toolCall: tc, - parsedArgs, - meta: getToolMeta(tc.function.name), - }; - }); - - setPendingTools(pending); - setShowConfirmModal(true); - - // Update message to show waiting for confirmation - setMessages((prev) => - prev.map((m) => - m.id === currentMessageId - ? { - ...m, - content: result.content || "", - toolCalls: result.toolCalls, - } - : m, - ), - ); - - // Wait for user confirmation - const confirmed = await new Promise((resolve) => { - pendingToolResolveRef.current = resolve; - }); - - setPendingTools([]); - setShowConfirmModal(false); - pendingToolResolveRef.current = null; - - if (!confirmed) { - // User cancelled - stop tool execution and inform LLM - const cancelledResults: ToolResult[] = - toolsNeedingConfirmation.map((tc) => ({ - tool_call_id: tc.id, - role: "tool" as const, - content: JSON.stringify({ - success: false, - message: "User cancelled the operation", - }), - })); - - // Execute query tools that don't need confirmation - const queryTools = result.toolCalls.filter( - (tc) => !requiresConfirmation(tc.function.name), - ); - for (const tc of queryTools) { - const queryResult = await executeTool(tc, model); - cancelledResults.push(queryResult); - } - - // Update message - setMessages((prev) => - prev.map((m) => - m.id === currentMessageId - ? { - ...m, - content: result.content || "", - toolCalls: result.toolCalls, - } - : m, - ), - ); - - // Build new messages with cancelled results - const newMessages: ChatMessagePayload[] = [ - ...chatMessages, - { - role: "assistant", - content: result.content || null, - tool_calls: result.toolCalls, - }, - ...cancelledResults.map((tr) => ({ - role: "tool", - content: tr.content, - tool_call_id: tr.tool_call_id, - })), - ]; - - // Continue to let LLM know about cancellation - const continuationMessage: ChatMessage = { - id: generateMessageId(), - role: "assistant", - content: "", - timestamp: new Date(), - }; - - setMessages((prev) => - prev - .map((m) => - m.id === currentMessageId - ? { - ...m, - content: result.content || "", - toolCalls: result.toolCalls, - } - : m, - ) - .concat(continuationMessage), - ); - - // Update current message ID to the continuation - currentMessageId = continuationMessage.id; - - const continuationPayload = { - model: modelName, - messages: newMessages, - stream: true, - temperature: 0.7, - tools: CHAT_TOOLS, - tool_choice: "auto" as const, - }; - - const continuationRequestBody = isProxyRequest(model) - ? { - endpoint: model.endpoint, - api_key: model.apiKey || null, - payload: continuationPayload, - } - : continuationPayload; - - result = await streamChatCompletion( - endpoint, - continuationRequestBody, - currentMessageId, - abortControllerRef.current.signal, - ); - - chatMessages.length = 0; - chatMessages.push(...newMessages); - continue; - } - } - - // Execute all tool calls (confirmed or query-only) - setIsExecutingTool(true); - const toolResults: ToolResult[] = []; - - for (const toolCall of result.toolCalls) { - setCurrentToolName(toolCall.function.name); - - // Update message to show tool execution - setMessages((prev) => - prev.map((m) => - m.id === currentMessageId - ? { - ...m, - content: result.content || "", - toolCalls: result.toolCalls, - } - : m, - ), - ); - - const toolResult = await executeTool(toolCall, model); - toolResults.push(toolResult); - - // Refresh system context after tool execution (data may have changed) - await refreshContext(); - } - - setIsExecutingTool(false); - setCurrentToolName(null); - - // Build new messages array with tool calls and results - const newMessages: ChatMessagePayload[] = [ - ...chatMessages, - // Assistant message with tool calls - { - role: "assistant", - content: result.content || null, - tool_calls: result.toolCalls, - }, - // Tool results - ...toolResults.map((tr) => ({ - role: "tool", - content: tr.content, - tool_call_id: tr.tool_call_id, - })), - ]; - - // Create new assistant message for continued response - const continuationMessage: ChatMessage = { - id: generateMessageId(), - role: "assistant", - content: "", - timestamp: new Date(), - }; - - setMessages((prev) => { - // Update the current assistant message and add continuation - return prev - .map((m) => - m.id === currentMessageId - ? { - ...m, - content: result.content || "", - toolCalls: result.toolCalls, - } - : m, - ) - .concat(continuationMessage); - }); - - // Update current message ID to the continuation - currentMessageId = continuationMessage.id; - - // Build new chat payload with tool results - const continuationPayload = { - model: modelName, - messages: newMessages, - stream: true, - temperature: 0.7, - tools: CHAT_TOOLS, - tool_choice: "auto" as const, - }; - - const continuationRequestBody = isProxyRequest(model) - ? { - endpoint: model.endpoint, - api_key: model.apiKey || null, - payload: continuationPayload, - } - : continuationPayload; - - // Continue streaming - result = await streamChatCompletion( - endpoint, - continuationRequestBody, - currentMessageId, - abortControllerRef.current.signal, - ); - - // Update chat messages for next iteration if needed - chatMessages.length = 0; - chatMessages.push(...newMessages); - } - - // Final update - only update the current (last) message - setMessages((prev) => - prev.map((m) => { - if (m.id === currentMessageId) { - return { - ...m, - content: result.content, - thinking: result.thinking || undefined, - model: result.model, - // Only set toolCalls if there are any (don't overwrite with undefined) - ...(result.toolCalls ? { toolCalls: result.toolCalls } : {}), - }; - } - return m; - }), - ); - - // Get the final message for callback - const finalAssistantMsg: ChatMessage = { - id: currentMessageId, - role: "assistant", - content: result.content, - thinking: result.thinking || undefined, - model: result.model, - toolCalls: result.toolCalls, - timestamp: new Date(), - }; - - onStreamComplete?.(userMessage, finalAssistantMsg); - } catch (error: unknown) { - const err = error as Error; - if (err.name === "AbortError") { - message.info("Generation stopped"); - } else { - message.error(`Error: ${err.message}`); - // Remove the initial assistant message on error - setMessages((prev) => - prev.filter((m) => m.id !== assistantMessage.id), - ); - } - } finally { - setIsStreaming(false); - setIsExecutingTool(false); - setCurrentToolName(null); - abortControllerRef.current = null; - } - }, - [ - messages, - isStreaming, - systemContext, - getEndpointUrl, - isProxyRequest, - getHeaders, - onMessageAdded, - onStreamComplete, - streamChatCompletion, - refreshContext, - ], - ); - - /** - * Stop the current streaming response - */ - const stopStreaming = useCallback(() => { - abortControllerRef.current?.abort(); - }, []); - - /** - * Clear all messages - */ - const clearMessages = useCallback(() => { - setMessages([]); - }, []); - - /** - * Confirm pending tool execution - */ - const confirmToolExecution = useCallback(() => { - if (pendingToolResolveRef.current) { - pendingToolResolveRef.current(true); - } - }, []); - - /** - * Cancel pending tool execution - */ - const cancelToolExecution = useCallback(() => { - if (pendingToolResolveRef.current) { - pendingToolResolveRef.current(false); - } - }, []); - - return { - messages, - isStreaming, - isExecutingTool, - currentToolName, - pendingTools, - showConfirmModal, - systemContext, - refreshContext, - sendMessage, - stopStreaming, - clearMessages, - setMessages, - confirmToolExecution, - cancelToolExecution, - }; -} diff --git a/frontend/src/components/logos/index.tsx b/frontend/src/components/logos/index.tsx index 2fba673..ff49e6b 100644 --- a/frontend/src/components/logos/index.tsx +++ b/frontend/src/components/logos/index.tsx @@ -71,6 +71,62 @@ export function HuggingFaceLogo({ ); } +/** + * MLX Logo - Apple's ML framework for Apple Silicon + */ +export function MLXLogo({ height = 16, style }: Omit) { + // Use Apple-style gradient colors + const gradientId = `mlx-gradient-${Math.random().toString(36).substr(2, 9)}`; + return ( + + + + + + + + + + MLX + + + ); +} + +/** + * Llama.cpp Logo + */ +export function LlamaCppLogo({ + height = 16, + isDark = false, + style, +}: LogoProps) { + const textColor = isDark ? "#ffffff" : "#333333"; + return ( + + + llama.cpp + + + ); +} + // ============================================================================= // Icons // ============================================================================= @@ -129,5 +185,15 @@ export function getBackendConfig( color: tagColor, icon: , }, + mlx: { + label: "MLX", + color: tagColor, + icon: , + }, + llama_cpp: { + label: "llama.cpp", + color: tagColor, + icon: , + }, }; } diff --git a/frontend/src/pages/AutoTuning.tsx b/frontend/src/pages/AutoTuning.tsx index e1bdfd5..0bdedfb 100644 --- a/frontend/src/pages/AutoTuning.tsx +++ b/frontend/src/pages/AutoTuning.tsx @@ -920,7 +920,7 @@ export default function AutoTuning() { } + prefix={} /> @@ -931,9 +931,9 @@ export default function AutoTuning() { value={runningJobs} prefix={ runningJobs > 0 ? ( - + ) : ( - + ) } /> @@ -944,7 +944,7 @@ export default function AutoTuning() { } + prefix={} /> @@ -1214,7 +1214,18 @@ export default function AutoTuning() { - + {logModal && + [ + "pending", + "analyzing", + "querying_kb", + "exploring", + "benchmarking", + ].includes(logModal.status) ? ( + + ) : ( + + )} Live Logs {logModal && ( <> diff --git a/frontend/src/pages/Deployments.tsx b/frontend/src/pages/Deployments.tsx index bfe0b5f..922e529 100644 --- a/frontend/src/pages/Deployments.tsx +++ b/frontend/src/pages/Deployments.tsx @@ -1,6 +1,7 @@ import { useEffect, useState, useCallback, useRef } from "react"; import { useSearchParams } from "react-router-dom"; import { + Alert, Button, Card, Form, @@ -30,9 +31,7 @@ import { } from "@ant-design/icons"; import { useAppTheme } from "../hooks/useTheme"; import { - VllmLogo, OllamaLogo, - SGLangLogo, HuggingFaceLogo, DockerIcon, getBackendConfig, @@ -99,14 +98,41 @@ export default function Deployments() { // Get the selected model const selectedModel = models.find((m) => m.id === selectedModelId); - // Determine available backends based on model source - const availableBackends = - selectedModel?.source === "ollama" - ? (["ollama"] as const) - : (["vllm", "sglang"] as const); - // Get the selected worker's GPU info const selectedWorker = workers.find((w) => w.id === selectedWorkerId); + + // Determine available backends based on model source and worker capabilities + const availableBackends = (() => { + // Start with model-based restrictions + if (selectedModel?.source === "ollama") { + return ["ollama"] as const; + } + + // If no worker selected, show all HuggingFace-compatible backends + if (!selectedWorker) { + return ["vllm", "sglang"] as const; + } + + // macOS workers only support Ollama (vLLM/SGLang require NVIDIA GPU) + if (selectedWorker.os_type === "darwin") { + return ["ollama"] as const; + } + + // Use worker's available_backends if provided + if ( + selectedWorker.available_backends && + selectedWorker.available_backends.length > 0 + ) { + // Filter for HuggingFace-compatible backends from worker's list + const hfBackends = selectedWorker.available_backends.filter((b) => + ["vllm", "sglang", "ollama"].includes(b), + ); + return hfBackends.length > 0 ? hfBackends : (["vllm", "sglang"] as const); + } + + // Default fallback for Linux workers + return ["vllm", "sglang"] as const; + })(); const workerGpus = selectedWorker?.gpu_info || []; // Calculate total available GPU memory for selected GPUs @@ -351,11 +377,7 @@ export default function Deployments() { {record.model?.source !== "ollama" && renderSourceTag(record.model?.source, "small")} - {backend === "vllm" && } - {backend === "sglang" && } - {backend === "ollama" && ( - - )} + {config.icon} {config.label} {record.model?.name} @ {record.worker?.name} @@ -543,11 +565,7 @@ export default function Deployments() { {record.model?.source !== "ollama" && renderSourceTag(record.model?.source, "small")} - {backend === "vllm" && } - {backend === "sglang" && } - {backend === "ollama" && ( - - )} + {config.icon} {config.label} @@ -836,7 +854,9 @@ export default function Deployments() { extra={ selectedModel?.source === "ollama" ? "Ollama models can only use Ollama backend" - : "HuggingFace models can use vLLM or SGLang" + : selectedWorker?.os_type === "darwin" + ? "macOS workers only support Ollama backend" + : "HuggingFace models can use vLLM or SGLang" } >