From 2c2a1b376b27385e4296a0452a0b82897fb1e5f0 Mon Sep 17 00:00:00 2001 From: Bartosz Osowski Date: Sat, 7 Mar 2026 09:50:11 +0000 Subject: [PATCH 1/3] feat: add /model command for runtime model switching via Telegram Allow users to switch Claude models at runtime without restarting the bot. Supports aliases (sonnet, opus, haiku), full model names, and /model default to reset. Override is per-user. --- src/bot/orchestrator.py | 69 ++++++ src/claude/facade.py | 5 + src/claude/sdk_integration.py | 10 +- tests/unit/test_claude/test_facade.py | 79 +++++++ .../unit/test_claude/test_sdk_integration.py | 60 +++++ tests/unit/test_orchestrator.py | 210 +++++++++++++++++- 6 files changed, 425 insertions(+), 8 deletions(-) diff --git a/src/bot/orchestrator.py b/src/bot/orchestrator.py index ac1d5304..855c7809 100644 --- a/src/bot/orchestrator.py +++ b/src/bot/orchestrator.py @@ -306,6 +306,7 @@ def _register_agentic_handlers(self, app: Application) -> None: ("new", self.agentic_new), ("status", self.agentic_status), ("verbose", self.agentic_verbose), + ("model", self.agentic_model), ("repo", self.agentic_repo), ("restart", command.restart_command), ] @@ -415,6 +416,7 @@ async def get_bot_commands(self) -> list: # type: ignore[type-arg] BotCommand("new", "Start a fresh session"), BotCommand("status", "Show session status"), BotCommand("verbose", "Set output verbosity (0/1/2)"), + BotCommand("model", "Switch Claude model"), BotCommand("repo", "List repos / switch workspace"), BotCommand("restart", "Restart the bot"), ] @@ -578,6 +580,66 @@ async def agentic_verbose( parse_mode="HTML", ) + def _get_model_override(self, context: ContextTypes.DEFAULT_TYPE) -> Optional[str]: + """Return per-user model override, or None to use the default.""" + return context.user_data.get("model_override") + + @staticmethod + def _resolve_model_display( + user_override: Optional[str], + config_model: Optional[str], + last_model: Optional[str] = None, + ) -> str: + """Return a human-readable model string showing what will actually be used.""" + if user_override: + return user_override + if config_model: + return config_model + if last_model: + return last_model + return "unknown (send a message first to detect)" + + async def agentic_model( + self, update: Update, context: ContextTypes.DEFAULT_TYPE + ) -> None: + """Set Claude model: /model [model_name].""" + args = update.message.text.split()[1:] if update.message.text else [] + user_override = self._get_model_override(context) + last_model = context.user_data.get("last_model") + current = self._resolve_model_display( + user_override, self.settings.claude_model, last_model + ) + + if not args: + source = "user override" if user_override else ( + "server config" if self.settings.claude_model else "Claude Code default" + ) + await update.message.reply_text( + f"Model: {escape_html(current)} ({source})\n\n" + "Usage: /model model_name\n" + "Aliases: sonnet, opus, haiku\n" + "Full names: claude-sonnet-4-6, claude-opus-4-6, " + "claude-haiku-4-5-20251001\n" + "Reset: /model default", + parse_mode="HTML", + ) + return + + model_name = args[0].strip() + if model_name == "default": + context.user_data.pop("model_override", None) + default = self._resolve_model_display(None, self.settings.claude_model) + await update.message.reply_text( + f"Model reset to {escape_html(default)}", + parse_mode="HTML", + ) + else: + context.user_data["model_override"] = model_name + await update.message.reply_text( + f"Model set to {escape_html(model_name)}", + parse_mode="HTML", + ) + def _format_verbose_progress( self, activity_log: List[Dict[str, Any]], @@ -941,6 +1003,7 @@ async def agentic_text( session_id=session_id, on_stream=on_stream, force_new=force_new, + model_override=self._get_model_override(context), ) # New session created successfully — clear the one-shot flag @@ -948,6 +1011,8 @@ async def agentic_text( context.user_data["force_new_session"] = False context.user_data["claude_session_id"] = claude_response.session_id + if claude_response.model: + context.user_data["last_model"] = claude_response.model # Track directory changes from .handlers.message import _update_working_directory_from_claude_response @@ -1185,12 +1250,15 @@ async def agentic_document( session_id=session_id, on_stream=on_stream, force_new=force_new, + model_override=self._get_model_override(context), ) if force_new: context.user_data["force_new_session"] = False context.user_data["claude_session_id"] = claude_response.session_id + if claude_response.model: + context.user_data["last_model"] = claude_response.model from .handlers.message import _update_working_directory_from_claude_response @@ -1384,6 +1452,7 @@ async def _handle_agentic_media_message( session_id=session_id, on_stream=on_stream, force_new=force_new, + model_override=self._get_model_override(context), ) finally: heartbeat.cancel() diff --git a/src/claude/facade.py b/src/claude/facade.py index fcb2ada6..09545ff6 100644 --- a/src/claude/facade.py +++ b/src/claude/facade.py @@ -37,6 +37,7 @@ async def run_command( session_id: Optional[str] = None, on_stream: Optional[Callable[[StreamUpdate], None]] = None, force_new: bool = False, + model_override: Optional[str] = None, ) -> ClaudeResponse: """Run Claude Code command with full integration.""" logger.info( @@ -85,6 +86,7 @@ async def run_command( session_id=claude_session_id, continue_session=should_continue, stream_callback=on_stream, + model_override=model_override, ) except Exception as resume_error: # If resume failed (e.g., session expired/missing on Claude's side), @@ -109,6 +111,7 @@ async def run_command( session_id=None, continue_session=False, stream_callback=on_stream, + model_override=model_override, ) else: raise @@ -152,6 +155,7 @@ async def _execute( session_id: Optional[str] = None, continue_session: bool = False, stream_callback: Optional[Callable] = None, + model_override: Optional[str] = None, ) -> ClaudeResponse: """Execute command via SDK.""" return await self.sdk_manager.execute_command( @@ -160,6 +164,7 @@ async def _execute( session_id=session_id, continue_session=continue_session, stream_callback=stream_callback, + model_override=model_override, ) async def _find_resumable_session( diff --git a/src/claude/sdk_integration.py b/src/claude/sdk_integration.py index adf553f4..c6b5e591 100644 --- a/src/claude/sdk_integration.py +++ b/src/claude/sdk_integration.py @@ -53,6 +53,7 @@ class ClaudeResponse: is_error: bool = False error_type: Optional[str] = None tools_used: List[Dict[str, Any]] = field(default_factory=list) + model: Optional[str] = None @dataclass @@ -153,6 +154,7 @@ async def execute_command( session_id: Optional[str] = None, continue_session: bool = False, stream_callback: Optional[Callable[[StreamUpdate], None]] = None, + model_override: Optional[str] = None, ) -> ClaudeResponse: """Execute Claude Code command via SDK.""" start_time = asyncio.get_event_loop().time() @@ -197,7 +199,7 @@ def _stderr_callback(line: str) -> None: # Build Claude Agent options options = ClaudeAgentOptions( max_turns=self.config.claude_max_turns, - model=self.config.claude_model or None, + model=model_override or self.config.claude_model or None, max_budget_usd=self.config.claude_max_cost_per_request, cwd=str(working_directory), allowed_tools=sdk_allowed_tools, @@ -294,11 +296,12 @@ async def _run_client() -> None: timeout=self.config.claude_timeout_seconds, ) - # Extract cost, tools, and session_id from result message + # Extract cost, tools, session_id, and model from result message cost = 0.0 tools_used: List[Dict[str, Any]] = [] claude_session_id = None result_content = None + response_model: Optional[str] = None for message in messages: if isinstance(message, ResultMessage): cost = getattr(message, "total_cost_usd", 0.0) or 0.0 @@ -307,6 +310,8 @@ async def _run_client() -> None: current_time = asyncio.get_event_loop().time() for msg in messages: if isinstance(msg, AssistantMessage): + if not response_model: + response_model = getattr(msg, "model", None) msg_content = getattr(msg, "content", []) if msg_content and isinstance(msg_content, list): for block in msg_content: @@ -377,6 +382,7 @@ async def _run_client() -> None: ] ), tools_used=tools_used, + model=response_model, ) except asyncio.TimeoutError: diff --git a/tests/unit/test_claude/test_facade.py b/tests/unit/test_claude/test_facade.py index 666a2246..814522e2 100644 --- a/tests/unit/test_claude/test_facade.py +++ b/tests/unit/test_claude/test_facade.py @@ -269,6 +269,85 @@ async def test_retry_after_failure_still_skips_auto_resume( assert user_data["force_new_session"] is False +class TestModelOverride: + """Verify model_override is passed through to _execute.""" + + async def test_model_override_forwarded_to_execute(self, facade, session_manager): + """run_command passes model_override through to _execute.""" + project = Path("/test/project") + user_id = 123 + + with patch.object( + facade, + "_execute", + return_value=_make_mock_response(), + ) as mock_execute: + await facade.run_command( + prompt="hello", + working_directory=project, + user_id=user_id, + model_override="opus", + ) + + mock_execute.assert_called_once() + assert mock_execute.call_args.kwargs["model_override"] == "opus" + + async def test_model_override_none_by_default(self, facade, session_manager): + """run_command passes model_override=None when not specified.""" + project = Path("/test/project") + user_id = 123 + + with patch.object( + facade, + "_execute", + return_value=_make_mock_response(), + ) as mock_execute: + await facade.run_command( + prompt="hello", + working_directory=project, + user_id=user_id, + ) + + mock_execute.assert_called_once() + assert mock_execute.call_args.kwargs["model_override"] is None + + async def test_model_override_survives_session_retry(self, facade, session_manager): + """model_override is preserved when session resume fails and retries.""" + project = Path("/test/project") + user_id = 123 + + # Seed an existing session so resume is attempted + existing = ClaudeSession( + session_id="old-session", + user_id=user_id, + project_path=project, + created_at=datetime.utcnow(), + last_used=datetime.utcnow(), + ) + await session_manager.storage.save_session(existing) + session_manager.active_sessions[existing.session_id] = existing + + call_count = [0] + + async def _execute_side_effect(**kwargs): + call_count[0] += 1 + if call_count[0] == 1: + raise RuntimeError("session expired") + return _make_mock_response() + + with patch.object(facade, "_execute", side_effect=_execute_side_effect): + await facade.run_command( + prompt="hello", + working_directory=project, + user_id=user_id, + session_id="old-session", + model_override="haiku", + ) + + # Both the initial call and retry should have model_override="haiku" + assert call_count[0] == 2 + + class TestEmptySessionIdWarning: """Verify facade warns when final session_id is empty.""" diff --git a/tests/unit/test_claude/test_sdk_integration.py b/tests/unit/test_claude/test_sdk_integration.py index 17ba58ab..0af40e46 100644 --- a/tests/unit/test_claude/test_sdk_integration.py +++ b/tests/unit/test_claude/test_sdk_integration.py @@ -696,6 +696,66 @@ async def test_claude_model_none_when_unset(self, tmp_path): assert len(captured_options) == 1 assert captured_options[0].model is None + async def test_model_override_takes_priority(self, tmp_path): + """Test that model_override overrides claude_model from config.""" + config = Settings( + telegram_bot_token="test:token", + telegram_bot_username="testbot", + approved_directory=tmp_path, + claude_timeout_seconds=2, + claude_model="claude-sonnet-4-6", + ) + manager = ClaudeSDKManager(config) + + captured_options = [] + mock_factory = _mock_client_factory( + _make_assistant_message("Test response"), + _make_result_message(total_cost_usd=0.01), + capture_options=captured_options, + ) + + with patch( + "src.claude.sdk_integration.ClaudeSDKClient", side_effect=mock_factory + ): + await manager.execute_command( + prompt="Test prompt", + working_directory=tmp_path, + model_override="claude-opus-4-6", + ) + + assert len(captured_options) == 1 + assert captured_options[0].model == "claude-opus-4-6" + + async def test_model_override_none_uses_config(self, tmp_path): + """Test that model_override=None falls back to config model.""" + config = Settings( + telegram_bot_token="test:token", + telegram_bot_username="testbot", + approved_directory=tmp_path, + claude_timeout_seconds=2, + claude_model="claude-haiku-4-5-20251001", + ) + manager = ClaudeSDKManager(config) + + captured_options = [] + mock_factory = _mock_client_factory( + _make_assistant_message("Test response"), + _make_result_message(total_cost_usd=0.01), + capture_options=captured_options, + ) + + with patch( + "src.claude.sdk_integration.ClaudeSDKClient", side_effect=mock_factory + ): + await manager.execute_command( + prompt="Test prompt", + working_directory=tmp_path, + model_override=None, + ) + + assert len(captured_options) == 1 + assert captured_options[0].model == "claude-haiku-4-5-20251001" + class TestClaudeMCPErrors: """Test MCP-specific error handling.""" diff --git a/tests/unit/test_orchestrator.py b/tests/unit/test_orchestrator.py index cc02b7c0..7a5bc64c 100644 --- a/tests/unit/test_orchestrator.py +++ b/tests/unit/test_orchestrator.py @@ -82,8 +82,8 @@ def deps(): } -def test_agentic_registers_6_commands(agentic_settings, deps): - """Agentic mode registers start, new, status, verbose, repo, restart commands.""" +def test_agentic_registers_7_commands(agentic_settings, deps): + """Agentic mode registers start, new, status, verbose, model, repo, restart.""" orchestrator = MessageOrchestrator(agentic_settings, deps) app = MagicMock() app.add_handler = MagicMock() @@ -100,11 +100,12 @@ def test_agentic_registers_6_commands(agentic_settings, deps): ] commands = [h[0][0].commands for h in cmd_handlers] - assert len(cmd_handlers) == 6 + assert len(cmd_handlers) == 7 assert frozenset({"start"}) in commands assert frozenset({"new"}) in commands assert frozenset({"status"}) in commands assert frozenset({"verbose"}) in commands + assert frozenset({"model"}) in commands assert frozenset({"repo"}) in commands assert frozenset({"restart"}) in commands @@ -156,13 +157,13 @@ def test_agentic_registers_text_document_photo_handlers(agentic_settings, deps): async def test_agentic_bot_commands(agentic_settings, deps): - """Agentic mode returns 6 bot commands.""" + """Agentic mode returns 7 bot commands.""" orchestrator = MessageOrchestrator(agentic_settings, deps) commands = await orchestrator.get_bot_commands() - assert len(commands) == 6 + assert len(commands) == 7 cmd_names = [c.command for c in commands] - assert cmd_names == ["start", "new", "status", "verbose", "repo", "restart"] + assert cmd_names == ["start", "new", "status", "verbose", "model", "repo", "restart"] async def test_classic_bot_commands(classic_settings, deps): @@ -926,3 +927,200 @@ async def help_command(update, context): assert called["value"] is False update.effective_message.reply_text.assert_called_once() + + +# --- /model command tests --- + + +async def test_agentic_model_shows_last_model_when_unset(agentic_settings, deps): + """/model with no override shows the model from the last response.""" + orchestrator = MessageOrchestrator(agentic_settings, deps) + + update = MagicMock() + update.message.text = "/model" + update.message.reply_text = AsyncMock() + + context = MagicMock() + context.user_data = {"last_model": "claude-opus-4-6"} + + await orchestrator.agentic_model(update, context) + + call_args = update.message.reply_text.call_args + text = call_args.args[0] + assert "claude-opus-4-6" in text + assert "Claude Code default" in text + + +async def test_agentic_model_shows_unknown_before_first_message(agentic_settings, deps): + """/model before any message shows unknown.""" + orchestrator = MessageOrchestrator(agentic_settings, deps) + + update = MagicMock() + update.message.text = "/model" + update.message.reply_text = AsyncMock() + + context = MagicMock() + context.user_data = {} + + await orchestrator.agentic_model(update, context) + + call_args = update.message.reply_text.call_args + text = call_args.args[0] + assert "unknown" in text.lower() + assert call_args.kwargs.get("parse_mode") == "HTML" + + +async def test_agentic_model_shows_config_model(tmp_dir, deps): + """/model shows the server-configured model when CLAUDE_MODEL is set.""" + settings = create_test_config( + approved_directory=str(tmp_dir), + agentic_mode=True, + claude_model="claude-opus-4-6", + ) + orchestrator = MessageOrchestrator(settings, deps) + + update = MagicMock() + update.message.text = "/model" + update.message.reply_text = AsyncMock() + + context = MagicMock() + context.user_data = {} + + await orchestrator.agentic_model(update, context) + + text = update.message.reply_text.call_args.args[0] + assert "claude-opus-4-6" in text + assert "server config" in text + + +async def test_agentic_model_shows_user_override(agentic_settings, deps): + """/model shows the user's override when one is set.""" + orchestrator = MessageOrchestrator(agentic_settings, deps) + + update = MagicMock() + update.message.text = "/model" + update.message.reply_text = AsyncMock() + + context = MagicMock() + context.user_data = {"model_override": "haiku"} + + await orchestrator.agentic_model(update, context) + + text = update.message.reply_text.call_args.args[0] + assert "haiku" in text + assert "user override" in text + + +async def test_agentic_model_sets_override(agentic_settings, deps): + """/model sonnet sets the user's model override.""" + orchestrator = MessageOrchestrator(agentic_settings, deps) + + update = MagicMock() + update.message.text = "/model sonnet" + update.message.reply_text = AsyncMock() + + context = MagicMock() + context.user_data = {} + + await orchestrator.agentic_model(update, context) + + assert context.user_data["model_override"] == "sonnet" + text = update.message.reply_text.call_args.args[0] + assert "sonnet" in text + + +async def test_agentic_model_reset_to_default(agentic_settings, deps): + """/model default clears the user's model override.""" + orchestrator = MessageOrchestrator(agentic_settings, deps) + + update = MagicMock() + update.message.text = "/model default" + update.message.reply_text = AsyncMock() + + context = MagicMock() + context.user_data = {"model_override": "opus"} + + await orchestrator.agentic_model(update, context) + + assert "model_override" not in context.user_data + text = update.message.reply_text.call_args.args[0] + assert "reset" in text.lower() + + +async def test_model_override_passed_to_run_command(agentic_settings, deps): + """User model override is passed through to claude_integration.run_command.""" + orchestrator = MessageOrchestrator(agentic_settings, deps) + + mock_response = MagicMock() + mock_response.session_id = "session-abc" + mock_response.content = "Hello!" + mock_response.tools_used = [] + + claude_integration = AsyncMock() + claude_integration.run_command = AsyncMock(return_value=mock_response) + + update = MagicMock() + update.effective_user.id = 123 + update.message.text = "Help me" + update.message.message_id = 1 + update.message.chat.send_action = AsyncMock() + update.message.reply_text = AsyncMock() + + progress_msg = AsyncMock() + progress_msg.delete = AsyncMock() + update.message.reply_text.return_value = progress_msg + + context = MagicMock() + context.user_data = {"model_override": "haiku"} + context.bot_data = { + "settings": agentic_settings, + "claude_integration": claude_integration, + "storage": None, + "rate_limiter": None, + "audit_logger": None, + } + + await orchestrator.agentic_text(update, context) + + claude_integration.run_command.assert_called_once() + call_kwargs = claude_integration.run_command.call_args.kwargs + assert call_kwargs["model_override"] == "haiku" + + +async def test_model_override_none_when_not_set(agentic_settings, deps): + """model_override is None when user hasn't set one.""" + orchestrator = MessageOrchestrator(agentic_settings, deps) + + mock_response = MagicMock() + mock_response.session_id = "session-abc" + mock_response.content = "Hello!" + mock_response.tools_used = [] + + claude_integration = AsyncMock() + claude_integration.run_command = AsyncMock(return_value=mock_response) + + update = MagicMock() + update.effective_user.id = 123 + update.message.text = "Help me" + update.message.message_id = 1 + update.message.chat.send_action = AsyncMock() + update.message.reply_text = AsyncMock() + + progress_msg = AsyncMock() + progress_msg.delete = AsyncMock() + update.message.reply_text.return_value = progress_msg + + context = MagicMock() + context.user_data = {} + context.bot_data = { + "settings": agentic_settings, + "claude_integration": claude_integration, + "storage": None, + "rate_limiter": None, + "audit_logger": None, + } + + await orchestrator.agentic_text(update, context) + + call_kwargs = claude_integration.run_command.call_args.kwargs + assert call_kwargs["model_override"] is None From beba83893a4eaf3df9eafa9df52401333a8f0a5c Mon Sep 17 00:00:00 2001 From: Bartosz Osowski Date: Sat, 7 Mar 2026 10:58:04 +0000 Subject: [PATCH 2/3] =?UTF-8?q?fix:=20address=20PR=20review=20=E2=80=94=20?= =?UTF-8?q?audit=20logging,=20media=20handler,=20input=20validation?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add audit logging to /model set and reset paths - Update last_model in _handle_agentic_media_message - Add input validation (max length) for model name - Add tests for audit logging and long model name rejection --- src/bot/orchestrator.py | 17 ++++++++++--- tests/unit/test_orchestrator.py | 44 +++++++++++++++++++++++++++++++++ 2 files changed, 58 insertions(+), 3 deletions(-) diff --git a/src/bot/orchestrator.py b/src/bot/orchestrator.py index 855c7809..0b094d94 100644 --- a/src/bot/orchestrator.py +++ b/src/bot/orchestrator.py @@ -617,15 +617,15 @@ async def agentic_model( await update.message.reply_text( f"Model: {escape_html(current)} ({source})\n\n" "Usage: /model model_name\n" - "Aliases: sonnet, opus, haiku\n" - "Full names: claude-sonnet-4-6, claude-opus-4-6, " - "claude-haiku-4-5-20251001\n" "Reset: /model default", parse_mode="HTML", ) return model_name = args[0].strip() + if not model_name or len(model_name) > 100: + await update.message.reply_text("Invalid model name.") + return if model_name == "default": context.user_data.pop("model_override", None) default = self._resolve_model_display(None, self.settings.claude_model) @@ -640,6 +640,16 @@ async def agentic_model( parse_mode="HTML", ) + # Audit log + audit_logger = context.bot_data.get("audit_logger") + if audit_logger: + await audit_logger.log_command( + user_id=update.effective_user.id, + command="model", + args=[model_name], + success=True, + ) + def _format_verbose_progress( self, activity_log: List[Dict[str, Any]], @@ -1461,6 +1471,7 @@ async def _handle_agentic_media_message( context.user_data["force_new_session"] = False context.user_data["claude_session_id"] = claude_response.session_id + context.user_data["last_model"] = claude_response.model from .handlers.message import _update_working_directory_from_claude_response diff --git a/tests/unit/test_orchestrator.py b/tests/unit/test_orchestrator.py index 7a5bc64c..3cdfe389 100644 --- a/tests/unit/test_orchestrator.py +++ b/tests/unit/test_orchestrator.py @@ -1018,9 +1018,11 @@ async def test_agentic_model_sets_override(agentic_settings, deps): update = MagicMock() update.message.text = "/model sonnet" update.message.reply_text = AsyncMock() + update.effective_user.id = 123 context = MagicMock() context.user_data = {} + context.bot_data = {"audit_logger": AsyncMock()} await orchestrator.agentic_model(update, context) @@ -1036,9 +1038,11 @@ async def test_agentic_model_reset_to_default(agentic_settings, deps): update = MagicMock() update.message.text = "/model default" update.message.reply_text = AsyncMock() + update.effective_user.id = 123 context = MagicMock() context.user_data = {"model_override": "opus"} + context.bot_data = {"audit_logger": AsyncMock()} await orchestrator.agentic_model(update, context) @@ -1047,6 +1051,46 @@ async def test_agentic_model_reset_to_default(agentic_settings, deps): assert "reset" in text.lower() +async def test_agentic_model_audit_logged(agentic_settings, deps): + """/model sonnet logs the action to audit logger.""" + orchestrator = MessageOrchestrator(agentic_settings, deps) + + update = MagicMock() + update.message.text = "/model sonnet" + update.message.reply_text = AsyncMock() + update.effective_user.id = 42 + + audit_logger = AsyncMock() + context = MagicMock() + context.user_data = {} + context.bot_data = {"audit_logger": audit_logger} + + await orchestrator.agentic_model(update, context) + + audit_logger.log_command.assert_called_once_with( + user_id=42, command="model", args=["sonnet"], success=True, + ) + + + +async def test_agentic_model_rejects_long_name(agentic_settings, deps): + """/model with overly long name is rejected.""" + orchestrator = MessageOrchestrator(agentic_settings, deps) + + update = MagicMock() + update.message.text = "/model " + "a" * 101 + update.message.reply_text = AsyncMock() + + context = MagicMock() + context.user_data = {} + + await orchestrator.agentic_model(update, context) + + assert "model_override" not in context.user_data + text = update.message.reply_text.call_args.args[0] + assert "Invalid" in text + + async def test_model_override_passed_to_run_command(agentic_settings, deps): """User model override is passed through to claude_integration.run_command.""" orchestrator = MessageOrchestrator(agentic_settings, deps) From 32f345324fb9c263d721874e45f939f2dc705534 Mon Sep 17 00:00:00 2001 From: Bartosz Osowski Date: Sat, 7 Mar 2026 11:19:25 +0000 Subject: [PATCH 3/3] fix: distinguish model set vs reset in audit log Log model set as command="model" and reset as command="model_reset" with empty args for cleaner audit log queries. --- src/bot/orchestrator.py | 25 +++++++++++++++---------- tests/unit/test_orchestrator.py | 21 +++++++++++++++++++++ 2 files changed, 36 insertions(+), 10 deletions(-) diff --git a/src/bot/orchestrator.py b/src/bot/orchestrator.py index 0b094d94..da997cac 100644 --- a/src/bot/orchestrator.py +++ b/src/bot/orchestrator.py @@ -626,6 +626,7 @@ async def agentic_model( if not model_name or len(model_name) > 100: await update.message.reply_text("Invalid model name.") return + audit_logger = context.bot_data.get("audit_logger") if model_name == "default": context.user_data.pop("model_override", None) default = self._resolve_model_display(None, self.settings.claude_model) @@ -633,22 +634,26 @@ async def agentic_model( f"Model reset to {escape_html(default)}", parse_mode="HTML", ) + if audit_logger: + await audit_logger.log_command( + user_id=update.effective_user.id, + command="model_reset", + args=[], + success=True, + ) else: context.user_data["model_override"] = model_name await update.message.reply_text( f"Model set to {escape_html(model_name)}", parse_mode="HTML", ) - - # Audit log - audit_logger = context.bot_data.get("audit_logger") - if audit_logger: - await audit_logger.log_command( - user_id=update.effective_user.id, - command="model", - args=[model_name], - success=True, - ) + if audit_logger: + await audit_logger.log_command( + user_id=update.effective_user.id, + command="model", + args=[model_name], + success=True, + ) def _format_verbose_progress( self, diff --git a/tests/unit/test_orchestrator.py b/tests/unit/test_orchestrator.py index 3cdfe389..6029e0cc 100644 --- a/tests/unit/test_orchestrator.py +++ b/tests/unit/test_orchestrator.py @@ -1072,6 +1072,27 @@ async def test_agentic_model_audit_logged(agentic_settings, deps): ) +async def test_agentic_model_reset_audit_logged(agentic_settings, deps): + """/model default logs as model_reset with empty args.""" + orchestrator = MessageOrchestrator(agentic_settings, deps) + + update = MagicMock() + update.message.text = "/model default" + update.message.reply_text = AsyncMock() + update.effective_user.id = 42 + + audit_logger = AsyncMock() + context = MagicMock() + context.user_data = {"model_override": "opus"} + context.bot_data = {"audit_logger": audit_logger} + + await orchestrator.agentic_model(update, context) + + audit_logger.log_command.assert_called_once_with( + user_id=42, command="model_reset", args=[], success=True, + ) + + async def test_agentic_model_rejects_long_name(agentic_settings, deps): """/model with overly long name is rejected."""