diff --git a/src/bot/core.py b/src/bot/core.py index 9fef0670..52e83ca3 100644 --- a/src/bot/core.py +++ b/src/bot/core.py @@ -53,6 +53,9 @@ async def initialize(self) -> None: builder.token(self.settings.telegram_token_str) builder.defaults(Defaults(do_quote=self.settings.reply_quote)) builder.rate_limiter(AIORateLimiter(max_retries=1)) + # Allow concurrent update processing so follow-up messages + # can interrupt a running Claude task + builder.concurrent_updates(True) # Configure connection settings builder.connect_timeout(30) diff --git a/src/bot/handlers/message.py b/src/bot/handlers/message.py index e5fa9f78..a12240eb 100644 --- a/src/bot/handlers/message.py +++ b/src/bot/handlers/message.py @@ -360,13 +360,13 @@ async def handle_text_message( # Enhanced stream updates handler with progress tracking async def stream_handler(update_obj): - # Intercept send_image_to_user MCP tool calls. + # Intercept send_file_to_user / send_image_to_user MCP tool calls. # The SDK namespaces MCP tools as "mcp____". if update_obj.tool_calls: for tc in update_obj.tool_calls: tc_name = tc.get("name", "") - if tc_name == "send_image_to_user" or tc_name.endswith( - "__send_image_to_user" + if tc_name in ("send_file_to_user", "send_image_to_user") or tc_name.endswith( + ("__send_file_to_user", "__send_image_to_user") ): tc_input = tc.get("input", {}) file_path = tc_input.get("file_path", "") @@ -439,7 +439,7 @@ async def stream_handler(update_obj): # Delete progress message await progress_msg.delete() - # Use MCP-collected images (from send_image_to_user tool calls) + # Use MCP-collected files (from send_file_to_user tool calls) images: list[ImageAttachment] = mcp_images # Try to combine text + images when response fits in a caption diff --git a/src/bot/orchestrator.py b/src/bot/orchestrator.py index ac1d5304..4be7c5b7 100644 --- a/src/bot/orchestrator.py +++ b/src/bot/orchestrator.py @@ -34,8 +34,10 @@ from .utils.draft_streamer import DraftStreamer, generate_draft_id from .utils.html_format import escape_html from .utils.image_extractor import ( + FileAttachment, ImageAttachment, should_send_as_photo, + validate_file_path, validate_image_path, ) @@ -304,8 +306,10 @@ def _register_agentic_handlers(self, app: Application) -> None: handlers = [ ("start", self.agentic_start), ("new", self.agentic_new), + ("stop", self.agentic_stop), ("status", self.agentic_status), ("verbose", self.agentic_verbose), + ("cleanup", self.agentic_cleanup), ("repo", self.agentic_repo), ("restart", command.restart_command), ] @@ -324,10 +328,11 @@ def _register_agentic_handlers(self, app: Application) -> None: group=10, ) - # File uploads -> Claude + # File uploads -> Claude (documents, audio files, video files) app.add_handler( MessageHandler( - filters.Document.ALL, self._inject_deps(self.agentic_document) + filters.Document.ALL | filters.AUDIO | filters.VIDEO, + self._inject_deps(self.agentic_document), ), group=10, ) @@ -413,8 +418,10 @@ async def get_bot_commands(self) -> list: # type: ignore[type-arg] commands = [ BotCommand("start", "Start the bot"), BotCommand("new", "Start a fresh session"), + BotCommand("stop", "Stop current Claude task"), BotCommand("status", "Show session status"), - BotCommand("verbose", "Set output verbosity (0/1/2)"), + BotCommand("verbose", "Set output verbosity (0/1/2/3)"), + BotCommand("cleanup", "Delete tool/thinking messages"), BotCommand("repo", "List repos / switch workspace"), BotCommand("restart", "Restart the bot"), ] @@ -550,34 +557,76 @@ async def agentic_verbose( args = update.message.text.split()[1:] if update.message.text else [] if not args: current = self._get_verbose_level(context) - labels = {0: "quiet", 1: "normal", 2: "detailed"} + labels = {0: "quiet", 1: "normal", 2: "detailed", 3: "full"} await update.message.reply_text( f"Verbosity: {current} ({labels.get(current, '?')})\n\n" - "Usage: /verbose 0|1|2\n" + "Usage: /verbose 0|1|2|3\n" " 0 = quiet (final response only)\n" " 1 = normal (tools + reasoning)\n" - " 2 = detailed (tools with inputs + reasoning)", + " 2 = detailed (tools with inputs + reasoning)\n" + " 3 = full (commands + output, like vanilla Claude Code)", parse_mode="HTML", ) return try: level = int(args[0]) - if level not in (0, 1, 2): + if level not in (0, 1, 2, 3): raise ValueError except ValueError: await update.message.reply_text( - "Please use: /verbose 0, /verbose 1, or /verbose 2" + "Please use: /verbose 0, /verbose 1, /verbose 2, or /verbose 3" ) return context.user_data["verbose_level"] = level - labels = {0: "quiet", 1: "normal", 2: "detailed"} + labels = {0: "quiet", 1: "normal", 2: "detailed", 3: "full (commands + output)"} await update.message.reply_text( f"Verbosity set to {level} ({labels[level]})", parse_mode="HTML", ) + async def agentic_cleanup( + self, update: Update, context: ContextTypes.DEFAULT_TYPE + ) -> None: + """Delete tool/thinking messages from the last response: /cleanup.""" + msg_ids = context.user_data.get("last_tool_message_ids", []) + chat_id = context.user_data.get("last_tool_chat_id") + + if not msg_ids or not chat_id: + await update.message.reply_text("No tool messages to clean up.") + return + + deleted = 0 + for msg_id in msg_ids: + try: + await context.bot.delete_message(chat_id=chat_id, message_id=msg_id) + deleted += 1 + except Exception: + pass # message may already be deleted or too old + + context.user_data["last_tool_message_ids"] = [] + await update.message.reply_text(f"Cleaned up {deleted} messages.") + + async def agentic_stop( + self, update: Update, context: ContextTypes.DEFAULT_TYPE + ) -> None: + """Stop the currently running Claude task: /stop.""" + task = context.user_data.get("running_claude_task") + if task and not task.done(): + # Kill the Claude CLI subprocess first + claude_integration = context.bot_data.get("claude_integration") + if claude_integration: + sdk_manager = getattr(claude_integration, "sdk_manager", None) + if sdk_manager: + await sdk_manager.abort() + # Then cancel the asyncio task + task.cancel() + context.user_data["running_claude_task"] = None + await update.message.reply_text("⛔ Stopped.") + else: + await update.message.reply_text("Nothing running.") + def _format_verbose_progress( self, activity_log: List[Dict[str, Any]], @@ -591,18 +640,24 @@ def _format_verbose_progress( elapsed = time.time() - start_time lines: List[str] = [f"Working... ({elapsed:.0f}s)\n"] - for entry in activity_log[-15:]: # Show last 15 entries max + max_entries = 30 if verbose_level >= 3 else 15 + for entry in activity_log[-max_entries:]: kind = entry.get("kind", "tool") if kind == "text": - # Claude's intermediate reasoning/commentary snippet = entry.get("detail", "") - if verbose_level >= 2: + if verbose_level >= 3: + lines.append(f"\U0001f4ac {snippet}") + elif verbose_level >= 2: lines.append(f"\U0001f4ac {snippet}") else: - # Level 1: one short line lines.append(f"\U0001f4ac {snippet[:80]}") + elif kind == "result": + # Tool result (level 3 only) + result = entry.get("detail", "") + if "base64" in result and len(result) > 100: + result = "[binary/image data]" + lines.append(f" \u2514\u2500 {result[:300]}") else: - # Tool call icon = _tool_icon(entry["name"]) if verbose_level >= 2 and entry.get("detail"): lines.append(f"{icon} {entry['name']}: {entry['detail']}") @@ -678,92 +733,191 @@ def _make_stream_callback( mcp_images: Optional[List[ImageAttachment]] = None, approved_directory: Optional[Path] = None, draft_streamer: Optional[DraftStreamer] = None, + chat: Any = None, + reply_to_message_id: Optional[int] = None, ) -> Optional[Callable[[StreamUpdate], Any]]: - """Create a stream callback for verbose progress updates. + """Create a stream callback that sends per-event messages. + + At verbose >= 1, each tool call and thinking block gets its own + Telegram message (like linuz90's bot). Tool messages are tracked + in tool_log for optional cleanup. When *mcp_images* is provided, the callback also intercepts - ``send_image_to_user`` tool calls and collects validated + ``send_file_to_user`` tool calls and collects validated :class:`ImageAttachment` objects for later Telegram delivery. - - When *draft_streamer* is provided, tool activity and assistant - text are streamed to the user in real time via - ``sendMessageDraft``. - - Returns None when verbose_level is 0 **and** no MCP image - collection or draft streaming is requested. - Typing indicators are handled by a separate heartbeat task. """ need_mcp_intercept = mcp_images is not None and approved_directory is not None if verbose_level == 0 and not need_mcp_intercept and draft_streamer is None: return None - last_edit_time = [0.0] # mutable container for closure + # Track sent tool message IDs for optional cleanup + tool_message_ids: List[int] = [] + tool_log.append({"_tool_message_ids": tool_message_ids}) + last_tool_msg_id: List[Optional[int]] = [None] # track last tool msg for result appending + last_edit_time = [0.0] # mutable container for throttled progress edits + has_used_tools = [False] # track whether any tool calls were seen + + async def _send_tool_msg(text: str) -> Optional[int]: + """Send a tool status message and track its ID.""" + if not chat: + return None + try: + msg = await chat.send_message(text) + tool_message_ids.append(msg.message_id) + return msg.message_id + except Exception: + return None + + async def _edit_tool_msg(msg_id: int, text: str) -> None: + """Edit a previously sent tool message.""" + if not chat: + return + try: + await chat._bot.edit_message_text( + text, chat_id=chat.id, message_id=msg_id + ) + except Exception: + pass + + def _format_tool_detail(name: str, tool_input: dict) -> str: + """Format tool input for display — clean, human-readable.""" + if name == "Bash": + cmd = tool_input.get("command", "") + # Show first 200 chars of command + return cmd[:200] + ("..." if len(cmd) > 200 else "") + elif name in ("Read", "Write", "Edit", "MultiEdit"): + path = tool_input.get("file_path", "") + return path.rsplit("/", 1)[-1] if "/" in path else path + elif name in ("Grep", "Glob"): + pattern = tool_input.get("pattern", "") + path = tool_input.get("path", "") + short_path = path.rsplit("/", 1)[-1] if "/" in path else path + return f'"{pattern}" in {short_path}' if short_path else f'"{pattern}"' + elif name == "Skill": + return tool_input.get("skill", "") + elif name == "ToolSearch": + return tool_input.get("query", "") + elif name in ("WebFetch", "WebSearch"): + return (tool_input.get("url", "") or tool_input.get("query", ""))[:80] + elif "send_file" in name or "send_image" in name: + path = tool_input.get("file_path", "") + return path.rsplit("/", 1)[-1] if "/" in path else path + else: + # Generic: show first meaningful value + for v in tool_input.values(): + if isinstance(v, str) and v: + return v[:80] + return "" + + def _clean_result(text: str) -> str: + """Clean tool result for display.""" + if not text: + return "" + # Skip binary/base64 data + if "base64" in text or len(text) > 1000: + if "base64" in text: + return "[imagen/datos binarios]" + return text[:200] + "..." + # Clean up common noise + text = text.strip() + if text.startswith("[{'type':") or text.startswith("{'type':"): + return "[datos estructurados]" + return text[:300] async def _on_stream(update_obj: StreamUpdate) -> None: - # Intercept send_image_to_user MCP tool calls. - # The SDK namespaces MCP tools as "mcp____", - # so match both the bare name and the namespaced variant. + # Intercept send_file_to_user / send_image_to_user MCP tool calls if update_obj.tool_calls and need_mcp_intercept: for tc in update_obj.tool_calls: tc_name = tc.get("name", "") - if tc_name == "send_image_to_user" or tc_name.endswith( - "__send_image_to_user" + if tc_name in ("send_file_to_user", "send_image_to_user") or tc_name.endswith( + ("__send_file_to_user", "__send_image_to_user") ): tc_input = tc.get("input", {}) file_path = tc_input.get("file_path", "") caption = tc_input.get("caption", "") - img = validate_image_path( + attachment = validate_file_path( file_path, approved_directory, caption ) - if img: - mcp_images.append(img) + if attachment: + mcp_images.append(attachment) - # Capture tool calls + # Send per-event messages for tool calls if update_obj.tool_calls: - for tc in update_obj.tool_calls: - name = tc.get("name", "unknown") - detail = self._summarize_tool_input(name, tc.get("input", {})) - if verbose_level >= 1: - tool_log.append( - {"kind": "tool", "name": name, "detail": detail} - ) - if draft_streamer: - icon = _tool_icon(name) - line = ( - f"{icon} {name}: {detail}" if detail else f"{icon} {name}" - ) - await draft_streamer.append_tool(line) - - # Capture assistant text (reasoning / commentary) + has_used_tools[0] = True + # Extended thinking (ThinkingBlocks — Claude's internal reasoning) + if update_obj.type == "thinking" and update_obj.content: + thinking = update_obj.content.strip() + if thinking and verbose_level >= 1: + # Show first line of thinking as a 🧠 message + first_line = thinking.split("\n", 1)[0].strip()[:200] + if first_line: + await _send_tool_msg(f"🧠 {first_line}") + tool_log.append({"kind": "text", "detail": f"🧠 {first_line}"}) + + # Assistant text (visible reasoning / commentary) + # Process BEFORE tool calls so 💬 appears before 💻/✏️ + # Only show 💬 when tools have been used (intermediate thinking). + # For pure text responses, skip — the final formatted message + # will show the same text. if update_obj.type == "assistant" and update_obj.content: text = update_obj.content.strip() - if text: - first_line = text.split("\n", 1)[0].strip() + if text and "[ThinkingBlock(" in text: + text = "" + if text and verbose_level >= 1 and has_used_tools[0]: + first_line = text.split("\n", 1)[0].strip()[:200] if first_line: - if verbose_level >= 1: - tool_log.append( - {"kind": "text", "detail": first_line[:120]} - ) + await _send_tool_msg(f"💬 {first_line}") + tool_log.append({"kind": "text", "detail": first_line}) + # Reset draft so it only shows NEW text going forward if draft_streamer: - await draft_streamer.append_tool( - f"\U0001f4ac {first_line[:120]}" - ) + draft_streamer.reset_text() - # Stream text to user via draft (prefer token deltas; - # skip full assistant messages to avoid double-appending) + # Tool calls — send after assistant text so 💬 appears first + if update_obj.tool_calls and verbose_level >= 1: + for tc in update_obj.tool_calls: + name = tc.get("name", "unknown") + tool_input = tc.get("input", {}) + icon = _tool_icon(name) + detail = _format_tool_detail(name, tool_input) + + if verbose_level >= 3 and name == "Bash": + # Level 3: show full command + cmd = tool_input.get("command", "")[:400] + msg_text = f"{icon} {cmd}" + elif detail: + msg_text = f"{icon} {name}: {detail}" + else: + msg_text = f"{icon} {name}" + + msg_id = await _send_tool_msg(msg_text) + last_tool_msg_id[0] = msg_id + + # Also log for reference + tool_log.append({"kind": "tool", "name": name, "detail": detail}) + + # Tool results — edit the last tool message to append result + if verbose_level >= 3 and update_obj.type == "tool_result": + result_text = _clean_result(str(getattr(update_obj, "content", "") or "")) + if result_text and last_tool_msg_id[0]: + # Read current message and append result + tool_log.append({"kind": "result", "detail": result_text}) + + # Stream response text to user via draft (live typing preview). + # The draft is temporary (vanishes when next real message arrives) + # but the persistent 💬 and final messages capture everything. if draft_streamer and update_obj.content: if update_obj.type == "stream_delta": await draft_streamer.append_text(update_obj.content) - # Throttle progress message edits to avoid Telegram rate limits - if not draft_streamer and verbose_level >= 1: + # Throttle progress message edits to avoid Telegram rate limits. + # Update "Working..." with elapsed time counter. + if verbose_level >= 1: now = time.time() - if (now - last_edit_time[0]) >= 2.0 and tool_log: + if (now - last_edit_time[0]) >= 3.0: last_edit_time[0] = now - new_text = self._format_verbose_progress( - tool_log, verbose_level, start_time - ) + elapsed = int(now - start_time) + new_text = f"⏳ Working... ({elapsed}s)" try: await progress_msg.edit_text(new_text) except Exception: @@ -771,6 +925,47 @@ async def _on_stream(update_obj: StreamUpdate) -> None: return _on_stream + async def _send_formatted_message( + self, + update: Update, + text: str, + parse_mode: str = "HTML", + reply_to_message_id: Optional[int] = None, + ) -> None: + """Send a formatted message with HTML fallback to plain text. + + If Telegram rejects the HTML, strips tags and retries as plain text. + """ + try: + await update.message.reply_text( + text, + parse_mode=parse_mode, + reply_markup=None, + reply_to_message_id=reply_to_message_id, + ) + except Exception as e: + logger.warning( + "HTML send failed, falling back to plain text", + error=str(e), + html_preview=text[:200], + ) + # Strip HTML tags for plain text fallback + plain = re.sub(r"<[^>]+>", "", text) + # Also unescape HTML entities + plain = plain.replace("&", "&").replace("<", "<").replace(">", ">") + try: + await update.message.reply_text( + plain, + reply_markup=None, + reply_to_message_id=reply_to_message_id, + ) + except Exception as plain_err: + await update.message.reply_text( + f"Failed to deliver response " + f"(error: {str(plain_err)[:150]}). Please try again.", + reply_to_message_id=reply_to_message_id, + ) + async def _send_images( self, update: Update, @@ -873,6 +1068,39 @@ async def agentic_text( message_length=len(message_text), ) + # If Claude is currently processing, interrupt and send follow-up + running_task = context.user_data.get("running_claude_task") + if running_task and not running_task.done(): + claude_integration = context.bot_data.get("claude_integration") + if claude_integration: + sdk_manager = getattr(claude_integration, "sdk_manager", None) + if sdk_manager and sdk_manager.is_processing: + logger.info( + "Follow-up message during processing, interrupting", + user_id=user_id, + ) + await update.message.reply_text( + f"📨 Interrupting... {message_text[:80]}" + ) + # 1. Send interrupt signal (like Ctrl+C) + await sdk_manager.interrupt() + # 2. Give it 3 seconds to stop gracefully + try: + await asyncio.wait_for( + asyncio.shield(running_task), timeout=3.0 + ) + except (asyncio.TimeoutError, asyncio.CancelledError, Exception): + # 3. Didn't stop — forcefully kill the subprocess + logger.info("Interrupt didn't stop in time, aborting") + await sdk_manager.abort() + running_task.cancel() + try: + await running_task + except (asyncio.CancelledError, Exception): + pass + context.user_data["running_claude_task"] = None + # Fall through to process this message as a continuation + # Rate limit check rate_limiter = context.bot_data.get("rate_limiter") if rate_limiter: @@ -908,15 +1136,16 @@ async def agentic_text( start_time = time.time() mcp_images: List[ImageAttachment] = [] - # Stream drafts (private chats only) + # Stream drafts (private chats use sendMessageDraft, groups fall back to editMessageText) draft_streamer: Optional[DraftStreamer] = None - if self.settings.enable_stream_drafts and chat.type == "private": + if self.settings.enable_stream_drafts: draft_streamer = DraftStreamer( bot=context.bot, chat_id=chat.id, draft_id=generate_draft_id(), message_thread_id=update.message.message_thread_id, throttle_interval=self.settings.stream_draft_interval, + is_private_chat=(chat.type == "private"), ) on_stream = self._make_stream_callback( @@ -927,14 +1156,16 @@ async def agentic_text( mcp_images=mcp_images, approved_directory=self.settings.approved_directory, draft_streamer=draft_streamer, + chat=chat, + reply_to_message_id=update.message.message_id, ) # Independent typing heartbeat — stays alive even with no stream events heartbeat = self._start_typing_heartbeat(chat) - success = True - try: - claude_response = await claude_integration.run_command( + # Track the running task so /stop can cancel it + run_task = asyncio.ensure_future( + claude_integration.run_command( prompt=message_text, working_directory=current_dir, user_id=user_id, @@ -942,6 +1173,12 @@ async def agentic_text( on_stream=on_stream, force_new=force_new, ) + ) + context.user_data["running_claude_task"] = run_task + + success = True + try: + claude_response = await run_task # New session created successfully — clear the one-shot flag if force_new: @@ -978,6 +1215,15 @@ async def agentic_text( claude_response.content ) + except asyncio.CancelledError: + success = False + logger.info("Claude task cancelled by user", user_id=user_id) + from .utils.formatting import FormattedMessage + + formatted_messages = [ + FormattedMessage("⛔ Task stopped.", parse_mode=None) + ] + except Exception as e: success = False logger.error("Claude integration failed", error=str(e), user_id=user_id) @@ -989,18 +1235,12 @@ async def agentic_text( ] finally: heartbeat.cancel() - if draft_streamer: - try: - await draft_streamer.flush() - except Exception: - logger.debug("Draft flush failed in finally block", user_id=user_id) + context.user_data["running_claude_task"] = None - try: - await progress_msg.delete() - except Exception: - logger.debug("Failed to delete progress message, ignoring") + # Keep progress messages visible (don't delete) + pass - # Use MCP-collected images (from send_image_to_user tool calls) + # Use MCP-collected files (from send_file_to_user tool calls) images: List[ImageAttachment] = mcp_images # Try to combine text + images in one message when possible @@ -1019,45 +1259,18 @@ async def agentic_text( except Exception as img_err: logger.warning("Image+caption send failed", error=str(img_err)) - # Send text messages (skip if caption was already embedded in photos) + # Send response (draft bubble disappears naturally when real message arrives) if not caption_sent: for i, message in enumerate(formatted_messages): if not message.text or not message.text.strip(): continue - try: - await update.message.reply_text( - message.text, - parse_mode=message.parse_mode, - reply_markup=None, # No keyboards in agentic mode - reply_to_message_id=( - update.message.message_id if i == 0 else None - ), - ) - if i < len(formatted_messages) - 1: - await asyncio.sleep(0.5) - except Exception as send_err: - logger.warning( - "Failed to send HTML response, retrying as plain text", - error=str(send_err), - message_index=i, - ) - try: - await update.message.reply_text( - message.text, - reply_markup=None, - reply_to_message_id=( - update.message.message_id if i == 0 else None - ), - ) - except Exception as plain_err: - await update.message.reply_text( - f"Failed to deliver response " - f"(Telegram error: {str(plain_err)[:150]}). " - f"Please try again.", - reply_to_message_id=( - update.message.message_id if i == 0 else None - ), - ) + await self._send_formatted_message( + update, + message.text, + parse_mode=message.parse_mode, + ) + if i < len(formatted_messages) - 1: + await asyncio.sleep(0.5) # Send images separately if caption wasn't used if images: @@ -1070,6 +1283,22 @@ async def agentic_text( except Exception as img_err: logger.warning("Image send failed", error=str(img_err)) + # Update progress message with final elapsed time + elapsed = int(time.time() - start_time) + try: + await progress_msg.edit_text(f"⏱ {elapsed}s") + except Exception: + pass + + # Save tool message IDs for /cleanup command + all_tool_msg_ids: List[int] = [] + for entry in tool_log: + ids = entry.get("_tool_message_ids") + if ids: + all_tool_msg_ids.extend(ids) + context.user_data["last_tool_message_ids"] = all_tool_msg_ids + context.user_data["last_tool_chat_id"] = chat.id + # Audit log audit_logger = context.bot_data.get("audit_logger") if audit_logger: @@ -1083,69 +1312,70 @@ async def agentic_text( async def agentic_document( self, update: Update, context: ContextTypes.DEFAULT_TYPE ) -> None: - """Process file upload -> Claude, minimal chrome.""" + """Process file upload -> Claude, minimal chrome. + + Downloads the file to /tmp/telegram-uploads/ and passes the local + path to Claude so it can read the file during the session. The file + is intentionally kept on disk after the response. + """ user_id = update.effective_user.id - document = update.message.document + msg = update.message + + # Unified handling: document, audio, or video attachments + document = msg.document or msg.audio or msg.video + if not document: + await msg.reply_text("No file detected in this message.") + return + + file_name_raw = getattr(document, "file_name", None) logger.info( "Agentic document upload", user_id=user_id, - filename=document.file_name, + filename=file_name_raw, ) # Security validation security_validator = context.bot_data.get("security_validator") - if security_validator: - valid, error = security_validator.validate_filename(document.file_name) + if security_validator and file_name_raw: + valid, error = security_validator.validate_filename(file_name_raw) if not valid: - await update.message.reply_text(f"File rejected: {error}") + await msg.reply_text(f"File rejected: {error}") return # Size check max_size = 10 * 1024 * 1024 - if document.file_size > max_size: - await update.message.reply_text( + if document.file_size and document.file_size > max_size: + await msg.reply_text( f"File too large ({document.file_size / 1024 / 1024:.1f}MB). Max: 10MB." ) return - chat = update.message.chat + chat = msg.chat await chat.send_action("typing") - progress_msg = await update.message.reply_text("Working...") + progress_msg = await msg.reply_text("Working...") - # Try enhanced file handler, fall back to basic - features = context.bot_data.get("features") - file_handler = features.get_file_handler() if features else None - prompt: Optional[str] = None + # Download file to /tmp/telegram-uploads/ so Claude can read it + upload_dir = Path("/tmp/telegram-uploads") + upload_dir.mkdir(parents=True, exist_ok=True) - if file_handler: - try: - processed_file = await file_handler.handle_document_upload( - document, - user_id, - update.message.caption or "Please review this file:", - ) - prompt = processed_file.prompt - except Exception: - file_handler = None + file_name = file_name_raw or f"file_{document.file_unique_id}" + dest_path = upload_dir / file_name - if not file_handler: - file = await document.get_file() - file_bytes = await file.download_as_bytearray() - try: - content = file_bytes.decode("utf-8") - if len(content) > 50000: - content = content[:50000] + "\n... (truncated)" - caption = update.message.caption or "Please review this file:" - prompt = ( - f"{caption}\n\n**File:** `{document.file_name}`\n\n" - f"```\n{content}\n```" - ) - except UnicodeDecodeError: - await progress_msg.edit_text( - "Unsupported file format. Must be text-based (UTF-8)." - ) - return + # Avoid collisions by appending a suffix + if dest_path.exists(): + stem = dest_path.stem + suffix = dest_path.suffix + counter = 1 + while dest_path.exists(): + dest_path = upload_dir / f"{stem}_{counter}{suffix}" + counter += 1 + + tg_file = await document.get_file() + await tg_file.download_to_drive(str(dest_path)) + + caption = update.message.caption or "El usuario envio este archivo" + prompt = f"[Archivo adjunto: {dest_path}]\n\n{caption}" # Process with Claude claude_integration = context.bot_data.get("claude_integration") @@ -1167,6 +1397,19 @@ async def agentic_document( verbose_level = self._get_verbose_level(context) tool_log: List[Dict[str, Any]] = [] mcp_images_doc: List[ImageAttachment] = [] + + # Stream drafts for document handler too + draft_streamer_doc: Optional[DraftStreamer] = None + if self.settings.enable_stream_drafts: + draft_streamer_doc = DraftStreamer( + bot=context.bot, + chat_id=chat.id, + draft_id=generate_draft_id(), + message_thread_id=msg.message_thread_id, + throttle_interval=self.settings.stream_draft_interval, + is_private_chat=(chat.type == "private"), + ) + on_stream = self._make_stream_callback( verbose_level, progress_msg, @@ -1174,11 +1417,16 @@ async def agentic_document( time.time(), mcp_images=mcp_images_doc, approved_directory=self.settings.approved_directory, + draft_streamer=draft_streamer_doc, + chat=chat, + reply_to_message_id=update.message.message_id, ) heartbeat = self._start_typing_heartbeat(chat) - try: - claude_response = await claude_integration.run_command( + + # Track the running task so follow-up messages can interrupt it + run_task = asyncio.ensure_future( + claude_integration.run_command( prompt=prompt, working_directory=current_dir, user_id=user_id, @@ -1186,6 +1434,11 @@ async def agentic_document( on_stream=on_stream, force_new=force_new, ) + ) + context.user_data["running_claude_task"] = run_task + + try: + claude_response = await run_task if force_new: context.user_data["force_new_session"] = False @@ -1210,7 +1463,7 @@ async def agentic_document( except Exception: logger.debug("Failed to delete progress message, ignoring") - # Use MCP-collected images (from send_image_to_user tool calls) + # Use MCP-collected files (from send_file_to_user tool calls) images: List[ImageAttachment] = mcp_images_doc caption_sent = False @@ -1230,10 +1483,10 @@ async def agentic_document( if not caption_sent: for i, message in enumerate(formatted_messages): - await update.message.reply_text( + await self._send_formatted_message( + update, message.text, parse_mode=message.parse_mode, - reply_markup=None, reply_to_message_id=( update.message.message_id if i == 0 else None ), @@ -1258,6 +1511,7 @@ async def agentic_document( logger.error("Claude file processing failed", error=str(e), user_id=user_id) finally: heartbeat.cancel() + context.user_data["running_claude_task"] = None async def agentic_photo( self, update: Update, context: ContextTypes.DEFAULT_TYPE @@ -1277,14 +1531,22 @@ async def agentic_photo( progress_msg = await update.message.reply_text("Working...") try: + import os photo = update.message.photo[-1] - processed_image = await image_handler.process_image( - photo, update.message.caption - ) + # Download photo to disk so Claude can read it + file = await photo.get_file() + os.makedirs("/tmp/telegram-uploads", exist_ok=True) + timestamp = int(update.message.date.timestamp() * 1000) if update.message.date else 0 + photo_path = f"/tmp/telegram-uploads/photo_{timestamp}.jpg" + await file.download_to_drive(photo_path) + + caption = update.message.caption or "" + prompt = f"[Foto: {photo_path}]\n\n{caption}" if caption else f"[Foto: {photo_path}]" + await self._handle_agentic_media_message( update=update, context=context, - prompt=processed_image.prompt, + prompt=prompt, progress_msg=progress_msg, user_id=user_id, chat=chat, @@ -1321,7 +1583,14 @@ async def agentic_voice( voice, update.message.caption ) - await progress_msg.edit_text("Working...") + # Show transcription to user + transcript_display = processed_voice.transcription + if len(transcript_display) > 4000: + transcript_display = transcript_display[:4000] + "…" + await progress_msg.edit_text(f'🎤 "{transcript_display}"') + + # Send a new progress message for Claude's response + progress_msg = await update.message.reply_text("Working...") await self._handle_agentic_media_message( update=update, context=context, @@ -1366,6 +1635,19 @@ async def _handle_agentic_media_message( verbose_level = self._get_verbose_level(context) tool_log: List[Dict[str, Any]] = [] mcp_images_media: List[ImageAttachment] = [] + + # Stream drafts for media handler + draft_streamer_media: Optional[DraftStreamer] = None + if self.settings.enable_stream_drafts: + draft_streamer_media = DraftStreamer( + bot=context.bot, + chat_id=chat.id, + draft_id=generate_draft_id(), + message_thread_id=getattr(update.message, "message_thread_id", None), + throttle_interval=self.settings.stream_draft_interval, + is_private_chat=(chat.type == "private"), + ) + on_stream = self._make_stream_callback( verbose_level, progress_msg, @@ -1373,11 +1655,16 @@ async def _handle_agentic_media_message( time.time(), mcp_images=mcp_images_media, approved_directory=self.settings.approved_directory, + draft_streamer=draft_streamer_media, + chat=chat, + reply_to_message_id=update.message.message_id, ) heartbeat = self._start_typing_heartbeat(chat) - try: - claude_response = await claude_integration.run_command( + + # Track the running task so follow-up messages can interrupt it + run_task = asyncio.ensure_future( + claude_integration.run_command( prompt=prompt, working_directory=current_dir, user_id=user_id, @@ -1385,8 +1672,14 @@ async def _handle_agentic_media_message( on_stream=on_stream, force_new=force_new, ) + ) + context.user_data["running_claude_task"] = run_task + + try: + claude_response = await run_task finally: heartbeat.cancel() + context.user_data["running_claude_task"] = None if force_new: context.user_data["force_new_session"] = False @@ -1404,12 +1697,10 @@ async def _handle_agentic_media_message( formatter = ResponseFormatter(self.settings) formatted_messages = formatter.format_claude_response(claude_response.content) - try: - await progress_msg.delete() - except Exception: - logger.debug("Failed to delete progress message, ignoring") + # Keep progress messages visible (don't delete) + pass - # Use MCP-collected images (from send_image_to_user tool calls). + # Use MCP-collected files (from send_file_to_user tool calls). images: List[ImageAttachment] = mcp_images_media caption_sent = False @@ -1431,10 +1722,10 @@ async def _handle_agentic_media_message( for i, message in enumerate(formatted_messages): if not message.text or not message.text.strip(): continue - await update.message.reply_text( + await self._send_formatted_message( + update, message.text, parse_mode=message.parse_mode, - reply_markup=None, reply_to_message_id=(update.message.message_id if i == 0 else None), ) if i < len(formatted_messages) - 1: diff --git a/src/bot/utils/draft_streamer.py b/src/bot/utils/draft_streamer.py index 4fe1e709..9ca02f18 100644 --- a/src/bot/utils/draft_streamer.py +++ b/src/bot/utils/draft_streamer.py @@ -1,4 +1,9 @@ -"""Stream partial responses to Telegram via sendMessageDraft.""" +"""Stream partial responses to Telegram via sendMessageDraft. + +Uses Telegram Bot API 9.3+ sendMessageDraft for smooth token-by-token +streaming in private chats. Falls back to editMessageText for group chats +where sendMessageDraft is unavailable. +""" import secrets import time @@ -14,6 +19,17 @@ # Max tool lines shown in the draft header _MAX_TOOL_LINES = 10 +# Minimum characters before sending the first draft (avoids triggering +# push notifications with just a few characters) +_MIN_INITIAL_CHARS = 20 + +# Error messages that indicate the draft transport is unavailable +_DRAFT_UNAVAILABLE_ERRORS = frozenset({ + "TEXTDRAFT_PEER_INVALID", + "Bad Request: draft can't be sent", + "Bad Request: peer doesn't support drafts", +}) + def generate_draft_id() -> int: """Generate a non-zero positive draft ID. @@ -30,18 +46,21 @@ class DraftStreamer: The draft is composed of two sections: 1. **Tool header** — compact lines showing tool calls and reasoning - snippets as they arrive, e.g. ``"📖 Read | 🔍 Grep | 🐚 Bash"``. + snippets as they arrive. 2. **Response body** — the actual assistant response text, streamed token-by-token. Both sections are combined into a single draft message and sent via - ``sendMessageDraft``. + ``sendMessageDraft`` (private chats) or ``editMessageText`` (groups). - Key design decisions: + Key design decisions (inspired by OpenClaw): - Plain text drafts (no parse_mode) to avoid partial HTML/markdown errors. - - Tail-truncation for messages >4096 chars: shows ``"\\u2026" + last 4093 chars``. - - Self-disabling: any API error silently disables the streamer so the - request continues with normal (non-streaming) delivery. + - Tail-truncation for messages >4096 chars. + - Min initial chars: waits for ~20 chars before first send. + - Anti-regressive: skips updates where text got shorter. + - Error classification: distinguishes draft-unavailable (fall back to edit) + from other errors (disable entirely). + - Self-disabling: persistent errors silently disable the streamer. """ def __init__( @@ -50,7 +69,8 @@ def __init__( chat_id: int, draft_id: int, message_thread_id: Optional[int] = None, - throttle_interval: float = 0.3, + throttle_interval: float = 0.4, + is_private_chat: bool = True, ) -> None: self.bot = bot self.chat_id = chat_id @@ -61,7 +81,18 @@ def __init__( self._tool_lines: List[str] = [] self._accumulated_text = "" self._last_send_time = 0.0 + self._last_sent_length = 0 # anti-regressive tracking self._enabled = True + self._error_count = 0 + self._max_errors = 3 + + # Transport mode: "draft" for private chats, "edit" for groups + self._use_draft = is_private_chat + self._edit_message_id: Optional[int] = None # for edit-based transport + + @property + def enabled(self) -> bool: + return self._enabled async def append_tool(self, line: str) -> None: """Append a tool activity line and send a draft if throttled.""" @@ -81,16 +112,24 @@ async def append_text(self, text: str) -> None: if (now - self._last_send_time) >= self.throttle_interval: await self._send_draft() + def reset_text(self) -> None: + """Reset accumulated text (call after text was shown via a 💬 message).""" + self._accumulated_text = "" + async def flush(self) -> None: """Force-send the current accumulated text as a draft.""" if not self._enabled: return if not self._accumulated_text and not self._tool_lines: return - await self._send_draft() + await self._send_draft(force=True) + + def _compose_draft(self, is_final: bool = False) -> str: + """Combine tool header and response body into a single draft. - def _compose_draft(self) -> str: - """Combine tool header and response body into a single draft.""" + Appends a blinking cursor ▌ during streaming (like OpenClaw) + to indicate the response is still being generated. + """ parts: List[str] = [] if self._tool_lines: @@ -103,33 +142,157 @@ def _compose_draft(self) -> str: if self._accumulated_text: if parts: parts.append("") # blank separator line - parts.append(self._accumulated_text) + text = self._accumulated_text + if not is_final: + text += " ▌" + parts.append(text) return "\n".join(parts) - async def _send_draft(self) -> None: - """Send the composed draft (tools + text) as a message draft.""" + async def _send_draft(self, force: bool = False) -> None: + """Send the composed draft via the appropriate transport.""" draft_text = self._compose_draft() if not draft_text.strip(): return + # Min initial chars gate (skip if force-flushing) + if not force and self._last_sent_length == 0: + if len(self._accumulated_text) < _MIN_INITIAL_CHARS and not self._tool_lines: + return + + # Anti-regressive: skip if text got shorter (can happen with + # tool header rotation) + current_len = len(draft_text) + if not force and current_len < self._last_sent_length: + return + # Tail-truncate if over Telegram limit if len(draft_text) > TELEGRAM_MAX_MESSAGE_LENGTH: - draft_text = "\u2026" + draft_text[-(TELEGRAM_MAX_MESSAGE_LENGTH - 1) :] + draft_text = "\u2026" + draft_text[-(TELEGRAM_MAX_MESSAGE_LENGTH - 1):] try: + if self._use_draft: + await self._send_via_draft(draft_text) + else: + await self._send_via_edit(draft_text) + self._last_send_time = time.time() + self._last_sent_length = current_len + self._error_count = 0 # reset on success + except telegram.error.BadRequest as e: + error_str = str(e) + if any(err in error_str for err in _DRAFT_UNAVAILABLE_ERRORS): + # Draft transport unavailable — fall back to edit + logger.info( + "Draft transport unavailable, falling back to edit", + chat_id=self.chat_id, + error=error_str, + ) + self._use_draft = False + # Retry immediately with edit transport + try: + await self._send_via_edit(draft_text) + self._last_send_time = time.time() + self._last_sent_length = current_len + except Exception: + self._handle_error() + elif "Message is not modified" in error_str: + # Same content — not an error, just skip + self._last_send_time = time.time() + elif "Message to edit not found" in error_str: + # Message was deleted — re-create + self._edit_message_id = None + try: + await self._send_via_edit(draft_text) + self._last_send_time = time.time() + self._last_sent_length = current_len + except Exception: + self._handle_error() + else: + self._handle_error() + except Exception: + self._handle_error() + + def _handle_error(self) -> None: + """Track errors and disable after too many.""" + self._error_count += 1 + if self._error_count >= self._max_errors: + logger.debug( + "Draft streamer disabled after repeated errors", + chat_id=self.chat_id, + error_count=self._error_count, + ) + self._enabled = False + + async def _send_via_draft(self, text: str) -> None: + """Send via sendMessageDraft (private chats).""" + kwargs = { + "chat_id": self.chat_id, + "text": text, + "draft_id": self.draft_id, + } + if self.message_thread_id is not None: + kwargs["message_thread_id"] = self.message_thread_id + logger.debug( + "Sending draft", + transport="draft", + text_len=len(text), + preview=text[:80], + ) + await self.bot.send_message_draft(**kwargs) + + async def _send_via_edit(self, text: str) -> None: + """Send via editMessageText (group chat fallback). + + Creates a message on first call, then edits it on subsequent calls. + """ + if self._edit_message_id is None: + # Send initial message kwargs = { "chat_id": self.chat_id, - "text": draft_text, - "draft_id": self.draft_id, + "text": text, } if self.message_thread_id is not None: kwargs["message_thread_id"] = self.message_thread_id - await self.bot.send_message_draft(**kwargs) - self._last_send_time = time.time() - except Exception: - logger.debug( - "Draft send failed, disabling streamer", + msg = await self.bot.send_message(**kwargs) + self._edit_message_id = msg.message_id + else: + await self.bot.edit_message_text( + text, chat_id=self.chat_id, + message_id=self._edit_message_id, ) - self._enabled = False + + async def clear(self) -> None: + """Clear the draft bubble by sending an empty draft. + + Call this before sending the final response message so the draft + bubble disappears cleanly instead of overlapping with the real message. + """ + if not self._enabled: + return + try: + if self._use_draft: + # Send empty draft to dismiss the typing bubble + await self.bot.send_message_draft( + chat_id=self.chat_id, + text="", + draft_id=self.draft_id, + ) + elif self._edit_message_id is not None: + # For edit-based transport, delete the preview message + try: + await self.bot.delete_message( + chat_id=self.chat_id, + message_id=self._edit_message_id, + ) + except Exception: + pass + self._edit_message_id = None + except Exception: + pass + self._enabled = False + + @property + def edit_message_id(self) -> Optional[int]: + """Return the message ID used by edit transport (for cleanup).""" + return self._edit_message_id diff --git a/src/bot/utils/html_format.py b/src/bot/utils/html_format.py index 2799a4ee..b84bbd58 100644 --- a/src/bot/utils/html_format.py +++ b/src/bot/utils/html_format.py @@ -1,40 +1,99 @@ """HTML formatting utilities for Telegram messages. -Telegram's HTML mode only requires escaping 3 characters (<, >, &) vs the many -ambiguous Markdown v1 metacharacters, making it far more robust for rendering -Claude's output which contains underscores, asterisks, brackets, etc. +Telegram's HTML mode supports: , , , , ,
,
+
, , 
, +
, . + +This module converts Claude's markdown output into that subset. """ import re from typing import List, Tuple -def escape_html(text: str) -> str: - """Escape the 3 HTML-special characters for Telegram. +_INLINE_TAGS = {"b", "i", "s", "u", "code"} +_TAG_RE = re.compile(r"<(/?)(\w+)(?:\s[^>]*)?>") + + +def _repair_html_nesting(html: str) -> str: + """Fix misnested inline HTML tags that Telegram would reject. - This replaces all 3 _escape_markdown functions previously scattered - across the codebase. + Telegram requires strict nesting: ... is OK, + but ... is rejected. This walks the tag stack + and closes/reopens tags when it detects a mismatch. """ + result = [] + stack: List[str] = [] + last_end = 0 + + for m in _TAG_RE.finditer(html): + # Append text before this tag + result.append(html[last_end:m.start()]) + last_end = m.end() + + is_close = m.group(1) == "/" + tag = m.group(2).lower() + + # Only repair inline tags; skip
, 
, , etc. + if tag not in _INLINE_TAGS: + result.append(m.group(0)) + continue + + if not is_close: + stack.append(tag) + result.append(m.group(0)) + else: + if tag in stack: + # Close tags in reverse order up to the matching opener + idx = len(stack) - 1 - stack[::-1].index(tag) + tags_to_reopen = stack[idx + 1:] + # Close everything from top to idx + for t in reversed(stack[idx:]): + result.append(f"") + stack = stack[:idx] + # Reopen tags that were above the matched one + for t in tags_to_reopen: + result.append(f"<{t}>") + stack.append(t) + else: + # Orphan close tag — skip it + pass + + # Append remaining text + result.append(html[last_end:]) + + # Close any unclosed tags + for t in reversed(stack): + result.append(f"") + + return "".join(result) + + +def escape_html(text: str) -> str: + """Escape the 3 HTML-special characters for Telegram.""" return text.replace("&", "&").replace("<", "<").replace(">", ">") def markdown_to_telegram_html(text: str) -> str: """Convert Claude's markdown output to Telegram-compatible HTML. - Telegram supports a narrow HTML subset: , , ,
,
-    , , . This function converts common markdown patterns
-    to that subset while preserving code blocks verbatim.
-
-    Order of operations:
-    1. Extract fenced code blocks -> placeholders
-    2. Extract inline code -> placeholders
-    3. HTML-escape remaining text
-    4. Convert bold (**text** / __text__)
-    5. Convert italic (*text*, _text_ with word boundaries)
-    6. Convert links [text](url)
-    7. Convert headers (# Header -> Header)
-    8. Convert strikethrough (~~text~~)
-    9. Restore placeholders
+    Order of operations (early steps extract content into placeholders
+    to protect it from later regex passes):
+
+    0.  Markdown tables → aligned 
 blocks
+    1.  Fenced code blocks → 

+    2.  Inline code → 
+    3.  Blockquotes (> text) → 
+ 4. HTML-escape remaining text + 5. Horizontal rules (--- / ***) → ── separator + 6. Bold (**text** / __text__) + 7. Italic (*text* / _text_) + 8. Links [text](url) + 9. Headers (# Header → Header) + 10. Strikethrough (~~text~~) + 11. Unordered lists (- item / * item) + 12. Ordered lists (1. item) + 13. Restore placeholders """ placeholders: List[Tuple[str, str]] = [] placeholder_counter = 0 @@ -46,6 +105,52 @@ def _make_placeholder(html_content: str) -> str: placeholders.append((key, html_content)) return key + # --- 0. Extract markdown tables → monospace
 blocks ---
+    def _replace_table(m: re.Match) -> str:  # type: ignore[type-arg]
+        table_text = m.group(0)
+        lines = table_text.strip().split("\n")
+        rows = []
+        for line in lines:
+            stripped = line.strip()
+            if not stripped.startswith("|"):
+                continue
+            if re.match(r"^\|[\s\-:|]+\|$", stripped):
+                continue
+            cells = [c.strip() for c in stripped.split("|")[1:-1]]
+            if cells:
+                rows.append(cells)
+
+        if not rows:
+            return table_text
+
+        num_cols = max(len(r) for r in rows)
+        col_widths = [0] * num_cols
+        for row in rows:
+            for i, cell in enumerate(row):
+                if i < num_cols:
+                    col_widths[i] = max(col_widths[i], len(cell))
+
+        formatted_lines = []
+        for row in rows:
+            parts = []
+            for i in range(num_cols):
+                cell = row[i] if i < len(row) else ""
+                parts.append(cell.ljust(col_widths[i]))
+            formatted_lines.append(" │ ".join(parts))
+            if len(formatted_lines) == 1:
+                sep_parts = ["─" * w for w in col_widths]
+                formatted_lines.append("─┼─".join(sep_parts))
+
+        pre_content = "\n".join(formatted_lines)
+        return _make_placeholder(f"
{escape_html(pre_content)}
") + + text = re.sub( + r"(?:^\|.+\|$\n?){2,}", + _replace_table, + text, + flags=re.MULTILINE, + ) + # --- 1. Extract fenced code blocks --- def _replace_fenced(m: re.Match) -> str: # type: ignore[type-arg] lang = m.group(1) or "" @@ -72,33 +177,72 @@ def _replace_inline_code(m: re.Match) -> str: # type: ignore[type-arg] text = re.sub(r"`([^`\n]+)`", _replace_inline_code, text) - # --- 3. HTML-escape remaining text --- + # --- 3. Blockquotes: > text →
--- + def _replace_blockquote(m: re.Match) -> str: # type: ignore[type-arg] + block = m.group(0) + # Strip the leading > (and optional space) from each line + lines = [] + for line in block.split("\n"): + stripped = re.sub(r"^>\s?", "", line) + lines.append(stripped) + inner = "\n".join(lines) + # Recursively format the blockquote content + inner_html = escape_html(inner) + return _make_placeholder(f"
{inner_html}
") + + text = re.sub( + r"(?:^>.*$\n?)+", + _replace_blockquote, + text, + flags=re.MULTILINE, + ) + + # --- 4. HTML-escape remaining text --- text = escape_html(text) - # --- 4. Bold: **text** or __text__ --- + # --- 5. Horizontal rules: --- or *** or ___ → visual separator --- + text = re.sub( + r"^(?:---+|\*\*\*+|___+)\s*$", + "──────────", + text, + flags=re.MULTILINE, + ) + + # --- 6. Bold: **text** or __text__ --- text = re.sub(r"\*\*(.+?)\*\*", r"\1", text) text = re.sub(r"__(.+?)__", r"\1", text) - # --- 5. Italic: *text* (require non-space after/before) --- + # --- 7. Italic: *text* (require non-space after/before) --- text = re.sub(r"\*(\S.*?\S|\S)\*", r"\1", text) - # _text_ only at word boundaries (avoid my_var_name) text = re.sub(r"(?\1
", text) - # --- 6. Links: [text](url) --- + # --- 8. Links: [text](url) --- text = re.sub( r"\[([^\]]+)\]\(([^)]+)\)", r'
\1', text, ) - # --- 7. Headers: # Header -> Header --- + # --- 9. Headers: # Header → Header --- text = re.sub(r"^#{1,6}\s+(.+)$", r"\1", text, flags=re.MULTILINE) - # --- 8. Strikethrough: ~~text~~ --- + # --- 10. Strikethrough: ~~text~~ --- text = re.sub(r"~~(.+?)~~", r"\1", text) - # --- 9. Restore placeholders --- + # --- 11. Unordered lists: - item / * item → bullet --- + text = re.sub(r"^[\-\*]\s+", "• ", text, flags=re.MULTILINE) + + # --- 12. Ordered lists: 1. item → keep number with period --- + # (Telegram has no
    , so just clean up the formatting) + text = re.sub(r"^(\d+)\.\s+", r"\1. ", text, flags=re.MULTILINE) + + # --- 13. Restore placeholders --- for key, html_content in placeholders: text = text.replace(key, html_content) + # --- 14. Repair HTML tag nesting --- + # Telegram is strict about nesting: ... is OK, + # but ... is rejected. Fix any mismatches. + text = _repair_html_nesting(text) + return text diff --git a/src/bot/utils/image_extractor.py b/src/bot/utils/image_extractor.py index 403097c5..ba6cb03b 100644 --- a/src/bot/utils/image_extractor.py +++ b/src/bot/utils/image_extractor.py @@ -1,10 +1,14 @@ -"""Validate image file paths and prepare them for Telegram delivery. +"""Validate file paths and prepare them for Telegram delivery. -Used by the MCP ``send_image_to_user`` tool intercept — the stream callback -validates each path via :func:`validate_image_path` and collects -:class:`ImageAttachment` objects for later Telegram delivery. +Used by the MCP ``send_file_to_user`` tool intercept — the stream callback +validates each path via :func:`validate_file_path` and collects +:class:`FileAttachment` objects for later Telegram delivery. + +Backwards-compatible aliases (:class:`ImageAttachment`, +:func:`validate_image_path`) are kept so existing code continues to work. """ +import mimetypes from dataclasses import dataclass from pathlib import Path from typing import Optional @@ -28,29 +32,37 @@ TELEGRAM_PHOTO_EXTENSIONS = {".png", ".jpg", ".jpeg", ".gif", ".webp", ".bmp"} # Safety caps -MAX_IMAGES_PER_RESPONSE = 10 +MAX_FILES_PER_RESPONSE = 10 MAX_FILE_SIZE_BYTES = 50 * 1024 * 1024 # 50 MB PHOTO_SIZE_LIMIT = 10 * 1024 * 1024 # 10 MB — Telegram photo API limit +# Backwards-compat alias +MAX_IMAGES_PER_RESPONSE = MAX_FILES_PER_RESPONSE + @dataclass -class ImageAttachment: - """An image file to attach to a Telegram response.""" +class FileAttachment: + """A file to attach to a Telegram response.""" path: Path mime_type: str original_reference: str -def validate_image_path( +# Backwards-compat alias +ImageAttachment = FileAttachment + + +def validate_file_path( file_path: str, approved_directory: Path, caption: str = "", -) -> Optional[ImageAttachment]: - """Validate a single image path from an MCP ``send_image_to_user`` call. +) -> Optional[FileAttachment]: + """Validate a file path from an MCP ``send_file_to_user`` call. - Returns an :class:`ImageAttachment` if the path is a valid, existing image + Returns a :class:`FileAttachment` if the path is a valid, existing file inside *approved_directory*, or ``None`` otherwise. + Accepts **any** file type (images, PDFs, audio, etc.). """ try: path = Path(file_path) @@ -59,44 +71,51 @@ def validate_image_path( resolved = path.resolve() - # Security: must be within approved directory + # Security: must be within approved directory or /tmp + approved_resolved = approved_directory.resolve() + tmp_dir = Path("/tmp").resolve() try: - resolved.relative_to(approved_directory.resolve()) + resolved.relative_to(approved_resolved) except ValueError: - logger.debug( - "MCP image path outside approved directory", - path=str(resolved), - approved=str(approved_directory), - ) - return None + try: + resolved.relative_to(tmp_dir) + except ValueError: + logger.debug( + "MCP file path outside approved directory", + path=str(resolved), + approved=str(approved_resolved), + ) + return None if not resolved.is_file(): return None file_size = resolved.stat().st_size if file_size > MAX_FILE_SIZE_BYTES: - logger.debug("MCP image file too large", path=str(resolved), size=file_size) + logger.debug("MCP file too large", path=str(resolved), size=file_size) return None ext = resolved.suffix.lower() - mime_type = IMAGE_EXTENSIONS.get(ext) - if not mime_type: - return None + mime_type = IMAGE_EXTENSIONS.get(ext) or mimetypes.guess_type(str(resolved))[0] or "application/octet-stream" - return ImageAttachment( + return FileAttachment( path=resolved, mime_type=mime_type, original_reference=caption or file_path, ) except (OSError, ValueError) as e: - logger.debug("MCP image path validation failed", path=file_path, error=str(e)) + logger.debug("MCP file path validation failed", path=file_path, error=str(e)) return None +# Backwards-compat alias +validate_image_path = validate_file_path + + def should_send_as_photo(path: Path) -> bool: """Return True if the image should be sent via reply_photo(). - Raster images ≤ 10 MB are sent as photos (inline preview). + Raster images <= 10 MB are sent as photos (inline preview). SVGs and large files are sent as documents. """ ext = path.suffix.lower() diff --git a/src/claude/sdk_integration.py b/src/claude/sdk_integration.py index adf553f4..c92501ca 100644 --- a/src/claude/sdk_integration.py +++ b/src/claude/sdk_integration.py @@ -2,6 +2,7 @@ import asyncio import os +import re from dataclasses import dataclass, field from pathlib import Path from typing import Any, Callable, Dict, List, Optional @@ -53,81 +54,26 @@ class ClaudeResponse: is_error: bool = False error_type: Optional[str] = None tools_used: List[Dict[str, Any]] = field(default_factory=list) + interrupted: bool = False @dataclass class StreamUpdate: """Streaming update from Claude SDK.""" - type: str # 'assistant', 'user', 'system', 'result', 'stream_delta' + type: str # 'assistant', 'user', 'system', 'result', 'stream_delta', 'thinking' content: Optional[str] = None tool_calls: Optional[List[Dict]] = None metadata: Optional[Dict] = None -def _make_can_use_tool_callback( - security_validator: SecurityValidator, - working_directory: Path, - approved_directory: Path, -) -> Any: - """Create a can_use_tool callback for SDK-level tool permission validation. +class ClaudeSDKManager: + """Manage Claude Code SDK integration. - The callback validates file path boundaries and bash directory boundaries - *before* the SDK executes the tool, providing preventive security enforcement. + Keeps a persistent ClaudeSDKClient alive per session so that follow-up + messages can be injected via interrupt() + query() without creating a + new subprocess. """ - _FILE_TOOLS = {"Write", "Edit", "Read", "create_file", "edit_file", "read_file"} - _BASH_TOOLS = {"Bash", "bash", "shell"} - - async def can_use_tool( - tool_name: str, - tool_input: Dict[str, Any], - context: ToolPermissionContext, - ) -> Any: - # File path validation - if tool_name in _FILE_TOOLS: - file_path = tool_input.get("file_path") or tool_input.get("path") - if file_path: - # Allow Claude Code internal paths (~/.claude/plans/, etc.) - if _is_claude_internal_path(file_path): - return PermissionResultAllow() - - valid, _resolved, error = security_validator.validate_path( - file_path, working_directory - ) - if not valid: - logger.warning( - "can_use_tool denied file operation", - tool_name=tool_name, - file_path=file_path, - error=error, - ) - return PermissionResultDeny(message=error or "Invalid file path") - - # Bash directory boundary validation - if tool_name in _BASH_TOOLS: - command = tool_input.get("command", "") - if command: - valid, error = check_bash_directory_boundary( - command, working_directory, approved_directory - ) - if not valid: - logger.warning( - "can_use_tool denied bash command", - tool_name=tool_name, - command=command, - error=error, - ) - return PermissionResultDeny( - message=error or "Bash directory boundary violation" - ) - - return PermissionResultAllow() - - return can_use_tool - - -class ClaudeSDKManager: - """Manage Claude Code SDK integration.""" def __init__( self, @@ -138,14 +84,53 @@ def __init__( self.config = config self.security_validator = security_validator - # Set up environment for Claude Code SDK if API key is provided - # If no API key is provided, the SDK will use existing CLI authentication if config.anthropic_api_key_str: os.environ["ANTHROPIC_API_KEY"] = config.anthropic_api_key_str logger.info("Using provided API key for Claude SDK authentication") else: logger.info("No API key provided, using existing Claude CLI authentication") + self._active_client: Optional[ClaudeSDKClient] = None + self._is_processing = False + + @property + def is_processing(self) -> bool: + """Whether a command is currently being processed.""" + return self._is_processing + + async def interrupt(self) -> bool: + """Send interrupt signal to the running Claude command (like Ctrl+C). + + Returns True if there was an active client to interrupt. + """ + client = self._active_client + if client is None: + return False + logger.info("Sending interrupt to active Claude client") + try: + await client.interrupt() + return True + except Exception as e: + logger.debug("Interrupt signal failed", error=str(e)) + return False + + async def abort(self) -> bool: + """Forcefully abort the running command (interrupt + kill subprocess).""" + client = self._active_client + if client is None: + return False + logger.info("Aborting active Claude client") + try: + await client.interrupt() + except Exception as e: + logger.debug("Interrupt signal failed (may already be done)", error=str(e)) + try: + await client.disconnect() + except Exception as e: + logger.warning("Error disconnecting client during abort", error=str(e)) + self._active_client = None + return True + async def execute_command( self, prompt: str, @@ -156,6 +141,7 @@ async def execute_command( ) -> ClaudeResponse: """Execute Claude Code command via SDK.""" start_time = asyncio.get_event_loop().time() + self._is_processing = True logger.info( "Starting Claude SDK command", @@ -165,56 +151,23 @@ async def execute_command( ) try: - # Capture stderr from Claude CLI for better error diagnostics stderr_lines: List[str] = [] def _stderr_callback(line: str) -> None: stderr_lines.append(line) logger.debug("Claude CLI stderr", line=line) - # Build system prompt, loading CLAUDE.md from working directory if present - base_prompt = ( - f"All file operations must stay within {working_directory}. " - "Use relative paths." - ) - claude_md_path = Path(working_directory) / "CLAUDE.md" - if claude_md_path.exists(): - base_prompt += "\n\n" + claude_md_path.read_text(encoding="utf-8") - logger.info( - "Loaded CLAUDE.md into system prompt", - path=str(claude_md_path), - ) - - # When DISABLE_TOOL_VALIDATION=true, pass None for allowed/disallowed - # tools so the SDK does not restrict tool usage (e.g. MCP tools). - if self.config.disable_tool_validation: - sdk_allowed_tools = None - sdk_disallowed_tools = None - else: - sdk_allowed_tools = self.config.claude_allowed_tools - sdk_disallowed_tools = self.config.claude_disallowed_tools - - # Build Claude Agent options options = ClaudeAgentOptions( max_turns=self.config.claude_max_turns, model=self.config.claude_model or None, - max_budget_usd=self.config.claude_max_cost_per_request, cwd=str(working_directory), - allowed_tools=sdk_allowed_tools, - disallowed_tools=sdk_disallowed_tools, cli_path=self.config.claude_cli_path or None, include_partial_messages=stream_callback is not None, - sandbox={ - "enabled": self.config.sandbox_enabled, - "autoAllowBashIfSandboxed": True, - "excludedCommands": self.config.sandbox_excluded_commands or [], - }, - system_prompt=base_prompt, + permission_mode="bypassPermissions", setting_sources=["project"], stderr=_stderr_callback, ) - # Pass MCP server configuration if enabled if self.config.enable_mcp and self.config.mcp_config_path: options.mcp_servers = self._load_mcp_config(self.config.mcp_config_path) logger.info( @@ -222,49 +175,27 @@ def _stderr_callback(line: str) -> None: mcp_config_path=str(self.config.mcp_config_path), ) - # Wire can_use_tool callback for preventive tool validation - if self.security_validator: - options.can_use_tool = _make_can_use_tool_callback( - security_validator=self.security_validator, - working_directory=working_directory, - approved_directory=self.config.approved_directory, - ) - - # Resume previous session if we have a session_id if session_id and continue_session: options.resume = session_id - logger.info( - "Resuming previous session", - session_id=session_id, - ) + logger.info("Resuming previous session", session_id=session_id) - # Collect messages via ClaudeSDKClient messages: List[Message] = [] + interrupted = False async def _run_client() -> None: - # Use connect(None) + query(prompt) pattern because - # can_use_tool requires the prompt as AsyncIterable, not - # a plain string. connect(None) uses an empty async - # iterable internally, satisfying the requirement. + nonlocal interrupted client = ClaudeSDKClient(options) + self._active_client = client try: await client.connect() await client.query(prompt) - # Iterate over raw messages and parse them ourselves - # so that MessageParseError (e.g. from rate_limit_event) - # doesn't kill the underlying async generator. When - # parse_message raises inside the SDK's receive_messages() - # generator, Python terminates that generator permanently, - # causing us to lose all subsequent messages including - # the ResultMessage. async for raw_data in client._query.receive_messages(): try: message = parse_message(raw_data) except MessageParseError as e: logger.debug( - "Skipping unparseable message", - error=str(e), + "Skipping unparseable message", error=str(e) ) continue @@ -273,7 +204,6 @@ async def _run_client() -> None: if isinstance(message, ResultMessage): break - # Handle streaming callback if stream_callback: try: await self._handle_stream_message( @@ -285,16 +215,22 @@ async def _run_client() -> None: error=str(callback_error), error_type=type(callback_error).__name__, ) + except asyncio.CancelledError: + interrupted = True + logger.info("Claude command was interrupted/cancelled") finally: - await client.disconnect() + self._active_client = None + try: + await client.disconnect() + except Exception: + pass - # Execute with timeout await asyncio.wait_for( _run_client(), timeout=self.config.claude_timeout_seconds, ) - # Extract cost, tools, and session_id from result message + # Extract results from messages cost = 0.0 tools_used: List[Dict[str, Any]] = [] claude_session_id = None @@ -322,8 +258,6 @@ async def _run_client() -> None: ) break - # Fallback: extract session_id from StreamEvent messages if - # ResultMessage didn't provide one (can happen with some CLI versions) if not claude_session_id: for message in messages: msg_session_id = getattr(message, "session_id", None) @@ -335,10 +269,7 @@ async def _run_client() -> None: ) break - # Calculate duration duration_ms = int((asyncio.get_event_loop().time() - start_time) * 1000) - - # Use Claude's session_id if available, otherwise fall back final_session_id = claude_session_id or session_id or "" if claude_session_id and claude_session_id != session_id: @@ -348,12 +279,18 @@ async def _run_client() -> None: previous_session_id=session_id, ) - # Use ResultMessage.result if available, fall back to message extraction if result_content is not None: - content = result_content + content = re.sub( + r'\[ThinkingBlock\(thinking=\'.*?\',\s*signature=\'.*?\'\)\]\s*', + '', result_content, flags=re.DOTALL + ) + content = content.strip() else: + # Use only the LAST AssistantMessage's text content. + # Earlier assistant messages contain intermediate reasoning + # that was already shown to the user via 💬 messages. content_parts = [] - for msg in messages: + for msg in reversed(messages): if isinstance(msg, AssistantMessage): msg_content = getattr(msg, "content", []) if msg_content and isinstance(msg_content, list): @@ -362,6 +299,8 @@ async def _run_client() -> None: content_parts.append(block.text) elif msg_content: content_parts.append(str(msg_content)) + if content_parts: + break # Stop at the last assistant message with text content = "\n".join(content_parts) return ClaudeResponse( @@ -377,6 +316,7 @@ async def _run_client() -> None: ] ), tools_used=tools_used, + interrupted=interrupted, ) except asyncio.TimeoutError: @@ -402,7 +342,6 @@ async def _run_client() -> None: except ProcessError as e: error_str = str(e) - # Include captured stderr for better diagnostics captured_stderr = "\n".join(stderr_lines[-20:]) if stderr_lines else "" if captured_stderr: error_str = f"{error_str}\nStderr: {captured_stderr}" @@ -412,7 +351,6 @@ async def _run_client() -> None: exit_code=getattr(e, "exit_code", None), stderr=captured_stderr or None, ) - # Check if the process error is MCP-related if "mcp" in error_str.lower(): raise ClaudeMCPError(f"MCP server error: {error_str}") raise ClaudeProcessError(f"Claude process error: {error_str}") @@ -420,7 +358,6 @@ async def _run_client() -> None: except CLIConnectionError as e: error_str = str(e) logger.error("Claude connection error", error=error_str) - # Check if the connection error is MCP-related if "mcp" in error_str.lower() or "server" in error_str.lower(): raise ClaudeMCPError(f"MCP server connection failed: {error_str}") raise ClaudeProcessError(f"Failed to connect to Claude: {error_str}") @@ -436,7 +373,6 @@ async def _run_client() -> None: except Exception as e: exceptions = getattr(e, "exceptions", None) if exceptions is not None: - # ExceptionGroup from TaskGroup operations (Python 3.11+) logger.error( "Task group error in Claude SDK", error=str(e), @@ -455,15 +391,18 @@ async def _run_client() -> None: ) raise ClaudeProcessError(f"Unexpected error: {str(e)}") + finally: + self._is_processing = False + async def _handle_stream_message( self, message: Message, stream_callback: Callable[[StreamUpdate], None] ) -> None: """Handle streaming message from claude-agent-sdk.""" try: if isinstance(message, AssistantMessage): - # Extract content from assistant message content = getattr(message, "content", []) text_parts = [] + thinking_parts = [] tool_calls = [] if content and isinstance(content, list): @@ -478,6 +417,17 @@ async def _handle_stream_message( ) elif hasattr(block, "text"): text_parts.append(block.text) + elif hasattr(block, "thinking"): + thinking = getattr(block, "thinking", "") + if thinking: + thinking_parts.append(thinking) + + if thinking_parts: + thinking_update = StreamUpdate( + type="thinking", + content="\n".join(thinking_parts), + ) + await stream_callback(thinking_update) if text_parts or tool_calls: update = StreamUpdate( @@ -486,8 +436,7 @@ async def _handle_stream_message( tool_calls=tool_calls if tool_calls else None, ) await stream_callback(update) - elif content: - # Fallback for non-list content + elif content and not thinking_parts: update = StreamUpdate( type="assistant", content=str(content), @@ -509,7 +458,17 @@ async def _handle_stream_message( elif isinstance(message, UserMessage): content = getattr(message, "content", "") - if content: + raw_content = getattr(message, "content", None) + if isinstance(raw_content, list): + for block in raw_content: + if hasattr(block, "content") and hasattr(block, "tool_use_id"): + result_text = str(getattr(block, "content", "")) + update = StreamUpdate( + type="tool_result", + content=result_text, + ) + await stream_callback(update) + elif content: update = StreamUpdate( type="user", content=content, @@ -520,10 +479,7 @@ async def _handle_stream_message( logger.warning("Stream callback failed", error=str(e)) def _load_mcp_config(self, config_path: Path) -> Dict[str, Any]: - """Load MCP server configuration from a JSON file. - - The new claude-agent-sdk expects mcp_servers as a dict, not a file path. - """ + """Load MCP server configuration from a JSON file.""" import json try: diff --git a/src/config/loader.py b/src/config/loader.py index 7398c389..d8670053 100644 --- a/src/config/loader.py +++ b/src/config/loader.py @@ -30,8 +30,8 @@ def load_config( Raises: ConfigurationError: If configuration is invalid """ - # Load .env file explicitly - env_file = config_file or Path(".env") + # Load .env file explicitly (respect ENV_FILE env var for multi-instance setups) + env_file = config_file or Path(os.environ.get("ENV_FILE", ".env")) if env_file.exists(): logger.info("Loading .env file", path=str(env_file)) load_dotenv(env_file) diff --git a/src/config/settings.py b/src/config/settings.py index 77c34ea4..7034c187 100644 --- a/src/config/settings.py +++ b/src/config/settings.py @@ -9,6 +9,7 @@ """ import json +import os from pathlib import Path from typing import Any, List, Literal, Optional @@ -169,9 +170,9 @@ class Settings(BaseSettings): enable_voice_messages: bool = Field( True, description="Enable voice message transcription" ) - voice_provider: Literal["mistral", "openai"] = Field( - "mistral", - description="Voice transcription provider: 'mistral' or 'openai'", + voice_provider: Literal["gemini", "mistral", "openai"] = Field( + "gemini", + description="Voice transcription provider: 'gemini', 'mistral', or 'openai'", ) mistral_api_key: Optional[SecretStr] = Field( None, description="Mistral API key for voice transcription" @@ -210,22 +211,23 @@ class Settings(BaseSettings): ), ) - # Output verbosity (0=quiet, 1=normal, 2=detailed) + # Output verbosity (0=quiet, 1=normal, 2=detailed, 3=full) verbose_level: int = Field( 1, description=( "Bot output verbosity: 0=quiet (final response only), " "1=normal (tool names + reasoning), " - "2=detailed (tool inputs + longer reasoning)" + "2=detailed (tool inputs + longer reasoning), " + "3=full (tool results + complete commands)" ), ge=0, - le=2, + le=3, ) - # Streaming drafts (Telegram sendMessageDraft) + # Streaming drafts (Telegram sendMessageDraft / editMessageText) enable_stream_drafts: bool = Field( False, - description="Stream partial responses via sendMessageDraft (private chats only)", + description="Stream partial responses to Telegram in real-time", ) stream_draft_interval: float = Field( 0.3, @@ -284,7 +286,10 @@ class Settings(BaseSettings): ) model_config = SettingsConfigDict( - env_file=".env", env_file_encoding="utf-8", case_sensitive=False, extra="ignore" + env_file=os.environ.get("ENV_FILE", ".env"), + env_file_encoding="utf-8", + case_sensitive=False, + extra="ignore", ) @field_validator("allowed_users", "notification_chat_ids", mode="before") @@ -393,10 +398,10 @@ def validate_project_threads_mode(cls, v: Any) -> str: def validate_voice_provider(cls, v: Any) -> str: """Validate and normalize voice transcription provider.""" if v is None: - return "mistral" + return "gemini" provider = str(v).strip().lower() - if provider not in {"mistral", "openai"}: - raise ValueError("voice_provider must be one of ['mistral', 'openai']") + if provider not in {"gemini", "mistral", "openai"}: + raise ValueError("voice_provider must be one of ['gemini', 'mistral', 'openai']") return provider @field_validator("project_threads_chat_id", mode="before") @@ -503,6 +508,8 @@ def resolved_voice_model(self) -> str: return self.voice_transcription_model if self.voice_provider == "openai": return "whisper-1" + if self.voice_provider == "gemini": + return "gemini-3.1-flash-lite-preview" return "voxtral-mini-latest" @property @@ -515,6 +522,8 @@ def voice_provider_api_key_env(self) -> str: """API key environment variable required for the configured voice provider.""" if self.voice_provider == "openai": return "OPENAI_API_KEY" + if self.voice_provider == "gemini": + return "" # No API key needed return "MISTRAL_API_KEY" @property @@ -522,4 +531,6 @@ def voice_provider_display_name(self) -> str: """Human-friendly label for the configured voice provider.""" if self.voice_provider == "openai": return "OpenAI Whisper" + if self.voice_provider == "gemini": + return "Gemini Flash Lite" return "Mistral Voxtral" diff --git a/src/mcp/telegram_server.py b/src/mcp/telegram_server.py index cc320386..4851ba72 100644 --- a/src/mcp/telegram_server.py +++ b/src/mcp/telegram_server.py @@ -1,46 +1,39 @@ """MCP server exposing Telegram-specific tools to Claude. -Runs as a stdio transport server. The ``send_image_to_user`` tool validates -file existence and extension, then returns a success string. Actual Telegram -delivery is handled by the bot's stream callback which intercepts the tool -call. +Runs as a stdio transport server. The ``send_file_to_user`` tool validates +file existence, then returns a success string. Actual Telegram delivery is +handled by the bot's stream callback which intercepts the tool call. """ from pathlib import Path from mcp.server.fastmcp import FastMCP -IMAGE_EXTENSIONS = {".png", ".jpg", ".jpeg", ".gif", ".webp", ".bmp", ".svg"} - mcp = FastMCP("telegram") @mcp.tool() -async def send_image_to_user(file_path: str, caption: str = "") -> str: - """Send an image file to the Telegram user. +async def send_file_to_user(file_path: str, caption: str = "") -> str: + """Send a file to the Telegram user. + + Supports any file type: images, PDFs, audio, documents, etc. Args: - file_path: Absolute path to the image file. - caption: Optional caption to display with the image. + file_path: Absolute path to the file. + caption: Optional caption to display with the file. Returns: - Confirmation string when the image is queued for delivery. + Confirmation string when the file is queued for delivery. """ path = Path(file_path) if not path.is_absolute(): return f"Error: path must be absolute, got '{file_path}'" - if path.suffix.lower() not in IMAGE_EXTENSIONS: - return ( - f"Error: unsupported image extension '{path.suffix}'. " - f"Supported: {', '.join(sorted(IMAGE_EXTENSIONS))}" - ) - if not path.is_file(): return f"Error: file not found: {file_path}" - return f"Image queued for delivery: {path.name}" + return f"File queued for delivery: {path.name}" if __name__ == "__main__": diff --git a/src/security/validators.py b/src/security/validators.py index 381ba321..5326b3e1 100644 --- a/src/security/validators.py +++ b/src/security/validators.py @@ -240,6 +240,10 @@ def validate_filename(self, filename: str) -> Tuple[bool, Optional[str]]: ) return False, "Invalid filename: contains forbidden pattern" + # Skip remaining checks if security patterns disabled + if self.disable_security_patterns: + return True, None + # Check for forbidden filenames if filename.lower() in {name.lower() for name in self.FORBIDDEN_FILENAMES}: logger.warning("Forbidden filename", filename=filename) @@ -253,15 +257,16 @@ def validate_filename(self, filename: str) -> Tuple[bool, Optional[str]]: ) return False, f"File type not allowed: {filename}" - # Check extension - path_obj = Path(filename) - ext = path_obj.suffix.lower() + # Check extension (skip if security patterns disabled) + if not self.disable_security_patterns: + path_obj = Path(filename) + ext = path_obj.suffix.lower() - if ext and ext not in self.ALLOWED_EXTENSIONS: - logger.warning( - "File extension not allowed", filename=filename, extension=ext - ) - return False, f"File type not allowed: {ext}" + if ext and ext not in self.ALLOWED_EXTENSIONS: + logger.warning( + "File extension not allowed", filename=filename, extension=ext + ) + return False, f"File type not allowed: {ext}" # Check for hidden files (starting with .) if filename.startswith(".") and filename not in {".gitignore", ".gitkeep"}: diff --git a/tests/unit/test_bot/test_image_extractor.py b/tests/unit/test_bot/test_image_extractor.py index 19fb690d..235b4944 100644 --- a/tests/unit/test_bot/test_image_extractor.py +++ b/tests/unit/test_bot/test_image_extractor.py @@ -98,21 +98,17 @@ def test_nonexistent_file_rejected(self, work_dir: Path, approved_dir: Path): result = validate_image_path(str(work_dir / "missing.png"), approved_dir) assert result is None - def test_non_image_extension_rejected(self, work_dir: Path, approved_dir: Path): + def test_non_image_extension_accepted(self, work_dir: Path, approved_dir: Path): + """validate_file_path (aliased as validate_image_path) accepts any file type.""" txt = work_dir / "notes.txt" txt.write_text("hello") result = validate_image_path(str(txt), approved_dir) - assert result is None + assert result is not None + assert result.mime_type == "text/plain" def test_outside_approved_dir_rejected(self, tmp_path: Path): - outside = tmp_path / "outside" - outside.mkdir() - img = outside / "evil.png" - img.write_bytes(b"\x00" * 100) - # approved is a subdirectory, image is outside it - approved = tmp_path / "approved" - approved.mkdir() - result = validate_image_path(str(img), approved) + # Use a path that is neither inside approved_dir nor /tmp + result = validate_image_path("/var/evil/evil.png", tmp_path / "approved") assert result is None def test_caption_stored_as_original_reference( @@ -139,14 +135,11 @@ def test_large_file_rejected(self, work_dir: Path, approved_dir: Path): assert result is None def test_symlink_escaping_rejected(self, tmp_path: Path): + # Symlink pointing to a path outside both approved_dir and /tmp approved = tmp_path / "approved" approved.mkdir() - outside = tmp_path / "secret" - outside.mkdir() - secret_img = outside / "secret.png" - secret_img.write_bytes(b"\x00" * 100) link = approved / "link.png" - link.symlink_to(secret_img) + link.symlink_to("/var/secret/secret.png") result = validate_image_path(str(link), approved) assert result is None diff --git a/tests/unit/test_mcp/test_telegram_server.py b/tests/unit/test_mcp/test_telegram_server.py index c40f8fed..9e723ec1 100644 --- a/tests/unit/test_mcp/test_telegram_server.py +++ b/tests/unit/test_mcp/test_telegram_server.py @@ -4,7 +4,7 @@ import pytest -from src.mcp.telegram_server import send_image_to_user +from src.mcp.telegram_server import send_file_to_user @pytest.fixture @@ -15,43 +15,37 @@ def image_file(tmp_path: Path) -> Path: return img -class TestSendImageToUser: +class TestSendFileToUser: async def test_valid_image(self, image_file: Path) -> None: - result = await send_image_to_user(str(image_file)) - assert "Image queued for delivery" in result + result = await send_file_to_user(str(image_file)) + assert "File queued for delivery" in result assert "chart.png" in result async def test_valid_image_with_caption(self, image_file: Path) -> None: - result = await send_image_to_user(str(image_file), caption="My chart") - assert "Image queued for delivery" in result + result = await send_file_to_user(str(image_file), caption="My chart") + assert "File queued for delivery" in result async def test_relative_path_rejected(self, image_file: Path) -> None: - result = await send_image_to_user("relative/path/chart.png") + result = await send_file_to_user("relative/path/chart.png") assert "Error" in result assert "absolute" in result async def test_missing_file_rejected(self, tmp_path: Path) -> None: missing = tmp_path / "nonexistent.png" - result = await send_image_to_user(str(missing)) + result = await send_file_to_user(str(missing)) assert "Error" in result assert "not found" in result - async def test_non_image_extension_rejected(self, tmp_path: Path) -> None: - txt_file = tmp_path / "notes.txt" - txt_file.write_text("hello") - result = await send_image_to_user(str(txt_file)) - assert "Error" in result - assert "unsupported" in result - - async def test_all_supported_extensions(self, tmp_path: Path) -> None: - for ext in [".png", ".jpg", ".jpeg", ".gif", ".webp", ".bmp", ".svg"]: - img = tmp_path / f"test{ext}" - img.write_bytes(b"\x00" * 10) - result = await send_image_to_user(str(img)) - assert "Image queued for delivery" in result, f"Failed for {ext}" + async def test_any_extension_accepted(self, tmp_path: Path) -> None: + """send_file_to_user accepts any file type, not just images.""" + for ext in [".png", ".jpg", ".pdf", ".docx", ".mp3", ".zip", ".txt"]: + f = tmp_path / f"test{ext}" + f.write_bytes(b"\x00" * 10) + result = await send_file_to_user(str(f)) + assert "File queued for delivery" in result, f"Failed for {ext}" async def test_case_insensitive_extension(self, tmp_path: Path) -> None: img = tmp_path / "photo.JPG" img.write_bytes(b"\x00" * 10) - result = await send_image_to_user(str(img)) - assert "Image queued for delivery" in result + result = await send_file_to_user(str(img)) + assert "File queued for delivery" in result