diff --git a/examples/hitl-demo/README.md b/examples/hitl-demo/README.md index c297703..22125db 100644 --- a/examples/hitl-demo/README.md +++ b/examples/hitl-demo/README.md @@ -27,7 +27,7 @@ cd /path/to/jaf-py 2. **Python Dependencies**: Install required dependencies: ```bash pip install -e . -pip install fastapi uvicorn # For API demo +pip install uvicorn # For JAF server demo ``` 3. **Model Provider Configuration**: @@ -59,84 +59,123 @@ This runs the interactive file system demo where you can: - See approval context flow to tool execution - Experience persistent approval storage across sessions -#### 🌐 API Demo with HTTP Endpoints +#### 🌐 JAF Server Demo with HTTP Endpoints ```bash python examples/hitl-demo/api_demo.py ``` -This runs both terminal interaction AND HTTP endpoints for approvals. +This runs the JAF server with standard HTTP endpoints for chat and approval management. -## 🌐 API Demo Usage +## 🌐 JAF Server Demo Usage -When running `api_demo.py`, you get both terminal interaction AND HTTP endpoints: +When running `api_demo.py`, you get a full JAF server with standard endpoints: -### API Endpoints +### JAF Server Endpoints | Method | Endpoint | Description | |--------|----------|-------------| -| `GET` | `/pending` | List all pending tool approvals | -| `POST` | `/approve/{sessionId}/{toolCallId}` | Approve a specific tool call | -| `POST` | `/reject/{sessionId}/{toolCallId}` | Reject a specific tool call | -| `GET` | `/health` | Health check and pending count | -| `GET` | `/approvals/stream?conversationId=...` | SSE stream for real-time updates | +| `GET` | `/health` | Health check | +| `GET` | `/agents` | List available agents | +| `POST` | `/chat` | Send chat messages to agents | +| `GET` | `/approvals/pending?conversationId=...` | List pending tool approvals | +| `GET` | `/approvals/stream?conversationId=...` | SSE stream for real-time approval updates | +| `POST` | `/approvals/approve` | Approve a tool call with optional additional context | +| `POST` | `/approvals/reject` | Reject a tool call with optional reason and context | ### Example Workflow -1. **Start the API demo:** +1. **Start the JAF server:** ```bash python examples/hitl-demo/api_demo.py ``` -2. **Check pending approvals via curl:** +2. **Check available agents:** ```bash - curl http://localhost:3001/pending + curl http://localhost:3001/agents ``` -3. **Approve via curl (simple):** +3. **Send a chat message to an agent:** ```bash - curl -X POST http://localhost:3001/approve/SESSION_ID/TOOL_CALL_ID - ``` - -4. **Approve with additional context:** - ```bash - curl -X POST http://localhost:3001/approve/SESSION_ID/TOOL_CALL_ID \ + curl -X POST http://localhost:3001/chat \ -H "Content-Type: application/json" \ -d '{ - "additionalContext": { - "message": "your-additional-context" - } + "agent_name": "FileSystemAgent", + "messages": [ + { + "role": "user", + "content": "list files in the current directory" + } + ] }' ``` -5. **Approve with image context (base64 data):** +4. **Send a message that requires approval:** ```bash - curl -X POST http://localhost:3001/approve/SESSION_ID/TOOL_CALL_ID \ + curl -X POST http://localhost:3001/chat \ -H "Content-Type: application/json" \ -d '{ - "additionalContext": { - "messages": [ - { - "role": "user", - "content": "Analyze this image and make your decision based on it", - "attachments": [ - { - "kind": "image", - "mime_type": "image/png", - "name": "test-pixel.png", - "data": "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8/5+hHgAHggJ/PchI7wAAAABJRU5ErkJggg==" - } - ] - } - ] - } + "agent_name": "FileSystemAgent", + "messages": [{"role": "user", "content": "delete the config.json file"}], + "conversation_id": "YOUR_CONVERSATION_ID" }' ``` -6. **Approve with image context (URL):** +5. **Check pending approvals:** + ```bash + curl "http://localhost:3001/approvals/pending?conversation_id=YOUR_CONVERSATION_ID" + ``` + +### Additional Context Support + +The JAF server supports providing additional context when approving tool calls, including image attachments and custom metadata. Use the standard JAF approval endpoints with additional context payload: + +```bash +# Approve with additional context +curl -X POST "http://localhost:3001/approvals/approve" \ + -H "Content-Type: application/json" \ + -d '{ + "conversationId": "YOUR_CONVERSATION_ID", + "toolCallId": "TOOL_CALL_ID", + "additionalContext": { + "message": "Approved after review" + } + }' +``` + +```bash +# Approve with image context +curl -X POST "http://localhost:3001/approvals/approve" \ + -H "Content-Type: application/json" \ + -d '{ + "conversationId": "YOUR_CONVERSATION_ID", + "toolCallId": "TOOL_CALL_ID", + "additionalContext": { + "messages": [ + { + "role": "user", + "content": "Please review this image before proceeding", + "attachments": [ + { + "kind": "image", + "mime_type": "image/png", + "name": "context.png", + "data": "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8/5+hHgAHggJ/PchI7wAAAABJRU5ErkJggg==" + } + ] + } + ] + } + }' +``` + ``` + +7. **Approve with image context (URL):** ```bash - curl -X POST http://localhost:3001/approve/SESSION_ID/TOOL_CALL_ID \ + curl -X POST "http://localhost:3001/approvals/approve" \ -H "Content-Type: application/json" \ -d '{ + "conversationId": "YOUR_CONVERSATION_ID", + "toolCallId": "TOOL_CALL_ID", "additionalContext": { "messages": [ { @@ -156,16 +195,23 @@ When running `api_demo.py`, you get both terminal interaction AND HTTP endpoints }' ``` -7. **Reject via curl (simple):** +8. **Reject via curl (simple):** ```bash - curl -X POST http://localhost:3001/reject/SESSION_ID/TOOL_CALL_ID + curl -X POST "http://localhost:3001/approvals/reject" \ + -H "Content-Type: application/json" \ + -d '{ + "conversationId": "YOUR_CONVERSATION_ID", + "toolCallId": "TOOL_CALL_ID" + }' ``` -8. **Reject with additional context:** +9. **Reject with additional context:** ```bash - curl -X POST http://localhost:3001/reject/SESSION_ID/TOOL_CALL_ID \ + curl -X POST "http://localhost:3001/approvals/reject" \ -H "Content-Type: application/json" \ -d '{ + "conversationId": "YOUR_CONVERSATION_ID", + "toolCallId": "TOOL_CALL_ID", "reason": "not authorized", "additionalContext": { "rejectedBy": "your-name" diff --git a/examples/hitl-demo/api_demo.py b/examples/hitl-demo/api_demo.py index de0b9dc..3f464b4 100644 --- a/examples/hitl-demo/api_demo.py +++ b/examples/hitl-demo/api_demo.py @@ -1,42 +1,32 @@ #!/usr/bin/env python3 """ -File System HITL API Demo - With HTTP endpoints for approval +File System HITL JAF Server Demo -This demo extends the file system HITL demo with HTTP API endpoints -for remote approval/rejection via curl commands: +This demo showcases the file system HITL functionality using JAF server: - All file operations from the main demo -- HTTP API server for approval management -- curl-based approval/rejection support -- Real-time coordination between terminal and API +- JAF server with standard HTTP API endpoints +- Built-in approval management via JAF endpoints +- Real-time coordination between clients and server Usage: python examples/hitl-demo/api_demo.py """ import asyncio -import json import os import sys import time from pathlib import Path -from typing import Dict, Any, List, Optional -import uuid -import concurrent.futures -import threading # Add the project root to the path sys.path.insert(0, str(Path(__file__).parent.parent.parent)) -from fastapi import FastAPI, HTTPException, Body -from fastapi.middleware.cors import CORSMiddleware -from pydantic import BaseModel import uvicorn -from jaf.core.types import RunState, RunConfig, create_run_id, create_trace_id, Message, ContentRole -from jaf.core.engine import run -from jaf.core.state import approve, reject +from jaf.server.server import create_jaf_server +from jaf.server.types import ServerConfig +from jaf.core.types import RunConfig from jaf.providers.model import make_litellm_provider -from jaf.memory.approval_storage import create_in_memory_approval_storage from shared.agent import file_system_agent, LITELLM_BASE_URL, LITELLM_API_KEY, LITELLM_MODEL from shared.tools import FileSystemContext, DEMO_DIR @@ -46,33 +36,20 @@ # Configuration API_PORT = int(os.getenv('API_PORT', '3001')) -# Global state for pending approvals -pending_approvals: Dict[str, Dict[str, Any]] = {} - - -# Pydantic models for API requests -class ApprovalRequest(BaseModel): - additionalContext: Optional[Dict[str, Any]] = None - - -class RejectionRequest(BaseModel): - reason: Optional[str] = "Rejected via API" - additionalContext: Optional[Dict[str, Any]] = None - def create_model_provider(): """Create model provider - requires LiteLLM configuration.""" # Check if we have environment variables set (not using defaults) has_env_config = os.getenv('LITELLM_BASE_URL') or os.getenv('LITELLM_URL') has_api_key = os.getenv('LITELLM_API_KEY') - + if not has_env_config or not has_api_key: print(Colors.yellow('āŒ No LiteLLM configuration found')) print(Colors.yellow(' Please set LITELLM_BASE_URL and LITELLM_API_KEY environment variables')) print(Colors.yellow(' Example: LITELLM_BASE_URL=http://localhost:4000 LITELLM_API_KEY=your-key python examples/hitl-demo/api_demo.py')) print(Colors.dim(' Or copy examples/hitl-demo/.env.example to .env and configure your LiteLLM server')) sys.exit(1) - + print(Colors.green(f'šŸ¤– Using LiteLLM: {LITELLM_BASE_URL} ({LITELLM_MODEL})')) return make_litellm_provider(LITELLM_BASE_URL, LITELLM_API_KEY) @@ -81,29 +58,29 @@ def setup_sandbox(): """Setup demo sandbox directory.""" try: DEMO_DIR.mkdir(parents=True, exist_ok=True) - + demo_files = [ { 'name': 'README.txt', - 'content': 'Welcome to the File System HITL API Demo!\\nThis is a sample file for testing.' + 'content': 'Welcome to the File System HITL JAF Server Demo!\nThis is a sample file for testing.' }, { 'name': 'config.json', - 'content': '{\\n \"app\": \"filesystem-api-demo\",\\n \"version\": \"1.0.0\",\\n \"api\": true\\n}' + 'content': '{\n "app": "filesystem-jaf-server-demo",\n "version": "1.0.0",\n "server": "JAF"\n}' }, { 'name': 'notes.md', - 'content': '# API Demo Notes\\n\\n- This is a markdown file\\n- You can edit or delete it via terminal or API\\n- Operations require approval' + 'content': '# JAF Server Demo Notes\n\n- This is a markdown file\n- You can edit or delete it via JAF endpoints\n- Operations require approval' } ] - + for file_info in demo_files: file_path = DEMO_DIR / file_info['name'] if not file_path.exists(): file_path.write_text(file_info['content'], encoding='utf-8') - + print(Colors.green(f'šŸ“ Sandbox directory ready: {DEMO_DIR}')) - + except Exception as e: print(Colors.yellow(f'Failed to setup sandbox: {e}')) sys.exit(1) @@ -112,353 +89,64 @@ def setup_sandbox(): def display_welcome(): """Display welcome message.""" os.system('clear' if os.name == 'posix' else 'cls') - print(Colors.cyan('🌐 JAF File System HITL API Demo')) - print(Colors.cyan('====================================')) + print(Colors.cyan('🌐 JAF File System HITL Server Demo')) + print(Colors.cyan('===================================')) print() - - print(Colors.green('This demo showcases HITL with curl-based approval only:')) + + print(Colors.green('This demo showcases HITL with JAF server endpoints:')) print(Colors.green('• Safe operations: listFiles, readFile (no approval)')) print(Colors.green('• Dangerous operations: deleteFile, editFile (require approval)')) - print(Colors.green('• Approve/reject ONLY via curl commands')) - print(Colors.green('• No terminal approval - must use API endpoints')) + print(Colors.green('• Chat via JAF server endpoints')) + print(Colors.green('• Approval management via JAF server endpoints')) + print(Colors.green('• Integrated approval storage in memory provider')) print() - - print(Colors.cyan('Try these commands:')) - print('• \"list files in the current directory\"') - print('• \"read the README file\"') - print('• \"edit the config file to add api: true\"') - print('• \"delete the notes file\"') + + print(Colors.cyan('Example requests:')) + print('• "list files in the current directory"') + print('• "read the README file"') + print('• "edit the config file to add server: JAF"') + print('• "delete the notes file"') print() - - print(Colors.yellow('API Endpoints:')) - print(f'• GET http://localhost:{API_PORT}/pending - List pending approvals') - print(f'• POST http://localhost:{API_PORT}/approve/:sessionId/:toolCallId - Approve') - print(f'• POST http://localhost:{API_PORT}/reject/:sessionId/:toolCallId - Reject') + + print(Colors.yellow('JAF Server Endpoints:')) + print(f'• Health: GET http://localhost:{API_PORT}/health') + print(f'• Agents: GET http://localhost:{API_PORT}/agents') + print(f'• Chat: POST http://localhost:{API_PORT}/chat') + print(f'• Pending Approvals: GET http://localhost:{API_PORT}/approvals/pending?conversationId=...') + print(f'• Approvals SSE Stream: GET http://localhost:{API_PORT}/approvals/stream?conversationId=...') print() - - print(Colors.dim('Commands: type \"exit\" to quit, \"clear\" to clear screen')) + + print(Colors.dim('Use the JAF server endpoints to interact with the agent')) print() -def get_additional_context(tool_name: str) -> Dict[str, Any]: - """Get additional context based on tool.""" - if tool_name == 'deleteFile': - return { - 'deletion_confirmed': { - 'confirmed_by': 'demo-user', - 'timestamp': time.strftime('%Y-%m-%dT%H:%M:%S.%fZ'), - 'backup_created': True - } - } - elif tool_name == 'editFile': - return { - 'editing_approved': { - 'approved_by': 'demo-user', - 'timestamp': time.strftime('%Y-%m-%dT%H:%M:%S.%fZ'), - 'safety_level': 'standard' - } - } - return {} - - -def setup_api_server(): - """Setup HTTP API server.""" - app = FastAPI(title="JAF HITL API Demo", description="File System HITL Demo with FastAPI", version="1.0.0") - - # Add CORS middleware - app.add_middleware( - CORSMiddleware, - allow_origins=["*"], - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], - ) +def setup_api_server(config: RunConfig[FileSystemContext]): + """Setup JAF HTTP API server.""" - @app.get('/health') - def health_check(): - return { - 'status': 'healthy', - 'pending_approvals': len(pending_approvals), - 'timestamp': time.strftime('%Y-%m-%dT%H:%M:%S.%fZ') - } - - @app.get('/pending') - def list_pending_approvals(): - pending_list = [] - for key, data in pending_approvals.items(): - pending_list.append({ - 'key': key, - **data['metadata'] - }) - return pending_list - - @app.post('/approve/{session_id}/{tool_call_id}') - def approve_tool_call(session_id: str, tool_call_id: str, request: ApprovalRequest = Body(default=ApprovalRequest())): - approval_key = f"{session_id}-{tool_call_id}" - - pending = pending_approvals.get(approval_key) - if not pending: - raise HTTPException(status_code=404, detail='Approval request not found') - - additional_context = request.additionalContext or {} - - result = { - 'approved': True, - 'source': 'API', - 'additional_context': { - **get_additional_context(pending['metadata']['tool_name']), - **additional_context, - 'approved_via_api': True - } - } - - # Use the concurrent.futures approach - thread-safe - future = pending['future'] - if not future.done(): - future.set_result(result) - print(f"[API] Approval set for {approval_key}") - else: - print(f"[API] Future already done for {approval_key}") - - return {'message': 'Approval recorded', 'session_id': session_id, 'tool_call_id': tool_call_id} - - @app.post('/reject/{session_id}/{tool_call_id}') - def reject_tool_call(session_id: str, tool_call_id: str, request: RejectionRequest = Body(default=RejectionRequest())): - approval_key = f"{session_id}-{tool_call_id}" - - pending = pending_approvals.get(approval_key) - if not pending: - raise HTTPException(status_code=404, detail='Approval request not found') - - result = { - 'approved': False, - 'source': 'API', - 'additional_context': { - 'rejection_reason': request.reason, - 'rejected_by': 'api-user', - 'timestamp': time.strftime('%Y-%m-%dT%H:%M:%S.%fZ'), - 'rejected_via_api': True, - **(request.additionalContext or {}) - } - } - - # Use the concurrent.futures approach - thread-safe - future = pending['future'] - if not future.done(): - future.set_result(result) - print(f"[API] Rejection set for {approval_key}") - else: - print(f"[API] Future already done for {approval_key}") - - return {'message': 'Rejection recorded', 'session_id': session_id, 'tool_call_id': tool_call_id} + # Create agent registry + agent_registry = {'FileSystemAgent': file_system_agent} - return app + # Server configuration + server_config = ServerConfig( + agent_registry=agent_registry, + run_config=config, + default_memory_provider=config.memory.provider, + cors=True + ) + # Create JAF server + app = create_jaf_server(server_config) -async def handle_approval(interruption: Any) -> Dict[str, Any]: - """Handle approval request (curl-only).""" - tool_call = interruption.tool_call - - # Parse arguments safely - try: - args = json.loads(tool_call.function.arguments) - except json.JSONDecodeError: - args = {"arguments": tool_call.function.arguments} - - approval_key = f"{interruption.session_id}-{tool_call.id}" - - print(Colors.yellow('šŸ›‘ APPROVAL REQUIRED')) - print() - print(Colors.yellow(f'Tool: {tool_call.function.name}')) - print(Colors.yellow('Arguments:')) - for key, value in args.items(): - print(Colors.yellow(f' {key}: {value}')) - print(Colors.yellow(f'Session ID: {interruption.session_id}')) - print(Colors.yellow(f'Tool Call ID: {tool_call.id}')) - print() - - print(Colors.cyan('šŸ’” Use curl to approve/reject:')) - print(f' Approve: curl -X POST http://localhost:{API_PORT}/approve/{interruption.session_id}/{tool_call.id}') - print() - print(f' Approve with context: curl -X POST http://localhost:{API_PORT}/approve/{interruption.session_id}/{tool_call.id} \\') - print(' -H "Content-Type: application/json" \\') - print(' -d \'{"additionalContext": {"message": "your-additional-context"}}\'') - print() - print(Colors.green(' šŸ“Ž Approve with image (base64):')) - print(f' curl -X POST http://localhost:{API_PORT}/approve/{interruption.session_id}/{tool_call.id} \\') - print(' -H "Content-Type: application/json" \\') - print(' -d \'{"additionalContext": {"messages": [{"role": "user", "content": "Here is visual context", "attachments": [{"kind": "image", "mime_type": "image/png", "name": "test.png", "data": "iVBORw0KGgoAAAANSUhEUgAAAAE..."}]}]}}\'') - print() - print(Colors.green(' šŸ“Ž Approve with image (URL):')) - print(f' curl -X POST http://localhost:{API_PORT}/approve/{interruption.session_id}/{tool_call.id} \\') - print(' -H "Content-Type: application/json" \\') - print(' -d \'{"additionalContext": {"messages": [{"role": "user", "content": "Image for context", "attachments": [{"kind": "image", "mime_type": "image/jpeg", "name": "photo.jpg", "url": "https://example.com/image.jpg"}]}]}}\'') - print() - print(f' Reject: curl -X POST http://localhost:{API_PORT}/reject/{interruption.session_id}/{tool_call.id}') - print() - print(f' Reject with context: curl -X POST http://localhost:{API_PORT}/reject/{interruption.session_id}/{tool_call.id} \\') - print(' -H "Content-Type: application/json" \\') - print(' -d \'{"reason": "not authorized", "additionalContext": {"rejectedBy": "your-name"}}\'') - print() - print(f' Check: curl http://localhost:{API_PORT}/pending') - print() - - # Store pending approval for API access only - use concurrent.futures for thread safety - future = concurrent.futures.Future() - pending_approvals[approval_key] = { - 'interruption': interruption, - 'future': future, - 'metadata': { - 'session_id': interruption.session_id, - 'tool_call_id': tool_call.id, - 'tool_name': tool_call.function.name, - 'arguments': args, - 'timestamp': time.strftime('%Y-%m-%dT%H:%M:%S.%fZ') - } - } - - print(Colors.dim('ā³ Waiting for curl approval/rejection...')) - print() - - # Wait for API call only - use asyncio.run_in_executor for blocking concurrent.futures.Future - print(f"[DEBUG] Waiting for future {approval_key}...") - loop = asyncio.get_event_loop() - result = await loop.run_in_executor(None, future.result) - print(f"[DEBUG] Future resolved with result: {result}") - - # Clean up pending approval - pending_approvals.pop(approval_key, None) - - if result['approved']: - print(Colors.green('\\nāœ… Approved via curl! Providing additional context...\\n')) - else: - print(Colors.yellow('\\nāŒ Rejected via curl!\\n')) - - return result - - -async def process_conversation( - user_input: str, - conversation_history: List[Dict[str, str]], - config: RunConfig[FileSystemContext] -) -> tuple[List[Dict[str, str]], bool]: - """Process a single conversation turn.""" - - # Add user message to conversation - new_history = conversation_history + [{'role': 'user', 'content': user_input}] - - context = FileSystemContext( - user_id='api-demo-user', - working_directory=str(DEMO_DIR), - permissions=['read', 'write', 'delete'] - ) - - # Convert history to Message objects - messages = [ - Message(role=ContentRole(msg['role']), content=msg['content']) - for msg in new_history - ] - - state = RunState( - run_id=create_run_id('filesystem-api-demo'), - trace_id=create_trace_id('fs-api-trace'), - messages=messages, - current_agent_name='FileSystemAgent', - context=context, - turn_count=0, - approvals={} - ) - - print(Colors.dim('ā³ Processing...\\n')) - - # Process with the engine - while True: - result = await run(state, config) - - if result.outcome.status == 'interrupted': - interruption = result.outcome.interruptions[0] - - if interruption.type == 'tool_approval': - approval_result = await handle_approval(interruption) - - if approval_result['approved']: - state = await approve(state, interruption, approval_result.get('additional_context'), config) - else: - state = await reject(state, interruption, approval_result.get('additional_context'), config) - - # Continue processing with the approval decision - continue - - elif result.outcome.status == 'completed': - # Add assistant response to conversation history - final_history = new_history + [{'role': 'assistant', 'content': result.outcome.output}] - - print(Colors.cyan('Assistant: ') + str(result.outcome.output) + '\\n') - return final_history, True - - elif result.outcome.status == 'error': - print(Colors.yellow(f'āŒ Error: {result.outcome.error}\\n')) - return new_history, True - - -async def conversation_loop( - conversation_history: List[Dict[str, str]], - config: RunConfig[FileSystemContext] -): - """Main conversation loop (recursive pattern).""" - try: - user_input = input(Colors.green('You: ')).strip() - - if user_input.lower() == 'exit': - print(Colors.cyan('šŸ‘‹ Goodbye!')) - return - - if user_input.lower() == 'clear': - display_welcome() - return await conversation_loop(conversation_history, config) - - if not user_input: - return await conversation_loop(conversation_history, config) - - # Process the conversation turn - conversation_history, should_continue = await process_conversation( - user_input, conversation_history, config - ) - - if should_continue: - # Recursive call to continue the conversation - return await conversation_loop(conversation_history, config) - - except KeyboardInterrupt: - print(Colors.cyan('\\nšŸ‘‹ Goodbye!')) - return - except EOFError: - print(Colors.cyan('\\nšŸ‘‹ Goodbye!')) - return + return app async def main(): """Main demo function.""" display_welcome() setup_sandbox() - - # Setup API server - app = setup_api_server() - - # Generate session ID for this demo run - session_id = f"api-demo-{int(time.time() * 1000)}" - print(Colors.cyan(f'šŸ”— Session ID: {session_id}')) - print() - - model_provider = create_model_provider() - - # Setup memory and approval storage + + # Setup memory provider (now includes approval storage automatically) memory_provider = await setup_memory_provider() - - print(Colors.cyan('šŸ” Setting up approval storage...')) - approval_storage = create_in_memory_approval_storage() - print(Colors.green('āœ… Approval storage initialized')) - print() from jaf.memory.types import MemoryConfig memory_config = MemoryConfig( @@ -467,39 +155,39 @@ async def main(): max_messages=50, store_on_completion=True ) - + config = RunConfig( agent_registry={'FileSystemAgent': file_system_agent}, - model_provider=model_provider, + model_provider=create_model_provider(), memory=memory_config, - conversation_id=f'filesystem-api-demo-{int(time.time() * 1000)}', - approval_storage=approval_storage + conversation_id=f'filesystem-jaf-server-demo-{int(time.time() * 1000)}' ) - - # Start API server in background - server_thread = threading.Thread(target=lambda: uvicorn.run(app, host='127.0.0.1', port=API_PORT, log_level="error")) - server_thread.daemon = True - server_thread.start() - - print(Colors.green(f'🌐 API server running on http://localhost:{API_PORT}')) + + # Setup JAF API server + app = setup_api_server(config) + + # Generate session ID for this demo run + session_id = f"jaf-server-demo-{int(time.time() * 1000)}" + print(Colors.cyan(f'šŸ”— Session ID: {session_id}')) + print() + + print(Colors.green(f'🌐 JAF server running on http://localhost:{API_PORT}')) print(Colors.dim(f' Health: http://localhost:{API_PORT}/health')) - print(Colors.dim(f' Pending: http://localhost:{API_PORT}/pending')) + print(Colors.dim(f' Agents: http://localhost:{API_PORT}/agents')) + print(Colors.dim(f' Chat: http://localhost:{API_PORT}/chat')) print() - - try: - # Start the recursive conversation loop - await conversation_loop([], config) - except Exception as e: - print(Colors.yellow(f'Error: {e}')) - import traceback - traceback.print_exc() + + # Start JAF server + config_uvicorn = uvicorn.Config(app, host='127.0.0.1', port=API_PORT, log_level="info") + server = uvicorn.Server(config_uvicorn) + await server.serve() if __name__ == '__main__': try: asyncio.run(main()) except KeyboardInterrupt: - print(Colors.cyan('\\nšŸ‘‹ Goodbye!')) + print(Colors.cyan('\nšŸ‘‹ Goodbye!')) except Exception as e: print(Colors.yellow(f'Error: {e}')) import traceback diff --git a/examples/hitl-demo/demo.py b/examples/hitl-demo/demo.py deleted file mode 100644 index 81397f0..0000000 --- a/examples/hitl-demo/demo.py +++ /dev/null @@ -1,320 +0,0 @@ -#!/usr/bin/env python3 - -""" -File System HITL Demo - Recursive conversation pattern - -This demo showcases the HITL (Human-in-the-Loop) system with file operations: -- listFiles, readFile: No approval required -- deleteFile, editFile: Require approval -- Uses memory providers from environment -- Uses approval storage for persistence -- Recursive conversation pattern (no while loops) - -Usage: python examples/hitl-demo/demo.py -""" - -import asyncio -import os -import sys -import time -from pathlib import Path -from typing import Dict, Any, List - -# Add the project root to the path -sys.path.insert(0, str(Path(__file__).parent.parent.parent)) - -from jaf.core.types import RunState, RunConfig, create_run_id, create_trace_id, Message, ContentRole -from jaf.core.engine import run -from jaf.core.state import approve, reject -from jaf.providers.model import make_litellm_provider -from jaf.memory.approval_storage import create_in_memory_approval_storage -from jaf.core.tracing import create_composite_trace_collector, ConsoleTraceCollector - -from shared.agent import file_system_agent, LITELLM_BASE_URL, LITELLM_API_KEY, LITELLM_MODEL -from shared.tools import FileSystemContext, DEMO_DIR -from shared.memory import setup_memory_provider, Colors - - -def setup_sandbox(): - """Setup demo sandbox directory.""" - try: - DEMO_DIR.mkdir(parents=True, exist_ok=True) - - # Create some demo files - demo_files = [ - { - 'name': 'README.txt', - 'content': 'Welcome to the File System HITL Demo!\nThis is a sample file for testing.' - }, - { - 'name': 'config.json', - 'content': '{\n "app": "filesystem-demo",\n "version": "1.0.0"\n}' - }, - { - 'name': 'notes.md', - 'content': '# Demo Notes\n\n- This is a markdown file\n- You can edit or delete it\n- Operations require approval' - } - ] - - for file_info in demo_files: - file_path = DEMO_DIR / file_info['name'] - if not file_path.exists(): - file_path.write_text(file_info['content'], encoding='utf-8') - - print(Colors.green(f'šŸ“ Sandbox directory ready: {DEMO_DIR}')) - - except Exception as e: - print(Colors.yellow(f'Failed to setup sandbox: {e}')) - sys.exit(1) - - -def display_welcome(): - """Display welcome message.""" - os.system('clear' if os.name == 'posix' else 'cls') - print(Colors.cyan('šŸ—‚ļø JAF File System Human-in-the-Loop Demo')) - print(Colors.cyan('=' * 48)) - print() - - print(Colors.green('This demo showcases HITL approval for file operations:')) - print(Colors.green('• Safe operations: listFiles, readFile (no approval)')) - print(Colors.green('• Dangerous operations: deleteFile, editFile (require approval)')) - print(Colors.green('• Approval state persists using memory providers')) - print(Colors.green('• Conversation history is maintained across sessions')) - print() - - print(Colors.cyan('Try these commands:')) - print('• "list files in the current directory"') - print('• "read the README file"') - print('• "edit the config file to add a new field"') - print('• "delete the notes file"') - print() - - print(Colors.dim('Commands: type "exit" to quit, "clear" to clear screen')) - print() - - -def create_model_provider(): - """Create model provider - requires LiteLLM configuration.""" - # Check if we have environment variables set (not using defaults) - has_env_config = os.getenv('LITELLM_BASE_URL') or os.getenv('LITELLM_URL') - has_api_key = os.getenv('LITELLM_API_KEY') - - if not has_env_config or not has_api_key: - print(Colors.yellow('āŒ No LiteLLM configuration found')) - print(Colors.yellow(' Please set LITELLM_BASE_URL and LITELLM_API_KEY environment variables')) - print(Colors.yellow(' Example: LITELLM_BASE_URL=http://localhost:4000 LITELLM_API_KEY=your-key python examples/hitl-demo/demo.py')) - print(Colors.dim(' Or copy examples/hitl-demo/.env.example to .env and configure your LiteLLM server')) - sys.exit(1) - - print(Colors.green(f'šŸ¤– Using LiteLLM: {LITELLM_BASE_URL} ({LITELLM_MODEL})')) - return make_litellm_provider(LITELLM_BASE_URL, LITELLM_API_KEY) - - -async def handle_approval(interruption: Any) -> Dict[str, Any]: - """Handle approval request interactively.""" - tool_call = interruption.tool_call - - # Parse arguments safely - try: - import json - args = json.loads(tool_call.function.arguments) - except json.JSONDecodeError: - args = {"arguments": tool_call.function.arguments} - - print(Colors.yellow('šŸ›‘ APPROVAL REQUIRED')) - print() - print(Colors.yellow(f'Tool: {tool_call.function.name}')) - print(Colors.yellow('Arguments:')) - for key, value in args.items(): - print(Colors.yellow(f' {key}: {value}')) - print(Colors.yellow(f'Session ID: {interruption.session_id}')) - print() - - while True: - approval = input(Colors.cyan('Do you approve this action? (y/n): ')).strip().lower() - - if approval in ['y', 'yes']: - print(Colors.green('\nāœ… Approved! Providing additional context...\n')) - - # Provide additional context based on the tool - additional_context = {} - - if tool_call.function.name == 'deleteFile': - additional_context = { - 'deletion_confirmed': { - 'confirmed_by': 'demo-user', - 'timestamp': int(time.time() * 1000), - 'backup_created': True - } - } - elif tool_call.function.name == 'editFile': - additional_context = { - 'editing_approved': { - 'approved_by': 'demo-user', - 'timestamp': int(time.time() * 1000), - 'safety_level': 'standard' - } - } - - return {'approved': True, 'additional_context': additional_context} - - elif approval in ['n', 'no']: - print(Colors.yellow('\nāŒ Rejected!\n')) - return { - 'approved': False, - 'additional_context': { - 'rejection_reason': 'User declined the action', - 'rejected_by': 'demo-user', - 'timestamp': int(time.time() * 1000) - } - } - else: - print(Colors.yellow('Please enter "y" for yes or "n" for no.')) - - -async def process_conversation( - user_input: str, - conversation_history: List[Dict[str, str]], - config: RunConfig[FileSystemContext] -) -> tuple[List[Dict[str, str]], bool]: - """Process a single conversation turn.""" - - # Add user message to conversation - new_history = conversation_history + [{'role': 'user', 'content': user_input}] - - context = FileSystemContext( - user_id='demo-user', - working_directory=str(DEMO_DIR), - permissions=['read', 'write', 'delete'] - ) - - # Convert history to Message objects - messages = [ - Message(role=ContentRole(msg['role']), content=msg['content']) - for msg in new_history - ] - - state = RunState( - run_id=create_run_id('filesystem-demo'), - trace_id=create_trace_id('fs-trace'), - messages=messages, - current_agent_name='FileSystemAgent', - context=context, - turn_count=0, - approvals={} - ) - - print(Colors.dim('ā³ Processing...\n')) - - # Process with the engine - while True: - result = await run(state, config) - - if result.outcome.status == 'interrupted': - interruption = result.outcome.interruptions[0] - - if interruption.type == 'tool_approval': - approval_result = await handle_approval(interruption) - - if approval_result['approved']: - state = await approve(state, interruption, approval_result.get('additional_context'), config) - else: - state = await reject(state, interruption, approval_result.get('additional_context'), config) - - # Continue processing with the approval decision - continue - - elif result.outcome.status == 'completed': - # Add assistant response to conversation history - final_history = new_history + [{'role': 'assistant', 'content': result.outcome.output}] - - print(Colors.cyan('Assistant: ') + str(result.outcome.output) + '\n') - return final_history, True - - elif result.outcome.status == 'error': - print(Colors.yellow(f'āŒ Error: {result.outcome.error}\n')) - return new_history, True - - -async def conversation_loop( - conversation_history: List[Dict[str, str]], - config: RunConfig[FileSystemContext] -): - """Main conversation loop (recursive pattern).""" - try: - user_input = input(Colors.green('You: ')).strip() - - if user_input.lower() == 'exit': - print(Colors.cyan('šŸ‘‹ Goodbye!')) - return - - if user_input.lower() == 'clear': - display_welcome() - return await conversation_loop(conversation_history, config) - - if not user_input: - return await conversation_loop(conversation_history, config) - - # Process the conversation turn - conversation_history, should_continue = await process_conversation( - user_input, conversation_history, config - ) - - if should_continue: - # Recursive call to continue the conversation - return await conversation_loop(conversation_history, config) - - except KeyboardInterrupt: - print(Colors.cyan('\nšŸ‘‹ Goodbye!')) - return - except EOFError: - print(Colors.cyan('\nšŸ‘‹ Goodbye!')) - return - - -async def main(): - """Main demo function.""" - display_welcome() - setup_sandbox() - - model_provider = create_model_provider() - - # Set up memory provider from environment - memory_provider = await setup_memory_provider() - - # Set up approval storage - print(Colors.cyan('šŸ” Setting up approval storage...')) - approval_storage = create_in_memory_approval_storage() - print(Colors.green('āœ… Approval storage initialized')) - - # Set up tracing - trace_collector = create_composite_trace_collector(ConsoleTraceCollector()) - - from jaf.memory.types import MemoryConfig - memory_config = MemoryConfig( - provider=memory_provider, - auto_store=True, - max_messages=50, - store_on_completion=True - ) - - config = RunConfig( - agent_registry={'FileSystemAgent': file_system_agent}, - model_provider=model_provider, - memory=memory_config, - conversation_id=f'filesystem-demo-{int(time.time() * 1000)}', - approval_storage=approval_storage, - on_event=trace_collector.collect - ) - - try: - # Start the recursive conversation loop - await conversation_loop([], config) - except Exception as e: - print(Colors.yellow(f'Error: {e}')) - import traceback - traceback.print_exc() - - -if __name__ == '__main__': - asyncio.run(main()) \ No newline at end of file diff --git a/examples/hitl-demo/requirements.txt b/examples/hitl-demo/requirements.txt index dc08121..8c96fa4 100644 --- a/examples/hitl-demo/requirements.txt +++ b/examples/hitl-demo/requirements.txt @@ -1,9 +1,7 @@ # Core JAF dependencies (should be installed with `pip install -e .` from project root) # Demo-specific dependencies -fastapi>=0.104.0 uvicorn[standard]>=0.35.0 -pydantic>=2.0.0 openai>=1.0.0 # Required by JAF's LiteLLM provider dotenv>=0.9.9 @@ -12,7 +10,4 @@ redis>=5.0.0 # Optional: For PostgreSQL memory provider psycopg2-binary>=2.9.0 -# OR asyncpg>=0.28.0 - -# Optional: For enhanced features -python-multipart>=0.0.6 # For form data parsing \ No newline at end of file +# OR asyncpg>=0.28.0 \ No newline at end of file diff --git a/examples/hitl-demo/run_server.py b/examples/hitl-demo/run_server.py deleted file mode 100644 index 573cd86..0000000 --- a/examples/hitl-demo/run_server.py +++ /dev/null @@ -1,186 +0,0 @@ -#!/usr/bin/env python3 - -""" -HITL Demo Server - HTTP server for JAF HITL functionality - -This module provides a server showcasing JAF's Human-in-the-Loop -capabilities with approval-requiring tools. -""" - -import os -import sys -import time -from pathlib import Path -from typing import Any, Dict - -# Add the project root to the path -sys.path.insert(0, str(Path(__file__).parent.parent.parent)) - -from jaf.server.server import create_jaf_server -from jaf.server.types import ServerConfig -from jaf.providers.model import make_litellm_provider -from jaf.core.types import RunConfig, Agent, Tool, ToolSchema -from jaf.memory.approval_storage import create_in_memory_approval_storage -from pydantic import BaseModel - -from shared.memory import setup_memory_provider, Colors - - -class RedirectParams(BaseModel): - url: str - reason: str = None - - -class SendDataParams(BaseModel): - data: str - recipient: str - - -def create_redirect_tool() -> Tool[RedirectParams, Any]: - """Tool that requires approval - redirect user to different screen/page.""" - - class RedirectTool: - @property - def schema(self) -> ToolSchema[RedirectParams]: - return ToolSchema( - name='redirectUser', - description='Redirect user to a different screen/page', - parameters=RedirectParams - ) - - @property - def needs_approval(self) -> bool: - return True - - async def execute(self, args: RedirectParams, context: Any) -> str: - # Simulate using context provided through approval - prev = f" from {context.get('currentScreen', '')}" if context.get('currentScreen') else "" - reason = args.reason or "n/a" - return f"Redirected user{prev} to {args.url}. Reason: {reason}" - - return RedirectTool() - - -def create_send_data_tool() -> Tool[SendDataParams, Any]: - """Tool that requires approval - send sensitive data.""" - - class SendDataTool: - @property - def schema(self) -> ToolSchema[SendDataParams]: - return ToolSchema( - name='sendSensitiveData', - description='Send sensitive data to a recipient', - parameters=SendDataParams - ) - - @property - def needs_approval(self) -> bool: - return True - - async def execute(self, args: SendDataParams, context: Any) -> str: - level = context.get('encryptionLevel', 'none') - return f"Sent data to {args.recipient} with encryption={level}." - - return SendDataTool() - - -def create_hitl_agent() -> Agent[Any, str]: - """Create the HITL demo agent.""" - return Agent( - name='HITL Demo Agent', - instructions=lambda state: """You are a helpful assistant. Use tools when appropriate. -Tools: -- redirectUser (requires approval) -- sendSensitiveData (requires approval) -""", - tools=[create_redirect_tool(), create_send_data_tool()], - model_config={'name': os.getenv('LITELLM_MODEL', 'gpt-3.5-turbo'), 'temperature': 0.1} - ) - - -def create_model_provider(): - """Create model provider with environment configuration.""" - base_url = os.getenv('LITELLM_URL', 'http://localhost:4000') - api_key = os.getenv('LITELLM_API_KEY', 'sk-demo') - - print(Colors.green(f'šŸ¤– Using LiteLLM: {base_url}')) - return make_litellm_provider(base_url, api_key) - - -async def main(): - """Main server function.""" - - # Configuration - host = os.getenv('HOST', '127.0.0.1') - port = int(os.getenv('PORT', '3000')) - - # Model provider - model_provider = create_model_provider() - - # Memory provider from env - memory_provider = await setup_memory_provider() - - # Approval storage - approval_storage = create_in_memory_approval_storage() - - # Create agent - hitl_agent = create_hitl_agent() - agent_registry = {'HITL Demo Agent': hitl_agent} - - # Run configuration - from jaf.memory.types import MemoryConfig - memory_config = MemoryConfig( - provider=memory_provider, - auto_store=True, - max_messages=200 - ) - - run_config = RunConfig( - agent_registry=agent_registry, - model_provider=model_provider, - max_turns=6, - memory=memory_config, - approval_storage=approval_storage - ) - - # Server configuration - server_config = ServerConfig( - agent_registry=agent_registry, - run_config=run_config, - default_memory_provider=memory_provider, - cors=True - ) - - # Create server - app = create_jaf_server(server_config) - - # Usage hints - print(Colors.green('\\nāœ… HITL Server Running')) - print(f'Base URL: http://{host}:{port}') - print() - - print('Endpoints:') - print(f'• Health: GET /health') - print(f'• Agents: GET /agents') - print(f'• Chat: POST /chat') - print(f'• Pending Approvals: GET /approvals/pending?conversationId=...') - print(f'• Approvals SSE Stream: GET /approvals/stream?conversationId=...') - print() - - # Start server - import uvicorn - config = uvicorn.Config(app, host=host, port=port, log_level="info") - server = uvicorn.Server(config) - await server.serve() - - -if __name__ == '__main__': - import asyncio - try: - asyncio.run(main()) - except KeyboardInterrupt: - print(Colors.cyan('\\nšŸ‘‹ Server stopped')) - except Exception as e: - print(Colors.yellow(f'Error: {e}')) - import traceback - traceback.print_exc() \ No newline at end of file diff --git a/examples/hitl-demo/shared/tools.py b/examples/hitl-demo/shared/tools.py index 3806c2d..26d1341 100644 --- a/examples/hitl-demo/shared/tools.py +++ b/examples/hitl-demo/shared/tools.py @@ -24,36 +24,49 @@ class FileSystemContext: permissions: List[str] +def get_context_attr(context, attr: str, default=None): + """Helper to get attribute from context whether it's a dict or dataclass.""" + if hasattr(context, attr): + return getattr(context, attr) + elif isinstance(context, dict): + return context.get(attr, default) + else: + return default + + # Get demo directory path DEMO_DIR = Path(__file__).parent.parent / "sandbox" class ListFilesArgs(BaseModel): """Arguments for listing files.""" - directory: Optional[str] = Field(None, description="Directory to list (relative to working directory)") + directory: Optional[str] = Field(None, description="Directory to list (defaults to current directory)") class ReadFileArgs(BaseModel): - """Arguments for reading files.""" - filepath: str = Field(description="Path to the file to read (relative to working directory)") + """Arguments for reading a file.""" + filepath: str = Field(description="Path to the file to read") class DeleteFileArgs(BaseModel): - """Arguments for deleting files.""" - filepath: str = Field(description="Path to the file to delete (relative to working directory)") + """Arguments for deleting a file.""" + filepath: str = Field(description="Path to the file to delete") reason: Optional[str] = Field(None, description="Reason for deletion") class EditFileArgs(BaseModel): - """Arguments for editing files.""" - filepath: str = Field(description="Path to the file to edit (relative to working directory)") + """Arguments for editing a file.""" + filepath: str = Field(description="Path to the file to edit or create") content: str = Field(description="New content for the file") backup: Optional[bool] = Field(False, description="Whether to create a backup before editing") class ListFilesTool: """Tool for listing files and directories.""" - + + def __init__(self, sandbox_dir: Path = DEMO_DIR): + self.sandbox_dir = sandbox_dir + @property def schema(self) -> ToolSchema[ListFilesArgs]: return ToolSchema( @@ -61,48 +74,51 @@ def schema(self) -> ToolSchema[ListFilesArgs]: description="List files and directories in the specified directory", parameters=ListFilesArgs ) - + @property def needs_approval(self) -> bool: return False - - async def execute(self, args: ListFilesArgs, context: FileSystemContext) -> str: + + async def execute(self, args: ListFilesArgs, context: Any = None) -> str: try: - target_dir = DEMO_DIR + target_dir = self.sandbox_dir if args.directory: - target_dir = (Path(context.working_directory) / args.directory).resolve() - + target_dir = (self.sandbox_dir / args.directory).resolve() + # Security check - ensure we stay within sandbox - if not str(target_dir).startswith(str(DEMO_DIR)): + if not str(target_dir).startswith(str(self.sandbox_dir)): return f"Error: Access denied. Directory outside of sandbox: {target_dir}" - + if not target_dir.exists(): return f"Error: Directory does not exist: {target_dir}" - + items = [] for item in target_dir.iterdir(): item_type = "directory" if item.is_dir() else "file" - relative_path = item.relative_to(Path(context.working_directory)) + relative_path = item.relative_to(self.sandbox_dir) items.append({ "name": item.name, "type": item_type, "path": str(relative_path) }) - + file_list = "\n".join([ f" {'šŸ“' if item['type'] == 'directory' else 'šŸ“„'} {item['name']}" for item in items ]) - + return f"Directory listing for {target_dir}:\n{file_list}" - + except Exception as e: return f"Error listing directory: {str(e)}" class ReadFileTool: """Tool for reading file contents.""" - + + def __init__(self, sandbox_dir: Path = DEMO_DIR): + self.sandbox_dir = sandbox_dir + @property def schema(self) -> ToolSchema[ReadFileArgs]: return ToolSchema( @@ -110,32 +126,38 @@ def schema(self) -> ToolSchema[ReadFileArgs]: description="Read the contents of a file", parameters=ReadFileArgs ) - + @property def needs_approval(self) -> bool: return False - - async def execute(self, args: ReadFileArgs, context: FileSystemContext) -> str: + + async def execute(self, args: ReadFileArgs, context: Any = None) -> str: try: - target_path = (Path(context.working_directory) / args.filepath).resolve() - + target_path = (self.sandbox_dir / args.filepath).resolve() + # Security check - ensure we stay within sandbox - if not str(target_path).startswith(str(DEMO_DIR)): + if not str(target_path).startswith(str(self.sandbox_dir)): return f"Error: Access denied. File outside of sandbox: {target_path}" - + if not target_path.exists(): return f"Error: File does not exist: {args.filepath}" - + + if not target_path.is_file(): + return f"Error: {args.filepath} is not a file" + content = target_path.read_text(encoding='utf-8') - return f"Contents of {args.filepath}:\n```\n{content}\n```" - + return f"Contents of {args.filepath}:\n{content}" + except Exception as e: return f"Error reading file: {str(e)}" class DeleteFileTool: - """Tool for deleting files (requires approval).""" - + """Tool for deleting files.""" + + def __init__(self, sandbox_dir: Path = DEMO_DIR): + self.sandbox_dir = sandbox_dir + @property def schema(self) -> ToolSchema[DeleteFileArgs]: return ToolSchema( @@ -143,88 +165,93 @@ def schema(self) -> ToolSchema[DeleteFileArgs]: description="Delete a file (requires approval)", parameters=DeleteFileArgs ) - + @property def needs_approval(self) -> bool: return True - - async def execute(self, args: DeleteFileArgs, context: FileSystemContext) -> str: + + async def execute(self, args: DeleteFileArgs, context: Any = None) -> str: try: - target_path = (Path(context.working_directory) / args.filepath).resolve() - + target_path = (self.sandbox_dir / args.filepath).resolve() + # Security check - ensure we stay within sandbox - if not str(target_path).startswith(str(DEMO_DIR)): + if not str(target_path).startswith(str(self.sandbox_dir)): return f"Error: Access denied. File outside of sandbox: {target_path}" - + if not target_path.exists(): return f"Error: File does not exist: {args.filepath}" - + target_path.unlink() print(f"šŸ—‘ļø File deleted: {args.filepath}") if args.reason: print(f" Reason: {args.reason}") - + # Check for approval context - if hasattr(context, 'deletion_confirmed') and context.deletion_confirmed: - print(f" Confirmed by: {context.deletion_confirmed.get('confirmed_by')}") - print(f" Backup created: {context.deletion_confirmed.get('backup_created')}") - + deletion_confirmed = get_context_attr(context, 'deletion_confirmed') + if deletion_confirmed: + print(f" Confirmed by: {deletion_confirmed.get('confirmed_by')}") + print(f" Backup created: {deletion_confirmed.get('backup_created')}") + reason_text = f" (Reason: {args.reason})" if args.reason else "" return f"Successfully deleted file: {args.filepath}{reason_text}" - + except Exception as e: return f"Error deleting file: {str(e)}" class EditFileTool: - """Tool for editing files (requires approval).""" - + """Tool for editing or creating files.""" + + def __init__(self, sandbox_dir: Path = DEMO_DIR): + self.sandbox_dir = sandbox_dir + @property def schema(self) -> ToolSchema[EditFileArgs]: return ToolSchema( - name="editFile", + name="editFile", description="Edit or create a file with new content (requires approval)", parameters=EditFileArgs ) - + @property def needs_approval(self) -> bool: return True - - async def execute(self, args: EditFileArgs, context: FileSystemContext) -> str: + + async def execute(self, args: EditFileArgs, context: Any = None) -> str: try: - target_path = (Path(context.working_directory) / args.filepath).resolve() - + target_path = (self.sandbox_dir / args.filepath).resolve() + # Security check - ensure we stay within sandbox - if not str(target_path).startswith(str(DEMO_DIR)): + if not str(target_path).startswith(str(self.sandbox_dir)): return f"Error: Access denied. File outside of sandbox: {target_path}" - + backup_path = "" if args.backup and target_path.exists(): backup_path = f"{target_path}.backup.{int(asyncio.get_event_loop().time() * 1000)}" target_path.rename(backup_path) print(f"šŸ’¾ Backup created: {backup_path}") - + # Ensure parent directory exists target_path.parent.mkdir(parents=True, exist_ok=True) - + target_path.write_text(args.content, encoding='utf-8') print(f"āœļø File edited: {args.filepath}") - + # Check for approval context - if hasattr(context, 'editing_approved') and context.editing_approved: - print(f" Approved by: {context.editing_approved.get('approved_by')}") - print(f" Safety level: {context.editing_approved.get('safety_level')}") - + editing_approved = get_context_attr(context, 'editing_approved') + if editing_approved: + print(f" Approved by: {editing_approved.get('approved_by')}") + print(f" Safety level: {editing_approved.get('safety_level')}") + backup_text = f" (Backup: {Path(backup_path).name})" if backup_path else "" return f"Successfully edited file: {args.filepath}{backup_text}" - + except Exception as e: return f"Error editing file: {str(e)}" -# Create tool instances +# Create tool instances with default sandbox configuration list_files_tool = ListFilesTool() -read_file_tool = ReadFileTool() +read_file_tool = ReadFileTool() delete_file_tool = DeleteFileTool() edit_file_tool = EditFileTool() \ No newline at end of file diff --git a/jaf/core/engine.py b/jaf/core/engine.py index 698a4d2..0ce492b 100644 --- a/jaf/core/engine.py +++ b/jaf/core/engine.py @@ -158,12 +158,23 @@ async def try_resume_pending_tool_calls( outcome=InterruptedOutcome(interruptions=interruptions) ) + # Collect enhanced contexts from tool executions + enhanced_contexts = [r.get('enhanced_context') for r in tool_results if r.get('enhanced_context')] + + # Merge enhanced contexts into state context if any were provided + final_context = state.context + if enhanced_contexts: + print(f'[JAF:APPROVAL] Merging {len(enhanced_contexts)} enhanced contexts into state for resume') + # Take the most recent enhanced context (last tool that provided enhancement) + final_context = enhanced_contexts[-1] + # Continue with normal execution next_state = replace( state, messages=list(state.messages) + [r['message'] for r in tool_results], turn_count=state.turn_count, - approvals=state.approvals + approvals=state.approvals, + context=final_context ) return await _run_internal(next_state, config) @@ -197,9 +208,11 @@ async def run( messages=state_with_memory.messages, # Now includes full conversation history agent_name=state_with_memory.current_agent_name )))) + + print(f'[JAF:ENGINE] Loaded context for runId {state_with_memory.run_id}: {state_with_memory.context}') # Load approvals from storage if configured - if config.approval_storage: + if config.memory and config.memory.provider: print(f'[JAF:ENGINE] Loading approvals for runId {state_with_memory.run_id}') from .state import load_approvals_into_state state_with_memory = await load_approvals_into_state(state_with_memory, config) @@ -269,9 +282,11 @@ async def _load_conversation_history(state: RunState[Ctx], config: RunConfig[Ctx all_memory_messages = conversation_data.messages[-max_messages:] # Filter out halted messages - they're for audit/database only, not for LLM context + # Also extract approval context from approved_and_executed messages for LLM visibility memory_messages = [] filtered_count = 0 - + extracted_approval_contexts = [] + for msg in all_memory_messages: if msg.role not in (ContentRole.TOOL, 'tool'): memory_messages.append(msg) @@ -280,28 +295,50 @@ async def _load_conversation_history(state: RunState[Ctx], config: RunConfig[Ctx content = json.loads(msg.content) status = content.get('status') hitl_status = content.get('hitl_status') + # Filter out ALL halted/pending approval messages (they're for audit only) if status == 'halted' or hitl_status == 'pending_approval': filtered_count += 1 continue # Skip this halted message - else: - memory_messages.append(msg) + + # Extract approval context from approved_and_executed messages + if hitl_status == 'approved_and_executed': + approval_context = content.get('approval_context') + if approval_context: + # Extract useful context information + context_message = approval_context.get('message', '') + if context_message: + extracted_approval_contexts.append({ + 'tool_name': content.get('tool_name', 'unknown'), + 'context_message': context_message, + 'approval_context': approval_context + }) + + memory_messages.append(msg) except (json.JSONDecodeError, TypeError): # Keep non-JSON tool messages memory_messages.append(msg) + # Inject extracted approval contexts as system messages for LLM visibility + if extracted_approval_contexts: + approval_context_messages = [] + for ctx in extracted_approval_contexts: + context_text = f"Previous approval context for {ctx['tool_name']}: {ctx['context_message']}" + + # Create a system message with the approval context + approval_msg = Message( + role=ContentRole.SYSTEM, + content=context_text + ) + approval_context_messages.append(approval_msg) + + # Insert approval context messages at the beginning for visibility + memory_messages = approval_context_messages + memory_messages + # For HITL scenarios, append new messages to memory messages # This prevents duplication when resuming from interruptions if memory_messages: - combined_messages = memory_messages + [ - msg for msg in state.messages - if not any( - mem_msg.role == msg.role and - mem_msg.content == msg.content and - getattr(mem_msg, 'tool_calls', None) == getattr(msg, 'tool_calls', None) - for mem_msg in memory_messages - ) - ] + combined_messages = memory_messages + list(state.messages) else: combined_messages = list(state.messages) @@ -332,6 +369,9 @@ async def _load_conversation_history(state: RunState[Ctx], config: RunConfig[Ctx else: print(f'[JAF:MEMORY] Loaded {len(all_memory_messages)} messages from memory') + if extracted_approval_contexts: + print(f'[JAF:APPROVAL] Extracted {len(extracted_approval_contexts)} approval contexts for LLM visibility: {[ctx["tool_name"] for ctx in extracted_approval_contexts]}') + return replace( state, messages=combined_messages, @@ -363,8 +403,8 @@ async def _store_conversation_history(state: RunState[Ctx], config: RunConfig[Ct if state.approvals: approval_metadata = { "approval_count": len(state.approvals), - "approved_tools": [tool_id for tool_id, approval in state.approvals.items() if approval.approved], - "rejected_tools": [tool_id for tool_id, approval in state.approvals.items() if not approval.approved], + "approved_tools": [tool_id for tool_id, approval in state.approvals.items() if getattr(approval, 'approved', approval.get('approved') if isinstance(approval, dict) else False)], + "rejected_tools": [tool_id for tool_id, approval in state.approvals.items() if not getattr(approval, 'approved', approval.get('approved') if isinstance(approval, dict) else False)], "has_approvals": True } @@ -775,13 +815,24 @@ async def _run_internal( except (json.JSONDecodeError, TypeError): cleaned_new_messages.append(msg) + # Collect enhanced contexts from tool executions + enhanced_contexts = [r.get('enhanced_context') for r in tool_results if r.get('enhanced_context')] + + # Merge enhanced contexts into state context if any were provided + final_context = state.context + if enhanced_contexts: + print(f'[JAF:APPROVAL] Merging {len(enhanced_contexts)} enhanced contexts into state for handoff') + # Take the most recent enhanced context (last tool that provided enhancement) + final_context = enhanced_contexts[-1] + # Continue with new agent next_state = replace( state, messages=cleaned_new_messages + [r['message'] for r in tool_results], current_agent_name=target_agent, turn_count=state.turn_count + 1, - approvals=state.approvals + approvals=state.approvals, + context=final_context ) return await _run_internal(next_state, config) @@ -803,12 +854,23 @@ async def _run_internal( except (json.JSONDecodeError, TypeError): cleaned_new_messages.append(msg) + # Collect enhanced contexts from tool executions + enhanced_contexts = [r.get('enhanced_context') for r in tool_results if r.get('enhanced_context')] + + # Merge enhanced contexts into state context if any were provided + final_context = state.context + if enhanced_contexts: + print(f'[JAF:APPROVAL] Merging {len(enhanced_contexts)} enhanced contexts into state') + # Take the most recent enhanced context (last tool that provided enhancement) + final_context = enhanced_contexts[-1] + # Continue with tool results next_state = replace( state, messages=cleaned_new_messages + [r['message'] for r in tool_results], turn_count=state.turn_count + 1, - approvals=state.approvals + approvals=state.approvals, + context=final_context ) return await _run_internal(next_state, config) @@ -1084,23 +1146,28 @@ async def execute_single_tool_call(tool_call: ToolCall) -> Dict[str, Any]: if not approval_status: signature = f"{tool_call.function.name}:{tool_call.function.arguments}" for _, approval in state.approvals.items(): - if approval.additional_context and approval.additional_context.get('signature') == signature: + additional_context = getattr(approval, 'additional_context', approval.get('additional_context') if isinstance(approval, dict) else None) + if additional_context and additional_context.get('signature') == signature: approval_status = approval break derived_status = None if approval_status: # Use explicit status if available - if approval_status.status: - derived_status = approval_status.status + status = getattr(approval_status, 'status', approval_status.get('status') if isinstance(approval_status, dict) else None) + if status: + derived_status = status # Fall back to approved boolean if status not set - elif approval_status.approved is True: - derived_status = 'approved' - elif approval_status.approved is False: - if approval_status.additional_context and approval_status.additional_context.get('status') == 'pending': - derived_status = 'pending' - else: - derived_status = 'rejected' + else: + approved = getattr(approval_status, 'approved', approval_status.get('approved') if isinstance(approval_status, dict) else None) + if approved is True: + derived_status = 'approved' + elif approved is False: + additional_context = getattr(approval_status, 'additional_context', approval_status.get('additional_context') if isinstance(approval_status, dict) else None) + if additional_context and additional_context.get('status') == 'pending': + derived_status = 'pending' + else: + derived_status = 'rejected' is_pending = derived_status == 'pending' @@ -1130,13 +1197,19 @@ async def execute_single_tool_call(tool_call: ToolCall) -> Dict[str, Any]: # If approval was explicitly rejected, return rejection message if derived_status == 'rejected': - rejection_reason = approval_status.additional_context.get('rejection_reason', 'User declined the action') if approval_status.additional_context else 'User declined the action' + approval_context = getattr(approval_status, 'additional_context', approval_status.get('additional_context') if isinstance(approval_status, dict) else None) if approval_status else None + rejection_reason = approval_context.get('rejection_reason', 'User declined the action') if approval_context else 'User declined the action' + context_message = approval_context.get('message', '') if approval_context else '' + + base_message = f'Action was not approved. {rejection_reason}. Please ask if you can help with something else or suggest an alternative approach.' + full_message = f'{base_message} IMPORTANT: {context_message}' if context_message else base_message + rejection_result = json.dumps({ 'hitl_status': 'rejected', # HITL workflow status: user rejected the action - 'message': f'Action was not approved. {rejection_reason}. Please ask if you can help with something else or suggest an alternative approach.', + 'message': full_message, 'tool_name': tool_call.function.name, 'rejection_reason': rejection_reason, - 'additional_context': approval_status.additional_context if approval_status else None + 'additional_context': approval_context }) return { @@ -1157,8 +1230,9 @@ async def execute_single_tool_call(tool_call: ToolCall) -> Dict[str, Any]: timeout = config.default_tool_timeout if config.default_tool_timeout is not None else 300.0 # Merge additional context if provided through approval - additional_context = approval_status.additional_context if approval_status else None + additional_context = getattr(approval_status, 'additional_context', approval_status.get('additional_context') if isinstance(approval_status, dict) else None) if approval_status else None context_with_additional = state.context + context_was_enhanced = False if additional_context: # Create a copy of context with additional fields from approval if hasattr(state.context, '__dict__'): @@ -1172,6 +1246,8 @@ async def execute_single_tool_call(tool_call: ToolCall) -> Dict[str, Any]: else: # For dict contexts, merge normally context_with_additional = {**state.context, **additional_context} + context_was_enhanced = True + print(f'[JAF:APPROVAL] Enhanced context for tool {tool_call.function.name} with additional context: {additional_context}') print(f'[JAF:ENGINE] About to execute tool: {tool_call.function.name}') print(f'[JAF:ENGINE] Tool args:', validated_args) @@ -1221,13 +1297,15 @@ async def execute_single_tool_call(tool_call: ToolCall) -> Dict[str, Any]: print(f'[JAF:ENGINE] Converted to string:', result_string) # Wrap tool result with status information for approval context - if approval_status and approval_status.additional_context: + if approval_status and getattr(approval_status, 'additional_context', approval_status.get('additional_context') if isinstance(approval_status, dict) else None): + approval_context = getattr(approval_status, 'additional_context', approval_status.get('additional_context') if isinstance(approval_status, dict) else None) + context_message = approval_context.get('message', '') if approval_context else '' final_content = json.dumps({ 'hitl_status': 'approved_and_executed', # HITL workflow status: approved by user and executed 'result': result_string, 'tool_name': tool_call.function.name, - 'approval_context': approval_status.additional_context, - 'message': 'Tool was approved and executed successfully with additional context.' + 'approval_context': approval_context, + 'message': f'Tool was approved and executed successfully. IMPORTANT: {context_message}' if context_message else 'Tool was approved and executed successfully with additional context.' }) elif needs_approval: final_content = json.dumps({ @@ -1268,7 +1346,7 @@ async def execute_single_tool_call(tool_call: ToolCall) -> Dict[str, Any]: 'target_agent': handoff_check['handoff_to'] } - return { + result_dict = { 'message': Message( role=ContentRole.TOOL, content=final_content, @@ -1276,6 +1354,13 @@ async def execute_single_tool_call(tool_call: ToolCall) -> Dict[str, Any]: ) } + # If context was enhanced with additional approval context, pass it back + if context_was_enhanced: + result_dict['enhanced_context'] = context_with_additional + print(f'[JAF:APPROVAL] Propagating enhanced context back from tool {tool_call.function.name}') + + return result_dict + except Exception as error: error_result = json.dumps({ 'hitl_status': 'execution_error', # HITL workflow status diff --git a/jaf/core/state.py b/jaf/core/state.py index a9818be..1613435 100644 --- a/jaf/core/state.py +++ b/jaf/core/state.py @@ -133,25 +133,31 @@ async def approve( Updated run state with approval recorded """ if interruption.type == 'tool_approval': + print(f"[JAF:APPROVAL] approve() called with additional_context: {additional_context}") + + # Ensure user additional context is preserved and internal status is non-conflicting + merged_context = {**(additional_context or {})} + if 'status' not in merged_context: # Only add status if user didn't provide it + merged_context['approval_status'] = 'approved' # Use non-conflicting key + approval_value = ApprovalValue( status='approved', approved=True, - additional_context={ - **(additional_context or {}), - 'status': 'approved' - } + additional_context=merged_context ) + + print(f"[JAF:APPROVAL] Created approval_value with additional_context: {approval_value.additional_context}") # Store in approval storage if available - if config and config.approval_storage: + if config and config.memory and config.memory.provider: try: print(f"[JAF:APPROVAL] Storing approval for tool_call_id {interruption.tool_call.id}: {approval_value}") - result = await config.approval_storage.store_approval( + result = await config.memory.provider.store_approval( state.run_id, interruption.tool_call.id, approval_value ) - if not result.success: + if hasattr(result, 'error'): print(f"[JAF:APPROVAL] Failed to store approval: {result.error}") # Continue with in-memory fallback else: @@ -192,25 +198,31 @@ async def reject( Updated run state with rejection recorded """ if interruption.type == 'tool_approval': + print(f"[JAF:APPROVAL] reject() called with additional_context: {additional_context}") + + # Ensure user additional context is preserved and internal status is non-conflicting + merged_context = {**(additional_context or {})} + if 'status' not in merged_context: # Only add status if user didn't provide it + merged_context['approval_status'] = 'rejected' # Use non-conflicting key + approval_value = ApprovalValue( status='rejected', approved=False, - additional_context={ - **(additional_context or {}), - 'status': 'rejected' - } + additional_context=merged_context ) + + print(f"[JAF:APPROVAL] Created approval_value with additional_context: {approval_value.additional_context}") # Store in approval storage if available - if config and config.approval_storage: + if config and config.memory and config.memory.provider: try: print(f"[JAF:APPROVAL] Storing approval for tool_call_id {interruption.tool_call.id}: {approval_value}") - result = await config.approval_storage.store_approval( + result = await config.memory.provider.store_approval( state.run_id, interruption.tool_call.id, approval_value ) - if not result.success: + if hasattr(result, 'error'): print(f"[JAF:APPROVAL] Failed to store approval: {result.error}") # Continue with in-memory fallback else: @@ -246,18 +258,18 @@ async def load_approvals_into_state( Returns: Updated run state with loaded approvals """ - if not config or not config.approval_storage: + if not config or not config.memory and config.memory.provider: print(f"[JAF:APPROVAL] No approval storage configured, using existing approvals: {state.approvals}") return state try: print(f"[JAF:APPROVAL] Loading approvals from storage for run_id: {state.run_id}") - result = await config.approval_storage.get_run_approvals(state.run_id) - if result.success and result.data: + result = await config.memory.provider.get_run_approvals(state.run_id) + if hasattr(result, 'data') and result.data: print(f"[JAF:APPROVAL] Loaded {len(result.data)} approvals from storage: {result.data}") return replace(state, approvals=result.data) else: - if not result.success: + if hasattr(result, 'error'): print(f"[JAF:APPROVAL] Failed to load approvals: {result.error}") else: print(f"[JAF:APPROVAL] No approvals found in storage for run_id: {state.run_id}") diff --git a/jaf/core/types.py b/jaf/core/types.py index 3b70b02..b8179ad 100644 --- a/jaf/core/types.py +++ b/jaf/core/types.py @@ -29,7 +29,6 @@ if TYPE_CHECKING: from .tool_results import ToolResult - from ..memory.approval_storage import ApprovalStorage from ..memory.types import MemoryConfig @@ -823,4 +822,3 @@ class RunConfig(Generic[Ctx]): conversation_id: Optional[str] = None default_fast_model: Optional[str] = None # Default model for fast operations like guardrails default_tool_timeout: Optional[float] = 300.0 # Default timeout for tool execution in seconds - approval_storage: Optional['ApprovalStorage'] = None # Storage for approval decisions diff --git a/jaf/memory/approval_storage.py b/jaf/memory/approval_storage.py deleted file mode 100644 index 9b66b99..0000000 --- a/jaf/memory/approval_storage.py +++ /dev/null @@ -1,306 +0,0 @@ -""" -Approval storage interface and implementations for Human-in-the-Loop (HITL) functionality. - -This module provides persistent storage for tool approval decisions, enabling -the framework to maintain approval states across conversation sessions and -handle interruptions gracefully. -""" - -from abc import ABC, abstractmethod -from typing import Dict, Any, Optional -import asyncio - -from ..core.types import RunId, ApprovalValue - - -class ApprovalStorageResult: - """Result wrapper for approval storage operations.""" - - def __init__(self, success: bool, data: Any = None, error: Optional[str] = None): - self.success = success - self.data = data - self.error = error - - @classmethod - def success_result(cls, data: Any = None) -> 'ApprovalStorageResult': - """Create a successful result.""" - return cls(success=True, data=data) - - @classmethod - def error_result(cls, error: str) -> 'ApprovalStorageResult': - """Create an error result.""" - return cls(success=False, error=error) - - -class ApprovalStorage(ABC): - """Abstract interface for approval storage implementations.""" - - @abstractmethod - async def store_approval( - self, - run_id: RunId, - tool_call_id: str, - approval: ApprovalValue, - metadata: Optional[Dict[str, Any]] = None - ) -> ApprovalStorageResult: - """Store an approval decision for a tool call.""" - pass - - @abstractmethod - async def get_approval( - self, - run_id: RunId, - tool_call_id: str - ) -> ApprovalStorageResult: - """Retrieve approval for a specific tool call. Returns None if not found.""" - pass - - @abstractmethod - async def get_run_approvals( - self, - run_id: RunId - ) -> ApprovalStorageResult: - """Get all approvals for a run as a Dict[str, ApprovalValue].""" - pass - - @abstractmethod - async def update_approval( - self, - run_id: RunId, - tool_call_id: str, - updates: Dict[str, Any] - ) -> ApprovalStorageResult: - """Update existing approval with additional context.""" - pass - - @abstractmethod - async def delete_approval( - self, - run_id: RunId, - tool_call_id: str - ) -> ApprovalStorageResult: - """Delete approval for a tool call. Returns success status.""" - pass - - @abstractmethod - async def clear_run_approvals(self, run_id: RunId) -> ApprovalStorageResult: - """Clear all approvals for a run. Returns count of deleted approvals.""" - pass - - @abstractmethod - async def get_stats(self) -> ApprovalStorageResult: - """Get approval statistics.""" - pass - - @abstractmethod - async def health_check(self) -> ApprovalStorageResult: - """Health check for the approval storage.""" - pass - - @abstractmethod - async def close(self) -> ApprovalStorageResult: - """Close/cleanup the storage.""" - pass - - -class InMemoryApprovalStorage(ApprovalStorage): - """In-memory implementation of ApprovalStorage for development and testing.""" - - def __init__(self): - self._approvals: Dict[str, Dict[str, ApprovalValue]] = {} - self._lock = asyncio.Lock() - - def _get_run_key(self, run_id: RunId) -> str: - """Generate a consistent key for a run.""" - return f"run:{run_id}" - - async def store_approval( - self, - run_id: RunId, - tool_call_id: str, - approval: ApprovalValue, - metadata: Optional[Dict[str, Any]] = None - ) -> ApprovalStorageResult: - """Store an approval decision.""" - try: - async with self._lock: - run_key = self._get_run_key(run_id) - - if run_key not in self._approvals: - self._approvals[run_key] = {} - - # Enhance approval with metadata if provided - enhanced_approval = approval - if metadata: - additional_context = {**(approval.additional_context or {}), **metadata} - enhanced_approval = ApprovalValue( - status=approval.status, - approved=approval.approved, - additional_context=additional_context - ) - - self._approvals[run_key][tool_call_id] = enhanced_approval - - return ApprovalStorageResult.success_result() - except Exception as e: - return ApprovalStorageResult.error_result(f"Failed to store approval: {e}") - - async def get_approval( - self, - run_id: RunId, - tool_call_id: str - ) -> ApprovalStorageResult: - """Retrieve approval for a specific tool call.""" - try: - async with self._lock: - run_key = self._get_run_key(run_id) - run_approvals = self._approvals.get(run_key, {}) - approval = run_approvals.get(tool_call_id) - - return ApprovalStorageResult.success_result(approval) - except Exception as e: - return ApprovalStorageResult.error_result(f"Failed to get approval: {e}") - - async def get_run_approvals(self, run_id: RunId) -> ApprovalStorageResult: - """Get all approvals for a run.""" - try: - async with self._lock: - run_key = self._get_run_key(run_id) - run_approvals = self._approvals.get(run_key, {}).copy() - - return ApprovalStorageResult.success_result(run_approvals) - except Exception as e: - return ApprovalStorageResult.error_result(f"Failed to get run approvals: {e}") - - async def update_approval( - self, - run_id: RunId, - tool_call_id: str, - updates: Dict[str, Any] - ) -> ApprovalStorageResult: - """Update existing approval.""" - try: - async with self._lock: - run_key = self._get_run_key(run_id) - - if run_key not in self._approvals or tool_call_id not in self._approvals[run_key]: - return ApprovalStorageResult.error_result( - f"Approval not found for tool call {tool_call_id} in run {run_id}" - ) - - existing = self._approvals[run_key][tool_call_id] - - # Merge additional context - merged_context = {**(existing.additional_context or {}), **(updates.get('additional_context', {}))} - - updated_approval = ApprovalValue( - status=updates.get('status', existing.status), - approved=updates.get('approved', existing.approved), - additional_context=merged_context if merged_context else existing.additional_context - ) - - self._approvals[run_key][tool_call_id] = updated_approval - - return ApprovalStorageResult.success_result() - except Exception as e: - return ApprovalStorageResult.error_result(f"Failed to update approval: {e}") - - async def delete_approval( - self, - run_id: RunId, - tool_call_id: str - ) -> ApprovalStorageResult: - """Delete approval for a tool call.""" - try: - async with self._lock: - run_key = self._get_run_key(run_id) - - if run_key not in self._approvals: - return ApprovalStorageResult.success_result(False) - - deleted = self._approvals[run_key].pop(tool_call_id, None) is not None - - # Clean up empty run maps - if not self._approvals[run_key]: - del self._approvals[run_key] - - return ApprovalStorageResult.success_result(deleted) - except Exception as e: - return ApprovalStorageResult.error_result(f"Failed to delete approval: {e}") - - async def clear_run_approvals(self, run_id: RunId) -> ApprovalStorageResult: - """Clear all approvals for a run.""" - try: - async with self._lock: - run_key = self._get_run_key(run_id) - - if run_key not in self._approvals: - return ApprovalStorageResult.success_result(0) - - count = len(self._approvals[run_key]) - del self._approvals[run_key] - - return ApprovalStorageResult.success_result(count) - except Exception as e: - return ApprovalStorageResult.error_result(f"Failed to clear run approvals: {e}") - - async def get_stats(self) -> ApprovalStorageResult: - """Get approval statistics.""" - try: - async with self._lock: - total_approvals = 0 - approved_count = 0 - rejected_count = 0 - runs_with_approvals = len(self._approvals) - - for run_approvals in self._approvals.values(): - for approval in run_approvals.values(): - total_approvals += 1 - if approval.approved: - approved_count += 1 - else: - rejected_count += 1 - - stats = { - 'total_approvals': total_approvals, - 'approved_count': approved_count, - 'rejected_count': rejected_count, - 'runs_with_approvals': runs_with_approvals - } - - return ApprovalStorageResult.success_result(stats) - except Exception as e: - return ApprovalStorageResult.error_result(f"Failed to get stats: {e}") - - async def health_check(self) -> ApprovalStorageResult: - """Health check for the storage.""" - try: - # Simple operation to test functionality - await asyncio.sleep(0.001) # Minimal async operation - - health_data = { - 'healthy': True, - 'latency_ms': 1.0 # Approximate for in-memory - } - - return ApprovalStorageResult.success_result(health_data) - except Exception as e: - health_data = { - 'healthy': False, - 'error': str(e) - } - return ApprovalStorageResult.success_result(health_data) - - async def close(self) -> ApprovalStorageResult: - """Close/cleanup the storage.""" - try: - async with self._lock: - self._approvals.clear() - return ApprovalStorageResult.success_result() - except Exception as e: - return ApprovalStorageResult.error_result(f"Failed to close storage: {e}") - - -def create_in_memory_approval_storage() -> InMemoryApprovalStorage: - """Create an in-memory approval storage instance.""" - return InMemoryApprovalStorage() \ No newline at end of file diff --git a/jaf/memory/providers/in_memory.py b/jaf/memory/providers/in_memory.py index 3dcaf0a..7d52e77 100644 --- a/jaf/memory/providers/in_memory.py +++ b/jaf/memory/providers/in_memory.py @@ -11,7 +11,7 @@ from datetime import datetime from typing import Any, Dict, List, Optional, Union -from ...core.types import Message +from ...core.types import Message, RunId, ApprovalValue from ..types import ( ConversationMemory, Failure, @@ -37,6 +37,7 @@ class InMemoryProvider(MemoryProvider): def __init__(self, config: InMemoryConfig): self.config = config self._conversations: OrderedDict[str, ConversationMemory] = OrderedDict() + self._approvals: Dict[str, Dict[str, ApprovalValue]] = {} # run_id -> {tool_call_id: approval} self._lock = asyncio.Lock() print(f"[MEMORY:InMemory] Initialized with max {config.max_conversations} conversations, {config.max_messages_per_conversation} messages each") @@ -307,11 +308,144 @@ async def health_check(self) -> Result[Dict[str, Any], MemoryConnectionError]: cause=e )) + # Approval storage methods + def _get_run_key(self, run_id: RunId) -> str: + """Convert run_id to string key.""" + return str(run_id) + + async def store_approval( + self, + run_id: RunId, + tool_call_id: str, + approval: ApprovalValue, + metadata: Optional[Dict[str, Any]] = None + ) -> Result[None, MemoryStorageError]: + """Store an approval decision for a tool call.""" + try: + async with self._lock: + run_key = self._get_run_key(run_id) + + if run_key not in self._approvals: + self._approvals[run_key] = {} + + self._approvals[run_key][tool_call_id] = approval + + return Success(None) + except Exception as e: + return Failure(MemoryStorageError(f"Failed to store approval: {e}")) + + async def get_approval( + self, + run_id: RunId, + tool_call_id: str + ) -> Result[Optional[ApprovalValue], MemoryStorageError]: + """Retrieve approval for a specific tool call. Returns None if not found.""" + try: + async with self._lock: + run_key = self._get_run_key(run_id) + + if run_key not in self._approvals: + return Success(None) + + approval = self._approvals[run_key].get(tool_call_id) + return Success(approval) + except Exception as e: + return Failure(MemoryStorageError(f"Failed to get approval: {e}")) + + async def get_run_approvals( + self, + run_id: RunId + ) -> Result[Dict[str, ApprovalValue], MemoryStorageError]: + """Get all approvals for a run as a Dict[str, ApprovalValue].""" + try: + async with self._lock: + run_key = self._get_run_key(run_id) + run_approvals = self._approvals.get(run_key, {}).copy() + return Success(run_approvals) + except Exception as e: + return Failure(MemoryStorageError(f"Failed to get run approvals: {e}")) + + async def update_approval( + self, + run_id: RunId, + tool_call_id: str, + updates: Dict[str, Any] + ) -> Result[None, MemoryStorageError]: + """Update approval with new data.""" + try: + async with self._lock: + run_key = self._get_run_key(run_id) + + if run_key not in self._approvals or tool_call_id not in self._approvals[run_key]: + return Failure(MemoryStorageError(f"Approval not found for tool_call_id: {tool_call_id}")) + + # Update approval fields + current_approval = self._approvals[run_key][tool_call_id] + + # Create updated approval with new values + updated_approval = ApprovalValue( + status=updates.get('status', current_approval.status), + approved=updates.get('approved', current_approval.approved), + additional_context={ + **current_approval.additional_context, + **updates.get('additional_context', {}) + } + ) + + self._approvals[run_key][tool_call_id] = updated_approval + + return Success(None) + except Exception as e: + return Failure(MemoryStorageError(f"Failed to update approval: {e}")) + + async def delete_approval( + self, + run_id: RunId, + tool_call_id: str + ) -> Result[bool, MemoryStorageError]: + """Delete approval for a tool call. Returns True if it existed.""" + try: + async with self._lock: + run_key = self._get_run_key(run_id) + + if run_key not in self._approvals: + return Success(False) + + deleted = self._approvals[run_key].pop(tool_call_id, None) is not None + + # Clean up empty run maps + if not self._approvals[run_key]: + del self._approvals[run_key] + + return Success(deleted) + except Exception as e: + return Failure(MemoryStorageError(f"Failed to delete approval: {e}")) + + async def clear_run_approvals( + self, + run_id: RunId + ) -> Result[int, MemoryStorageError]: + """Clear all approvals for a run. Returns count of deleted approvals.""" + try: + async with self._lock: + run_key = self._get_run_key(run_id) + + if run_key not in self._approvals: + return Success(0) + + count = len(self._approvals[run_key]) + del self._approvals[run_key] + + return Success(count) + except Exception as e: + return Failure(MemoryStorageError(f"Failed to clear run approvals: {e}")) + async def close(self) -> Result[None, MemoryConnectionError]: """Close/cleanup the provider.""" async with self._lock: self._conversations.clear() - print("[MEMORY:InMemory] Closed provider, cleared all conversations") + self._approvals.clear() + print("[MEMORY:InMemory] Closed provider, cleared all conversations and approvals") return Success(None) def create_in_memory_provider(config: Optional[InMemoryConfig] = None) -> InMemoryProvider: diff --git a/jaf/memory/providers/postgres.py b/jaf/memory/providers/postgres.py index 7adde04..857ad54 100644 --- a/jaf/memory/providers/postgres.py +++ b/jaf/memory/providers/postgres.py @@ -9,7 +9,7 @@ from datetime import datetime from typing import Any, Dict, List, Optional, Union -from ...core.types import Message +from ...core.types import Message, RunId, ApprovalValue from ..types import ( ConversationMemory, Failure, @@ -239,6 +239,198 @@ async def health_check(self) -> Result[Dict[str, Any], MemoryConnectionError]: except Exception as e: return Failure(MemoryConnectionError(provider="Postgres", message="Postgres health check failed", cause=e)) + # Approval storage methods + async def _ensure_approval_table_exists(self): + """Ensure the approval table exists.""" + query = f""" + CREATE TABLE IF NOT EXISTS {self.config.approval_table_name} ( + id SERIAL PRIMARY KEY, + run_id VARCHAR(255) NOT NULL, + tool_call_id VARCHAR(255) NOT NULL, + status VARCHAR(50), + approved BOOLEAN NOT NULL, + additional_context JSONB, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + UNIQUE(run_id, tool_call_id) + ); + """ + await self._db_execute(query) + + async def store_approval( + self, + run_id: RunId, + tool_call_id: str, + approval: ApprovalValue, + metadata: Optional[Dict[str, Any]] = None + ) -> Result[None, MemoryStorageError]: + """Store an approval decision for a tool call.""" + try: + await self._ensure_approval_table_exists() + + query = f""" + INSERT INTO {self.config.approval_table_name} + (run_id, tool_call_id, status, approved, additional_context, updated_at) + VALUES ($1, $2, $3, $4, $5, CURRENT_TIMESTAMP) + ON CONFLICT (run_id, tool_call_id) + DO UPDATE SET + status = EXCLUDED.status, + approved = EXCLUDED.approved, + additional_context = EXCLUDED.additional_context, + updated_at = CURRENT_TIMESTAMP + """ + + await self._db_execute( + query, + str(run_id), + tool_call_id, + approval.status, + approval.approved, + json.dumps(approval.additional_context) + ) + + return Success(None) + except Exception as e: + return Failure(MemoryStorageError(f"Failed to store approval: {e}")) + + async def get_approval( + self, + run_id: RunId, + tool_call_id: str + ) -> Result[Optional[ApprovalValue], MemoryStorageError]: + """Retrieve approval for a specific tool call.""" + try: + await self._ensure_approval_table_exists() + + query = f""" + SELECT status, approved, additional_context + FROM {self.config.approval_table_name} + WHERE run_id = $1 AND tool_call_id = $2 + """ + + row = await self._db_fetchrow(query, str(run_id), tool_call_id) + + if row is None: + return Success(None) + + approval = ApprovalValue( + status=row['status'], + approved=row['approved'], + additional_context=json.loads(row['additional_context']) if row['additional_context'] else {} + ) + + return Success(approval) + except Exception as e: + return Failure(MemoryStorageError(f"Failed to get approval: {e}")) + + async def get_run_approvals( + self, + run_id: RunId + ) -> Result[Dict[str, ApprovalValue], MemoryStorageError]: + """Get all approvals for a run.""" + try: + await self._ensure_approval_table_exists() + + query = f""" + SELECT tool_call_id, status, approved, additional_context + FROM {self.config.approval_table_name} + WHERE run_id = $1 + """ + + rows = await self._db_fetch(query, str(run_id)) + + approvals = {} + for row in rows: + approval = ApprovalValue( + status=row['status'], + approved=row['approved'], + additional_context=json.loads(row['additional_context']) if row['additional_context'] else {} + ) + approvals[row['tool_call_id']] = approval + + return Success(approvals) + except Exception as e: + return Failure(MemoryStorageError(f"Failed to get run approvals: {e}")) + + async def update_approval( + self, + run_id: RunId, + tool_call_id: str, + updates: Dict[str, Any] + ) -> Result[None, MemoryStorageError]: + """Update approval with new data.""" + try: + await self._ensure_approval_table_exists() + + # Get current approval + current_result = await self.get_approval(run_id, tool_call_id) + if hasattr(current_result, 'error'): + return current_result + if current_result.data is None: + return Failure(MemoryStorageError(f"Approval not found for tool_call_id: {tool_call_id}")) + + current_approval = current_result.data + + # Create updated approval + updated_approval = ApprovalValue( + status=updates.get('status', current_approval.status), + approved=updates.get('approved', current_approval.approved), + additional_context={ + **current_approval.additional_context, + **updates.get('additional_context', {}) + } + ) + + # Store updated approval + return await self.store_approval(run_id, tool_call_id, updated_approval) + except Exception as e: + return Failure(MemoryStorageError(f"Failed to update approval: {e}")) + + async def delete_approval( + self, + run_id: RunId, + tool_call_id: str + ) -> Result[bool, MemoryStorageError]: + """Delete approval for a tool call.""" + try: + await self._ensure_approval_table_exists() + + query = f""" + DELETE FROM {self.config.approval_table_name} + WHERE run_id = $1 AND tool_call_id = $2 + """ + + result = await self._db_execute(query, str(run_id), tool_call_id) + # PostgreSQL returns the number of affected rows + return Success(result is not None and result != 0) + except Exception as e: + return Failure(MemoryStorageError(f"Failed to delete approval: {e}")) + + async def clear_run_approvals( + self, + run_id: RunId + ) -> Result[int, MemoryStorageError]: + """Clear all approvals for a run.""" + try: + await self._ensure_approval_table_exists() + + count_query = f""" + SELECT COUNT(*) FROM {self.config.approval_table_name} + WHERE run_id = $1 + """ + count_row = await self._db_fetchrow(count_query, str(run_id)) + count = count_row['count'] if count_row else 0 + + delete_query = f""" + DELETE FROM {self.config.approval_table_name} + WHERE run_id = $1 + """ + await self._db_execute(delete_query, str(run_id)) + + return Success(count) + except Exception as e: + return Failure(MemoryStorageError(f"Failed to clear run approvals: {e}")) + async def close(self) -> Result[None, MemoryConnectionError]: try: if hasattr(self.client, 'close'): diff --git a/jaf/memory/providers/redis.py b/jaf/memory/providers/redis.py index ee57822..d91c5c3 100644 --- a/jaf/memory/providers/redis.py +++ b/jaf/memory/providers/redis.py @@ -8,7 +8,7 @@ from datetime import datetime from typing import Any, Dict, List, Optional, Union -from ...core.types import Message +from ...core.types import Message, RunId, ApprovalValue from ..types import ( ConversationMemory, Failure, @@ -205,6 +205,142 @@ async def health_check(self) -> Result[Dict[str, Any], MemoryConnectionError]: except Exception as e: return Failure(MemoryConnectionError(provider="Redis", message="Redis health check failed", cause=e)) + # Approval storage methods + def _get_approval_key(self, run_id: RunId) -> str: + """Get Redis key for approval storage.""" + return f"{self.config.key_prefix}approvals:{run_id}" + + def _serialize_approval(self, approval: ApprovalValue) -> str: + """Serialize approval to JSON.""" + import json + return json.dumps({ + 'status': approval.status, + 'approved': approval.approved, + 'additional_context': approval.additional_context + }) + + def _deserialize_approval(self, data: str) -> ApprovalValue: + """Deserialize approval from JSON.""" + import json + approval_dict = json.loads(data) + return ApprovalValue( + status=approval_dict['status'], + approved=approval_dict['approved'], + additional_context=approval_dict.get('additional_context', {}) + ) + + async def store_approval( + self, + run_id: RunId, + tool_call_id: str, + approval: ApprovalValue, + metadata: Optional[Dict[str, Any]] = None + ) -> Result[None, MemoryStorageError]: + """Store an approval decision for a tool call.""" + try: + key = self._get_approval_key(run_id) + serialized_approval = self._serialize_approval(approval) + await self.redis_client.hset(key, tool_call_id, serialized_approval) + return Success(None) + except Exception as e: + return Failure(MemoryStorageError(f"Failed to store approval: {e}")) + + async def get_approval( + self, + run_id: RunId, + tool_call_id: str + ) -> Result[Optional[ApprovalValue], MemoryStorageError]: + """Retrieve approval for a specific tool call.""" + try: + key = self._get_approval_key(run_id) + data = await self.redis_client.hget(key, tool_call_id) + + if data is None: + return Success(None) + + approval = self._deserialize_approval(data) + return Success(approval) + except Exception as e: + return Failure(MemoryStorageError(f"Failed to get approval: {e}")) + + async def get_run_approvals( + self, + run_id: RunId + ) -> Result[Dict[str, ApprovalValue], MemoryStorageError]: + """Get all approvals for a run.""" + try: + key = self._get_approval_key(run_id) + approval_data = await self.redis_client.hgetall(key) + + approvals = {} + for tool_call_id, data in approval_data.items(): + approvals[tool_call_id] = self._deserialize_approval(data) + + return Success(approvals) + except Exception as e: + return Failure(MemoryStorageError(f"Failed to get run approvals: {e}")) + + async def update_approval( + self, + run_id: RunId, + tool_call_id: str, + updates: Dict[str, Any] + ) -> Result[None, MemoryStorageError]: + """Update approval with new data.""" + try: + key = self._get_approval_key(run_id) + + # Get current approval + current_data = await self.redis_client.hget(key, tool_call_id) + if current_data is None: + return Failure(MemoryStorageError(f"Approval not found for tool_call_id: {tool_call_id}")) + + current_approval = self._deserialize_approval(current_data) + + # Create updated approval + updated_approval = ApprovalValue( + status=updates.get('status', current_approval.status), + approved=updates.get('approved', current_approval.approved), + additional_context={ + **current_approval.additional_context, + **updates.get('additional_context', {}) + } + ) + + # Store updated approval + serialized_approval = self._serialize_approval(updated_approval) + await self.redis_client.hset(key, tool_call_id, serialized_approval) + + return Success(None) + except Exception as e: + return Failure(MemoryStorageError(f"Failed to update approval: {e}")) + + async def delete_approval( + self, + run_id: RunId, + tool_call_id: str + ) -> Result[bool, MemoryStorageError]: + """Delete approval for a tool call.""" + try: + key = self._get_approval_key(run_id) + deleted = await self.redis_client.hdel(key, tool_call_id) + return Success(deleted > 0) + except Exception as e: + return Failure(MemoryStorageError(f"Failed to delete approval: {e}")) + + async def clear_run_approvals( + self, + run_id: RunId + ) -> Result[int, MemoryStorageError]: + """Clear all approvals for a run.""" + try: + key = self._get_approval_key(run_id) + count = await self.redis_client.hlen(key) + await self.redis_client.delete(key) + return Success(count) + except Exception as e: + return Failure(MemoryStorageError(f"Failed to clear run approvals: {e}")) + async def close(self) -> Result[None, MemoryConnectionError]: try: await self.redis_client.aclose() diff --git a/jaf/memory/types.py b/jaf/memory/types.py index 4a80c46..8b3c868 100644 --- a/jaf/memory/types.py +++ b/jaf/memory/types.py @@ -11,7 +11,7 @@ from pydantic import BaseModel, Field -from ..core.types import Message, TraceId +from ..core.types import Message, TraceId, RunId, ApprovalValue # Generic Result type for functional error handling T = TypeVar('T') @@ -135,6 +135,56 @@ async def close(self) -> Result[None, 'MemoryConnectionError']: """Close/cleanup the provider.""" ... + # Approval storage methods + async def store_approval( + self, + run_id: RunId, + tool_call_id: str, + approval: ApprovalValue, + metadata: Optional[Dict[str, Any]] = None + ) -> Result[None, 'MemoryStorageError']: + """Store an approval decision for a tool call.""" + ... + + async def get_approval( + self, + run_id: RunId, + tool_call_id: str + ) -> Result[Optional[ApprovalValue], 'MemoryStorageError']: + """Retrieve approval for a specific tool call. Returns None if not found.""" + ... + + async def get_run_approvals( + self, + run_id: RunId + ) -> Result[Dict[str, ApprovalValue], 'MemoryStorageError']: + """Get all approvals for a run as a Dict[str, ApprovalValue].""" + ... + + async def update_approval( + self, + run_id: RunId, + tool_call_id: str, + updates: Dict[str, Any] + ) -> Result[None, 'MemoryStorageError']: + """Update approval with new data.""" + ... + + async def delete_approval( + self, + run_id: RunId, + tool_call_id: str + ) -> Result[bool, 'MemoryStorageError']: + """Delete approval for a tool call. Returns True if it existed.""" + ... + + async def clear_run_approvals( + self, + run_id: RunId + ) -> Result[int, 'MemoryStorageError']: + """Clear all approvals for a run. Returns count of deleted approvals.""" + ... + # Configuration models using Pydantic for validation class InMemoryConfig(BaseModel): @@ -165,6 +215,7 @@ class PostgresConfig(BaseModel): password: Optional[str] = None ssl: bool = Field(default=False) table_name: str = Field(default="conversations") + approval_table_name: str = Field(default="jaf_approvals") max_connections: int = Field(default=10, ge=1) # Union type for all provider configurations diff --git a/jaf/server/server.py b/jaf/server/server.py index a0cf342..16cf86b 100644 --- a/jaf/server/server.py +++ b/jaf/server/server.py @@ -9,6 +9,7 @@ import uuid import asyncio import json +from datetime import datetime from dataclasses import asdict, replace from typing import TypeVar, Dict, Set @@ -36,6 +37,8 @@ AgentListData, AgentListResponse, ApprovalMessage, + ApprovalRequest, + ApprovalResponse, BaseOutcomeData, ChatRequest, ChatResponse, @@ -52,6 +55,7 @@ PendingApprovalData, PendingApprovalsData, PendingApprovalsResponse, + RejectRequest, ServerConfig, ToolCallInterruption, ) @@ -202,12 +206,27 @@ def _convert_core_message_to_http(core_msg: Message) -> HttpMessage: for att in core_msg.attachments ] + # Convert ToolCall dataclasses to dictionaries for HttpMessage + http_tool_calls = None + if core_msg.tool_calls: + http_tool_calls = [ + { + "id": tc.id, + "type": tc.type, + "function": { + "name": tc.function.name, + "arguments": tc.function.arguments + } + } + for tc in core_msg.tool_calls + ] + return HttpMessage( role=core_msg.role, content=content, attachments=attachments, tool_call_id=core_msg.tool_call_id, - tool_calls=core_msg.tool_calls + tool_calls=http_tool_calls ) def create_jaf_server(config: ServerConfig[Ctx]) -> FastAPI: @@ -346,20 +365,29 @@ async def chat_completion(request: ChatRequest): approvals_list = validated_request.approvals or [] - async def persist_approval(conv_id: str, appr: ApprovalMessage): + async def persist_approval(conv_id: str, appr): """Persist approval to memory provider with metadata (matching TypeScript).""" if not config.default_memory_provider: return - + provider = config.default_memory_provider + + # Handle both ApprovalMessage objects and dict representations + session_id = getattr(appr, 'session_id', None) or appr.get('session_id') + tool_call_id = getattr(appr, 'tool_call_id', None) or appr.get('tool_call_id') + approved = getattr(appr, 'approved', None) + if approved is None: + approved = appr.get('approved') + additional_context = getattr(appr, 'additional_context', None) or appr.get('additional_context') + # Keyed by previous run/session id + toolCallId for uniqueness (matching TypeScript) - approval_key = f"{appr.session_id}:{appr.tool_call_id}" + approval_key = f"{session_id}:{tool_call_id}" base_entry = { - 'approved': appr.approved, - 'status': 'approved' if appr.approved else 'rejected', - 'additionalContext': appr.additional_context, - 'sessionId': appr.session_id, - 'toolCallId': appr.tool_call_id, + 'approved': approved, + 'status': 'approved' if approved else 'rejected', + 'additionalContext': additional_context, + 'sessionId': session_id, + 'toolCallId': tool_call_id, } try: @@ -371,7 +399,7 @@ async def persist_approval(conv_id: str, appr: ApprovalMessage): for i in range(len(msgs) - 1, -1, -1): m = msgs[i] if m.role == 'assistant' and hasattr(m, 'tool_calls') and m.tool_calls: - match = next((tc for tc in m.tool_calls if tc.id == appr.tool_call_id), None) + match = next((tc for tc in m.tool_calls if tc.id == tool_call_id), None) if match: base_entry['toolName'] = match.function.name base_entry['signature'] = compute_tool_call_signature(match) @@ -426,22 +454,43 @@ async def persist_approval(conv_id: str, appr: ApprovalMessage): try: broadcast_approval_decision({ 'conversationId': conv_id, - 'sessionId': appr.session_id, - 'toolCallId': appr.tool_call_id, - 'status': 'approved' if appr.approved else 'rejected', - 'additionalContext': appr.additional_context + 'sessionId': session_id, + 'toolCallId': tool_call_id, + 'status': 'approved' if approved else 'rejected', + 'additionalContext': additional_context }) except Exception: pass # ignore if len(approvals_list) > 0: for approval in approvals_list: - if approval.session_id: # Matching TypeScript condition - initial_approvals[approval.tool_call_id] = { - 'status': 'approved' if approval.approved else 'rejected', - 'approved': approval.approved, - 'additionalContext': approval.additional_context - } + # Handle both ApprovalMessage objects and dict representations + session_id = getattr(approval, 'session_id', None) or (approval.get('session_id') if hasattr(approval, 'get') else None) + tool_call_id = getattr(approval, 'tool_call_id', None) or (approval.get('tool_call_id') if hasattr(approval, 'get') else None) + approved = getattr(approval, 'approved', None) + if approved is None: + approved = approval.get('approved') if hasattr(approval, 'get') else None + additional_context = getattr(approval, 'additional_context', None) or (approval.get('additional_context') if hasattr(approval, 'get') else None) + + if session_id: # Matching TypeScript condition + approval_value = ApprovalValue( + status='approved' if approved else 'rejected', + approved=approved, + additional_context=additional_context or {} + ) + initial_approvals[tool_call_id] = approval_value + + # Store approval in memory provider for persistence + if config.default_memory_provider: + try: + await config.default_memory_provider.store_approval( + session_id, # run_id + tool_call_id, + approval_value + ) + print(f"[JAF:SERVER] Stored approval in memory provider: {tool_call_id}") + except Exception as e: + print(f"[JAF:SERVER] Failed to store approval in memory provider: {e}") await persist_approval(conversation_id, approval) # Seed approvals from persisted conversation metadata @@ -493,7 +542,7 @@ async def persist_approval(conv_id: str, appr: ApprovalMessage): initial_approvals[target_id] = ApprovalValue( status=status, approved=approval_entry.get('approved', False), - additional_context=approval_entry.get('additional_context') + additional_context=approval_entry.get('additionalContext') ) except Exception as e: @@ -740,33 +789,49 @@ async def get_pending_approvals(conversation_id: str = None): data=PendingApprovalsData(pending=[]) ) - # Check which tool calls have already been executed + # Check which tool calls have already been executed (not just responded to) tool_ids = {tc.id for tc in assistant_msg.tool_calls} executed = set() + pending_approval = set() + for j in range(assistant_index + 1, len(messages)): msg = messages[j] if hasattr(msg, 'role') and msg.role == 'tool' and hasattr(msg, 'tool_call_id'): if msg.tool_call_id in tool_ids: - executed.add(msg.tool_call_id) + # Check if tool response indicates pending approval + try: + import json + content = json.loads(msg.content) if isinstance(msg.content, str) else msg.content + if isinstance(content, dict) and content.get('hitl_status') == 'pending_approval': + pending_approval.add(msg.tool_call_id) + else: + executed.add(msg.tool_call_id) + except (json.JSONDecodeError, TypeError): + # If we can't parse the content, assume it's executed + executed.add(msg.tool_call_id) # Build pending approvals list pending_approvals = [] for tc in assistant_msg.tool_calls: if tc.id in executed: continue # Already executed - - # Check approval status - approval_key = f"{conversation.conversation_id}:{tc.id}" - approval_entry = approvals_meta.get(approval_key) - - status = 'pending' - if approval_entry: - status = approval_entry.get('status', 'pending') - if approval_entry.get('approved') is True: - status = 'approved' - elif approval_entry.get('approved') is False: - status = 'rejected' - + + # If tool is waiting for approval, include it regardless of metadata + if tc.id in pending_approval: + status = 'pending' + else: + # Check approval status from metadata + approval_key = f"{conversation.conversation_id}:{tc.id}" + approval_entry = approvals_meta.get(approval_key) + + status = 'pending' + if approval_entry: + status = approval_entry.get('status', 'pending') + if approval_entry.get('approved') is True: + status = 'approved' + elif approval_entry.get('approved') is False: + status = 'rejected' + if status == 'pending': pending_approvals.append(PendingApprovalData( conversation_id=conversation_id, @@ -854,4 +919,150 @@ async def event_stream(): } ) + @app.post("/approvals/approve", response_model=ApprovalResponse) + async def approve_tool_call(request: ApprovalRequest): + """Approve a tool call with optional additional context.""" + try: + # Use the existing chat endpoint with approval data + if not config.default_memory_provider: + return ApprovalResponse( + success=False, + error="Memory provider not configured" + ) + + conv_result = await config.default_memory_provider.get_conversation(request.conversationId) + if not (hasattr(conv_result, 'data') and conv_result.data): + return ApprovalResponse( + success=False, + error="Conversation not found" + ) + + conversation = conv_result.data + agent_name = conversation.metadata.get('agent_name') if conversation.metadata else None + + if not agent_name: + return ApprovalResponse( + success=False, + error="Could not determine agent name from conversation" + ) + + # Continue conversation with approval via existing chat endpoint + from .types import ApprovalMessage + + approval_message = ApprovalMessage( + type="approval", + session_id=conversation.metadata.get('run_id') if conversation.metadata else request.conversationId, + tool_call_id=request.toolCallId, + approved=True, + additional_context=request.additionalContext or {} + ) + + chat_request = ChatRequest( + agent_name=agent_name, + conversation_id=request.conversationId, + messages=[], # Empty messages - just continue the conversation + approvals=[approval_message] + ) + + chat_response = await chat_completion(chat_request) + + if chat_response.success: + return ApprovalResponse( + success=True, + data={ + "message": "Tool call approved and executed", + "toolCallId": request.toolCallId, + "conversationId": request.conversationId, + "chat_response": chat_response.data + } + ) + else: + return ApprovalResponse( + success=False, + error=chat_response.error + ) + + except Exception as e: + import traceback + print(f"[JAF:APPROVAL] Error in approve_tool_call: {traceback.format_exc()}") + return ApprovalResponse( + success=False, + error=f"Failed to approve tool call: {str(e)}" + ) + + @app.post("/approvals/reject", response_model=ApprovalResponse) + async def reject_tool_call(request: RejectRequest): + """Reject a tool call with optional reason.""" + try: + # Use the existing chat endpoint with rejection data + if not config.default_memory_provider: + return ApprovalResponse( + success=False, + error="Memory provider not configured" + ) + + conv_result = await config.default_memory_provider.get_conversation(request.conversationId) + if not (hasattr(conv_result, 'data') and conv_result.data): + return ApprovalResponse( + success=False, + error="Conversation not found" + ) + + conversation = conv_result.data + agent_name = conversation.metadata.get('agent_name') if conversation.metadata else None + + if not agent_name: + return ApprovalResponse( + success=False, + error="Could not determine agent name from conversation" + ) + + # Continue conversation with rejection via existing chat endpoint + from .types import ApprovalMessage + + approval_message = ApprovalMessage( + type="approval", + session_id=conversation.metadata.get('run_id') if conversation.metadata else request.conversationId, + tool_call_id=request.toolCallId, + approved=False, + additional_context={ + **(request.additionalContext or {}), + "rejection_reason": request.reason or "User declined the action" + } + ) + + chat_request = ChatRequest( + agent_name=agent_name, + conversation_id=request.conversationId, + messages=[], # Empty messages - just continue the conversation + approvals=[approval_message] + ) + + chat_response = await chat_completion(chat_request) + + if chat_response.success: + return ApprovalResponse( + success=True, + data={ + "message": "Tool call rejected", + "toolCallId": request.toolCallId, + "conversationId": request.conversationId, + "reason": request.reason, + "chat_response": chat_response.data + } + ) + else: + return ApprovalResponse( + success=False, + error=chat_response.error + ) + + except Exception as e: + import traceback + print(f"[JAF:APPROVAL] Error in reject_tool_call: {traceback.format_exc()}") + return ApprovalResponse( + success=False, + error=f"Failed to reject tool call: {str(e)}" + ) + return app diff --git a/jaf/server/types.py b/jaf/server/types.py index c5aa2ae..cb3941e 100644 --- a/jaf/server/types.py +++ b/jaf/server/types.py @@ -268,6 +268,36 @@ class PendingApprovalsResponse(BaseModel): error: Optional[str] = None +class ApprovalRequest(BaseModel): + """Request format for approval endpoints.""" + conversationId: str = Field(alias="conversationId") + toolCallId: str = Field(alias="toolCallId") + additionalContext: Optional[Dict[str, Any]] = Field(None, alias="additionalContext") + + class Config: + populate_by_name = True + allow_population_by_field_name = True + + +class RejectRequest(BaseModel): + """Request format for rejection endpoints.""" + conversationId: str = Field(alias="conversationId") + toolCallId: str = Field(alias="toolCallId") + reason: Optional[str] = None + additionalContext: Optional[Dict[str, Any]] = Field(None, alias="additionalContext") + + class Config: + populate_by_name = True + allow_population_by_field_name = True + + +class ApprovalResponse(BaseModel): + """Response format for approval/rejection endpoints.""" + success: bool + data: Optional[Dict[str, Any]] = None + error: Optional[str] = None + + # Validation schemas def validate_chat_request(data: Dict[str, Any]) -> ChatRequest: """Validate and parse a chat request."""