Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 24 additions & 8 deletions parallel_orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,18 @@ def get_session(self):
"""Get a new database session."""
return self._session_maker()

def _get_model_for_role(self, role: str) -> str | None:
"""Resolve the model for a given agent role.

Uses the per-role model resolution from registry, with self.model
as the CLI override (highest priority).

Returns:
Model ID string, or None if the default should be used.
"""
from registry import get_model_for_role
return get_model_for_role(role, cli_override=self.model)

def _get_random_passing_feature(self) -> int | None:
"""Get a random passing feature for regression testing (no claim needed).

Expand Down Expand Up @@ -829,8 +841,9 @@ def _spawn_coding_agent(self, feature_id: int) -> tuple[bool, str]:
"--agent-type", "coding",
"--feature-id", str(feature_id),
]
if self.model:
cmd.extend(["--model", self.model])
coding_model = self._get_model_for_role("coding")
if coding_model:
cmd.extend(["--model", coding_model])
if self.yolo_mode:
cmd.append("--yolo")

Expand Down Expand Up @@ -895,8 +908,9 @@ def _spawn_coding_agent_batch(self, feature_ids: list[int]) -> tuple[bool, str]:
"--agent-type", "coding",
"--feature-ids", ",".join(str(fid) for fid in feature_ids),
]
if self.model:
cmd.extend(["--model", self.model])
coding_model = self._get_model_for_role("coding")
if coding_model:
cmd.extend(["--model", coding_model])
if self.yolo_mode:
cmd.append("--yolo")

Expand Down Expand Up @@ -998,8 +1012,9 @@ def _spawn_testing_agent(self) -> tuple[bool, str]:
"--agent-type", "testing",
"--testing-feature-ids", batch_str,
]
if self.model:
cmd.extend(["--model", self.model])
testing_model = self._get_model_for_role("testing")
if testing_model:
cmd.extend(["--model", testing_model])

try:
# CREATE_NO_WINDOW on Windows prevents console window pop-ups
Expand Down Expand Up @@ -1058,8 +1073,9 @@ async def _run_initializer(self) -> bool:
"--agent-type", "initializer",
"--max-iterations", "1",
]
if self.model:
cmd.extend(["--model", self.model])
init_model = self._get_model_for_role("initializer")
if init_model:
cmd.extend(["--model", init_model])

print("Running initializer agent...", flush=True)

Expand Down
104 changes: 102 additions & 2 deletions registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def _migrate_registry_dir() -> None:
AVAILABLE_MODELS = [
{"id": "claude-opus-4-6", "name": "Claude Opus"},
{"id": "claude-sonnet-4-5-20250929", "name": "Claude Sonnet"},
{"id": "claude-haiku-4-5-20251001", "name": "Claude Haiku"},
]

# Map legacy model IDs to their current replacements.
Expand All @@ -73,6 +74,23 @@ def _migrate_registry_dir() -> None:
VALID_MODELS.append(DEFAULT_MODEL)
DEFAULT_YOLO_MODE = False

# =============================================================================
# Role-to-Tier Mapping
# =============================================================================

# Each role maps to a tier (high/mid/low) which resolves to a provider-specific model
ROLE_TIER_MAP: dict[str, str] = {
"initializer": "high",
"coding": "mid",
"testing": "low",
"spec_creation": "high",
"expand": "high",
"assistant": "mid",
"log_review": "low",
}
VALID_ROLES = list(ROLE_TIER_MAP.keys())
ROLE_MODEL_KEY_PREFIX = "model_"

# SQLite connection settings
SQLITE_TIMEOUT = 30 # seconds to wait for database lock
SQLITE_MAX_RETRIES = 3 # number of retry attempts on busy database
Expand Down Expand Up @@ -617,9 +635,12 @@ def get_all_settings() -> dict[str, str]:
settings = session.query(Settings).all()
result = {s.key: s.value for s in settings}

# Auto-migrate legacy model IDs
# Auto-migrate legacy model IDs (including per-role model_* keys)
migrated = False
for key in ("model", "api_model"):
migration_keys = ["model", "api_model"] + [
f"{ROLE_MODEL_KEY_PREFIX}{role}" for role in VALID_ROLES
]
for key in migration_keys:
old_id = result.get(key)
if old_id and old_id in LEGACY_MODEL_MAP:
new_id = LEGACY_MODEL_MAP[old_id]
Expand All @@ -642,6 +663,75 @@ def get_all_settings() -> dict[str, str]:
return {}


def get_model_for_role(role: str, cli_override: str | None = None) -> str:
"""Resolve the model to use for a given role.

Resolution chain:
1. cli_override (--model flag) — universal override
2. settings["model_{role}"] — per-role override from Settings UI
3. Provider tier default — provider.tiers[ROLE_TIER_MAP[role]]
4. settings["api_model"] — global model override
5. DEFAULT_MODEL — absolute fallback

Args:
role: One of VALID_ROLES (e.g. "coding", "testing", "assistant")
cli_override: Model passed via --model CLI flag

Returns:
Model ID string
"""
if cli_override:
return cli_override

all_settings = get_all_settings()

# Per-role setting
role_model = all_settings.get(f"{ROLE_MODEL_KEY_PREFIX}{role}")
if role_model:
return role_model

# Provider tier default
provider_id = all_settings.get("api_provider", "claude")
provider = API_PROVIDERS.get(provider_id, {})
tier = ROLE_TIER_MAP.get(role, "mid")
tier_model = provider.get("tiers", {}).get(tier)
if tier_model:
return tier_model

# Global fallback
return all_settings.get("api_model") or all_settings.get("model") or DEFAULT_MODEL


def get_model_for_tier(tier: str) -> str:
"""Resolve the model for a specific tier (high/mid/low).

Used by the assistant chat session for dynamic tier switching.

Args:
tier: One of "high", "mid", "low"

Returns:
Model ID string
"""
all_settings = get_all_settings()
provider_id = all_settings.get("api_provider", "claude")
provider = API_PROVIDERS.get(provider_id, {})
tier_model = provider.get("tiers", {}).get(tier)
if tier_model:
return tier_model
return all_settings.get("api_model") or all_settings.get("model") or DEFAULT_MODEL


def get_all_role_models() -> dict[str, str | None]:
"""Get all per-role model overrides from settings.

Returns:
Dict mapping role names to model IDs (None if no override set).
"""
all_settings = get_all_settings()
return {role: all_settings.get(f"{ROLE_MODEL_KEY_PREFIX}{role}") or None for role in VALID_ROLES}


# =============================================================================
# API Provider Definitions
# =============================================================================
Expand All @@ -654,8 +744,14 @@ def get_all_settings() -> dict[str, str]:
"models": [
{"id": "claude-opus-4-6", "name": "Claude Opus"},
{"id": "claude-sonnet-4-5-20250929", "name": "Claude Sonnet"},
{"id": "claude-haiku-4-5-20251001", "name": "Claude Haiku"},
],
"default_model": "claude-opus-4-6",
"tiers": {
"high": "claude-opus-4-6",
"mid": "claude-sonnet-4-5-20250929",
"low": "claude-haiku-4-5-20251001",
},
},
"kimi": {
"name": "Kimi K2.5 (Moonshot)",
Expand All @@ -664,6 +760,7 @@ def get_all_settings() -> dict[str, str]:
"auth_env_var": "ANTHROPIC_API_KEY",
"models": [{"id": "kimi-k2.5", "name": "Kimi K2.5"}],
"default_model": "kimi-k2.5",
"tiers": {"high": "kimi-k2.5", "mid": "kimi-k2.5", "low": "kimi-k2.5"},
},
"glm": {
"name": "GLM (Zhipu AI)",
Expand All @@ -675,6 +772,7 @@ def get_all_settings() -> dict[str, str]:
{"id": "glm-4.5-air", "name": "GLM 4.5 Air"},
],
"default_model": "glm-4.7",
"tiers": {"high": "glm-4.7", "mid": "glm-4.5-air", "low": "glm-4.5-air"},
},
"ollama": {
"name": "Ollama (Local)",
Expand All @@ -685,6 +783,7 @@ def get_all_settings() -> dict[str, str]:
{"id": "deepseek-coder-v2", "name": "DeepSeek Coder V2"},
],
"default_model": "qwen3-coder",
"tiers": {"high": "qwen3-coder", "mid": "qwen3-coder", "low": "qwen3-coder"},
},
"custom": {
"name": "Custom Provider",
Expand All @@ -693,6 +792,7 @@ def get_all_settings() -> dict[str, str]:
"auth_env_var": "ANTHROPIC_AUTH_TOKEN",
"models": [],
"default_model": "",
"tiers": {"high": "", "mid": "", "low": ""},
},
}

Expand Down
22 changes: 22 additions & 0 deletions server/routers/assistant_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,22 @@ async def assistant_chat_websocket(websocket: WebSocket, project_name: str):
await websocket.send_json({"type": "pong"})
continue

elif msg_type == "set_tier":
# Dynamic tier switching for assistant model quality
tier = message.get("tier", "mid")
if tier in ("high", "mid", "low") and session:
session.set_tier(tier)
await websocket.send_json({
"type": "tier_changed",
"tier": tier,
})
elif not session:
await websocket.send_json({
"type": "error",
"content": "No active session. Send 'start' first."
})
continue

elif msg_type == "start":
# Get optional conversation_id to resume
conversation_id = message.get("conversation_id")
Expand All @@ -268,6 +284,12 @@ async def assistant_chat_websocket(websocket: WebSocket, project_name: str):
project_dir,
conversation_id=conversation_id,
)

# Apply initial tier from start message if provided
initial_tier = message.get("tier")
if initial_tier and initial_tier in ("high", "mid", "low"):
session.set_tier(initial_tier)

logger.debug("Session created, starting...")

# Stream the initial greeting
Expand Down
35 changes: 34 additions & 1 deletion server/routers/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,15 @@

from fastapi import APIRouter

from ..schemas import ModelInfo, ModelsResponse, ProviderInfo, ProvidersResponse, SettingsResponse, SettingsUpdate
from ..schemas import (
ModelInfo,
ModelsResponse,
ProviderInfo,
ProvidersResponse,
RoleModelAssignment,
SettingsResponse,
SettingsUpdate,
)
from ..services.chat_constants import ROOT_DIR

# Mimetype fix for Windows - must run before StaticFiles is mounted
Expand All @@ -25,6 +33,9 @@
API_PROVIDERS,
AVAILABLE_MODELS,
DEFAULT_MODEL,
ROLE_MODEL_KEY_PREFIX,
VALID_ROLES,
get_all_role_models,
get_all_settings,
get_setting,
set_setting,
Expand All @@ -51,6 +62,7 @@ async def get_available_providers():
models=[ModelInfo(id=m["id"], name=m["name"]) for m in pdata.get("models", [])],
default_model=pdata.get("default_model", ""),
requires_auth=pdata.get("requires_auth", False),
tiers=pdata.get("tiers"),
))
return ProvidersResponse(providers=providers, current=current)

Expand Down Expand Up @@ -105,6 +117,9 @@ async def get_settings():
glm_mode = api_provider == "glm"
ollama_mode = api_provider == "ollama"

role_models_dict = get_all_role_models()
role_models = RoleModelAssignment(**role_models_dict)

return SettingsResponse(
yolo_mode=_parse_yolo_mode(all_settings.get("yolo_mode")),
model=all_settings.get("model", DEFAULT_MODEL),
Expand All @@ -117,6 +132,7 @@ async def get_settings():
api_base_url=all_settings.get("api_base_url"),
api_has_auth_token=bool(all_settings.get("api_auth_token")),
api_model=all_settings.get("api_model"),
role_models=role_models,
)


Expand Down Expand Up @@ -163,12 +179,28 @@ async def update_settings(update: SettingsUpdate):
if update.api_model is not None:
set_setting("api_model", update.api_model)

# Per-role model overrides
if update.role_models is not None:
role_dict = update.role_models.model_dump()
for role in VALID_ROLES:
value = role_dict.get(role)
if value is not None:
key = f"{ROLE_MODEL_KEY_PREFIX}{role}"
if value == "":
# Empty string clears the override (delete the setting)
set_setting(key, "")
else:
set_setting(key, value)

# Return updated settings
all_settings = get_all_settings()
api_provider = all_settings.get("api_provider", "claude")
glm_mode = api_provider == "glm"
ollama_mode = api_provider == "ollama"

updated_role_models_dict = get_all_role_models()
updated_role_models = RoleModelAssignment(**updated_role_models_dict)

return SettingsResponse(
yolo_mode=_parse_yolo_mode(all_settings.get("yolo_mode")),
model=all_settings.get("model", DEFAULT_MODEL),
Expand All @@ -181,4 +213,5 @@ async def update_settings(update: SettingsUpdate):
api_base_url=all_settings.get("api_base_url"),
api_has_auth_token=bool(all_settings.get("api_auth_token")),
api_model=all_settings.get("api_model"),
role_models=updated_role_models,
)
Loading