Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 85 additions & 0 deletions src/bot/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,7 @@ def _register_agentic_handlers(self, app: Application) -> None:
("new", self.agentic_new),
("status", self.agentic_status),
("verbose", self.agentic_verbose),
("model", self.agentic_model),
("repo", self.agentic_repo),
("restart", command.restart_command),
]
Expand Down Expand Up @@ -415,6 +416,7 @@ async def get_bot_commands(self) -> list: # type: ignore[type-arg]
BotCommand("new", "Start a fresh session"),
BotCommand("status", "Show session status"),
BotCommand("verbose", "Set output verbosity (0/1/2)"),
BotCommand("model", "Switch Claude model"),
BotCommand("repo", "List repos / switch workspace"),
BotCommand("restart", "Restart the bot"),
]
Expand Down Expand Up @@ -578,6 +580,81 @@ async def agentic_verbose(
parse_mode="HTML",
)

def _get_model_override(self, context: ContextTypes.DEFAULT_TYPE) -> Optional[str]:
"""Return per-user model override, or None to use the default."""
return context.user_data.get("model_override")

@staticmethod
def _resolve_model_display(
user_override: Optional[str],
config_model: Optional[str],
last_model: Optional[str] = None,
) -> str:
"""Return a human-readable model string showing what will actually be used."""
if user_override:
return user_override
if config_model:
return config_model
if last_model:
return last_model
return "unknown (send a message first to detect)"

async def agentic_model(
self, update: Update, context: ContextTypes.DEFAULT_TYPE
) -> None:
"""Set Claude model: /model [model_name]."""
args = update.message.text.split()[1:] if update.message.text else []
user_override = self._get_model_override(context)
last_model = context.user_data.get("last_model")
current = self._resolve_model_display(
user_override, self.settings.claude_model, last_model
)

if not args:
source = "user override" if user_override else (
"server config" if self.settings.claude_model else "Claude Code default"
)
await update.message.reply_text(
f"Model: <b>{escape_html(current)}</b> ({source})\n\n"
"Usage: <code>/model model_name</code>\n"
"Reset: <code>/model default</code>",
parse_mode="HTML",
)
return

model_name = args[0].strip()
if not model_name or len(model_name) > 100:
await update.message.reply_text("Invalid model name.")
return
audit_logger = context.bot_data.get("audit_logger")
if model_name == "default":
context.user_data.pop("model_override", None)
default = self._resolve_model_display(None, self.settings.claude_model)
await update.message.reply_text(
f"Model reset to <b>{escape_html(default)}</b>",
parse_mode="HTML",
)
if audit_logger:
await audit_logger.log_command(
user_id=update.effective_user.id,
command="model_reset",
args=[],
success=True,
)
else:
context.user_data["model_override"] = model_name
await update.message.reply_text(
f"Model set to <b>{escape_html(model_name)}</b>",
parse_mode="HTML",
)
if audit_logger:
await audit_logger.log_command(
user_id=update.effective_user.id,
command="model",
args=[model_name],
success=True,
)

def _format_verbose_progress(
self,
activity_log: List[Dict[str, Any]],
Expand Down Expand Up @@ -941,13 +1018,16 @@ async def agentic_text(
session_id=session_id,
on_stream=on_stream,
force_new=force_new,
model_override=self._get_model_override(context),
)

# New session created successfully — clear the one-shot flag
if force_new:
context.user_data["force_new_session"] = False

context.user_data["claude_session_id"] = claude_response.session_id
if claude_response.model:
context.user_data["last_model"] = claude_response.model

# Track directory changes
from .handlers.message import _update_working_directory_from_claude_response
Expand Down Expand Up @@ -1185,12 +1265,15 @@ async def agentic_document(
session_id=session_id,
on_stream=on_stream,
force_new=force_new,
model_override=self._get_model_override(context),
)

if force_new:
context.user_data["force_new_session"] = False

context.user_data["claude_session_id"] = claude_response.session_id
if claude_response.model:
context.user_data["last_model"] = claude_response.model

from .handlers.message import _update_working_directory_from_claude_response

Expand Down Expand Up @@ -1384,6 +1467,7 @@ async def _handle_agentic_media_message(
session_id=session_id,
on_stream=on_stream,
force_new=force_new,
model_override=self._get_model_override(context),
)
finally:
heartbeat.cancel()
Expand All @@ -1392,6 +1476,7 @@ async def _handle_agentic_media_message(
context.user_data["force_new_session"] = False

context.user_data["claude_session_id"] = claude_response.session_id
context.user_data["last_model"] = claude_response.model

from .handlers.message import _update_working_directory_from_claude_response

Expand Down
5 changes: 5 additions & 0 deletions src/claude/facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ async def run_command(
session_id: Optional[str] = None,
on_stream: Optional[Callable[[StreamUpdate], None]] = None,
force_new: bool = False,
model_override: Optional[str] = None,
) -> ClaudeResponse:
"""Run Claude Code command with full integration."""
logger.info(
Expand Down Expand Up @@ -85,6 +86,7 @@ async def run_command(
session_id=claude_session_id,
continue_session=should_continue,
stream_callback=on_stream,
model_override=model_override,
)
except Exception as resume_error:
# If resume failed (e.g., session expired/missing on Claude's side),
Expand All @@ -109,6 +111,7 @@ async def run_command(
session_id=None,
continue_session=False,
stream_callback=on_stream,
model_override=model_override,
)
else:
raise
Expand Down Expand Up @@ -152,6 +155,7 @@ async def _execute(
session_id: Optional[str] = None,
continue_session: bool = False,
stream_callback: Optional[Callable] = None,
model_override: Optional[str] = None,
) -> ClaudeResponse:
"""Execute command via SDK."""
return await self.sdk_manager.execute_command(
Expand All @@ -160,6 +164,7 @@ async def _execute(
session_id=session_id,
continue_session=continue_session,
stream_callback=stream_callback,
model_override=model_override,
)

async def _find_resumable_session(
Expand Down
10 changes: 8 additions & 2 deletions src/claude/sdk_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class ClaudeResponse:
is_error: bool = False
error_type: Optional[str] = None
tools_used: List[Dict[str, Any]] = field(default_factory=list)
model: Optional[str] = None


@dataclass
Expand Down Expand Up @@ -153,6 +154,7 @@ async def execute_command(
session_id: Optional[str] = None,
continue_session: bool = False,
stream_callback: Optional[Callable[[StreamUpdate], None]] = None,
model_override: Optional[str] = None,
) -> ClaudeResponse:
"""Execute Claude Code command via SDK."""
start_time = asyncio.get_event_loop().time()
Expand Down Expand Up @@ -197,7 +199,7 @@ def _stderr_callback(line: str) -> None:
# Build Claude Agent options
options = ClaudeAgentOptions(
max_turns=self.config.claude_max_turns,
model=self.config.claude_model or None,
model=model_override or self.config.claude_model or None,
max_budget_usd=self.config.claude_max_cost_per_request,
cwd=str(working_directory),
allowed_tools=sdk_allowed_tools,
Expand Down Expand Up @@ -294,11 +296,12 @@ async def _run_client() -> None:
timeout=self.config.claude_timeout_seconds,
)

# Extract cost, tools, and session_id from result message
# Extract cost, tools, session_id, and model from result message
cost = 0.0
tools_used: List[Dict[str, Any]] = []
claude_session_id = None
result_content = None
response_model: Optional[str] = None
for message in messages:
if isinstance(message, ResultMessage):
cost = getattr(message, "total_cost_usd", 0.0) or 0.0
Expand All @@ -307,6 +310,8 @@ async def _run_client() -> None:
current_time = asyncio.get_event_loop().time()
for msg in messages:
if isinstance(msg, AssistantMessage):
if not response_model:
response_model = getattr(msg, "model", None)
msg_content = getattr(msg, "content", [])
if msg_content and isinstance(msg_content, list):
for block in msg_content:
Expand Down Expand Up @@ -377,6 +382,7 @@ async def _run_client() -> None:
]
),
tools_used=tools_used,
model=response_model,
)

except asyncio.TimeoutError:
Expand Down
79 changes: 79 additions & 0 deletions tests/unit/test_claude/test_facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,85 @@ async def test_retry_after_failure_still_skips_auto_resume(
assert user_data["force_new_session"] is False


class TestModelOverride:
"""Verify model_override is passed through to _execute."""

async def test_model_override_forwarded_to_execute(self, facade, session_manager):
"""run_command passes model_override through to _execute."""
project = Path("/test/project")
user_id = 123

with patch.object(
facade,
"_execute",
return_value=_make_mock_response(),
) as mock_execute:
await facade.run_command(
prompt="hello",
working_directory=project,
user_id=user_id,
model_override="opus",
)

mock_execute.assert_called_once()
assert mock_execute.call_args.kwargs["model_override"] == "opus"

async def test_model_override_none_by_default(self, facade, session_manager):
"""run_command passes model_override=None when not specified."""
project = Path("/test/project")
user_id = 123

with patch.object(
facade,
"_execute",
return_value=_make_mock_response(),
) as mock_execute:
await facade.run_command(
prompt="hello",
working_directory=project,
user_id=user_id,
)

mock_execute.assert_called_once()
assert mock_execute.call_args.kwargs["model_override"] is None

async def test_model_override_survives_session_retry(self, facade, session_manager):
"""model_override is preserved when session resume fails and retries."""
project = Path("/test/project")
user_id = 123

# Seed an existing session so resume is attempted
existing = ClaudeSession(
session_id="old-session",
user_id=user_id,
project_path=project,
created_at=datetime.utcnow(),
last_used=datetime.utcnow(),
)
await session_manager.storage.save_session(existing)
session_manager.active_sessions[existing.session_id] = existing

call_count = [0]

async def _execute_side_effect(**kwargs):
call_count[0] += 1
if call_count[0] == 1:
raise RuntimeError("session expired")
return _make_mock_response()

with patch.object(facade, "_execute", side_effect=_execute_side_effect):
await facade.run_command(
prompt="hello",
working_directory=project,
user_id=user_id,
session_id="old-session",
model_override="haiku",
)

# Both the initial call and retry should have model_override="haiku"
assert call_count[0] == 2


class TestEmptySessionIdWarning:
"""Verify facade warns when final session_id is empty."""

Expand Down
60 changes: 60 additions & 0 deletions tests/unit/test_claude/test_sdk_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -696,6 +696,66 @@ async def test_claude_model_none_when_unset(self, tmp_path):
assert len(captured_options) == 1
assert captured_options[0].model is None

async def test_model_override_takes_priority(self, tmp_path):
"""Test that model_override overrides claude_model from config."""
config = Settings(
telegram_bot_token="test:token",
telegram_bot_username="testbot",
approved_directory=tmp_path,
claude_timeout_seconds=2,
claude_model="claude-sonnet-4-6",
)
manager = ClaudeSDKManager(config)

captured_options = []
mock_factory = _mock_client_factory(
_make_assistant_message("Test response"),
_make_result_message(total_cost_usd=0.01),
capture_options=captured_options,
)

with patch(
"src.claude.sdk_integration.ClaudeSDKClient", side_effect=mock_factory
):
await manager.execute_command(
prompt="Test prompt",
working_directory=tmp_path,
model_override="claude-opus-4-6",
)

assert len(captured_options) == 1
assert captured_options[0].model == "claude-opus-4-6"

async def test_model_override_none_uses_config(self, tmp_path):
"""Test that model_override=None falls back to config model."""
config = Settings(
telegram_bot_token="test:token",
telegram_bot_username="testbot",
approved_directory=tmp_path,
claude_timeout_seconds=2,
claude_model="claude-haiku-4-5-20251001",
)
manager = ClaudeSDKManager(config)

captured_options = []
mock_factory = _mock_client_factory(
_make_assistant_message("Test response"),
_make_result_message(total_cost_usd=0.01),
capture_options=captured_options,
)

with patch(
"src.claude.sdk_integration.ClaudeSDKClient", side_effect=mock_factory
):
await manager.execute_command(
prompt="Test prompt",
working_directory=tmp_path,
model_override=None,
)

assert len(captured_options) == 1
assert captured_options[0].model == "claude-haiku-4-5-20251001"


class TestClaudeMCPErrors:
"""Test MCP-specific error handling."""
Expand Down
Loading