diff --git a/examples/hitl-demo/hitl-clarification-stream-demo.ts b/examples/hitl-demo/hitl-clarification-stream-demo.ts index c5ce9ce..7ec9cc2 100644 --- a/examples/hitl-demo/hitl-clarification-stream-demo.ts +++ b/examples/hitl-demo/hitl-clarification-stream-demo.ts @@ -23,7 +23,9 @@ import { RunConfig, generateRunId, generateTraceId, - makeLiteLLMProvider + makeLiteLLMProvider, + createInMemoryClarificationStorage, + provideClarification } from '@xynehq/jaf'; import { ClarificationInterruption } from '../../src/core/types'; @@ -142,11 +144,11 @@ Remember: MULTIPLE distinct people = MUST use request_user_clarification tool!`, }; // Helper function to simulate user input -function simulateUserSelection(options: readonly { id: string; label: string }[], preferredIndex: number = 0): string { +function simulateUserSelection(options: readonly string[], preferredIndex: number = 0): string { console.log('\nπŸ€” Simulating user selection...'); const selected = options[preferredIndex]; - console.log(`πŸ‘€ User selects: ${selected.label} (ID: ${selected.id})\n`); - return selected.id; + console.log(`πŸ‘€ User selects: ${selected}\n`); + return selected; } // Main demo function using runStream @@ -160,11 +162,15 @@ async function streamingDemo() { const litellmApiKey = process.env.LITELLM_API_KEY; const modelProvider = makeLiteLLMProvider(litellmUrl, litellmApiKey); + // Create clarification storage + const clarificationStorage = createInMemoryClarificationStorage(); + const config: RunConfig = { agentRegistry, modelProvider: modelProvider as any, modelOverride: process.env.LITELLM_MODEL || 'gpt-4o-mini', allowClarificationRequests: true, + clarificationStorage, }; // Test Case 1: Ambiguous query with streaming @@ -184,12 +190,12 @@ async function streamingDemo() { }; let clarificationNeeded: ClarificationInterruption | null = null; - + let currentState = state1; console.log('🌊 Starting event stream...\n'); // Stream events in real-time - for await (const event of runStream(state1, config)) { + for await (const event of runStream(currentState, config)) { handleEvent(event); // Check if this is a run_end event @@ -205,17 +211,13 @@ async function streamingDemo() { if (clarification) { clarificationNeeded = clarification; - // The stream is complete, we can access the final state - // by creating a new run that will immediately detect and handle the clarification console.log('\n⏸️ Stream paused - clarification required'); console.log(`Question: ${clarification.question}`); console.log('Options:'); clarification.options.forEach((opt, idx) => { - console.log(` ${idx + 1}. ${opt.label} (ID: ${opt.id})`); + console.log(` ${idx + 1}. ${opt}`); }); - // We need to get the final state from the outcome - // For now, we'll simulate getting it from the event data break; } } else if (outcome.status === 'completed') { @@ -229,18 +231,16 @@ async function streamingDemo() { // If clarification was needed, resume with user selection if (clarificationNeeded) { - const selectedId = simulateUserSelection(clarificationNeeded.options, 0); + const selectedOption = simulateUserSelection(clarificationNeeded.options, 0); console.log('πŸ”„ Resuming stream with user selection...\n'); - // Create a new state with the clarification response - const stateWithClarification = { - ...state1, - clarifications: new Map([[clarificationNeeded.clarificationId, selectedId]]) - }; + // Use provideClarification helper to store the selection + // This stores it in clarificationStorage AND updates the state + currentState = await provideClarification(currentState, clarificationNeeded, selectedOption, undefined, config); - // Resume streaming with the clarification - for await (const event of runStream(stateWithClarification, config)) { + // Resume streaming - engine will auto-load clarifications from storage + for await (const event of runStream(currentState, config)) { handleEvent(event); if (event.type === 'run_end') { @@ -348,13 +348,13 @@ function handleEvent(event: TraceEvent | any) { console.log('\nπŸ”” Clarification Requested!'); console.log(` Question: ${event.data.question}`); console.log(' Options:'); - event.data.options.forEach((opt: any, idx: number) => { - console.log(` ${idx + 1}. ${opt.label} (ID: ${opt.id})`); + event.data.options.forEach((opt: string, idx: number) => { + console.log(` ${idx + 1}. ${opt}`); }); break; case 'clarification_provided': - console.log(`βœ… Clarification provided: ${event.data.selectedId}`); + console.log(`βœ… Clarification provided: ${event.data.selectedOption}`); break; case 'turn_end': diff --git a/src/core/engine.ts b/src/core/engine.ts index 7effb89..0181a54 100644 --- a/src/core/engine.ts +++ b/src/core/engine.ts @@ -41,10 +41,7 @@ function createClarificationTool(config: RunConfig): Tool<{ description, parameters: z.object({ question: z.string().describe('The clarifying question to ask the user'), - options: z.array(z.object({ - id: z.string().describe('Unique identifier for this option'), - label: z.string().describe('Human-readable label shown to the user') - })).min(2).describe('clear and meaningful options that user can choose from (minimum 2 options)') + options: z.array(z.string()).min(2).describe('clear and meaningful options that user can choose from (minimum 2 options)') }) }, execute: async (args, _context): Promise => { @@ -66,8 +63,8 @@ export async function run( try { config.onEvent?.({ type: 'run_start', - data: { - runId: initialState.runId, + data: { + runId: initialState.runId, traceId: initialState.traceId, context: initialState.context, userId: (initialState.context as any)?.userId, @@ -90,6 +87,12 @@ export async function run( stateWithMemory = await loadApprovalsIntoState(stateWithMemory, config); } + if (config.clarificationStorage) { + safeConsole.log(`[JAF:ENGINE] Loading clarifications for runId ${stateWithMemory.runId}`); + const { loadClarificationsIntoState } = await import('./state'); + stateWithMemory = await loadClarificationsIntoState(stateWithMemory, config); + } + const result = await runInternal(stateWithMemory, config); if (config.memory?.autoStore && config.conversationId && result.outcome.status === 'completed' && config.memory.storeOnCompletion) { @@ -260,6 +263,20 @@ export async function* runStream( } } +function removeOldInterruptedMessages(messages: readonly Message[], toolCallIds: Set): Message[] { + return messages.filter(msg => { + if (msg.role === 'tool' && msg.tool_call_id && toolCallIds.has(msg.tool_call_id)) { + try { + const content = JSON.parse(getTextContent(msg.content)); + if (content.status === InterruptionStatus.Halted || content.status === InterruptionStatus.AwaitingClarification) { + return false; // Remove old halted/awaiting messages + } + } catch { /* ignore */ } + } + return true; // Keep all other messages + }); +} + async function tryResumePendingToolCalls( state: RunState, config: RunConfig @@ -275,11 +292,53 @@ async function tryResumePendingToolCalls( for (let j = i + 1; j < messages.length; j++) { const m = messages[j]; if (m.role === 'tool' && m.tool_call_id && ids.has(m.tool_call_id)) { + // Don't count "halted" or "awaiting_clarification" as executed - they need to be retried + try { + const content = JSON.parse(getTextContent(m.content)); + if (content.status === InterruptionStatus.Halted || content.status === InterruptionStatus.AwaitingClarification) { + continue; // Skip this, it's still pending + } + } catch { /* ignore */ } executed.add(m.tool_call_id); } } - const pendingToolCalls = msg.tool_calls.filter(tc => !executed.has(tc.id)); + let pendingToolCalls = msg.tool_calls.filter(tc => !executed.has(tc.id)); + + // Handle clarification tool calls that already have a stored clarification + const clarificationResultMessages: Message[] = []; + if (state.clarifications && state.clarifications.size > 0) { + const clarifications = state.clarifications; // Capture for closure + const filteredPending = pendingToolCalls.filter(tc => { + // If this is a clarification tool call, check if we have a clarification for it + if (tc.function.name === 'request_user_clarification') { + const clarificationId = `clarify_${tc.id}`; + const selectedOption = clarifications.get(clarificationId); + if (selectedOption) { + safeConsole.log(`[JAF:ENGINE] Found clarification for tool call ${tc.id} - creating result message`); + // Create the clarification_provided tool result message + clarificationResultMessages.push({ + role: 'tool', + tool_call_id: tc.id, + content: JSON.stringify({ + status: InterruptionStatus.ClarificationProvided, + message: `User selected option: ${selectedOption}` + }) + }); + return false; // Skip re-executing this tool call + } + } + return true; + }); + pendingToolCalls = filteredPending; + } + + + if (clarificationResultMessages.length > 0 && pendingToolCalls.length === 0) { + safeConsole.log(`[JAF:ENGINE] Added ${clarificationResultMessages.length} clarification result message(s) to state`); + return null; + } + if (pendingToolCalls.length === 0) { return null; // Nothing to resume } @@ -297,6 +356,17 @@ async function tryResumePendingToolCalls( } } as RunResult; } + const effectiveTools = [ + ...(currentAgent.tools || []) + ]; + + if (config.allowClarificationRequests) { + effectiveTools.push(createClarificationTool(config)); + } + const effectiveAgent: Agent = { + ...currentAgent, + tools: effectiveTools + }; try { const requests = pendingToolCalls.map(tc => ({ @@ -307,7 +377,7 @@ async function tryResumePendingToolCalls( config.onEvent?.({ type: 'tool_requests', data: { toolCalls: requests } }); } catch { /* ignore */ } - const toolResults = await executeToolCalls(pendingToolCalls, currentAgent, state, config); + const toolResults = await executeToolCalls(pendingToolCalls, effectiveAgent, state, config); const interruptions = toolResults .map(r => r.interruption) @@ -332,9 +402,13 @@ async function tryResumePendingToolCalls( data: { results: toolResults.map(r => r.message) } }); + // Remove old halted/awaiting_clarification messages for the tools we just re-executed + const pendingToolCallIds = new Set(pendingToolCalls.map(tc => tc.id)); + const cleanedMessages = removeOldInterruptedMessages(state.messages, pendingToolCallIds); + const nextState: RunState = { ...state, - messages: [...state.messages, ...toolResults.map(r => r.message)], + messages: [...cleanedMessages, ...toolResults.map(r => r.message)], turnCount: state.turnCount, approvals: state.approvals ?? new Map(), }; @@ -355,50 +429,58 @@ async function runInternal( if (resumed) return resumed; // Check if we're resuming from a clarification + // Search through ALL messages for any awaiting_clarification that now have clarifications if (state.clarifications && state.clarifications.size > 0) { - const lastMessage = state.messages[state.messages.length - 1]; - if (lastMessage?.role === 'tool') { - try { - const content = JSON.parse(getTextContent(lastMessage.content)); - if (content.status === InterruptionStatus.AwaitingClarification) { - const clarificationId = content.clarification_id; - const selectedId = state.clarifications.get(clarificationId); - - if (selectedId) { - safeConsole.log(`[JAF:ENGINE] Resuming with clarification: ${clarificationId}, selected option: ${selectedId}`); - - // Find the selected option to include in the event - const updatedMessages = [...state.messages]; - updatedMessages[updatedMessages.length - 1] = { - ...lastMessage, - content: JSON.stringify({ - status: InterruptionStatus.ClarificationProvided, - message: `User selected option: ${selectedId}` - }) - }; + let hasUpdates = false; + const updatedMessages = [...state.messages]; - config.onEvent?.({ - type: 'clarification_provided', - data: { - clarificationId, - selectedId, - selectedOption: { id: selectedId, label: selectedId } - } - }); + for (let i = 0; i < updatedMessages.length; i++) { + const message = updatedMessages[i]; + if (message?.role === 'tool') { + try { + const content = JSON.parse(getTextContent(message.content)); + if (content.status === InterruptionStatus.AwaitingClarification) { + const clarificationId = content.clarification_id; + const selectedOption = state.clarifications.get(clarificationId); + + if (selectedOption) { + safeConsole.log(`[JAF:ENGINE] Resuming with clarification: ${clarificationId}, selected option: ${selectedOption}`); + + // Update this message to clarification_provided + updatedMessages[i] = { + ...message, + content: JSON.stringify({ + status: InterruptionStatus.ClarificationProvided, + message: `User selected option: ${selectedOption}` + }) + }; - // Continue execution with updated messages - const stateWithClarification: RunState = { - ...state, - messages: updatedMessages - }; + config.onEvent?.({ + type: 'clarification_provided', + data: { + clarificationId, + selectedOption + } + }); - return runInternal(stateWithClarification, config); + hasUpdates = true; + } } + } catch (e) { + // Ignore parse errors } - } catch (e) { - safeConsole.log(`[JAF:ENGINE] Error checking for clarification resume:`, e); } } + + // If we updated any messages, continue with the updated state + if (hasUpdates) { + const stateWithClarification: RunState = { + ...state, + messages: updatedMessages + }; + + return runInternal(stateWithClarification, config); + } } const maxTurns = config.maxTurns ?? 50; @@ -431,8 +513,8 @@ async function runInternal( const hasAdvancedGuardrails = !!(currentAgent.advancedConfig?.guardrails && (currentAgent.advancedConfig.guardrails.inputPrompt || - currentAgent.advancedConfig.guardrails.outputPrompt || - currentAgent.advancedConfig.guardrails.requireCitations)); + currentAgent.advancedConfig.guardrails.outputPrompt || + currentAgent.advancedConfig.guardrails.requireCitations)); safeConsole.log('[JAF:ENGINE] Debug guardrails setup:', { agentName: currentAgent.name, @@ -444,7 +526,7 @@ async function runInternal( let effectiveInputGuardrails: Guardrail[] = []; let effectiveOutputGuardrails: Guardrail[] = []; - + if (hasAdvancedGuardrails) { const result = await buildEffectiveGuardrails(currentAgent, config); effectiveInputGuardrails = result.inputGuardrails; @@ -469,7 +551,7 @@ async function runInternal( ...(currentAgent.tools || []) ]; - if(config.allowClarificationRequests){ + if (config.allowClarificationRequests) { effectiveTools.push(createClarificationTool(config)); } const effectiveAgent: Agent = { @@ -523,7 +605,7 @@ async function runInternal( } }; } - + const turnNumber = state.turnCount + 1; config.onEvent?.({ type: 'turn_start', data: { turn: turnNumber, agentName: currentAgent.name } }); @@ -554,69 +636,66 @@ async function runInternal( let llmResponse: any; let streamingUsed = false; let assistantEventStreamed = false; - + if (inputGuardrailsToRun.length > 0 && state.turnCount === 0) { const firstUserMessage = state.messages.find(m => m.role === 'user'); if (firstUserMessage) { if (hasAdvancedGuardrails) { const executionMode = currentAgent.advancedConfig?.guardrails?.executionMode || 'parallel'; - - if (executionMode === 'sequential') { - const guardrailResult = await executeInputGuardrailsSequential(inputGuardrailsToRun, firstUserMessage, config); - if (!guardrailResult.isValid) { - await runTurnEndHooks(config, { - turn: turnNumber, - agentName: currentAgent.name, - state, - lastAssistantMessage: undefined - }); - return { - finalState: state, - outcome: { - status: 'error', - error: { - _tag: 'InputGuardrailTripwire', - reason: guardrailResult.errorMessage + + if (executionMode === 'sequential') { + const guardrailResult = await executeInputGuardrailsSequential(inputGuardrailsToRun, firstUserMessage, config); + if (!guardrailResult.isValid) { + await runTurnEndHooks(config, { + turn: turnNumber, + agentName: currentAgent.name, + state, + lastAssistantMessage: undefined + }); + return { + finalState: state, + outcome: { + status: 'error', + error: { + _tag: 'InputGuardrailTripwire', + reason: guardrailResult.errorMessage + } } } - }; - } - - safeConsole.log(`βœ… All input guardrails passed. Starting LLM call.`); - llmResponse = await config.modelProvider.getCompletion(state, effectiveAgent, config); - } else { - const guardrailPromise = executeInputGuardrailsParallel(inputGuardrailsToRun, firstUserMessage, config); - const llmPromise = config.modelProvider.getCompletion(state, effectiveAgent, config); - - const [guardrailResult, llmResult] = await Promise.all([ - guardrailPromise, - llmPromise - ]); - - llmResponse = llmResult; - - if (!guardrailResult.isValid) { - safeConsole.log(`🚨 Input guardrail violation: ${guardrailResult.errorMessage}`); - safeConsole.log(`[JAF:GUARDRAILS] Discarding LLM response due to input guardrail violation`); - await runTurnEndHooks(config, { - turn: turnNumber, - agentName: currentAgent.name, - state, - lastAssistantMessage: undefined - }); - return { - finalState: state, - outcome: { - status: 'error', - error: { - _tag: 'InputGuardrailTripwire', - reason: guardrailResult.errorMessage + } + safeConsole.log(`βœ… All input guardrails passed. Starting LLM call.`); + llmResponse = await config.modelProvider.getCompletion(state, effectiveAgent, config); + } else { + const guardrailPromise = executeInputGuardrailsParallel(inputGuardrailsToRun, firstUserMessage, config); + const llmPromise = config.modelProvider.getCompletion(state, effectiveAgent, config); + const [guardrailResult, llmResult] = await Promise.all([ + guardrailPromise, + llmPromise + ]); + + llmResponse = llmResult; + + if (!guardrailResult.isValid) { + safeConsole.log(`🚨 Input guardrail violation: ${guardrailResult.errorMessage}`); + safeConsole.log(`[JAF:GUARDRAILS] Discarding LLM response due to input guardrail violation`); + await runTurnEndHooks(config, { + turn: turnNumber, + agentName: currentAgent.name, + state, + lastAssistantMessage: undefined + }); + return { + finalState: state, + outcome: { + status: 'error', + error: { + _tag: 'InputGuardrailTripwire', + reason: guardrailResult.errorMessage + } } } - }; - } - - safeConsole.log(`βœ… All input guardrails passed. Using LLM response.`); + } + safeConsole.log(`βœ… All input guardrails passed. Using LLM response.`); } } else { safeConsole.log('[JAF:ENGINE] Using LEGACY guardrails path with', inputGuardrailsToRun.length, 'guardrails'); @@ -680,15 +759,15 @@ async function runInternal( content: aggregatedText, ...(toolCalls.length > 0 ? { - tool_calls: toolCalls.map((tc, i) => ({ - id: tc.id ?? `call_${i}`, - type: 'function' as const, - function: { - name: tc.function.name ?? '', - arguments: tc.function.arguments - } - })) - } + tool_calls: toolCalls.map((tc, i) => ({ + id: tc.id ?? `call_${i}`, + type: 'function' as const, + function: { + name: tc.function.name ?? '', + arguments: tc.function.arguments + } + })) + } : {}) }; try { config.onEvent?.({ type: 'assistant_message', data: { message: partialMessage } }); } catch (err) { safeConsole.error('Error in config.onEvent:', err); } @@ -700,15 +779,15 @@ async function runInternal( content: aggregatedText || undefined, ...(toolCalls.length > 0 ? { - tool_calls: toolCalls.map((tc, i) => ({ - id: tc.id ?? `call_${i}`, - type: 'function' as const, - function: { - name: tc.function.name ?? '', - arguments: tc.function.arguments - } - })) - } + tool_calls: toolCalls.map((tc, i) => ({ + id: tc.id ?? `call_${i}`, + type: 'function' as const, + function: { + name: tc.function.name ?? '', + arguments: tc.function.arguments + } + })) + } : {}) } }; @@ -753,18 +832,18 @@ async function runInternal( content: aggregatedText, ...(toolCalls.length > 0 ? { - tool_calls: toolCalls.map((tc, i) => ({ - id: tc.id ?? `call_${i}`, - type: 'function' as const, - function: { - name: tc.function.name ?? '', - arguments: tc.function.arguments - } - })) - } + tool_calls: toolCalls.map((tc, i) => ({ + id: tc.id ?? `call_${i}`, + type: 'function' as const, + function: { + name: tc.function.name ?? '', + arguments: tc.function.arguments + } + })) + } : {}) }; - try { config.onEvent?.({ type: 'assistant_message', data: { message: partialMessage } }); } catch (err) {safeConsole.error('Error in config.onEvent:', err); } + try { config.onEvent?.({ type: 'assistant_message', data: { message: partialMessage } }); } catch (err) { safeConsole.error('Error in config.onEvent:', err); } } } @@ -773,15 +852,15 @@ async function runInternal( content: aggregatedText || undefined, ...(toolCalls.length > 0 ? { - tool_calls: toolCalls.map((tc, i) => ({ - id: tc.id ?? `call_${i}`, - type: 'function' as const, - function: { - name: tc.function.name ?? '', - arguments: tc.function.arguments - } - })) - } + tool_calls: toolCalls.map((tc, i) => ({ + id: tc.id ?? `call_${i}`, + type: 'function' as const, + function: { + name: tc.function.name ?? '', + arguments: tc.function.arguments + } + })) + } : {}) } }; @@ -794,17 +873,17 @@ async function runInternal( llmResponse = await config.modelProvider.getCompletion(state, effectiveAgent, config); } } - + const usage = (llmResponse as any)?.usage; const prompt = (llmResponse as any)?.prompt; - + config.onEvent?.({ type: 'llm_call_end', - data: { + data: { choice: llmResponse, fullResponse: llmResponse, // Include complete response prompt: prompt, // Include the prompt that was sent - traceId: state.traceId, + traceId: state.traceId, runId: state.runId, agentName: currentAgent.name, model: model || 'unknown', @@ -869,7 +948,7 @@ async function runInternal( if (llmResponse.message.tool_calls && llmResponse.message.tool_calls.length > 0) { safeConsole.log(`[JAF:ENGINE] Processing ${llmResponse.message.tool_calls.length} tool calls`); safeConsole.log(`[JAF:ENGINE] Tool calls:`, llmResponse.message.tool_calls); - + try { const requests = llmResponse.message.tool_calls.map((tc: any) => ({ id: tc.id, @@ -878,7 +957,7 @@ async function runInternal( })); config.onEvent?.({ type: 'tool_requests', data: { toolCalls: requests } }); } catch { /* ignore */ } - + const toolResults = await executeToolCalls( llmResponse.message.tool_calls, effectiveAgent, @@ -917,10 +996,12 @@ async function runInternal( safeConsole.log(`[JAF:ENGINE] Clarification requested: ${interruption.question}`); } } - + + // Add ALL tool result messages, including interrupted ones + // This ensures tool_use blocks have corresponding tool_result blocks const interruptedState = { ...state, - messages: [...newMessages, ...completedToolResults.map(r => r.message)], + messages: [...newMessages, ...toolResults.map(r => r.message)], turnCount: updatedTurnCount, approvals: updatedApprovals, clarifications: updatedClarifications, @@ -962,7 +1043,7 @@ async function runInternal( const handoffResult = toolResults.find(r => r.isHandoff); if (handoffResult) { const targetAgent = handoffResult.targetAgent!; - + if (!currentAgent.handoffs?.includes(targetAgent)) { config.onEvent?.({ type: 'handoff_denied', @@ -1059,7 +1140,7 @@ async function runInternal( const parseResult = currentAgent.outputCodec.safeParse( tryParseJSON(llmResponse.message.content) ); - + if (!parseResult.success) { config.onEvent?.({ type: 'decode_error', data: { errors: parseResult.error.issues } }); await runTurnEndHooks(config, { @@ -1228,7 +1309,7 @@ async function executeToolCalls( toolCalls.map(async (toolCall): Promise => { const tool = agent.tools?.find(t => t.schema.name === toolCall.function.name); const startTime = Date.now(); - + let rawArgs = tryParseJSON(toolCall.function.arguments); // Emit before_tool_execution event - handler can return modified args @@ -1323,8 +1404,8 @@ async function executeToolCalls( config.onEvent?.({ type: 'tool_call_end', - data: { - toolName: toolCall.function.name, + data: { + toolName: toolCall.function.name, result: errorResult, traceId: state.traceId, runId: state.runId, @@ -1443,12 +1524,12 @@ async function executeToolCalls( // Not a clarification trigger, continue with normal processing } } - + // Apply onAfterToolExecution callback if configured if (config.onAfterToolExecution) { try { const toolResultStatus = typeof toolResult === 'string' ? 'success' : (toolResult?.status || 'success'); - + const modifiedResult = await config.onAfterToolExecution( toolCall.function.name, toolResult, @@ -1471,10 +1552,10 @@ async function executeToolCalls( } let resultString: string; let toolResultObj: any = null; - + if (typeof toolResult === 'string') { resultString = toolResult; - safeConsole.log(`[JAF:ENGINE] Tool ${toolCall.function.name}` ); + safeConsole.log(`[JAF:ENGINE] Tool ${toolCall.function.name}`); } else { toolResultObj = toolResult; const { toolResultToString } = await import('./tool-results'); @@ -1484,8 +1565,8 @@ async function executeToolCalls( config.onEvent?.({ type: 'tool_call_end', - data: { - toolName: toolCall.function.name, + data: { + toolName: toolCall.function.name, result: resultString, traceId: state.traceId, runId: state.runId, @@ -1556,16 +1637,16 @@ async function executeToolCalls( config.onEvent?.({ type: 'tool_call_end', - data: { - toolName: toolCall.function.name, + data: { + toolName: toolCall.function.name, result: errorResult, traceId: state.traceId, runId: state.runId, status: 'error', toolResult: { error: 'execution_error', detail: error instanceof Error ? error.message : String(error) }, executionTime: Date.now() - startTime, - error: { - type: 'execution_error', + error: { + type: 'execution_error', message: error instanceof Error ? error.message : String(error), stack: error instanceof Error ? error.stack : undefined } @@ -1618,7 +1699,7 @@ async function loadConversationHistory( const maxMessages = config.memory.maxMessages || result.data.messages.length; const allMemoryMessages = result.data.messages.slice(-maxMessages); - + const memoryMessages = allMemoryMessages.filter(msg => { if (msg.role !== 'tool') return true; try { @@ -1628,19 +1709,19 @@ async function loadConversationHistory( return true; // Keep non-JSON tool messages } }); - - const combinedMessages = memoryMessages.length > 0 - ? [...memoryMessages, ...initialState.messages.filter(msg => - !memoryMessages.some(memMsg => - memMsg.role === msg.role && - memMsg.content === msg.content && - JSON.stringify(memMsg.tool_calls) === JSON.stringify(msg.tool_calls) - ) - )] + + const combinedMessages = memoryMessages.length > 0 + ? [...memoryMessages, ...initialState.messages.filter(msg => + !memoryMessages.some(memMsg => + memMsg.role === msg.role && + memMsg.content === msg.content && + JSON.stringify(memMsg.tool_calls) === JSON.stringify(msg.tool_calls) + ) + )] : initialState.messages; - + const storedApprovals = result.data.metadata?.approvals; - const approvalsMap = storedApprovals + const approvalsMap = storedApprovals ? new Map(Object.entries(storedApprovals) as [string, any][]) : (initialState.approvals ?? new Map()); @@ -1651,7 +1732,7 @@ async function loadConversationHistory( safeConsole.log(`[JAF:MEMORY] Memory messages:`, memoryMessages.map(m => ({ role: m.role, content: getTextContent(m.content)?.substring(0, 100) + '...' }))); safeConsole.log(`[JAF:MEMORY] New messages:`, initialState.messages.map(m => ({ role: m.role, content: getTextContent(m.content)?.substring(0, 100) + '...' }))); safeConsole.log(`[JAF:MEMORY] Combined messages (${combinedMessages.length} total):`, combinedMessages.map(m => ({ role: m.role, content: getTextContent(m.content)?.substring(0, 100) + '...' }))); - + return { ...initialState, messages: combinedMessages, @@ -1674,7 +1755,7 @@ async function storeConversationHistory( if (config.memory.compressionThreshold && messagesToStore.length > config.memory.compressionThreshold) { const keepFirst = Math.floor(config.memory.compressionThreshold * 0.2); const keepRecent = config.memory.compressionThreshold - keepFirst; - + messagesToStore = [ ...messagesToStore.slice(0, keepFirst), ...messagesToStore.slice(-keepRecent) diff --git a/src/core/state.ts b/src/core/state.ts index f2f687f..6de3536 100644 --- a/src/core/state.ts +++ b/src/core/state.ts @@ -1,4 +1,4 @@ -import { RunState, Interruption, RunConfig } from './types'; +import { RunState, Interruption, RunConfig, ClarificationInterruption } from './types'; import { safeConsole } from '../utils/logger.js'; export async function approve( @@ -95,3 +95,70 @@ export async function loadApprovalsIntoState( return state; } } + +/** + * Provide clarification selection for a clarification request + */ +export async function provideClarification( + state: RunState, + interruption: ClarificationInterruption, + selectedOption: string, + additionalContext?: Record, + config?: RunConfig +): Promise> { + if (interruption.type === 'clarification_required') { + const clarificationValue = { + selectedOption, + additionalContext: { + timestamp: new Date().toISOString(), + ...(additionalContext || {}) + } + }; + + // Store in clarification storage if available + if (config?.clarificationStorage) { + const result = await config.clarificationStorage.storeClarification( + state.runId, + interruption.clarificationId, + clarificationValue + ); + if (!result.success) { + safeConsole.warn('Failed to store clarification:', result.error); + // Continue with in-memory fallback + } + } + + // Update in-memory state + const newClarifications = new Map(state.clarifications ?? []); + newClarifications.set(interruption.clarificationId, selectedOption); + return { + ...state, + clarifications: newClarifications, + }; + } + return state; +} + +/** + * Helper function to load clarifications from storage into state + * This is called automatically by the engine when resuming a run + */ +export async function loadClarificationsIntoState( + state: RunState, + config?: RunConfig +): Promise> { + if (!config?.clarificationStorage) { + return state; + } + + const result = await config.clarificationStorage.getRunClarifications(state.runId); + if (result.success) { + return { + ...state, + clarifications: result.data, + }; + } else { + safeConsole.warn('Failed to load clarifications:', result.error); + return state; + } +} diff --git a/src/core/types.ts b/src/core/types.ts index 0d51980..55860f6 100644 --- a/src/core/types.ts +++ b/src/core/types.ts @@ -1,6 +1,7 @@ import { z } from 'zod'; import { MemoryConfig } from '../memory/types'; import type { ApprovalStorage } from '../memory/approval-storage'; +import type { ClarificationStorage } from '../memory/clarification-storage'; export type TraceId = string & { readonly _brand: 'TraceId' }; export type RunId = string & { readonly _brand: 'RunId' }; @@ -31,7 +32,7 @@ export type Attachment = { readonly useLiteLLMFormat?: boolean; // Use LiteLLM native file format instead of text extraction }; -export type MessageContentPart = +export type MessageContentPart = | { readonly type: 'text'; readonly text: string } | { readonly type: 'image_url'; readonly image_url: { readonly url: string; readonly detail?: 'low' | 'high' | 'auto' } } | { readonly type: 'file'; readonly file: { readonly file_id: string; readonly format?: string } }; @@ -48,18 +49,18 @@ export function getTextContent(content: string | readonly MessageContentPart[] | if (typeof content === 'string') { return content; } - + if (Array.isArray(content)) { return content .filter(item => item && typeof item === 'object' && item.type === 'text') .map(item => item.text || '') .join(' '); } - + if (content && typeof content === 'object') { return content.text || content.content || ''; } - + return String(content || ''); } @@ -80,11 +81,11 @@ export type Tool = { context: Readonly, ) => Promise; readonly needsApproval?: - | boolean - | (( - context: Readonly, - params: Readonly, - ) => Promise | boolean); + | boolean + | (( + context: Readonly, + params: Readonly, + ) => Promise | boolean); }; export type AdvancedGuardrailsConfig = { @@ -179,11 +180,7 @@ export type JAFError = | { readonly _tag: "HandoffError"; readonly detail: string } | { readonly _tag: "AgentNotFound"; readonly agentName: string }; -export type ClarificationOption = { - readonly id: string; - readonly label: string; - readonly value?: any; -}; +export type ClarificationOption = string; export type ToolApprovalInterruption = { readonly type: 'tool_approval'; @@ -206,12 +203,12 @@ export type Interruption = ToolApprovalInterruption | ClarificationInt export type RunResult = { readonly finalState: RunState; readonly outcome: - | { readonly status: 'completed'; readonly output: Out } - | { readonly status: 'error'; readonly error: JAFError } - | { - readonly status: 'interrupted'; - readonly interruptions: readonly Interruption[]; - }; + | { readonly status: 'completed'; readonly output: Out } + | { readonly status: 'error'; readonly error: JAFError } + | { + readonly status: 'interrupted'; + readonly interruptions: readonly Interruption[]; + }; }; /** @@ -242,7 +239,7 @@ export type TraceEvent = | { type: 'turn_end'; data: { turn: number; agentName: string } } | { type: 'run_end'; data: { outcome: RunResult['outcome']; finalState: RunState; traceId: TraceId; runId: RunId; } } | { type: 'clarification_requested'; data: { clarificationId: string; question: string; options: readonly ClarificationOption[]; context?: any; } } - | { type: 'clarification_provided'; data: { clarificationId: string; selectedOption: ClarificationOption; selectedId: string; } }; + | { type: 'clarification_provided'; data: { clarificationId: string; selectedOption: ClarificationOption; } }; /** * Helper type to extract event data by event type @@ -457,10 +454,11 @@ export type RunConfig = { executionTime: number; status: string | import('./tool-results').ToolResult; } - ) => Promise ; + ) => Promise; readonly memory?: MemoryConfig; readonly conversationId?: string; readonly approvalStorage?: ApprovalStorage; + readonly clarificationStorage?: ClarificationStorage; readonly defaultFastModel?: string; readonly allowClarificationRequests?: boolean; readonly clarificationDescription?: string; diff --git a/src/index.ts b/src/index.ts index 9eed3d1..23f7aee 100644 --- a/src/index.ts +++ b/src/index.ts @@ -1,5 +1,6 @@ export * from './core/types'; export * from './core/engine'; +export * from './core/state'; export * from './core/tracing'; export * from './core/errors'; export * from './core/tool-results'; @@ -23,6 +24,10 @@ export * from './memory/providers/in-memory'; export * from './memory/providers/redis'; export * from './memory/providers/postgres'; +// HITL Storage +export * from './memory/approval-storage'; +export * from './memory/clarification-storage'; + // A2A Protocol Support export * from './a2a'; diff --git a/src/memory/clarification-storage.ts b/src/memory/clarification-storage.ts new file mode 100644 index 0000000..0f600df --- /dev/null +++ b/src/memory/clarification-storage.ts @@ -0,0 +1,295 @@ +import { RunId, TraceId } from '../core/types'; +import { Result, createSuccess, createFailure, createMemoryStorageError } from './types'; + +/** + * Clarification value stored for a clarification request + */ +export type ClarificationValue = { + readonly selectedOption: string; + readonly additionalContext?: Record; +}; + +/** + * Clarification storage interface for managing clarification responses + * Similar to ApprovalStorage but for handling user clarification selections + */ +export interface ClarificationStorage { + /** + * Store clarification response for a clarification request + */ + readonly storeClarification: ( + runId: RunId, + clarificationId: string, + clarification: ClarificationValue, + metadata?: { traceId?: TraceId; [key: string]: any } + ) => Promise>; + + /** + * Retrieve clarification for a specific clarification request + */ + readonly getClarification: ( + runId: RunId, + clarificationId: string + ) => Promise>; + + /** + * Get all clarifications for a run + */ + readonly getRunClarifications: ( + runId: RunId + ) => Promise>>; + + /** + * Update existing clarification with additional context + */ + readonly updateClarification: ( + runId: RunId, + clarificationId: string, + updates: Partial + ) => Promise>; + + /** + * Delete clarification for a clarification request + */ + readonly deleteClarification: ( + runId: RunId, + clarificationId: string + ) => Promise>; + + /** + * Clear all clarifications for a run + */ + readonly clearRunClarifications: (runId: RunId) => Promise>; + + /** + * Get clarification statistics + */ + readonly getStats: () => Promise>; + + /** + * Health check for the clarification storage + */ + readonly healthCheck: () => Promise>; + + /** + * Close/cleanup the storage + */ + readonly close: () => Promise>; +} + +/** + * In-memory implementation of ClarificationStorage + * Non-persistent, good for development and testing + */ +export function createInMemoryClarificationStorage(): ClarificationStorage { + const clarifications = new Map>(); + + const getRunKey = (runId: RunId): string => `run:${runId}`; + + return { + storeClarification: async (runId, clarificationId, clarification) => { + try { + const runKey = getRunKey(runId); + + if (!clarifications.has(runKey)) { + clarifications.set(runKey, new Map()); + } + + const runClarifications = clarifications.get(runKey)!; + runClarifications.set(clarificationId, clarification); + + return createSuccess(undefined); + } catch (error) { + return createFailure(createMemoryStorageError( + 'store clarification', + 'InMemoryClarificationStorage', + error instanceof Error ? error : new Error(String(error)) + )); + } + }, + + getClarification: async (runId, clarificationId) => { + try { + const runKey = getRunKey(runId); + const runClarifications = clarifications.get(runKey); + + if (!runClarifications) { + return createSuccess(null); + } + + const clarification = runClarifications.get(clarificationId) || null; + return createSuccess(clarification); + } catch (error) { + return createFailure(createMemoryStorageError( + 'get clarification', + 'InMemoryClarificationStorage', + error instanceof Error ? error : new Error(String(error)) + )); + } + }, + + getRunClarifications: async (runId) => { + try { + const runKey = getRunKey(runId); + const runClarifications = clarifications.get(runKey); + + if (!runClarifications) { + return createSuccess(new Map() as ReadonlyMap); + } + + // Convert ClarificationValue map to string map (just the selectedOption) + const resultMap = new Map(); + for (const [id, value] of runClarifications.entries()) { + resultMap.set(id, value.selectedOption); + } + + return createSuccess(resultMap as ReadonlyMap); + } catch (error) { + return createFailure(createMemoryStorageError( + 'get run clarifications', + 'InMemoryClarificationStorage', + error instanceof Error ? error : new Error(String(error)) + )); + } + }, + + updateClarification: async (runId, clarificationId, updates) => { + try { + const runKey = getRunKey(runId); + const runClarifications = clarifications.get(runKey); + + if (!runClarifications || !runClarifications.has(clarificationId)) { + return createFailure(createMemoryStorageError( + 'update clarification', + 'InMemoryClarificationStorage', + new Error(`Clarification not found for ${clarificationId} in run ${runId}`) + )); + } + + const existingClarification = runClarifications.get(clarificationId)!; + const updatedClarification: ClarificationValue = { + ...existingClarification, + ...updates, + additionalContext: { + ...existingClarification.additionalContext, + ...updates.additionalContext + } + }; + + runClarifications.set(clarificationId, updatedClarification); + return createSuccess(undefined); + } catch (error) { + return createFailure(createMemoryStorageError( + 'update clarification', + 'InMemoryClarificationStorage', + error instanceof Error ? error : new Error(String(error)) + )); + } + }, + + deleteClarification: async (runId, clarificationId) => { + try { + const runKey = getRunKey(runId); + const runClarifications = clarifications.get(runKey); + + if (!runClarifications) { + return createSuccess(false); + } + + const deleted = runClarifications.delete(clarificationId); + + // Clean up empty run maps + if (runClarifications.size === 0) { + clarifications.delete(runKey); + } + + return createSuccess(deleted); + } catch (error) { + return createFailure(createMemoryStorageError( + 'delete clarification', + 'InMemoryClarificationStorage', + error instanceof Error ? error : new Error(String(error)) + )); + } + }, + + clearRunClarifications: async (runId) => { + try { + const runKey = getRunKey(runId); + const runClarifications = clarifications.get(runKey); + + if (!runClarifications) { + return createSuccess(0); + } + + const count = runClarifications.size; + clarifications.delete(runKey); + + return createSuccess(count); + } catch (error) { + return createFailure(createMemoryStorageError( + 'clear run clarifications', + 'InMemoryClarificationStorage', + error instanceof Error ? error : new Error(String(error)) + )); + } + }, + + getStats: async () => { + try { + let totalClarifications = 0; + const runsWithClarifications = clarifications.size; + + for (const [, runClarifications] of clarifications) { + totalClarifications += runClarifications.size; + } + + return createSuccess({ + totalClarifications, + runsWithClarifications + }); + } catch (error) { + return createFailure(createMemoryStorageError( + 'get stats', + 'InMemoryClarificationStorage', + error instanceof Error ? error : new Error(String(error)) + )); + } + }, + + healthCheck: async () => { + try { + const start = Date.now(); + // Simple operation to test functionality + await Promise.resolve(); + const latencyMs = Date.now() - start; + + return createSuccess({ + healthy: true, + latencyMs + }); + } catch (error) { + return createSuccess({ + healthy: false, + error: error instanceof Error ? error.message : String(error) + }); + } + }, + + close: async () => { + try { + clarifications.clear(); + return createSuccess(undefined); + } catch (error) { + return createFailure(createMemoryStorageError( + 'close', + 'InMemoryClarificationStorage', + error instanceof Error ? error : new Error(String(error)) + )); + } + } + }; +} diff --git a/src/server/types.ts b/src/server/types.ts index 8916c22..fa3c4e0 100644 --- a/src/server/types.ts +++ b/src/server/types.ts @@ -108,11 +108,7 @@ export const chatResponseSchema = z.object({ type: z.literal('clarification_required'), clarificationId: z.string(), question: z.string(), - options: z.array(z.object({ - id: z.string(), - label: z.string(), - value: z.any().optional() - })), + options: z.array(z.string()), context: z.record(z.any()).optional() }) ])).optional()