diff --git a/src/bot/handlers/message.py b/src/bot/handlers/message.py index e5fa9f78..1641a19a 100644 --- a/src/bot/handlers/message.py +++ b/src/bot/handlers/message.py @@ -590,22 +590,16 @@ async def stream_handler(update_obj): if conversation_enhancer and claude_response: try: # Update conversation context - conversation_context = conversation_enhancer.update_context( - session_id=claude_response.session_id, - user_id=user_id, - working_directory=str(current_dir), - tools_used=claude_response.tools_used or [], - response_content=claude_response.content, + conversation_enhancer.update_context(user_id, claude_response) + conversation_context = conversation_enhancer.get_or_create_context( + user_id ) # Check if we should show follow-up suggestions - if conversation_enhancer.should_show_suggestions( - claude_response.tools_used or [], claude_response.content - ): + if conversation_enhancer.should_show_suggestions(claude_response): # Generate follow-up suggestions suggestions = conversation_enhancer.generate_follow_up_suggestions( - claude_response.content, - claude_response.tools_used or [], + claude_response, conversation_context, ) diff --git a/src/claude/sdk_integration.py b/src/claude/sdk_integration.py index adf553f4..a92ce40b 100644 --- a/src/claude/sdk_integration.py +++ b/src/claude/sdk_integration.py @@ -40,6 +40,9 @@ logger = structlog.get_logger() +# Fallback message when Claude produces no text but did use tools. +TASK_COMPLETED_MSG = "✅ Task completed. Tools used: {tools_summary}" + @dataclass class ClaudeResponse: @@ -61,8 +64,114 @@ class StreamUpdate: type: str # 'assistant', 'user', 'system', 'result', 'stream_delta' content: Optional[str] = None - tool_calls: Optional[List[Dict]] = None - metadata: Optional[Dict] = None + tool_calls: Optional[List[Dict[str, Any]]] = None + metadata: Optional[Dict[str, Any]] = None + progress: Optional[Dict[str, Any]] = None + + def get_tool_names(self) -> List[str]: + """Return tool names from the stream payload.""" + names: List[str] = [] + + if self.tool_calls: + for tool_call in self.tool_calls: + name = tool_call.get("name") if isinstance(tool_call, dict) else None + if isinstance(name, str) and name: + names.append(name) + + if self.metadata: + tool_name = self.metadata.get("tool_name") + if isinstance(tool_name, str) and tool_name: + names.append(tool_name) + + metadata_tools = self.metadata.get("tools") + if isinstance(metadata_tools, list): + for tool in metadata_tools: + if isinstance(tool, dict): + name = tool.get("name") + elif isinstance(tool, str): + name = tool + else: + name = None + + if isinstance(name, str) and name: + names.append(name) + + # Preserve insertion order while de-duplicating. + return list(dict.fromkeys(names)) + + def is_error(self) -> bool: + """Check whether this stream update represents an error.""" + if self.type == "error": + return True + + if self.metadata: + if self.metadata.get("is_error") is True: + return True + status = self.metadata.get("status") + if isinstance(status, str) and status.lower() == "error": + return True + error_val = self.metadata.get("error") + if isinstance(error_val, str) and error_val: + return True + error_msg_val = self.metadata.get("error_message") + if isinstance(error_msg_val, str) and error_msg_val: + return True + + if self.progress: + status = self.progress.get("status") + if isinstance(status, str) and status.lower() == "error": + return True + + return False + + def get_error_message(self) -> str: + """Get the best available error message from the stream payload.""" + if self.metadata: + for key in ("error_message", "error", "message"): + value = self.metadata.get(key) + if isinstance(value, str) and value.strip(): + return value + + if isinstance(self.content, str) and self.content.strip(): + return self.content + + if self.progress: + value = self.progress.get("error") + if isinstance(value, str) and value.strip(): + return value + + return "Unknown error" + + def get_progress_percentage(self) -> Optional[int]: + """Extract progress percentage if present.""" + + def _to_int(value: Any) -> Optional[int]: + if isinstance(value, (int, float)): + return int(value) + if isinstance(value, str) and value.strip(): + try: + return int(float(value)) + except ValueError: + return None + return None + + if self.progress: + for key in ("percentage", "percent", "progress"): + percentage = _to_int(self.progress.get(key)) + if percentage is not None: + return max(0, min(100, percentage)) + + step = _to_int(self.progress.get("step")) + total_steps = _to_int(self.progress.get("total_steps")) + if step is not None and total_steps and total_steps > 0: + return max(0, min(100, int((step / total_steps) * 100))) + + if self.metadata: + percentage = _to_int(self.metadata.get("progress_percentage")) + if percentage is not None: + return max(0, min(100, percentage)) + + return None def _make_can_use_tool_callback( @@ -350,7 +459,7 @@ async def _run_client() -> None: # Use ResultMessage.result if available, fall back to message extraction if result_content is not None: - content = result_content + content = str(result_content).strip() else: content_parts = [] for msg in messages: @@ -362,7 +471,17 @@ async def _run_client() -> None: content_parts.append(block.text) elif msg_content: content_parts.append(str(msg_content)) - content = "\n".join(content_parts) + content = "\n".join(content_parts).strip() + + if not content and tools_used: + tool_names = [ + tool.get("name", "") + for tool in tools_used + if isinstance(tool.get("name"), str) and tool.get("name") + ] + unique_tool_names = list(dict.fromkeys(tool_names)) + tools_summary = ", ".join(unique_tool_names) or "unknown" + content = TASK_COMPLETED_MSG.format(tools_summary=tools_summary) return ClaudeResponse( content=content,