From 4e8a402c2a89d84768990dfd919e682ec01ef259 Mon Sep 17 00:00:00 2001 From: rickychen-infinirc Date: Sun, 1 Feb 2026 20:29:44 +0800 Subject: [PATCH 1/9] feat: add macOS native worker support with Ollama --- backend/app/api/agent.py | 28 +- backend/app/api/auto_tuning.py | 239 +- backend/app/api/workers.py | 65 +- backend/app/models/llm_model.py | 5 + backend/app/models/worker.py | 61 +- backend/app/schemas/tuning.py | 113 + backend/app/schemas/worker.py | 19 + backend/app/services/bayesian_tuner.py | 167 +- backend/app/services/benchmark/__init__.py | 46 + backend/app/services/benchmark/config.py | 147 ++ backend/app/services/benchmark/metrics.py | 254 ++ backend/app/services/benchmark/runner.py | 484 ++++ backend/app/services/benchmark/saturation.py | 289 +++ backend/app/services/deployer.py | 241 +- backend/app/services/mcp/agent.py | 19 +- .../migrations/011_add_worker_mac_support.py | 85 + .../src/components/DeploymentAdvancedForm.tsx | 8 + .../components/chat-panel/AgentChatView.tsx | 316 +-- .../src/components/chat-panel/ChatPanel.tsx | 684 +---- .../chat-panel/ToolConfirmModal.tsx | 218 -- frontend/src/components/chat-panel/index.ts | 22 +- .../components/chat-panel/systemContext.ts | 515 ---- frontend/src/components/chat-panel/tools.ts | 2252 ----------------- frontend/src/components/chat-panel/useChat.ts | 762 ------ frontend/src/components/logos/index.tsx | 66 + frontend/src/pages/AutoTuning.tsx | 21 +- frontend/src/pages/Deployments.tsx | 138 +- frontend/src/pages/Workers.tsx | 184 +- frontend/src/types/worker.ts | 19 + mcp-server/src/client.ts | 42 + mcp-server/src/formatters.ts | 15 +- mcp-server/src/index.ts | 121 +- mcp-server/src/tools/benchmark.ts | 172 ++ worker/agent.py | 51 +- worker/docker_ops/system.py | 147 ++ worker/native_ops/__init__.py | 7 + worker/native_ops/mlx.py | 198 ++ worker/native_ops/ollama.py | 150 ++ worker/native_ops/process_manager.py | 348 +++ worker/requirements.txt | 19 +- worker/routes/__init__.py | 3 + worker/routes/native.py | 169 ++ worker/run_native.py | 45 + 43 files changed, 4168 insertions(+), 4786 deletions(-) create mode 100644 backend/app/services/benchmark/__init__.py create mode 100644 backend/app/services/benchmark/config.py create mode 100644 backend/app/services/benchmark/metrics.py create mode 100644 backend/app/services/benchmark/runner.py create mode 100644 backend/app/services/benchmark/saturation.py create mode 100644 backend/migrations/011_add_worker_mac_support.py delete mode 100644 frontend/src/components/chat-panel/ToolConfirmModal.tsx delete mode 100644 frontend/src/components/chat-panel/systemContext.ts delete mode 100644 frontend/src/components/chat-panel/tools.ts delete mode 100644 frontend/src/components/chat-panel/useChat.ts create mode 100644 worker/native_ops/__init__.py create mode 100644 worker/native_ops/mlx.py create mode 100644 worker/native_ops/ollama.py create mode 100644 worker/native_ops/process_manager.py create mode 100644 worker/routes/native.py create mode 100644 worker/run_native.py diff --git a/backend/app/api/agent.py b/backend/app/api/agent.py index 98bbe34..7f46ce0 100644 --- a/backend/app/api/agent.py +++ b/backend/app/api/agent.py @@ -169,12 +169,15 @@ async def get_or_create_agent( llm_api_key = llm_config.api_key llm_model = llm_config.model + # Track if tool calling is supported + supports_tool_calling = True + if llm_config.provider == "system" and llm_config.deployment_id: - # Look up the deployment + # Look up the deployment with model info result = await db.execute( select(Deployment) .where(Deployment.id == llm_config.deployment_id) - .options(selectinload(Deployment.worker)) + .options(selectinload(Deployment.worker), selectinload(Deployment.model)) ) deployment = result.scalar_one_or_none() @@ -184,9 +187,16 @@ async def get_or_create_agent( raise HTTPException(status_code=400, detail="Deployment is not running") worker = deployment.worker - llm_base_url = f"http://{worker.host}:{deployment.port}/v1" + # Extract IP from worker address (format: IP:Port) + worker_ip = worker.address.split(":")[0] + llm_base_url = f"http://{worker_ip}:{deployment.port}/v1" llm_api_key = "dummy" - llm_model = llm_config.model or "default" + # Use the actual model_id from the LLMModel (e.g., "Qwen/Qwen2.5-0.5B-Instruct") + llm_model = deployment.model.model_id + + # Check if deployment has tool calling enabled + extra_params = deployment.extra_params or {} + supports_tool_calling = extra_params.get("enable-auto-tool-choice", False) elif llm_config.provider == "openai": llm_base_url = "https://api.openai.com/v1" @@ -216,6 +226,7 @@ async def get_or_create_agent( llm_model=llm_model, mcp_api_url=mcp_api_url, mcp_api_token=api_token, + supports_tool_calling=supports_tool_calling, ) await agent.initialize() @@ -443,7 +454,7 @@ async def agent_chat_simple( result = await db.execute( select(Deployment) .where(Deployment.id == request.llm_config.deployment_id) - .options(selectinload(Deployment.worker)) + .options(selectinload(Deployment.worker), selectinload(Deployment.model)) ) deployment = result.scalar_one_or_none() @@ -453,9 +464,12 @@ async def agent_chat_simple( raise HTTPException(status_code=400, detail="Deployment is not running") worker = deployment.worker - llm_base_url = f"http://{worker.host}:{deployment.port}/v1" + # Extract IP from worker address (format: IP:Port) + worker_ip = worker.address.split(":")[0] + llm_base_url = f"http://{worker_ip}:{deployment.port}/v1" llm_api_key = "dummy" - llm_model = request.llm_config.model or "default" + # Use the actual model_id from the LLMModel (e.g., "Qwen/Qwen2.5-0.5B-Instruct") + llm_model = deployment.model.model_id elif request.llm_config.provider == "openai": llm_base_url = "https://api.openai.com/v1" diff --git a/backend/app/api/auto_tuning.py b/backend/app/api/auto_tuning.py index f6b2f99..fab5c4d 100644 --- a/backend/app/api/auto_tuning.py +++ b/backend/app/api/auto_tuning.py @@ -34,10 +34,17 @@ BenchmarkRequest, BenchmarkResultListResponse, BenchmarkResultResponse, + ComprehensiveBenchmarkMetrics, + ComprehensiveBenchmarkRequest, + ComprehensiveBenchmarkResponse, + ConcurrencyLevelResult, KnowledgeQuery, KnowledgeQueryResponse, KnowledgeRecord, KnowledgeSaveRequest, + LatencyPercentiles, + SaturationDetectionRequest, + SaturationDetectionResponse, TuningJobCreate, TuningJobListResponse, TuningJobProgress, @@ -397,6 +404,198 @@ async def list_benchmark_results( ) +# ============================================================================ +# Comprehensive Benchmark Endpoints +# ============================================================================ + + +@router.post("/benchmarks/comprehensive", response_model=ComprehensiveBenchmarkResponse) +async def run_comprehensive_benchmark( + request: ComprehensiveBenchmarkRequest, + db: AsyncSession = Depends(get_db), + current_user: User = Depends(require_operator), +): + """ + Run a comprehensive benchmark with detailed metrics. + + Returns metrics including: + - TTFT, ITL, TPOT with percentiles (p50, p90, p95, p99) + - Throughput (tokens/sec, requests/sec) + - Success rate and error statistics + """ + from app.services.benchmark import BenchmarkConfig, BenchmarkRunner, LoadPattern + + # Verify deployment exists and is running + result = await db.execute( + select(Deployment) + .where(Deployment.id == request.deployment_id) + .options( + selectinload(Deployment.model), + selectinload(Deployment.worker), + ) + ) + deployment = result.scalar_one_or_none() + + if not deployment: + raise HTTPException(status_code=404, detail="Deployment not found") + + if deployment.status != DeploymentStatus.RUNNING.value: + raise HTTPException(status_code=400, detail="Deployment is not running") + + worker = deployment.worker + if not worker: + raise HTTPException(status_code=400, detail="Worker not found") + + endpoint = f"http://{worker.host}:{deployment.port}/v1" + model_name = deployment.model.model_id if deployment.model else "default" + + # Configure benchmark + config = BenchmarkConfig( + endpoint=endpoint, + model_name=model_name, + load_pattern=LoadPattern.FIXED, + concurrency=request.concurrency, + num_requests=request.num_requests, + warmup_requests=request.warmup_requests, + prompt_tokens=request.prompt_tokens, + output_tokens=request.output_tokens, + custom_prompt=request.custom_prompt, + stream=True, + ) + + # Run benchmark + runner = BenchmarkRunner(config) + bench_result = await runner.run() + + # Convert metrics to response format + m = bench_result.metrics + + def to_percentiles(lm) -> LatencyPercentiles: + return LatencyPercentiles( + mean=lm.mean, + median=lm.median, + min=lm.min, + max=lm.max, + std=lm.std, + p50=lm.p50, + p90=lm.p90, + p95=lm.p95, + p99=lm.p99, + ) + + metrics = ComprehensiveBenchmarkMetrics( + ttft=to_percentiles(m.ttft), + itl=to_percentiles(m.itl), + tpot=to_percentiles(m.tpot), + e2e_latency=to_percentiles(m.e2e_latency), + throughput_tps=m.throughput_tps, + throughput_rps=m.throughput_rps, + output_tps=m.output_tps, + total_requests=m.total_requests, + successful_requests=m.successful_requests, + failed_requests=m.failed_requests, + success_rate=m.success_rate, + total_prompt_tokens=m.total_prompt_tokens, + total_completion_tokens=m.total_completion_tokens, + avg_prompt_tokens=m.avg_prompt_tokens, + avg_completion_tokens=m.avg_completion_tokens, + total_duration_seconds=m.total_duration_seconds, + concurrency=m.concurrency, + ) + + return ComprehensiveBenchmarkResponse( + metrics=metrics, + config={ + "endpoint": endpoint, + "model_name": model_name, + "concurrency": request.concurrency, + "num_requests": request.num_requests, + }, + error=bench_result.error, + started_at=bench_result.started_at, + completed_at=bench_result.completed_at, + duration_seconds=bench_result.duration_seconds, + ) + + +@router.post("/benchmarks/saturation", response_model=SaturationDetectionResponse) +async def run_saturation_detection( + request: SaturationDetectionRequest, + db: AsyncSession = Depends(get_db), + current_user: User = Depends(require_operator), +): + """ + Run saturation detection to find optimal concurrency. + + Automatically increases concurrency and detects: + - Throughput plateau (where adding more concurrency doesn't help) + - Latency degradation (where latency increases significantly) + - Error rate increases + + Returns the optimal concurrency level for the deployment. + """ + from app.services.benchmark import SaturationConfig, SaturationDetector + + # Verify deployment exists and is running + result = await db.execute( + select(Deployment) + .where(Deployment.id == request.deployment_id) + .options( + selectinload(Deployment.model), + selectinload(Deployment.worker), + ) + ) + deployment = result.scalar_one_or_none() + + if not deployment: + raise HTTPException(status_code=404, detail="Deployment not found") + + if deployment.status != DeploymentStatus.RUNNING.value: + raise HTTPException(status_code=400, detail="Deployment is not running") + + worker = deployment.worker + if not worker: + raise HTTPException(status_code=400, detail="Worker not found") + + endpoint = f"http://{worker.host}:{deployment.port}/v1" + model_name = deployment.model.model_id if deployment.model else "default" + + # Configure saturation detection + config = SaturationConfig( + enabled=True, + start_concurrency=request.start_concurrency, + max_concurrency=request.max_concurrency, + requests_per_level=request.requests_per_level, + use_exponential=request.use_exponential, + step_size=request.step_size, + step_multiplier=request.step_multiplier, + ) + + # Run saturation detection + detector = SaturationDetector(endpoint, model_name, config) + sat_result = await detector.detect() + + # Convert to response + return SaturationDetectionResponse( + optimal_concurrency=sat_result.optimal_concurrency, + max_throughput_tps=sat_result.max_throughput_tps, + latency_at_optimal_ms=sat_result.latency_at_optimal_ms, + saturation_concurrency=sat_result.saturation_concurrency, + saturation_detected=sat_result.saturation_detected, + stop_reason=sat_result.stop_reason, + concurrency_results=[ + ConcurrencyLevelResult( + concurrency=r.concurrency, + throughput_tps=r.throughput_tps, + avg_latency_ms=r.avg_latency_ms, + p95_latency_ms=r.p95_latency_ms, + success_rate=r.success_rate, + ) + for r in sat_result.results_by_concurrency + ], + ) + + # ============================================================================ # Knowledge Base Endpoints # ============================================================================ @@ -706,34 +905,44 @@ async def run_auto_tuning(job_id: int, llm_config: dict | None = None): async def _run_benchmark_test(deployment: Deployment, request: BenchmarkRequest) -> dict: - """Run actual benchmark test on a deployment using HTTP requests""" - from app.services.tuning_agent import _run_http_benchmark + """Run actual benchmark test on a deployment using the benchmark module""" + from app.services.benchmark import BenchmarkConfig, BenchmarkRunner, LoadPattern # Get worker info worker = deployment.worker if not worker: return {"error": "Worker not found"} - base_url = f"http://{worker.host}:{deployment.port}/v1" + endpoint = f"http://{worker.host}:{deployment.port}/v1" + model_name = deployment.model.model_id if deployment.model else "default" - result = await _run_http_benchmark( - base_url=base_url, - num_requests=max(10, request.concurrency * 5), + # Configure benchmark + config = BenchmarkConfig( + endpoint=endpoint, + model_name=model_name, + load_pattern=LoadPattern.FIXED, concurrency=request.concurrency, - input_tokens=request.input_length, + num_requests=max(20, request.concurrency * 5), + warmup_requests=5, + prompt_tokens=request.input_length, output_tokens=request.output_length, + stream=True, ) - if not result.get("success"): - return {"error": result.get("error", "Benchmark failed")} + # Run benchmark + runner = BenchmarkRunner(config) + result = await runner.run() + + if result.error: + return {"error": result.error} - metrics = result.get("metrics", {}) + metrics = result.metrics return { - "throughput_tps": metrics.get("throughput_tps"), - "ttft_ms": metrics.get("avg_ttft_ms"), - "tpot_ms": metrics.get("avg_tpot_ms"), - "total_latency_ms": None, # Not directly measured + "throughput_tps": metrics.output_tps, + "ttft_ms": metrics.ttft.mean, + "tpot_ms": metrics.tpot.mean, + "total_latency_ms": metrics.e2e_latency.mean, "gpu_utilization": None, # Would need GPU monitoring "vram_usage_gb": None, # Would need GPU monitoring - "raw": result.get("summary"), + "raw": metrics.to_dict(), } diff --git a/backend/app/api/workers.py b/backend/app/api/workers.py index 398fd20..42f7b47 100644 --- a/backend/app/api/workers.py +++ b/backend/app/api/workers.py @@ -71,6 +71,14 @@ async def list_workers( "status": worker.status, "gpu_info": worker.gpu_info, "system_info": worker.system_info, + "os_type": worker.os_type, + "gpu_type": worker.gpu_type, + "capabilities": worker.capabilities, + "available_backends": worker.available_backends, + "connection_type": worker.connection_type, + "tailscale_ip": worker.tailscale_ip, + "headscale_node_id": worker.headscale_node_id, + "effective_address": worker.effective_address, "created_at": worker.created_at, "updated_at": worker.updated_at, "last_heartbeat": worker.last_heartbeat, @@ -173,6 +181,14 @@ async def create_worker( status=original_worker.status, gpu_info=original_worker.gpu_info, system_info=original_worker.system_info, + os_type=original_worker.os_type, + gpu_type=original_worker.gpu_type, + capabilities=original_worker.capabilities, + available_backends=original_worker.available_backends, + connection_type=original_worker.connection_type, + tailscale_ip=original_worker.tailscale_ip, + headscale_node_id=original_worker.headscale_node_id, + effective_address=original_worker.effective_address, created_at=original_worker.created_at, updated_at=original_worker.updated_at, last_heartbeat=original_worker.last_heartbeat, @@ -206,6 +222,18 @@ async def create_worker( if token.is_local: worker_labels["type"] = "local" + # Extract os_type, gpu_type, capabilities from system_info + os_type = "linux" + gpu_type = "nvidia" + capabilities = None + if worker_in.system_info: + if worker_in.system_info.os_type: + os_type = worker_in.system_info.os_type + if worker_in.system_info.gpu_type: + gpu_type = worker_in.system_info.gpu_type + if worker_in.system_info.capabilities: + capabilities = worker_in.system_info.capabilities.model_dump() + worker = Worker( name=worker_in.name, address=real_address, @@ -213,6 +241,9 @@ async def create_worker( labels=worker_labels if worker_labels else None, gpu_info=([gpu.model_dump() for gpu in worker_in.gpu_info] if worker_in.gpu_info else None), system_info=(worker_in.system_info.model_dump() if worker_in.system_info else None), + os_type=os_type, + gpu_type=gpu_type, + capabilities=capabilities, status=WorkerStatus.ONLINE.value, last_heartbeat=datetime.now(UTC), ) @@ -236,6 +267,14 @@ async def create_worker( status=worker.status, gpu_info=worker.gpu_info, system_info=worker.system_info, + os_type=worker.os_type, + gpu_type=worker.gpu_type, + capabilities=worker.capabilities, + available_backends=worker.available_backends, + connection_type=worker.connection_type, + tailscale_ip=worker.tailscale_ip, + headscale_node_id=worker.headscale_node_id, + effective_address=worker.effective_address, created_at=worker.created_at, updated_at=worker.updated_at, last_heartbeat=worker.last_heartbeat, @@ -268,6 +307,14 @@ async def get_worker( status=worker.status, gpu_info=worker.gpu_info, system_info=worker.system_info, + os_type=worker.os_type, + gpu_type=worker.gpu_type, + capabilities=worker.capabilities, + available_backends=worker.available_backends, + connection_type=worker.connection_type, + tailscale_ip=worker.tailscale_ip, + headscale_node_id=worker.headscale_node_id, + effective_address=worker.effective_address, created_at=worker.created_at, updated_at=worker.updated_at, last_heartbeat=worker.last_heartbeat, @@ -318,6 +365,14 @@ async def update_worker( status=worker.status, gpu_info=worker.gpu_info, system_info=worker.system_info, + os_type=worker.os_type, + gpu_type=worker.gpu_type, + capabilities=worker.capabilities, + available_backends=worker.available_backends, + connection_type=worker.connection_type, + tailscale_ip=worker.tailscale_ip, + headscale_node_id=worker.headscale_node_id, + effective_address=worker.effective_address, created_at=worker.created_at, updated_at=worker.updated_at, last_heartbeat=worker.last_heartbeat, @@ -377,7 +432,15 @@ async def worker_heartbeat( worker.gpu_info = [gpu.model_dump() for gpu in heartbeat.gpu_info] if heartbeat.system_info: - worker.system_info = heartbeat.system_info.model_dump() + system_data = heartbeat.system_info.model_dump() + worker.system_info = system_data + # Extract os_type, gpu_type, capabilities from system_info + if heartbeat.system_info.os_type: + worker.os_type = heartbeat.system_info.os_type + if heartbeat.system_info.gpu_type: + worker.gpu_type = heartbeat.system_info.gpu_type + if heartbeat.system_info.capabilities: + worker.capabilities = heartbeat.system_info.capabilities.model_dump() # Check if worker is going offline is_going_offline = heartbeat.status == WorkerStatus.OFFLINE diff --git a/backend/app/models/llm_model.py b/backend/app/models/llm_model.py index ebca271..f34d6c0 100644 --- a/backend/app/models/llm_model.py +++ b/backend/app/models/llm_model.py @@ -12,10 +12,15 @@ class BackendType(str, Enum): """Inference backend type""" + # Docker-based backends (Linux with NVIDIA GPU) VLLM = "vllm" SGLANG = "sglang" OLLAMA = "ollama" + # Native backends (macOS with Apple Silicon) + MLX = "mlx" # MLX-LM for Apple Silicon + LLAMA_CPP = "llama_cpp" # llama.cpp with Metal + class ModelSource(str, Enum): """Model source type""" diff --git a/backend/app/models/worker.py b/backend/app/models/worker.py index 83fe603..3be352d 100644 --- a/backend/app/models/worker.py +++ b/backend/app/models/worker.py @@ -24,6 +24,23 @@ class ConnectionType(str, Enum): TAILSCALE = "tailscale" # Via Tailscale/Headscale VPN +class OSType(str, Enum): + """Worker operating system type""" + + LINUX = "linux" + DARWIN = "darwin" # macOS + WINDOWS = "windows" + + +class GPUType(str, Enum): + """GPU type for inference""" + + NVIDIA = "nvidia" + APPLE_SILICON = "apple_silicon" # M1/M2/M3/M4 + AMD = "amd" + NONE = "none" + + class Worker(Base): """Worker node model - represents a GPU node that can run LLM inference""" @@ -43,6 +60,13 @@ class Worker(Base): Integer, nullable=True ) # Node ID in Headscale + # System detection + os_type: Mapped[str] = mapped_column(String(50), default=OSType.LINUX.value) + gpu_type: Mapped[str] = mapped_column(String(50), default=GPUType.NVIDIA.value) + capabilities: Mapped[dict | None] = mapped_column( + JSON, nullable=True + ) # Available backends/tools + gpu_info: Mapped[dict | None] = mapped_column(JSON, nullable=True) system_info: Mapped[dict | None] = mapped_column(JSON, nullable=True) labels: Mapped[dict | None] = mapped_column(JSON, nullable=True) @@ -72,5 +96,40 @@ def effective_address(self) -> str: return f"{self.tailscale_ip}:{port}" return self.address + @property + def supports_docker(self) -> bool: + """Check if this worker supports Docker deployments.""" + # macOS without Docker support uses native process management + caps = self.capabilities or {} + return caps.get("docker", self.os_type == OSType.LINUX.value) + + @property + def is_mac(self) -> bool: + """Check if this is a macOS worker.""" + return self.os_type == OSType.DARWIN.value + + @property + def available_backends(self) -> list[str]: + """Get list of available backends for this worker.""" + caps = self.capabilities or {} + backends = [] + + if self.supports_docker: + # Docker-based backends + if self.gpu_type == GPUType.NVIDIA.value: + backends.extend(["vllm", "sglang", "ollama"]) + else: + backends.append("ollama") + else: + # Native backends (Mac) + if caps.get("ollama"): + backends.append("ollama") + if caps.get("mlx"): + backends.append("mlx") + if caps.get("llama_cpp"): + backends.append("llama_cpp") + + return backends + def __repr__(self) -> str: - return f"" + return f"" diff --git a/backend/app/schemas/tuning.py b/backend/app/schemas/tuning.py index 616f872..8a21a60 100644 --- a/backend/app/schemas/tuning.py +++ b/backend/app/schemas/tuning.py @@ -270,3 +270,116 @@ class KnowledgeSaveRequest(BaseModel): benchmark_result_id: int = Field(..., description="ID of the benchmark result to save") model_family: str = Field(..., description="Model family for categorization") model_params_b: float | None = Field(default=None, description="Model parameters in billions") + + +# ============================================================================ +# Comprehensive Benchmark Schemas +# ============================================================================ + + +class LatencyPercentiles(BaseModel): + """Latency metrics with percentiles""" + + mean: float = 0.0 + median: float = 0.0 + min: float = 0.0 + max: float = 0.0 + std: float = 0.0 + p50: float = 0.0 + p90: float = 0.0 + p95: float = 0.0 + p99: float = 0.0 + + +class ComprehensiveBenchmarkMetrics(BaseModel): + """Comprehensive benchmark metrics with percentiles""" + + # Latency metrics with percentiles + ttft: LatencyPercentiles = Field(default_factory=LatencyPercentiles) + itl: LatencyPercentiles = Field(default_factory=LatencyPercentiles) + tpot: LatencyPercentiles = Field(default_factory=LatencyPercentiles) + e2e_latency: LatencyPercentiles = Field(default_factory=LatencyPercentiles) + + # Throughput metrics + throughput_tps: float = Field(0.0, description="Total tokens per second") + throughput_rps: float = Field(0.0, description="Requests per second") + output_tps: float = Field(0.0, description="Output tokens per second") + + # Request statistics + total_requests: int = 0 + successful_requests: int = 0 + failed_requests: int = 0 + success_rate: float = 0.0 + + # Token statistics + total_prompt_tokens: int = 0 + total_completion_tokens: int = 0 + avg_prompt_tokens: float = 0.0 + avg_completion_tokens: float = 0.0 + + # Timing + total_duration_seconds: float = 0.0 + concurrency: int = 0 + + +class ComprehensiveBenchmarkRequest(BaseModel): + """Request for running a comprehensive benchmark""" + + deployment_id: int = Field(..., description="ID of the deployment to benchmark") + concurrency: int = Field(default=10, ge=1, le=128, description="Concurrent requests") + num_requests: int = Field(default=50, ge=10, le=1000, description="Total requests") + warmup_requests: int = Field(default=5, ge=0, le=20, description="Warmup requests") + prompt_tokens: int = Field(default=256, ge=32, le=8192, description="Approximate input tokens") + output_tokens: int = Field(default=128, ge=16, le=2048, description="Max output tokens") + custom_prompt: str | None = Field(default=None, description="Custom prompt to use") + + +class ComprehensiveBenchmarkResponse(BaseModel): + """Response for comprehensive benchmark""" + + metrics: ComprehensiveBenchmarkMetrics + config: dict + error: str | None = None + started_at: float + completed_at: float + duration_seconds: float + + +class SaturationDetectionRequest(BaseModel): + """Request for saturation detection""" + + deployment_id: int = Field(..., description="ID of the deployment to test") + start_concurrency: int = Field(default=1, ge=1, description="Starting concurrency") + max_concurrency: int = Field( + default=64, ge=1, le=256, description="Maximum concurrency to test" + ) + requests_per_level: int = Field(default=20, ge=10, le=100, description="Requests per level") + use_exponential: bool = Field(default=True, description="Use exponential stepping") + step_size: int = Field(default=2, ge=1, description="Linear step size") + step_multiplier: float = Field( + default=1.5, ge=1.1, le=3.0, description="Exponential multiplier" + ) + + +class ConcurrencyLevelResult(BaseModel): + """Result for a single concurrency level""" + + concurrency: int + throughput_tps: float + avg_latency_ms: float + p95_latency_ms: float + success_rate: float + + +class SaturationDetectionResponse(BaseModel): + """Response for saturation detection""" + + optimal_concurrency: int = Field(description="Recommended concurrency level") + max_throughput_tps: float = Field(description="Maximum throughput achieved") + latency_at_optimal_ms: float = Field(description="Latency at optimal concurrency") + saturation_concurrency: int = Field(description="Concurrency where saturation started") + saturation_detected: bool = Field(description="Whether saturation was detected") + stop_reason: str = Field(description="Reason for stopping") + concurrency_results: list[ConcurrencyLevelResult] = Field( + default_factory=list, description="Results for each concurrency level" + ) diff --git a/backend/app/schemas/worker.py b/backend/app/schemas/worker.py index e89b026..4c405e4 100644 --- a/backend/app/schemas/worker.py +++ b/backend/app/schemas/worker.py @@ -45,12 +45,27 @@ class DiskInfo(BaseModel): percent: float = 0 # percentage +class CapabilitiesInfo(BaseModel): + """Worker capabilities schema (available backends)""" + + os_type: str = "linux" # linux, darwin, windows + gpu_type: str = "nvidia" # nvidia, apple_silicon, amd, none + docker: bool = True # Docker available + ollama: bool = False # Ollama installed + ollama_running: bool = False # Ollama service running + mlx: bool = False # MLX-LM available (Mac only) + llama_cpp: bool = False # llama.cpp available + + class SystemInfo(BaseModel): """System information schema (CPU, Memory, Disk)""" cpu: CPUInfo | None = None memory: MemoryInfo | None = None disk: DiskInfo | None = None + os_type: str | None = None # linux, darwin, windows + gpu_type: str | None = None # nvidia, apple_silicon, amd, none + capabilities: CapabilitiesInfo | None = None # Available backends class WorkerBase(BaseModel): @@ -97,6 +112,10 @@ class WorkerResponse(WorkerBase): status: str gpu_info: list[dict] | None = None system_info: dict | None = None + os_type: str = "linux" # linux, darwin, windows + gpu_type: str = "nvidia" # nvidia, apple_silicon, amd, none + capabilities: dict | None = None # Available backends + available_backends: list[str] | None = None # List of available backend names tailscale_ip: str | None = None headscale_node_id: int | None = None effective_address: str | None = None # The actual address to use for connections diff --git a/backend/app/services/bayesian_tuner.py b/backend/app/services/bayesian_tuner.py index 1abe1d7..1ba34d7 100644 --- a/backend/app/services/bayesian_tuner.py +++ b/backend/app/services/bayesian_tuner.py @@ -13,7 +13,6 @@ """ import asyncio -import json import logging import time from dataclasses import dataclass, field @@ -21,7 +20,6 @@ from enum import Enum from typing import Any -import httpx import optuna from optuna.samplers import TPESampler from sqlalchemy import select @@ -32,6 +30,7 @@ from app.models.deployment import Deployment, DeploymentStatus from app.models.tuning import OptimizationTarget, PerformanceKnowledge, TuningJob, TuningJobStatus from app.models.worker import Worker +from app.services.benchmark import BenchmarkConfig, BenchmarkRunner, LoadPattern from app.services.deployer import DeployerService # Configure logging with detailed format @@ -251,6 +250,7 @@ class ObjectiveCalculator: Computes optimization objective from benchmark metrics. Implements Scorer phase: score valid configurations by performance. + Supports comprehensive metrics including percentiles. """ def __init__(self, target: OptimizationTarget): @@ -262,27 +262,42 @@ def compute(self, metrics: dict[str, float]) -> float: For Optuna, we negate values when minimizing so all objectives are treated as maximization internally. + + Uses p95 latencies for SLA-focused optimization. """ throughput = metrics.get("throughput_tps", 0.0) - ttft = metrics.get("avg_ttft_ms", float("inf")) - tpot = metrics.get("avg_tpot_ms", float("inf")) + + # Use mean for basic calculation, p95 for SLA + ttft_mean = metrics.get("avg_ttft_ms", float("inf")) + tpot_mean = metrics.get("avg_tpot_ms", float("inf")) + ttft_p95 = metrics.get("ttft_p95_ms", ttft_mean) + tpot_p95 = metrics.get("tpot_p95_ms", tpot_mean) + itl_p95 = metrics.get("itl_p95_ms", 0.0) if self.target == OptimizationTarget.THROUGHPUT: # Maximize throughput return throughput elif self.target == OptimizationTarget.LATENCY: - # Minimize latency (negate for maximization) - if ttft == float("inf"): + # Minimize latency using p95 values for SLA compliance + if ttft_p95 == float("inf") or ttft_p95 == 0: return float("-inf") - return -1.0 * (ttft + tpot * 10) # Weight TPOT more + # Combined latency score: TTFT + (ITL * typical_output_tokens) + # Assume ~50 output tokens for a typical response + estimated_generation_latency = itl_p95 * 50 if itl_p95 > 0 else tpot_p95 * 50 + total_latency = ttft_p95 + estimated_generation_latency + return -1.0 * total_latency elif self.target == OptimizationTarget.BALANCED: - # Combined score: throughput / latency - if ttft == 0 or throughput == 0: + # Combined score: throughput / latency with p95 consideration + if ttft_mean == 0 or throughput == 0: return float("-inf") - latency_factor = 1 + (ttft / 100) + (tpot / 10) - return throughput / latency_factor + + # Penalize high p95 latencies more than mean + latency_factor = 1 + (ttft_mean / 100) + (tpot_mean / 10) + p95_penalty = 1 + (ttft_p95 - ttft_mean) / 200 + (tpot_p95 - tpot_mean) / 20 + + return throughput / (latency_factor * p95_penalty) elif self.target == OptimizationTarget.COST: # Maximize efficiency (throughput per resource unit) @@ -763,10 +778,14 @@ async def _wait_for_deployment(self, deployment_id: int, timeout: int = 600) -> async def _run_benchmark( self, deployment_id: int, - num_requests: int = 20, - concurrency: int = 4, + num_requests: int = 30, + concurrency: int = 8, ) -> dict[str, float]: - """Run benchmark against deployment""" + """ + Run benchmark against deployment using the benchmark module. + + Returns comprehensive metrics including percentiles. + """ result = await self.db.execute( select(Deployment) .where(Deployment.id == deployment_id) @@ -779,106 +798,32 @@ async def _run_benchmark( # Build endpoint URL worker_ip = deployment.worker.address.split(":")[0] - base_url = f"http://{worker_ip}:{deployment.port}/v1" + endpoint = f"http://{worker_ip}:{deployment.port}/v1" model_name = deployment.model.model_id - # Run HTTP benchmark (reuse existing implementation pattern) - return await self._http_benchmark(base_url, model_name, num_requests, concurrency) + # Configure benchmark + config = BenchmarkConfig( + endpoint=endpoint, + model_name=model_name, + load_pattern=LoadPattern.FIXED, + concurrency=concurrency, + num_requests=num_requests, + warmup_requests=5, + output_tokens=64, + prompt_tokens=200, + stream=True, + verbose=False, + ) - async def _http_benchmark( - self, - base_url: str, - model_name: str, - num_requests: int, - concurrency: int, - ) -> dict[str, float]: - """Execute HTTP benchmark against OpenAI-compatible endpoint""" - test_prompt = "Explain the concept of machine learning in simple terms. " * 20 - - results = [] - semaphore = asyncio.Semaphore(concurrency) - - async def make_request(client: httpx.AsyncClient) -> dict | None: - async with semaphore: - start = time.perf_counter() - first_token_time = None - token_count = 0 - - try: - async with client.stream( - "POST", - f"{base_url}/chat/completions", - json={ - "model": model_name, - "messages": [{"role": "user", "content": test_prompt}], - "max_tokens": 64, - "stream": True, - }, - timeout=60.0, - ) as resp: - if resp.status_code != 200: - return None - - async for line in resp.aiter_lines(): - if line.startswith("data: ") and line != "data: [DONE]": - try: - chunk = json.loads(line[6:]) - content = ( - chunk.get("choices", [{}])[0] - .get("delta", {}) - .get("content", "") - ) - if content: - if first_token_time is None: - first_token_time = time.perf_counter() - token_count += 1 - except json.JSONDecodeError: - pass - - end = time.perf_counter() - - if first_token_time and token_count > 0: - return { - "ttft_ms": (first_token_time - start) * 1000, - "tpot_ms": ( - ((end - first_token_time) / max(1, token_count - 1)) * 1000 - if token_count > 1 - else 0 - ), - "tokens": token_count, - "total_time": end - start, - } - except Exception: - pass - return None - - async with httpx.AsyncClient() as client: - # Warmup - for _ in range(2): - await make_request(client) - - # Actual benchmark - tasks = [make_request(client) for _ in range(num_requests)] - results = await asyncio.gather(*tasks) - - valid = [r for r in results if r] - if not valid: - return {"throughput_tps": 0, "avg_ttft_ms": 0, "avg_tpot_ms": 0} - - total_tokens = sum(r["tokens"] for r in valid) - total_time = sum(r["total_time"] for r in valid) - - return { - "throughput_tps": round(total_tokens / total_time, 2) if total_time > 0 else 0, - "avg_ttft_ms": round(sum(r["ttft_ms"] for r in valid) / len(valid), 2), - "avg_tpot_ms": round( - sum(r["tpot_ms"] for r in valid if r["tpot_ms"] > 0) - / max(1, len([r for r in valid if r["tpot_ms"] > 0])), - 2, - ), - "successful_requests": len(valid), - "total_requests": num_requests, - } + # Run benchmark + runner = BenchmarkRunner(config) + bench_result = await runner.run() + + if bench_result.error: + await self._log("WARNING", f"Benchmark warning: {bench_result.error}") + + # Return comprehensive metrics in simple dict format + return bench_result.metrics.to_simple_dict() async def _cleanup_deployment(self): """Stop and remove current deployment""" diff --git a/backend/app/services/benchmark/__init__.py b/backend/app/services/benchmark/__init__.py new file mode 100644 index 0000000..3bfb8cb --- /dev/null +++ b/backend/app/services/benchmark/__init__.py @@ -0,0 +1,46 @@ +""" +Benchmark Module + +Independent benchmarking system for LLM inference performance evaluation. +Supports comprehensive metrics, multiple load patterns, and saturation detection. + +Usage: + from app.services.benchmark import BenchmarkRunner, BenchmarkConfig, LoadPattern + + config = BenchmarkConfig( + endpoint="http://localhost:8000/v1", + model_name="Qwen/Qwen2.5-7B-Instruct", + load_pattern=LoadPattern.FIXED, + concurrency=10, + duration_seconds=60, + ) + + runner = BenchmarkRunner(config) + result = await runner.run() + + print(f"Throughput: {result.metrics.throughput_tps:.2f} TPS") + print(f"TTFT p95: {result.metrics.ttft.p95:.2f} ms") +""" + +from .config import BenchmarkConfig, LoadPattern, SaturationConfig +from .metrics import BenchmarkMetrics, LatencyMetrics, RequestResult, compute_percentiles +from .runner import BenchmarkResult, BenchmarkRunner +from .saturation import SaturationDetector, SaturationResult + +__all__ = [ + # Config + "BenchmarkConfig", + "LoadPattern", + "SaturationConfig", + # Metrics + "LatencyMetrics", + "BenchmarkMetrics", + "RequestResult", + "compute_percentiles", + # Runner + "BenchmarkRunner", + "BenchmarkResult", + # Saturation + "SaturationDetector", + "SaturationResult", +] diff --git a/backend/app/services/benchmark/config.py b/backend/app/services/benchmark/config.py new file mode 100644 index 0000000..e2b3f5b --- /dev/null +++ b/backend/app/services/benchmark/config.py @@ -0,0 +1,147 @@ +""" +Benchmark Configuration + +Defines configuration options for benchmark execution. +""" + +from dataclasses import dataclass, field +from enum import Enum +from typing import Any + + +class LoadPattern(Enum): + """Load pattern types for benchmark execution""" + + # Fixed concurrency throughout the test + FIXED = "fixed" + + # Gradually increase concurrency to find saturation point + INCREMENTAL = "incremental" + + # Burst load pattern with spikes + BURST = "burst" + + # Step-wise increase for finding optimal concurrency + STEP = "step" + + +@dataclass +class SaturationConfig: + """Configuration for saturation detection""" + + # Enable saturation detection (auto-find optimal concurrency) + enabled: bool = False + + # Starting concurrency level + start_concurrency: int = 1 + + # Maximum concurrency to test + max_concurrency: int = 64 + + # Concurrency step size + step_size: int = 2 + + # Multiplier for exponential stepping (alternative to linear) + step_multiplier: float = 1.5 + + # Use exponential stepping instead of linear + use_exponential: bool = True + + # Minimum requests per concurrency level + requests_per_level: int = 20 + + # Throughput degradation threshold (e.g., 0.95 = 5% drop triggers saturation) + degradation_threshold: float = 0.95 + + # Latency increase threshold (e.g., 1.5 = 50% increase triggers saturation) + latency_threshold: float = 1.5 + + # Number of consecutive degradations before declaring saturation + consecutive_degradations: int = 2 + + +@dataclass +class BenchmarkConfig: + """Configuration for benchmark execution""" + + # Target endpoint (OpenAI-compatible API) + endpoint: str + + # Model name/ID for API requests + model_name: str + + # Load pattern to use + load_pattern: LoadPattern = LoadPattern.FIXED + + # Concurrency level (for FIXED pattern) + concurrency: int = 10 + + # Duration in seconds (0 = use num_requests instead) + duration_seconds: int = 0 + + # Number of requests (used when duration_seconds = 0) + num_requests: int = 50 + + # Warmup requests before actual benchmark + warmup_requests: int = 5 + + # Request timeout in seconds + request_timeout: float = 120.0 + + # Prompt configuration + prompt_tokens: int = 256 # Approximate input tokens + output_tokens: int = 128 # Max output tokens + + # Custom prompt (overrides prompt_tokens if set) + custom_prompt: str | None = None + + # Saturation detection config + saturation: SaturationConfig = field(default_factory=SaturationConfig) + + # Extra parameters for requests + extra_params: dict[str, Any] = field(default_factory=dict) + + # Stream responses (required for accurate ITL measurement) + stream: bool = True + + # Verbose logging + verbose: bool = False + + def get_prompt(self) -> str: + """Get the prompt to use for benchmarking""" + if self.custom_prompt: + return self.custom_prompt + + # Generate synthetic prompt of approximately prompt_tokens length + # Average English word is ~4 characters, ~1.3 tokens + words_needed = int(self.prompt_tokens / 1.3) + base_prompt = "Explain the following concept in detail: " + filler = "artificial intelligence machine learning neural network deep learning natural language processing computer vision reinforcement learning transformer architecture attention mechanism gradient descent backpropagation optimization algorithm data science statistical analysis pattern recognition feature extraction dimensionality reduction clustering classification regression prediction inference training validation testing deployment scalability performance efficiency accuracy precision recall " + + # Repeat filler to reach desired length + repeated = (filler * ((words_needed // len(filler.split())) + 1)).split() + prompt = base_prompt + " ".join(repeated[:words_needed]) + + return prompt + + def validate(self) -> tuple[bool, str]: + """Validate configuration""" + if not self.endpoint: + return False, "Endpoint is required" + + if not self.model_name: + return False, "Model name is required" + + if self.concurrency < 1: + return False, "Concurrency must be at least 1" + + if self.duration_seconds < 0: + return False, "Duration cannot be negative" + + if self.num_requests < 1 and self.duration_seconds == 0: + return False, "Either duration_seconds or num_requests must be positive" + + if self.request_timeout < 1: + return False, "Request timeout must be at least 1 second" + + return True, "" diff --git a/backend/app/services/benchmark/metrics.py b/backend/app/services/benchmark/metrics.py new file mode 100644 index 0000000..0792213 --- /dev/null +++ b/backend/app/services/benchmark/metrics.py @@ -0,0 +1,254 @@ +""" +Benchmark Metrics + +Comprehensive metrics calculation including: +- TTFT (Time to First Token) +- ITL (Inter-Token Latency) +- TPOT (Time Per Output Token) +- Percentiles (p50, p90, p95, p99) +- Throughput (tokens/sec, requests/sec) +""" + +import statistics +from collections.abc import Sequence +from dataclasses import dataclass, field + + +@dataclass +class RequestResult: + """Result from a single request""" + + # Timing metrics (in milliseconds) + ttft_ms: float = 0.0 # Time to first token + total_latency_ms: float = 0.0 # Total request duration + token_timestamps_ms: list[float] = field(default_factory=list) # Timestamp of each token + + # Token counts + prompt_tokens: int = 0 + completion_tokens: int = 0 + + # Request metadata + success: bool = True + error_message: str | None = None + start_time: float = 0.0 + end_time: float = 0.0 + + @property + def itl_values_ms(self) -> list[float]: + """Calculate inter-token latencies from timestamps""" + if len(self.token_timestamps_ms) < 2: + return [] + + itl_values = [] + for i in range(1, len(self.token_timestamps_ms)): + itl = self.token_timestamps_ms[i] - self.token_timestamps_ms[i - 1] + itl_values.append(itl) + return itl_values + + @property + def tpot_ms(self) -> float: + """Time per output token (excluding TTFT)""" + if self.completion_tokens <= 1: + return 0.0 + + generation_time = self.total_latency_ms - self.ttft_ms + return generation_time / (self.completion_tokens - 1) + + +@dataclass +class LatencyMetrics: + """Latency metrics with percentiles""" + + mean: float = 0.0 + median: float = 0.0 + min: float = 0.0 + max: float = 0.0 + std: float = 0.0 + p50: float = 0.0 + p90: float = 0.0 + p95: float = 0.0 + p99: float = 0.0 + + def to_dict(self) -> dict[str, float]: + """Convert to dictionary""" + return { + "mean": round(self.mean, 2), + "median": round(self.median, 2), + "min": round(self.min, 2), + "max": round(self.max, 2), + "std": round(self.std, 2), + "p50": round(self.p50, 2), + "p90": round(self.p90, 2), + "p95": round(self.p95, 2), + "p99": round(self.p99, 2), + } + + +def compute_percentiles(values: Sequence[float]) -> LatencyMetrics: + """Compute latency metrics with percentiles from a sequence of values""" + if not values: + return LatencyMetrics() + + sorted_values = sorted(values) + n = len(sorted_values) + + def percentile(p: float) -> float: + """Calculate percentile value""" + if n == 1: + return sorted_values[0] + k = (n - 1) * (p / 100) + f = int(k) + c = f + 1 if f + 1 < n else f + return sorted_values[f] + (k - f) * (sorted_values[c] - sorted_values[f]) + + return LatencyMetrics( + mean=statistics.mean(values), + median=statistics.median(values), + min=min(values), + max=max(values), + std=statistics.stdev(values) if n > 1 else 0.0, + p50=percentile(50), + p90=percentile(90), + p95=percentile(95), + p99=percentile(99), + ) + + +@dataclass +class BenchmarkMetrics: + """Comprehensive benchmark metrics""" + + # Latency metrics + ttft: LatencyMetrics = field(default_factory=LatencyMetrics) + itl: LatencyMetrics = field(default_factory=LatencyMetrics) + tpot: LatencyMetrics = field(default_factory=LatencyMetrics) + e2e_latency: LatencyMetrics = field(default_factory=LatencyMetrics) + + # Throughput metrics + throughput_tps: float = 0.0 # Tokens per second + throughput_rps: float = 0.0 # Requests per second + output_tps: float = 0.0 # Output tokens per second + + # Request statistics + total_requests: int = 0 + successful_requests: int = 0 + failed_requests: int = 0 + success_rate: float = 0.0 + + # Token statistics + total_prompt_tokens: int = 0 + total_completion_tokens: int = 0 + avg_prompt_tokens: float = 0.0 + avg_completion_tokens: float = 0.0 + + # Timing + total_duration_seconds: float = 0.0 + concurrency: int = 0 + + def to_dict(self) -> dict: + """Convert to dictionary for serialization""" + return { + "ttft": self.ttft.to_dict(), + "itl": self.itl.to_dict(), + "tpot": self.tpot.to_dict(), + "e2e_latency": self.e2e_latency.to_dict(), + "throughput_tps": round(self.throughput_tps, 2), + "throughput_rps": round(self.throughput_rps, 4), + "output_tps": round(self.output_tps, 2), + "total_requests": self.total_requests, + "successful_requests": self.successful_requests, + "failed_requests": self.failed_requests, + "success_rate": round(self.success_rate, 4), + "total_prompt_tokens": self.total_prompt_tokens, + "total_completion_tokens": self.total_completion_tokens, + "avg_prompt_tokens": round(self.avg_prompt_tokens, 1), + "avg_completion_tokens": round(self.avg_completion_tokens, 1), + "total_duration_seconds": round(self.total_duration_seconds, 2), + "concurrency": self.concurrency, + } + + @classmethod + def from_results( + cls, + results: list[RequestResult], + total_duration: float, + concurrency: int, + ) -> "BenchmarkMetrics": + """Compute metrics from a list of request results""" + if not results: + return cls() + + successful = [r for r in results if r.success] + failed = [r for r in results if not r.success] + + if not successful: + return cls( + total_requests=len(results), + failed_requests=len(failed), + total_duration_seconds=total_duration, + concurrency=concurrency, + ) + + # Collect latency values + ttft_values = [r.ttft_ms for r in successful if r.ttft_ms > 0] + e2e_values = [r.total_latency_ms for r in successful if r.total_latency_ms > 0] + tpot_values = [r.tpot_ms for r in successful if r.tpot_ms > 0] + + # Collect all ITL values + all_itl_values: list[float] = [] + for r in successful: + all_itl_values.extend(r.itl_values_ms) + + # Token counts + total_prompt = sum(r.prompt_tokens for r in successful) + total_completion = sum(r.completion_tokens for r in successful) + total_tokens = total_prompt + total_completion + + # Throughput + throughput_tps = total_tokens / total_duration if total_duration > 0 else 0 + throughput_rps = len(successful) / total_duration if total_duration > 0 else 0 + output_tps = total_completion / total_duration if total_duration > 0 else 0 + + return cls( + ttft=compute_percentiles(ttft_values), + itl=compute_percentiles(all_itl_values) if all_itl_values else LatencyMetrics(), + tpot=compute_percentiles(tpot_values), + e2e_latency=compute_percentiles(e2e_values), + throughput_tps=throughput_tps, + throughput_rps=throughput_rps, + output_tps=output_tps, + total_requests=len(results), + successful_requests=len(successful), + failed_requests=len(failed), + success_rate=len(successful) / len(results) if results else 0, + total_prompt_tokens=total_prompt, + total_completion_tokens=total_completion, + avg_prompt_tokens=total_prompt / len(successful) if successful else 0, + avg_completion_tokens=total_completion / len(successful) if successful else 0, + total_duration_seconds=total_duration, + concurrency=concurrency, + ) + + def to_simple_dict(self) -> dict[str, float]: + """Convert to simple dictionary for backward compatibility with BayesianTuner""" + return { + "throughput_tps": self.output_tps, + "avg_ttft_ms": self.ttft.mean, + "avg_tpot_ms": self.tpot.mean, + "avg_itl_ms": self.itl.mean, + "ttft_p50_ms": self.ttft.p50, + "ttft_p95_ms": self.ttft.p95, + "ttft_p99_ms": self.ttft.p99, + "tpot_p50_ms": self.tpot.p50, + "tpot_p95_ms": self.tpot.p95, + "tpot_p99_ms": self.tpot.p99, + "itl_p50_ms": self.itl.p50, + "itl_p95_ms": self.itl.p95, + "itl_p99_ms": self.itl.p99, + "e2e_latency_p50_ms": self.e2e_latency.p50, + "e2e_latency_p95_ms": self.e2e_latency.p95, + "e2e_latency_p99_ms": self.e2e_latency.p99, + "successful_requests": float(self.successful_requests), + "total_requests": float(self.total_requests), + "success_rate": self.success_rate, + } diff --git a/backend/app/services/benchmark/runner.py b/backend/app/services/benchmark/runner.py new file mode 100644 index 0000000..6c37dd2 --- /dev/null +++ b/backend/app/services/benchmark/runner.py @@ -0,0 +1,484 @@ +""" +Benchmark Runner + +Executes benchmarks against OpenAI-compatible endpoints with accurate +token-level timing measurements using streaming responses. +""" + +import asyncio +import json +import logging +import time +from dataclasses import dataclass, field +from typing import Any + +import httpx + +from .config import BenchmarkConfig, LoadPattern +from .metrics import BenchmarkMetrics, RequestResult + +logger = logging.getLogger(__name__) + + +@dataclass +class BenchmarkResult: + """Complete benchmark result""" + + metrics: BenchmarkMetrics + config: BenchmarkConfig + raw_results: list[RequestResult] = field(default_factory=list) + error: str | None = None + started_at: float = 0.0 + completed_at: float = 0.0 + + @property + def duration_seconds(self) -> float: + """Total benchmark duration""" + return self.completed_at - self.started_at + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary""" + return { + "metrics": self.metrics.to_dict(), + "config": { + "endpoint": self.config.endpoint, + "model_name": self.config.model_name, + "load_pattern": self.config.load_pattern.value, + "concurrency": self.config.concurrency, + "num_requests": self.config.num_requests, + "duration_seconds": self.config.duration_seconds, + }, + "error": self.error, + "started_at": self.started_at, + "completed_at": self.completed_at, + "duration_seconds": self.duration_seconds, + } + + +class BenchmarkRunner: + """ + Executes benchmarks against LLM inference endpoints. + + Supports multiple load patterns and accurate token-level timing. + """ + + def __init__(self, config: BenchmarkConfig): + self.config = config + self._cancelled = False + self._results: list[RequestResult] = [] + self._active_requests = 0 + self._completed_requests = 0 + self._start_time: float = 0.0 + + async def run(self) -> BenchmarkResult: + """Execute the benchmark""" + # Validate config + valid, error = self.config.validate() + if not valid: + return BenchmarkResult( + metrics=BenchmarkMetrics(), + config=self.config, + error=error, + ) + + self._start_time = time.time() + result = BenchmarkResult( + metrics=BenchmarkMetrics(), + config=self.config, + started_at=self._start_time, + ) + + try: + # Run warmup + if self.config.warmup_requests > 0: + await self._run_warmup() + + # Execute based on load pattern + if self.config.load_pattern == LoadPattern.FIXED: + await self._run_fixed_load() + elif self.config.load_pattern == LoadPattern.INCREMENTAL: + await self._run_incremental_load() + elif self.config.load_pattern == LoadPattern.BURST: + await self._run_burst_load() + elif self.config.load_pattern == LoadPattern.STEP: + await self._run_step_load() + else: + await self._run_fixed_load() + + # Calculate metrics + end_time = time.time() + total_duration = end_time - self._start_time + + result.metrics = BenchmarkMetrics.from_results( + self._results, + total_duration, + self.config.concurrency, + ) + result.raw_results = self._results + result.completed_at = end_time + + except Exception as e: + logger.exception(f"Benchmark failed: {e}") + result.error = str(e) + result.completed_at = time.time() + + return result + + def cancel(self): + """Cancel the running benchmark""" + self._cancelled = True + + async def _run_warmup(self): + """Run warmup requests""" + if self.config.verbose: + logger.info(f"Running {self.config.warmup_requests} warmup requests...") + + async with httpx.AsyncClient(timeout=self.config.request_timeout) as client: + tasks = [ + self._make_request(client, warmup=True) for _ in range(self.config.warmup_requests) + ] + await asyncio.gather(*tasks, return_exceptions=True) + + if self.config.verbose: + logger.info("Warmup complete") + + async def _run_fixed_load(self): + """Run with fixed concurrency""" + semaphore = asyncio.Semaphore(self.config.concurrency) + + async def limited_request(client: httpx.AsyncClient): + async with semaphore: + return await self._make_request(client) + + async with httpx.AsyncClient(timeout=self.config.request_timeout) as client: + if self.config.duration_seconds > 0: + # Duration-based + end_time = time.time() + self.config.duration_seconds + tasks: list[asyncio.Task] = [] + + while time.time() < end_time and not self._cancelled: + if len(tasks) < self.config.concurrency * 2: + task = asyncio.create_task(limited_request(client)) + tasks.append(task) + + # Clean up completed tasks + done = [t for t in tasks if t.done()] + for t in done: + try: + result = t.result() + if result: + self._results.append(result) + except Exception: + pass + tasks.remove(t) + + await asyncio.sleep(0.01) + + # Wait for remaining tasks + for task in tasks: + try: + result = await task + if result: + self._results.append(result) + except Exception: + pass + else: + # Request count-based + tasks = [limited_request(client) for _ in range(self.config.num_requests)] + results = await asyncio.gather(*tasks, return_exceptions=True) + + for r in results: + if isinstance(r, RequestResult): + self._results.append(r) + + async def _run_incremental_load(self): + """Run with incrementally increasing concurrency""" + sat_config = self.config.saturation + current_concurrency = sat_config.start_concurrency + + async with httpx.AsyncClient(timeout=self.config.request_timeout) as client: + while current_concurrency <= sat_config.max_concurrency and not self._cancelled: + if self.config.verbose: + logger.info(f"Testing concurrency level: {current_concurrency}") + + semaphore = asyncio.Semaphore(current_concurrency) + + async def limited_request(sem=semaphore): + async with sem: + return await self._make_request(client) + + tasks = [limited_request() for _ in range(sat_config.requests_per_level)] + results = await asyncio.gather(*tasks, return_exceptions=True) + + for r in results: + if isinstance(r, RequestResult): + self._results.append(r) + + # Increase concurrency + if sat_config.use_exponential: + current_concurrency = int(current_concurrency * sat_config.step_multiplier) + else: + current_concurrency += sat_config.step_size + + async def _run_burst_load(self): + """Run with burst load pattern""" + # Burst pattern: low -> high -> low + concurrency_levels = [ + self.config.concurrency // 4, # Low + self.config.concurrency, # High (burst) + self.config.concurrency // 4, # Low + self.config.concurrency, # High (burst) + self.config.concurrency // 4, # Low + ] + + requests_per_phase = self.config.num_requests // len(concurrency_levels) + + async with httpx.AsyncClient(timeout=self.config.request_timeout) as client: + for phase, concurrency in enumerate(concurrency_levels): + if self._cancelled: + break + + if self.config.verbose: + logger.info(f"Burst phase {phase + 1}: concurrency={concurrency}") + + semaphore = asyncio.Semaphore(max(1, concurrency)) + + async def limited_request(sem=semaphore): + async with sem: + return await self._make_request(client) + + tasks = [limited_request() for _ in range(requests_per_phase)] + results = await asyncio.gather(*tasks, return_exceptions=True) + + for r in results: + if isinstance(r, RequestResult): + self._results.append(r) + + async def _run_step_load(self): + """Run with step-wise load increase""" + sat_config = self.config.saturation + steps = [] + + # Generate step levels + current = sat_config.start_concurrency + while current <= sat_config.max_concurrency: + steps.append(current) + if sat_config.use_exponential: + current = int(current * sat_config.step_multiplier) + else: + current += sat_config.step_size + + async with httpx.AsyncClient(timeout=self.config.request_timeout) as client: + for step_idx, concurrency in enumerate(steps): + if self._cancelled: + break + + if self.config.verbose: + logger.info(f"Step {step_idx + 1}/{len(steps)}: concurrency={concurrency}") + + semaphore = asyncio.Semaphore(concurrency) + + async def limited_request(sem=semaphore): + async with sem: + return await self._make_request(client) + + tasks = [limited_request() for _ in range(sat_config.requests_per_level)] + results = await asyncio.gather(*tasks, return_exceptions=True) + + for r in results: + if isinstance(r, RequestResult): + self._results.append(r) + + async def _make_request( + self, + client: httpx.AsyncClient, + warmup: bool = False, + ) -> RequestResult | None: + """Make a single request with token-level timing""" + if self._cancelled: + return None + + result = RequestResult(start_time=time.time()) + prompt = self.config.get_prompt() + + request_body = { + "model": self.config.model_name, + "messages": [{"role": "user", "content": prompt}], + "max_tokens": self.config.output_tokens, + "stream": self.config.stream, + **self.config.extra_params, + } + + endpoint = self.config.endpoint.rstrip("/") + url = f"{endpoint}/chat/completions" + + try: + start_time = time.perf_counter() + + if self.config.stream: + result = await self._make_streaming_request(client, url, request_body, start_time) + else: + result = await self._make_non_streaming_request( + client, url, request_body, start_time + ) + + result.end_time = time.time() + + if self.config.verbose and not warmup: + logger.info( + f"Request completed: TTFT={result.ttft_ms:.1f}ms, " + f"tokens={result.completion_tokens}, " + f"latency={result.total_latency_ms:.1f}ms" + ) + + except httpx.TimeoutException: + result.success = False + result.error_message = "Request timeout" + result.end_time = time.time() + + except Exception as e: + result.success = False + result.error_message = str(e) + result.end_time = time.time() + + return result + + async def _make_streaming_request( + self, + client: httpx.AsyncClient, + url: str, + request_body: dict, + start_time: float, + ) -> RequestResult: + """Make a streaming request with token-level timing""" + result = RequestResult(start_time=time.time()) + first_token_received = False + token_timestamps: list[float] = [] + token_count = 0 + + async with client.stream("POST", url, json=request_body) as response: + if response.status_code != 200: + result.success = False + result.error_message = f"HTTP {response.status_code}" + return result + + async for line in response.aiter_lines(): + if self._cancelled: + break + + if not line.startswith("data: "): + continue + + data = line[6:] + if data == "[DONE]": + break + + try: + chunk = json.loads(data) + choices = chunk.get("choices", []) + if not choices: + continue + + delta = choices[0].get("delta", {}) + content = delta.get("content", "") + + if content: + current_time = time.perf_counter() + + if not first_token_received: + result.ttft_ms = (current_time - start_time) * 1000 + first_token_received = True + + # Record timestamp for each token + token_timestamps.append((current_time - start_time) * 1000) + token_count += 1 + + # Extract usage if available (some backends include it) + usage = chunk.get("usage", {}) + if usage: + result.prompt_tokens = usage.get("prompt_tokens", result.prompt_tokens) + result.completion_tokens = usage.get("completion_tokens", token_count) + + except json.JSONDecodeError: + continue + + end_time = time.perf_counter() + result.total_latency_ms = (end_time - start_time) * 1000 + result.token_timestamps_ms = token_timestamps + result.completion_tokens = ( + token_count if result.completion_tokens == 0 else result.completion_tokens + ) + + return result + + async def _make_non_streaming_request( + self, + client: httpx.AsyncClient, + url: str, + request_body: dict, + start_time: float, + ) -> RequestResult: + """Make a non-streaming request""" + result = RequestResult(start_time=time.time()) + + response = await client.post(url, json=request_body) + end_time = time.perf_counter() + + if response.status_code != 200: + result.success = False + result.error_message = f"HTTP {response.status_code}" + return result + + try: + data = response.json() + usage = data.get("usage", {}) + result.prompt_tokens = usage.get("prompt_tokens", 0) + result.completion_tokens = usage.get("completion_tokens", 0) + + # For non-streaming, TTFT = total latency + result.total_latency_ms = (end_time - start_time) * 1000 + result.ttft_ms = result.total_latency_ms + + except Exception as e: + result.success = False + result.error_message = str(e) + + return result + + +async def run_benchmark( + endpoint: str, + model_name: str, + concurrency: int = 10, + num_requests: int = 50, + warmup_requests: int = 5, + output_tokens: int = 128, + verbose: bool = False, +) -> BenchmarkResult: + """ + Convenience function to run a simple benchmark. + + Args: + endpoint: OpenAI-compatible API endpoint + model_name: Model name/ID + concurrency: Number of concurrent requests + num_requests: Total number of requests + warmup_requests: Number of warmup requests + output_tokens: Max output tokens per request + verbose: Enable verbose logging + + Returns: + BenchmarkResult with metrics + """ + config = BenchmarkConfig( + endpoint=endpoint, + model_name=model_name, + concurrency=concurrency, + num_requests=num_requests, + warmup_requests=warmup_requests, + output_tokens=output_tokens, + verbose=verbose, + ) + + runner = BenchmarkRunner(config) + return await runner.run() diff --git a/backend/app/services/benchmark/saturation.py b/backend/app/services/benchmark/saturation.py new file mode 100644 index 0000000..c6a14bd --- /dev/null +++ b/backend/app/services/benchmark/saturation.py @@ -0,0 +1,289 @@ +""" +Saturation Detection + +Automatically finds the optimal concurrency level by detecting when +performance starts to degrade (throughput plateaus or latency spikes). +""" + +import logging +from dataclasses import dataclass, field +from typing import Any + +from .config import BenchmarkConfig, LoadPattern, SaturationConfig +from .metrics import BenchmarkMetrics +from .runner import BenchmarkRunner + +logger = logging.getLogger(__name__) + + +@dataclass +class ConcurrencyResult: + """Result for a single concurrency level""" + + concurrency: int + metrics: BenchmarkMetrics + throughput_tps: float = 0.0 + avg_latency_ms: float = 0.0 + p95_latency_ms: float = 0.0 + success_rate: float = 0.0 + + +@dataclass +class SaturationResult: + """Result from saturation detection""" + + # Optimal concurrency found + optimal_concurrency: int = 0 + + # Maximum throughput achieved + max_throughput_tps: float = 0.0 + + # Latency at optimal concurrency + latency_at_optimal_ms: float = 0.0 + + # Saturation concurrency (where degradation starts) + saturation_concurrency: int = 0 + + # All tested concurrency levels + results_by_concurrency: list[ConcurrencyResult] = field(default_factory=list) + + # Was saturation detected? + saturation_detected: bool = False + + # Reason for stopping + stop_reason: str = "" + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary""" + return { + "optimal_concurrency": self.optimal_concurrency, + "max_throughput_tps": round(self.max_throughput_tps, 2), + "latency_at_optimal_ms": round(self.latency_at_optimal_ms, 2), + "saturation_concurrency": self.saturation_concurrency, + "saturation_detected": self.saturation_detected, + "stop_reason": self.stop_reason, + "concurrency_results": [ + { + "concurrency": r.concurrency, + "throughput_tps": round(r.throughput_tps, 2), + "avg_latency_ms": round(r.avg_latency_ms, 2), + "p95_latency_ms": round(r.p95_latency_ms, 2), + "success_rate": round(r.success_rate, 4), + } + for r in self.results_by_concurrency + ], + } + + +class SaturationDetector: + """ + Detects optimal concurrency by running incremental load tests. + + Algorithm: + 1. Start with low concurrency + 2. Increase concurrency and measure throughput/latency + 3. Stop when: + - Throughput stops increasing (plateau) + - Latency increases significantly + - Error rate increases + 4. Return the concurrency level with best throughput/latency balance + """ + + def __init__( + self, + endpoint: str, + model_name: str, + config: SaturationConfig | None = None, + ): + self.endpoint = endpoint + self.model_name = model_name + self.config = config or SaturationConfig(enabled=True) + self._cancelled = False + + async def detect(self) -> SaturationResult: + """Run saturation detection""" + result = SaturationResult() + results_by_level: list[ConcurrencyResult] = [] + + current_concurrency = self.config.start_concurrency + peak_throughput = 0.0 + baseline_latency = 0.0 + consecutive_degradations = 0 + + logger.info( + f"Starting saturation detection: " + f"start={self.config.start_concurrency}, " + f"max={self.config.max_concurrency}" + ) + + while current_concurrency <= self.config.max_concurrency and not self._cancelled: + logger.info(f"Testing concurrency: {current_concurrency}") + + # Run benchmark at current concurrency + benchmark_config = BenchmarkConfig( + endpoint=self.endpoint, + model_name=self.model_name, + load_pattern=LoadPattern.FIXED, + concurrency=current_concurrency, + num_requests=self.config.requests_per_level, + warmup_requests=2, + verbose=False, + ) + + runner = BenchmarkRunner(benchmark_config) + bench_result = await runner.run() + + if bench_result.error: + logger.warning( + f"Benchmark error at concurrency {current_concurrency}: {bench_result.error}" + ) + consecutive_degradations += 1 + + if consecutive_degradations >= self.config.consecutive_degradations: + result.stop_reason = f"Too many errors: {bench_result.error}" + result.saturation_detected = True + result.saturation_concurrency = current_concurrency + break + + # Skip to next level + current_concurrency = self._next_concurrency(current_concurrency) + continue + + metrics = bench_result.metrics + + # Record result + level_result = ConcurrencyResult( + concurrency=current_concurrency, + metrics=metrics, + throughput_tps=metrics.output_tps, + avg_latency_ms=metrics.e2e_latency.mean, + p95_latency_ms=metrics.e2e_latency.p95, + success_rate=metrics.success_rate, + ) + results_by_level.append(level_result) + + logger.info( + f"Concurrency {current_concurrency}: " + f"throughput={metrics.output_tps:.1f} TPS, " + f"latency={metrics.e2e_latency.mean:.1f}ms (p95={metrics.e2e_latency.p95:.1f}ms), " + f"success={metrics.success_rate:.1%}" + ) + + # Set baseline latency from first successful run + if baseline_latency == 0.0 and metrics.e2e_latency.mean > 0: + baseline_latency = metrics.e2e_latency.mean + + # Check for peak throughput + if metrics.output_tps > peak_throughput: + peak_throughput = metrics.output_tps + consecutive_degradations = 0 + + # Check for degradation + degradation_detected = False + + # Throughput degradation + if peak_throughput > 0: + throughput_ratio = metrics.output_tps / peak_throughput + if throughput_ratio < self.config.degradation_threshold: + logger.info( + f"Throughput degradation detected: " + f"{metrics.output_tps:.1f} vs peak {peak_throughput:.1f}" + ) + degradation_detected = True + + # Latency degradation + if baseline_latency > 0: + latency_ratio = metrics.e2e_latency.mean / baseline_latency + if latency_ratio > self.config.latency_threshold: + logger.info( + f"Latency degradation detected: " + f"{metrics.e2e_latency.mean:.1f}ms vs baseline {baseline_latency:.1f}ms" + ) + degradation_detected = True + + # Success rate check + if metrics.success_rate < 0.95: + logger.info(f"High error rate detected: {1 - metrics.success_rate:.1%}") + degradation_detected = True + + if degradation_detected: + consecutive_degradations += 1 + if consecutive_degradations >= self.config.consecutive_degradations: + result.saturation_detected = True + result.saturation_concurrency = current_concurrency + result.stop_reason = "Performance degradation detected" + break + else: + consecutive_degradations = 0 + + # Move to next concurrency level + current_concurrency = self._next_concurrency(current_concurrency) + + # Find optimal concurrency (best throughput/latency ratio) + if results_by_level: + result.results_by_concurrency = results_by_level + + # Find best result (highest throughput with acceptable latency) + best = max( + results_by_level, + key=lambda r: ( + r.throughput_tps / max(1, r.avg_latency_ms / 100) if r.success_rate > 0.9 else 0 + ), + ) + + result.optimal_concurrency = best.concurrency + result.max_throughput_tps = peak_throughput + result.latency_at_optimal_ms = best.avg_latency_ms + + if not result.saturation_detected: + result.stop_reason = "Reached max concurrency" + + logger.info( + f"Saturation detection complete: " + f"optimal={result.optimal_concurrency}, " + f"max_throughput={result.max_throughput_tps:.1f} TPS" + ) + + return result + + def _next_concurrency(self, current: int) -> int: + """Calculate next concurrency level""" + if self.config.use_exponential: + next_val = int(current * self.config.step_multiplier) + # Ensure we make at least step_size progress + return max(next_val, current + self.config.step_size) + else: + return current + self.config.step_size + + def cancel(self): + """Cancel the saturation detection""" + self._cancelled = True + + +async def find_optimal_concurrency( + endpoint: str, + model_name: str, + max_concurrency: int = 64, + requests_per_level: int = 20, +) -> SaturationResult: + """ + Convenience function to find optimal concurrency. + + Args: + endpoint: OpenAI-compatible API endpoint + model_name: Model name/ID + max_concurrency: Maximum concurrency to test + requests_per_level: Requests per concurrency level + + Returns: + SaturationResult with optimal concurrency and metrics + """ + config = SaturationConfig( + enabled=True, + start_concurrency=1, + max_concurrency=max_concurrency, + requests_per_level=requests_per_level, + ) + + detector = SaturationDetector(endpoint, model_name, config) + return await detector.detect() diff --git a/backend/app/services/deployer.py b/backend/app/services/deployer.py index 537d50d..a8800e6 100644 --- a/backend/app/services/deployer.py +++ b/backend/app/services/deployer.py @@ -13,6 +13,7 @@ from app.database import async_session_maker from app.models.deployment import Deployment, DeploymentStatus from app.models.llm_model import BackendType, LLMModel +from app.models.worker import OSType logger = logging.getLogger(__name__) settings = get_settings() @@ -60,6 +61,19 @@ async def deploy(self, deployment_id: int) -> None: deployment.status_message = "Sending deployment request to worker..." await db.commit() + # Check if worker supports Docker or needs native deployment + worker = deployment.worker + is_mac_native = worker.os_type == OSType.DARWIN.value and not worker.supports_docker + + # Use native deployment for Mac without Docker + if is_mac_native: + result = await self._deploy_native(deployment, db) + if result.get("error"): + deployment.status = DeploymentStatus.ERROR.value + deployment.status_message = result["error"] + await db.commit() + return + # Build deployment request deploy_request = self._build_deploy_request(deployment) @@ -488,6 +502,157 @@ def _is_local_worker(self, address: str) -> bool: host = address.split(":")[0].lower() return host in ("localhost", "127.0.0.1", "local") + async def _deploy_native(self, deployment: Deployment, db) -> dict: + """Deploy using native backend (Mac without Docker). + + Supports Ollama, MLX, and llama.cpp backends on macOS. + """ + worker = deployment.worker + model = deployment.model + backend = deployment.backend + + # Validate backend is supported + available_backends = worker.available_backends + if backend not in available_backends: + return { + "error": f"Backend '{backend}' not available on this worker. " + f"Available backends: {', '.join(available_backends)}" + } + + try: + worker_url = f"http://{worker.effective_address}/native/deploy" + + deploy_request = { + "deployment_id": deployment.id, + "deployment_name": deployment.name, + "model_id": model.model_id, + "backend": backend, + "port": 0, # Auto-assign + "extra_params": deployment.extra_params, + } + + deployment.status_message = f"Starting {backend} deployment..." + await db.commit() + + async with httpx.AsyncClient(timeout=600.0) as client: + response = await client.post(worker_url, json=deploy_request) + + if response.status_code != 200: + error_detail = response.json().get("detail", response.text) + return {"error": f"Native deployment failed: {error_detail}"} + + result = response.json() + deployment.port = result.get("port") + # Use process_id as container_id for native deployments + deployment.container_id = result.get("process_id") + + # Wait for API to be ready + deployment.status_message = "Waiting for model to be ready..." + await db.commit() + + # For Ollama on native, the API is at port 11434 + if backend == BackendType.OLLAMA.value: + api_port = 11434 + else: + api_port = deployment.port + + api_ready = await self._wait_for_native_api_ready( + worker.effective_address, + api_port, + deployment.id, + db, + backend=backend, + ) + + if api_ready is None: + return {} # Cancelled + elif api_ready: + deployment.status = DeploymentStatus.RUNNING.value + deployment.status_message = "Model ready" + asyncio.create_task(_update_semantic_router_config_background()) + else: + deployment.status = DeploymentStatus.ERROR.value + deployment.status_message = "API failed to start" + + await db.commit() + return {} + + except httpx.ConnectError as e: + return {"error": f"Cannot connect to worker: {e}"} + except Exception as e: + logger.exception(f"Native deployment error: {e}") + return {"error": str(e)} + + async def _wait_for_native_api_ready( + self, + worker_address: str, + port: int, + deployment_id: int, + db, + backend: str = BackendType.OLLAMA.value, + timeout: int = 300, + ) -> bool | None: + """Wait for native API to be ready. + + Returns: + True: API is ready + False: Timeout + None: Cancelled + """ + worker_ip = worker_address.split(":")[0] + api_base = f"http://{worker_ip}:{port}" + + # Determine health endpoint based on backend + if backend == BackendType.OLLAMA.value: + health_url = f"{api_base}/v1/models" + else: + health_url = f"{api_base}/v1/models" + + elapsed = 0 + check_interval = 5 + + async with httpx.AsyncClient(timeout=10.0) as client: + while elapsed < timeout: + # Check if cancelled + try: + result = await db.execute( + select(Deployment).where(Deployment.id == deployment_id) + ) + deployment = result.scalar_one_or_none() + if deployment and deployment.status in [ + DeploymentStatus.STOPPED.value, + DeploymentStatus.STOPPING.value, + ]: + return None + except Exception: + pass + + try: + response = await client.get(health_url) + if response.status_code == 200: + logger.info(f"Native API ready at {health_url}") + return True + except Exception: + pass + + await asyncio.sleep(check_interval) + elapsed += check_interval + + # Update status periodically + if elapsed % 30 == 0: + try: + result = await db.execute( + select(Deployment).where(Deployment.id == deployment_id) + ) + deployment = result.scalar_one_or_none() + if deployment: + deployment.status_message = f"Loading model... ({elapsed}s)" + await db.commit() + except Exception: + pass + + return False + def _image_exists_local(self, image: str) -> bool: """Check if a Docker image exists locally.""" try: @@ -713,6 +878,9 @@ def _build_vllm_config( continue if value is True: cmd.append(f"--{key}") + # Auto-add tool-call-parser when enable-auto-tool-choice is enabled + if key == "enable-auto-tool-choice": + cmd.extend(["--tool-call-parser", "hermes"]) elif value is not False and value is not None: cmd.extend([f"--{key}", str(value)]) @@ -854,18 +1022,33 @@ async def stop(self, deployment_id: int) -> None: return try: - is_local = self._is_local_worker(deployment.worker.address) + worker = deployment.worker - if is_local: - # Stop locally using Docker directly - await self._stop_local(deployment.container_id) - else: - worker_url = f"http://{deployment.worker.address}/stop" + # Check if this is a native deployment (process_id starts with "native-") + is_native = deployment.container_id.startswith("native-") + if is_native: + # Stop native process + worker_url = f"http://{worker.effective_address}/native/stop" async with httpx.AsyncClient(timeout=60.0) as client: await client.post( - worker_url, json={"container_id": deployment.container_id} + worker_url, + json={"process_id": deployment.container_id}, ) + else: + # Docker-based deployment + is_local = self._is_local_worker(worker.address) + + if is_local: + # Stop locally using Docker directly + await self._stop_local(deployment.container_id) + else: + worker_url = f"http://{worker.address}/stop" + + async with httpx.AsyncClient(timeout=60.0) as client: + await client.post( + worker_url, json={"container_id": deployment.container_id} + ) except Exception as e: logger.warning(f"Error stopping deployment {deployment_id}: {e}") @@ -889,26 +1072,44 @@ async def get_logs(self, deployment: Deployment, tail: int = 100) -> str: return "No container running" try: - is_local = self._is_local_worker(deployment.worker.address) + worker = deployment.worker - if is_local: - return self._get_logs_local(deployment.container_id, tail) - else: - worker_url = f"http://{deployment.worker.address}/logs" + # Check if this is a native deployment + is_native = deployment.container_id.startswith("native-") + if is_native: + # Get logs from native process + worker_url = ( + f"http://{worker.effective_address}/native/logs/{deployment.container_id}" + ) async with httpx.AsyncClient(timeout=30.0) as client: - response = await client.get( - worker_url, - params={ - "container_id": deployment.container_id, - "tail": tail, - }, - ) - + response = await client.get(worker_url, params={"tail": tail}) if response.status_code == 200: return response.json().get("logs", "") else: return f"Error fetching logs: {response.text}" + else: + # Docker-based deployment + is_local = self._is_local_worker(worker.address) + + if is_local: + return self._get_logs_local(deployment.container_id, tail) + else: + worker_url = f"http://{worker.address}/logs" + + async with httpx.AsyncClient(timeout=30.0) as client: + response = await client.get( + worker_url, + params={ + "container_id": deployment.container_id, + "tail": tail, + }, + ) + + if response.status_code == 200: + return response.json().get("logs", "") + else: + return f"Error fetching logs: {response.text}" except httpx.ConnectError: return f"Cannot connect to worker at {deployment.worker.address}" diff --git a/backend/app/services/mcp/agent.py b/backend/app/services/mcp/agent.py index 016c541..be5d491 100644 --- a/backend/app/services/mcp/agent.py +++ b/backend/app/services/mcp/agent.py @@ -210,6 +210,7 @@ def __init__( mcp_api_url: str | None = None, mcp_api_token: str | None = None, max_iterations: int = 20, + supports_tool_calling: bool = True, ): """ Initialize the Agent Service. @@ -222,12 +223,14 @@ def __init__( mcp_api_url: LMStack API URL for MCP server. mcp_api_token: LMStack API token for MCP server. max_iterations: Maximum tool call iterations. + supports_tool_calling: Whether the LLM supports tool/function calling. """ self.llm_client = llm_client self.llm_model = llm_model self.llm_base_url = llm_base_url self.llm_api_key = llm_api_key self.max_iterations = max_iterations + self.supports_tool_calling = supports_tool_calling # MCP configuration self.mcp_api_url = mcp_api_url @@ -537,6 +540,15 @@ async def chat( self.conversation.append(user_msg) try: + # Warn if tool calling is not supported + if not self.supports_tool_calling: + yield AgentEvent( + type=EventType.MESSAGE, + content="**Note:** This deployment does not have tool calling enabled. " + "The agent will respond as a regular chat model without executing tools. " + "To enable Agent features, redeploy the model with 'Enable Tool Calling' option in Advanced Settings.\n\n", + ) + # Initial thinking event yield AgentEvent(type=EventType.THINKING, content="Analyzing your request...") @@ -555,15 +567,18 @@ async def chat( # For continuation, don't add the user message again messages = messages[:-1] - tools_schema = self._build_tools_schema() + # Only include tools if tool calling is supported + tools_schema = self._build_tools_schema() if self.supports_tool_calling else None try: # Use streaming for better UX + # Note: Don't pass tool_choice="auto" as it requires vLLM to have + # --enable-auto-tool-choice and --tool-call-parser flags enabled. + # Without tool_choice, the API uses default behavior. stream = await self.llm_client.chat.completions.create( model=self.llm_model, messages=messages, tools=tools_schema if tools_schema else None, - tool_choice="auto" if tools_schema else None, stream=True, ) diff --git a/backend/migrations/011_add_worker_mac_support.py b/backend/migrations/011_add_worker_mac_support.py new file mode 100644 index 0000000..3e66bc8 --- /dev/null +++ b/backend/migrations/011_add_worker_mac_support.py @@ -0,0 +1,85 @@ +""" +Migration: Add Mac native deployment support to workers table + +Adds os_type, gpu_type, and capabilities columns to track worker environment +and available backends (Docker, Ollama, MLX, llama.cpp). + +Run with: python -m migrations.011_add_worker_mac_support +""" + +import asyncio +import sys +from pathlib import Path + +# Add parent directory to path for imports +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from sqlalchemy import text +from sqlalchemy.ext.asyncio import create_async_engine + +from app.config import get_settings + + +async def column_exists(conn, table_name: str, column_name: str) -> bool: + """Check if a column exists in a table (SQLite compatible)""" + result = await conn.execute(text(f"PRAGMA table_info({table_name})")) + columns = [row[1] for row in result.fetchall()] + return column_name in columns + + +async def migrate(): + settings = get_settings() + engine = create_async_engine(settings.database_url, echo=True) + + async with engine.begin() as conn: + # Add os_type column + if not await column_exists(conn, "workers", "os_type"): + print("Adding 'os_type' column to workers table...") + await conn.execute( + text( + """ + ALTER TABLE workers ADD COLUMN os_type VARCHAR(50) DEFAULT 'linux' + """ + ) + ) + print("'os_type' column added!") + else: + print("'os_type' column already exists") + + # Add gpu_type column + if not await column_exists(conn, "workers", "gpu_type"): + print("Adding 'gpu_type' column to workers table...") + await conn.execute( + text( + """ + ALTER TABLE workers ADD COLUMN gpu_type VARCHAR(50) DEFAULT 'nvidia' + """ + ) + ) + print("'gpu_type' column added!") + else: + print("'gpu_type' column already exists") + + # Add capabilities column (JSON) + if not await column_exists(conn, "workers", "capabilities"): + print("Adding 'capabilities' column to workers table...") + await conn.execute( + text( + """ + ALTER TABLE workers ADD COLUMN capabilities JSON + """ + ) + ) + print("'capabilities' column added!") + else: + print("'capabilities' column already exists") + + print("\n" + "=" * 50) + print("Migration completed successfully!") + print("=" * 50) + + await engine.dispose() + + +if __name__ == "__main__": + asyncio.run(migrate()) diff --git a/frontend/src/components/DeploymentAdvancedForm.tsx b/frontend/src/components/DeploymentAdvancedForm.tsx index 62871e8..b141d1d 100644 --- a/frontend/src/components/DeploymentAdvancedForm.tsx +++ b/frontend/src/components/DeploymentAdvancedForm.tsx @@ -94,6 +94,14 @@ const VllmAdvancedParams = () => ( > + + + ; + return ( + + ); } return ( @@ -128,11 +132,20 @@ export function AgentChatView({ function EmptyState({ colors, isDark: _isDark, + onSendMessage, }: { colors: ThemeColors; isDark: boolean; + onSendMessage?: (message: string) => void; }) { const isDark = _isDark; + + const quickActions = [ + { icon: , text: "列出所有 Worker 狀態" }, + { icon: , text: "GPU 記憶體使用狀況" }, + { icon: , text: "列出所有容器" }, + ]; + return (
- +
- LMStack AI Agent + Agent
- Powered by MCP. Deploy models, run benchmarks, and optimize - configurations through natural language. + 部署模型、執行基準測試、查詢系統狀態
- {/* Capability cards */} + {/* Quick actions */}
- {[ - { icon: , text: "Deploy LLM models to GPU workers" }, - { icon: , text: "Run performance benchmarks" }, - { - icon: , - text: "Query knowledge base for optimal configs", - }, - ].map((item, idx) => ( + {quickActions.map((item, idx) => (
onSendMessage?.(item.text)} style={{ display: "flex", alignItems: "center", - gap: 12, - padding: "12px 16px", - borderRadius: 10, + gap: 10, + padding: "10px 14px", + borderRadius: 8, background: isDark ? "rgba(255,255,255,0.03)" : "rgba(0,0,0,0.02)", border: `1px solid ${isDark ? "rgba(255,255,255,0.06)" : "rgba(0,0,0,0.06)"}`, + cursor: "pointer", + transition: "background 0.15s", + }} + onMouseEnter={(e) => { + e.currentTarget.style.background = isDark + ? "rgba(255,255,255,0.06)" + : "rgba(0,0,0,0.04)"; + }} + onMouseLeave={(e) => { + e.currentTarget.style.background = isDark + ? "rgba(255,255,255,0.03)" + : "rgba(0,0,0,0.02)"; }} > -
{item.icon}
+
+ {item.icon} +
{item.text}
))}
- -
- Try: "Deploy Qwen-7B on Worker 1" or "What's the GPU memory status?" -
); } @@ -309,17 +319,19 @@ function MessageBlock({ {/* Avatar */}
- +
{/* Content */} @@ -342,19 +354,22 @@ function MessageBlock({ style={{ display: "flex", alignItems: "center", - gap: 10, - padding: "12px 14px", - borderRadius: 12, + gap: 8, + padding: "10px 12px", + borderRadius: 8, background: isDark - ? "rgba(59, 130, 246, 0.1)" - : "rgba(59, 130, 246, 0.06)", - border: `1px solid ${isDark ? "rgba(59, 130, 246, 0.2)" : "rgba(59, 130, 246, 0.15)"}`, + ? "rgba(255,255,255,0.03)" + : "rgba(0,0,0,0.02)", + border: `1px solid ${isDark ? "rgba(255,255,255,0.06)" : "rgba(0,0,0,0.06)"}`, marginBottom: 12, }} > - - - Analyzing your request... + + + 處理中... )} @@ -436,7 +451,7 @@ function PageReferenceCard({ colors, }: PageReferenceCardProps) { const getIcon = () => { - const iconStyle = { fontSize: 16, color: "#3b82f6" }; + const iconStyle = { fontSize: 14, color: colors.textMuted }; switch (reference.icon) { case "cluster": return ; @@ -462,66 +477,44 @@ function PageReferenceCard({ display: "flex", alignItems: "center", gap: 10, - padding: "10px 14px", - borderRadius: 10, - background: isDark - ? "rgba(59, 130, 246, 0.1)" - : "rgba(59, 130, 246, 0.06)", + padding: "8px 12px", + borderRadius: 6, + background: isDark ? "rgba(255,255,255,0.03)" : "rgba(0,0,0,0.02)", border: `1px solid ${ - isDark ? "rgba(59, 130, 246, 0.25)" : "rgba(59, 130, 246, 0.2)" + isDark ? "rgba(255,255,255,0.08)" : "rgba(0,0,0,0.08)" }`, cursor: "pointer", - transition: "all 0.2s ease", - minWidth: 180, + transition: "background 0.15s", }} onMouseEnter={(e) => { e.currentTarget.style.background = isDark - ? "rgba(59, 130, 246, 0.15)" - : "rgba(59, 130, 246, 0.1)"; - e.currentTarget.style.borderColor = isDark - ? "rgba(59, 130, 246, 0.4)" - : "rgba(59, 130, 246, 0.35)"; + ? "rgba(255,255,255,0.06)" + : "rgba(0,0,0,0.04)"; }} onMouseLeave={(e) => { e.currentTarget.style.background = isDark - ? "rgba(59, 130, 246, 0.1)" - : "rgba(59, 130, 246, 0.06)"; - e.currentTarget.style.borderColor = isDark - ? "rgba(59, 130, 246, 0.25)" - : "rgba(59, 130, 246, 0.2)"; + ? "rgba(255,255,255,0.03)" + : "rgba(0,0,0,0.02)"; }} > -
- {getIcon()} -
+ {getIcon()}
{reference.title}
-
- {reference.description} -
+ {reference.description && ( +
+ {reference.description} +
+ )}
- + ); } @@ -543,7 +536,7 @@ function ActionSuggestionButton({ colors, }: ActionSuggestionButtonProps) { const getIcon = () => { - const iconStyle = { fontSize: 14 }; + const iconStyle = { fontSize: 12 }; switch (suggestion.icon) { case "rocket": return ; @@ -558,60 +551,11 @@ function ActionSuggestionButton({ case "tool": return ; default: - return ; - } - }; - - const getButtonStyle = () => { - switch (suggestion.type) { - case "primary": - return { - background: isDark - ? "rgba(59, 130, 246, 0.15)" - : "rgba(59, 130, 246, 0.1)", - border: `1px solid ${ - isDark ? "rgba(59, 130, 246, 0.4)" : "rgba(59, 130, 246, 0.3)" - }`, - color: "#3b82f6", - hoverBg: isDark - ? "rgba(59, 130, 246, 0.25)" - : "rgba(59, 130, 246, 0.15)", - hoverBorder: isDark - ? "rgba(59, 130, 246, 0.6)" - : "rgba(59, 130, 246, 0.5)", - }; - case "danger": - return { - background: isDark - ? "rgba(239, 68, 68, 0.1)" - : "rgba(239, 68, 68, 0.06)", - border: `1px solid ${ - isDark ? "rgba(239, 68, 68, 0.3)" : "rgba(239, 68, 68, 0.2)" - }`, - color: "#ef4444", - hoverBg: isDark ? "rgba(239, 68, 68, 0.2)" : "rgba(239, 68, 68, 0.1)", - hoverBorder: isDark - ? "rgba(239, 68, 68, 0.5)" - : "rgba(239, 68, 68, 0.4)", - }; - default: - return { - background: isDark - ? "rgba(255, 255, 255, 0.05)" - : "rgba(0, 0, 0, 0.04)", - border: `1px solid ${ - isDark ? "rgba(255, 255, 255, 0.1)" : "rgba(0, 0, 0, 0.1)" - }`, - color: colors.textSecondary, - hoverBg: isDark ? "rgba(255, 255, 255, 0.1)" : "rgba(0, 0, 0, 0.08)", - hoverBorder: isDark - ? "rgba(255, 255, 255, 0.2)" - : "rgba(0, 0, 0, 0.2)", - }; + return ; } }; - const style = getButtonStyle(); + const isDanger = suggestion.type === "danger"; return (
{ - e.currentTarget.style.background = style.hoverBg; - e.currentTarget.style.borderColor = style.hoverBorder; + e.currentTarget.style.background = isDanger + ? isDark + ? "rgba(239, 68, 68, 0.15)" + : "rgba(239, 68, 68, 0.1)" + : isDark + ? "rgba(255, 255, 255, 0.08)" + : "rgba(0, 0, 0, 0.06)"; }} onMouseLeave={(e) => { - e.currentTarget.style.background = style.background; - e.currentTarget.style.borderColor = style.border.replace( - "1px solid ", - "", - ); + e.currentTarget.style.background = isDanger + ? isDark + ? "rgba(239, 68, 68, 0.1)" + : "rgba(239, 68, 68, 0.06)" + : isDark + ? "rgba(255, 255, 255, 0.04)" + : "rgba(0, 0, 0, 0.03)"; }} > {getIcon()} @@ -751,12 +715,10 @@ function StepItem({ step, onToggle, isDark, colors }: StepItemProps) { return (
{/* Header */} @@ -768,27 +730,27 @@ function StepItem({ step, onToggle, isDark, colors }: StepItemProps) { gap: 8, padding: "10px 12px", cursor: "pointer", - background: isDark - ? "rgba(139,92,246,0.08)" - : "rgba(139,92,246,0.05)", + background: isDark ? "rgba(255,255,255,0.02)" : "rgba(0,0,0,0.02)", borderBottom: step.expanded - ? `1px solid ${isDark ? "rgba(139,92,246,0.2)" : "rgba(139,92,246,0.15)"}` + ? `1px solid ${isDark ? "rgba(255,255,255,0.06)" : "rgba(0,0,0,0.06)"}` : "none", }} > - + {step.title} {step.status === "running" && ( - + )} void; @@ -90,32 +60,6 @@ function savePanelState(state: Partial) { } } -/** - * Load chat mode from localStorage - */ -function loadChatMode(): ChatMode { - try { - const saved = localStorage.getItem(CHAT_MODE_STORAGE_KEY); - if (saved === "chat" || saved === "agent") { - return saved; - } - } catch { - // Ignore parse errors - } - return "agent"; // Default to agent mode -} - -/** - * Save chat mode to localStorage - */ -function saveChatMode(mode: ChatMode) { - try { - localStorage.setItem(CHAT_MODE_STORAGE_KEY, mode); - } catch { - // Ignore save errors - } -} - /** * Global chat panel component */ @@ -140,26 +84,8 @@ export function ChatPanel({ () => loadPanelState().customEndpoints || [], ); - // Chat mode state - const [chatMode, setChatMode] = useState(() => loadChatMode()); - - // Traditional chat state + // Input state const [inputValue, setInputValue] = useState(""); - const { - messages: chatMessages, - isStreaming: chatIsStreaming, - isExecutingTool, - currentToolName, - pendingTools, - showConfirmModal, - systemContext, - refreshContext, - sendMessage: chatSendMessage, - stopStreaming: chatStopStreaming, - clearMessages: chatClearMessages, - confirmToolExecution, - cancelToolExecution, - } = useChat(); // Agent chat state const { @@ -172,6 +98,7 @@ export function ChatPanel({ stopStreaming: agentStopStreaming, clearMessages: agentClearMessages, toggleStepExpanded, + startNewConversation, } = useAgentChat(); // Chat panel context - for external access @@ -185,49 +112,28 @@ export function ChatPanel({ }; }, [_registerSendFunction, agentSendMessage, selectedModel]); - // Derived state based on mode - const isStreaming = chatMode === "agent" ? agentIsStreaming : chatIsStreaming; - const hasMessages = - chatMode === "agent" ? agentMessages.length > 0 : chatMessages.length > 0; - - // Handle mode change - const handleModeChange = useCallback((value: string | number) => { - const newMode = value as ChatMode; - setChatMode(newMode); - saveChatMode(newMode); - }, []); + // Derived state + const isStreaming = agentIsStreaming; + const hasMessages = agentMessages.length > 0; - // Handle send message based on mode + // Handle send message const handleSendMessage = useCallback(() => { if (!inputValue.trim() || !selectedModel) return; - - if (chatMode === "agent") { - agentSendMessage(inputValue, selectedModel); - } else { - chatSendMessage(inputValue, selectedModel); - } + agentSendMessage(inputValue, selectedModel); setInputValue(""); - }, [inputValue, selectedModel, chatMode, agentSendMessage, chatSendMessage]); + }, [inputValue, selectedModel, agentSendMessage]); - // Handle stop streaming based on mode + // Handle stop streaming const handleStopStreaming = useCallback(() => { - if (chatMode === "agent") { - agentStopStreaming(); - } else { - chatStopStreaming(); - } - }, [chatMode, agentStopStreaming, chatStopStreaming]); + agentStopStreaming(); + }, [agentStopStreaming]); - // Handle clear messages based on mode + // Handle clear messages const handleClearMessages = useCallback(() => { - if (chatMode === "agent") { - agentClearMessages(); - } else { - chatClearMessages(); - } - }, [chatMode, agentClearMessages, chatClearMessages]); + agentClearMessages(); + }, [agentClearMessages]); - // Handle sending message from action suggestions (agent mode) + // Handle sending message from action suggestions const handleAgentSendSuggestion = useCallback( (message: string) => { if (!selectedModel || agentIsStreaming) return; @@ -236,9 +142,25 @@ export function ChatPanel({ [selectedModel, agentIsStreaming, agentSendMessage], ); + // Handle model change - start new conversation when model changes + const handleModelChange = useCallback( + (model: ChatModelConfig | null) => { + // If model type or identity changes, start a new conversation + const modelChanged = + selectedModel?.type !== model?.type || + selectedModel?.deploymentId !== model?.deploymentId || + selectedModel?.endpoint !== model?.endpoint; + + if (modelChanged && agentMessages.length > 0) { + startNewConversation(); + } + setSelectedModel(model); + }, + [selectedModel, agentMessages.length, startNewConversation], + ); + // Refs const messagesContainerRef = useRef(null); - const messagesEndRef = useRef(null); const resizeHandleRef = useRef(null); const [isResizing, setIsResizing] = useState(false); const [showScrollButton, setShowScrollButton] = useState(false); @@ -336,18 +258,10 @@ export function ChatPanel({ // Scroll on new messages useEffect(() => { - const messageCount = - chatMode === "agent" ? agentMessages.length : chatMessages.length; - if (!isStreaming && messageCount > 0) { + if (!isStreaming && agentMessages.length > 0) { scrollToBottom(); } - }, [ - chatMode, - agentMessages.length, - chatMessages.length, - isStreaming, - scrollToBottom, - ]); + }, [agentMessages.length, isStreaming, scrollToBottom]); if (!isOpen) return null; @@ -426,7 +340,7 @@ export function ChatPanel({ {/* Left side: Model selector */} - {/* Right side: Mode toggle and actions */} + {/* Right side: Actions */}
- {/* Mode toggle */} - , - label: "Agent", - }, - { - value: "chat", - icon: , - label: "Chat", - }, - ]} - style={{ - background: isDark ? "#27272a" : "#f4f4f5", - }} - /> - - {/* System status indicator (only in chat mode) */} - {chatMode === "chat" && systemContext && ( - -
Workers: {systemContext.workers.length}
-
- Deployments:{" "} - { - systemContext.deployments.filter( - (d) => d.status === "running", - ).length - } -
-
Models: {systemContext.models.length}
-
- Click to refresh -
-
- } - > - - - )} - {/* Clear button */} {hasMessages && ( @@ -544,56 +391,19 @@ export function ChatPanel({ position: "relative", }} > - {chatMode === "agent" ? ( - /* Agent Mode - Claude Code style */ - - ) : /* Chat Mode - Traditional tool calling */ - chatMessages.length === 0 ? ( - - ) : ( -
- {chatMessages.map((msg, index) => { - const isLast = index === chatMessages.length - 1; - const showStreaming = - isLast && chatIsStreaming && msg.role === "assistant"; - const showToolExecution = - isLast && isExecutingTool && msg.role === "assistant"; - - return ( - - ); - })} -
-
- )} +
{/* Scroll to bottom button */} @@ -642,394 +452,6 @@ export function ChatPanel({
- - {/* Tool Confirmation Modal (only for chat mode) */} - {chatMode === "chat" && ( - - )} ); } - -/** - * Empty state component - */ -interface EmptyStateProps { - selectedModel: ChatModelConfig | null; - systemContext: import("./systemContext").SystemContext | null; - colors: ThemeColors; -} - -function EmptyState({ selectedModel, systemContext, colors }: EmptyStateProps) { - const activeDeployments = - systemContext?.deployments.filter((d) => d.status === "running") || []; - const runningContainers = - systemContext?.containers.filter( - (c) => - c.status.toLowerCase().includes("running") || - c.status.toLowerCase().includes("up"), - ) || []; - - return ( -
-
- -
-
- {selectedModel ? "LMStack AI Assistant" : "Select a Model"} -
-
- {selectedModel - ? "I can help you manage your LLM infrastructure" - : "Choose a model from the dropdown above"} -
- - {/* System summary */} - {selectedModel && systemContext && ( -
-
- System Overview -
-
-
- • {systemContext.workers.length} Worker - {systemContext.workers.length !== 1 ? "s" : ""} -
-
- • {runningContainers.length}/{systemContext.containers.length}{" "} - Container{systemContext.containers.length !== 1 ? "s" : ""}{" "} - running -
-
- • {activeDeployments.length} Model deployment - {activeDeployments.length !== 1 ? "s" : ""} active -
-
- • {systemContext.models.length} Model - {systemContext.models.length !== 1 ? "s" : ""} available -
-
- • {systemContext.images.length} Docker image - {systemContext.images.length !== 1 ? "s" : ""} -
-
-
- Try: "有幾個容器在運行?" or "GPU 記憶體剩多少?" -
-
- )} -
- ); -} - -/** - * Format tool name for display - */ -function formatToolName(name: string): string { - return name.replace(/_/g, " ").replace(/\b\w/g, (c) => c.toUpperCase()); -} - -/** - * Message bubble component - */ -interface MessageBubbleProps { - message: ChatMessage; - isStreaming: boolean; - isExecutingTool: boolean; - currentToolName: string | null; - isDark: boolean; - colors: ThemeColors; -} - -function MessageBubble({ - message, - isStreaming, - isExecutingTool, - currentToolName, - isDark, - colors, -}: MessageBubbleProps) { - const isUser = message.role === "user"; - - return ( -
- {/* Avatar */} - {!isUser && ( -
- -
- )} - - {/* Content */} -
- {isUser ? ( -
- {message.content} -
- ) : ( - <> - {/* Thinking section for reasoning models */} - {message.thinking && ( - - - Thinking Process - - ), - children: ( -
- {message.thinking} -
- ), - }, - ]} - /> - )} - - {/* Tool execution indicator - show when executing */} - {isExecutingTool && currentToolName && ( -
- - - Executing: {formatToolName(currentToolName)} - -
- )} - - {/* Tool calls display - always show in history */} - {message.toolCalls && - message.toolCalls.length > 0 && - !isExecutingTool && ( -
-
- - Tool Calls Executed -
- {message.toolCalls.map((tc) => { - let args: Record = {}; - try { - args = JSON.parse(tc.function.arguments); - } catch { - args = {}; - } - return ( -
-
- - - {formatToolName(tc.function.name)} - -
- {Object.keys(args).length > 0 && ( -
- {Object.entries(args).map(([key, value]) => ( -
- {key}:{" "} - {typeof value === "object" - ? JSON.stringify(value) - : String(value)} -
- ))} -
- )} -
- ); - })} -
- )} - - - {message.model && ( -
- via {message.model} -
- )} - - )} -
- - {/* User avatar */} - {isUser && ( -
- -
- )} -
- ); -} diff --git a/frontend/src/components/chat-panel/ToolConfirmModal.tsx b/frontend/src/components/chat-panel/ToolConfirmModal.tsx deleted file mode 100644 index 5627359..0000000 --- a/frontend/src/components/chat-panel/ToolConfirmModal.tsx +++ /dev/null @@ -1,218 +0,0 @@ -/** - * Tool Confirmation Modal - * - * A modal dialog that asks for user confirmation before executing - * AI-requested tool actions. - */ -import { Modal, Button, Tag, Descriptions } from "antd"; -import { - ExclamationCircleOutlined, - CheckCircleOutlined, - DeleteOutlined, - RocketOutlined, - PauseCircleOutlined, - PlayCircleOutlined, - PlusOutlined, - StopOutlined, - ThunderboltOutlined, - KeyOutlined, - CloudDownloadOutlined, - DatabaseOutlined, - PieChartOutlined, - ClearOutlined, -} from "@ant-design/icons"; -import type { PendingToolExecution } from "./tools"; - -interface ToolConfirmModalProps { - visible: boolean; - pendingTools: PendingToolExecution[]; - onConfirm: () => void; - onCancel: () => void; - isDark: boolean; -} - -/** - * Get icon for tool - */ -function getToolIcon(iconName: string) { - const iconMap: Record = { - delete: , - rocket: , - "pause-circle": , - "play-circle": , - plus: , - stop: , - thunderbolt: , - key: , - download: , - database: , - "pie-chart": , - clear: , - }; - return iconMap[iconName] || ; -} - -/** - * Format argument value for display - */ -function formatArgValue(value: any): string { - if (value === null || value === undefined) { - return "-"; - } - if (Array.isArray(value)) { - return value.join(", "); - } - if (typeof value === "object") { - return JSON.stringify(value); - } - return String(value); -} - -/** - * Tool Confirmation Modal Component - */ -export function ToolConfirmModal({ - visible, - pendingTools, - onConfirm, - onCancel, - isDark, -}: ToolConfirmModalProps) { - if (pendingTools.length === 0) return null; - - const hasDangerous = pendingTools.some((t) => t.meta.dangerous); - - return ( - - {hasDangerous ? ( - - ) : ( - - )} - Confirm Action - - } - onCancel={onCancel} - footer={[ - , - , - ]} - width={500} - centered - styles={{ - body: { - background: isDark ? "#1f1f1f" : "#ffffff", - }, - header: { - background: isDark ? "#1f1f1f" : "#ffffff", - }, - content: { - background: isDark ? "#1f1f1f" : "#ffffff", - }, - }} - > -
-
- AI assistant wants to execute the following actions: -
- - {pendingTools.map((tool, index) => ( -
- {/* Tool header */} -
- - {getToolIcon(tool.meta.icon)} - - - {tool.meta.displayName} - - {tool.meta.dangerous && ( - - Dangerous - - )} -
- - {/* Tool arguments */} - - {Object.entries(tool.parsedArgs).map(([key, value]) => ( - - {formatArgValue(value)} - - ))} - -
- ))} - - {hasDangerous && ( -
- - This action may be irreversible. Please confirm before proceeding. -
- )} -
-
- ); -} diff --git a/frontend/src/components/chat-panel/index.ts b/frontend/src/components/chat-panel/index.ts index cdda4aa..ca1b850 100644 --- a/frontend/src/components/chat-panel/index.ts +++ b/frontend/src/components/chat-panel/index.ts @@ -2,19 +2,15 @@ * Chat Panel Components * * Global chat panel for AI conversations accessible from any page. - * Supports two modes: - * - Agent Mode: MCP-based agent with Claude Code-style interaction - * - Chat Mode: Traditional tool calling via LLM API + * Uses MCP-based agent with Claude Code-style interaction. */ // Components export { ChatPanel } from "./ChatPanel"; export { ModelSelector } from "./ModelSelector"; -export { ToolConfirmModal } from "./ToolConfirmModal"; export { AgentChatView } from "./AgentChatView"; // Hooks -export { useChat } from "./useChat"; export { useAgentChat } from "./useAgentChat"; // Types @@ -25,14 +21,6 @@ export type { ModelSourceType, } from "./types"; -export type { - ToolDefinition, - ToolCall, - ToolResult, - ToolMeta, - PendingToolExecution, -} from "./tools"; - export type { AgentEventType, AgentEvent, @@ -47,11 +35,3 @@ export { MAX_PANEL_WIDTH, CHAT_PANEL_STORAGE_KEY, } from "./types"; - -export { - CHAT_TOOLS, - TOOL_META, - requiresConfirmation, - getToolMeta, - executeTool, -} from "./tools"; diff --git a/frontend/src/components/chat-panel/systemContext.ts b/frontend/src/components/chat-panel/systemContext.ts deleted file mode 100644 index a6631a1..0000000 --- a/frontend/src/components/chat-panel/systemContext.ts +++ /dev/null @@ -1,515 +0,0 @@ -/** - * System Context Builder - * - * Builds system context for the AI assistant to understand the current - * state of the LMStack platform. - */ -import { api } from "../../api/client"; - -export interface SystemContext { - workers: WorkerInfo[]; - deployments: DeploymentInfo[]; - models: ModelInfo[]; - containers: ContainerInfo[]; - images: ImageInfo[]; - storageVolumes: StorageVolumeInfo[]; - semanticRouter: SemanticRouterInfo | null; - timestamp: string; -} - -interface WorkerInfo { - id: number; - name: string; - host: string; - status: string; - gpus: GpuInfo[]; -} - -interface GpuInfo { - index: number; - name: string; - memoryTotal: number; - memoryUsed: number; - utilizationGpu: number; -} - -interface DeploymentInfo { - id: number; - name: string; - modelName: string; - workerName: string; - status: string; - endpoint?: string; -} - -interface ModelInfo { - id: number; - name: string; - source: string; - parameters?: string; - quantization?: string; -} - -interface ContainerInfo { - id: string; - name: string; - image: string; - status: string; - workerName: string; -} - -interface ImageInfo { - id: string; - name: string; - tag: string; - size: number; - workerName: string; -} - -interface StorageVolumeInfo { - name: string; - driver: string; - mountpoint: string; - workerName: string; -} - -interface SemanticRouterInfo { - deployed: boolean; - status: string; - models: string[]; -} - -/** - * Fetch current system state - */ -export async function fetchSystemContext(): Promise { - try { - const [ - workersRes, - deploymentsRes, - modelsRes, - containersRes, - imagesRes, - storageRes, - srStatus, - ] = await Promise.all([ - api.get("/workers").catch(() => ({ data: { items: [] } })), - api.get("/deployments").catch(() => ({ data: { items: [] } })), - api.get("/models").catch(() => ({ data: { items: [] } })), - api.get("/containers").catch(() => ({ data: { items: [] } })), - api.get("/images").catch(() => ({ data: { items: [] } })), - api.get("/storage/volumes").catch(() => ({ data: [] })), - api.get("/semantic-router/status").catch(() => ({ data: null })), - ]); - - const workers: WorkerInfo[] = (workersRes.data.items || []).map( - (w: any) => ({ - id: w.id, - name: w.name, - host: w.host, - status: w.status, - gpus: (w.gpu_info || []).map((g: any) => ({ - index: g.index, - name: g.name, - memoryTotal: g.memory_total, - memoryUsed: g.memory_used, - utilizationGpu: g.utilization_gpu, - })), - }), - ); - - const deployments: DeploymentInfo[] = (deploymentsRes.data.items || []).map( - (d: any) => ({ - id: d.id, - name: d.name, - modelName: d.model?.name || d.name, - workerName: d.worker?.name || "unknown", - status: d.status, - endpoint: - d.status === "running" ? `/api/deployments/${d.id}/chat` : undefined, - }), - ); - - const models: ModelInfo[] = (modelsRes.data.items || []).map((m: any) => ({ - id: m.id, - name: m.name, - source: m.source, - parameters: m.parameters, - quantization: m.quantization, - })); - - const containers: ContainerInfo[] = (containersRes.data.items || []).map( - (c: any) => ({ - id: c.id || c.container_id, - name: c.name, - image: c.image, - status: c.status, - workerName: c.worker?.name || c.worker_name || "unknown", - }), - ); - - const images: ImageInfo[] = (imagesRes.data.items || []).map((i: any) => ({ - id: i.id || i.image_id, - name: i.name || i.repository, - tag: i.tag || "latest", - size: i.size || 0, - workerName: i.worker?.name || i.worker_name || "unknown", - })); - - // Backend /storage/volumes returns a list directly, not { items: [] } - const storageVolumes: StorageVolumeInfo[] = ( - Array.isArray(storageRes.data) ? storageRes.data : [] - ).map((v: any) => ({ - name: v.name, - driver: v.driver || "local", - mountpoint: v.mountpoint || "", - workerName: v.worker_name || "unknown", - })); - - const semanticRouter: SemanticRouterInfo | null = srStatus.data - ? { - deployed: srStatus.data.deployed, - status: srStatus.data.status, - models: srStatus.data.models || [], - } - : null; - - return { - workers, - deployments, - models, - containers, - images, - storageVolumes, - semanticRouter, - timestamp: new Date().toISOString(), - }; - } catch (error) { - console.error("Failed to fetch system context:", error); - return { - workers: [], - deployments: [], - models: [], - containers: [], - images: [], - storageVolumes: [], - semanticRouter: null, - timestamp: new Date().toISOString(), - }; - } -} - -/** - * Page routes for navigation links - */ -const PAGE_ROUTES = { - dashboard: "/dashboard", - workers: "/workers", - containers: "/containers", - images: "/images", - storage: "/storage", - models: "/models", - deployments: "/deployments", - chat: "/chat", - apiKeys: "/api-keys", - settings: "/settings", -}; - -/** - * Action links that open modals directly - */ -const ACTION_LINKS = { - newDeployment: "/deployments?action=new", - newModel: "/models?action=new", - newApiKey: "/api-keys?action=new", -}; - -/** - * Format system context as a system message for the LLM - */ -export function formatSystemPrompt(context: SystemContext): string { - const lines: string[] = [ - "You are an AI assistant for the LMStack platform - an LLM deployment and management system.", - "You have access to real-time system information and can help users manage their AI infrastructure.", - "", - "## Navigation Links", - "When referencing pages, use markdown links so users can click to navigate:", - `- Workers: [Workers](${PAGE_ROUTES.workers})`, - `- Docker Containers: [Containers](${PAGE_ROUTES.containers})`, - `- Docker Images: [Images](${PAGE_ROUTES.images})`, - `- Storage Volumes: [Storage](${PAGE_ROUTES.storage})`, - `- Models: [Models](${PAGE_ROUTES.models})`, - `- Deployments: [Deployments](${PAGE_ROUTES.deployments})`, - `- API Keys: [API Keys](${PAGE_ROUTES.apiKeys})`, - "", - "## Quick Action Links (ALWAYS use for create/deploy/add operations)", - "**CRITICAL: For deploying models, adding models, or creating API keys, ALWAYS guide users to the UI. NEVER use deploy_model, add_model, or create_api_key tools directly.**", - "", - "These links open the action dialog directly:", - `- Deploy a Model: [New Deployment](${ACTION_LINKS.newDeployment})`, - `- Add a Model: [Add Model](${ACTION_LINKS.newModel})`, - `- Create API Key: [Create API Key](${ACTION_LINKS.newApiKey})`, - "", - "**TOOL USAGE RULES:**", - "- deploy_model, add_model, create_api_key → NEVER use these. Always guide to UI instead.", - "- stop_deployment, delete_deployment, stop_container, remove_container, delete_* → OK to use (destructive actions need confirmation)", - "- list_*, get_* → OK to use (query tools, no confirmation needed)", - "", - "**EXAMPLES:**", - `- User: '我想部署模型' → '請點擊 [New Deployment](${ACTION_LINKS.newDeployment}) 開啟部署表單。'`, - `- User: '幫我部署 Qwen' → '請點擊 [New Deployment](${ACTION_LINKS.newDeployment}) 來部署,選擇 Qwen 模型即可。'`, - `- User: '幫我新增模型' → '請點擊 [Add Model](${ACTION_LINKS.newModel}) 來新增模型。'`, - "- User: '有哪些模型?' → Use list_models tool", - "- User: '停止 deployment 1' → Use stop_deployment tool", - "", - "## Current System Status", - `Last updated: ${new Date(context.timestamp).toLocaleString()}`, - "", - ]; - - // Workers section - lines.push("### Workers"); - if (context.workers.length === 0) { - lines.push("No workers registered. Go to [Workers](/workers) to add one."); - } else { - lines.push(`Total: ${context.workers.length} worker(s)`); - for (const worker of context.workers) { - lines.push(`- **${worker.name}** (${worker.host}): ${worker.status}`); - for (const gpu of worker.gpus) { - const memUsedGB = (gpu.memoryUsed / 1024).toFixed(1); - const memTotalGB = (gpu.memoryTotal / 1024).toFixed(1); - const memFreeGB = ((gpu.memoryTotal - gpu.memoryUsed) / 1024).toFixed( - 1, - ); - lines.push( - ` - GPU ${gpu.index}: ${gpu.name}, Used: ${memUsedGB}GB, Free: ${memFreeGB}GB, Total: ${memTotalGB}GB, Util: ${gpu.utilizationGpu}%`, - ); - } - } - } - lines.push(""); - - // Docker Containers section - lines.push("### Docker Containers"); - if (context.containers.length === 0) { - lines.push( - "No containers running. View [Containers](/containers) page for details.", - ); - } else { - const runningContainers = context.containers.filter( - (c) => - c.status.toLowerCase().includes("running") || - c.status.toLowerCase().includes("up"), - ); - lines.push( - `Total: ${context.containers.length} container(s), Running: ${runningContainers.length}`, - ); - for (const container of context.containers.slice(0, 10)) { - lines.push( - `- **${container.name}** (${container.image}): ${container.status} on ${container.workerName}`, - ); - } - if (context.containers.length > 10) { - lines.push( - ` ... and ${context.containers.length - 10} more. See [Containers](/containers) for full list.`, - ); - } - } - lines.push(""); - - // Docker Images section - lines.push("### Docker Images"); - if (context.images.length === 0) { - lines.push("No images found. View [Images](/images) page for details."); - } else { - lines.push(`Total: ${context.images.length} image(s)`); - const imagesByWorker: Record = {}; - for (const img of context.images) { - if (!imagesByWorker[img.workerName]) imagesByWorker[img.workerName] = []; - imagesByWorker[img.workerName].push(img); - } - for (const [worker, imgs] of Object.entries(imagesByWorker)) { - lines.push(`- ${worker}: ${imgs.length} image(s)`); - } - } - lines.push(""); - - // Storage section - lines.push("### Storage Volumes"); - if (context.storageVolumes.length === 0) { - lines.push( - "No storage volumes. View [Storage](/storage) page for details.", - ); - } else { - lines.push(`Total: ${context.storageVolumes.length} volume(s)`); - } - lines.push(""); - - // Deployments section - lines.push("### Model Deployments"); - const activeDeployments = context.deployments.filter( - (d) => d.status === "running", - ); - const allDeployments = context.deployments; - lines.push( - `Total: ${allDeployments.length}, Running: ${activeDeployments.length}`, - ); - if (activeDeployments.length === 0) { - lines.push( - "No active model deployments. Go to [Deployments](/deployments) to deploy a model.", - ); - } else { - for (const dep of activeDeployments) { - lines.push( - `- **${dep.modelName}** on ${dep.workerName} (ID: ${dep.id}) - running`, - ); - } - } - lines.push(""); - - // Available models section - lines.push("### Available Models"); - if (context.models.length === 0) { - lines.push("No models registered. Go to [Models](/models) to add models."); - } else { - lines.push(`Total: ${context.models.length} model(s)`); - const modelsBySource: Record = {}; - for (const model of context.models) { - if (!modelsBySource[model.source]) { - modelsBySource[model.source] = []; - } - modelsBySource[model.source].push(model); - } - for (const [source, models] of Object.entries(modelsBySource)) { - lines.push(`- **${source}:** ${models.map((m) => m.name).join(", ")}`); - } - } - lines.push(""); - - // Semantic Router - if (context.semanticRouter) { - lines.push("### Semantic Router"); - lines.push( - `Status: ${context.semanticRouter.deployed ? "Deployed" : "Not deployed"}`, - ); - if (context.semanticRouter.models.length > 0) { - lines.push( - `Connected models: ${context.semanticRouter.models.join(", ")}`, - ); - } - lines.push(""); - } - - // Tool Calling Capabilities - lines.push("## Available Actions (Tool Calling)"); - lines.push( - "You have access to tools that allow you to TAKE ACTIONS on the system:", - ); - lines.push(""); - lines.push("### Query Tools (No confirmation needed)"); - lines.push("- `get_system_status`: Get complete system overview"); - lines.push("- `list_workers`: List all workers with GPU information"); - lines.push( - "- `list_containers`: List Docker containers (filter by status/worker_id)", - ); - lines.push("- `list_deployments`: List model deployments (filter by status)"); - lines.push("- `list_models`: List available models (filter by source)"); - lines.push("- `get_gpu_status`: Get detailed GPU status"); - lines.push(""); - lines.push("### Model Management (Requires user confirmation)"); - lines.push( - "- `add_model`: Add a new model (name, source: huggingface/ollama)", - ); - lines.push("- `delete_model`: Delete a model (model_id)"); - lines.push(""); - lines.push("### Deployment Management (Requires user confirmation)"); - lines.push( - "- `deploy_model`: Deploy a model to a worker (model_id, worker_id, gpu_ids?)", - ); - lines.push("- `stop_deployment`: Stop a running deployment (deployment_id)"); - lines.push( - "- `start_deployment`: Start a stopped deployment (deployment_id)", - ); - lines.push( - "- `delete_deployment`: Delete a deployment permanently (deployment_id)", - ); - lines.push(""); - lines.push("### Container Management (Requires user confirmation)"); - lines.push( - "- `stop_container`: Stop a Docker container (container_name, worker_id)", - ); - lines.push( - "- `remove_container`: Remove a Docker container (container_name, worker_id, force?)", - ); - lines.push(""); - lines.push("### API Key Management"); - lines.push("- `list_api_keys`: List all API keys (No confirmation needed)"); - lines.push( - "- `create_api_key`: Create a new API key (name, description?, expires_in_days?)", - ); - lines.push("- `delete_api_key`: Delete an API key (api_key_id)"); - lines.push(""); - lines.push("### Docker Image Management"); - lines.push( - "- `list_images`: List all Docker images (No confirmation needed)", - ); - lines.push("- `pull_image`: Pull a Docker image (worker_id, image)"); - lines.push( - "- `delete_image`: Delete a Docker image (image_id, worker_id, force?)", - ); - lines.push(""); - lines.push("### Storage Management"); - lines.push( - "- `list_storage_volumes`: List storage volumes (No confirmation needed)", - ); - lines.push( - "- `get_disk_usage`: Get disk usage statistics (No confirmation needed)", - ); - lines.push( - "- `delete_storage_volume`: Delete a storage volume (volume_name, worker_id, force?)", - ); - lines.push( - "- `prune_storage`: Clean up unused Docker resources (images?, containers?, volumes?, build_cache?)", - ); - lines.push(""); - lines.push("**WORKFLOW FOR CONTAINER/IMAGE OPERATIONS:**"); - lines.push( - "1. If user asks to stop/remove a container by name, FIRST call list_containers to find the worker_id", - ); - lines.push( - "2. Then call stop_container or remove_container with both container_name AND worker_id", - ); - lines.push( - "3. Same workflow applies to images - use list_images first to find worker_id", - ); - lines.push(""); - lines.push( - "**IMPORTANT:** When users ask you to perform actions, USE THE TOOLS to execute them. The user will see a confirmation dialog before any action is executed.", - ); - lines.push(""); - - // Capabilities and instructions - lines.push("## Instructions"); - lines.push( - "1. Always use markdown links when mentioning pages (e.g., [Containers](/containers))", - ); - lines.push("2. Provide specific numbers from the system data above"); - lines.push("3. Be concise and accurate"); - lines.push( - "4. When users ask about 'containers' or 'docker containers', refer to the Docker Containers section", - ); - lines.push( - "5. When users ask about 'deployments', distinguish between Docker containers and Model Deployments", - ); - lines.push( - "6. When users ask you to deploy a model, USE the deploy_model tool", - ); - lines.push( - "7. When users ask you to stop/start/delete a deployment, USE the corresponding tool", - ); - lines.push("8. After executing an action, report the result to the user"); - lines.push("9. Respond in the same language as the user's query"); - lines.push(""); - - return lines.join("\n"); -} diff --git a/frontend/src/components/chat-panel/tools.ts b/frontend/src/components/chat-panel/tools.ts deleted file mode 100644 index 125cce1..0000000 --- a/frontend/src/components/chat-panel/tools.ts +++ /dev/null @@ -1,2252 +0,0 @@ -/** - * Chat Tools Definition - * - * Tools that the AI assistant can call to interact with LMStack. - * Includes confirmation flow for dangerous operations. - */ -import { api } from "../../api/client"; -import type { ChatModelConfig } from "./types"; - -/** - * Tool definition for OpenAI-compatible API - */ -export interface ToolDefinition { - type: "function"; - function: { - name: string; - description: string; - parameters: { - type: "object"; - properties: Record; - required?: string[]; - }; - }; -} - -/** - * Tool call from LLM response - */ -export interface ToolCall { - id: string; - type: "function"; - function: { - name: string; - arguments: string; - }; -} - -/** - * Tool execution result - */ -export interface ToolResult { - tool_call_id: string; - role: "tool"; - content: string; -} - -/** - * Tool metadata for UI display - */ -export interface ToolMeta { - name: string; - displayName: string; - description: string; - category: "query" | "action"; - dangerous: boolean; - icon: string; -} - -/** - * Pending tool execution for confirmation - */ -export interface PendingToolExecution { - toolCall: ToolCall; - parsedArgs: Record; - meta: ToolMeta; -} - -/** - * Tool metadata registry - */ -export const TOOL_META: Record = { - // Query tools (no confirmation needed) - get_system_status: { - name: "get_system_status", - displayName: "Get System Status", - description: "Query complete system status", - category: "query", - dangerous: false, - icon: "dashboard", - }, - list_workers: { - name: "list_workers", - displayName: "List Workers", - description: "Query all worker nodes", - category: "query", - dangerous: false, - icon: "cluster", - }, - list_containers: { - name: "list_containers", - displayName: "List Containers", - description: "Query Docker containers", - category: "query", - dangerous: false, - icon: "container", - }, - list_deployments: { - name: "list_deployments", - displayName: "List Deployments", - description: "Query model deployments", - category: "query", - dangerous: false, - icon: "rocket", - }, - list_models: { - name: "list_models", - displayName: "List Models", - description: "Query available models", - category: "query", - dangerous: false, - icon: "robot", - }, - get_gpu_status: { - name: "get_gpu_status", - displayName: "Get GPU Status", - description: "Query GPU usage", - category: "query", - dangerous: false, - icon: "thunderbolt", - }, - - // Action tools (confirmation needed) - add_model: { - name: "add_model", - displayName: "Add Model", - description: "Add a new model to the system", - category: "action", - dangerous: false, - icon: "plus", - }, - delete_model: { - name: "delete_model", - displayName: "Delete Model", - description: "Delete a model from the system", - category: "action", - dangerous: true, - icon: "delete", - }, - deploy_model: { - name: "deploy_model", - displayName: "Deploy Model", - description: "Deploy a model to a worker", - category: "action", - dangerous: false, - icon: "rocket", - }, - stop_deployment: { - name: "stop_deployment", - displayName: "Stop Deployment", - description: "Stop a running deployment", - category: "action", - dangerous: true, - icon: "pause-circle", - }, - start_deployment: { - name: "start_deployment", - displayName: "Start Deployment", - description: "Start a stopped deployment", - category: "action", - dangerous: false, - icon: "play-circle", - }, - delete_deployment: { - name: "delete_deployment", - displayName: "Delete Deployment", - description: "Permanently delete a deployment", - category: "action", - dangerous: true, - icon: "delete", - }, - stop_container: { - name: "stop_container", - displayName: "Stop Container", - description: "Stop a Docker container", - category: "action", - dangerous: true, - icon: "stop", - }, - remove_container: { - name: "remove_container", - displayName: "Remove Container", - description: "Remove a Docker container", - category: "action", - dangerous: true, - icon: "delete", - }, - - // API Key tools - list_api_keys: { - name: "list_api_keys", - displayName: "List API Keys", - description: "Query all API keys", - category: "query", - dangerous: false, - icon: "key", - }, - create_api_key: { - name: "create_api_key", - displayName: "Create API Key", - description: "Create a new API key", - category: "action", - dangerous: false, - icon: "plus", - }, - delete_api_key: { - name: "delete_api_key", - displayName: "Delete API Key", - description: "Delete an API key", - category: "action", - dangerous: true, - icon: "delete", - }, - - // Docker Image tools - list_images: { - name: "list_images", - displayName: "List Images", - description: "Query Docker images", - category: "query", - dangerous: false, - icon: "container", - }, - pull_image: { - name: "pull_image", - displayName: "Pull Image", - description: "Pull a Docker image from registry", - category: "action", - dangerous: false, - icon: "download", - }, - delete_image: { - name: "delete_image", - displayName: "Delete Image", - description: "Delete a Docker image", - category: "action", - dangerous: true, - icon: "delete", - }, - - // Storage tools - list_storage_volumes: { - name: "list_storage_volumes", - displayName: "List Storage Volumes", - description: "Query Docker storage volumes", - category: "query", - dangerous: false, - icon: "database", - }, - get_disk_usage: { - name: "get_disk_usage", - displayName: "Get Disk Usage", - description: "Query disk usage statistics", - category: "query", - dangerous: false, - icon: "pie-chart", - }, - delete_storage_volume: { - name: "delete_storage_volume", - displayName: "Delete Storage Volume", - description: "Delete a Docker storage volume", - category: "action", - dangerous: true, - icon: "delete", - }, - prune_storage: { - name: "prune_storage", - displayName: "Prune Storage", - description: "Clean up unused Docker resources", - category: "action", - dangerous: true, - icon: "clear", - }, - - // Auto-Tuning tools - list_tuning_jobs: { - name: "list_tuning_jobs", - displayName: "List Tuning Jobs", - description: "Query all auto-tuning jobs", - category: "query", - dangerous: false, - icon: "thunderbolt", - }, - start_auto_tuning: { - name: "start_auto_tuning", - displayName: "Start Auto-Tuning", - description: "Start a new auto-tuning job", - category: "action", - dangerous: false, - icon: "experiment", - }, - get_tuning_job: { - name: "get_tuning_job", - displayName: "Get Tuning Job", - description: "Get details of a tuning job", - category: "query", - dangerous: false, - icon: "info-circle", - }, - cancel_tuning_job: { - name: "cancel_tuning_job", - displayName: "Cancel Tuning Job", - description: "Cancel a running tuning job", - category: "action", - dangerous: true, - icon: "stop", - }, - query_knowledge_base: { - name: "query_knowledge_base", - displayName: "Query Knowledge Base", - description: "Query performance knowledge base", - category: "query", - dangerous: false, - icon: "database", - }, - run_benchmark: { - name: "run_benchmark", - displayName: "Run Benchmark", - description: "Run performance benchmark on a deployment", - category: "action", - dangerous: false, - icon: "bar-chart", - }, -}; - -/** - * Check if a tool requires confirmation - */ -export function requiresConfirmation(toolName: string): boolean { - const meta = TOOL_META[toolName]; - return meta?.category === "action"; -} - -/** - * Get tool metadata - */ -export function getToolMeta(toolName: string): ToolMeta { - return ( - TOOL_META[toolName] || { - name: toolName, - displayName: toolName, - description: "Unknown tool", - category: "action", - dangerous: false, - icon: "question", - } - ); -} - -/** - * Available tools for the AI assistant - */ -export const CHAT_TOOLS: ToolDefinition[] = [ - // ============== Query Tools ============== - { - type: "function", - function: { - name: "get_system_status", - description: - "Get complete LMStack system status including workers, GPUs, containers, and deployments. Call this to get the latest system information.", - parameters: { - type: "object", - properties: {}, - required: [], - }, - }, - }, - { - type: "function", - function: { - name: "list_workers", - description: - "List all worker nodes with their GPU status, memory usage, and availability.", - parameters: { - type: "object", - properties: {}, - required: [], - }, - }, - }, - { - type: "function", - function: { - name: "list_containers", - description: "List all Docker containers running across all workers.", - parameters: { - type: "object", - properties: { - status: { - type: "string", - description: "Filter by status: running, stopped, or all", - enum: ["running", "stopped", "all"], - }, - worker_id: { - type: "number", - description: "Optional: Filter by specific worker ID", - }, - }, - required: [], - }, - }, - }, - { - type: "function", - function: { - name: "list_deployments", - description: "List all model deployments with their status.", - parameters: { - type: "object", - properties: { - status: { - type: "string", - description: "Filter by status: running, stopped, or all", - enum: ["running", "stopped", "all"], - }, - }, - required: [], - }, - }, - }, - { - type: "function", - function: { - name: "list_models", - description: "List all available models that can be deployed.", - parameters: { - type: "object", - properties: { - source: { - type: "string", - description: "Filter by source", - enum: ["huggingface", "ollama", "local"], - }, - }, - required: [], - }, - }, - }, - { - type: "function", - function: { - name: "get_gpu_status", - description: - "Get detailed GPU status including memory usage, utilization, and temperature for all workers.", - parameters: { - type: "object", - properties: { - worker_id: { - type: "number", - description: "Optional: Filter by specific worker ID", - }, - }, - required: [], - }, - }, - }, - - // ============== Model Management Tools ============== - { - type: "function", - function: { - name: "add_model", - description: - "Add a new model to the system. Supports HuggingFace and Ollama models.", - parameters: { - type: "object", - properties: { - name: { - type: "string", - description: - "Model name/identifier (e.g., 'Qwen/Qwen2.5-7B-Instruct' for HuggingFace, 'llama3.2' for Ollama)", - }, - source: { - type: "string", - description: "Model source", - enum: ["huggingface", "ollama"], - }, - parameters: { - type: "string", - description: "Optional: Model parameters (e.g., '7B', '13B')", - }, - quantization: { - type: "string", - description: - "Optional: Quantization format (e.g., 'GPTQ', 'AWQ', 'GGUF')", - }, - }, - required: ["name", "source"], - }, - }, - }, - { - type: "function", - function: { - name: "delete_model", - description: - "Delete a model from the system. This will NOT delete any deployments using this model.", - parameters: { - type: "object", - properties: { - model_id: { - type: "number", - description: - "ID of the model to delete (use list_models to find IDs)", - }, - }, - required: ["model_id"], - }, - }, - }, - - // ============== Deployment Tools ============== - { - type: "function", - function: { - name: "deploy_model", - description: - "Deploy a model to a worker. This will start the model inference service.", - parameters: { - type: "object", - properties: { - model_id: { - type: "number", - description: - "ID of the model to deploy (use list_models to find IDs)", - }, - worker_id: { - type: "number", - description: - "ID of the worker to deploy to (use list_workers to find IDs)", - }, - gpu_ids: { - type: "array", - items: { type: "number" }, - description: - "Optional: Specific GPU indices to use. If not provided, GPUs will be auto-selected.", - }, - }, - required: ["model_id", "worker_id"], - }, - }, - }, - { - type: "function", - function: { - name: "stop_deployment", - description: "Stop a running model deployment.", - parameters: { - type: "object", - properties: { - deployment_id: { - type: "number", - description: - "ID of the deployment to stop (use list_deployments to find IDs)", - }, - }, - required: ["deployment_id"], - }, - }, - }, - { - type: "function", - function: { - name: "start_deployment", - description: "Start a stopped model deployment.", - parameters: { - type: "object", - properties: { - deployment_id: { - type: "number", - description: "ID of the deployment to start", - }, - }, - required: ["deployment_id"], - }, - }, - }, - { - type: "function", - function: { - name: "delete_deployment", - description: - "Delete a model deployment completely. This cannot be undone.", - parameters: { - type: "object", - properties: { - deployment_id: { - type: "number", - description: "ID of the deployment to delete", - }, - }, - required: ["deployment_id"], - }, - }, - }, - - // ============== Container Tools ============== - { - type: "function", - function: { - name: "stop_container", - description: - "Stop a running Docker container. If you don't know the worker_id, call list_containers first to find it.", - parameters: { - type: "object", - properties: { - container_name: { - type: "string", - description: - "Name of the container to stop (e.g., 'lmstack-llama')", - }, - worker_id: { - type: "number", - description: - "ID of the worker where the container is running. Use list_containers to find this.", - }, - }, - required: ["container_name", "worker_id"], - }, - }, - }, - { - type: "function", - function: { - name: "remove_container", - description: - "Remove/delete a Docker container. If you don't know the worker_id, call list_containers first to find it.", - parameters: { - type: "object", - properties: { - container_name: { - type: "string", - description: "Name of the container to remove", - }, - worker_id: { - type: "number", - description: - "ID of the worker where the container is located. Use list_containers to find this.", - }, - force: { - type: "boolean", - description: "Force remove even if running (default: false)", - }, - }, - required: ["container_name", "worker_id"], - }, - }, - }, - - // ============== API Key Tools ============== - { - type: "function", - function: { - name: "list_api_keys", - description: - "List all API keys in the system with their usage statistics.", - parameters: { - type: "object", - properties: {}, - required: [], - }, - }, - }, - { - type: "function", - function: { - name: "create_api_key", - description: "Create a new API key for accessing the LMStack API.", - parameters: { - type: "object", - properties: { - name: { - type: "string", - description: - "Name for the API key (e.g., 'production-key', 'test-key')", - }, - description: { - type: "string", - description: "Optional description for the API key", - }, - expires_in_days: { - type: "number", - description: - "Optional: Number of days until the key expires. If not set, the key never expires.", - }, - }, - required: ["name"], - }, - }, - }, - { - type: "function", - function: { - name: "delete_api_key", - description: "Delete an API key from the system.", - parameters: { - type: "object", - properties: { - api_key_id: { - type: "number", - description: - "ID of the API key to delete (use list_api_keys to find IDs)", - }, - }, - required: ["api_key_id"], - }, - }, - }, - - // ============== Docker Image Tools ============== - { - type: "function", - function: { - name: "list_images", - description: "List all Docker images across all workers.", - parameters: { - type: "object", - properties: { - worker_id: { - type: "number", - description: "Optional: Filter by specific worker ID", - }, - repository: { - type: "string", - description: "Optional: Filter by repository name", - }, - }, - required: [], - }, - }, - }, - { - type: "function", - function: { - name: "pull_image", - description: "Pull a Docker image from a registry to a worker.", - parameters: { - type: "object", - properties: { - worker_id: { - type: "number", - description: "ID of the worker to pull the image to", - }, - image: { - type: "string", - description: - "Image reference (e.g., 'nginx:latest', 'python:3.11')", - }, - }, - required: ["worker_id", "image"], - }, - }, - }, - { - type: "function", - function: { - name: "delete_image", - description: "Delete a Docker image from a worker.", - parameters: { - type: "object", - properties: { - image_id: { - type: "string", - description: "ID or name of the image to delete", - }, - worker_id: { - type: "number", - description: "ID of the worker where the image is located", - }, - force: { - type: "boolean", - description: - "Force removal even if image is in use (default: false)", - }, - }, - required: ["image_id", "worker_id"], - }, - }, - }, - - // ============== Storage Tools ============== - { - type: "function", - function: { - name: "list_storage_volumes", - description: "List all Docker storage volumes across all workers.", - parameters: { - type: "object", - properties: { - worker_id: { - type: "number", - description: "Optional: Filter by specific worker ID", - }, - }, - required: [], - }, - }, - }, - { - type: "function", - function: { - name: "get_disk_usage", - description: - "Get Docker disk usage statistics including images, containers, volumes, and build cache.", - parameters: { - type: "object", - properties: { - worker_id: { - type: "number", - description: "Optional: Filter by specific worker ID", - }, - }, - required: [], - }, - }, - }, - { - type: "function", - function: { - name: "delete_storage_volume", - description: "Delete a Docker storage volume from a worker.", - parameters: { - type: "object", - properties: { - volume_name: { - type: "string", - description: "Name of the volume to delete", - }, - worker_id: { - type: "number", - description: "ID of the worker where the volume is located", - }, - force: { - type: "boolean", - description: "Force removal (default: false)", - }, - }, - required: ["volume_name", "worker_id"], - }, - }, - }, - { - type: "function", - function: { - name: "prune_storage", - description: - "Clean up unused Docker resources (images, containers, volumes, build cache) to free disk space.", - parameters: { - type: "object", - properties: { - worker_id: { - type: "number", - description: - "Optional: Only prune on specific worker. If not set, prunes on all workers.", - }, - images: { - type: "boolean", - description: "Prune unused images (default: true)", - }, - containers: { - type: "boolean", - description: "Prune stopped containers (default: true)", - }, - volumes: { - type: "boolean", - description: "Prune unused volumes (default: false - be careful!)", - }, - build_cache: { - type: "boolean", - description: "Prune build cache (default: true)", - }, - }, - required: [], - }, - }, - }, - - // ============== Auto-Tuning Tools ============== - { - type: "function", - function: { - name: "list_tuning_jobs", - description: "List all auto-tuning jobs with their status and progress.", - parameters: { - type: "object", - properties: { - status: { - type: "string", - description: "Filter by status", - enum: [ - "pending", - "analyzing", - "querying_kb", - "exploring", - "benchmarking", - "completed", - "failed", - "cancelled", - ], - }, - }, - required: [], - }, - }, - }, - { - type: "function", - function: { - name: "start_auto_tuning", - description: - "Start a new auto-tuning job to find the best deployment configuration for a model. The agent will analyze the environment, query the knowledge base, explore configuration space, run benchmarks, and find the optimal settings.", - parameters: { - type: "object", - properties: { - model_id: { - type: "number", - description: - "ID of the model to tune (use list_models to find IDs)", - }, - worker_id: { - type: "number", - description: - "ID of the worker to use for tuning (use list_workers to find IDs)", - }, - optimization_target: { - type: "string", - description: "What to optimize for", - enum: ["throughput", "latency", "cost", "balanced"], - }, - }, - required: ["model_id", "worker_id"], - }, - }, - }, - { - type: "function", - function: { - name: "get_tuning_job", - description: - "Get detailed information about a specific tuning job including progress, best configuration, and all results.", - parameters: { - type: "object", - properties: { - job_id: { - type: "number", - description: "ID of the tuning job", - }, - }, - required: ["job_id"], - }, - }, - }, - { - type: "function", - function: { - name: "cancel_tuning_job", - description: "Cancel a running auto-tuning job.", - parameters: { - type: "object", - properties: { - job_id: { - type: "number", - description: "ID of the tuning job to cancel", - }, - }, - required: ["job_id"], - }, - }, - }, - { - type: "function", - function: { - name: "query_knowledge_base", - description: - "Query the performance knowledge base to find similar configurations and their benchmark results. This uses transfer learning from previous tuning results.", - parameters: { - type: "object", - properties: { - model_name: { - type: "string", - description: "Model name pattern to match (e.g., 'Qwen', 'Llama')", - }, - model_family: { - type: "string", - description: "Model family: Qwen, Llama, Mistral, etc.", - }, - gpu_model: { - type: "string", - description: "GPU model pattern (e.g., 'RTX 4090', 'A100')", - }, - optimization_target: { - type: "string", - description: "Optimization target for scoring", - enum: ["throughput", "latency", "cost", "balanced"], - }, - limit: { - type: "number", - description: "Maximum number of results to return (default: 10)", - }, - }, - required: [], - }, - }, - }, - { - type: "function", - function: { - name: "run_benchmark", - description: - "Run a performance benchmark on a deployment to measure throughput, latency, and resource usage.", - parameters: { - type: "object", - properties: { - deployment_id: { - type: "number", - description: - "ID of the deployment to benchmark (use list_deployments to find IDs)", - }, - test_type: { - type: "string", - description: "Type of benchmark test", - enum: ["throughput", "latency"], - }, - duration_seconds: { - type: "number", - description: "Test duration in seconds (10-600, default: 60)", - }, - input_length: { - type: "number", - description: "Input token length (default: 512)", - }, - output_length: { - type: "number", - description: "Output token length (default: 128)", - }, - concurrency: { - type: "number", - description: "Number of concurrent requests (1-64, default: 1)", - }, - }, - required: ["deployment_id"], - }, - }, - }, -]; - -/** - * Execute a tool call and return the result - * @param toolCall - The tool call to execute - * @param modelConfig - Optional model config for tools that need LLM access (like auto-tuning) - */ -export async function executeTool( - toolCall: ToolCall, - modelConfig?: ChatModelConfig, -): Promise { - const { name, arguments: argsStr } = toolCall.function; - let args: Record = {}; - - try { - args = JSON.parse(argsStr); - } catch { - return { - tool_call_id: toolCall.id, - role: "tool", - content: `Error: Invalid arguments JSON: ${argsStr}`, - }; - } - - try { - let result: string; - - switch (name) { - // Query tools - case "get_system_status": - result = await getSystemStatus(); - break; - - case "list_workers": - result = await listWorkers(); - break; - - case "list_containers": - result = await listContainers(args.status, args.worker_id); - break; - - case "list_deployments": - result = await listDeployments(args.status); - break; - - case "list_models": - result = await listModels(args.source); - break; - - case "get_gpu_status": - result = await getGpuStatus(args.worker_id); - break; - - // Model management tools - case "add_model": - result = await addModel( - args.name, - args.source, - args.parameters, - args.quantization, - ); - break; - - case "delete_model": - result = await deleteModel(args.model_id); - break; - - // Deployment tools - case "deploy_model": - result = await deployModel(args.model_id, args.worker_id, args.gpu_ids); - break; - - case "stop_deployment": - result = await stopDeployment(args.deployment_id); - break; - - case "start_deployment": - result = await startDeployment(args.deployment_id); - break; - - case "delete_deployment": - result = await deleteDeployment(args.deployment_id); - break; - - // Container tools - case "stop_container": - result = await stopContainer(args.container_name, args.worker_id); - break; - - case "remove_container": - result = await removeContainer( - args.container_name, - args.worker_id, - args.force, - ); - break; - - // API Key tools - case "list_api_keys": - result = await listApiKeys(); - break; - - case "create_api_key": - result = await createApiKey( - args.name, - args.description, - args.expires_in_days, - ); - break; - - case "delete_api_key": - result = await deleteApiKey(args.api_key_id); - break; - - // Docker Image tools - case "list_images": - result = await listImages(args.worker_id, args.repository); - break; - - case "pull_image": - result = await pullImage(args.worker_id, args.image); - break; - - case "delete_image": - result = await deleteImage(args.image_id, args.worker_id, args.force); - break; - - // Storage tools - case "list_storage_volumes": - result = await listStorageVolumes(args.worker_id); - break; - - case "get_disk_usage": - result = await getDiskUsage(args.worker_id); - break; - - case "delete_storage_volume": - result = await deleteStorageVolume( - args.volume_name, - args.worker_id, - args.force, - ); - break; - - case "prune_storage": - result = await pruneStorage( - args.worker_id, - args.images, - args.containers, - args.volumes, - args.build_cache, - ); - break; - - // Auto-Tuning tools - case "list_tuning_jobs": - result = await listTuningJobs(args.status); - break; - - case "start_auto_tuning": - result = await startAutoTuning( - args.model_id, - args.worker_id, - args.optimization_target, - modelConfig, - ); - break; - - case "get_tuning_job": - result = await getTuningJob(args.job_id); - break; - - case "cancel_tuning_job": - result = await cancelTuningJob(args.job_id); - break; - - case "query_knowledge_base": - result = await queryKnowledgeBase( - args.model_name, - args.model_family, - args.gpu_model, - args.optimization_target, - args.limit, - ); - break; - - case "run_benchmark": - result = await runBenchmark( - args.deployment_id, - args.test_type, - args.duration_seconds, - args.input_length, - args.output_length, - args.concurrency, - ); - break; - - default: - result = `Unknown tool: ${name}`; - } - - return { - tool_call_id: toolCall.id, - role: "tool", - content: result, - }; - } catch (error) { - const message = error instanceof Error ? error.message : String(error); - return { - tool_call_id: toolCall.id, - role: "tool", - content: `Error executing ${name}: ${message}`, - }; - } -} - -// ============================================================================ -// Tool Implementations -// ============================================================================ - -async function getSystemStatus(): Promise { - const [workers, containers, deployments, models] = await Promise.all([ - api.get("/workers").then((r) => r.data.items || []), - api.get("/containers").then((r) => r.data.items || []), - api.get("/deployments").then((r) => r.data.items || []), - api.get("/models").then((r) => r.data.items || []), - ]); - - const onlineWorkers = workers.filter((w: any) => w.status === "online"); - const runningContainers = containers.filter( - (c: any) => - c.status?.toLowerCase().includes("running") || - c.status?.toLowerCase().includes("up"), - ); - const runningDeployments = deployments.filter( - (d: any) => d.status === "running", - ); - - let totalGpuMem = 0, - usedGpuMem = 0; - for (const w of workers) { - for (const g of w.gpu_info || []) { - totalGpuMem += g.memory_total || 0; - usedGpuMem += g.memory_used || 0; - } - } - - return JSON.stringify( - { - summary: { - workers: `${onlineWorkers.length}/${workers.length} online`, - containers: `${runningContainers.length}/${containers.length} running`, - deployments: `${runningDeployments.length}/${deployments.length} running`, - models: `${models.length} available`, - gpu_memory: { - used_gb: (usedGpuMem / 1024).toFixed(1), - free_gb: ((totalGpuMem - usedGpuMem) / 1024).toFixed(1), - total_gb: (totalGpuMem / 1024).toFixed(1), - }, - }, - workers: workers.map((w: any) => ({ - id: w.id, - name: w.name, - status: w.status, - gpus: (w.gpu_info || []).map((g: any) => ({ - index: g.index, - name: g.name, - memory_used_gb: (g.memory_used / 1024).toFixed(1), - memory_free_gb: ((g.memory_total - g.memory_used) / 1024).toFixed(1), - memory_total_gb: (g.memory_total / 1024).toFixed(1), - utilization: g.utilization_gpu, - })), - })), - running_deployments: runningDeployments.map((d: any) => ({ - id: d.id, - model: d.model?.name || d.name, - worker: d.worker?.name, - })), - }, - null, - 2, - ); -} - -async function listWorkers(): Promise { - const response = await api.get("/workers"); - const workers = response.data.items || []; - - return JSON.stringify( - workers.map((w: any) => ({ - id: w.id, - name: w.name, - host: w.host, - status: w.status, - gpus: (w.gpu_info || []).map((g: any) => ({ - index: g.index, - name: g.name, - memory_used_gb: (g.memory_used / 1024).toFixed(1), - memory_free_gb: ((g.memory_total - g.memory_used) / 1024).toFixed(1), - memory_total_gb: (g.memory_total / 1024).toFixed(1), - utilization_percent: g.utilization_gpu, - })), - })), - null, - 2, - ); -} - -async function listContainers( - status?: string, - workerId?: number, -): Promise { - const response = await api.get("/containers"); - let containers = response.data.items || []; - - if (workerId) { - containers = containers.filter( - (c: any) => c.worker?.id === workerId || c.worker_id === workerId, - ); - } - - if (status && status !== "all") { - containers = containers.filter((c: any) => { - const s = c.status?.toLowerCase() || ""; - if (status === "running") { - return s.includes("running") || s.includes("up"); - } - return s.includes(status); - }); - } - - return JSON.stringify( - containers.map((c: any) => ({ - id: c.id?.substring(0, 12), - name: c.name, - image: c.image, - status: c.status, - worker: c.worker?.name || c.worker_name, - worker_id: c.worker?.id || c.worker_id, - })), - null, - 2, - ); -} - -async function listDeployments(status?: string): Promise { - const response = await api.get("/deployments"); - let deployments = response.data.items || []; - - if (status && status !== "all") { - deployments = deployments.filter((d: any) => d.status === status); - } - - return JSON.stringify( - deployments.map((d: any) => ({ - id: d.id, - name: d.name, - model: d.model?.name, - model_id: d.model?.id, - worker: d.worker?.name, - worker_id: d.worker?.id, - status: d.status, - gpu_ids: d.gpu_ids, - port: d.port, - created_at: d.created_at, - })), - null, - 2, - ); -} - -async function listModels(source?: string): Promise { - const response = await api.get("/models"); - let models = response.data.items || []; - - if (source) { - models = models.filter((m: any) => m.source === source); - } - - return JSON.stringify( - models.map((m: any) => ({ - id: m.id, - name: m.name, - source: m.source, - parameters: m.parameters, - quantization: m.quantization, - })), - null, - 2, - ); -} - -async function getGpuStatus(workerId?: number): Promise { - const response = await api.get("/workers"); - let workers = response.data.items || []; - - if (workerId) { - workers = workers.filter((w: any) => w.id === workerId); - } - - const result = workers.map((w: any) => ({ - worker_id: w.id, - worker_name: w.name, - status: w.status, - gpus: (w.gpu_info || []).map((g: any) => ({ - index: g.index, - name: g.name, - memory_used_gb: (g.memory_used / 1024).toFixed(1), - memory_free_gb: ((g.memory_total - g.memory_used) / 1024).toFixed(1), - memory_total_gb: (g.memory_total / 1024).toFixed(1), - utilization_percent: g.utilization_gpu, - temperature: g.temperature, - })), - })); - - return JSON.stringify(result, null, 2); -} - -async function deployModel( - modelId: number, - workerId: number, - gpuIds?: number[], -): Promise { - if (!modelId || !workerId) { - return "Error: model_id and worker_id are required"; - } - - const response = await api.post("/deployments", { - model_id: modelId, - worker_id: workerId, - gpu_ids: gpuIds, - }); - - const deployment = response.data; - const deploymentId = deployment.id; - - // Poll deployment status for up to 60 seconds - const maxPollTime = 60000; - const pollInterval = 3000; - const startTime = Date.now(); - let lastStatus = deployment.status; - const statusUpdates: string[] = [`Initial status: ${deployment.status}`]; - - while (Date.now() - startTime < maxPollTime) { - await new Promise((resolve) => setTimeout(resolve, pollInterval)); - - try { - const statusResponse = await api.get(`/deployments/${deploymentId}`); - const currentStatus = statusResponse.data.status; - const statusMessage = statusResponse.data.status_message; - - if (currentStatus !== lastStatus) { - statusUpdates.push( - `Status changed: ${lastStatus} → ${currentStatus}${statusMessage ? ` (${statusMessage})` : ""}`, - ); - lastStatus = currentStatus; - } - - // Stop polling if deployment reached a terminal state - if (["running", "error", "stopped"].includes(currentStatus)) { - return JSON.stringify( - { - success: currentStatus === "running", - message: - currentStatus === "running" - ? `Deployment completed successfully! Model is now running.` - : currentStatus === "error" - ? `Deployment failed: ${statusMessage || "Unknown error"}` - : `Deployment stopped`, - deployment: { - id: deploymentId, - status: currentStatus, - status_message: statusMessage, - model: deployment.model?.name, - worker: deployment.worker?.name, - port: statusResponse.data.port, - }, - status_history: statusUpdates, - }, - null, - 2, - ); - } - } catch (error) { - // Continue polling even if one request fails - } - } - - // Timeout - return current status - return JSON.stringify( - { - success: false, - message: `Deployment is still in progress (status: ${lastStatus}). Check [Deployments](/deployments) page for updates.`, - deployment: { - id: deploymentId, - status: lastStatus, - model: deployment.model?.name, - worker: deployment.worker?.name, - }, - status_history: statusUpdates, - note: "Deployment is taking longer than expected. This is normal for large models that need to be downloaded.", - }, - null, - 2, - ); -} - -async function stopDeployment(deploymentId: number): Promise { - if (!deploymentId) { - return "Error: deployment_id is required"; - } - - await api.post(`/deployments/${deploymentId}/stop`); - return JSON.stringify( - { - success: true, - message: `Deployment ${deploymentId} stopped successfully`, - }, - null, - 2, - ); -} - -async function startDeployment(deploymentId: number): Promise { - if (!deploymentId) { - return "Error: deployment_id is required"; - } - - await api.post(`/deployments/${deploymentId}/start`); - return JSON.stringify( - { - success: true, - message: `Deployment ${deploymentId} started successfully`, - }, - null, - 2, - ); -} - -async function deleteDeployment(deploymentId: number): Promise { - if (!deploymentId) { - return "Error: deployment_id is required"; - } - - await api.delete(`/deployments/${deploymentId}`); - return JSON.stringify( - { - success: true, - message: `Deployment ${deploymentId} deleted successfully`, - }, - null, - 2, - ); -} - -// ============================================================================ -// Model Management Tools -// ============================================================================ - -async function addModel( - name: string, - source: string, - parameters?: string, - quantization?: string, -): Promise { - if (!name || !source) { - return "Error: name and source are required"; - } - - const response = await api.post("/models", { - name, - source, - parameters, - quantization, - }); - - const model = response.data; - return JSON.stringify( - { - success: true, - message: `Model added successfully`, - model: { - id: model.id, - name: model.name, - source: model.source, - parameters: model.parameters, - quantization: model.quantization, - }, - }, - null, - 2, - ); -} - -async function deleteModel(modelId: number): Promise { - if (!modelId) { - return "Error: model_id is required"; - } - - await api.delete(`/models/${modelId}`); - return JSON.stringify( - { - success: true, - message: `Model ${modelId} deleted successfully`, - }, - null, - 2, - ); -} - -// ============================================================================ -// Container Tools -// ============================================================================ - -async function stopContainer( - containerName: string, - workerId: number, -): Promise { - if (!containerName || !workerId) { - return "Error: container_name and worker_id are required"; - } - - // Backend expects: POST /containers/{container_id}/stop?worker_id=X - await api.post( - `/containers/${encodeURIComponent(containerName)}/stop`, - null, - { - params: { worker_id: workerId }, - }, - ); - - return JSON.stringify( - { - success: true, - message: `Container "${containerName}" stopped successfully`, - }, - null, - 2, - ); -} - -async function removeContainer( - containerName: string, - workerId: number, - force?: boolean, -): Promise { - if (!containerName || !workerId) { - return "Error: container_name and worker_id are required"; - } - - // Backend expects: DELETE /containers/{container_id}?worker_id=X&force=Y - await api.delete(`/containers/${encodeURIComponent(containerName)}`, { - params: { worker_id: workerId, force: force || false }, - }); - - return JSON.stringify( - { - success: true, - message: `Container "${containerName}" removed successfully`, - }, - null, - 2, - ); -} - -// ============================================================================ -// API Key Tools -// ============================================================================ - -async function listApiKeys(): Promise { - const response = await api.get("/api-keys"); - const apiKeys = response.data.items || []; - - return JSON.stringify( - { - total: response.data.total || apiKeys.length, - api_keys: apiKeys.map((k: any) => ({ - id: k.id, - name: k.name, - description: k.description, - access_key: k.access_key, - expires_at: k.expires_at, - created_at: k.created_at, - last_used_at: k.last_used_at, - })), - }, - null, - 2, - ); -} - -async function createApiKey( - name: string, - description?: string, - expiresInDays?: number, -): Promise { - if (!name) { - return "Error: name is required"; - } - - const response = await api.post("/api-keys", { - name, - description, - expires_in_days: expiresInDays, - }); - - const apiKey = response.data; - return JSON.stringify( - { - success: true, - message: "API key created successfully", - api_key: { - id: apiKey.id, - name: apiKey.name, - access_key: apiKey.access_key, - full_key: apiKey.api_key, // The full key is only shown once! - expires_at: apiKey.expires_at, - }, - warning: "Save the full API key now! It will not be shown again.", - }, - null, - 2, - ); -} - -async function deleteApiKey(apiKeyId: number): Promise { - if (!apiKeyId) { - return "Error: api_key_id is required"; - } - - await api.delete(`/api-keys/${apiKeyId}`); - return JSON.stringify( - { - success: true, - message: `API key ${apiKeyId} deleted successfully`, - }, - null, - 2, - ); -} - -// ============================================================================ -// Docker Image Tools -// ============================================================================ - -async function listImages( - workerId?: number, - repository?: string, -): Promise { - const params: any = {}; - if (workerId) params.worker_id = workerId; - if (repository) params.repository = repository; - - const response = await api.get("/images", { params }); - const images = response.data.items || []; - - return JSON.stringify( - { - total: response.data.total || images.length, - images: images.map((img: any) => ({ - id: img.id?.substring(0, 12), - repository: img.repository, - tag: img.tag, - full_name: img.full_name, - size_mb: (img.size / 1024 / 1024).toFixed(1), - created_at: img.created_at, - worker: img.worker_name, - worker_id: img.worker_id, - })), - }, - null, - 2, - ); -} - -async function pullImage(workerId: number, image: string): Promise { - if (!workerId || !image) { - return "Error: worker_id and image are required"; - } - - const response = await api.post("/images/pull", { - worker_id: workerId, - image, - }); - - return JSON.stringify( - { - success: true, - message: `Image "${image}" pulled successfully`, - image: response.data.image, - }, - null, - 2, - ); -} - -async function deleteImage( - imageId: string, - workerId: number, - force?: boolean, -): Promise { - if (!imageId || !workerId) { - return "Error: image_id and worker_id are required"; - } - - await api.delete(`/images/${encodeURIComponent(imageId)}`, { - params: { worker_id: workerId, force: force || false }, - }); - - return JSON.stringify( - { - success: true, - message: `Image "${imageId}" deleted successfully`, - }, - null, - 2, - ); -} - -// ============================================================================ -// Storage Tools -// ============================================================================ - -async function listStorageVolumes(workerId?: number): Promise { - const params: any = {}; - if (workerId) params.worker_id = workerId; - - const response = await api.get("/storage/volumes", { params }); - const volumes = Array.isArray(response.data) ? response.data : []; - - return JSON.stringify( - { - total: volumes.length, - volumes: volumes.map((v: any) => ({ - name: v.name, - driver: v.driver, - mountpoint: v.mountpoint, - created_at: v.created_at, - worker: v.worker_name, - worker_id: v.worker_id, - })), - }, - null, - 2, - ); -} - -async function getDiskUsage(workerId?: number): Promise { - const params: any = {}; - if (workerId) params.worker_id = workerId; - - const response = await api.get("/storage/disk-usage", { params }); - const usageList = Array.isArray(response.data) ? response.data : []; - - const formatSize = (bytes: number) => { - if (bytes >= 1024 * 1024 * 1024) { - return `${(bytes / 1024 / 1024 / 1024).toFixed(2)} GB`; - } - return `${(bytes / 1024 / 1024).toFixed(2)} MB`; - }; - - return JSON.stringify( - { - workers: usageList.map((u: any) => ({ - worker: u.worker_name, - worker_id: u.worker_id, - images: { - count: u.images.count, - size: formatSize(u.images.size), - reclaimable: formatSize(u.images.reclaimable), - }, - containers: { - count: u.containers.count, - size: formatSize(u.containers.size), - reclaimable: formatSize(u.containers.reclaimable), - }, - volumes: { - count: u.volumes.count, - size: formatSize(u.volumes.size), - reclaimable: formatSize(u.volumes.reclaimable), - }, - build_cache: { - count: u.build_cache.count, - size: formatSize(u.build_cache.size), - reclaimable: formatSize(u.build_cache.reclaimable), - }, - total_size: formatSize(u.total_size), - total_reclaimable: formatSize(u.total_reclaimable), - })), - }, - null, - 2, - ); -} - -async function deleteStorageVolume( - volumeName: string, - workerId: number, - force?: boolean, -): Promise { - if (!volumeName || !workerId) { - return "Error: volume_name and worker_id are required"; - } - - await api.delete(`/storage/volumes/${encodeURIComponent(volumeName)}`, { - params: { worker_id: workerId, force: force || false }, - }); - - return JSON.stringify( - { - success: true, - message: `Volume "${volumeName}" deleted successfully`, - }, - null, - 2, - ); -} - -async function pruneStorage( - workerId?: number, - images: boolean = true, - containers: boolean = true, - volumes: boolean = false, - buildCache: boolean = true, -): Promise { - const params: any = {}; - if (workerId) params.worker_id = workerId; - - const response = await api.post( - "/storage/prune", - { - images, - containers, - volumes, - build_cache: buildCache, - }, - { params }, - ); - - const results = Array.isArray(response.data) ? response.data : []; - - const formatSize = (bytes: number) => { - if (bytes >= 1024 * 1024 * 1024) { - return `${(bytes / 1024 / 1024 / 1024).toFixed(2)} GB`; - } - return `${(bytes / 1024 / 1024).toFixed(2)} MB`; - }; - - return JSON.stringify( - { - success: true, - message: "Storage pruned successfully", - results: results.map((r: any) => ({ - worker: r.worker_name, - images_deleted: r.images_deleted, - containers_deleted: r.containers_deleted, - volumes_deleted: r.volumes_deleted, - build_cache_deleted: r.build_cache_deleted, - space_reclaimed: formatSize(r.space_reclaimed), - })), - }, - null, - 2, - ); -} - -// ============================================================================ -// Auto-Tuning Tools -// ============================================================================ - -async function listTuningJobs(status?: string): Promise { - const response = await api.get("/auto-tuning/jobs"); - let jobs = response.data.items || []; - - if (status) { - jobs = jobs.filter((j: any) => j.status === status); - } - - return JSON.stringify( - { - total: jobs.length, - jobs: jobs.map((j: any) => ({ - id: j.id, - model: j.model_name, - worker: j.worker_name, - optimization_target: j.optimization_target, - status: j.status, - progress: j.progress - ? { - step: j.progress.step, - total_steps: j.progress.total_steps, - step_name: j.progress.step_name, - configs_tested: j.progress.configs_tested, - configs_total: j.progress.configs_total, - best_score: j.progress.best_score_so_far, - } - : null, - created_at: j.created_at, - })), - note: "View detailed results at [Auto-Tuning](/auto-tuning) page", - }, - null, - 2, - ); -} - -async function startAutoTuning( - modelId: number, - workerId: number, - optimizationTarget?: string, - modelConfig?: ChatModelConfig, -): Promise { - if (!modelId || !workerId) { - return "Error: model_id and worker_id are required"; - } - - // Build the LLM configuration for the agent - const llmConfig: Record = {}; - if (modelConfig) { - if (modelConfig.type === "deployment" && modelConfig.deploymentId) { - // Use local deployment - llmConfig.deployment_id = modelConfig.deploymentId; - } else if (modelConfig.type === "custom" && modelConfig.endpoint) { - // Use custom endpoint - llmConfig.base_url = modelConfig.endpoint; - llmConfig.api_key = modelConfig.apiKey || ""; - llmConfig.model = modelConfig.modelId || modelConfig.name; - } - } - - const response = await api.post("/auto-tuning/jobs", { - model_id: modelId, - worker_id: workerId, - optimization_target: optimizationTarget || "balanced", - llm_config: Object.keys(llmConfig).length > 0 ? llmConfig : undefined, - }); - - const job = response.data; - - return JSON.stringify( - { - success: true, - message: - "Auto-tuning job started! The agent will use the selected model for reasoning.", - job: { - id: job.id, - model: job.model_name, - worker: job.worker_name, - optimization_target: job.optimization_target, - status: job.status, - agent_llm: modelConfig?.name || "auto-detected", - }, - note: "Track progress at [Auto-Tuning](/auto-tuning) page. This may take several minutes depending on the number of configurations to test.", - }, - null, - 2, - ); -} - -async function getTuningJob(jobId: number): Promise { - if (!jobId) { - return "Error: job_id is required"; - } - - const response = await api.get(`/auto-tuning/jobs/${jobId}`); - const job = response.data; - - return JSON.stringify( - { - id: job.id, - model: job.model_name, - worker: job.worker_name, - optimization_target: job.optimization_target, - status: job.status, - status_message: job.status_message, - progress: job.progress - ? { - step: job.progress.step, - total_steps: job.progress.total_steps, - step_name: job.progress.step_name, - step_description: job.progress.step_description, - configs_tested: job.progress.configs_tested, - configs_total: job.progress.configs_total, - current_config: job.progress.current_config, - best_config_so_far: job.progress.best_config_so_far, - best_score_so_far: job.progress.best_score_so_far, - } - : null, - best_config: job.best_config, - all_results: job.all_results, - created_at: job.created_at, - completed_at: job.completed_at, - }, - null, - 2, - ); -} - -async function cancelTuningJob(jobId: number): Promise { - if (!jobId) { - return "Error: job_id is required"; - } - - await api.post(`/auto-tuning/jobs/${jobId}/cancel`); - - return JSON.stringify( - { - success: true, - message: `Tuning job ${jobId} cancelled successfully`, - }, - null, - 2, - ); -} - -async function queryKnowledgeBase( - modelName?: string, - modelFamily?: string, - gpuModel?: string, - optimizationTarget?: string, - limit?: number, -): Promise { - const response = await api.post("/auto-tuning/knowledge/query", { - model_name: modelName, - model_family: modelFamily, - gpu_model: gpuModel, - optimization_target: optimizationTarget || "balanced", - limit: limit || 10, - }); - - const data = response.data; - - return JSON.stringify( - { - total: data.total, - query: data.query, - results: data.items.map((r: any) => ({ - gpu: `${r.gpu_count}x ${r.gpu_model}`, - total_vram_gb: r.total_vram_gb, - model: r.model_name, - model_family: r.model_family, - model_params_b: r.model_params_b, - engine: r.engine, - quantization: r.quantization, - tensor_parallel: r.tensor_parallel, - throughput_tps: r.throughput_tps, - ttft_ms: r.ttft_ms, - tpot_ms: r.tpot_ms, - score: r.score, - })), - note: "These are benchmark results from similar configurations. Use this to guide your deployment decisions.", - }, - null, - 2, - ); -} - -async function runBenchmark( - deploymentId: number, - testType?: string, - durationSeconds?: number, - inputLength?: number, - outputLength?: number, - concurrency?: number, -): Promise { - if (!deploymentId) { - return "Error: deployment_id is required"; - } - - const response = await api.post("/auto-tuning/benchmarks/run", { - deployment_id: deploymentId, - test_type: testType || "throughput", - duration_seconds: durationSeconds || 60, - input_length: inputLength || 512, - output_length: outputLength || 128, - concurrency: concurrency || 1, - }); - - const result = response.data; - - return JSON.stringify( - { - success: true, - message: "Benchmark completed", - benchmark: { - id: result.id, - deployment_id: result.deployment_id, - test_type: result.test_type, - duration_seconds: result.test_duration_seconds, - config: { - input_length: result.input_length, - output_length: result.output_length, - concurrency: result.concurrency, - }, - metrics: { - throughput_tps: result.metrics?.throughput_tps, - ttft_ms: result.metrics?.ttft_ms, - tpot_ms: result.metrics?.tpot_ms, - total_latency_ms: result.metrics?.total_latency_ms, - gpu_utilization: result.metrics?.gpu_utilization, - vram_usage_gb: result.metrics?.vram_usage_gb, - }, - }, - note: "Higher throughput (TPS) is better. Lower latency (TTFT, TPOT) is better.", - }, - null, - 2, - ); -} diff --git a/frontend/src/components/chat-panel/useChat.ts b/frontend/src/components/chat-panel/useChat.ts deleted file mode 100644 index 4b3353e..0000000 --- a/frontend/src/components/chat-panel/useChat.ts +++ /dev/null @@ -1,762 +0,0 @@ -/** - * useChat Hook - * - * Encapsulates chat logic for streaming conversations with LLM endpoints. - * Supports deployments, Semantic Router, and custom OpenAI-compatible endpoints. - * Includes system context injection for AI assistant capabilities. - * Supports Tool Calling for LLM to interact with LMStack system. - */ -import { useState, useRef, useCallback, useEffect } from "react"; -import { message } from "antd"; -import { generateMessageId } from "../chat"; -import type { ChatMessage } from "../chat"; -import type { ChatModelConfig } from "./types"; -import { STORAGE_KEYS } from "../../constants"; -import { - fetchSystemContext, - formatSystemPrompt, - type SystemContext, -} from "./systemContext"; -import { - CHAT_TOOLS, - executeTool, - requiresConfirmation, - getToolMeta, - type ToolCall, - type ToolResult, - type PendingToolExecution, -} from "./tools"; - -interface UseChatOptions { - /** Called when a new message is added */ - onMessageAdded?: (message: ChatMessage) => void; - /** Called when streaming completes */ - onStreamComplete?: (userMsg: ChatMessage, assistantMsg: ChatMessage) => void; -} - -interface UseChatReturn { - messages: ChatMessage[]; - isStreaming: boolean; - isExecutingTool: boolean; - currentToolName: string | null; - pendingTools: PendingToolExecution[]; - showConfirmModal: boolean; - systemContext: SystemContext | null; - refreshContext: () => Promise; - sendMessage: (content: string, model: ChatModelConfig) => Promise; - stopStreaming: () => void; - clearMessages: () => void; - setMessages: React.Dispatch>; - confirmToolExecution: () => void; - cancelToolExecution: () => void; -} - -/** - * Hook for managing chat state and streaming - */ -export function useChat(options: UseChatOptions = {}): UseChatReturn { - const { onMessageAdded, onStreamComplete } = options; - - const [messages, setMessages] = useState([]); - const [isStreaming, setIsStreaming] = useState(false); - const [isExecutingTool, setIsExecutingTool] = useState(false); - const [currentToolName, setCurrentToolName] = useState(null); - const [systemContext, setSystemContext] = useState( - null, - ); - - // Tool confirmation state - const [pendingTools, setPendingTools] = useState([]); - const [showConfirmModal, setShowConfirmModal] = useState(false); - const pendingToolResolveRef = useRef<((confirmed: boolean) => void) | null>( - null, - ); - - const abortControllerRef = useRef(null); - - // Fetch system context on mount and periodically refresh - const refreshContext = useCallback(async () => { - const context = await fetchSystemContext(); - setSystemContext(context); - }, []); - - useEffect(() => { - refreshContext(); - const interval = setInterval(refreshContext, 30000); // Refresh every 30s - return () => clearInterval(interval); - }, [refreshContext]); - - /** - * Check if model uses proxy endpoint - */ - const isProxyRequest = useCallback((model: ChatModelConfig): boolean => { - return model.type === "custom"; - }, []); - - /** - * Get the chat endpoint URL based on model config - */ - const getEndpointUrl = useCallback((model: ChatModelConfig): string => { - switch (model.type) { - case "deployment": - return `/api/deployments/${model.deploymentId}/chat`; - case "semantic-router": - return `/api/semantic-router/chat`; - case "custom": - // Use backend proxy to avoid CORS issues - return `/api/chat-proxy`; - default: - return ""; - } - }, []); - - /** - * Get headers for the request - */ - const getHeaders = useCallback((): HeadersInit => { - const headers: HeadersInit = { - "Content-Type": "application/json", - }; - - const token = localStorage.getItem(STORAGE_KEYS.TOKEN); - if (token) { - headers["Authorization"] = `Bearer ${token}`; - } - - return headers; - }, []); - - /** - * Stream a chat completion request - */ - const streamChatCompletion = useCallback( - async ( - endpoint: string, - requestBody: any, - assistantMessageId: string, - signal: AbortSignal, - ): Promise<{ - content: string; - thinking: string; - model?: string; - toolCalls?: ToolCall[]; - finishReason?: string; - }> => { - const response = await fetch(endpoint, { - method: "POST", - headers: getHeaders(), - body: JSON.stringify(requestBody), - signal, - }); - - if (!response.ok) { - throw new Error(`API error: ${response.status} ${response.statusText}`); - } - - const reader = response.body?.getReader(); - if (!reader) throw new Error("No response body"); - - const decoder = new TextDecoder(); - let accumulatedContent = ""; - let accumulatedThinking = ""; - let responseModel: string | undefined; - let buffer = ""; - let finishReason: string | undefined; - - // Track tool calls being accumulated - const toolCallsMap: Map< - number, - { id: string; name: string; arguments: string } - > = new Map(); - - // eslint-disable-next-line no-constant-condition - while (true) { - const { done, value } = await reader.read(); - if (done) break; - - buffer += decoder.decode(value, { stream: true }); - const lines = buffer.split("\n"); - buffer = lines.pop() || ""; - - for (const line of lines) { - const trimmedLine = line.trim(); - if (!trimmedLine || !trimmedLine.startsWith("data:")) continue; - - const data = trimmedLine.slice(5).trim(); - if (data === "[DONE]") continue; - - try { - const parsed = JSON.parse(data); - - if (parsed.error) { - throw new Error(parsed.error.message || "API error"); - } - - const choice = parsed.choices?.[0]; - const delta = choice?.delta; - - // Track finish reason - if (choice?.finish_reason) { - finishReason = choice.finish_reason; - } - - // Handle regular content - const deltaContent = delta?.content || ""; - accumulatedContent += deltaContent; - - // Handle thinking/reasoning content (for models like DeepSeek-R1) - const deltaThinking = delta?.reasoning_content || ""; - accumulatedThinking += deltaThinking; - - // Handle tool calls streaming - if (delta?.tool_calls) { - for (const tc of delta.tool_calls) { - const index = tc.index ?? 0; - if (!toolCallsMap.has(index)) { - toolCallsMap.set(index, { - id: tc.id || "", - name: "", - arguments: "", - }); - } - const existing = toolCallsMap.get(index)!; - if (tc.id) existing.id = tc.id; - if (tc.function?.name) existing.name = tc.function.name; - if (tc.function?.arguments) - existing.arguments += tc.function.arguments; - } - } - - if (!responseModel && parsed.model) { - responseModel = parsed.model; - } - - setMessages((prev) => - prev.map((m) => - m.id === assistantMessageId - ? { - ...m, - content: accumulatedContent, - thinking: accumulatedThinking || undefined, - model: responseModel, - } - : m, - ), - ); - } catch { - // Skip invalid JSON - } - } - } - - // Process remaining buffer - if (buffer.trim().startsWith("data:")) { - const data = buffer.trim().slice(5).trim(); - if (data !== "[DONE]") { - try { - const parsed = JSON.parse(data); - const choice = parsed.choices?.[0]; - const delta = choice?.delta; - if (choice?.finish_reason) finishReason = choice.finish_reason; - const deltaContent = delta?.content || ""; - const deltaThinking = delta?.reasoning_content || ""; - accumulatedContent += deltaContent; - accumulatedThinking += deltaThinking; - if (!responseModel && parsed.model) { - responseModel = parsed.model; - } - } catch { - // Skip invalid JSON - } - } - } - - // Convert tool calls map to array - const toolCalls: ToolCall[] = []; - for (const [, tc] of toolCallsMap) { - if (tc.id && tc.name) { - toolCalls.push({ - id: tc.id, - type: "function", - function: { name: tc.name, arguments: tc.arguments }, - }); - } - } - - return { - content: accumulatedContent, - thinking: accumulatedThinking, - model: responseModel, - toolCalls: toolCalls.length > 0 ? toolCalls : undefined, - finishReason, - }; - }, - [getHeaders], - ); - - /** - * Send a message and stream the response with tool calling support - */ - const sendMessage = useCallback( - async (content: string, model: ChatModelConfig) => { - if (!content.trim() || isStreaming) return; - - const endpoint = getEndpointUrl(model); - if (!endpoint) { - message.error("Invalid endpoint configuration"); - return; - } - - setIsStreaming(true); - - const userMessage: ChatMessage = { - id: generateMessageId(), - role: "user", - content: content.trim(), - timestamp: new Date(), - }; - - const assistantMessage: ChatMessage = { - id: generateMessageId(), - role: "assistant", - content: "", - timestamp: new Date(), - }; - - setMessages((prev) => [...prev, userMessage, assistantMessage]); - onMessageAdded?.(userMessage); - - try { - abortControllerRef.current = new AbortController(); - - const modelName = model.modelId || model.name; - - // Build messages array with system context - type ChatMessagePayload = { - role: string; - content: string | null; - tool_calls?: ToolCall[]; - tool_call_id?: string; - }; - const chatMessages: ChatMessagePayload[] = []; - - // Add system prompt with current context - if (systemContext) { - chatMessages.push({ - role: "system", - content: formatSystemPrompt(systemContext), - }); - } - - // Add conversation history - chatMessages.push( - ...messages.map((m) => ({ role: m.role, content: m.content })), - { role: "user", content: content.trim() }, - ); - - // Build the chat payload with tools - const chatPayload = { - model: modelName, - messages: chatMessages, - stream: true, - temperature: 0.7, - tools: CHAT_TOOLS, - tool_choice: "auto" as const, - }; - - // Build request body - const requestBody = isProxyRequest(model) - ? { - endpoint: model.endpoint, - api_key: model.apiKey || null, - payload: chatPayload, - } - : chatPayload; - - // Track the current message being streamed to - let currentMessageId = assistantMessage.id; - - // First streaming request - let result = await streamChatCompletion( - endpoint, - requestBody, - currentMessageId, - abortControllerRef.current.signal, - ); - - // Tool calling loop - continue until no more tool calls - let iterationCount = 0; - const maxIterations = 10; // Prevent infinite loops - - while ( - result.toolCalls && - result.toolCalls.length > 0 && - iterationCount < maxIterations - ) { - iterationCount++; - - // Check if any tool requires confirmation - const toolsNeedingConfirmation = result.toolCalls.filter((tc) => - requiresConfirmation(tc.function.name), - ); - - if (toolsNeedingConfirmation.length > 0) { - // Prepare pending tools for confirmation - const pending: PendingToolExecution[] = - toolsNeedingConfirmation.map((tc) => { - let parsedArgs: Record = {}; - try { - parsedArgs = JSON.parse(tc.function.arguments); - } catch { - parsedArgs = { raw: tc.function.arguments }; - } - return { - toolCall: tc, - parsedArgs, - meta: getToolMeta(tc.function.name), - }; - }); - - setPendingTools(pending); - setShowConfirmModal(true); - - // Update message to show waiting for confirmation - setMessages((prev) => - prev.map((m) => - m.id === currentMessageId - ? { - ...m, - content: result.content || "", - toolCalls: result.toolCalls, - } - : m, - ), - ); - - // Wait for user confirmation - const confirmed = await new Promise((resolve) => { - pendingToolResolveRef.current = resolve; - }); - - setPendingTools([]); - setShowConfirmModal(false); - pendingToolResolveRef.current = null; - - if (!confirmed) { - // User cancelled - stop tool execution and inform LLM - const cancelledResults: ToolResult[] = - toolsNeedingConfirmation.map((tc) => ({ - tool_call_id: tc.id, - role: "tool" as const, - content: JSON.stringify({ - success: false, - message: "User cancelled the operation", - }), - })); - - // Execute query tools that don't need confirmation - const queryTools = result.toolCalls.filter( - (tc) => !requiresConfirmation(tc.function.name), - ); - for (const tc of queryTools) { - const queryResult = await executeTool(tc, model); - cancelledResults.push(queryResult); - } - - // Update message - setMessages((prev) => - prev.map((m) => - m.id === currentMessageId - ? { - ...m, - content: result.content || "", - toolCalls: result.toolCalls, - } - : m, - ), - ); - - // Build new messages with cancelled results - const newMessages: ChatMessagePayload[] = [ - ...chatMessages, - { - role: "assistant", - content: result.content || null, - tool_calls: result.toolCalls, - }, - ...cancelledResults.map((tr) => ({ - role: "tool", - content: tr.content, - tool_call_id: tr.tool_call_id, - })), - ]; - - // Continue to let LLM know about cancellation - const continuationMessage: ChatMessage = { - id: generateMessageId(), - role: "assistant", - content: "", - timestamp: new Date(), - }; - - setMessages((prev) => - prev - .map((m) => - m.id === currentMessageId - ? { - ...m, - content: result.content || "", - toolCalls: result.toolCalls, - } - : m, - ) - .concat(continuationMessage), - ); - - // Update current message ID to the continuation - currentMessageId = continuationMessage.id; - - const continuationPayload = { - model: modelName, - messages: newMessages, - stream: true, - temperature: 0.7, - tools: CHAT_TOOLS, - tool_choice: "auto" as const, - }; - - const continuationRequestBody = isProxyRequest(model) - ? { - endpoint: model.endpoint, - api_key: model.apiKey || null, - payload: continuationPayload, - } - : continuationPayload; - - result = await streamChatCompletion( - endpoint, - continuationRequestBody, - currentMessageId, - abortControllerRef.current.signal, - ); - - chatMessages.length = 0; - chatMessages.push(...newMessages); - continue; - } - } - - // Execute all tool calls (confirmed or query-only) - setIsExecutingTool(true); - const toolResults: ToolResult[] = []; - - for (const toolCall of result.toolCalls) { - setCurrentToolName(toolCall.function.name); - - // Update message to show tool execution - setMessages((prev) => - prev.map((m) => - m.id === currentMessageId - ? { - ...m, - content: result.content || "", - toolCalls: result.toolCalls, - } - : m, - ), - ); - - const toolResult = await executeTool(toolCall, model); - toolResults.push(toolResult); - - // Refresh system context after tool execution (data may have changed) - await refreshContext(); - } - - setIsExecutingTool(false); - setCurrentToolName(null); - - // Build new messages array with tool calls and results - const newMessages: ChatMessagePayload[] = [ - ...chatMessages, - // Assistant message with tool calls - { - role: "assistant", - content: result.content || null, - tool_calls: result.toolCalls, - }, - // Tool results - ...toolResults.map((tr) => ({ - role: "tool", - content: tr.content, - tool_call_id: tr.tool_call_id, - })), - ]; - - // Create new assistant message for continued response - const continuationMessage: ChatMessage = { - id: generateMessageId(), - role: "assistant", - content: "", - timestamp: new Date(), - }; - - setMessages((prev) => { - // Update the current assistant message and add continuation - return prev - .map((m) => - m.id === currentMessageId - ? { - ...m, - content: result.content || "", - toolCalls: result.toolCalls, - } - : m, - ) - .concat(continuationMessage); - }); - - // Update current message ID to the continuation - currentMessageId = continuationMessage.id; - - // Build new chat payload with tool results - const continuationPayload = { - model: modelName, - messages: newMessages, - stream: true, - temperature: 0.7, - tools: CHAT_TOOLS, - tool_choice: "auto" as const, - }; - - const continuationRequestBody = isProxyRequest(model) - ? { - endpoint: model.endpoint, - api_key: model.apiKey || null, - payload: continuationPayload, - } - : continuationPayload; - - // Continue streaming - result = await streamChatCompletion( - endpoint, - continuationRequestBody, - currentMessageId, - abortControllerRef.current.signal, - ); - - // Update chat messages for next iteration if needed - chatMessages.length = 0; - chatMessages.push(...newMessages); - } - - // Final update - only update the current (last) message - setMessages((prev) => - prev.map((m) => { - if (m.id === currentMessageId) { - return { - ...m, - content: result.content, - thinking: result.thinking || undefined, - model: result.model, - // Only set toolCalls if there are any (don't overwrite with undefined) - ...(result.toolCalls ? { toolCalls: result.toolCalls } : {}), - }; - } - return m; - }), - ); - - // Get the final message for callback - const finalAssistantMsg: ChatMessage = { - id: currentMessageId, - role: "assistant", - content: result.content, - thinking: result.thinking || undefined, - model: result.model, - toolCalls: result.toolCalls, - timestamp: new Date(), - }; - - onStreamComplete?.(userMessage, finalAssistantMsg); - } catch (error: unknown) { - const err = error as Error; - if (err.name === "AbortError") { - message.info("Generation stopped"); - } else { - message.error(`Error: ${err.message}`); - // Remove the initial assistant message on error - setMessages((prev) => - prev.filter((m) => m.id !== assistantMessage.id), - ); - } - } finally { - setIsStreaming(false); - setIsExecutingTool(false); - setCurrentToolName(null); - abortControllerRef.current = null; - } - }, - [ - messages, - isStreaming, - systemContext, - getEndpointUrl, - isProxyRequest, - getHeaders, - onMessageAdded, - onStreamComplete, - streamChatCompletion, - refreshContext, - ], - ); - - /** - * Stop the current streaming response - */ - const stopStreaming = useCallback(() => { - abortControllerRef.current?.abort(); - }, []); - - /** - * Clear all messages - */ - const clearMessages = useCallback(() => { - setMessages([]); - }, []); - - /** - * Confirm pending tool execution - */ - const confirmToolExecution = useCallback(() => { - if (pendingToolResolveRef.current) { - pendingToolResolveRef.current(true); - } - }, []); - - /** - * Cancel pending tool execution - */ - const cancelToolExecution = useCallback(() => { - if (pendingToolResolveRef.current) { - pendingToolResolveRef.current(false); - } - }, []); - - return { - messages, - isStreaming, - isExecutingTool, - currentToolName, - pendingTools, - showConfirmModal, - systemContext, - refreshContext, - sendMessage, - stopStreaming, - clearMessages, - setMessages, - confirmToolExecution, - cancelToolExecution, - }; -} diff --git a/frontend/src/components/logos/index.tsx b/frontend/src/components/logos/index.tsx index 2fba673..ff49e6b 100644 --- a/frontend/src/components/logos/index.tsx +++ b/frontend/src/components/logos/index.tsx @@ -71,6 +71,62 @@ export function HuggingFaceLogo({ ); } +/** + * MLX Logo - Apple's ML framework for Apple Silicon + */ +export function MLXLogo({ height = 16, style }: Omit) { + // Use Apple-style gradient colors + const gradientId = `mlx-gradient-${Math.random().toString(36).substr(2, 9)}`; + return ( + + + + + + + + + + MLX + + + ); +} + +/** + * Llama.cpp Logo + */ +export function LlamaCppLogo({ + height = 16, + isDark = false, + style, +}: LogoProps) { + const textColor = isDark ? "#ffffff" : "#333333"; + return ( + + + llama.cpp + + + ); +} + // ============================================================================= // Icons // ============================================================================= @@ -129,5 +185,15 @@ export function getBackendConfig( color: tagColor, icon: , }, + mlx: { + label: "MLX", + color: tagColor, + icon: , + }, + llama_cpp: { + label: "llama.cpp", + color: tagColor, + icon: , + }, }; } diff --git a/frontend/src/pages/AutoTuning.tsx b/frontend/src/pages/AutoTuning.tsx index e1bdfd5..0bdedfb 100644 --- a/frontend/src/pages/AutoTuning.tsx +++ b/frontend/src/pages/AutoTuning.tsx @@ -920,7 +920,7 @@ export default function AutoTuning() { } + prefix={} /> @@ -931,9 +931,9 @@ export default function AutoTuning() { value={runningJobs} prefix={ runningJobs > 0 ? ( - + ) : ( - + ) } /> @@ -944,7 +944,7 @@ export default function AutoTuning() { } + prefix={} /> @@ -1214,7 +1214,18 @@ export default function AutoTuning() { - + {logModal && + [ + "pending", + "analyzing", + "querying_kb", + "exploring", + "benchmarking", + ].includes(logModal.status) ? ( + + ) : ( + + )} Live Logs {logModal && ( <> diff --git a/frontend/src/pages/Deployments.tsx b/frontend/src/pages/Deployments.tsx index bfe0b5f..922e529 100644 --- a/frontend/src/pages/Deployments.tsx +++ b/frontend/src/pages/Deployments.tsx @@ -1,6 +1,7 @@ import { useEffect, useState, useCallback, useRef } from "react"; import { useSearchParams } from "react-router-dom"; import { + Alert, Button, Card, Form, @@ -30,9 +31,7 @@ import { } from "@ant-design/icons"; import { useAppTheme } from "../hooks/useTheme"; import { - VllmLogo, OllamaLogo, - SGLangLogo, HuggingFaceLogo, DockerIcon, getBackendConfig, @@ -99,14 +98,41 @@ export default function Deployments() { // Get the selected model const selectedModel = models.find((m) => m.id === selectedModelId); - // Determine available backends based on model source - const availableBackends = - selectedModel?.source === "ollama" - ? (["ollama"] as const) - : (["vllm", "sglang"] as const); - // Get the selected worker's GPU info const selectedWorker = workers.find((w) => w.id === selectedWorkerId); + + // Determine available backends based on model source and worker capabilities + const availableBackends = (() => { + // Start with model-based restrictions + if (selectedModel?.source === "ollama") { + return ["ollama"] as const; + } + + // If no worker selected, show all HuggingFace-compatible backends + if (!selectedWorker) { + return ["vllm", "sglang"] as const; + } + + // macOS workers only support Ollama (vLLM/SGLang require NVIDIA GPU) + if (selectedWorker.os_type === "darwin") { + return ["ollama"] as const; + } + + // Use worker's available_backends if provided + if ( + selectedWorker.available_backends && + selectedWorker.available_backends.length > 0 + ) { + // Filter for HuggingFace-compatible backends from worker's list + const hfBackends = selectedWorker.available_backends.filter((b) => + ["vllm", "sglang", "ollama"].includes(b), + ); + return hfBackends.length > 0 ? hfBackends : (["vllm", "sglang"] as const); + } + + // Default fallback for Linux workers + return ["vllm", "sglang"] as const; + })(); const workerGpus = selectedWorker?.gpu_info || []; // Calculate total available GPU memory for selected GPUs @@ -351,11 +377,7 @@ export default function Deployments() { {record.model?.source !== "ollama" && renderSourceTag(record.model?.source, "small")} - {backend === "vllm" && } - {backend === "sglang" && } - {backend === "ollama" && ( - - )} + {config.icon} {config.label} {record.model?.name} @ {record.worker?.name} @@ -543,11 +565,7 @@ export default function Deployments() { {record.model?.source !== "ollama" && renderSourceTag(record.model?.source, "small")} - {backend === "vllm" && } - {backend === "sglang" && } - {backend === "ollama" && ( - - )} + {config.icon} {config.label} @@ -836,7 +854,9 @@ export default function Deployments() { extra={ selectedModel?.source === "ollama" ? "Ollama models can only use Ollama backend" - : "HuggingFace models can use vLLM or SGLang" + : selectedWorker?.os_type === "darwin" + ? "macOS workers only support Ollama backend" + : "HuggingFace models can use vLLM or SGLang" } >