From 55e556f18466102c0a7ee1e21f9159b3c37d2d71 Mon Sep 17 00:00:00 2001 From: Caitlyn Byrne Date: Sun, 8 Feb 2026 20:03:54 -0500 Subject: [PATCH 1/2] feat: add per-task-type model assignment with three-tier system Allows assigning different AI models to different task types (initializer, coding, testing, spec creation, expansion, assistant, log review). Each provider defines three tiers (high/mid/low) with smart defaults -- e.g., Opus for initialization, Sonnet for coding, Haiku for testing. Users can override per-role in Settings, and the assistant chat panel gets an on-the-fly tier switcher. - registry.py: ROLE_TIER_MAP, provider tiers, get_model_for_role() - parallel_orchestrator.py: per-role model resolution in spawn methods - Chat sessions: use get_model_for_role() per session type - Settings API: read/write per-role overrides, expose tiers - UI: collapsible "Model per Task" section, assistant tier toggle Co-Authored-By: Claude Opus 4.6 --- parallel_orchestrator.py | 32 +++++-- registry.py | 104 +++++++++++++++++++++- server/routers/assistant_chat.py | 22 +++++ server/routers/settings.py | 35 +++++++- server/schemas.py | 14 +++ server/services/assistant_chat_session.py | 76 +++++++++++++++- server/services/expand_chat_session.py | 7 +- server/services/spec_chat_session.py | 7 +- ui/src/components/AssistantChat.tsx | 42 ++++++++- ui/src/components/SettingsModal.tsx | 62 ++++++++++++- ui/src/hooks/useAssistantChat.ts | 44 ++++++++- ui/src/hooks/useProjects.ts | 21 +++-- ui/src/lib/types.ts | 13 +++ 13 files changed, 444 insertions(+), 35 deletions(-) diff --git a/parallel_orchestrator.py b/parallel_orchestrator.py index 856e33cb..4f46de84 100644 --- a/parallel_orchestrator.py +++ b/parallel_orchestrator.py @@ -228,6 +228,18 @@ def get_session(self): """Get a new database session.""" return self._session_maker() + def _get_model_for_role(self, role: str) -> str | None: + """Resolve the model for a given agent role. + + Uses the per-role model resolution from registry, with self.model + as the CLI override (highest priority). + + Returns: + Model ID string, or None if the default should be used. + """ + from registry import get_model_for_role + return get_model_for_role(role, cli_override=self.model) + def _get_random_passing_feature(self) -> int | None: """Get a random passing feature for regression testing (no claim needed). @@ -829,8 +841,9 @@ def _spawn_coding_agent(self, feature_id: int) -> tuple[bool, str]: "--agent-type", "coding", "--feature-id", str(feature_id), ] - if self.model: - cmd.extend(["--model", self.model]) + coding_model = self._get_model_for_role("coding") + if coding_model: + cmd.extend(["--model", coding_model]) if self.yolo_mode: cmd.append("--yolo") @@ -895,8 +908,9 @@ def _spawn_coding_agent_batch(self, feature_ids: list[int]) -> tuple[bool, str]: "--agent-type", "coding", "--feature-ids", ",".join(str(fid) for fid in feature_ids), ] - if self.model: - cmd.extend(["--model", self.model]) + coding_model = self._get_model_for_role("coding") + if coding_model: + cmd.extend(["--model", coding_model]) if self.yolo_mode: cmd.append("--yolo") @@ -998,8 +1012,9 @@ def _spawn_testing_agent(self) -> tuple[bool, str]: "--agent-type", "testing", "--testing-feature-ids", batch_str, ] - if self.model: - cmd.extend(["--model", self.model]) + testing_model = self._get_model_for_role("testing") + if testing_model: + cmd.extend(["--model", testing_model]) try: # CREATE_NO_WINDOW on Windows prevents console window pop-ups @@ -1058,8 +1073,9 @@ async def _run_initializer(self) -> bool: "--agent-type", "initializer", "--max-iterations", "1", ] - if self.model: - cmd.extend(["--model", self.model]) + init_model = self._get_model_for_role("initializer") + if init_model: + cmd.extend(["--model", init_model]) print("Running initializer agent...", flush=True) diff --git a/registry.py b/registry.py index 30765198..0ca7cc8d 100644 --- a/registry.py +++ b/registry.py @@ -48,6 +48,7 @@ def _migrate_registry_dir() -> None: AVAILABLE_MODELS = [ {"id": "claude-opus-4-6", "name": "Claude Opus"}, {"id": "claude-sonnet-4-5-20250929", "name": "Claude Sonnet"}, + {"id": "claude-haiku-4-5-20251001", "name": "Claude Haiku"}, ] # Map legacy model IDs to their current replacements. @@ -73,6 +74,23 @@ def _migrate_registry_dir() -> None: VALID_MODELS.append(DEFAULT_MODEL) DEFAULT_YOLO_MODE = False +# ============================================================================= +# Role-to-Tier Mapping +# ============================================================================= + +# Each role maps to a tier (high/mid/low) which resolves to a provider-specific model +ROLE_TIER_MAP: dict[str, str] = { + "initializer": "high", + "coding": "mid", + "testing": "low", + "spec_creation": "high", + "expand": "high", + "assistant": "mid", + "log_review": "low", +} +VALID_ROLES = list(ROLE_TIER_MAP.keys()) +ROLE_MODEL_KEY_PREFIX = "model_" + # SQLite connection settings SQLITE_TIMEOUT = 30 # seconds to wait for database lock SQLITE_MAX_RETRIES = 3 # number of retry attempts on busy database @@ -617,9 +635,12 @@ def get_all_settings() -> dict[str, str]: settings = session.query(Settings).all() result = {s.key: s.value for s in settings} - # Auto-migrate legacy model IDs + # Auto-migrate legacy model IDs (including per-role model_* keys) migrated = False - for key in ("model", "api_model"): + migration_keys = ["model", "api_model"] + [ + f"{ROLE_MODEL_KEY_PREFIX}{role}" for role in VALID_ROLES + ] + for key in migration_keys: old_id = result.get(key) if old_id and old_id in LEGACY_MODEL_MAP: new_id = LEGACY_MODEL_MAP[old_id] @@ -642,6 +663,75 @@ def get_all_settings() -> dict[str, str]: return {} +def get_model_for_role(role: str, cli_override: str | None = None) -> str: + """Resolve the model to use for a given role. + + Resolution chain: + 1. cli_override (--model flag) — universal override + 2. settings["model_{role}"] — per-role override from Settings UI + 3. Provider tier default — provider.tiers[ROLE_TIER_MAP[role]] + 4. settings["api_model"] — global model override + 5. DEFAULT_MODEL — absolute fallback + + Args: + role: One of VALID_ROLES (e.g. "coding", "testing", "assistant") + cli_override: Model passed via --model CLI flag + + Returns: + Model ID string + """ + if cli_override: + return cli_override + + all_settings = get_all_settings() + + # Per-role setting + role_model = all_settings.get(f"{ROLE_MODEL_KEY_PREFIX}{role}") + if role_model: + return role_model + + # Provider tier default + provider_id = all_settings.get("api_provider", "claude") + provider = API_PROVIDERS.get(provider_id, {}) + tier = ROLE_TIER_MAP.get(role, "mid") + tier_model = provider.get("tiers", {}).get(tier) + if tier_model: + return tier_model + + # Global fallback + return all_settings.get("api_model") or all_settings.get("model") or DEFAULT_MODEL + + +def get_model_for_tier(tier: str) -> str: + """Resolve the model for a specific tier (high/mid/low). + + Used by the assistant chat session for dynamic tier switching. + + Args: + tier: One of "high", "mid", "low" + + Returns: + Model ID string + """ + all_settings = get_all_settings() + provider_id = all_settings.get("api_provider", "claude") + provider = API_PROVIDERS.get(provider_id, {}) + tier_model = provider.get("tiers", {}).get(tier) + if tier_model: + return tier_model + return all_settings.get("api_model") or all_settings.get("model") or DEFAULT_MODEL + + +def get_all_role_models() -> dict[str, str | None]: + """Get all per-role model overrides from settings. + + Returns: + Dict mapping role names to model IDs (None if no override set). + """ + all_settings = get_all_settings() + return {role: all_settings.get(f"{ROLE_MODEL_KEY_PREFIX}{role}") or None for role in VALID_ROLES} + + # ============================================================================= # API Provider Definitions # ============================================================================= @@ -654,8 +744,14 @@ def get_all_settings() -> dict[str, str]: "models": [ {"id": "claude-opus-4-6", "name": "Claude Opus"}, {"id": "claude-sonnet-4-5-20250929", "name": "Claude Sonnet"}, + {"id": "claude-haiku-4-5-20251001", "name": "Claude Haiku"}, ], "default_model": "claude-opus-4-6", + "tiers": { + "high": "claude-opus-4-6", + "mid": "claude-sonnet-4-5-20250929", + "low": "claude-haiku-4-5-20251001", + }, }, "kimi": { "name": "Kimi K2.5 (Moonshot)", @@ -664,6 +760,7 @@ def get_all_settings() -> dict[str, str]: "auth_env_var": "ANTHROPIC_API_KEY", "models": [{"id": "kimi-k2.5", "name": "Kimi K2.5"}], "default_model": "kimi-k2.5", + "tiers": {"high": "kimi-k2.5", "mid": "kimi-k2.5", "low": "kimi-k2.5"}, }, "glm": { "name": "GLM (Zhipu AI)", @@ -675,6 +772,7 @@ def get_all_settings() -> dict[str, str]: {"id": "glm-4.5-air", "name": "GLM 4.5 Air"}, ], "default_model": "glm-4.7", + "tiers": {"high": "glm-4.7", "mid": "glm-4.5-air", "low": "glm-4.5-air"}, }, "ollama": { "name": "Ollama (Local)", @@ -685,6 +783,7 @@ def get_all_settings() -> dict[str, str]: {"id": "deepseek-coder-v2", "name": "DeepSeek Coder V2"}, ], "default_model": "qwen3-coder", + "tiers": {"high": "qwen3-coder", "mid": "qwen3-coder", "low": "qwen3-coder"}, }, "custom": { "name": "Custom Provider", @@ -693,6 +792,7 @@ def get_all_settings() -> dict[str, str]: "auth_env_var": "ANTHROPIC_AUTH_TOKEN", "models": [], "default_model": "", + "tiers": {"high": "", "mid": "", "low": ""}, }, } diff --git a/server/routers/assistant_chat.py b/server/routers/assistant_chat.py index 1c3ece5c..e324248e 100644 --- a/server/routers/assistant_chat.py +++ b/server/routers/assistant_chat.py @@ -255,6 +255,22 @@ async def assistant_chat_websocket(websocket: WebSocket, project_name: str): await websocket.send_json({"type": "pong"}) continue + elif msg_type == "set_tier": + # Dynamic tier switching for assistant model quality + tier = message.get("tier", "mid") + if tier in ("high", "mid", "low") and session: + session.set_tier(tier) + await websocket.send_json({ + "type": "tier_changed", + "tier": tier, + }) + elif not session: + await websocket.send_json({ + "type": "error", + "content": "No active session. Send 'start' first." + }) + continue + elif msg_type == "start": # Get optional conversation_id to resume conversation_id = message.get("conversation_id") @@ -268,6 +284,12 @@ async def assistant_chat_websocket(websocket: WebSocket, project_name: str): project_dir, conversation_id=conversation_id, ) + + # Apply initial tier from start message if provided + initial_tier = message.get("tier") + if initial_tier and initial_tier in ("high", "mid", "low"): + session.set_tier(initial_tier) + logger.debug("Session created, starting...") # Stream the initial greeting diff --git a/server/routers/settings.py b/server/routers/settings.py index 6137c63c..0928a7b0 100644 --- a/server/routers/settings.py +++ b/server/routers/settings.py @@ -11,7 +11,15 @@ from fastapi import APIRouter -from ..schemas import ModelInfo, ModelsResponse, ProviderInfo, ProvidersResponse, SettingsResponse, SettingsUpdate +from ..schemas import ( + ModelInfo, + ModelsResponse, + ProviderInfo, + ProvidersResponse, + RoleModelAssignment, + SettingsResponse, + SettingsUpdate, +) from ..services.chat_constants import ROOT_DIR # Mimetype fix for Windows - must run before StaticFiles is mounted @@ -25,6 +33,9 @@ API_PROVIDERS, AVAILABLE_MODELS, DEFAULT_MODEL, + ROLE_MODEL_KEY_PREFIX, + VALID_ROLES, + get_all_role_models, get_all_settings, get_setting, set_setting, @@ -51,6 +62,7 @@ async def get_available_providers(): models=[ModelInfo(id=m["id"], name=m["name"]) for m in pdata.get("models", [])], default_model=pdata.get("default_model", ""), requires_auth=pdata.get("requires_auth", False), + tiers=pdata.get("tiers"), )) return ProvidersResponse(providers=providers, current=current) @@ -105,6 +117,9 @@ async def get_settings(): glm_mode = api_provider == "glm" ollama_mode = api_provider == "ollama" + role_models_dict = get_all_role_models() + role_models = RoleModelAssignment(**role_models_dict) + return SettingsResponse( yolo_mode=_parse_yolo_mode(all_settings.get("yolo_mode")), model=all_settings.get("model", DEFAULT_MODEL), @@ -117,6 +132,7 @@ async def get_settings(): api_base_url=all_settings.get("api_base_url"), api_has_auth_token=bool(all_settings.get("api_auth_token")), api_model=all_settings.get("api_model"), + role_models=role_models, ) @@ -163,12 +179,28 @@ async def update_settings(update: SettingsUpdate): if update.api_model is not None: set_setting("api_model", update.api_model) + # Per-role model overrides + if update.role_models is not None: + role_dict = update.role_models.model_dump() + for role in VALID_ROLES: + value = role_dict.get(role) + if value is not None: + key = f"{ROLE_MODEL_KEY_PREFIX}{role}" + if value == "": + # Empty string clears the override (delete the setting) + set_setting(key, "") + else: + set_setting(key, value) + # Return updated settings all_settings = get_all_settings() api_provider = all_settings.get("api_provider", "claude") glm_mode = api_provider == "glm" ollama_mode = api_provider == "ollama" + updated_role_models_dict = get_all_role_models() + updated_role_models = RoleModelAssignment(**updated_role_models_dict) + return SettingsResponse( yolo_mode=_parse_yolo_mode(all_settings.get("yolo_mode")), model=all_settings.get("model", DEFAULT_MODEL), @@ -181,4 +213,5 @@ async def update_settings(update: SettingsUpdate): api_base_url=all_settings.get("api_base_url"), api_has_auth_token=bool(all_settings.get("api_auth_token")), api_model=all_settings.get("api_model"), + role_models=updated_role_models, ) diff --git a/server/schemas.py b/server/schemas.py index 5f546e2b..0a54b65a 100644 --- a/server/schemas.py +++ b/server/schemas.py @@ -402,6 +402,7 @@ class ProviderInfo(BaseModel): models: list[ModelInfo] default_model: str requires_auth: bool = False + tiers: dict[str, str] | None = None class ProvidersResponse(BaseModel): @@ -410,6 +411,17 @@ class ProvidersResponse(BaseModel): current: str +class RoleModelAssignment(BaseModel): + """Per-role model overrides. None means use provider tier default.""" + initializer: str | None = None + coding: str | None = None + testing: str | None = None + spec_creation: str | None = None + expand: str | None = None + assistant: str | None = None + log_review: str | None = None + + class SettingsResponse(BaseModel): """Response schema for global settings.""" yolo_mode: bool = False @@ -423,6 +435,7 @@ class SettingsResponse(BaseModel): api_base_url: str | None = None api_has_auth_token: bool = False # Never expose actual token api_model: str | None = None + role_models: RoleModelAssignment | None = None class ModelsResponse(BaseModel): @@ -442,6 +455,7 @@ class SettingsUpdate(BaseModel): api_base_url: str | None = Field(None, max_length=500) api_auth_token: str | None = Field(None, max_length=500) # Write-only, never returned api_model: str | None = Field(None, max_length=200) + role_models: RoleModelAssignment | None = None @field_validator('api_base_url') @classmethod diff --git a/server/services/assistant_chat_session.py b/server/services/assistant_chat_session.py index f030aa4b..45e063f7 100755 --- a/server/services/assistant_chat_session.py +++ b/server/services/assistant_chat_session.py @@ -9,7 +9,6 @@ import json import logging -import os import shutil import sys import threading @@ -185,6 +184,8 @@ def __init__(self, project_name: str, project_dir: Path, conversation_id: Option self._client_entered: bool = False self.created_at = datetime.now() self._history_loaded: bool = False # Track if we've loaded history for resumed conversations + self.tier_override: Optional[str] = None # Dynamic tier override (high/mid/low) + self._needs_client_refresh: bool = False # Flag to recreate client on tier change async def close(self) -> None: """Clean up resources and close the Claude client.""" @@ -266,11 +267,11 @@ async def start(self) -> AsyncGenerator[dict, None]: system_cli = shutil.which("claude") # Build environment overrides for API configuration - from registry import DEFAULT_MODEL, get_effective_sdk_env + from registry import get_effective_sdk_env sdk_env = get_effective_sdk_env() - # Determine model from SDK env (provider-aware) or fallback to env/default - model = sdk_env.get("ANTHROPIC_DEFAULT_OPUS_MODEL") or os.getenv("ANTHROPIC_DEFAULT_OPUS_MODEL", DEFAULT_MODEL) + # Determine model using per-role resolution (assistant -> mid tier by default) + model = self._resolve_model() try: logger.info("Creating ClaudeSDKClient...") @@ -322,6 +323,23 @@ async def start(self) -> AsyncGenerator[dict, None]: # _history_loaded stays False so send_message() will include history yield {"type": "response_done"} + def set_tier(self, tier: str) -> None: + """Set model tier override for this session. Takes effect on next message.""" + if tier in ("high", "mid", "low"): + self.tier_override = tier + self._needs_client_refresh = True + logger.info(f"Assistant tier set to '{tier}' for {self.project_name}") + + def _resolve_model(self) -> str: + """Resolve model with tier override support. + + Resolution: tier_override -> per-role setting -> provider tier default -> api_model -> DEFAULT_MODEL + """ + from registry import get_model_for_role, get_model_for_tier + if self.tier_override: + return get_model_for_tier(self.tier_override) + return get_model_for_role("assistant") + async def send_message(self, user_message: str) -> AsyncGenerator[dict, None]: """ Send user message and stream Claude's response. @@ -384,10 +402,60 @@ async def _query_claude(self, message: str) -> AsyncGenerator[dict, None]: Internal method to query Claude and stream responses. Handles tool calls and text responses. + Recreates the client if a tier change was requested. """ if not self.client: return + # Recreate client if tier was changed + if self._needs_client_refresh: + self._needs_client_refresh = False + new_model = self._resolve_model() + logger.info(f"Refreshing Claude client with model: {new_model}") + # Close existing client + if self._client_entered: + try: + await self.client.__aexit__(None, None, None) + except Exception as e: + logger.warning(f"Error closing old client during refresh: {e}") + self._client_entered = False + + # Recreate with new model + from registry import get_effective_sdk_env + system_cli = shutil.which("claude") + sdk_env = get_effective_sdk_env() + + from autoforge_paths import get_claude_assistant_settings_path + settings_file = get_claude_assistant_settings_path(self.project_dir) + + from claude_agent_sdk import ClaudeAgentOptions + self.client = ClaudeSDKClient( + options=ClaudeAgentOptions( + model=new_model, + cli_path=system_cli, + setting_sources=["project"], + allowed_tools=[*READONLY_BUILTIN_TOOLS, *ASSISTANT_FEATURE_TOOLS], + mcp_servers={ + "features": { + "command": sys.executable, + "args": ["-m", "mcp_server.feature_mcp"], + "env": { + "PROJECT_DIR": str(self.project_dir.resolve()), + "PYTHONPATH": str(ROOT_DIR.resolve()), + }, + }, + }, # type: ignore[arg-type] + permission_mode="bypassPermissions", + max_turns=100, + cwd=str(self.project_dir.resolve()), + settings=str(settings_file.resolve()), + env=sdk_env, + ) + ) + await self.client.__aenter__() + self._client_entered = True + logger.info(f"Client refreshed with model: {new_model}") + # Send message to Claude await self.client.query(message) diff --git a/server/services/expand_chat_session.py b/server/services/expand_chat_session.py index b06e9d85..b0c851e6 100644 --- a/server/services/expand_chat_session.py +++ b/server/services/expand_chat_session.py @@ -9,7 +9,6 @@ import asyncio import json import logging -import os import shutil import sys import threading @@ -154,11 +153,11 @@ async def start(self) -> AsyncGenerator[dict, None]: system_prompt = skill_content.replace("$ARGUMENTS", project_path) # Build environment overrides for API configuration - from registry import DEFAULT_MODEL, get_effective_sdk_env + from registry import get_effective_sdk_env, get_model_for_role sdk_env = get_effective_sdk_env() - # Determine model from SDK env (provider-aware) or fallback to env/default - model = sdk_env.get("ANTHROPIC_DEFAULT_OPUS_MODEL") or os.getenv("ANTHROPIC_DEFAULT_OPUS_MODEL", DEFAULT_MODEL) + # Determine model using per-role resolution (expand -> high tier) + model = get_model_for_role("expand") # Build MCP servers config for feature creation mcp_servers = { diff --git a/server/services/spec_chat_session.py b/server/services/spec_chat_session.py index d3556173..9f014bc0 100644 --- a/server/services/spec_chat_session.py +++ b/server/services/spec_chat_session.py @@ -8,7 +8,6 @@ import json import logging -import os import shutil import threading from datetime import datetime @@ -140,11 +139,11 @@ async def start(self) -> AsyncGenerator[dict, None]: system_cli = shutil.which("claude") # Build environment overrides for API configuration - from registry import DEFAULT_MODEL, get_effective_sdk_env + from registry import get_effective_sdk_env, get_model_for_role sdk_env = get_effective_sdk_env() - # Determine model from SDK env (provider-aware) or fallback to env/default - model = sdk_env.get("ANTHROPIC_DEFAULT_OPUS_MODEL") or os.getenv("ANTHROPIC_DEFAULT_OPUS_MODEL", DEFAULT_MODEL) + # Determine model using per-role resolution (spec_creation -> high tier) + model = get_model_for_role("spec_creation") try: self.client = ClaudeSDKClient( diff --git a/ui/src/components/AssistantChat.tsx b/ui/src/components/AssistantChat.tsx index 0592644e..d89c5d57 100644 --- a/ui/src/components/AssistantChat.tsx +++ b/ui/src/components/AssistantChat.tsx @@ -8,7 +8,8 @@ import { useState, useRef, useEffect, useCallback, useMemo } from 'react' import { Send, Loader2, Wifi, WifiOff, Plus, History } from 'lucide-react' -import { useAssistantChat } from '../hooks/useAssistantChat' +import { useAssistantChat, type AssistantTier } from '../hooks/useAssistantChat' +import { useAvailableProviders, useSettings } from '../hooks/useProjects' import { ChatMessage as ChatMessageComponent } from './ChatMessage' import { ConversationHistory } from './ConversationHistory' import { QuestionOptions } from './QuestionOptions' @@ -54,15 +55,26 @@ export function AssistantChat({ connectionStatus, conversationId: activeConversationId, currentQuestions, + currentTier, start, sendMessage, sendAnswer, + setTier, clearMessages, } = useAssistantChat({ projectName, onError: handleError, }) + // Get provider info for tier model labels + const { data: providersData } = useAvailableProviders() + const { data: settings } = useSettings() + const currentProvider = providersData?.providers?.find(p => p.id === (settings?.api_provider ?? 'claude')) + const providerTiers = currentProvider?.tiers + const hasMultipleTiers = providerTiers + ? new Set(Object.values(providerTiers)).size > 1 + : false + // Notify parent when a NEW conversation is created (not when switching to existing) // Track activeConversationId to fire callback only once when it transitions from null to a value const previousActiveConversationIdRef = useRef(activeConversationId) @@ -206,8 +218,34 @@ export function AssistantChat({ /> - {/* Connection status */} + {/* Tier toggle + Connection status */}
+ {/* Tier toggle - only show when provider has multiple distinct tier models */} + {hasMultipleTiers && providerTiers && ( +
+ {(['low', 'mid', 'high'] as AssistantTier[]).map((tier) => { + const modelId = providerTiers[tier] + const modelName = currentProvider?.models?.find(m => m.id === modelId)?.name + // Short label: take last word (e.g., "Claude Haiku" -> "Haiku") + const shortLabel = modelName?.split(' ').pop() ?? tier.toUpperCase() + return ( + + ) + })} +
+ )} + {connectionStatus === 'connected' ? ( <> diff --git a/ui/src/components/SettingsModal.tsx b/ui/src/components/SettingsModal.tsx index 0a2b9eec..d1fc4b2f 100644 --- a/ui/src/components/SettingsModal.tsx +++ b/ui/src/components/SettingsModal.tsx @@ -1,8 +1,8 @@ import { useState } from 'react' -import { Loader2, AlertCircle, Check, Moon, Sun, Eye, EyeOff, ShieldCheck } from 'lucide-react' +import { Loader2, AlertCircle, Check, Moon, Sun, Eye, EyeOff, ShieldCheck, ChevronDown, ChevronRight } from 'lucide-react' import { useSettings, useUpdateSettings, useAvailableModels, useAvailableProviders } from '../hooks/useProjects' import { useTheme, THEMES } from '../hooks/useTheme' -import type { ProviderInfo } from '../lib/types' +import type { ProviderInfo, RoleModelAssignment } from '../lib/types' import { Dialog, DialogContent, @@ -38,6 +38,7 @@ export function SettingsModal({ isOpen, onClose }: SettingsModalProps) { const [authTokenInput, setAuthTokenInput] = useState('') const [customModelInput, setCustomModelInput] = useState('') const [customBaseUrlInput, setCustomBaseUrlInput] = useState('') + const [showRoleModels, setShowRoleModels] = useState(false) const handleYoloToggle = () => { if (settings && !updateSettings.isPending) { @@ -95,6 +96,12 @@ export function SettingsModal({ isOpen, onClose }: SettingsModalProps) { } } + const handleRoleModelChange = (role: keyof RoleModelAssignment, modelId: string) => { + if (!updateSettings.isPending) { + updateSettings.mutate({ role_models: { [role]: modelId } }) + } + } + const providers = providersData?.providers ?? [] const models = modelsData?.models ?? [] const isSaving = updateSettings.isPending @@ -107,7 +114,7 @@ export function SettingsModal({ isOpen, onClose }: SettingsModalProps) { return ( !open && onClose()}> - + Settings @@ -353,6 +360,55 @@ export function SettingsModal({ isOpen, onClose }: SettingsModalProps) { )}
+ {/* Per-Role Model Assignment */} + {models.length > 1 && ( +
+ + {showRoleModels && ( +
+

+ Override the model used for each task type. Leave as "Default" to use the recommended tier-based defaults. +

+ {([ + { role: 'initializer' as const, label: 'Initialization', tier: 'high' }, + { role: 'coding' as const, label: 'Coding', tier: 'mid' }, + { role: 'testing' as const, label: 'Testing', tier: 'low' }, + { role: 'spec_creation' as const, label: 'Spec Creation', tier: 'high' }, + { role: 'expand' as const, label: 'Expansion', tier: 'high' }, + { role: 'assistant' as const, label: 'Assistant', tier: 'mid' }, + { role: 'log_review' as const, label: 'Log Review', tier: 'low' }, + ]).map(({ role, label, tier }) => { + const currentValue = settings.role_models?.[role] ?? '' + const tierDefault = currentProviderInfo?.tiers?.[tier as keyof typeof currentProviderInfo.tiers] + const tierModelName = models.find(m => m.id === tierDefault)?.name ?? tierDefault ?? 'Default' + return ( +
+ {label} + +
+ ) + })} +
+ )} +
+ )} +
{/* YOLO Mode Toggle */} diff --git a/ui/src/hooks/useAssistantChat.ts b/ui/src/hooks/useAssistantChat.ts index cb660f60..8bf6ade0 100755 --- a/ui/src/hooks/useAssistantChat.ts +++ b/ui/src/hooks/useAssistantChat.ts @@ -12,15 +12,19 @@ interface UseAssistantChatOptions { onError?: (error: string) => void; } +export type AssistantTier = "high" | "mid" | "low"; + interface UseAssistantChatReturn { messages: ChatMessage[]; isLoading: boolean; connectionStatus: ConnectionStatus; conversationId: number | null; currentQuestions: SpecQuestion[] | null; + currentTier: AssistantTier; start: (conversationId?: number | null) => void; sendMessage: (content: string) => void; sendAnswer: (answers: Record) => void; + setTier: (tier: AssistantTier) => void; disconnect: () => void; clearMessages: () => void; } @@ -39,6 +43,15 @@ export function useAssistantChat({ useState("disconnected"); const [conversationId, setConversationId] = useState(null); const [currentQuestions, setCurrentQuestions] = useState(null); + const [currentTier, setCurrentTier] = useState(() => { + try { + const stored = localStorage.getItem(`assistant-tier-${projectName}`); + if (stored && (stored === "high" || stored === "mid" || stored === "low")) { + return stored as AssistantTier; + } + } catch { /* ignore */ } + return "mid"; + }); const wsRef = useRef(null); const currentAssistantMessageRef = useRef(null); @@ -273,6 +286,12 @@ export function useAssistantChat({ // Keep-alive response, nothing to do break; } + + // eslint-disable-next-line @typescript-eslint/no-explicit-any + case "tier_changed" as any: { + // Server confirmed tier change, nothing more to do + break; + } } } catch (e) { console.error("Failed to parse WebSocket message:", e); @@ -295,8 +314,9 @@ export function useAssistantChat({ if (wsRef.current?.readyState === WebSocket.OPEN) { checkAndSendTimeoutRef.current = null; setIsLoading(true); - const payload: { type: string; conversation_id?: number } = { + const payload: { type: string; conversation_id?: number; tier?: string } = { type: "start", + tier: currentTier, }; if (existingConversationId) { payload.conversation_id = existingConversationId; @@ -315,7 +335,7 @@ export function useAssistantChat({ checkAndSendTimeoutRef.current = window.setTimeout(checkAndSend, 100); }, - [connect], + [connect, currentTier], ); const sendMessage = useCallback( @@ -392,6 +412,24 @@ export function useAssistantChat({ [onError], ); + const setTier = useCallback( + (tier: AssistantTier) => { + setCurrentTier(tier); + localStorage.setItem(`assistant-tier-${projectName}`, tier); + + // Send to server if connected + if (wsRef.current?.readyState === WebSocket.OPEN) { + wsRef.current.send( + JSON.stringify({ + type: "set_tier", + tier, + }), + ); + } + }, + [projectName], + ); + const disconnect = useCallback(() => { reconnectAttempts.current = maxReconnectAttempts; // Prevent reconnection if (pingIntervalRef.current) { @@ -416,9 +454,11 @@ export function useAssistantChat({ connectionStatus, conversationId, currentQuestions, + currentTier, start, sendMessage, sendAnswer, + setTier, disconnect, clearMessages, }; diff --git a/ui/src/hooks/useProjects.ts b/ui/src/hooks/useProjects.ts index f69d90f9..e46de561 100644 --- a/ui/src/hooks/useProjects.ts +++ b/ui/src/hooks/useProjects.ts @@ -272,6 +272,7 @@ const DEFAULT_SETTINGS: Settings = { api_base_url: null, api_has_auth_token: false, api_model: null, + role_models: null, } const DEFAULT_PROVIDERS: ProvidersResponse = { @@ -324,11 +325,21 @@ export function useUpdateSettings() { const previous = queryClient.getQueryData(['settings']) // Optimistically update - queryClient.setQueryData(['settings'], (old) => ({ - ...DEFAULT_SETTINGS, - ...old, - ...newSettings, - })) + queryClient.setQueryData(['settings'], (old) => { + const base = { ...DEFAULT_SETTINGS, ...old } + const { role_models: newRoleModels, ...rest } = newSettings + const merged = { ...base, ...rest } + // Merge role_models carefully (partial update into full object) + if (newRoleModels) { + merged.role_models = { + initializer: null, coding: null, testing: null, + spec_creation: null, expand: null, assistant: null, log_review: null, + ...base.role_models, + ...newRoleModels, + } + } + return merged + }) return { previous } }, diff --git a/ui/src/lib/types.ts b/ui/src/lib/types.ts index ba8eab94..86661c31 100644 --- a/ui/src/lib/types.ts +++ b/ui/src/lib/types.ts @@ -538,6 +538,7 @@ export interface ProviderInfo { models: ModelInfo[] default_model: string requires_auth: boolean + tiers?: { high: string; mid: string; low: string } } export interface ProvidersResponse { @@ -545,6 +546,16 @@ export interface ProvidersResponse { current: string } +export interface RoleModelAssignment { + initializer: string | null + coding: string | null + testing: string | null + spec_creation: string | null + expand: string | null + assistant: string | null + log_review: string | null +} + export interface Settings { yolo_mode: boolean model: string @@ -557,6 +568,7 @@ export interface Settings { api_base_url: string | null api_has_auth_token: boolean api_model: string | null + role_models: RoleModelAssignment | null } export interface SettingsUpdate { @@ -569,6 +581,7 @@ export interface SettingsUpdate { api_base_url?: string api_auth_token?: string api_model?: string + role_models?: Partial } export interface ProjectSettingsUpdate { From 304fec5ced7ef2e1ce5ad5671f91ac4bd014bc5a Mon Sep 17 00:00:00 2001 From: Caitlyn Byrne Date: Sun, 8 Feb 2026 20:30:35 -0500 Subject: [PATCH 2/2] ui: global model selector sets all role assignments with toast feedback Clicking a model in the global Model selector now sets all 7 per-role model assignments to that model, shows an inline toast confirming the change, and auto-expands the per-task section for further customization. The global selector only highlights a model when all roles match it. Co-Authored-By: Claude Opus 4.6 --- ui/src/components/SettingsModal.tsx | 42 ++++++++++++++++++++++++++--- 1 file changed, 39 insertions(+), 3 deletions(-) diff --git a/ui/src/components/SettingsModal.tsx b/ui/src/components/SettingsModal.tsx index d1fc4b2f..536b7201 100644 --- a/ui/src/components/SettingsModal.tsx +++ b/ui/src/components/SettingsModal.tsx @@ -1,4 +1,4 @@ -import { useState } from 'react' +import { useState, useEffect } from 'react' import { Loader2, AlertCircle, Check, Moon, Sun, Eye, EyeOff, ShieldCheck, ChevronDown, ChevronRight } from 'lucide-react' import { useSettings, useUpdateSettings, useAvailableModels, useAvailableProviders } from '../hooks/useProjects' import { useTheme, THEMES } from '../hooks/useTheme' @@ -39,6 +39,22 @@ export function SettingsModal({ isOpen, onClose }: SettingsModalProps) { const [customModelInput, setCustomModelInput] = useState('') const [customBaseUrlInput, setCustomBaseUrlInput] = useState('') const [showRoleModels, setShowRoleModels] = useState(false) + const [toast, setToast] = useState(null) + + useEffect(() => { + if (toast) { + const timer = setTimeout(() => setToast(null), 4000) + return () => clearTimeout(timer) + } + }, [toast]) + + const ALL_ROLES = ['initializer', 'coding', 'testing', 'spec_creation', 'expand', 'assistant', 'log_review'] as const + + const allRolesMatch = (modelId: string): boolean => { + const rm = settings?.role_models + if (!rm) return false + return ALL_ROLES.every(r => rm[r] === modelId) + } const handleYoloToggle = () => { if (settings && !updateSettings.isPending) { @@ -48,7 +64,21 @@ export function SettingsModal({ isOpen, onClose }: SettingsModalProps) { const handleModelChange = (modelId: string) => { if (!updateSettings.isPending) { - updateSettings.mutate({ api_model: modelId }) + updateSettings.mutate({ + api_model: modelId, + role_models: { + initializer: modelId, + coding: modelId, + testing: modelId, + spec_creation: modelId, + expand: modelId, + assistant: modelId, + log_review: modelId, + }, + }) + const modelName = models.find(m => m.id === modelId)?.name ?? modelId + setToast(`All task roles updated to ${modelName}. Customize per-task below.`) + setShowRoleModels(true) } } @@ -327,7 +357,7 @@ export function SettingsModal({ isOpen, onClose }: SettingsModalProps) { onClick={() => handleModelChange(model.id)} disabled={isSaving} className={`flex-1 py-2 px-3 text-sm font-medium transition-colors ${ - (settings.api_model ?? settings.model) === model.id + allRolesMatch(model.id) ? 'bg-primary text-primary-foreground' : 'bg-background text-foreground hover:bg-muted' } ${isSaving ? 'opacity-50 cursor-not-allowed' : ''}`} @@ -338,6 +368,12 @@ export function SettingsModal({ isOpen, onClose }: SettingsModalProps) { ))} )} + {toast && ( +
+ + {toast} +
+ )} {/* Custom model input for Ollama/Custom */} {showCustomModelInput && (