From 7813365d5f95dbad239898c386c374a760a0ed86 Mon Sep 17 00:00:00 2001 From: Son Le Date: Thu, 12 Mar 2026 14:08:09 +0100 Subject: [PATCH 01/13] feat: add state snapshot event --- docs/wayflowcore/source/core/api/events.rst | 4 + .../source/core/api/serialization.rst | 16 + docs/wayflowcore/source/core/changelog.rst | 5 + .../source/core/howtoguides/howto_tracing.rst | 34 + .../agentspec/components/__init__.py | 3 + .../agentspec/components/transforms.py | 6 + .../src/wayflowcore/agentspec/tracing.py | 210 +++++ wayflowcore/src/wayflowcore/conversation.py | 114 ++- wayflowcore/src/wayflowcore/events/event.py | 26 + .../executors/_statesnapshot_eventlistener.py | 339 ++++++++ .../executors/statesnapshotpolicy.py | 57 ++ .../src/wayflowcore/serialization/__init__.py | 31 + .../_builtins_serialization_plugin.py | 40 +- .../wayflowcore/serialization/conversation.py | 314 ++++++++ .../agentspec/test_state_snapshot_tracing.py | 694 ++++++++++++++++ .../tests/events/test_state_snapshot_event.py | 46 ++ .../events/test_state_snapshot_runtime.py | 754 ++++++++++++++++++ .../test_conversation_state_snapshot.py | 92 +++ 18 files changed, 2767 insertions(+), 18 deletions(-) create mode 100644 wayflowcore/src/wayflowcore/executors/_statesnapshot_eventlistener.py create mode 100644 wayflowcore/src/wayflowcore/executors/statesnapshotpolicy.py create mode 100644 wayflowcore/src/wayflowcore/serialization/conversation.py create mode 100644 wayflowcore/tests/agentspec/test_state_snapshot_tracing.py create mode 100644 wayflowcore/tests/events/test_state_snapshot_event.py create mode 100644 wayflowcore/tests/events/test_state_snapshot_runtime.py create mode 100644 wayflowcore/tests/serialization/test_conversation_state_snapshot.py diff --git a/docs/wayflowcore/source/core/api/events.rst b/docs/wayflowcore/source/core/api/events.rst index cd3996709..98cb8e5b0 100644 --- a/docs/wayflowcore/source/core/api/events.rst +++ b/docs/wayflowcore/source/core/api/events.rst @@ -55,6 +55,10 @@ Events .. autoclass:: wayflowcore.events.event.ConversationExecutionFinishedEvent :exclude-members: to_tracing_info +.. _statesnapshotevent: +.. autoclass:: wayflowcore.events.event.StateSnapshotEvent + :exclude-members: to_tracing_info + .. _toolexecutionstartevent: .. autoclass:: wayflowcore.events.event.ToolExecutionStartEvent :exclude-members: to_tracing_info diff --git a/docs/wayflowcore/source/core/api/serialization.rst b/docs/wayflowcore/source/core/api/serialization.rst index 28d1e3c5a..1262dd377 100644 --- a/docs/wayflowcore/source/core/api/serialization.rst +++ b/docs/wayflowcore/source/core/api/serialization.rst @@ -33,6 +33,22 @@ Deserialization .. autofunction:: wayflowcore.serialization.serializer.autodeserialize +Conversation State Snapshots +---------------------------- + +.. _dumpconversationstate: +.. autofunction:: wayflowcore.serialization.dump_conversation_state + +.. _serializeconversationstate: +.. autofunction:: wayflowcore.serialization.serialize_conversation_state + +.. _deserializeconversationstate: +.. autofunction:: wayflowcore.serialization.deserialize_conversation_state + +.. _dumpvariablestate: +.. autofunction:: wayflowcore.serialization.dump_variable_state + + Plugins ------- diff --git a/docs/wayflowcore/source/core/changelog.rst b/docs/wayflowcore/source/core/changelog.rst index 337935eb3..8f9f7fa8f 100644 --- a/docs/wayflowcore/source/core/changelog.rst +++ b/docs/wayflowcore/source/core/changelog.rst @@ -7,6 +7,11 @@ WayFlow |current_version| New features ^^^^^^^^^^^^ +* **State snapshot tracing events:** + + Added ``StateSnapshotPolicy``, ``StateSnapshotEvent``, and conversation snapshot serialization helpers. + Snapshot emission can now be enabled per ``conversation.execute()`` / ``execute_async()`` turn, and is bridged to Agent Spec ``StateSnapshotEmitted`` events via the ``AgentSpecEventListener``. WayFlow-specific ``variable_state`` remains part of ``StateSnapshotEvent`` only, is not forwarded to Agent Spec, and requires JSON-serializable variable values. Snapshots are emitted only from direct execution boundary events; raised or interrupted turns do not synthesize extra unwind snapshots. + * **OAuth support for MCP Clients:** MCP Clients now support OAuth-based authorization. diff --git a/docs/wayflowcore/source/core/howtoguides/howto_tracing.rst b/docs/wayflowcore/source/core/howtoguides/howto_tracing.rst index 8a1de2286..81866574a 100644 --- a/docs/wayflowcore/source/core/howtoguides/howto_tracing.rst +++ b/docs/wayflowcore/source/core/howtoguides/howto_tracing.rst @@ -152,6 +152,40 @@ Here's an example of how to use it in your code. :start-after: .. start-##_Enable_Agent_Spec_Tracing :end-before: .. end-##_Enable_Agent_Spec_Tracing +State snapshot events +--------------------- + +WayFlow can also emit ``StateSnapshotEvent`` payloads at conversation, step, and tool +boundaries by passing a ``StateSnapshotPolicy`` to ``conversation.execute()`` or +``conversation.execute_async()``. When ``AgentSpecEventListener`` is registered and +the installed ``pyagentspec`` version exposes ``StateSnapshotEmitted``, these runtime +snapshots are bridged into Agent Spec ``StateSnapshotEmitted`` events on the active +span. The WayFlow-only ``variable_state`` payload is not bridged, because Agent Spec +does not define WayFlow variable semantics. When ``include_variable_state=True``, +variable values must already be JSON-serializable. +Snapshots are emitted only when the corresponding boundary event occurs. If a turn +raises or is interrupted before its matching closing event, WayFlow does not +synthesize an extra unwind snapshot. For step and tool intervals, the latest +already-emitted start snapshot is the recovery point. +For flows, ``NODE_TURNS`` snapshots are emitted around each step. For agents, the same +policy emits snapshots around each decision-loop iteration. Tool start/end snapshots +are emitted only for ``TOOL_TURNS`` and ``ALL_INTERNAL_TURNS``. + +.. code-block:: python + + from wayflowcore.executors.statesnapshotpolicy import ( + StateSnapshotInterval, + StateSnapshotPolicy, + ) + + status = conversation.execute( + state_snapshot_policy=StateSnapshotPolicy( + state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS, + include_variable_state=True, + extra_state_builder=lambda conv: {"ui": {"active_tab": "plan"}}, + ) + ) + Agent Spec Exporting/Loading ============================ diff --git a/wayflowcore/src/wayflowcore/agentspec/components/__init__.py b/wayflowcore/src/wayflowcore/agentspec/components/__init__.py index dfa345296..5cc2e53b1 100644 --- a/wayflowcore/src/wayflowcore/agentspec/components/__init__.py +++ b/wayflowcore/src/wayflowcore/agentspec/components/__init__.py @@ -115,6 +115,7 @@ from .transforms import ( PluginAppendTrailingSystemMessageToUserMessageTransform, PluginCoalesceSystemMessagesTransform, + PluginManagerWorkersToolRequestAndCallsTransform, PluginReactMergeToolRequestAndCallsTransform, PluginRemoveEmptyNonUserMessageTransform, PluginSwarmToolRequestAndCallsTransform, @@ -225,6 +226,7 @@ "contextprovider_deserialization_plugin", "PluginAppendTrailingSystemMessageToUserMessageTransform", "PluginCoalesceSystemMessagesTransform", + "PluginManagerWorkersToolRequestAndCallsTransform", "PluginRemoveEmptyNonUserMessageTransform", "PluginReactMergeToolRequestAndCallsTransform", "PluginSwarmToolRequestAndCallsTransform", @@ -247,6 +249,7 @@ "PluginManagerWorkers", "PluginAppendTrailingSystemMessageToUserMessageTransform", "PluginCoalesceSystemMessagesTransform", + "PluginManagerWorkersToolRequestAndCallsTransform", "PluginReactMergeToolRequestAndCallsTransform", "PluginRemoveEmptyNonUserMessageTransform", "messagetransform_deserialization_plugin", diff --git a/wayflowcore/src/wayflowcore/agentspec/components/transforms.py b/wayflowcore/src/wayflowcore/agentspec/components/transforms.py index 2e2bc0bc3..4b55c1c7a 100644 --- a/wayflowcore/src/wayflowcore/agentspec/components/transforms.py +++ b/wayflowcore/src/wayflowcore/agentspec/components/transforms.py @@ -58,6 +58,10 @@ class PluginSwarmToolRequestAndCallsTransform(MessageTransform): sequence of messages.""" +class PluginManagerWorkersToolRequestAndCallsTransform(MessageTransform): + """Format Tool requests as Agent messages and Tool results as User messages for manager-workers prompts.""" + + class PluginCanonicalizationMessageTransform(MessageTransform): """ Produce a conversation shaped like: @@ -98,6 +102,7 @@ class PluginSplitPromptOnMarkerMessageTransform(MessageTransform): PluginAppendTrailingSystemMessageToUserMessageTransform.__name__: PluginAppendTrailingSystemMessageToUserMessageTransform, PluginLlamaMergeToolRequestAndCallsTransform.__name__: PluginLlamaMergeToolRequestAndCallsTransform, PluginReactMergeToolRequestAndCallsTransform.__name__: PluginReactMergeToolRequestAndCallsTransform, + PluginManagerWorkersToolRequestAndCallsTransform.__name__: PluginManagerWorkersToolRequestAndCallsTransform, PluginSwarmToolRequestAndCallsTransform.__name__: PluginSwarmToolRequestAndCallsTransform, PluginCanonicalizationMessageTransform.__name__: PluginCanonicalizationMessageTransform, PluginSplitPromptOnMarkerMessageTransform.__name__: PluginSplitPromptOnMarkerMessageTransform, @@ -111,6 +116,7 @@ class PluginSplitPromptOnMarkerMessageTransform(MessageTransform): PluginAppendTrailingSystemMessageToUserMessageTransform.__name__: PluginAppendTrailingSystemMessageToUserMessageTransform, PluginLlamaMergeToolRequestAndCallsTransform.__name__: PluginLlamaMergeToolRequestAndCallsTransform, PluginReactMergeToolRequestAndCallsTransform.__name__: PluginReactMergeToolRequestAndCallsTransform, + PluginManagerWorkersToolRequestAndCallsTransform.__name__: PluginManagerWorkersToolRequestAndCallsTransform, PluginSwarmToolRequestAndCallsTransform.__name__: PluginSwarmToolRequestAndCallsTransform, PluginCanonicalizationMessageTransform.__name__: PluginCanonicalizationMessageTransform, PluginSplitPromptOnMarkerMessageTransform.__name__: PluginSplitPromptOnMarkerMessageTransform, diff --git a/wayflowcore/src/wayflowcore/agentspec/tracing.py b/wayflowcore/src/wayflowcore/agentspec/tracing.py index e42147179..a66ab754c 100644 --- a/wayflowcore/src/wayflowcore/agentspec/tracing.py +++ b/wayflowcore/src/wayflowcore/agentspec/tracing.py @@ -12,6 +12,8 @@ from pyagentspec.flows.node import Node as AgentSpecNode from pyagentspec.llms import LlmConfig as AgentSpecLlmConfig from pyagentspec.llms import LlmGenerationConfig +from pyagentspec.managerworkers import ManagerWorkers as AgentSpecManagerWorkers +from pyagentspec.swarm import Swarm as AgentSpecSwarm from pyagentspec.tools import Tool as AgentSpecTool from pyagentspec.tracing.events import AgentExecutionEnd as AgentSpecAgentExecutionEnd from pyagentspec.tracing.events import AgentExecutionStart as AgentSpecAgentExecutionStart @@ -23,8 +25,17 @@ ) from pyagentspec.tracing.events import LlmGenerationRequest as AgentSpecLlmGenerationRequest from pyagentspec.tracing.events import LlmGenerationResponse as AgentSpecLlmGenerationResponse +from pyagentspec.tracing.events import ( + ManagerWorkersExecutionEnd as AgentSpecManagerWorkersExecutionEnd, +) +from pyagentspec.tracing.events import ( + ManagerWorkersExecutionStart as AgentSpecManagerWorkersExecutionStart, +) from pyagentspec.tracing.events import NodeExecutionEnd as AgentSpecNodeExecutionEnd from pyagentspec.tracing.events import NodeExecutionStart as AgentSpecNodeExecutionStart +from pyagentspec.tracing.events import StateSnapshotEmitted as AgentSpecStateSnapshotEmitted +from pyagentspec.tracing.events import SwarmExecutionEnd as AgentSpecSwarmExecutionEnd +from pyagentspec.tracing.events import SwarmExecutionStart as AgentSpecSwarmExecutionStart from pyagentspec.tracing.events import ToolExecutionRequest as AgentSpecToolExecutionRequest from pyagentspec.tracing.events import ToolExecutionResponse as AgentSpecToolExecutionResponse from pyagentspec.tracing.events.llmgeneration import ToolCall as AgentSpecToolCall @@ -32,23 +43,30 @@ from pyagentspec.tracing.spans import AgentExecutionSpan as AgentSpecAgentExecutionSpan from pyagentspec.tracing.spans import FlowExecutionSpan as AgentSpecFlowExecutionSpan from pyagentspec.tracing.spans import LlmGenerationSpan as AgentSpecLlmGenerationSpan +from pyagentspec.tracing.spans import ( + ManagerWorkersExecutionSpan as AgentSpecManagerWorkersExecutionSpan, +) from pyagentspec.tracing.spans import NodeExecutionSpan as AgentSpecNodeExecutionSpan from pyagentspec.tracing.spans import Span as AgentSpecSpan +from pyagentspec.tracing.spans import SwarmExecutionSpan as AgentSpecSwarmExecutionSpan from pyagentspec.tracing.spans import ToolExecutionSpan as AgentSpecToolExecutionSpan from wayflowcore._utils.formatting import stringify from wayflowcore.agentspec import AgentSpecExporter from wayflowcore.component import Component +from wayflowcore.conversation import Conversation, _get_active_conversations from wayflowcore.events.event import ( AgentExecutionFinishedEvent, AgentExecutionStartedEvent, ConversationMessageStreamChunkEvent, + EndSpanEvent, Event, ExceptionRaisedEvent, FlowExecutionFinishedEvent, FlowExecutionStartedEvent, LlmGenerationRequestEvent, LlmGenerationResponseEvent, + StateSnapshotEvent, StepInvocationResultEvent, StepInvocationStartEvent, ToolExecutionResultEvent, @@ -56,6 +74,9 @@ ) from wayflowcore.events.eventlistener import EventListener from wayflowcore.executors.executionstatus import FinishedStatus +from wayflowcore.managerworkers import ManagerWorkers as RuntimeManagerWorkers +from wayflowcore.steps.agentexecutionstep import AgentExecutionStep as RuntimeAgentExecutionStep +from wayflowcore.swarm import Swarm as RuntimeSwarm from wayflowcore.tracing.span import LlmGenerationSpan, get_active_span_stack, get_current_span @@ -70,6 +91,12 @@ def __init__(self) -> None: self.agentspec_exporter: AgentSpecExporter = AgentSpecExporter() # We keep a registry of conversions, so that we do not repeat the conversion for the same object twice self.agentspec_components_registry: Dict[str, AgentSpecComponent] = {} + # State snapshots belong to the span that owns their conversation id, not + # necessarily to the runtime span that was active when the snapshot event + # was emitted. + self._conversation_spans_registry: Dict[str, AgentSpecSpan] = {} + self._pending_multi_agent_spans_by_component_id: Dict[str, AgentSpecSpan] = {} + self._multi_agent_spans_by_step_span_id: Dict[str, AgentSpecSpan] = {} # Track last assistant message id and a robust mapping tool_request_id -> assistant message id. # Some providers may emit tool events before final assistant message id is known; we allow # temporarily missing ids and backfill on LLM response. @@ -83,6 +110,168 @@ def _convert_to_agentspec(self, component: Component) -> AgentSpecComponent: ) return self.agentspec_components_registry[component.id] + def _get_active_wayflow_conversation(self) -> Conversation | None: + active_conversations = _get_active_conversations(return_copy=False) + if not active_conversations: + return None + return active_conversations[-1] + + def _register_current_conversation_span(self, agentspec_span: AgentSpecSpan) -> None: + active_conversation = self._get_active_wayflow_conversation() + if active_conversation is not None: + self._conversation_spans_registry[active_conversation.conversation_id] = agentspec_span + + def _start_multi_agent_span_if_needed( + self, + current_span_id: str, + event_name: str, + event: StepInvocationStartEvent, + ) -> None: + if not isinstance(event.step, RuntimeAgentExecutionStep): + return + + if isinstance(event.step.agent, RuntimeManagerWorkers): + agentspec_managerworkers = cast( + AgentSpecManagerWorkers, self._convert_to_agentspec(event.step.agent) + ) + multi_agent_span: AgentSpecManagerWorkersExecutionSpan | AgentSpecSwarmExecutionSpan = ( + AgentSpecManagerWorkersExecutionSpan( + id=f"{current_span_id}:managerworkers", + name=f"ManagerWorkersExecution[{event.step.agent._get_display_name()}]", + managerworkers=agentspec_managerworkers, + ) + ) + multi_agent_span.start() + multi_agent_span.add_event( + AgentSpecManagerWorkersExecutionStart( + id=event.event_id, + name=event_name, + managerworkers=agentspec_managerworkers, + inputs={ + input_name: input_value for input_name, input_value in event.inputs.items() + }, + ) + ) + self._multi_agent_spans_by_step_span_id[current_span_id] = multi_agent_span + self._pending_multi_agent_spans_by_component_id[event.step.agent.id] = multi_agent_span + elif isinstance(event.step.agent, RuntimeSwarm): + agentspec_swarm = cast(AgentSpecSwarm, self._convert_to_agentspec(event.step.agent)) + multi_agent_span = AgentSpecSwarmExecutionSpan( + id=f"{current_span_id}:swarm", + name=f"SwarmExecution[{event.step.agent._get_display_name()}]", + swarm=agentspec_swarm, + ) + multi_agent_span.start() + multi_agent_span.add_event( + AgentSpecSwarmExecutionStart( + id=event.event_id, + name=event_name, + swarm=agentspec_swarm, + inputs={ + input_name: input_value for input_name, input_value in event.inputs.items() + }, + ) + ) + self._multi_agent_spans_by_step_span_id[current_span_id] = multi_agent_span + self._pending_multi_agent_spans_by_component_id[event.step.agent.id] = multi_agent_span + + def _end_multi_agent_span_if_needed( + self, + current_span_id: str, + event_name: str, + event: StepInvocationResultEvent, + ) -> None: + if not isinstance(event.step, RuntimeAgentExecutionStep): + return + + multi_agent_span = self._multi_agent_spans_by_step_span_id.pop(current_span_id, None) + if multi_agent_span is None: + return + + outputs = { + output_name: output_value + for output_name, output_value in event.step_result.outputs.items() + if output_name != "__execution_status__" + } + if isinstance(multi_agent_span, AgentSpecManagerWorkersExecutionSpan): + multi_agent_span.add_event( + AgentSpecManagerWorkersExecutionEnd( + id=event.event_id, + name=event_name, + managerworkers=multi_agent_span.managerworkers, + outputs=outputs, + ) + ) + elif isinstance(multi_agent_span, AgentSpecSwarmExecutionSpan): + multi_agent_span.add_event( + AgentSpecSwarmExecutionEnd( + id=event.event_id, + name=event_name, + swarm=multi_agent_span.swarm, + outputs=outputs, + ) + ) + + multi_agent_span.end() + self._pending_multi_agent_spans_by_component_id.pop(event.step.agent.id, None) + + def _get_snapshot_owner_span( + self, + event: StateSnapshotEvent, + current_agentspec_span: AgentSpecSpan | None, + ) -> AgentSpecSpan | None: + if event.conversation_id in self._conversation_spans_registry: + return self._conversation_spans_registry[event.conversation_id] + + active_conversations = _get_active_conversations(return_copy=False) + matching_conversation = next( + ( + conversation + for conversation in reversed(active_conversations) + if conversation.conversation_id == event.conversation_id + ), + None, + ) + if matching_conversation is None: + return current_agentspec_span + + pending_multi_agent_span = self._pending_multi_agent_spans_by_component_id.get( + matching_conversation.component.id + ) + if pending_multi_agent_span is not None: + self._conversation_spans_registry[event.conversation_id] = pending_multi_agent_span + return pending_multi_agent_span + + return current_agentspec_span + + @staticmethod + def _move_snapshot_before_terminal_event( + agentspec_span: AgentSpecSpan, + snapshot_event: AgentSpecStateSnapshotEmitted, + ) -> None: + if len(agentspec_span.events) < 2 or not isinstance( + agentspec_span.events[-2], + ( + AgentSpecAgentExecutionEnd, + AgentSpecExceptionRaised, + AgentSpecFlowExecutionEnd, + AgentSpecManagerWorkersExecutionEnd, + AgentSpecSwarmExecutionEnd, + ), + ): + return + + terminal_event = agentspec_span.events[-2] + if agentspec_span.end_time is not None: + snapshot_event.timestamp = min(snapshot_event.timestamp, agentspec_span.end_time) + else: + snapshot_event.timestamp = min(snapshot_event.timestamp, terminal_event.timestamp) + + agentspec_span.events[-2], agentspec_span.events[-1] = ( + agentspec_span.events[-1], + agentspec_span.events[-2], + ) + def __call__(self, event: Event) -> None: # We intercept the wayflow events, and based on the type of event: # - if it corresponds to a span start event, we create the corresponding agent spec span, and we start it @@ -294,10 +483,12 @@ def __call__(self, event: Event) -> None: }, ) ) + self._start_multi_agent_span_if_needed(current_span.span_id, event_name, event) case StepInvocationResultEvent(): # Step execution ends. Add the event to the agent spec span and close the span if not current_agentspec_span: return + self._end_multi_agent_span_if_needed(current_span.span_id, event_name, event) agentspec_node = cast(AgentSpecNode, self._convert_to_agentspec(event.step)) current_agentspec_span.add_event( AgentSpecNodeExecutionEnd( @@ -312,6 +503,19 @@ def __call__(self, event: Event) -> None: ) ) current_agentspec_span.end() + case StateSnapshotEvent(): + owner_span = self._get_snapshot_owner_span(event, current_agentspec_span) + if not owner_span: + return + snapshot_event = AgentSpecStateSnapshotEmitted( + id=event.event_id, + name=event_name, + conversation_id=event.conversation_id, + state_snapshot=event.state_snapshot, + extra_state=event.extra_state, + ) + owner_span.add_event(snapshot_event) + self._move_snapshot_before_terminal_event(owner_span, snapshot_event) case FlowExecutionStartedEvent(): # Flow execution starts. Create the new agent spec span, start it, add the event agentspec_flow = cast( @@ -332,6 +536,7 @@ def __call__(self, event: Event) -> None: inputs={}, ) ) + self._register_current_conversation_span(current_agentspec_span) case FlowExecutionFinishedEvent(): # Flow execution ends. Add the event to the agent spec span and close the span if not current_agentspec_span: @@ -375,6 +580,7 @@ def __call__(self, event: Event) -> None: inputs={}, ) ) + self._register_current_conversation_span(current_agentspec_span) case AgentExecutionFinishedEvent(): # Agent execution ends. Add the event to the agent spec span and close the span if not current_agentspec_span: @@ -408,3 +614,7 @@ def __call__(self, event: Event) -> None: exception_stacktrace=str(event.exception.__traceback__), ) ) + case EndSpanEvent(): + if not current_agentspec_span or current_agentspec_span.end_time is not None: + return + current_agentspec_span.end() diff --git a/wayflowcore/src/wayflowcore/conversation.py b/wayflowcore/src/wayflowcore/conversation.py index 6d347d1b9..c6213318f 100644 --- a/wayflowcore/src/wayflowcore/conversation.py +++ b/wayflowcore/src/wayflowcore/conversation.py @@ -3,6 +3,7 @@ # This software is under the Apache License 2.0 # (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0) or Universal Permissive License # (UPL) 1.0 (LICENSE-UPL or https://oss.oracle.com/licenses/upl), at your option. +import json import logging import warnings from abc import abstractmethod @@ -20,6 +21,7 @@ Optional, Sequence, Union, + cast, ) from wayflowcore._utils.async_helpers import run_async_in_sync @@ -32,6 +34,7 @@ ToolRequestStatus, UserMessageRequestStatus, ) +from wayflowcore.executors.statesnapshotpolicy import StateSnapshotPolicy from wayflowcore.messagelist import Message, MessageContent, MessageList from wayflowcore.planning import ExecutionPlan from wayflowcore.tokenusage import TokenUsage @@ -97,6 +100,10 @@ class Conversation(DataclassComponent): """Whether the current status associated to this conversation was already handled or not (messages/tool results were added to the conversation)""" + _state_snapshot_policy: Optional[StateSnapshotPolicy] = field( + default=None, init=False, repr=False, compare=False + ) + def __post_init__(self) -> None: if self.inputs is None: self.inputs = {} @@ -111,9 +118,99 @@ def _get_interrupts(self) -> Optional[List["ExecutionInterrupt"]]: def _register_event(self, event: Event) -> None: self.state._register_event(event) + def _get_parent_state_snapshot_policy(self) -> Optional[StateSnapshotPolicy]: + active_conversations = _get_active_conversations(return_copy=True) + if not active_conversations or active_conversations[-1] is self: + return None + return active_conversations[-1]._get_state_snapshot_policy() + + def _get_state_snapshot_policy(self) -> Optional[StateSnapshotPolicy]: + return self._state_snapshot_policy + + def _build_extra_state(self) -> Optional[Dict[str, Any]]: + state_snapshot_policy = self._get_state_snapshot_policy() + if state_snapshot_policy is None or state_snapshot_policy.extra_state_builder is None: + return None + + try: + extra_state = state_snapshot_policy.extra_state_builder(self) + except Exception: + logger.warning( + "Failed to build extra snapshot state for conversation '%s'", + self.conversation_id, + exc_info=True, + ) + return None + + if extra_state is None: + return None + if not isinstance(extra_state, dict): + logger.warning( + "Expected extra snapshot state to be a dictionary for conversation '%s'", + self.conversation_id, + ) + return None + + try: + return cast(Dict[str, Any], json.loads(json.dumps(extra_state))) + except Exception: + logger.warning( + "Extra snapshot state is not JSON serializable for conversation '%s'", + self.conversation_id, + exc_info=True, + ) + return None + + @contextmanager + def _use_state_snapshot( + self, state_snapshot_policy: Optional[StateSnapshotPolicy] + ) -> Generator[None, Any, None]: + """ + Activate the effective state snapshot policy for this execution turn. + + Child conversations inherit the parent's policy unless they explicitly + override it. When snapshots are enabled, listener registration happens + here in the same order the runtime depends on: + 1. pre-interrupt snapshot listener + 2. interrupts listener + 3. post-interrupt snapshot listener + """ + active_state_snapshot_policy = ( + state_snapshot_policy + if state_snapshot_policy is not None + else self._get_parent_state_snapshot_policy() + ) + previous_policy = self._state_snapshot_policy + self._state_snapshot_policy = active_state_snapshot_policy + try: + if active_state_snapshot_policy is None: + yield + else: + from wayflowcore.executors._interrupts_eventlistener import ( + get_interrupts_event_listener_context_for_conversation, + ) + from wayflowcore.executors._statesnapshot_eventlistener import ( + StateSnapshotListenerPhase, + get_state_snapshot_event_listener_context_for_conversation, + ) + + with get_state_snapshot_event_listener_context_for_conversation( + self, + phase=StateSnapshotListenerPhase.PRE_INTERRUPTS, + ): + with get_interrupts_event_listener_context_for_conversation(self): + with get_state_snapshot_event_listener_context_for_conversation( + self, + phase=StateSnapshotListenerPhase.POST_INTERRUPTS, + ): + yield + finally: + self._state_snapshot_policy = previous_policy + def execute( self, execution_interrupts: Optional[Sequence["ExecutionInterrupt"]] = None, + state_snapshot_policy: Optional[StateSnapshotPolicy] = None, ) -> "ExecutionStatus": """ Execute the conversation and get its ``ExecutionStatus`` based on the outcome. @@ -122,12 +219,16 @@ def execute( finished the conversation. """ return run_async_in_sync( - self.execute_async, execution_interrupts, method_name="execute_async" + self.execute_async, + execution_interrupts, + state_snapshot_policy, + method_name="execute_async", ) async def execute_async( self, execution_interrupts: Optional[Sequence["ExecutionInterrupt"]] = None, + state_snapshot_policy: Optional[StateSnapshotPolicy] = None, ) -> "ExecutionStatus": """ Execute the conversation and get its ``ExecutionStatus`` based on the outcome. @@ -138,12 +239,13 @@ async def execute_async( if self.status_handled is False: self._update_conversation_with_status() - with _register_conversation(self): - new_status = await self.component.runner.execute_async(self, execution_interrupts) + with self._use_state_snapshot(state_snapshot_policy): + with _register_conversation(self): + new_status = await self.component.runner.execute_async(self, execution_interrupts) - self.status = new_status - self.status_handled = False - return self.status + self.status = new_status + self.status_handled = False + return self.status @property @abstractmethod diff --git a/wayflowcore/src/wayflowcore/events/event.py b/wayflowcore/src/wayflowcore/events/event.py index 4c9c39f1c..2e632177a 100644 --- a/wayflowcore/src/wayflowcore/events/event.py +++ b/wayflowcore/src/wayflowcore/events/event.py @@ -794,6 +794,32 @@ def to_tracing_info(self, mask_sensitive_information: bool = True) -> Dict[str, } +@dataclass(frozen=True) +class StateSnapshotEvent(Event): + """Event emitted by WayFlow when a conversation state snapshot is recorded.""" + + conversation_id: str = field(default_factory=_required_attribute("conversation_id", str)) + state_snapshot: Optional[Dict[str, Any]] = None + extra_state: Optional[Dict[str, Any]] = None + variable_state: Optional[Dict[str, Any]] = None + + def to_tracing_info(self, mask_sensitive_information: bool = True) -> Dict[str, Any]: + def _masked(value: Optional[Dict[str, Any]]) -> Any: + if value is None: + return None + if mask_sensitive_information: + return _PII_TEXT_MASK + return value + + return { + **super().to_tracing_info(mask_sensitive_information=mask_sensitive_information), + "conversation_id": self.conversation_id, + "state_snapshot": _masked(self.state_snapshot), + "extra_state": _masked(self.extra_state), + "variable_state": _masked(self.variable_state), + } + + @dataclass(frozen=True) class AgentNextActionDecisionStartEvent(Event): """ diff --git a/wayflowcore/src/wayflowcore/executors/_statesnapshot_eventlistener.py b/wayflowcore/src/wayflowcore/executors/_statesnapshot_eventlistener.py new file mode 100644 index 000000000..561e0e455 --- /dev/null +++ b/wayflowcore/src/wayflowcore/executors/_statesnapshot_eventlistener.py @@ -0,0 +1,339 @@ +# Copyright © 2025 Oracle and/or its affiliates. +# +# This software is under the Apache License 2.0 +# (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0) or Universal Permissive License +# (UPL) 1.0 (LICENSE-UPL or https://oss.oracle.com/licenses/upl), at your option. + +from __future__ import annotations + +import logging +from contextlib import contextmanager +from enum import Enum +from typing import Iterator, Optional + +from wayflowcore.conversation import Conversation, _get_active_conversations +from wayflowcore.events import Event, EventListener +from wayflowcore.events.event import ( + AgentExecutionFinishedEvent, + AgentExecutionIterationFinishedEvent, + AgentExecutionIterationStartedEvent, + AgentExecutionStartedEvent, + ExceptionRaisedEvent, + FlowExecutionFinishedEvent, + FlowExecutionIterationFinishedEvent, + FlowExecutionStartedEvent, + StateSnapshotEvent, + StepInvocationStartEvent, + ToolExecutionResultEvent, + ToolExecutionStartEvent, +) +from wayflowcore.events.eventlistener import ( + get_event_listeners, + record_event, + register_event_listeners, +) +from wayflowcore.executors._events.event import EventType as ExecutionEventType +from wayflowcore.executors._executor import ExecutionInterruptedException +from wayflowcore.executors.executionstatus import ExecutionStatus +from wayflowcore.executors.statesnapshotpolicy import StateSnapshotInterval +from wayflowcore.serialization.conversation import dump_conversation_state, dump_variable_state +from wayflowcore.tracing.span import AgentExecutionSpan, FlowExecutionSpan, get_current_span + +logger = logging.getLogger(__name__) + + +class StateSnapshotBoundary(str, Enum): + """ + Concrete runtime boundaries at which a state snapshot may be recorded. + + `TURN_START` + The opening boundary of a single `conversation.execute(...)` call. This + captures the turn's initial resume point before execution work begins. + + `TURN_END` + The closing boundary of a single `conversation.execute(...)` call. This + is the stable resume point after the turn's final status is known. + + `TOOL_START` + Right before a tool invocation begins. + + `TOOL_END` + Right after a tool invocation completes and its result is available. + + `NODE_START` + Right before a flow step starts executing. + + `NODE_END` + Right after a flow step finishes executing. + + `AGENT_LOOP_START` + Right before an agent reasoning/decision-loop iteration starts. + + `AGENT_LOOP_END` + Right after an agent reasoning/decision-loop iteration finishes. + """ + + TURN_START = "turn_start" + TURN_END = "turn_end" + TOOL_START = "tool_start" + TOOL_END = "tool_end" + NODE_START = "node_start" + NODE_END = "node_end" + AGENT_LOOP_START = "agent_loop_start" + AGENT_LOOP_END = "agent_loop_end" + + +class StateSnapshotListenerPhase(str, Enum): + PRE_INTERRUPTS = "pre_interrupts" + POST_INTERRUPTS = "post_interrupts" + + +def should_emit_state_snapshot( + conversation: Conversation, + boundary: StateSnapshotBoundary, +) -> bool: + state_snapshot_policy = conversation._get_state_snapshot_policy() + if state_snapshot_policy is None: + return False + + snapshot_interval = state_snapshot_policy.state_snapshot_interval + if snapshot_interval == StateSnapshotInterval.OFF: + should_emit = False + elif boundary == StateSnapshotBoundary.TURN_START: + should_emit = snapshot_interval == StateSnapshotInterval.CONVERSATION_TURNS + elif boundary == StateSnapshotBoundary.TURN_END: + should_emit = True + elif boundary in {StateSnapshotBoundary.TOOL_START, StateSnapshotBoundary.TOOL_END}: + should_emit = snapshot_interval in { + StateSnapshotInterval.TOOL_TURNS, + StateSnapshotInterval.ALL_INTERNAL_TURNS, + } + elif boundary in { + StateSnapshotBoundary.NODE_START, + StateSnapshotBoundary.NODE_END, + StateSnapshotBoundary.AGENT_LOOP_START, + StateSnapshotBoundary.AGENT_LOOP_END, + }: + # Agents do not expose node execution events, so NODE_TURNS maps to + # per-step boundaries for flows and per-iteration boundaries for agents. + should_emit = snapshot_interval in { + StateSnapshotInterval.NODE_TURNS, + StateSnapshotInterval.ALL_INTERNAL_TURNS, + } + else: + should_emit = False + + return should_emit + + +def record_state_snapshot( + conversation: Conversation, + boundary: StateSnapshotBoundary, + *, + execution_status: ExecutionStatus | None, + status_handled: bool, +) -> bool: + state_snapshot_policy = conversation._get_state_snapshot_policy() + if state_snapshot_policy is None or not should_emit_state_snapshot(conversation, boundary): + return False + + previous_status = conversation.status + previous_status_handled = conversation.status_handled + conversation.status = execution_status + conversation.status_handled = status_handled + + try: + record_event( + StateSnapshotEvent( + conversation_id=conversation.conversation_id, + state_snapshot=dump_conversation_state(conversation), + extra_state=conversation._build_extra_state(), + variable_state=( + dump_variable_state(conversation) + if state_snapshot_policy.include_variable_state + else None + ), + ) + ) + return True + except Exception: + logger.warning( + "Failed to emit state snapshot for conversation '%s'", + conversation.conversation_id, + exc_info=True, + ) + return False + finally: + conversation.status = previous_status + conversation.status_handled = previous_status_handled + + +def _get_current_active_conversation() -> Optional[Conversation]: + active_conversations = _get_active_conversations(return_copy=False) + if not active_conversations: + return None + return active_conversations[-1] + + +def _is_multi_agent_conversation(conversation: Conversation) -> bool: + from wayflowcore.executors._managerworkersconversation import ManagerWorkersConversation + from wayflowcore.executors._swarmconversation import SwarmConversation + + return isinstance(conversation, (ManagerWorkersConversation, SwarmConversation)) + + +def _get_nearest_parent_multi_agent_conversation() -> Optional[Conversation]: + active_conversations = _get_active_conversations(return_copy=False) + if len(active_conversations) < 2: + return None + + for conversation in reversed(active_conversations[:-1]): + if _is_multi_agent_conversation(conversation): + return conversation + + return None + + +class StateSnapshotEventListener(EventListener): + """Emit state snapshots for the active conversation.""" + + def __init__( + self, + conversation: Conversation, + phase: StateSnapshotListenerPhase, + ) -> None: + self.conversation = conversation + self.phase = phase + + def _record_snapshot(self, boundary: StateSnapshotBoundary) -> None: + record_state_snapshot( + self.conversation, + boundary, + execution_status=None, + status_handled=False, + ) + + def _handle_pre_interrupt_event(self, event: Event) -> None: + match event: + case FlowExecutionStartedEvent(): + self._record_snapshot(StateSnapshotBoundary.TURN_START) + case AgentExecutionStartedEvent(): + self._record_snapshot(StateSnapshotBoundary.TURN_START) + case ToolExecutionStartEvent(): + self._record_snapshot(StateSnapshotBoundary.TOOL_START) + case ToolExecutionResultEvent(): + self._record_snapshot(StateSnapshotBoundary.TOOL_END) + case StepInvocationStartEvent(): + self._record_snapshot(StateSnapshotBoundary.NODE_START) + case FlowExecutionIterationFinishedEvent(): + self._record_snapshot(StateSnapshotBoundary.NODE_END) + case AgentExecutionIterationStartedEvent(): + self._record_snapshot(StateSnapshotBoundary.AGENT_LOOP_START) + case AgentExecutionIterationFinishedEvent(): + self._record_snapshot(StateSnapshotBoundary.AGENT_LOOP_END) + + def _handle_pre_interrupt_event_for_parent_multi_agent(self, event: Event) -> None: + match event: + case AgentExecutionStartedEvent() | FlowExecutionStartedEvent(): + self._record_snapshot(StateSnapshotBoundary.TURN_START) + + def _record_turn_end_snapshot( + self, + execution_status: ExecutionStatus | None = None, + ) -> None: + record_state_snapshot( + self.conversation, + StateSnapshotBoundary.TURN_END, + execution_status=execution_status, + status_handled=False, + ) + + def _latest_execution_event_is_turn_end(self) -> bool: + if not self.conversation.state.events: + return False + return self.conversation.state.events[-1].type == ExecutionEventType.EXECUTION_END + + def _should_record_interrupted_turn_end_snapshot( + self, + ) -> bool: + if not self._latest_execution_event_is_turn_end(): + should_record = False + elif not isinstance(get_current_span(), (FlowExecutionSpan, AgentExecutionSpan)): + should_record = False + else: + should_record = True + + return should_record + + def _handle_post_interrupt_event(self, event: Event) -> None: + match event: + case FlowExecutionFinishedEvent( + execution_status=execution_status + ) | AgentExecutionFinishedEvent(execution_status=execution_status): + self._record_turn_end_snapshot(execution_status) + case ExceptionRaisedEvent(exception=ExecutionInterruptedException() as exception): + if self._should_record_interrupted_turn_end_snapshot(): + self._record_turn_end_snapshot(exception.execution_status) + + def _handle_post_interrupt_event_for_parent_multi_agent(self, event: Event) -> None: + match event: + case FlowExecutionFinishedEvent( + execution_status=execution_status + ) | AgentExecutionFinishedEvent(execution_status=execution_status): + self._record_turn_end_snapshot(execution_status) + case ExceptionRaisedEvent(exception=ExecutionInterruptedException() as exception): + self._record_turn_end_snapshot(exception.execution_status) + + def __call__(self, event: Event) -> None: + if isinstance(event, StateSnapshotEvent): + return + + current_conversation = _get_current_active_conversation() + if current_conversation is None: + return + + is_current_conversation = current_conversation.id == self.conversation.id + parent_multi_agent_conversation = _get_nearest_parent_multi_agent_conversation() + is_parent_multi_agent_conversation = ( + parent_multi_agent_conversation is not None + and parent_multi_agent_conversation.id == self.conversation.id + ) + + if not is_current_conversation and not is_parent_multi_agent_conversation: + return + + if self.phase == StateSnapshotListenerPhase.PRE_INTERRUPTS: + if is_current_conversation: + self._handle_pre_interrupt_event(event) + else: + self._handle_pre_interrupt_event_for_parent_multi_agent(event) + else: + if is_current_conversation: + self._handle_post_interrupt_event(event) + else: + self._handle_post_interrupt_event_for_parent_multi_agent(event) + + +@contextmanager +def get_state_snapshot_event_listener_context_for_conversation( + conversation: Conversation, + *, + phase: StateSnapshotListenerPhase, +) -> Iterator[StateSnapshotEventListener]: + current_listener = next( + ( + event_listener + for event_listener in get_event_listeners() + if isinstance(event_listener, StateSnapshotEventListener) + and event_listener.conversation.id == conversation.id + and event_listener.phase == phase + ), + None, + ) + + if current_listener is not None: + yield current_listener + else: + listener = StateSnapshotEventListener(conversation, phase=phase) + with register_event_listeners([listener]): + yield listener diff --git a/wayflowcore/src/wayflowcore/executors/statesnapshotpolicy.py b/wayflowcore/src/wayflowcore/executors/statesnapshotpolicy.py new file mode 100644 index 000000000..f8150092a --- /dev/null +++ b/wayflowcore/src/wayflowcore/executors/statesnapshotpolicy.py @@ -0,0 +1,57 @@ +# Copyright © 2025 Oracle and/or its affiliates. +# +# This software is under the Apache License 2.0 +# (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0) or Universal Permissive License +# (UPL) 1.0 (LICENSE-UPL or https://oss.oracle.com/licenses/upl), at your option. + +from __future__ import annotations + +from dataclasses import dataclass +from enum import Enum +from typing import TYPE_CHECKING, Any, Callable, Dict, Optional + +if TYPE_CHECKING: + from wayflowcore.conversation import Conversation + + +class StateSnapshotInterval(str, Enum): + """ + Configure which execution boundaries emit state snapshots. + + `CONVERSATION_TURNS` + Emit an opening turn snapshot before execution starts and a closing + turn snapshot when the turn finishes or is interrupted at execution + end. This is the default policy because it gives a stable resume point + without emitting snapshots for every internal step. + + `TOOL_TURNS` + Emit the standard closing turn snapshot plus snapshots around each tool + invocation (`TOOL_START` and `TOOL_END`). + + `NODE_TURNS` + Emit the standard closing turn snapshot plus snapshots around each + internal node boundary. For flows this means per-step snapshots; for + agents it maps to decision-loop iteration boundaries. + + `ALL_INTERNAL_TURNS` + Emit the standard closing turn snapshot plus all tool and node + snapshots. + + `OFF` + Disable state snapshot emission entirely. + """ + + CONVERSATION_TURNS = "conversation_turns" + TOOL_TURNS = "tool_turns" + NODE_TURNS = "node_turns" + ALL_INTERNAL_TURNS = "all_internal_turns" + OFF = "off" + + +@dataclass(frozen=True) +class StateSnapshotPolicy: + """Execution-time policy controlling WayFlow state snapshot emission.""" + + state_snapshot_interval: StateSnapshotInterval = StateSnapshotInterval.CONVERSATION_TURNS + include_variable_state: bool = True + extra_state_builder: Optional[Callable[["Conversation"], Optional[Dict[str, Any]]]] = None diff --git a/wayflowcore/src/wayflowcore/serialization/__init__.py b/wayflowcore/src/wayflowcore/serialization/__init__.py index 5d90ab78e..8089ebfdf 100644 --- a/wayflowcore/src/wayflowcore/serialization/__init__.py +++ b/wayflowcore/src/wayflowcore/serialization/__init__.py @@ -4,6 +4,8 @@ # (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0) or Universal Permissive License # (UPL) 1.0 (LICENSE-UPL or https://oss.oracle.com/licenses/upl), at your option. +from typing import Any + from .serializer import ( autodeserialize, deserialize, @@ -12,10 +14,39 @@ serialize_to_dict, ) + +def dump_conversation_state(*args: Any, **kwargs: Any) -> Any: + from .conversation import dump_conversation_state as _dump_conversation_state + + return _dump_conversation_state(*args, **kwargs) + + +def serialize_conversation_state(*args: Any, **kwargs: Any) -> Any: + from .conversation import serialize_conversation_state as _serialize_conversation_state + + return _serialize_conversation_state(*args, **kwargs) + + +def deserialize_conversation_state(*args: Any, **kwargs: Any) -> Any: + from .conversation import deserialize_conversation_state as _deserialize_conversation_state + + return _deserialize_conversation_state(*args, **kwargs) + + +def dump_variable_state(*args: Any, **kwargs: Any) -> Any: + from .conversation import dump_variable_state as _dump_variable_state + + return _dump_variable_state(*args, **kwargs) + + __all__ = [ "autodeserialize", "deserialize", + "deserialize_conversation_state", "deserialize_from_dict", + "dump_conversation_state", + "dump_variable_state", "serialize", + "serialize_conversation_state", "serialize_to_dict", ] diff --git a/wayflowcore/src/wayflowcore/serialization/_builtins_serialization_plugin.py b/wayflowcore/src/wayflowcore/serialization/_builtins_serialization_plugin.py index e10eea2a7..d7e95dc72 100644 --- a/wayflowcore/src/wayflowcore/serialization/_builtins_serialization_plugin.py +++ b/wayflowcore/src/wayflowcore/serialization/_builtins_serialization_plugin.py @@ -134,7 +134,9 @@ from wayflowcore.agentspec.components import ( PluginVllmEmbeddingConfig as AgentSpecPluginVllmEmbeddingConfig, ) -from wayflowcore.agentspec.components import all_serialization_plugin +from wayflowcore.agentspec.components import ( + all_serialization_plugin, +) from wayflowcore.agentspec.components.agent import ExtendedAgent as AgentSpecExtendedAgent from wayflowcore.agentspec.components.contextprovider import ( PluginConstantContextProvider as AgentSpecPluginConstantContextProvider, @@ -278,6 +280,9 @@ from wayflowcore.agentspec.components.transforms import ( PluginLlamaMergeToolRequestAndCallsTransform as AgentSpecPluginLlamaMergeToolRequestAndCallsTransform, ) +from wayflowcore.agentspec.components.transforms import ( + PluginManagerWorkersToolRequestAndCallsTransform as AgentSpecPluginManagerWorkersToolRequestAndCallsTransform, +) from wayflowcore.agentspec.components.transforms import ( PluginReactMergeToolRequestAndCallsTransform as AgentSpecPluginReactMergeToolRequestAndCallsTransform, ) @@ -377,7 +382,9 @@ from wayflowcore.models.ociclientconfig import ( OCIClientConfigWithUserAuthentication as RuntimeOCIClientConfigWithUserAuthentication, ) -from wayflowcore.models.openaicompatiblemodel import EMPTY_API_KEY +from wayflowcore.models.openaicompatiblemodel import ( + EMPTY_API_KEY, +) from wayflowcore.models.openaicompatiblemodel import ( OpenAICompatibleModel as RuntimeOpenAICompatibleModel, ) @@ -435,6 +442,9 @@ ) from wayflowcore.swarm import Swarm as RuntimeSwarm from wayflowcore.templates import PromptTemplate as RuntimePromptTemplate +from wayflowcore.templates._managerworkerstemplate import ( + _ToolRequestAndCallsTransform as RuntimeManagerWorkersToolRequestAndCallsTransform, +) from wayflowcore.templates._swarmtemplate import ( _ToolRequestAndCallsTransform as RuntimeSwarmToolRequestAndCallsTransform, ) @@ -1594,6 +1604,15 @@ def _messagetransform_convert_to_agentspec( runtime_messagetransform ), ) + elif isinstance( + runtime_messagetransform, RuntimeManagerWorkersToolRequestAndCallsTransform + ): + return AgentSpecPluginManagerWorkersToolRequestAndCallsTransform( + name="managerworkerstoolrequestandcalls_messagetransform", + metadata=_create_agentspec_metadata_from_runtime_component( + runtime_messagetransform + ), + ) elif isinstance(runtime_messagetransform, RuntimeSwarmToolRequestAndCallsTransform): return AgentSpecPluginSwarmToolRequestAndCallsTransform( name="swarmtoolrequestandcalls_messagetransform", @@ -2500,6 +2519,11 @@ def _managerworkers_convert_to_agentspec( referenced_objects: Optional[Dict[str, Any]] = None, ) -> AgentSpecManagerWorkers: metadata = _create_agentspec_metadata_from_runtime_component(runtime_managerworkers) + group_manager = ( + runtime_managerworkers.manager_agent + if isinstance(runtime_managerworkers.group_manager, RuntimeLlmModel) + else runtime_managerworkers.group_manager + ) return AgentSpecManagerWorkers( name=runtime_managerworkers.name @@ -2507,17 +2531,9 @@ def _managerworkers_convert_to_agentspec( description=runtime_managerworkers.description or runtime_managerworkers.__metadata_info__.get("description", ""), id=runtime_managerworkers.id, - group_manager=cast( - Union[AgentSpecAgent, AgentSpecLlmConfig], - conversion_context.convert( - runtime_managerworkers.group_manager, referenced_objects - ), - ), + group_manager=conversion_context.convert(group_manager, referenced_objects), workers=[ - cast( - AgentSpecAgent, - conversion_context.convert(worker, referenced_objects), - ) + conversion_context.convert(worker, referenced_objects) for worker in runtime_managerworkers.workers ], inputs=[ diff --git a/wayflowcore/src/wayflowcore/serialization/conversation.py b/wayflowcore/src/wayflowcore/serialization/conversation.py new file mode 100644 index 000000000..becb4a4e6 --- /dev/null +++ b/wayflowcore/src/wayflowcore/serialization/conversation.py @@ -0,0 +1,314 @@ +# Copyright © 2025 Oracle and/or its affiliates. +# +# This software is under the Apache License 2.0 +# (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0) or Universal Permissive License +# (UPL) 1.0 (LICENSE-UPL or https://oss.oracle.com/licenses/upl), at your option. + +from __future__ import annotations + +import json +from datetime import datetime +from enum import Enum +from typing import TYPE_CHECKING, Any, Optional, cast + +from wayflowcore._utils.formatting import stringify +from wayflowcore.executors.executionstatus import ( + AuthChallengeRequestStatus, + ExecutionStatus, + FinishedStatus, + ToolExecutionConfirmationStatus, + ToolRequestStatus, + UserMessageRequestStatus, +) +from wayflowcore.messagelist import ImageContent, Message, MessageContent, TextContent +from wayflowcore.serialization.context import SerializationContext +from wayflowcore.serialization.serializer import serialize_any_to_dict_or_stringify +from wayflowcore.tools.tools import ToolRequest, ToolResult + +if TYPE_CHECKING: + from wayflowcore.conversation import Conversation + from wayflowcore.executors._agentconversation import AgentConversation + from wayflowcore.executors._flowconversation import FlowConversation + + +def _dump_json_compatible_value(value: Any) -> Any: + from wayflowcore.component import Component + from wayflowcore.conversation import Conversation + + dumped_value: Any + if value is None or isinstance(value, (bool, int, float, str)): + dumped_value = value + elif isinstance(value, bytes): + dumped_value = value.decode("utf-8", errors="replace") + elif isinstance(value, datetime): + dumped_value = value.isoformat() + elif isinstance(value, Enum): + dumped_value = _dump_json_compatible_value(value.value) + elif isinstance(value, Conversation): + dumped_value = { + "conversation_id": value.conversation_id, + "conversation_type": value.__class__.__name__, + } + elif isinstance(value, Component): + dumped_value = { + "component_id": value.id, + "component_type": value.__class__.__name__, + } + elif isinstance(value, dict): + dumped_value = { + str(key): _dump_json_compatible_value(inner_value) for key, inner_value in value.items() + } + elif isinstance(value, (list, tuple, set)): + dumped_value = [_dump_json_compatible_value(inner_value) for inner_value in value] + else: + serialized_value = serialize_any_to_dict_or_stringify( + value, SerializationContext(root=value) + ) + if serialized_value is value: + dumped_value = stringify(value) + else: + dumped_value = _dump_json_compatible_value(serialized_value) + return dumped_value + + +def _dump_variable_value(variable_name: str, value: Any) -> Any: + try: + serialized_value = json.dumps(value, sort_keys=True, allow_nan=False) + except (TypeError, ValueError) as e: + raise TypeError( + f"Variable '{variable_name}' contains a non-JSON-serializable value of type {type(value).__name__}" + ) from e + return cast(Any, json.loads(serialized_value)) + + +def _dump_string_keyed_mapping(values: dict[Any, Any]) -> dict[str, Any]: + return {str(key): _dump_json_compatible_value(value) for key, value in values.items()} + + +def _dump_flow_input_output_key_values(values: dict[Any, Any]) -> dict[str, Any]: + return { + f"{step_name}.{value_name}": _dump_json_compatible_value(value) + for (step_name, value_name), value in values.items() + } + + +def _dump_variable_store(variable_store: dict[str, Any]) -> dict[str, Any]: + return { + variable_name: _dump_variable_value(variable_name, variable_value) + for variable_name, variable_value in variable_store.items() + } + + +def _dump_tool_request(tool_request: Optional[ToolRequest]) -> Optional[dict[str, Any]]: + if tool_request is None: + dumped_tool_request = None + else: + dumped_tool_request = { + "name": tool_request.name, + "tool_request_id": tool_request.tool_request_id, + "args": _dump_json_compatible_value(tool_request.args), + "requires_confirmation": tool_request._requires_confirmation, + "tool_execution_confirmed": tool_request._tool_execution_confirmed, + "tool_rejection_reason": tool_request._tool_rejection_reason, + } + return dumped_tool_request + + +def _dump_tool_result(tool_result: Optional[ToolResult]) -> Optional[dict[str, Any]]: + if tool_result is None: + dumped_tool_result = None + else: + dumped_tool_result = { + "tool_request_id": tool_result.tool_request_id, + "content": _dump_json_compatible_value(tool_result.content), + } + return dumped_tool_result + + +def _dump_message_content(content: MessageContent) -> dict[str, Any]: + content_type = getattr(content, "type", content.__class__.__name__) + + if isinstance(content, TextContent): + dumped_content = { + "type": content.type, + "content": content.content, + } + elif isinstance(content, ImageContent): + dumped_content = { + "type": content.type, + "base64_content": content.base64_content, + } + else: + serialized_content = _dump_json_compatible_value(content) + if isinstance(serialized_content, dict): + if "type" in serialized_content: + dumped_content = serialized_content + else: + dumped_content = {"type": content_type, **serialized_content} + else: + dumped_content = { + "type": content_type, + "content": serialized_content, + } + return dumped_content + + +def _dump_message(message: Message) -> dict[str, Any]: + dumped_message: dict[str, Any] = { + "role": message.role, + "message_type": message.message_type.value if message.message_type else None, + "sender": message.sender, + "recipients": sorted(message.recipients), + "time_created": message.time_created.isoformat(), + "time_updated": message.time_updated.isoformat(), + "content": message.content, + "contents": [_dump_message_content(content) for content in message.contents], + } + + tool_requests = [ + _dump_tool_request(tool_request) for tool_request in message.tool_requests or [] + ] + dumped_tool_result = _dump_tool_result(message.tool_result) + + if tool_requests: + dumped_message["tool_requests"] = tool_requests + if dumped_tool_result is not None: + dumped_message["tool_result"] = dumped_tool_result + return dumped_message + + +def _dump_execution_status(execution_status: Optional[ExecutionStatus]) -> Optional[dict[str, Any]]: + dumped_status: dict[str, Any] | None + if execution_status is None: + dumped_status = None + else: + dumped_status = { + "type": execution_status.__class__.__name__, + "conversation_id": execution_status._conversation_id, + } + + if isinstance(execution_status, FinishedStatus): + dumped_status["output_values"] = _dump_json_compatible_value( + execution_status.output_values + ) + dumped_status["complete_step_name"] = execution_status.complete_step_name + elif isinstance(execution_status, UserMessageRequestStatus): + dumped_status["message"] = _dump_message(execution_status.message) + elif isinstance(execution_status, ToolRequestStatus): + dumped_status["tool_requests"] = [ + _dump_tool_request(tool_request) for tool_request in execution_status.tool_requests + ] + dumped_status["tool_results"] = [ + _dump_tool_result(tool_result) + for tool_result in execution_status._tool_results or [] + ] + elif isinstance(execution_status, ToolExecutionConfirmationStatus): + dumped_status["tool_requests"] = [ + _dump_tool_request(tool_request) for tool_request in execution_status.tool_requests + ] + elif isinstance(execution_status, AuthChallengeRequestStatus): + dumped_status["client_transport_id"] = execution_status.client_transport_id + return dumped_status + + +def _dump_conversation_info(conversation: "Conversation") -> dict[str, Any]: + return { + "conversation_id": conversation.conversation_id, + "conversation_type": conversation.__class__.__name__, + "component_type": conversation.component.__class__.__name__, + "name": conversation.name, + "inputs": _dump_json_compatible_value(conversation.inputs), + "messages": [_dump_message(message) for message in conversation.get_messages()], + } + + +def _dump_execution_info(conversation: "Conversation") -> dict[str, Any]: + return { + "current_step_name": conversation.current_step_name, + "status": _dump_execution_status(conversation.status), + "status_handled": conversation.status_handled, + "token_usage": _dump_json_compatible_value(conversation.token_usage), + } + + +def _dump_flow_execution_info(conversation: "FlowConversation") -> dict[str, Any]: + return { + "step_history": list(conversation.state.step_history), + "nesting_level": conversation.state.nesting_level, + "input_output_key_values": _dump_flow_input_output_key_values( + conversation.state.input_output_key_values + ), + "flow_output_values": _dump_json_compatible_value( + conversation.state._flow_output_value_dict + ), + "context_key_values": _dump_json_compatible_value(conversation.state.context_key_values), + "internal_context_key_values": _dump_string_keyed_mapping( + conversation.state.internal_context_key_values + ), + } + + +def _dump_agent_execution_info(conversation: "AgentConversation") -> dict[str, Any]: + return { + "curr_iter": conversation.state.curr_iter, + "has_confirmed_conversation_exit": conversation.state.has_confirmed_conversation_exit, + "tool_call_queue": [ + _dump_tool_request(tool_request) for tool_request in conversation.state.tool_call_queue + ], + "current_tool_request": _dump_tool_request(conversation.state.current_tool_request), + "current_flow_conversation": _dump_json_compatible_value( + conversation.state.current_flow_conversation + ), + "current_sub_component_conversations": _dump_string_keyed_mapping( + conversation.state.current_sub_component_conversations + ), + } + + +def dump_conversation_state(conversation: "Conversation") -> dict[str, Any]: + from wayflowcore.executors._agentconversation import AgentConversation + from wayflowcore.executors._flowconversation import FlowConversation + + if isinstance(conversation, FlowConversation): + execution_info = { + **_dump_execution_info(conversation), + **_dump_flow_execution_info(conversation), + } + elif isinstance(conversation, AgentConversation): + execution_info = { + **_dump_execution_info(conversation), + **_dump_agent_execution_info(conversation), + } + else: + execution_info = _dump_execution_info(conversation) + + return { + "conversation": _dump_conversation_info(conversation), + "execution": execution_info, + } + + +def serialize_conversation_state(conversation: "Conversation") -> str: + return json.dumps(dump_conversation_state(conversation), sort_keys=True) + + +def deserialize_conversation_state(state: str) -> dict[str, Any]: + return cast(dict[str, Any], json.loads(state)) + + +def dump_variable_state(conversation: "Conversation") -> Optional[dict[str, Any]]: + from wayflowcore.executors._flowconversation import FlowConversation + + if not isinstance(conversation, FlowConversation): + variable_state = None + else: + variable_state = _dump_variable_store(conversation.state.variable_store) + return variable_state + + +__all__ = [ + "deserialize_conversation_state", + "dump_conversation_state", + "dump_variable_state", + "serialize_conversation_state", +] diff --git a/wayflowcore/tests/agentspec/test_state_snapshot_tracing.py b/wayflowcore/tests/agentspec/test_state_snapshot_tracing.py new file mode 100644 index 000000000..17ec4b7e2 --- /dev/null +++ b/wayflowcore/tests/agentspec/test_state_snapshot_tracing.py @@ -0,0 +1,694 @@ +# Copyright © 2025 Oracle and/or its affiliates. +# +# This software is under the Apache License 2.0 +# (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0) or Universal Permissive License +# (UPL) 1.0 (LICENSE-UPL or https://oss.oracle.com/licenses/upl), at your option. + +from dataclasses import dataclass +from typing import Any, cast + +import pytest +from pyagentspec.adapters.wayflow import AgentSpecLoader +from pyagentspec.agent import Agent as AgentSpecAgent +from pyagentspec.llms import VllmConfig +from pyagentspec.tracing.events import AgentExecutionEnd as AgentSpecAgentExecutionEnd +from pyagentspec.tracing.events import AgentExecutionStart as AgentSpecAgentExecutionStart +from pyagentspec.tracing.events import Event as AgentSpecEvent +from pyagentspec.tracing.events import FlowExecutionEnd as AgentSpecFlowExecutionEnd +from pyagentspec.tracing.events import FlowExecutionStart as AgentSpecFlowExecutionStart +from pyagentspec.tracing.events import ( + ManagerWorkersExecutionEnd as AgentSpecManagerWorkersExecutionEnd, +) +from pyagentspec.tracing.events import StateSnapshotEmitted as AgentSpecStateSnapshotEmitted +from pyagentspec.tracing.events import SwarmExecutionEnd as AgentSpecSwarmExecutionEnd +from pyagentspec.tracing.spanprocessor import SpanProcessor as AgentSpecSpanProcessor +from pyagentspec.tracing.spans import AgentExecutionSpan as AgentSpecAgentExecutionSpan +from pyagentspec.tracing.spans import FlowExecutionSpan as AgentSpecFlowExecutionSpan +from pyagentspec.tracing.spans import ( + ManagerWorkersExecutionSpan as AgentSpecManagerWorkersExecutionSpan, +) +from pyagentspec.tracing.spans import Span as AgentSpecSpan +from pyagentspec.tracing.spans import SwarmExecutionSpan as AgentSpecSwarmExecutionSpan +from pyagentspec.tracing.spans import ToolExecutionSpan as AgentSpecToolExecutionSpan +from pyagentspec.tracing.trace import Trace as AgentSpecTrace + +from wayflowcore import Agent as WayflowAgent +from wayflowcore.agentspec.tracing import AgentSpecEventListener +from wayflowcore.events.eventlistener import register_event_listeners +from wayflowcore.executors.executionstatus import FinishedStatus, UserMessageRequestStatus +from wayflowcore.executors.statesnapshotpolicy import ( + StateSnapshotInterval, + StateSnapshotPolicy, +) +from wayflowcore.flow import Flow +from wayflowcore.managerworkers import ManagerWorkers +from wayflowcore.messagelist import Message, MessageType +from wayflowcore.models.vllmmodel import VllmModel +from wayflowcore.serialization import dump_conversation_state +from wayflowcore.steps import AgentExecutionStep, CompleteStep, OutputMessageStep, ToolExecutionStep +from wayflowcore.swarm import Swarm +from wayflowcore.tools import ServerTool, ToolRequest + +from ..testhelpers.patching import patch_llm + +pytestmark = pytest.mark.skipif( + AgentSpecStateSnapshotEmitted is None, + reason="Installed pyagentspec does not expose StateSnapshotEmitted", +) + + +class SnapshotSpanRecorder(AgentSpecSpanProcessor): + def __init__(self) -> None: + super().__init__() + self.started_spans: list[AgentSpecSpan] = [] + self.ended_spans: list[AgentSpecSpan] = [] + + def on_start(self, span: AgentSpecSpan) -> None: + self.started_spans.append(span) + + async def on_start_async(self, span: AgentSpecSpan) -> None: + self.started_spans.append(span) + + def on_end(self, span: AgentSpecSpan) -> None: + self.ended_spans.append(span) + + async def on_end_async(self, span: AgentSpecSpan) -> None: + self.ended_spans.append(span) + + def on_event(self, event: AgentSpecEvent, span: AgentSpecSpan) -> None: + return None + + async def on_event_async(self, event: AgentSpecEvent, span: AgentSpecSpan) -> None: + return None + + def startup(self) -> None: + return None + + def shutdown(self) -> None: + return None + + async def startup_async(self) -> None: + return None + + async def shutdown_async(self) -> None: + return None + + +@dataclass(frozen=True) +class ExportedAGUIStateSnapshot: + conversation_id: str + snapshot: dict[str, Any] + + +class AGUIStateSnapshotExporter(AgentSpecSpanProcessor): + def __init__(self) -> None: + super().__init__() + self.exported_snapshots: list[ExportedAGUIStateSnapshot] = [] + + def on_start(self, span: AgentSpecSpan) -> None: + return None + + async def on_start_async(self, span: AgentSpecSpan) -> None: + return None + + def on_end(self, span: AgentSpecSpan) -> None: + return None + + async def on_end_async(self, span: AgentSpecSpan) -> None: + return None + + def on_event(self, event: AgentSpecEvent, span: AgentSpecSpan) -> None: + if not isinstance(event, AgentSpecStateSnapshotEmitted): + return + + conversation_snapshot = (event.state_snapshot or {}).get("conversation", {}) + self.exported_snapshots.append( + ExportedAGUIStateSnapshot( + conversation_id=event.conversation_id, + snapshot={ + "messages": conversation_snapshot.get("messages", []), + "input": conversation_snapshot.get("inputs", {}).get("input"), + "agent_state": (event.extra_state or {}).get("agent_state", {}), + }, + ) + ) + + async def on_event_async(self, event: AgentSpecEvent, span: AgentSpecSpan) -> None: + self.on_event(event, span) + + def startup(self) -> None: + return None + + def shutdown(self) -> None: + return None + + async def startup_async(self) -> None: + return None + + async def shutdown_async(self) -> None: + return None + + +_RETRIEVAL_INPUTS = { + "input": "How many orders last week?", + "thread_id": "thread-123", + "agent_type": "planner", + "llm_model_name": "gpt-5-mini", + "default_schema": "sales", + "input_document": "Only use the sales schema and weekly order metrics.", +} + +_RETRIEVAL_UI_STATE = { + "preplan": { + "summary": "Inspect weekly sales orders and answer concisely.", + "entries": [ + "Inspect the active schema", + "Aggregate last week's orders", + "Return the final answer", + ], + "ready_to_proceed": True, + }, + "assumptions": [ + {"text": "Use the sales schema only", "status": "approved"}, + {"text": "Week boundaries follow UTC", "status": "auto_approved"}, + ], +} + + +def _create_retrieval_like_wayflow_agent() -> WayflowAgent: + agentspec_agent = AgentSpecAgent( + name="retrieval_agent", + llm_config=VllmConfig(name="llm", url="http://mock.url", model_id="mock.model"), + system_prompt="You are a helpful retrieval agent.", + ) + return cast(WayflowAgent, AgentSpecLoader().load_component(agentspec_agent)) + + +def _build_retrieval_agent_state( + *, + conversation_inputs: dict[str, Any], + message_count: int, + last_response: str, +) -> dict[str, Any]: + return { + "thread_id": conversation_inputs["thread_id"], + "agent_type": conversation_inputs["agent_type"], + "llm_model_name": conversation_inputs["llm_model_name"], + "default_schema": conversation_inputs["default_schema"], + "input_document": conversation_inputs["input_document"], + "message_count": message_count, + "last_response": last_response, + "ui": _RETRIEVAL_UI_STATE, + } + + +def _build_retrieval_like_extra_state(conversation) -> dict[str, Any]: + conversation_snapshot = dump_conversation_state(conversation)["conversation"] + messages = conversation_snapshot["messages"] + last_response = next( + ( + message.get("content") + for message in reversed(messages) + if message.get("role") == "assistant" and message.get("content") + ), + "", + ) + return { + "agent_state": _build_retrieval_agent_state( + conversation_inputs=conversation.inputs, + message_count=len(messages), + last_response=last_response, + ) + } + + +def _create_mock_vllm_model(name: str) -> VllmModel: + return VllmModel(model_id="mock.model", host_port="http://mock.url", name=name) + + +def _create_send_message_request(recipient_name: str, message: str) -> Message: + return Message( + content="", + message_type=MessageType.TOOL_REQUEST, + tool_requests=[ + ToolRequest( + name="send_message", + args={"recipient": recipient_name, "message": message}, + ) + ], + ) + + +def _build_managerworkers_state_snapshot_flow() -> tuple[ + Flow, + VllmModel, + list[Message | str], + VllmModel, + list[Message | str], + type[AgentSpecSpan], + str, + str, + type[AgentSpecEvent], +]: + manager_llm = _create_mock_vllm_model("manager") + worker_llm = _create_mock_vllm_model("worker") + worker = WayflowAgent(llm=worker_llm, name="worker", description="worker") + managerworkers = ManagerWorkers(group_manager=manager_llm, workers=[worker], name="team") + + return ( + Flow.from_steps([AgentExecutionStep(agent=managerworkers), CompleteStep(name="end")]), + manager_llm, + [_create_send_message_request("worker", "Do it"), "manager final answer"], + worker_llm, + ["worker answer"], + AgentSpecManagerWorkersExecutionSpan, + "worker answer", + "manager final answer", + AgentSpecManagerWorkersExecutionEnd, + ) + + +def _build_swarm_state_snapshot_flow() -> tuple[ + Flow, + VllmModel, + list[Message | str], + VllmModel, + list[Message | str], + type[AgentSpecSpan], + str, + str, + type[AgentSpecEvent], +]: + first_agent_llm = _create_mock_vllm_model("agent1") + second_agent_llm = _create_mock_vllm_model("agent2") + first_agent = WayflowAgent(llm=first_agent_llm, name="agent1", description="agent1") + second_agent = WayflowAgent(llm=second_agent_llm, name="agent2", description="agent2") + swarm = Swarm( + first_agent=first_agent, + relationships=[(first_agent, second_agent), (second_agent, first_agent)], + name="swarm", + ) + + return ( + Flow.from_steps([AgentExecutionStep(agent=swarm), CompleteStep(name="end")]), + first_agent_llm, + [_create_send_message_request("agent2", "Do it"), "agent1 final answer"], + second_agent_llm, + ["agent2 answer"], + AgentSpecSwarmExecutionSpan, + "agent2 answer", + "agent1 final answer", + AgentSpecSwarmExecutionEnd, + ) + + +def test_flow_state_snapshots_are_mapped_into_the_flow_span_before_flow_end() -> None: + flow = Flow.from_steps( + [OutputMessageStep(message_template="Hello"), CompleteStep(name="end")], + step_names=["single_step", "end"], + ) + conversation = flow.start_conversation() + listener = AgentSpecEventListener() + span_recorder = SnapshotSpanRecorder() + + with AgentSpecTrace(span_processors=[span_recorder]): + with register_event_listeners([listener]): + status = conversation.execute( + state_snapshot_policy=StateSnapshotPolicy( + state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS, + extra_state_builder=lambda _conversation: {"ui": {"active_tab": "plan"}}, + ) + ) + + assert isinstance(status, FinishedStatus) + + flow_spans = [ + span for span in span_recorder.started_spans if isinstance(span, AgentSpecFlowExecutionSpan) + ] + assert len(flow_spans) == 1 + flow_span = flow_spans[0] + flow_events = flow_span.events + + assert any(isinstance(event, AgentSpecFlowExecutionStart) for event in flow_events) + flow_end_event = next( + event for event in flow_events if isinstance(event, AgentSpecFlowExecutionEnd) + ) + state_snapshot_events = [ + event for event in flow_events if isinstance(event, AgentSpecStateSnapshotEmitted) + ] + + # From an Agent Spec consumer point of view, the flow span should expose the + # opening and closing checkpoints, and the closing checkpoint must still + # appear before the terminal flow-end event. + assert len(state_snapshot_events) == 2 + final_snapshot_event = state_snapshot_events[-1] + assert final_snapshot_event.conversation_id == conversation.conversation_id + assert final_snapshot_event.state_snapshot["conversation"]["messages"][-1]["content"] == "Hello" + assert final_snapshot_event.extra_state == {"ui": {"active_tab": "plan"}} + assert flow_events.index(final_snapshot_event) < flow_events.index(flow_end_event) + assert flow_span.end_time is not None + assert final_snapshot_event.timestamp <= flow_span.end_time + assert "variable_state" not in final_snapshot_event.model_dump(mask_sensitive_information=False) + assert flow_span in span_recorder.ended_spans + + +def test_off_policy_does_not_bridge_state_snapshots_into_agent_spec_spans() -> None: + flow = Flow.from_steps( + [OutputMessageStep(message_template="Hello"), CompleteStep(name="end")], + step_names=["single_step", "end"], + ) + conversation = flow.start_conversation() + listener = AgentSpecEventListener() + span_recorder = SnapshotSpanRecorder() + + with AgentSpecTrace(span_processors=[span_recorder]): + with register_event_listeners([listener]): + status = conversation.execute( + state_snapshot_policy=StateSnapshotPolicy( + state_snapshot_interval=StateSnapshotInterval.OFF + ) + ) + + assert isinstance(status, FinishedStatus) + + flow_spans = [ + span for span in span_recorder.started_spans if isinstance(span, AgentSpecFlowExecutionSpan) + ] + assert len(flow_spans) == 1 + flow_events = flow_spans[0].events + + assert any(isinstance(event, AgentSpecFlowExecutionStart) for event in flow_events) + assert any(isinstance(event, AgentSpecFlowExecutionEnd) for event in flow_events) + assert not any(isinstance(event, AgentSpecStateSnapshotEmitted) for event in flow_events) + + +def test_agent_state_snapshots_support_the_agui_retrieval_export_flow() -> None: + assistant_message = "I checked the warehouse and found 42 orders last week." + wayflow_agent = _create_retrieval_like_wayflow_agent() + conversation = wayflow_agent.start_conversation(inputs=_RETRIEVAL_INPUTS) + conversation.append_user_message(_RETRIEVAL_INPUTS["input"]) + + listener = AgentSpecEventListener() + span_recorder = SnapshotSpanRecorder() + agui_exporter = AGUIStateSnapshotExporter() + + with patch_llm(wayflow_agent.llm, [assistant_message], patch_internal=True): + with AgentSpecTrace(span_processors=[span_recorder, agui_exporter]): + with register_event_listeners([listener]): + status = conversation.execute( + state_snapshot_policy=StateSnapshotPolicy( + state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS, + extra_state_builder=_build_retrieval_like_extra_state, + ) + ) + + assert isinstance(status, UserMessageRequestStatus) + + agent_spans = [ + span + for span in span_recorder.started_spans + if isinstance(span, AgentSpecAgentExecutionSpan) + ] + assert len(agent_spans) == 1 + agent_span = agent_spans[0] + agent_events = agent_span.events + + assert any(isinstance(event, AgentSpecAgentExecutionStart) for event in agent_events) + agent_end_event = next( + event for event in agent_events if isinstance(event, AgentSpecAgentExecutionEnd) + ) + state_snapshot_events = [ + event for event in agent_events if isinstance(event, AgentSpecStateSnapshotEmitted) + ] + assert len(state_snapshot_events) == 2 + + final_snapshot_event = state_snapshot_events[-1] + runtime_messages = final_snapshot_event.state_snapshot["conversation"]["messages"] + expected_agent_state = _build_retrieval_agent_state( + conversation_inputs=_RETRIEVAL_INPUTS, + message_count=len(runtime_messages), + last_response=assistant_message, + ) + + # This retrieval example is the main product use-case: a downstream AG-UI + # style exporter should be able to reconstruct the latest UI-facing state + # directly from the final snapshot event on the agent execution span. + assert final_snapshot_event.conversation_id == conversation.conversation_id + assert ( + final_snapshot_event.state_snapshot["conversation"]["inputs"]["input"] + == _RETRIEVAL_INPUTS["input"] + ) + assert runtime_messages[-1]["content"] == assistant_message + assert final_snapshot_event.extra_state == {"agent_state": expected_agent_state} + assert agent_events.index(final_snapshot_event) < agent_events.index(agent_end_event) + + assert len(agui_exporter.exported_snapshots) == 2 + assert agui_exporter.exported_snapshots[-1] == ExportedAGUIStateSnapshot( + conversation_id=conversation.conversation_id, + snapshot={ + "messages": runtime_messages, + "input": _RETRIEVAL_INPUTS["input"], + "agent_state": expected_agent_state, + }, + ) + + +@pytest.mark.parametrize( + "flow_builder", + [ + pytest.param(_build_managerworkers_state_snapshot_flow, id="managerworkers"), + pytest.param(_build_swarm_state_snapshot_flow, id="swarm"), + ], +) +def test_nested_multi_agent_state_snapshots_follow_conversation_ownership_boundaries( + flow_builder, +) -> None: + ( + flow, + primary_llm, + primary_outputs, + secondary_llm, + secondary_outputs, + expected_multi_agent_span_class, + expected_child_message, + expected_parent_message, + expected_multi_agent_end_event_class, + ) = flow_builder() + conversation = flow.start_conversation() + conversation.append_user_message("dummy") + + listener = AgentSpecEventListener() + span_recorder = SnapshotSpanRecorder() + + with patch_llm(primary_llm, primary_outputs, patch_internal=True): + with patch_llm(secondary_llm, secondary_outputs, patch_internal=True): + with AgentSpecTrace(span_processors=[span_recorder]): + with register_event_listeners([listener]): + status = conversation.execute( + state_snapshot_policy=StateSnapshotPolicy( + state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS + ) + ) + + assert isinstance(status, UserMessageRequestStatus) + + flow_spans = [ + span for span in span_recorder.started_spans if isinstance(span, AgentSpecFlowExecutionSpan) + ] + assert len(flow_spans) == 1 + flow_span = flow_spans[0] + flow_snapshot_events = [ + event for event in flow_span.events if isinstance(event, AgentSpecStateSnapshotEmitted) + ] + assert len(flow_snapshot_events) == 2 + assert [event.conversation_id for event in flow_snapshot_events] == [ + conversation.conversation_id, + conversation.conversation_id, + ] + assert ( + flow_snapshot_events[-1].state_snapshot["conversation"]["messages"][-1]["content"] + == expected_parent_message + ) + + multi_agent_spans = [ + span + for span in span_recorder.started_spans + if isinstance(span, expected_multi_agent_span_class) + ] + assert len(multi_agent_spans) == 1 + multi_agent_span = multi_agent_spans[0] + multi_agent_snapshot_events = [ + event + for event in multi_agent_span.events + if isinstance(event, AgentSpecStateSnapshotEmitted) + ] + multi_agent_end_event = next( + event + for event in multi_agent_span.events + if isinstance(event, expected_multi_agent_end_event_class) + ) + parent_multi_agent_conversation_id = multi_agent_snapshot_events[0].conversation_id + + # The parent multi-agent conversation brackets both child turns. It keeps a + # single conversation id while the manager/main-thread agent and the + # delegated child each emit snapshots on their own agent execution spans. + assert [event.conversation_id for event in multi_agent_snapshot_events] == [ + parent_multi_agent_conversation_id, + parent_multi_agent_conversation_id, + parent_multi_agent_conversation_id, + parent_multi_agent_conversation_id, + parent_multi_agent_conversation_id, + parent_multi_agent_conversation_id, + ] + assert [ + ( + event.state_snapshot["execution"]["status"]["type"] + if event.state_snapshot["execution"]["status"] is not None + else None + ) + for event in multi_agent_snapshot_events + ] == [ + None, + "ToolRequestStatus", + None, + "UserMessageRequestStatus", + None, + "UserMessageRequestStatus", + ] + assert ( + multi_agent_snapshot_events[4].state_snapshot["conversation"]["messages"][-1]["content"] + == expected_child_message + ) + assert ( + multi_agent_snapshot_events[-1].state_snapshot["conversation"]["messages"][-1]["content"] + == expected_parent_message + ) + assert multi_agent_span.events.index( + multi_agent_snapshot_events[-1] + ) < multi_agent_span.events.index(multi_agent_end_event) + + agent_snapshot_spans = [ + span + for span in span_recorder.started_spans + if isinstance(span, AgentSpecAgentExecutionSpan) + and any(isinstance(event, AgentSpecStateSnapshotEmitted) for event in span.events) + ] + assert len(agent_snapshot_spans) == 3 + agent_snapshot_events_by_conversation_id: dict[str, list[AgentSpecStateSnapshotEmitted]] = {} + for agent_span in agent_snapshot_spans: + snapshot_events = [ + event for event in agent_span.events if isinstance(event, AgentSpecStateSnapshotEmitted) + ] + agent_snapshot_events_by_conversation_id.setdefault( + snapshot_events[0].conversation_id, + [], + ).extend(snapshot_events) + + assert len(agent_snapshot_events_by_conversation_id) == 2 + manager_thread_snapshot_events = next( + snapshot_events + for snapshot_events in agent_snapshot_events_by_conversation_id.values() + if len(snapshot_events) == 4 + ) + delegated_agent_snapshot_events = next( + snapshot_events + for snapshot_events in agent_snapshot_events_by_conversation_id.values() + if len(snapshot_events) == 2 + ) + + assert manager_thread_snapshot_events[0].conversation_id != conversation.conversation_id + assert manager_thread_snapshot_events[0].conversation_id != parent_multi_agent_conversation_id + assert delegated_agent_snapshot_events[0].conversation_id not in { + conversation.conversation_id, + parent_multi_agent_conversation_id, + manager_thread_snapshot_events[0].conversation_id, + } + assert [ + ( + event.state_snapshot["execution"]["status"]["type"] + if event.state_snapshot["execution"]["status"] is not None + else None + ) + for event in manager_thread_snapshot_events + ] == [ + None, + "ToolRequestStatus", + None, + "UserMessageRequestStatus", + ] + assert ( + manager_thread_snapshot_events[2].state_snapshot["conversation"]["messages"][-1]["content"] + == expected_child_message + ) + assert ( + manager_thread_snapshot_events[-1].state_snapshot["conversation"]["messages"][-1]["content"] + == expected_parent_message + ) + assert [ + ( + event.state_snapshot["execution"]["status"]["type"] + if event.state_snapshot["execution"]["status"] is not None + else None + ) + for event in delegated_agent_snapshot_events + ] == [None, "UserMessageRequestStatus"] + assert ( + delegated_agent_snapshot_events[-1].state_snapshot["conversation"]["messages"][-1][ + "content" + ] + == expected_child_message + ) + + tool_spans = [ + span for span in span_recorder.started_spans if isinstance(span, AgentSpecToolExecutionSpan) + ] + assert tool_spans + assert not any( + isinstance(event, AgentSpecStateSnapshotEmitted) + for span in tool_spans + for event in span.events + ) + + assert flow_span in span_recorder.ended_spans + assert multi_agent_span in span_recorder.ended_spans + + +def test_only_the_opening_state_snapshot_is_exported_when_a_turn_raises() -> None: + flow = Flow.from_steps( + [ + ToolExecutionStep( + tool=ServerTool( + name="explode", + description="Raise an error", + func=lambda: (_ for _ in ()).throw(RuntimeError("boom")), + input_descriptors=[], + ) + ), + CompleteStep(name="end"), + ] + ) + conversation = flow.start_conversation() + listener = AgentSpecEventListener() + span_recorder = SnapshotSpanRecorder() + + with AgentSpecTrace(span_processors=[span_recorder]): + with register_event_listeners([listener]): + with pytest.raises(RuntimeError, match="boom"): + conversation.execute( + state_snapshot_policy=StateSnapshotPolicy( + state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS + ) + ) + + flow_spans = [ + span for span in span_recorder.started_spans if isinstance(span, AgentSpecFlowExecutionSpan) + ] + assert len(flow_spans) == 1 + flow_span = flow_spans[0] + state_snapshot_events = [ + event for event in flow_span.events if isinstance(event, AgentSpecStateSnapshotEmitted) + ] + + assert len(state_snapshot_events) == 1 + assert state_snapshot_events[0].state_snapshot["execution"]["status"] is None + assert flow_span in span_recorder.ended_spans diff --git a/wayflowcore/tests/events/test_state_snapshot_event.py b/wayflowcore/tests/events/test_state_snapshot_event.py new file mode 100644 index 000000000..ee3970e48 --- /dev/null +++ b/wayflowcore/tests/events/test_state_snapshot_event.py @@ -0,0 +1,46 @@ +# Copyright © 2025 Oracle and/or its affiliates. +# +# This software is under the Apache License 2.0 +# (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0) or Universal Permissive License +# (UPL) 1.0 (LICENSE-UPL or https://oss.oracle.com/licenses/upl), at your option. + +import pytest + +from wayflowcore.events.event import _PII_TEXT_MASK, StateSnapshotEvent + + +def test_state_snapshot_event_requires_conversation_id() -> None: + with pytest.raises(ValueError, match="conversation_id"): + StateSnapshotEvent() + + +@pytest.mark.parametrize("mask_sensitive_information", [True, False]) +def test_state_snapshot_event_serialization( + mask_sensitive_information: bool, +) -> None: + event = StateSnapshotEvent( + conversation_id="conversation-123", + state_snapshot={"conversation": {"messages": []}}, + extra_state={"ui": {"active_tab": "plan"}}, + variable_state={"count": 2}, + name="snapshot", + event_id="evt-1", + timestamp=12, + ) + + serialized_event = event.to_tracing_info(mask_sensitive_information=mask_sensitive_information) + + assert serialized_event["event_type"] == "StateSnapshotEvent" + assert serialized_event["conversation_id"] == "conversation-123" + assert serialized_event["name"] == "snapshot" + assert serialized_event["event_id"] == "evt-1" + assert serialized_event["timestamp"] == 12 + + if mask_sensitive_information: + assert serialized_event["state_snapshot"] == _PII_TEXT_MASK + assert serialized_event["extra_state"] == _PII_TEXT_MASK + assert serialized_event["variable_state"] == _PII_TEXT_MASK + else: + assert serialized_event["state_snapshot"] == {"conversation": {"messages": []}} + assert serialized_event["extra_state"] == {"ui": {"active_tab": "plan"}} + assert serialized_event["variable_state"] == {"count": 2} diff --git a/wayflowcore/tests/events/test_state_snapshot_runtime.py b/wayflowcore/tests/events/test_state_snapshot_runtime.py new file mode 100644 index 000000000..f056a0e93 --- /dev/null +++ b/wayflowcore/tests/events/test_state_snapshot_runtime.py @@ -0,0 +1,754 @@ +# Copyright © 2025 Oracle and/or its affiliates. +# +# This software is under the Apache License 2.0 +# (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0) or Universal Permissive License +# (UPL) 1.0 (LICENSE-UPL or https://oss.oracle.com/licenses/upl), at your option. + +import threading +from contextlib import nullcontext +from typing import Sequence + +import pytest + +from wayflowcore.agent import Agent +from wayflowcore.conversation import Conversation +from wayflowcore.events.event import Event, StateSnapshotEvent +from wayflowcore.events.eventlistener import EventListener, register_event_listeners +from wayflowcore.executors._events.event import EventType +from wayflowcore.executors._executionstate import ConversationExecutionState +from wayflowcore.executors.executionstatus import FinishedStatus, UserMessageRequestStatus +from wayflowcore.executors.interrupts.executioninterrupt import ( + ExecutionInterrupt, + InterruptedExecutionStatus, + _NullExecutionInterrupt, +) +from wayflowcore.executors.statesnapshotpolicy import ( + StateSnapshotInterval, + StateSnapshotPolicy, +) +from wayflowcore.flow import Flow +from wayflowcore.managerworkers import ManagerWorkers +from wayflowcore.messagelist import Message, MessageType +from wayflowcore.serialization.serializer import SerializableNeedToBeImplementedMixin +from wayflowcore.steps import ( + AgentExecutionStep, + CompleteStep, + FlowExecutionStep, + OutputMessageStep, + ToolExecutionStep, +) +from wayflowcore.swarm import Swarm +from wayflowcore.tools import ServerTool, ToolRequest, tool + +from ..conftest import disable_streaming +from ..test_interrupts import OnEventExecutionInterrupt +from ..testhelpers.dummy import DummyModel + +# Runtime snapshot tests stay focused on emission semantics. Event payload +# mapping and serialization details live in dedicated tracing/serialization +# suites. + + +class SnapshotCollector(EventListener): + def __init__(self) -> None: + self.state_snapshot_events: list[StateSnapshotEvent] = [] + + def __call__(self, event: Event) -> None: + if isinstance(event, StateSnapshotEvent): + self.state_snapshot_events.append(event) + + +class MutatingExecutionEndInterrupt(SerializableNeedToBeImplementedMixin, _NullExecutionInterrupt): + def __init__(self) -> None: + self.lock = threading.Lock() + self.count = 0 + super().__init__() + + def _on_execution_end( + self, + state: ConversationExecutionState, + conversation: Conversation, + ) -> InterruptedExecutionStatus | None: + conversation.inputs["preview_count"] = conversation.inputs.get("preview_count", 0) + 1 + self.count += 1 + return None + + +def _create_output_flow_conversation(message: str = "Hello") -> Conversation: + flow = Flow.from_steps( + [ + OutputMessageStep(message_template=message), + CompleteStep(name="end"), + ] + ) + return flow.start_conversation() + + +def _create_agent_conversation(message: str = "Hello from agent") -> Conversation: + llm = DummyModel() + llm.set_next_output(message) + conversation = Agent(llm=llm).start_conversation() + conversation.append_user_message("Hi") + return conversation + + +def _create_tool_calling_agent_conversation() -> Conversation: + @tool + def do_nothing_tool() -> str: + """Do nothing tool.""" + return "Tool called successfully" + + llm = DummyModel() + llm.set_next_output( + { + "Please use the do_nothing_tool": Message( + message_type=MessageType.TOOL_REQUEST, + content="I am calling the do nothing tool", + tool_requests=[ToolRequest("do_nothing_tool", {}, "tc1")], + ) + } + ) + conversation = Agent(llm=llm, tools=[do_nothing_tool], max_iterations=10).start_conversation() + conversation.append_user_message("Please use the do_nothing_tool") + return conversation + + +def _create_send_message_request(recipient_name: str, message: str) -> Message: + return Message( + content="", + message_type=MessageType.TOOL_REQUEST, + tool_requests=[ + ToolRequest( + name="send_message", + args={"recipient": recipient_name, "message": message}, + ) + ], + ) + + +def _create_nested_agent_step_flow_conversation() -> Conversation: + llm = DummyModel() + llm.set_next_output("agent answer") + child_agent = Agent(llm=llm) + conversation = Flow.from_steps( + [AgentExecutionStep(agent=child_agent), CompleteStep(name="end")] + ).start_conversation() + conversation.append_user_message("dummy") + return conversation + + +def _create_nested_managerworkers_flow_conversation() -> Conversation: + llm = DummyModel() + worker = Agent(llm=llm, name="worker", description="worker") + group = ManagerWorkers(group_manager=llm, workers=[worker]) + llm.set_next_output( + [ + _create_send_message_request("worker", "Do it"), + "worker answer", + "manager final answer", + ] + ) + + conversation = Flow.from_steps( + [AgentExecutionStep(agent=group), CompleteStep(name="end")] + ).start_conversation() + conversation.append_user_message("dummy") + return conversation + + +def _create_nested_swarm_flow_conversation() -> Conversation: + llm = DummyModel() + first_agent = Agent(llm=llm, name="agent1", description="agent1") + second_agent = Agent(llm=llm, name="agent2", description="agent2") + swarm = Swarm( + first_agent=first_agent, + relationships=[(first_agent, second_agent), (second_agent, first_agent)], + ) + llm.set_next_output( + [ + _create_send_message_request("agent2", "Do it"), + "agent2 answer", + "agent1 final answer", + ] + ) + + conversation = Flow.from_steps( + [AgentExecutionStep(agent=swarm), CompleteStep(name="end")] + ).start_conversation() + conversation.append_user_message("dummy") + return conversation + + +def _snapshot_status_types(snapshot_events: Sequence[StateSnapshotEvent]) -> list[str | None]: + return [ + ( + snapshot_event.state_snapshot["execution"]["status"]["type"] + if snapshot_event.state_snapshot["execution"]["status"] is not None + else None + ) + for snapshot_event in snapshot_events + ] + + +def _execute_with_state_snapshots( + conversation: Conversation, + *, + state_snapshot_policy: StateSnapshotPolicy, + execution_interrupts: Sequence[ExecutionInterrupt] | None = None, + use_disable_streaming: bool = False, +) -> tuple[object, list[StateSnapshotEvent]]: + collector = SnapshotCollector() + streaming_context = disable_streaming() if use_disable_streaming else nullcontext() + + with streaming_context: + with register_event_listeners([collector]): + status = conversation.execute( + execution_interrupts=execution_interrupts, + state_snapshot_policy=state_snapshot_policy, + ) + + return status, collector.state_snapshot_events + + +@pytest.mark.parametrize( + ( + "conversation_factory", + "expected_status_class", + "expected_status_type", + "expected_message", + ), + [ + pytest.param( + _create_output_flow_conversation, + FinishedStatus, + "FinishedStatus", + "Hello", + id="flow", + ), + pytest.param( + _create_agent_conversation, + UserMessageRequestStatus, + "UserMessageRequestStatus", + "Hello from agent", + id="agent", + ), + ], +) +def test_conversation_turn_policy_records_opening_and_closing_checkpoints( + conversation_factory, + expected_status_class, + expected_status_type: str, + expected_message: str, +) -> None: + conversation = conversation_factory() + + status, state_snapshot_events = _execute_with_state_snapshots( + conversation, + state_snapshot_policy=StateSnapshotPolicy( + state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS + ), + ) + + assert isinstance(status, expected_status_class) + assert _snapshot_status_types(state_snapshot_events) == [None, expected_status_type] + assert ( + state_snapshot_events[-1].state_snapshot["conversation"]["messages"][-1]["content"] + == expected_message + ) + assert state_snapshot_events[-1].state_snapshot["execution"]["status_handled"] is False + + +@pytest.mark.parametrize( + ("conversation_factory", "expected_status_class"), + [ + pytest.param(_create_output_flow_conversation, FinishedStatus, id="flow"), + pytest.param(_create_agent_conversation, UserMessageRequestStatus, id="agent"), + ], +) +def test_off_policy_disables_state_snapshot_emission( + conversation_factory, + expected_status_class, +) -> None: + conversation = conversation_factory() + + status, state_snapshot_events = _execute_with_state_snapshots( + conversation, + state_snapshot_policy=StateSnapshotPolicy( + state_snapshot_interval=StateSnapshotInterval.OFF + ), + ) + + assert isinstance(status, expected_status_class) + assert state_snapshot_events == [] + + +@pytest.mark.parametrize( + ("conversation_factory", "expected_message"), + [ + pytest.param(_create_output_flow_conversation, "Hello", id="flow"), + pytest.param(_create_agent_conversation, "Hello from agent", id="agent"), + ], +) +def test_conversation_turn_policy_records_interrupted_turn_end_checkpoints( + conversation_factory, + expected_message: str, +) -> None: + conversation = conversation_factory() + + status, state_snapshot_events = _execute_with_state_snapshots( + conversation, + execution_interrupts=[OnEventExecutionInterrupt(EventType.EXECUTION_END)], + state_snapshot_policy=StateSnapshotPolicy( + state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS + ), + ) + + assert isinstance(status, InterruptedExecutionStatus) + assert _snapshot_status_types(state_snapshot_events) == [None, "InterruptedExecutionStatus"] + assert ( + state_snapshot_events[-1].state_snapshot["conversation"]["messages"][-1]["content"] + == expected_message + ) + assert state_snapshot_events[-1].state_snapshot["execution"]["status_handled"] is False + + +def test_conversation_turn_policy_keeps_only_the_opening_checkpoint_when_turn_raises() -> None: + def explode() -> str: + raise RuntimeError("boom") + + conversation = Flow.from_steps( + [ + ToolExecutionStep( + tool=ServerTool( + name="explode", + description="Raise an error", + func=explode, + input_descriptors=[], + ) + ), + CompleteStep(name="end"), + ] + ).start_conversation() + collector = SnapshotCollector() + + with register_event_listeners([collector]): + with pytest.raises(RuntimeError, match="boom"): + conversation.execute( + state_snapshot_policy=StateSnapshotPolicy( + state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS + ) + ) + + assert len(collector.state_snapshot_events) == 1 + assert collector.state_snapshot_events[0].state_snapshot["execution"]["status"] is None + + +def test_conversation_turn_policy_reflects_real_interrupt_side_effects_once() -> None: + conversation = _create_output_flow_conversation() + interrupt = MutatingExecutionEndInterrupt() + + status, state_snapshot_events = _execute_with_state_snapshots( + conversation, + execution_interrupts=[interrupt], + state_snapshot_policy=StateSnapshotPolicy( + state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS + ), + ) + + assert isinstance(status, FinishedStatus) + assert interrupt.count == 1 + assert conversation.inputs["preview_count"] == 1 + assert _snapshot_status_types(state_snapshot_events) == [None, "FinishedStatus"] + assert state_snapshot_events[-1].state_snapshot["conversation"]["inputs"]["preview_count"] == 1 + + +@pytest.mark.parametrize( + ( + "conversation_factory", + "expected_status_class", + "expected_status_types", + "expected_snapshot_count", + "expected_curr_iters", + ), + [ + pytest.param( + _create_output_flow_conversation, + FinishedStatus, + [None, None, None, None, None, None, "FinishedStatus"], + 7, + None, + id="flow", + ), + pytest.param( + _create_agent_conversation, + UserMessageRequestStatus, + [None, None, "UserMessageRequestStatus"], + 3, + [0, 1], + id="agent", + ), + ], +) +def test_node_turn_policy_tracks_flow_steps_and_agent_iterations( + conversation_factory, + expected_status_class, + expected_status_types: list[str | None], + expected_snapshot_count: int, + expected_curr_iters: list[int] | None, +) -> None: + conversation = conversation_factory() + + status, state_snapshot_events = _execute_with_state_snapshots( + conversation, + state_snapshot_policy=StateSnapshotPolicy( + state_snapshot_interval=StateSnapshotInterval.NODE_TURNS + ), + ) + + # NODE_TURNS means step start/end checkpoints for flows and iteration + # start/end checkpoints for agents, plus the final turn-end checkpoint. + # Flow.from_steps(...) also inserts an internal StartStep, so the flow case + # includes start/end checkpoints for that step too. + assert isinstance(status, expected_status_class) + assert len(state_snapshot_events) == expected_snapshot_count + assert _snapshot_status_types(state_snapshot_events) == expected_status_types + if expected_curr_iters is not None: + assert [ + state_snapshot_events[0].state_snapshot["execution"]["curr_iter"], + state_snapshot_events[1].state_snapshot["execution"]["curr_iter"], + ] == expected_curr_iters + + +@pytest.mark.parametrize( + ("conversation_factory", "interrupt_event"), + [ + pytest.param(_create_output_flow_conversation, EventType.STEP_EXECUTION_START, id="flow"), + pytest.param(_create_agent_conversation, EventType.GENERATION_START, id="agent"), + ], +) +def test_node_turn_policy_keeps_partial_progress_when_interrupted_mid_turn( + conversation_factory, + interrupt_event: EventType, +) -> None: + conversation = conversation_factory() + + status, state_snapshot_events = _execute_with_state_snapshots( + conversation, + execution_interrupts=[OnEventExecutionInterrupt(interrupt_event)], + state_snapshot_policy=StateSnapshotPolicy( + state_snapshot_interval=StateSnapshotInterval.NODE_TURNS + ), + ) + + assert isinstance(status, InterruptedExecutionStatus) + assert _snapshot_status_types(state_snapshot_events) == [None] + + +def test_internal_snapshots_do_not_reuse_the_previous_turn_status() -> None: + llm = DummyModel() + llm.set_next_output(["Hello from agent", "Hello again"]) + conversation = Agent(llm=llm).start_conversation() + conversation.append_user_message("Hi") + collector = SnapshotCollector() + policy = StateSnapshotPolicy(state_snapshot_interval=StateSnapshotInterval.NODE_TURNS) + + with register_event_listeners([collector]): + first_status = conversation.execute(state_snapshot_policy=policy) + assert isinstance(first_status, UserMessageRequestStatus) + + first_status.submit_user_response("Continue") + second_status = conversation.execute(state_snapshot_policy=policy) + + assert isinstance(second_status, UserMessageRequestStatus) + assert len(collector.state_snapshot_events) == 6 + + second_turn_internal_snapshots = collector.state_snapshot_events[3:5] + assert _snapshot_status_types(second_turn_internal_snapshots) == [None, None] + assert all( + snapshot_event.state_snapshot["execution"]["status_handled"] is False + for snapshot_event in second_turn_internal_snapshots + ) + + +def test_state_snapshot_policy_is_inherited_by_nested_sub_conversations() -> None: + child_flow = Flow.from_steps( + [ + OutputMessageStep(message_template="child"), + CompleteStep(name="end"), + ] + ) + parent_flow = Flow.from_steps( + [ + FlowExecutionStep(flow=child_flow), + CompleteStep(name="end"), + ] + ) + conversation = parent_flow.start_conversation() + + status, state_snapshot_events = _execute_with_state_snapshots( + conversation, + state_snapshot_policy=StateSnapshotPolicy( + state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS + ), + ) + + assert isinstance(status, FinishedStatus) + assert len(state_snapshot_events) == 4 + assert {snapshot_event.conversation_id for snapshot_event in state_snapshot_events} == { + conversation.conversation_id + } + + +def test_state_snapshot_policy_is_inherited_by_nested_agent_steps() -> None: + conversation = _create_nested_agent_step_flow_conversation() + + status, state_snapshot_events = _execute_with_state_snapshots( + conversation, + state_snapshot_policy=StateSnapshotPolicy( + state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS + ), + ) + + nested_conversation_id = state_snapshot_events[1].conversation_id + + # A parent flow keeps its own opening/closing checkpoints, while the nested + # agent contributes its own opening/closing pair under the child + # conversation id. + assert isinstance(status, UserMessageRequestStatus) + assert _snapshot_status_types(state_snapshot_events) == [ + None, + None, + "UserMessageRequestStatus", + "UserMessageRequestStatus", + ] + assert [snapshot_event.conversation_id for snapshot_event in state_snapshot_events] == [ + conversation.conversation_id, + nested_conversation_id, + nested_conversation_id, + conversation.conversation_id, + ] + assert nested_conversation_id != conversation.conversation_id + assert ( + state_snapshot_events[-1].state_snapshot["conversation"]["messages"][-1]["content"] + == "agent answer" + ) + + +@pytest.mark.parametrize( + ( + "conversation_factory", + "expected_multi_agent_component_type", + "expected_child_message", + "expected_parent_message", + ), + [ + pytest.param( + _create_nested_managerworkers_flow_conversation, + "ManagerWorkers", + "worker answer", + "manager final answer", + id="managerworkers", + ), + pytest.param( + _create_nested_swarm_flow_conversation, + "Swarm", + "agent2 answer", + "agent1 final answer", + id="swarm", + ), + ], +) +def test_nested_multi_agent_components_emit_snapshots_for_the_active_conversation( + conversation_factory, + expected_multi_agent_component_type: str, + expected_child_message: str, + expected_parent_message: str, +) -> None: + conversation = conversation_factory() + + status, state_snapshot_events = _execute_with_state_snapshots( + conversation, + state_snapshot_policy=StateSnapshotPolicy( + state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS + ), + ) + + assert isinstance(status, UserMessageRequestStatus) + assert len(state_snapshot_events) == 14 + + snapshot_events_by_conversation_id: dict[str, list[StateSnapshotEvent]] = {} + for snapshot_event in state_snapshot_events: + snapshot_events_by_conversation_id.setdefault( + snapshot_event.conversation_id, + [], + ).append(snapshot_event) + + # A nested multi-agent turn has four independent snapshot streams: + # the outer flow, the parent multi-agent conversation, the manager/main + # thread agent conversation (which runs twice), and the delegated child. + assert len(snapshot_events_by_conversation_id) == 4 + + flow_snapshot_events = snapshot_events_by_conversation_id[conversation.conversation_id] + parent_multi_agent_snapshot_events = next( + snapshot_events + for snapshot_events in snapshot_events_by_conversation_id.values() + if snapshot_events[0].state_snapshot["conversation"]["component_type"] + == expected_multi_agent_component_type + ) + agent_snapshot_event_groups = [ + snapshot_events + for conversation_id, snapshot_events in snapshot_events_by_conversation_id.items() + if conversation_id + not in { + conversation.conversation_id, + parent_multi_agent_snapshot_events[0].conversation_id, + } + ] + manager_thread_snapshot_events = next( + snapshot_events + for snapshot_events in agent_snapshot_event_groups + if len(snapshot_events) == 4 + ) + delegated_agent_snapshot_events = next( + snapshot_events + for snapshot_events in agent_snapshot_event_groups + if len(snapshot_events) == 2 + ) + + assert _snapshot_status_types(flow_snapshot_events) == [None, "UserMessageRequestStatus"] + assert ( + flow_snapshot_events[-1].state_snapshot["conversation"]["messages"][-1]["content"] + == expected_parent_message + ) + + # The parent multi-agent conversation records checkpoints each time control + # enters or returns from a child turn, which is what lets a UI reconstruct + # the parent-level progress independently from the child conversations. + assert _snapshot_status_types(parent_multi_agent_snapshot_events) == [ + None, + "ToolRequestStatus", + None, + "UserMessageRequestStatus", + None, + "UserMessageRequestStatus", + ] + assert ( + parent_multi_agent_snapshot_events[4].state_snapshot["conversation"]["messages"][-1][ + "content" + ] + == expected_child_message + ) + assert ( + parent_multi_agent_snapshot_events[-1].state_snapshot["conversation"]["messages"][-1][ + "content" + ] + == expected_parent_message + ) + + # The manager/main thread agent conversation spans two execution turns: + # one that delegates and one that resumes after the child reply. + assert _snapshot_status_types(manager_thread_snapshot_events) == [ + None, + "ToolRequestStatus", + None, + "UserMessageRequestStatus", + ] + assert ( + manager_thread_snapshot_events[2].state_snapshot["conversation"]["messages"][-1]["content"] + == expected_child_message + ) + assert ( + manager_thread_snapshot_events[-1].state_snapshot["conversation"]["messages"][-1]["content"] + == expected_parent_message + ) + + assert _snapshot_status_types(delegated_agent_snapshot_events) == [ + None, + "UserMessageRequestStatus", + ] + assert ( + delegated_agent_snapshot_events[-1].state_snapshot["conversation"]["messages"][-1][ + "content" + ] + == expected_child_message + ) + + +def test_state_snapshot_emission_survives_broken_extra_state_builder() -> None: + def broken_builder(_conversation: Conversation) -> dict[str, object]: + raise RuntimeError("boom") + + conversation = _create_output_flow_conversation() + + status, state_snapshot_events = _execute_with_state_snapshots( + conversation, + state_snapshot_policy=StateSnapshotPolicy( + state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS, + extra_state_builder=broken_builder, + ), + ) + + assert isinstance(status, FinishedStatus) + assert len(state_snapshot_events) == 2 + assert all(snapshot_event.extra_state is None for snapshot_event in state_snapshot_events) + + +@pytest.mark.parametrize( + ( + "conversation_factory", + "execution_interrupts", + "use_disable_streaming", + "expected_status_class", + "expected_status_types", + ), + [ + pytest.param( + lambda: Flow.from_steps( + [ + ToolExecutionStep( + tool=ServerTool( + name="say_hi", + description="Say hi", + func=lambda: "hi", + input_descriptors=[], + ) + ), + CompleteStep(name="end"), + ] + ).start_conversation(), + None, + False, + FinishedStatus, + [None, None, "FinishedStatus"], + id="flow-success", + ), + pytest.param( + lambda: _create_tool_calling_agent_conversation(), + [OnEventExecutionInterrupt(EventType.TOOL_CALL_END)], + True, + InterruptedExecutionStatus, + [None, None], + id="agent-tool-end-interrupt", + ), + ], +) +def test_tool_turn_policy_records_real_tool_boundaries( + conversation_factory, + execution_interrupts: Sequence[ExecutionInterrupt] | None, + use_disable_streaming: bool, + expected_status_class, + expected_status_types: list[str | None], +) -> None: + conversation = conversation_factory() + + status, state_snapshot_events = _execute_with_state_snapshots( + conversation, + execution_interrupts=execution_interrupts, + state_snapshot_policy=StateSnapshotPolicy( + state_snapshot_interval=StateSnapshotInterval.TOOL_TURNS + ), + use_disable_streaming=use_disable_streaming, + ) + + assert isinstance(status, expected_status_class) + assert _snapshot_status_types(state_snapshot_events) == expected_status_types diff --git a/wayflowcore/tests/serialization/test_conversation_state_snapshot.py b/wayflowcore/tests/serialization/test_conversation_state_snapshot.py new file mode 100644 index 000000000..168a848c3 --- /dev/null +++ b/wayflowcore/tests/serialization/test_conversation_state_snapshot.py @@ -0,0 +1,92 @@ +# Copyright © 2025 Oracle and/or its affiliates. +# +# This software is under the Apache License 2.0 +# (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0) or Universal Permissive License +# (UPL) 1.0 (LICENSE-UPL or https://oss.oracle.com/licenses/upl), at your option. + +import json +from typing import Any + +import pytest + +from wayflowcore.conversation import Conversation +from wayflowcore.flow import Flow +from wayflowcore.property import AnyProperty, StringProperty +from wayflowcore.serialization import ( + deserialize_conversation_state, + dump_conversation_state, + dump_variable_state, + serialize_conversation_state, +) +from wayflowcore.steps import OutputMessageStep, VariableWriteStep +from wayflowcore.variable import Variable + + +class _UnserializableValue: + def __str__(self) -> str: + return "custom-value" + + +def _build_snapshot_flow(custom_variable: Variable) -> Flow: + return Flow.from_steps( + steps=[ + VariableWriteStep( + variable=custom_variable, + input_mapping={VariableWriteStep.VALUE: custom_variable.name}, + ), + OutputMessageStep(message_template="Hello there"), + ], + variables=[custom_variable], + name="snapshot_flow", + ) + + +def _walk_scalars(value: Any): + if isinstance(value, dict): + for inner_value in value.values(): + yield from _walk_scalars(inner_value) + return + if isinstance(value, list): + for inner_value in value: + yield from _walk_scalars(inner_value) + return + yield value + + +def test_dump_conversation_state_is_json_serializable_and_lightweight() -> None: + custom_variable = Variable( + name="custom", + type=StringProperty(), + description="Custom variable used for snapshot serialization tests", + ) + flow = _build_snapshot_flow(custom_variable) + conversation = flow.start_conversation(inputs={custom_variable.name: "custom-value"}) + conversation.execute() + + snapshot = dump_conversation_state(conversation) + variable_state = dump_variable_state(conversation) + serialized_snapshot = serialize_conversation_state(conversation) + + assert json.loads(json.dumps(snapshot)) == deserialize_conversation_state(serialized_snapshot) + assert variable_state == {"custom": "custom-value"} + assert snapshot["conversation"]["component_type"] == "Flow" + assert snapshot["conversation"]["messages"][-1]["content"] == "Hello there" + + assert all( + not isinstance(scalar, (Conversation, Flow, OutputMessageStep)) + for scalar in _walk_scalars(snapshot) + ) + + +def test_dump_variable_state_rejects_non_json_serializable_values() -> None: + custom_variable = Variable( + name="custom", + type=AnyProperty(), + description="Custom variable used for snapshot serialization tests", + ) + flow = _build_snapshot_flow(custom_variable) + conversation = flow.start_conversation(inputs={custom_variable.name: _UnserializableValue()}) + conversation.execute() + + with pytest.raises(TypeError, match="Variable 'custom' contains a non-JSON-serializable"): + dump_variable_state(conversation) From 58541326f3295bd9eb4e07c3bd690a709ba3ef40 Mon Sep 17 00:00:00 2001 From: Son Le Date: Mon, 16 Mar 2026 14:47:02 +0100 Subject: [PATCH 02/13] fix: avoid synthetic turn-end snapshots for interrupted parent multi-agent turns --- .../executors/_statesnapshot_eventlistener.py | 3 +- .../events/test_state_snapshot_runtime.py | 74 +++++++++++++++++++ 2 files changed, 76 insertions(+), 1 deletion(-) diff --git a/wayflowcore/src/wayflowcore/executors/_statesnapshot_eventlistener.py b/wayflowcore/src/wayflowcore/executors/_statesnapshot_eventlistener.py index 561e0e455..9b15de6f5 100644 --- a/wayflowcore/src/wayflowcore/executors/_statesnapshot_eventlistener.py +++ b/wayflowcore/src/wayflowcore/executors/_statesnapshot_eventlistener.py @@ -282,7 +282,8 @@ def _handle_post_interrupt_event_for_parent_multi_agent(self, event: Event) -> N ) | AgentExecutionFinishedEvent(execution_status=execution_status): self._record_turn_end_snapshot(execution_status) case ExceptionRaisedEvent(exception=ExecutionInterruptedException() as exception): - self._record_turn_end_snapshot(exception.execution_status) + if self._should_record_interrupted_turn_end_snapshot(): + self._record_turn_end_snapshot(exception.execution_status) def __call__(self, event: Event) -> None: if isinstance(event, StateSnapshotEvent): diff --git a/wayflowcore/tests/events/test_state_snapshot_runtime.py b/wayflowcore/tests/events/test_state_snapshot_runtime.py index f056a0e93..df5650a56 100644 --- a/wayflowcore/tests/events/test_state_snapshot_runtime.py +++ b/wayflowcore/tests/events/test_state_snapshot_runtime.py @@ -74,6 +74,29 @@ def _on_execution_end( return None +class WorkerExecutionEndInterrupt(SerializableNeedToBeImplementedMixin, _NullExecutionInterrupt): + def __init__(self) -> None: + self.triggered = False + super().__init__() + + def _on_execution_end( + self, + state: ConversationExecutionState, + conversation: Conversation, + ) -> InterruptedExecutionStatus | None: + if self.triggered: + return None + if getattr(conversation.component, "name", None) != "worker": + return None + + self.triggered = True + return InterruptedExecutionStatus( + interrupter=self, + reason="worker execution end", + _conversation_id=conversation.id, + ) + + def _create_output_flow_conversation(message: str = "Hello") -> Conversation: flow = Flow.from_steps( [ @@ -156,6 +179,23 @@ def _create_nested_managerworkers_flow_conversation() -> Conversation: return conversation +def _create_managerworkers_conversation() -> Conversation: + llm = DummyModel() + worker = Agent(llm=llm, name="worker", description="worker") + group = ManagerWorkers(group_manager=llm, workers=[worker]) + llm.set_next_output( + [ + _create_send_message_request("worker", "Do it"), + "worker answer", + "manager final answer", + ] + ) + + conversation = group.start_conversation() + conversation.append_user_message("dummy") + return conversation + + def _create_nested_swarm_flow_conversation() -> Conversation: llm = DummyModel() first_agent = Agent(llm=llm, name="agent1", description="agent1") @@ -362,6 +402,40 @@ def test_conversation_turn_policy_reflects_real_interrupt_side_effects_once() -> assert state_snapshot_events[-1].state_snapshot["conversation"]["inputs"]["preview_count"] == 1 +def test_parent_multi_agent_does_not_emit_turn_end_snapshot_when_child_turn_is_interrupted() -> ( + None +): + conversation = _create_managerworkers_conversation() + + status, state_snapshot_events = _execute_with_state_snapshots( + conversation, + execution_interrupts=[WorkerExecutionEndInterrupt()], + state_snapshot_policy=StateSnapshotPolicy( + state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS + ), + use_disable_streaming=True, + ) + + assert isinstance(status, InterruptedExecutionStatus) + + snapshot_events_by_conversation_id: dict[str, list[StateSnapshotEvent]] = {} + for snapshot_event in state_snapshot_events: + snapshot_events_by_conversation_id.setdefault( + snapshot_event.conversation_id, + [], + ).append(snapshot_event) + + parent_multi_agent_snapshot_events = next( + snapshot_events + for snapshot_events in snapshot_events_by_conversation_id.values() + if snapshot_events[0].state_snapshot["conversation"]["component_type"] == "ManagerWorkers" + ) + + assert "InterruptedExecutionStatus" not in _snapshot_status_types( + parent_multi_agent_snapshot_events + ) + + @pytest.mark.parametrize( ( "conversation_factory", From 22796df61405ad9cde683573bf929b94cd7db427 Mon Sep 17 00:00:00 2001 From: Son Le Date: Mon, 16 Mar 2026 15:24:09 +0100 Subject: [PATCH 03/13] fix state snapshot emission robustness --- .../agentspec/components/__init__.py | 3 -- .../agentspec/components/transforms.py | 6 --- .../src/wayflowcore/agentspec/tracing.py | 49 +------------------ .../executors/_statesnapshot_eventlistener.py | 17 +++++-- .../_builtins_serialization_plugin.py | 32 ++++-------- .../agentspec/test_state_snapshot_tracing.py | 37 -------------- .../events/test_state_snapshot_runtime.py | 39 +++++++++++++++ 7 files changed, 63 insertions(+), 120 deletions(-) diff --git a/wayflowcore/src/wayflowcore/agentspec/components/__init__.py b/wayflowcore/src/wayflowcore/agentspec/components/__init__.py index 5cc2e53b1..dfa345296 100644 --- a/wayflowcore/src/wayflowcore/agentspec/components/__init__.py +++ b/wayflowcore/src/wayflowcore/agentspec/components/__init__.py @@ -115,7 +115,6 @@ from .transforms import ( PluginAppendTrailingSystemMessageToUserMessageTransform, PluginCoalesceSystemMessagesTransform, - PluginManagerWorkersToolRequestAndCallsTransform, PluginReactMergeToolRequestAndCallsTransform, PluginRemoveEmptyNonUserMessageTransform, PluginSwarmToolRequestAndCallsTransform, @@ -226,7 +225,6 @@ "contextprovider_deserialization_plugin", "PluginAppendTrailingSystemMessageToUserMessageTransform", "PluginCoalesceSystemMessagesTransform", - "PluginManagerWorkersToolRequestAndCallsTransform", "PluginRemoveEmptyNonUserMessageTransform", "PluginReactMergeToolRequestAndCallsTransform", "PluginSwarmToolRequestAndCallsTransform", @@ -249,7 +247,6 @@ "PluginManagerWorkers", "PluginAppendTrailingSystemMessageToUserMessageTransform", "PluginCoalesceSystemMessagesTransform", - "PluginManagerWorkersToolRequestAndCallsTransform", "PluginReactMergeToolRequestAndCallsTransform", "PluginRemoveEmptyNonUserMessageTransform", "messagetransform_deserialization_plugin", diff --git a/wayflowcore/src/wayflowcore/agentspec/components/transforms.py b/wayflowcore/src/wayflowcore/agentspec/components/transforms.py index 4b55c1c7a..2e2bc0bc3 100644 --- a/wayflowcore/src/wayflowcore/agentspec/components/transforms.py +++ b/wayflowcore/src/wayflowcore/agentspec/components/transforms.py @@ -58,10 +58,6 @@ class PluginSwarmToolRequestAndCallsTransform(MessageTransform): sequence of messages.""" -class PluginManagerWorkersToolRequestAndCallsTransform(MessageTransform): - """Format Tool requests as Agent messages and Tool results as User messages for manager-workers prompts.""" - - class PluginCanonicalizationMessageTransform(MessageTransform): """ Produce a conversation shaped like: @@ -102,7 +98,6 @@ class PluginSplitPromptOnMarkerMessageTransform(MessageTransform): PluginAppendTrailingSystemMessageToUserMessageTransform.__name__: PluginAppendTrailingSystemMessageToUserMessageTransform, PluginLlamaMergeToolRequestAndCallsTransform.__name__: PluginLlamaMergeToolRequestAndCallsTransform, PluginReactMergeToolRequestAndCallsTransform.__name__: PluginReactMergeToolRequestAndCallsTransform, - PluginManagerWorkersToolRequestAndCallsTransform.__name__: PluginManagerWorkersToolRequestAndCallsTransform, PluginSwarmToolRequestAndCallsTransform.__name__: PluginSwarmToolRequestAndCallsTransform, PluginCanonicalizationMessageTransform.__name__: PluginCanonicalizationMessageTransform, PluginSplitPromptOnMarkerMessageTransform.__name__: PluginSplitPromptOnMarkerMessageTransform, @@ -116,7 +111,6 @@ class PluginSplitPromptOnMarkerMessageTransform(MessageTransform): PluginAppendTrailingSystemMessageToUserMessageTransform.__name__: PluginAppendTrailingSystemMessageToUserMessageTransform, PluginLlamaMergeToolRequestAndCallsTransform.__name__: PluginLlamaMergeToolRequestAndCallsTransform, PluginReactMergeToolRequestAndCallsTransform.__name__: PluginReactMergeToolRequestAndCallsTransform, - PluginManagerWorkersToolRequestAndCallsTransform.__name__: PluginManagerWorkersToolRequestAndCallsTransform, PluginSwarmToolRequestAndCallsTransform.__name__: PluginSwarmToolRequestAndCallsTransform, PluginCanonicalizationMessageTransform.__name__: PluginCanonicalizationMessageTransform, PluginSplitPromptOnMarkerMessageTransform.__name__: PluginSplitPromptOnMarkerMessageTransform, diff --git a/wayflowcore/src/wayflowcore/agentspec/tracing.py b/wayflowcore/src/wayflowcore/agentspec/tracing.py index a66ab754c..ae70c8b4b 100644 --- a/wayflowcore/src/wayflowcore/agentspec/tracing.py +++ b/wayflowcore/src/wayflowcore/agentspec/tracing.py @@ -12,7 +12,6 @@ from pyagentspec.flows.node import Node as AgentSpecNode from pyagentspec.llms import LlmConfig as AgentSpecLlmConfig from pyagentspec.llms import LlmGenerationConfig -from pyagentspec.managerworkers import ManagerWorkers as AgentSpecManagerWorkers from pyagentspec.swarm import Swarm as AgentSpecSwarm from pyagentspec.tools import Tool as AgentSpecTool from pyagentspec.tracing.events import AgentExecutionEnd as AgentSpecAgentExecutionEnd @@ -25,12 +24,6 @@ ) from pyagentspec.tracing.events import LlmGenerationRequest as AgentSpecLlmGenerationRequest from pyagentspec.tracing.events import LlmGenerationResponse as AgentSpecLlmGenerationResponse -from pyagentspec.tracing.events import ( - ManagerWorkersExecutionEnd as AgentSpecManagerWorkersExecutionEnd, -) -from pyagentspec.tracing.events import ( - ManagerWorkersExecutionStart as AgentSpecManagerWorkersExecutionStart, -) from pyagentspec.tracing.events import NodeExecutionEnd as AgentSpecNodeExecutionEnd from pyagentspec.tracing.events import NodeExecutionStart as AgentSpecNodeExecutionStart from pyagentspec.tracing.events import StateSnapshotEmitted as AgentSpecStateSnapshotEmitted @@ -43,9 +36,6 @@ from pyagentspec.tracing.spans import AgentExecutionSpan as AgentSpecAgentExecutionSpan from pyagentspec.tracing.spans import FlowExecutionSpan as AgentSpecFlowExecutionSpan from pyagentspec.tracing.spans import LlmGenerationSpan as AgentSpecLlmGenerationSpan -from pyagentspec.tracing.spans import ( - ManagerWorkersExecutionSpan as AgentSpecManagerWorkersExecutionSpan, -) from pyagentspec.tracing.spans import NodeExecutionSpan as AgentSpecNodeExecutionSpan from pyagentspec.tracing.spans import Span as AgentSpecSpan from pyagentspec.tracing.spans import SwarmExecutionSpan as AgentSpecSwarmExecutionSpan @@ -74,7 +64,6 @@ ) from wayflowcore.events.eventlistener import EventListener from wayflowcore.executors.executionstatus import FinishedStatus -from wayflowcore.managerworkers import ManagerWorkers as RuntimeManagerWorkers from wayflowcore.steps.agentexecutionstep import AgentExecutionStep as RuntimeAgentExecutionStep from wayflowcore.swarm import Swarm as RuntimeSwarm from wayflowcore.tracing.span import LlmGenerationSpan, get_active_span_stack, get_current_span @@ -130,31 +119,7 @@ def _start_multi_agent_span_if_needed( if not isinstance(event.step, RuntimeAgentExecutionStep): return - if isinstance(event.step.agent, RuntimeManagerWorkers): - agentspec_managerworkers = cast( - AgentSpecManagerWorkers, self._convert_to_agentspec(event.step.agent) - ) - multi_agent_span: AgentSpecManagerWorkersExecutionSpan | AgentSpecSwarmExecutionSpan = ( - AgentSpecManagerWorkersExecutionSpan( - id=f"{current_span_id}:managerworkers", - name=f"ManagerWorkersExecution[{event.step.agent._get_display_name()}]", - managerworkers=agentspec_managerworkers, - ) - ) - multi_agent_span.start() - multi_agent_span.add_event( - AgentSpecManagerWorkersExecutionStart( - id=event.event_id, - name=event_name, - managerworkers=agentspec_managerworkers, - inputs={ - input_name: input_value for input_name, input_value in event.inputs.items() - }, - ) - ) - self._multi_agent_spans_by_step_span_id[current_span_id] = multi_agent_span - self._pending_multi_agent_spans_by_component_id[event.step.agent.id] = multi_agent_span - elif isinstance(event.step.agent, RuntimeSwarm): + if isinstance(event.step.agent, RuntimeSwarm): agentspec_swarm = cast(AgentSpecSwarm, self._convert_to_agentspec(event.step.agent)) multi_agent_span = AgentSpecSwarmExecutionSpan( id=f"{current_span_id}:swarm", @@ -193,16 +158,7 @@ def _end_multi_agent_span_if_needed( for output_name, output_value in event.step_result.outputs.items() if output_name != "__execution_status__" } - if isinstance(multi_agent_span, AgentSpecManagerWorkersExecutionSpan): - multi_agent_span.add_event( - AgentSpecManagerWorkersExecutionEnd( - id=event.event_id, - name=event_name, - managerworkers=multi_agent_span.managerworkers, - outputs=outputs, - ) - ) - elif isinstance(multi_agent_span, AgentSpecSwarmExecutionSpan): + if isinstance(multi_agent_span, AgentSpecSwarmExecutionSpan): multi_agent_span.add_event( AgentSpecSwarmExecutionEnd( id=event.event_id, @@ -255,7 +211,6 @@ def _move_snapshot_before_terminal_event( AgentSpecAgentExecutionEnd, AgentSpecExceptionRaised, AgentSpecFlowExecutionEnd, - AgentSpecManagerWorkersExecutionEnd, AgentSpecSwarmExecutionEnd, ), ): diff --git a/wayflowcore/src/wayflowcore/executors/_statesnapshot_eventlistener.py b/wayflowcore/src/wayflowcore/executors/_statesnapshot_eventlistener.py index 9b15de6f5..b2d926075 100644 --- a/wayflowcore/src/wayflowcore/executors/_statesnapshot_eventlistener.py +++ b/wayflowcore/src/wayflowcore/executors/_statesnapshot_eventlistener.py @@ -143,16 +143,23 @@ def record_state_snapshot( conversation.status_handled = status_handled try: + variable_state = None + if state_snapshot_policy.include_variable_state: + try: + variable_state = dump_variable_state(conversation) + except Exception: + logger.warning( + "Failed to dump variable state for conversation '%s'", + conversation.conversation_id, + exc_info=True, + ) + record_event( StateSnapshotEvent( conversation_id=conversation.conversation_id, state_snapshot=dump_conversation_state(conversation), extra_state=conversation._build_extra_state(), - variable_state=( - dump_variable_state(conversation) - if state_snapshot_policy.include_variable_state - else None - ), + variable_state=variable_state, ) ) return True diff --git a/wayflowcore/src/wayflowcore/serialization/_builtins_serialization_plugin.py b/wayflowcore/src/wayflowcore/serialization/_builtins_serialization_plugin.py index d7e95dc72..799fbf3b2 100644 --- a/wayflowcore/src/wayflowcore/serialization/_builtins_serialization_plugin.py +++ b/wayflowcore/src/wayflowcore/serialization/_builtins_serialization_plugin.py @@ -280,9 +280,6 @@ from wayflowcore.agentspec.components.transforms import ( PluginLlamaMergeToolRequestAndCallsTransform as AgentSpecPluginLlamaMergeToolRequestAndCallsTransform, ) -from wayflowcore.agentspec.components.transforms import ( - PluginManagerWorkersToolRequestAndCallsTransform as AgentSpecPluginManagerWorkersToolRequestAndCallsTransform, -) from wayflowcore.agentspec.components.transforms import ( PluginReactMergeToolRequestAndCallsTransform as AgentSpecPluginReactMergeToolRequestAndCallsTransform, ) @@ -442,9 +439,6 @@ ) from wayflowcore.swarm import Swarm as RuntimeSwarm from wayflowcore.templates import PromptTemplate as RuntimePromptTemplate -from wayflowcore.templates._managerworkerstemplate import ( - _ToolRequestAndCallsTransform as RuntimeManagerWorkersToolRequestAndCallsTransform, -) from wayflowcore.templates._swarmtemplate import ( _ToolRequestAndCallsTransform as RuntimeSwarmToolRequestAndCallsTransform, ) @@ -1604,15 +1598,6 @@ def _messagetransform_convert_to_agentspec( runtime_messagetransform ), ) - elif isinstance( - runtime_messagetransform, RuntimeManagerWorkersToolRequestAndCallsTransform - ): - return AgentSpecPluginManagerWorkersToolRequestAndCallsTransform( - name="managerworkerstoolrequestandcalls_messagetransform", - metadata=_create_agentspec_metadata_from_runtime_component( - runtime_messagetransform - ), - ) elif isinstance(runtime_messagetransform, RuntimeSwarmToolRequestAndCallsTransform): return AgentSpecPluginSwarmToolRequestAndCallsTransform( name="swarmtoolrequestandcalls_messagetransform", @@ -2519,11 +2504,6 @@ def _managerworkers_convert_to_agentspec( referenced_objects: Optional[Dict[str, Any]] = None, ) -> AgentSpecManagerWorkers: metadata = _create_agentspec_metadata_from_runtime_component(runtime_managerworkers) - group_manager = ( - runtime_managerworkers.manager_agent - if isinstance(runtime_managerworkers.group_manager, RuntimeLlmModel) - else runtime_managerworkers.group_manager - ) return AgentSpecManagerWorkers( name=runtime_managerworkers.name @@ -2531,9 +2511,17 @@ def _managerworkers_convert_to_agentspec( description=runtime_managerworkers.description or runtime_managerworkers.__metadata_info__.get("description", ""), id=runtime_managerworkers.id, - group_manager=conversion_context.convert(group_manager, referenced_objects), + group_manager=cast( + Union[AgentSpecAgent, AgentSpecLlmConfig], + conversion_context.convert( + runtime_managerworkers.group_manager, referenced_objects + ), + ), workers=[ - conversion_context.convert(worker, referenced_objects) + cast( + AgentSpecAgent, + conversion_context.convert(worker, referenced_objects), + ) for worker in runtime_managerworkers.workers ], inputs=[ diff --git a/wayflowcore/tests/agentspec/test_state_snapshot_tracing.py b/wayflowcore/tests/agentspec/test_state_snapshot_tracing.py index 17ec4b7e2..2534e7a7e 100644 --- a/wayflowcore/tests/agentspec/test_state_snapshot_tracing.py +++ b/wayflowcore/tests/agentspec/test_state_snapshot_tracing.py @@ -16,17 +16,11 @@ from pyagentspec.tracing.events import Event as AgentSpecEvent from pyagentspec.tracing.events import FlowExecutionEnd as AgentSpecFlowExecutionEnd from pyagentspec.tracing.events import FlowExecutionStart as AgentSpecFlowExecutionStart -from pyagentspec.tracing.events import ( - ManagerWorkersExecutionEnd as AgentSpecManagerWorkersExecutionEnd, -) from pyagentspec.tracing.events import StateSnapshotEmitted as AgentSpecStateSnapshotEmitted from pyagentspec.tracing.events import SwarmExecutionEnd as AgentSpecSwarmExecutionEnd from pyagentspec.tracing.spanprocessor import SpanProcessor as AgentSpecSpanProcessor from pyagentspec.tracing.spans import AgentExecutionSpan as AgentSpecAgentExecutionSpan from pyagentspec.tracing.spans import FlowExecutionSpan as AgentSpecFlowExecutionSpan -from pyagentspec.tracing.spans import ( - ManagerWorkersExecutionSpan as AgentSpecManagerWorkersExecutionSpan, -) from pyagentspec.tracing.spans import Span as AgentSpecSpan from pyagentspec.tracing.spans import SwarmExecutionSpan as AgentSpecSwarmExecutionSpan from pyagentspec.tracing.spans import ToolExecutionSpan as AgentSpecToolExecutionSpan @@ -41,7 +35,6 @@ StateSnapshotPolicy, ) from wayflowcore.flow import Flow -from wayflowcore.managerworkers import ManagerWorkers from wayflowcore.messagelist import Message, MessageType from wayflowcore.models.vllmmodel import VllmModel from wayflowcore.serialization import dump_conversation_state @@ -239,35 +232,6 @@ def _create_send_message_request(recipient_name: str, message: str) -> Message: ) -def _build_managerworkers_state_snapshot_flow() -> tuple[ - Flow, - VllmModel, - list[Message | str], - VllmModel, - list[Message | str], - type[AgentSpecSpan], - str, - str, - type[AgentSpecEvent], -]: - manager_llm = _create_mock_vllm_model("manager") - worker_llm = _create_mock_vllm_model("worker") - worker = WayflowAgent(llm=worker_llm, name="worker", description="worker") - managerworkers = ManagerWorkers(group_manager=manager_llm, workers=[worker], name="team") - - return ( - Flow.from_steps([AgentExecutionStep(agent=managerworkers), CompleteStep(name="end")]), - manager_llm, - [_create_send_message_request("worker", "Do it"), "manager final answer"], - worker_llm, - ["worker answer"], - AgentSpecManagerWorkersExecutionSpan, - "worker answer", - "manager final answer", - AgentSpecManagerWorkersExecutionEnd, - ) - - def _build_swarm_state_snapshot_flow() -> tuple[ Flow, VllmModel, @@ -456,7 +420,6 @@ def test_agent_state_snapshots_support_the_agui_retrieval_export_flow() -> None: @pytest.mark.parametrize( "flow_builder", [ - pytest.param(_build_managerworkers_state_snapshot_flow, id="managerworkers"), pytest.param(_build_swarm_state_snapshot_flow, id="swarm"), ], ) diff --git a/wayflowcore/tests/events/test_state_snapshot_runtime.py b/wayflowcore/tests/events/test_state_snapshot_runtime.py index df5650a56..16eeeed08 100644 --- a/wayflowcore/tests/events/test_state_snapshot_runtime.py +++ b/wayflowcore/tests/events/test_state_snapshot_runtime.py @@ -29,6 +29,7 @@ from wayflowcore.flow import Flow from wayflowcore.managerworkers import ManagerWorkers from wayflowcore.messagelist import Message, MessageType +from wayflowcore.property import AnyProperty from wayflowcore.serialization.serializer import SerializableNeedToBeImplementedMixin from wayflowcore.steps import ( AgentExecutionStep, @@ -36,9 +37,11 @@ FlowExecutionStep, OutputMessageStep, ToolExecutionStep, + VariableWriteStep, ) from wayflowcore.swarm import Swarm from wayflowcore.tools import ServerTool, ToolRequest, tool +from wayflowcore.variable import Variable from ..conftest import disable_streaming from ..test_interrupts import OnEventExecutionInterrupt @@ -97,6 +100,10 @@ def _on_execution_end( ) +class _UnserializableVariableValue: + pass + + def _create_output_flow_conversation(message: str = "Hello") -> Conversation: flow = Flow.from_steps( [ @@ -767,6 +774,38 @@ def broken_builder(_conversation: Conversation) -> dict[str, object]: assert all(snapshot_event.extra_state is None for snapshot_event in state_snapshot_events) +def test_state_snapshot_emission_survives_unserializable_variable_state() -> None: + custom_variable = Variable(name="custom", type=AnyProperty()) + conversation = Flow.from_steps( + [ + VariableWriteStep( + variable=custom_variable, + input_mapping={VariableWriteStep.VALUE: custom_variable.name}, + ), + OutputMessageStep(message_template="done"), + CompleteStep(name="end"), + ], + variables=[custom_variable], + ).start_conversation(inputs={custom_variable.name: _UnserializableVariableValue()}) + + status, state_snapshot_events = _execute_with_state_snapshots( + conversation, + state_snapshot_policy=StateSnapshotPolicy( + state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS, + include_variable_state=True, + ), + ) + + assert isinstance(status, FinishedStatus) + assert len(state_snapshot_events) == 2 + assert state_snapshot_events[0].variable_state == {"custom": None} + assert state_snapshot_events[-1].variable_state is None + assert ( + state_snapshot_events[-1].state_snapshot["conversation"]["messages"][-1]["content"] + == "done" + ) + + @pytest.mark.parametrize( ( "conversation_factory", From fc7f7caf9c77c8a00953768b3bc146b15756908c Mon Sep 17 00:00:00 2001 From: Son Le Date: Mon, 16 Mar 2026 15:27:10 +0100 Subject: [PATCH 04/13] update copyright year for new state snapshot files --- .../src/wayflowcore/executors/_statesnapshot_eventlistener.py | 2 +- wayflowcore/src/wayflowcore/executors/statesnapshotpolicy.py | 2 +- wayflowcore/src/wayflowcore/serialization/conversation.py | 2 +- wayflowcore/tests/agentspec/test_state_snapshot_tracing.py | 2 +- wayflowcore/tests/events/test_state_snapshot_event.py | 2 +- wayflowcore/tests/events/test_state_snapshot_runtime.py | 2 +- .../tests/serialization/test_conversation_state_snapshot.py | 2 +- 7 files changed, 7 insertions(+), 7 deletions(-) diff --git a/wayflowcore/src/wayflowcore/executors/_statesnapshot_eventlistener.py b/wayflowcore/src/wayflowcore/executors/_statesnapshot_eventlistener.py index b2d926075..db56b3e04 100644 --- a/wayflowcore/src/wayflowcore/executors/_statesnapshot_eventlistener.py +++ b/wayflowcore/src/wayflowcore/executors/_statesnapshot_eventlistener.py @@ -1,4 +1,4 @@ -# Copyright © 2025 Oracle and/or its affiliates. +# Copyright © 2026 Oracle and/or its affiliates. # # This software is under the Apache License 2.0 # (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0) or Universal Permissive License diff --git a/wayflowcore/src/wayflowcore/executors/statesnapshotpolicy.py b/wayflowcore/src/wayflowcore/executors/statesnapshotpolicy.py index f8150092a..d123a8c12 100644 --- a/wayflowcore/src/wayflowcore/executors/statesnapshotpolicy.py +++ b/wayflowcore/src/wayflowcore/executors/statesnapshotpolicy.py @@ -1,4 +1,4 @@ -# Copyright © 2025 Oracle and/or its affiliates. +# Copyright © 2026 Oracle and/or its affiliates. # # This software is under the Apache License 2.0 # (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0) or Universal Permissive License diff --git a/wayflowcore/src/wayflowcore/serialization/conversation.py b/wayflowcore/src/wayflowcore/serialization/conversation.py index becb4a4e6..a8dcbd2a5 100644 --- a/wayflowcore/src/wayflowcore/serialization/conversation.py +++ b/wayflowcore/src/wayflowcore/serialization/conversation.py @@ -1,4 +1,4 @@ -# Copyright © 2025 Oracle and/or its affiliates. +# Copyright © 2026 Oracle and/or its affiliates. # # This software is under the Apache License 2.0 # (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0) or Universal Permissive License diff --git a/wayflowcore/tests/agentspec/test_state_snapshot_tracing.py b/wayflowcore/tests/agentspec/test_state_snapshot_tracing.py index 2534e7a7e..224b1ac5f 100644 --- a/wayflowcore/tests/agentspec/test_state_snapshot_tracing.py +++ b/wayflowcore/tests/agentspec/test_state_snapshot_tracing.py @@ -1,4 +1,4 @@ -# Copyright © 2025 Oracle and/or its affiliates. +# Copyright © 2026 Oracle and/or its affiliates. # # This software is under the Apache License 2.0 # (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0) or Universal Permissive License diff --git a/wayflowcore/tests/events/test_state_snapshot_event.py b/wayflowcore/tests/events/test_state_snapshot_event.py index ee3970e48..b59172f77 100644 --- a/wayflowcore/tests/events/test_state_snapshot_event.py +++ b/wayflowcore/tests/events/test_state_snapshot_event.py @@ -1,4 +1,4 @@ -# Copyright © 2025 Oracle and/or its affiliates. +# Copyright © 2026 Oracle and/or its affiliates. # # This software is under the Apache License 2.0 # (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0) or Universal Permissive License diff --git a/wayflowcore/tests/events/test_state_snapshot_runtime.py b/wayflowcore/tests/events/test_state_snapshot_runtime.py index 16eeeed08..0d61edf1a 100644 --- a/wayflowcore/tests/events/test_state_snapshot_runtime.py +++ b/wayflowcore/tests/events/test_state_snapshot_runtime.py @@ -1,4 +1,4 @@ -# Copyright © 2025 Oracle and/or its affiliates. +# Copyright © 2026 Oracle and/or its affiliates. # # This software is under the Apache License 2.0 # (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0) or Universal Permissive License diff --git a/wayflowcore/tests/serialization/test_conversation_state_snapshot.py b/wayflowcore/tests/serialization/test_conversation_state_snapshot.py index 168a848c3..6a5656721 100644 --- a/wayflowcore/tests/serialization/test_conversation_state_snapshot.py +++ b/wayflowcore/tests/serialization/test_conversation_state_snapshot.py @@ -1,4 +1,4 @@ -# Copyright © 2025 Oracle and/or its affiliates. +# Copyright © 2026 Oracle and/or its affiliates. # # This software is under the Apache License 2.0 # (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0) or Universal Permissive License From 006764c8e6485e1ce24bcf6db51f8db5fe69e1cf Mon Sep 17 00:00:00 2001 From: Son Le Date: Wed, 18 Mar 2026 13:50:29 +0100 Subject: [PATCH 05/13] refactor snapshot policy registration and split tests --- docs/wayflowcore/source/core/changelog.rst | 4 +- .../source/core/howtoguides/howto_tracing.rst | 15 +- .../src/wayflowcore/agentspec/tracing.py | 46 +- wayflowcore/src/wayflowcore/conversation.py | 101 +- wayflowcore/src/wayflowcore/events/event.py | 19 +- .../executors/_statesnapshot_eventlistener.py | 400 ++++---- .../executors/statesnapshotpolicy.py | 17 +- .../wayflowcore/serialization/conversation.py | 45 +- .../agentspec/test_state_snapshot_tracing.py | 657 ------------- .../test_state_snapshot_tracing_agent.py | 396 ++++++++ .../test_state_snapshot_tracing_flow.py | 381 ++++++++ .../test_state_snapshot_tracing_nested.py | 381 ++++++++ ...y => test_state_snapshot_event_tracing.py} | 11 +- .../test_state_snapshot_event_validation.py | 36 + .../events/test_state_snapshot_runtime.py | 867 ------------------ ...ate_snapshot_runtime_conversation_turns.py | 234 +++++ ...t_state_snapshot_runtime_internal_turns.py | 323 +++++++ .../test_state_snapshot_runtime_nested.py | 258 ++++++ .../test_state_snapshot_runtime_resilience.py | 54 ++ .../test_conversation_state_snapshot.py | 57 ++ .../tests/testhelpers/statesnapshots.py | 379 ++++++++ 21 files changed, 2849 insertions(+), 1832 deletions(-) delete mode 100644 wayflowcore/tests/agentspec/test_state_snapshot_tracing.py create mode 100644 wayflowcore/tests/agentspec/test_state_snapshot_tracing_agent.py create mode 100644 wayflowcore/tests/agentspec/test_state_snapshot_tracing_flow.py create mode 100644 wayflowcore/tests/agentspec/test_state_snapshot_tracing_nested.py rename wayflowcore/tests/events/{test_state_snapshot_event.py => test_state_snapshot_event_tracing.py} (83%) create mode 100644 wayflowcore/tests/events/test_state_snapshot_event_validation.py delete mode 100644 wayflowcore/tests/events/test_state_snapshot_runtime.py create mode 100644 wayflowcore/tests/events/test_state_snapshot_runtime_conversation_turns.py create mode 100644 wayflowcore/tests/events/test_state_snapshot_runtime_internal_turns.py create mode 100644 wayflowcore/tests/events/test_state_snapshot_runtime_nested.py create mode 100644 wayflowcore/tests/events/test_state_snapshot_runtime_resilience.py create mode 100644 wayflowcore/tests/testhelpers/statesnapshots.py diff --git a/docs/wayflowcore/source/core/changelog.rst b/docs/wayflowcore/source/core/changelog.rst index 8f9f7fa8f..8c9ea6c3b 100644 --- a/docs/wayflowcore/source/core/changelog.rst +++ b/docs/wayflowcore/source/core/changelog.rst @@ -9,8 +9,8 @@ New features * **State snapshot tracing events:** - Added ``StateSnapshotPolicy``, ``StateSnapshotEvent``, and conversation snapshot serialization helpers. - Snapshot emission can now be enabled per ``conversation.execute()`` / ``execute_async()`` turn, and is bridged to Agent Spec ``StateSnapshotEmitted`` events via the ``AgentSpecEventListener``. WayFlow-specific ``variable_state`` remains part of ``StateSnapshotEvent`` only, is not forwarded to Agent Spec, and requires JSON-serializable variable values. Snapshots are emitted only from direct execution boundary events; raised or interrupted turns do not synthesize extra unwind snapshots. + Added ``StateSnapshotPolicy``, ``StateSnapshotEvent``, and conversation snapshot serialization helpers. State snapshots can now be enabled per ``conversation.execute()`` / ``execute_async()`` turn, emitted at conversation, node, or tool boundaries, and bridged to Agent Spec ``StateSnapshotEmitted`` events via ``AgentSpecEventListener``. + Snapshot emission is covered on both synchronous and asynchronous execution paths. * **OAuth support for MCP Clients:** diff --git a/docs/wayflowcore/source/core/howtoguides/howto_tracing.rst b/docs/wayflowcore/source/core/howtoguides/howto_tracing.rst index 81866574a..734611193 100644 --- a/docs/wayflowcore/source/core/howtoguides/howto_tracing.rst +++ b/docs/wayflowcore/source/core/howtoguides/howto_tracing.rst @@ -163,13 +163,20 @@ snapshots are bridged into Agent Spec ``StateSnapshotEmitted`` events on the act span. The WayFlow-only ``variable_state`` payload is not bridged, because Agent Spec does not define WayFlow variable semantics. When ``include_variable_state=True``, variable values must already be JSON-serializable. -Snapshots are emitted only when the corresponding boundary event occurs. If a turn +``StateSnapshotEvent.conversation_id`` is the logical/public conversation id, +while ``state_snapshot["conversation"]["id"]`` identifies the concrete runtime +conversation instance that emitted the snapshot. +Each policy emits snapshots only for its own boundaries. ``CONVERSATION_TURNS`` +emits opening and closing turn snapshots. Internal policies emit only step, +iteration, and/or tool snapshots. Snapshots are emitted only when the +corresponding boundary event occurs. If a turn raises or is interrupted before its matching closing event, WayFlow does not synthesize an extra unwind snapshot. For step and tool intervals, the latest already-emitted start snapshot is the recovery point. -For flows, ``NODE_TURNS`` snapshots are emitted around each step. For agents, the same -policy emits snapshots around each decision-loop iteration. Tool start/end snapshots -are emitted only for ``TOOL_TURNS`` and ``ALL_INTERNAL_TURNS``. +For flows, ``NODE_TURNS`` uses flow-iteration start/end events, which align with +per-step execution while keeping the end snapshot on committed flow state. For +agents, the same policy emits snapshots around each decision-loop iteration. +Tool start/end snapshots are emitted only for ``TOOL_TURNS`` and ``ALL_INTERNAL_TURNS``. .. code-block:: python diff --git a/wayflowcore/src/wayflowcore/agentspec/tracing.py b/wayflowcore/src/wayflowcore/agentspec/tracing.py index ae70c8b4b..5add8b788 100644 --- a/wayflowcore/src/wayflowcore/agentspec/tracing.py +++ b/wayflowcore/src/wayflowcore/agentspec/tracing.py @@ -82,8 +82,10 @@ def __init__(self) -> None: self.agentspec_components_registry: Dict[str, AgentSpecComponent] = {} # State snapshots belong to the span that owns their conversation id, not # necessarily to the runtime span that was active when the snapshot event - # was emitted. - self._conversation_spans_registry: Dict[str, AgentSpecSpan] = {} + # was emitted. Nested flow sub-conversations can intentionally reuse the + # same deprecated conversation_id, so we also track the live conversation + # object id that currently owns that stream. + self._conversation_spans_registry: Dict[str, tuple[str, AgentSpecSpan]] = {} self._pending_multi_agent_spans_by_component_id: Dict[str, AgentSpecSpan] = {} self._multi_agent_spans_by_step_span_id: Dict[str, AgentSpecSpan] = {} # Track last assistant message id and a robust mapping tool_request_id -> assistant message id. @@ -108,7 +110,18 @@ def _get_active_wayflow_conversation(self) -> Conversation | None: def _register_current_conversation_span(self, agentspec_span: AgentSpecSpan) -> None: active_conversation = self._get_active_wayflow_conversation() if active_conversation is not None: - self._conversation_spans_registry[active_conversation.conversation_id] = agentspec_span + current_owner = self._conversation_spans_registry.get( + active_conversation.conversation_id + ) + if ( + current_owner is None + or current_owner[0] == active_conversation.id + or current_owner[1].end_time is not None + ): + self._conversation_spans_registry[active_conversation.conversation_id] = ( + active_conversation.id, + agentspec_span, + ) def _start_multi_agent_span_if_needed( self, @@ -177,25 +190,39 @@ def _get_snapshot_owner_span( current_agentspec_span: AgentSpecSpan | None, ) -> AgentSpecSpan | None: if event.conversation_id in self._conversation_spans_registry: - return self._conversation_spans_registry[event.conversation_id] + return self._conversation_spans_registry[event.conversation_id][1] + + if event.state_snapshot is None: + return None + if not isinstance(event.state_snapshot, dict): + return None + snapshot_conversation = event.state_snapshot.get("conversation") + if not isinstance(snapshot_conversation, dict): + return None + snapshot_runtime_conversation_id = snapshot_conversation.get("id") + if not isinstance(snapshot_runtime_conversation_id, str): + return None active_conversations = _get_active_conversations(return_copy=False) matching_conversation = next( ( conversation for conversation in reversed(active_conversations) - if conversation.conversation_id == event.conversation_id + if conversation.id == snapshot_runtime_conversation_id ), None, ) if matching_conversation is None: - return current_agentspec_span + return None pending_multi_agent_span = self._pending_multi_agent_spans_by_component_id.get( matching_conversation.component.id ) if pending_multi_agent_span is not None: - self._conversation_spans_registry[event.conversation_id] = pending_multi_agent_span + self._conversation_spans_registry[event.conversation_id] = ( + matching_conversation.id, + pending_multi_agent_span, + ) return pending_multi_agent_span return current_agentspec_span @@ -205,6 +232,11 @@ def _move_snapshot_before_terminal_event( agentspec_span: AgentSpecSpan, snapshot_event: AgentSpecStateSnapshotEmitted, ) -> None: + # State snapshots are emitted by a separate runtime listener in response + # to the turn-end event. That means this Agent Spec listener can record + # the terminal Agent/Flow/Swarm end event first and only see the derived + # snapshot event afterward. Keep the Agent Spec span readable by exposing + # the closing snapshot immediately before the terminal event. if len(agentspec_span.events) < 2 or not isinstance( agentspec_span.events[-2], ( diff --git a/wayflowcore/src/wayflowcore/conversation.py b/wayflowcore/src/wayflowcore/conversation.py index c6213318f..39dab5e32 100644 --- a/wayflowcore/src/wayflowcore/conversation.py +++ b/wayflowcore/src/wayflowcore/conversation.py @@ -3,7 +3,6 @@ # This software is under the Apache License 2.0 # (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0) or Universal Permissive License # (UPL) 1.0 (LICENSE-UPL or https://oss.oracle.com/licenses/upl), at your option. -import json import logging import warnings from abc import abstractmethod @@ -21,7 +20,6 @@ Optional, Sequence, Union, - cast, ) from wayflowcore._utils.async_helpers import run_async_in_sync @@ -100,10 +98,6 @@ class Conversation(DataclassComponent): """Whether the current status associated to this conversation was already handled or not (messages/tool results were added to the conversation)""" - _state_snapshot_policy: Optional[StateSnapshotPolicy] = field( - default=None, init=False, repr=False, compare=False - ) - def __post_init__(self) -> None: if self.inputs is None: self.inputs = {} @@ -118,95 +112,6 @@ def _get_interrupts(self) -> Optional[List["ExecutionInterrupt"]]: def _register_event(self, event: Event) -> None: self.state._register_event(event) - def _get_parent_state_snapshot_policy(self) -> Optional[StateSnapshotPolicy]: - active_conversations = _get_active_conversations(return_copy=True) - if not active_conversations or active_conversations[-1] is self: - return None - return active_conversations[-1]._get_state_snapshot_policy() - - def _get_state_snapshot_policy(self) -> Optional[StateSnapshotPolicy]: - return self._state_snapshot_policy - - def _build_extra_state(self) -> Optional[Dict[str, Any]]: - state_snapshot_policy = self._get_state_snapshot_policy() - if state_snapshot_policy is None or state_snapshot_policy.extra_state_builder is None: - return None - - try: - extra_state = state_snapshot_policy.extra_state_builder(self) - except Exception: - logger.warning( - "Failed to build extra snapshot state for conversation '%s'", - self.conversation_id, - exc_info=True, - ) - return None - - if extra_state is None: - return None - if not isinstance(extra_state, dict): - logger.warning( - "Expected extra snapshot state to be a dictionary for conversation '%s'", - self.conversation_id, - ) - return None - - try: - return cast(Dict[str, Any], json.loads(json.dumps(extra_state))) - except Exception: - logger.warning( - "Extra snapshot state is not JSON serializable for conversation '%s'", - self.conversation_id, - exc_info=True, - ) - return None - - @contextmanager - def _use_state_snapshot( - self, state_snapshot_policy: Optional[StateSnapshotPolicy] - ) -> Generator[None, Any, None]: - """ - Activate the effective state snapshot policy for this execution turn. - - Child conversations inherit the parent's policy unless they explicitly - override it. When snapshots are enabled, listener registration happens - here in the same order the runtime depends on: - 1. pre-interrupt snapshot listener - 2. interrupts listener - 3. post-interrupt snapshot listener - """ - active_state_snapshot_policy = ( - state_snapshot_policy - if state_snapshot_policy is not None - else self._get_parent_state_snapshot_policy() - ) - previous_policy = self._state_snapshot_policy - self._state_snapshot_policy = active_state_snapshot_policy - try: - if active_state_snapshot_policy is None: - yield - else: - from wayflowcore.executors._interrupts_eventlistener import ( - get_interrupts_event_listener_context_for_conversation, - ) - from wayflowcore.executors._statesnapshot_eventlistener import ( - StateSnapshotListenerPhase, - get_state_snapshot_event_listener_context_for_conversation, - ) - - with get_state_snapshot_event_listener_context_for_conversation( - self, - phase=StateSnapshotListenerPhase.PRE_INTERRUPTS, - ): - with get_interrupts_event_listener_context_for_conversation(self): - with get_state_snapshot_event_listener_context_for_conversation( - self, - phase=StateSnapshotListenerPhase.POST_INTERRUPTS, - ): - yield - finally: - self._state_snapshot_policy = previous_policy - def execute( self, execution_interrupts: Optional[Sequence["ExecutionInterrupt"]] = None, @@ -239,7 +144,11 @@ async def execute_async( if self.status_handled is False: self._update_conversation_with_status() - with self._use_state_snapshot(state_snapshot_policy): + from wayflowcore.executors._statesnapshot_eventlistener import ( + get_state_snapshot_execution_context_for_conversation, + ) + + with get_state_snapshot_execution_context_for_conversation(self, state_snapshot_policy): with _register_conversation(self): new_status = await self.component.runner.execute_async(self, execution_interrupts) diff --git a/wayflowcore/src/wayflowcore/events/event.py b/wayflowcore/src/wayflowcore/events/event.py index 2e632177a..11c2a24cf 100644 --- a/wayflowcore/src/wayflowcore/events/event.py +++ b/wayflowcore/src/wayflowcore/events/event.py @@ -796,13 +796,30 @@ def to_tracing_info(self, mask_sensitive_information: bool = True) -> Dict[str, @dataclass(frozen=True) class StateSnapshotEvent(Event): - """Event emitted by WayFlow when a conversation state snapshot is recorded.""" + """Event emitted by WayFlow when a conversation state snapshot is recorded. + + ``conversation_id`` is the logical/public conversation id. When a snapshot is + present, the emitting runtime conversation instance is identified by + ``state_snapshot["conversation"]["id"]``. + """ conversation_id: str = field(default_factory=_required_attribute("conversation_id", str)) state_snapshot: Optional[Dict[str, Any]] = None extra_state: Optional[Dict[str, Any]] = None variable_state: Optional[Dict[str, Any]] = None + def __post_init__(self) -> None: + if self.state_snapshot is None: + return + if not isinstance(self.state_snapshot, dict): + raise ValueError("state_snapshot must be a dictionary") + + snapshot_conversation = self.state_snapshot.get("conversation") + if not isinstance(snapshot_conversation, dict): + raise ValueError("state_snapshot must contain a 'conversation' object") + if not isinstance(snapshot_conversation.get("id"), str): + raise ValueError("state_snapshot['conversation']['id'] must be a string") + def to_tracing_info(self, mask_sensitive_information: bool = True) -> Dict[str, Any]: def _masked(value: Optional[Dict[str, Any]]) -> Any: if value is None: diff --git a/wayflowcore/src/wayflowcore/executors/_statesnapshot_eventlistener.py b/wayflowcore/src/wayflowcore/executors/_statesnapshot_eventlistener.py index db56b3e04..e608186fa 100644 --- a/wayflowcore/src/wayflowcore/executors/_statesnapshot_eventlistener.py +++ b/wayflowcore/src/wayflowcore/executors/_statesnapshot_eventlistener.py @@ -6,10 +6,11 @@ from __future__ import annotations +import json import logging from contextlib import contextmanager -from enum import Enum -from typing import Iterator, Optional +from contextvars import ContextVar +from typing import Any, Dict, Iterator, Optional, cast from wayflowcore.conversation import Conversation, _get_active_conversations from wayflowcore.events import Event, EventListener @@ -21,9 +22,9 @@ ExceptionRaisedEvent, FlowExecutionFinishedEvent, FlowExecutionIterationFinishedEvent, + FlowExecutionIterationStartedEvent, FlowExecutionStartedEvent, StateSnapshotEvent, - StepInvocationStartEvent, ToolExecutionResultEvent, ToolExecutionStartEvent, ) @@ -35,144 +36,173 @@ from wayflowcore.executors._events.event import EventType as ExecutionEventType from wayflowcore.executors._executor import ExecutionInterruptedException from wayflowcore.executors.executionstatus import ExecutionStatus -from wayflowcore.executors.statesnapshotpolicy import StateSnapshotInterval +from wayflowcore.executors.statesnapshotpolicy import StateSnapshotInterval, StateSnapshotPolicy from wayflowcore.serialization.conversation import dump_conversation_state, dump_variable_state from wayflowcore.tracing.span import AgentExecutionSpan, FlowExecutionSpan, get_current_span logger = logging.getLogger(__name__) -class StateSnapshotBoundary(str, Enum): - """ - Concrete runtime boundaries at which a state snapshot may be recorded. +_STATE_SNAPSHOT_POLICIES: ContextVar[Dict[str, StateSnapshotPolicy]] = ContextVar( + "_STATE_SNAPSHOT_POLICIES", + default={}, +) +"""Execution-local mapping of active conversations to their effective snapshot policy.""" - `TURN_START` - The opening boundary of a single `conversation.execute(...)` call. This - captures the turn's initial resume point before execution work begins. - `TURN_END` - The closing boundary of a single `conversation.execute(...)` call. This - is the stable resume point after the turn's final status is known. +def _get_state_snapshot_policies( + return_copy: bool = True, +) -> Dict[str, StateSnapshotPolicy]: + state_snapshot_policies = _STATE_SNAPSHOT_POLICIES.get() + return state_snapshot_policies.copy() if return_copy else state_snapshot_policies - `TOOL_START` - Right before a tool invocation begins. - `TOOL_END` - Right after a tool invocation completes and its result is available. +def _get_parent_state_snapshot_policy( + conversation: Conversation, +) -> Optional[StateSnapshotPolicy]: + active_conversations = _get_active_conversations(return_copy=False) + if not active_conversations or active_conversations[-1] is conversation: + return None + return _get_state_snapshot_policy(active_conversations[-1]) - `NODE_START` - Right before a flow step starts executing. - `NODE_END` - Right after a flow step finishes executing. +def _get_state_snapshot_policy( + conversation: Conversation, +) -> Optional[StateSnapshotPolicy]: + return _get_state_snapshot_policies(return_copy=False).get(conversation.id) - `AGENT_LOOP_START` - Right before an agent reasoning/decision-loop iteration starts. - `AGENT_LOOP_END` - Right after an agent reasoning/decision-loop iteration finishes. - """ +@contextmanager +def _use_state_snapshot_policy( + conversation: Conversation, + state_snapshot_policy: Optional[StateSnapshotPolicy], +) -> Iterator[None]: + # Copy-on-write is needed here because child anyio tasks inherit the current + # context, including references to mutable ContextVar values. + state_snapshot_policies = _get_state_snapshot_policies(return_copy=True) + if state_snapshot_policy is None: + state_snapshot_policies.pop(conversation.id, None) + else: + state_snapshot_policies[conversation.id] = state_snapshot_policy + + token = _STATE_SNAPSHOT_POLICIES.set(state_snapshot_policies) + try: + yield + finally: + _STATE_SNAPSHOT_POLICIES.reset(token) - TURN_START = "turn_start" - TURN_END = "turn_end" - TOOL_START = "tool_start" - TOOL_END = "tool_end" - NODE_START = "node_start" - NODE_END = "node_end" - AGENT_LOOP_START = "agent_loop_start" - AGENT_LOOP_END = "agent_loop_end" +def _build_extra_state( + conversation: Conversation, + state_snapshot_policy: StateSnapshotPolicy, +) -> Optional[Dict[str, Any]]: + if state_snapshot_policy.extra_state_builder is None: + return None -class StateSnapshotListenerPhase(str, Enum): - PRE_INTERRUPTS = "pre_interrupts" - POST_INTERRUPTS = "post_interrupts" + try: + extra_state = state_snapshot_policy.extra_state_builder(conversation) + except Exception: + logger.warning( + "Failed to build extra snapshot state for conversation '%s'", + conversation.conversation_id, + exc_info=True, + ) + return None + if extra_state is None: + return None + if not isinstance(extra_state, dict): + logger.warning( + "Expected extra snapshot state to be a dictionary for conversation '%s'", + conversation.conversation_id, + ) + return None -def should_emit_state_snapshot( + try: + return cast(Dict[str, Any], json.loads(json.dumps(extra_state))) + except Exception: + logger.warning( + "Extra snapshot state is not JSON serializable for conversation '%s'", + conversation.conversation_id, + exc_info=True, + ) + return None + + +def _get_snapshot_policy_for_interval( conversation: Conversation, - boundary: StateSnapshotBoundary, -) -> bool: - state_snapshot_policy = conversation._get_state_snapshot_policy() + required_snapshot_interval: StateSnapshotInterval, +) -> Optional[StateSnapshotPolicy]: + state_snapshot_policy = _get_state_snapshot_policy(conversation) if state_snapshot_policy is None: - return False + return None snapshot_interval = state_snapshot_policy.state_snapshot_interval if snapshot_interval == StateSnapshotInterval.OFF: - should_emit = False - elif boundary == StateSnapshotBoundary.TURN_START: - should_emit = snapshot_interval == StateSnapshotInterval.CONVERSATION_TURNS - elif boundary == StateSnapshotBoundary.TURN_END: - should_emit = True - elif boundary in {StateSnapshotBoundary.TOOL_START, StateSnapshotBoundary.TOOL_END}: - should_emit = snapshot_interval in { - StateSnapshotInterval.TOOL_TURNS, - StateSnapshotInterval.ALL_INTERNAL_TURNS, - } - elif boundary in { - StateSnapshotBoundary.NODE_START, - StateSnapshotBoundary.NODE_END, - StateSnapshotBoundary.AGENT_LOOP_START, - StateSnapshotBoundary.AGENT_LOOP_END, - }: - # Agents do not expose node execution events, so NODE_TURNS maps to - # per-step boundaries for flows and per-iteration boundaries for agents. - should_emit = snapshot_interval in { - StateSnapshotInterval.NODE_TURNS, - StateSnapshotInterval.ALL_INTERNAL_TURNS, - } - else: - should_emit = False + return None + + if snapshot_interval == required_snapshot_interval: + return state_snapshot_policy - return should_emit + if required_snapshot_interval == StateSnapshotInterval.CONVERSATION_TURNS: + return None + + if snapshot_interval == StateSnapshotInterval.ALL_INTERNAL_TURNS: + return state_snapshot_policy + + return None + + +def _build_variable_state( + conversation: Conversation, + state_snapshot_policy: StateSnapshotPolicy, +) -> Optional[dict[str, Any]]: + if not state_snapshot_policy.include_variable_state: + return None + + try: + return dump_variable_state(conversation) + except Exception: + logger.warning( + "Failed to dump variable state for conversation '%s'", + conversation.conversation_id, + exc_info=True, + ) + return None def record_state_snapshot( conversation: Conversation, - boundary: StateSnapshotBoundary, + required_snapshot_interval: StateSnapshotInterval, *, execution_status: ExecutionStatus | None, status_handled: bool, -) -> bool: - state_snapshot_policy = conversation._get_state_snapshot_policy() - if state_snapshot_policy is None or not should_emit_state_snapshot(conversation, boundary): - return False - - previous_status = conversation.status - previous_status_handled = conversation.status_handled - conversation.status = execution_status - conversation.status_handled = status_handled +) -> None: + state_snapshot_policy = _get_snapshot_policy_for_interval( + conversation, required_snapshot_interval + ) + if state_snapshot_policy is None: + return try: - variable_state = None - if state_snapshot_policy.include_variable_state: - try: - variable_state = dump_variable_state(conversation) - except Exception: - logger.warning( - "Failed to dump variable state for conversation '%s'", - conversation.conversation_id, - exc_info=True, - ) - record_event( StateSnapshotEvent( conversation_id=conversation.conversation_id, - state_snapshot=dump_conversation_state(conversation), - extra_state=conversation._build_extra_state(), - variable_state=variable_state, + state_snapshot=dump_conversation_state( + conversation, + status=execution_status, + status_handled=status_handled, + ), + extra_state=_build_extra_state(conversation, state_snapshot_policy), + variable_state=_build_variable_state(conversation, state_snapshot_policy), ) ) - return True except Exception: logger.warning( "Failed to emit state snapshot for conversation '%s'", conversation.conversation_id, exc_info=True, ) - return False - finally: - conversation.status = previous_status - conversation.status_handled = previous_status_handled def _get_current_active_conversation() -> Optional[Conversation]: @@ -207,90 +237,77 @@ class StateSnapshotEventListener(EventListener): def __init__( self, conversation: Conversation, - phase: StateSnapshotListenerPhase, + post_interrupts: bool, ) -> None: self.conversation = conversation - self.phase = phase - - def _record_snapshot(self, boundary: StateSnapshotBoundary) -> None: - record_state_snapshot( - self.conversation, - boundary, - execution_status=None, - status_handled=False, - ) + self.post_interrupts = post_interrupts - def _handle_pre_interrupt_event(self, event: Event) -> None: - match event: - case FlowExecutionStartedEvent(): - self._record_snapshot(StateSnapshotBoundary.TURN_START) - case AgentExecutionStartedEvent(): - self._record_snapshot(StateSnapshotBoundary.TURN_START) - case ToolExecutionStartEvent(): - self._record_snapshot(StateSnapshotBoundary.TOOL_START) - case ToolExecutionResultEvent(): - self._record_snapshot(StateSnapshotBoundary.TOOL_END) - case StepInvocationStartEvent(): - self._record_snapshot(StateSnapshotBoundary.NODE_START) - case FlowExecutionIterationFinishedEvent(): - self._record_snapshot(StateSnapshotBoundary.NODE_END) - case AgentExecutionIterationStartedEvent(): - self._record_snapshot(StateSnapshotBoundary.AGENT_LOOP_START) - case AgentExecutionIterationFinishedEvent(): - self._record_snapshot(StateSnapshotBoundary.AGENT_LOOP_END) - - def _handle_pre_interrupt_event_for_parent_multi_agent(self, event: Event) -> None: - match event: - case AgentExecutionStartedEvent() | FlowExecutionStartedEvent(): - self._record_snapshot(StateSnapshotBoundary.TURN_START) - - def _record_turn_end_snapshot( + def _record_snapshot( self, + required_snapshot_interval: StateSnapshotInterval, execution_status: ExecutionStatus | None = None, ) -> None: record_state_snapshot( self.conversation, - StateSnapshotBoundary.TURN_END, + required_snapshot_interval, execution_status=execution_status, status_handled=False, ) - def _latest_execution_event_is_turn_end(self) -> bool: - if not self.conversation.state.events: - return False - return self.conversation.state.events[-1].type == ExecutionEventType.EXECUTION_END + def _handle_pre_interrupt_event( + self, + event: Event, + *, + is_current_conversation: bool, + ) -> None: + # Agents do not expose node execution events, so NODE_TURNS maps to + # flow iteration boundaries and agent iteration boundaries. + match event: + case AgentExecutionStartedEvent() | FlowExecutionStartedEvent(): + self._record_snapshot(StateSnapshotInterval.CONVERSATION_TURNS) + case _ if not is_current_conversation: + return + case ToolExecutionStartEvent() | ToolExecutionResultEvent(): + self._record_snapshot(StateSnapshotInterval.TOOL_TURNS) + case ( + FlowExecutionIterationStartedEvent() + | FlowExecutionIterationFinishedEvent() + | AgentExecutionIterationStartedEvent() + | AgentExecutionIterationFinishedEvent() + ): + self._record_snapshot(StateSnapshotInterval.NODE_TURNS) def _should_record_interrupted_turn_end_snapshot( self, ) -> bool: - if not self._latest_execution_event_is_turn_end(): - should_record = False - elif not isinstance(get_current_span(), (FlowExecutionSpan, AgentExecutionSpan)): - should_record = False - else: - should_record = True + return ( + bool(self.conversation.state.events) + and (self.conversation.state.events[-1].type == ExecutionEventType.EXECUTION_END) + and isinstance(get_current_span(), (FlowExecutionSpan, AgentExecutionSpan)) + ) - return should_record + def _is_parent_multi_agent_conversation(self) -> bool: + parent_multi_agent_conversation = _get_nearest_parent_multi_agent_conversation() + return ( + parent_multi_agent_conversation is not None + and parent_multi_agent_conversation.id == self.conversation.id + ) def _handle_post_interrupt_event(self, event: Event) -> None: match event: case FlowExecutionFinishedEvent( execution_status=execution_status ) | AgentExecutionFinishedEvent(execution_status=execution_status): - self._record_turn_end_snapshot(execution_status) - case ExceptionRaisedEvent(exception=ExecutionInterruptedException() as exception): - if self._should_record_interrupted_turn_end_snapshot(): - self._record_turn_end_snapshot(exception.execution_status) - - def _handle_post_interrupt_event_for_parent_multi_agent(self, event: Event) -> None: - match event: - case FlowExecutionFinishedEvent( - execution_status=execution_status - ) | AgentExecutionFinishedEvent(execution_status=execution_status): - self._record_turn_end_snapshot(execution_status) + self._record_snapshot( + StateSnapshotInterval.CONVERSATION_TURNS, + execution_status, + ) case ExceptionRaisedEvent(exception=ExecutionInterruptedException() as exception): if self._should_record_interrupted_turn_end_snapshot(): - self._record_turn_end_snapshot(exception.execution_status) + self._record_snapshot( + StateSnapshotInterval.CONVERSATION_TURNS, + exception.execution_status, + ) def __call__(self, event: Event) -> None: if isinstance(event, StateSnapshotEvent): @@ -301,32 +318,23 @@ def __call__(self, event: Event) -> None: return is_current_conversation = current_conversation.id == self.conversation.id - parent_multi_agent_conversation = _get_nearest_parent_multi_agent_conversation() - is_parent_multi_agent_conversation = ( - parent_multi_agent_conversation is not None - and parent_multi_agent_conversation.id == self.conversation.id - ) - - if not is_current_conversation and not is_parent_multi_agent_conversation: + if not is_current_conversation and not self._is_parent_multi_agent_conversation(): return - if self.phase == StateSnapshotListenerPhase.PRE_INTERRUPTS: - if is_current_conversation: - self._handle_pre_interrupt_event(event) - else: - self._handle_pre_interrupt_event_for_parent_multi_agent(event) + if self.post_interrupts: + self._handle_post_interrupt_event(event) else: - if is_current_conversation: - self._handle_post_interrupt_event(event) - else: - self._handle_post_interrupt_event_for_parent_multi_agent(event) + self._handle_pre_interrupt_event( + event, + is_current_conversation=is_current_conversation, + ) @contextmanager def get_state_snapshot_event_listener_context_for_conversation( conversation: Conversation, *, - phase: StateSnapshotListenerPhase, + post_interrupts: bool, ) -> Iterator[StateSnapshotEventListener]: current_listener = next( ( @@ -334,7 +342,7 @@ def get_state_snapshot_event_listener_context_for_conversation( for event_listener in get_event_listeners() if isinstance(event_listener, StateSnapshotEventListener) and event_listener.conversation.id == conversation.id - and event_listener.phase == phase + and event_listener.post_interrupts == post_interrupts ), None, ) @@ -342,6 +350,50 @@ def get_state_snapshot_event_listener_context_for_conversation( if current_listener is not None: yield current_listener else: - listener = StateSnapshotEventListener(conversation, phase=phase) + listener = StateSnapshotEventListener(conversation, post_interrupts=post_interrupts) with register_event_listeners([listener]): yield listener + + +@contextmanager +def get_state_snapshot_execution_context_for_conversation( + conversation: Conversation, + state_snapshot_policy: Optional[StateSnapshotPolicy], +) -> Iterator[None]: + """ + Activate the effective snapshot policy for one `conversation.execute(...)` turn. + + Child conversations inherit the currently active parent policy unless they + explicitly override it. When snapshots are enabled, listener registration + happens here in the runtime order the execution model depends on: + 1. pre-interrupt snapshot listener + 2. interrupts listener + 3. post-interrupt snapshot listener + """ + active_state_snapshot_policy = ( + state_snapshot_policy + if state_snapshot_policy is not None + else _get_parent_state_snapshot_policy(conversation) + ) + + with _use_state_snapshot_policy(conversation, active_state_snapshot_policy): + if active_state_snapshot_policy is None: + yield + return + + from wayflowcore.executors._interrupts_eventlistener import ( + get_interrupts_event_listener_context_for_conversation, + ) + + with ( + get_state_snapshot_event_listener_context_for_conversation( + conversation, + post_interrupts=False, + ), + get_interrupts_event_listener_context_for_conversation(conversation), + get_state_snapshot_event_listener_context_for_conversation( + conversation, + post_interrupts=True, + ), + ): + yield diff --git a/wayflowcore/src/wayflowcore/executors/statesnapshotpolicy.py b/wayflowcore/src/wayflowcore/executors/statesnapshotpolicy.py index d123a8c12..a047700ae 100644 --- a/wayflowcore/src/wayflowcore/executors/statesnapshotpolicy.py +++ b/wayflowcore/src/wayflowcore/executors/statesnapshotpolicy.py @@ -21,21 +21,20 @@ class StateSnapshotInterval(str, Enum): `CONVERSATION_TURNS` Emit an opening turn snapshot before execution starts and a closing turn snapshot when the turn finishes or is interrupted at execution - end. This is the default policy because it gives a stable resume point - without emitting snapshots for every internal step. + end. This is the default policy because it gives a stable turn-level + checkpoint without emitting snapshots for every internal step. `TOOL_TURNS` - Emit the standard closing turn snapshot plus snapshots around each tool - invocation (`TOOL_START` and `TOOL_END`). + Emit snapshots around each tool invocation (`TOOL_START` and + `TOOL_END`) only. `NODE_TURNS` - Emit the standard closing turn snapshot plus snapshots around each - internal node boundary. For flows this means per-step snapshots; for - agents it maps to decision-loop iteration boundaries. + Emit snapshots around each internal node boundary only. For flows this + means per-step snapshots; for agents it maps to decision-loop + iteration boundaries. `ALL_INTERNAL_TURNS` - Emit the standard closing turn snapshot plus all tool and node - snapshots. + Emit all tool and node snapshots, without the broader turn snapshots. `OFF` Disable state snapshot emission entirely. diff --git a/wayflowcore/src/wayflowcore/serialization/conversation.py b/wayflowcore/src/wayflowcore/serialization/conversation.py index a8dcbd2a5..502320460 100644 --- a/wayflowcore/src/wayflowcore/serialization/conversation.py +++ b/wayflowcore/src/wayflowcore/serialization/conversation.py @@ -30,6 +30,8 @@ from wayflowcore.executors._agentconversation import AgentConversation from wayflowcore.executors._flowconversation import FlowConversation +_UNSET = object() + def _dump_json_compatible_value(value: Any) -> Any: from wayflowcore.component import Component @@ -46,6 +48,7 @@ def _dump_json_compatible_value(value: Any) -> Any: dumped_value = _dump_json_compatible_value(value.value) elif isinstance(value, Conversation): dumped_value = { + "id": value.id, "conversation_id": value.conversation_id, "conversation_type": value.__class__.__name__, } @@ -184,7 +187,6 @@ def _dump_execution_status(execution_status: Optional[ExecutionStatus]) -> Optio else: dumped_status = { "type": execution_status.__class__.__name__, - "conversation_id": execution_status._conversation_id, } if isinstance(execution_status, FinishedStatus): @@ -213,6 +215,7 @@ def _dump_execution_status(execution_status: Optional[ExecutionStatus]) -> Optio def _dump_conversation_info(conversation: "Conversation") -> dict[str, Any]: return { + "id": conversation.id, "conversation_id": conversation.conversation_id, "conversation_type": conversation.__class__.__name__, "component_type": conversation.component.__class__.__name__, @@ -222,11 +225,20 @@ def _dump_conversation_info(conversation: "Conversation") -> dict[str, Any]: } -def _dump_execution_info(conversation: "Conversation") -> dict[str, Any]: +def _dump_execution_info( + conversation: "Conversation", + *, + status: object = _UNSET, + status_handled: object = _UNSET, +) -> dict[str, Any]: return { "current_step_name": conversation.current_step_name, - "status": _dump_execution_status(conversation.status), - "status_handled": conversation.status_handled, + "status": _dump_execution_status( + conversation.status if status is _UNSET else cast(Optional[ExecutionStatus], status) + ), + "status_handled": ( + conversation.status_handled if status_handled is _UNSET else cast(bool, status_handled) + ), "token_usage": _dump_json_compatible_value(conversation.token_usage), } @@ -265,22 +277,39 @@ def _dump_agent_execution_info(conversation: "AgentConversation") -> dict[str, A } -def dump_conversation_state(conversation: "Conversation") -> dict[str, Any]: +def dump_conversation_state( + conversation: "Conversation", + *, + status: object = _UNSET, + status_handled: object = _UNSET, +) -> dict[str, Any]: from wayflowcore.executors._agentconversation import AgentConversation from wayflowcore.executors._flowconversation import FlowConversation if isinstance(conversation, FlowConversation): execution_info = { - **_dump_execution_info(conversation), + **_dump_execution_info( + conversation, + status=status, + status_handled=status_handled, + ), **_dump_flow_execution_info(conversation), } elif isinstance(conversation, AgentConversation): execution_info = { - **_dump_execution_info(conversation), + **_dump_execution_info( + conversation, + status=status, + status_handled=status_handled, + ), **_dump_agent_execution_info(conversation), } else: - execution_info = _dump_execution_info(conversation) + execution_info = _dump_execution_info( + conversation, + status=status, + status_handled=status_handled, + ) return { "conversation": _dump_conversation_info(conversation), diff --git a/wayflowcore/tests/agentspec/test_state_snapshot_tracing.py b/wayflowcore/tests/agentspec/test_state_snapshot_tracing.py deleted file mode 100644 index 224b1ac5f..000000000 --- a/wayflowcore/tests/agentspec/test_state_snapshot_tracing.py +++ /dev/null @@ -1,657 +0,0 @@ -# Copyright © 2026 Oracle and/or its affiliates. -# -# This software is under the Apache License 2.0 -# (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0) or Universal Permissive License -# (UPL) 1.0 (LICENSE-UPL or https://oss.oracle.com/licenses/upl), at your option. - -from dataclasses import dataclass -from typing import Any, cast - -import pytest -from pyagentspec.adapters.wayflow import AgentSpecLoader -from pyagentspec.agent import Agent as AgentSpecAgent -from pyagentspec.llms import VllmConfig -from pyagentspec.tracing.events import AgentExecutionEnd as AgentSpecAgentExecutionEnd -from pyagentspec.tracing.events import AgentExecutionStart as AgentSpecAgentExecutionStart -from pyagentspec.tracing.events import Event as AgentSpecEvent -from pyagentspec.tracing.events import FlowExecutionEnd as AgentSpecFlowExecutionEnd -from pyagentspec.tracing.events import FlowExecutionStart as AgentSpecFlowExecutionStart -from pyagentspec.tracing.events import StateSnapshotEmitted as AgentSpecStateSnapshotEmitted -from pyagentspec.tracing.events import SwarmExecutionEnd as AgentSpecSwarmExecutionEnd -from pyagentspec.tracing.spanprocessor import SpanProcessor as AgentSpecSpanProcessor -from pyagentspec.tracing.spans import AgentExecutionSpan as AgentSpecAgentExecutionSpan -from pyagentspec.tracing.spans import FlowExecutionSpan as AgentSpecFlowExecutionSpan -from pyagentspec.tracing.spans import Span as AgentSpecSpan -from pyagentspec.tracing.spans import SwarmExecutionSpan as AgentSpecSwarmExecutionSpan -from pyagentspec.tracing.spans import ToolExecutionSpan as AgentSpecToolExecutionSpan -from pyagentspec.tracing.trace import Trace as AgentSpecTrace - -from wayflowcore import Agent as WayflowAgent -from wayflowcore.agentspec.tracing import AgentSpecEventListener -from wayflowcore.events.eventlistener import register_event_listeners -from wayflowcore.executors.executionstatus import FinishedStatus, UserMessageRequestStatus -from wayflowcore.executors.statesnapshotpolicy import ( - StateSnapshotInterval, - StateSnapshotPolicy, -) -from wayflowcore.flow import Flow -from wayflowcore.messagelist import Message, MessageType -from wayflowcore.models.vllmmodel import VllmModel -from wayflowcore.serialization import dump_conversation_state -from wayflowcore.steps import AgentExecutionStep, CompleteStep, OutputMessageStep, ToolExecutionStep -from wayflowcore.swarm import Swarm -from wayflowcore.tools import ServerTool, ToolRequest - -from ..testhelpers.patching import patch_llm - -pytestmark = pytest.mark.skipif( - AgentSpecStateSnapshotEmitted is None, - reason="Installed pyagentspec does not expose StateSnapshotEmitted", -) - - -class SnapshotSpanRecorder(AgentSpecSpanProcessor): - def __init__(self) -> None: - super().__init__() - self.started_spans: list[AgentSpecSpan] = [] - self.ended_spans: list[AgentSpecSpan] = [] - - def on_start(self, span: AgentSpecSpan) -> None: - self.started_spans.append(span) - - async def on_start_async(self, span: AgentSpecSpan) -> None: - self.started_spans.append(span) - - def on_end(self, span: AgentSpecSpan) -> None: - self.ended_spans.append(span) - - async def on_end_async(self, span: AgentSpecSpan) -> None: - self.ended_spans.append(span) - - def on_event(self, event: AgentSpecEvent, span: AgentSpecSpan) -> None: - return None - - async def on_event_async(self, event: AgentSpecEvent, span: AgentSpecSpan) -> None: - return None - - def startup(self) -> None: - return None - - def shutdown(self) -> None: - return None - - async def startup_async(self) -> None: - return None - - async def shutdown_async(self) -> None: - return None - - -@dataclass(frozen=True) -class ExportedAGUIStateSnapshot: - conversation_id: str - snapshot: dict[str, Any] - - -class AGUIStateSnapshotExporter(AgentSpecSpanProcessor): - def __init__(self) -> None: - super().__init__() - self.exported_snapshots: list[ExportedAGUIStateSnapshot] = [] - - def on_start(self, span: AgentSpecSpan) -> None: - return None - - async def on_start_async(self, span: AgentSpecSpan) -> None: - return None - - def on_end(self, span: AgentSpecSpan) -> None: - return None - - async def on_end_async(self, span: AgentSpecSpan) -> None: - return None - - def on_event(self, event: AgentSpecEvent, span: AgentSpecSpan) -> None: - if not isinstance(event, AgentSpecStateSnapshotEmitted): - return - - conversation_snapshot = (event.state_snapshot or {}).get("conversation", {}) - self.exported_snapshots.append( - ExportedAGUIStateSnapshot( - conversation_id=event.conversation_id, - snapshot={ - "messages": conversation_snapshot.get("messages", []), - "input": conversation_snapshot.get("inputs", {}).get("input"), - "agent_state": (event.extra_state or {}).get("agent_state", {}), - }, - ) - ) - - async def on_event_async(self, event: AgentSpecEvent, span: AgentSpecSpan) -> None: - self.on_event(event, span) - - def startup(self) -> None: - return None - - def shutdown(self) -> None: - return None - - async def startup_async(self) -> None: - return None - - async def shutdown_async(self) -> None: - return None - - -_RETRIEVAL_INPUTS = { - "input": "How many orders last week?", - "thread_id": "thread-123", - "agent_type": "planner", - "llm_model_name": "gpt-5-mini", - "default_schema": "sales", - "input_document": "Only use the sales schema and weekly order metrics.", -} - -_RETRIEVAL_UI_STATE = { - "preplan": { - "summary": "Inspect weekly sales orders and answer concisely.", - "entries": [ - "Inspect the active schema", - "Aggregate last week's orders", - "Return the final answer", - ], - "ready_to_proceed": True, - }, - "assumptions": [ - {"text": "Use the sales schema only", "status": "approved"}, - {"text": "Week boundaries follow UTC", "status": "auto_approved"}, - ], -} - - -def _create_retrieval_like_wayflow_agent() -> WayflowAgent: - agentspec_agent = AgentSpecAgent( - name="retrieval_agent", - llm_config=VllmConfig(name="llm", url="http://mock.url", model_id="mock.model"), - system_prompt="You are a helpful retrieval agent.", - ) - return cast(WayflowAgent, AgentSpecLoader().load_component(agentspec_agent)) - - -def _build_retrieval_agent_state( - *, - conversation_inputs: dict[str, Any], - message_count: int, - last_response: str, -) -> dict[str, Any]: - return { - "thread_id": conversation_inputs["thread_id"], - "agent_type": conversation_inputs["agent_type"], - "llm_model_name": conversation_inputs["llm_model_name"], - "default_schema": conversation_inputs["default_schema"], - "input_document": conversation_inputs["input_document"], - "message_count": message_count, - "last_response": last_response, - "ui": _RETRIEVAL_UI_STATE, - } - - -def _build_retrieval_like_extra_state(conversation) -> dict[str, Any]: - conversation_snapshot = dump_conversation_state(conversation)["conversation"] - messages = conversation_snapshot["messages"] - last_response = next( - ( - message.get("content") - for message in reversed(messages) - if message.get("role") == "assistant" and message.get("content") - ), - "", - ) - return { - "agent_state": _build_retrieval_agent_state( - conversation_inputs=conversation.inputs, - message_count=len(messages), - last_response=last_response, - ) - } - - -def _create_mock_vllm_model(name: str) -> VllmModel: - return VllmModel(model_id="mock.model", host_port="http://mock.url", name=name) - - -def _create_send_message_request(recipient_name: str, message: str) -> Message: - return Message( - content="", - message_type=MessageType.TOOL_REQUEST, - tool_requests=[ - ToolRequest( - name="send_message", - args={"recipient": recipient_name, "message": message}, - ) - ], - ) - - -def _build_swarm_state_snapshot_flow() -> tuple[ - Flow, - VllmModel, - list[Message | str], - VllmModel, - list[Message | str], - type[AgentSpecSpan], - str, - str, - type[AgentSpecEvent], -]: - first_agent_llm = _create_mock_vllm_model("agent1") - second_agent_llm = _create_mock_vllm_model("agent2") - first_agent = WayflowAgent(llm=first_agent_llm, name="agent1", description="agent1") - second_agent = WayflowAgent(llm=second_agent_llm, name="agent2", description="agent2") - swarm = Swarm( - first_agent=first_agent, - relationships=[(first_agent, second_agent), (second_agent, first_agent)], - name="swarm", - ) - - return ( - Flow.from_steps([AgentExecutionStep(agent=swarm), CompleteStep(name="end")]), - first_agent_llm, - [_create_send_message_request("agent2", "Do it"), "agent1 final answer"], - second_agent_llm, - ["agent2 answer"], - AgentSpecSwarmExecutionSpan, - "agent2 answer", - "agent1 final answer", - AgentSpecSwarmExecutionEnd, - ) - - -def test_flow_state_snapshots_are_mapped_into_the_flow_span_before_flow_end() -> None: - flow = Flow.from_steps( - [OutputMessageStep(message_template="Hello"), CompleteStep(name="end")], - step_names=["single_step", "end"], - ) - conversation = flow.start_conversation() - listener = AgentSpecEventListener() - span_recorder = SnapshotSpanRecorder() - - with AgentSpecTrace(span_processors=[span_recorder]): - with register_event_listeners([listener]): - status = conversation.execute( - state_snapshot_policy=StateSnapshotPolicy( - state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS, - extra_state_builder=lambda _conversation: {"ui": {"active_tab": "plan"}}, - ) - ) - - assert isinstance(status, FinishedStatus) - - flow_spans = [ - span for span in span_recorder.started_spans if isinstance(span, AgentSpecFlowExecutionSpan) - ] - assert len(flow_spans) == 1 - flow_span = flow_spans[0] - flow_events = flow_span.events - - assert any(isinstance(event, AgentSpecFlowExecutionStart) for event in flow_events) - flow_end_event = next( - event for event in flow_events if isinstance(event, AgentSpecFlowExecutionEnd) - ) - state_snapshot_events = [ - event for event in flow_events if isinstance(event, AgentSpecStateSnapshotEmitted) - ] - - # From an Agent Spec consumer point of view, the flow span should expose the - # opening and closing checkpoints, and the closing checkpoint must still - # appear before the terminal flow-end event. - assert len(state_snapshot_events) == 2 - final_snapshot_event = state_snapshot_events[-1] - assert final_snapshot_event.conversation_id == conversation.conversation_id - assert final_snapshot_event.state_snapshot["conversation"]["messages"][-1]["content"] == "Hello" - assert final_snapshot_event.extra_state == {"ui": {"active_tab": "plan"}} - assert flow_events.index(final_snapshot_event) < flow_events.index(flow_end_event) - assert flow_span.end_time is not None - assert final_snapshot_event.timestamp <= flow_span.end_time - assert "variable_state" not in final_snapshot_event.model_dump(mask_sensitive_information=False) - assert flow_span in span_recorder.ended_spans - - -def test_off_policy_does_not_bridge_state_snapshots_into_agent_spec_spans() -> None: - flow = Flow.from_steps( - [OutputMessageStep(message_template="Hello"), CompleteStep(name="end")], - step_names=["single_step", "end"], - ) - conversation = flow.start_conversation() - listener = AgentSpecEventListener() - span_recorder = SnapshotSpanRecorder() - - with AgentSpecTrace(span_processors=[span_recorder]): - with register_event_listeners([listener]): - status = conversation.execute( - state_snapshot_policy=StateSnapshotPolicy( - state_snapshot_interval=StateSnapshotInterval.OFF - ) - ) - - assert isinstance(status, FinishedStatus) - - flow_spans = [ - span for span in span_recorder.started_spans if isinstance(span, AgentSpecFlowExecutionSpan) - ] - assert len(flow_spans) == 1 - flow_events = flow_spans[0].events - - assert any(isinstance(event, AgentSpecFlowExecutionStart) for event in flow_events) - assert any(isinstance(event, AgentSpecFlowExecutionEnd) for event in flow_events) - assert not any(isinstance(event, AgentSpecStateSnapshotEmitted) for event in flow_events) - - -def test_agent_state_snapshots_support_the_agui_retrieval_export_flow() -> None: - assistant_message = "I checked the warehouse and found 42 orders last week." - wayflow_agent = _create_retrieval_like_wayflow_agent() - conversation = wayflow_agent.start_conversation(inputs=_RETRIEVAL_INPUTS) - conversation.append_user_message(_RETRIEVAL_INPUTS["input"]) - - listener = AgentSpecEventListener() - span_recorder = SnapshotSpanRecorder() - agui_exporter = AGUIStateSnapshotExporter() - - with patch_llm(wayflow_agent.llm, [assistant_message], patch_internal=True): - with AgentSpecTrace(span_processors=[span_recorder, agui_exporter]): - with register_event_listeners([listener]): - status = conversation.execute( - state_snapshot_policy=StateSnapshotPolicy( - state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS, - extra_state_builder=_build_retrieval_like_extra_state, - ) - ) - - assert isinstance(status, UserMessageRequestStatus) - - agent_spans = [ - span - for span in span_recorder.started_spans - if isinstance(span, AgentSpecAgentExecutionSpan) - ] - assert len(agent_spans) == 1 - agent_span = agent_spans[0] - agent_events = agent_span.events - - assert any(isinstance(event, AgentSpecAgentExecutionStart) for event in agent_events) - agent_end_event = next( - event for event in agent_events if isinstance(event, AgentSpecAgentExecutionEnd) - ) - state_snapshot_events = [ - event for event in agent_events if isinstance(event, AgentSpecStateSnapshotEmitted) - ] - assert len(state_snapshot_events) == 2 - - final_snapshot_event = state_snapshot_events[-1] - runtime_messages = final_snapshot_event.state_snapshot["conversation"]["messages"] - expected_agent_state = _build_retrieval_agent_state( - conversation_inputs=_RETRIEVAL_INPUTS, - message_count=len(runtime_messages), - last_response=assistant_message, - ) - - # This retrieval example is the main product use-case: a downstream AG-UI - # style exporter should be able to reconstruct the latest UI-facing state - # directly from the final snapshot event on the agent execution span. - assert final_snapshot_event.conversation_id == conversation.conversation_id - assert ( - final_snapshot_event.state_snapshot["conversation"]["inputs"]["input"] - == _RETRIEVAL_INPUTS["input"] - ) - assert runtime_messages[-1]["content"] == assistant_message - assert final_snapshot_event.extra_state == {"agent_state": expected_agent_state} - assert agent_events.index(final_snapshot_event) < agent_events.index(agent_end_event) - - assert len(agui_exporter.exported_snapshots) == 2 - assert agui_exporter.exported_snapshots[-1] == ExportedAGUIStateSnapshot( - conversation_id=conversation.conversation_id, - snapshot={ - "messages": runtime_messages, - "input": _RETRIEVAL_INPUTS["input"], - "agent_state": expected_agent_state, - }, - ) - - -@pytest.mark.parametrize( - "flow_builder", - [ - pytest.param(_build_swarm_state_snapshot_flow, id="swarm"), - ], -) -def test_nested_multi_agent_state_snapshots_follow_conversation_ownership_boundaries( - flow_builder, -) -> None: - ( - flow, - primary_llm, - primary_outputs, - secondary_llm, - secondary_outputs, - expected_multi_agent_span_class, - expected_child_message, - expected_parent_message, - expected_multi_agent_end_event_class, - ) = flow_builder() - conversation = flow.start_conversation() - conversation.append_user_message("dummy") - - listener = AgentSpecEventListener() - span_recorder = SnapshotSpanRecorder() - - with patch_llm(primary_llm, primary_outputs, patch_internal=True): - with patch_llm(secondary_llm, secondary_outputs, patch_internal=True): - with AgentSpecTrace(span_processors=[span_recorder]): - with register_event_listeners([listener]): - status = conversation.execute( - state_snapshot_policy=StateSnapshotPolicy( - state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS - ) - ) - - assert isinstance(status, UserMessageRequestStatus) - - flow_spans = [ - span for span in span_recorder.started_spans if isinstance(span, AgentSpecFlowExecutionSpan) - ] - assert len(flow_spans) == 1 - flow_span = flow_spans[0] - flow_snapshot_events = [ - event for event in flow_span.events if isinstance(event, AgentSpecStateSnapshotEmitted) - ] - assert len(flow_snapshot_events) == 2 - assert [event.conversation_id for event in flow_snapshot_events] == [ - conversation.conversation_id, - conversation.conversation_id, - ] - assert ( - flow_snapshot_events[-1].state_snapshot["conversation"]["messages"][-1]["content"] - == expected_parent_message - ) - - multi_agent_spans = [ - span - for span in span_recorder.started_spans - if isinstance(span, expected_multi_agent_span_class) - ] - assert len(multi_agent_spans) == 1 - multi_agent_span = multi_agent_spans[0] - multi_agent_snapshot_events = [ - event - for event in multi_agent_span.events - if isinstance(event, AgentSpecStateSnapshotEmitted) - ] - multi_agent_end_event = next( - event - for event in multi_agent_span.events - if isinstance(event, expected_multi_agent_end_event_class) - ) - parent_multi_agent_conversation_id = multi_agent_snapshot_events[0].conversation_id - - # The parent multi-agent conversation brackets both child turns. It keeps a - # single conversation id while the manager/main-thread agent and the - # delegated child each emit snapshots on their own agent execution spans. - assert [event.conversation_id for event in multi_agent_snapshot_events] == [ - parent_multi_agent_conversation_id, - parent_multi_agent_conversation_id, - parent_multi_agent_conversation_id, - parent_multi_agent_conversation_id, - parent_multi_agent_conversation_id, - parent_multi_agent_conversation_id, - ] - assert [ - ( - event.state_snapshot["execution"]["status"]["type"] - if event.state_snapshot["execution"]["status"] is not None - else None - ) - for event in multi_agent_snapshot_events - ] == [ - None, - "ToolRequestStatus", - None, - "UserMessageRequestStatus", - None, - "UserMessageRequestStatus", - ] - assert ( - multi_agent_snapshot_events[4].state_snapshot["conversation"]["messages"][-1]["content"] - == expected_child_message - ) - assert ( - multi_agent_snapshot_events[-1].state_snapshot["conversation"]["messages"][-1]["content"] - == expected_parent_message - ) - assert multi_agent_span.events.index( - multi_agent_snapshot_events[-1] - ) < multi_agent_span.events.index(multi_agent_end_event) - - agent_snapshot_spans = [ - span - for span in span_recorder.started_spans - if isinstance(span, AgentSpecAgentExecutionSpan) - and any(isinstance(event, AgentSpecStateSnapshotEmitted) for event in span.events) - ] - assert len(agent_snapshot_spans) == 3 - agent_snapshot_events_by_conversation_id: dict[str, list[AgentSpecStateSnapshotEmitted]] = {} - for agent_span in agent_snapshot_spans: - snapshot_events = [ - event for event in agent_span.events if isinstance(event, AgentSpecStateSnapshotEmitted) - ] - agent_snapshot_events_by_conversation_id.setdefault( - snapshot_events[0].conversation_id, - [], - ).extend(snapshot_events) - - assert len(agent_snapshot_events_by_conversation_id) == 2 - manager_thread_snapshot_events = next( - snapshot_events - for snapshot_events in agent_snapshot_events_by_conversation_id.values() - if len(snapshot_events) == 4 - ) - delegated_agent_snapshot_events = next( - snapshot_events - for snapshot_events in agent_snapshot_events_by_conversation_id.values() - if len(snapshot_events) == 2 - ) - - assert manager_thread_snapshot_events[0].conversation_id != conversation.conversation_id - assert manager_thread_snapshot_events[0].conversation_id != parent_multi_agent_conversation_id - assert delegated_agent_snapshot_events[0].conversation_id not in { - conversation.conversation_id, - parent_multi_agent_conversation_id, - manager_thread_snapshot_events[0].conversation_id, - } - assert [ - ( - event.state_snapshot["execution"]["status"]["type"] - if event.state_snapshot["execution"]["status"] is not None - else None - ) - for event in manager_thread_snapshot_events - ] == [ - None, - "ToolRequestStatus", - None, - "UserMessageRequestStatus", - ] - assert ( - manager_thread_snapshot_events[2].state_snapshot["conversation"]["messages"][-1]["content"] - == expected_child_message - ) - assert ( - manager_thread_snapshot_events[-1].state_snapshot["conversation"]["messages"][-1]["content"] - == expected_parent_message - ) - assert [ - ( - event.state_snapshot["execution"]["status"]["type"] - if event.state_snapshot["execution"]["status"] is not None - else None - ) - for event in delegated_agent_snapshot_events - ] == [None, "UserMessageRequestStatus"] - assert ( - delegated_agent_snapshot_events[-1].state_snapshot["conversation"]["messages"][-1][ - "content" - ] - == expected_child_message - ) - - tool_spans = [ - span for span in span_recorder.started_spans if isinstance(span, AgentSpecToolExecutionSpan) - ] - assert tool_spans - assert not any( - isinstance(event, AgentSpecStateSnapshotEmitted) - for span in tool_spans - for event in span.events - ) - - assert flow_span in span_recorder.ended_spans - assert multi_agent_span in span_recorder.ended_spans - - -def test_only_the_opening_state_snapshot_is_exported_when_a_turn_raises() -> None: - flow = Flow.from_steps( - [ - ToolExecutionStep( - tool=ServerTool( - name="explode", - description="Raise an error", - func=lambda: (_ for _ in ()).throw(RuntimeError("boom")), - input_descriptors=[], - ) - ), - CompleteStep(name="end"), - ] - ) - conversation = flow.start_conversation() - listener = AgentSpecEventListener() - span_recorder = SnapshotSpanRecorder() - - with AgentSpecTrace(span_processors=[span_recorder]): - with register_event_listeners([listener]): - with pytest.raises(RuntimeError, match="boom"): - conversation.execute( - state_snapshot_policy=StateSnapshotPolicy( - state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS - ) - ) - - flow_spans = [ - span for span in span_recorder.started_spans if isinstance(span, AgentSpecFlowExecutionSpan) - ] - assert len(flow_spans) == 1 - flow_span = flow_spans[0] - state_snapshot_events = [ - event for event in flow_span.events if isinstance(event, AgentSpecStateSnapshotEmitted) - ] - - assert len(state_snapshot_events) == 1 - assert state_snapshot_events[0].state_snapshot["execution"]["status"] is None - assert flow_span in span_recorder.ended_spans diff --git a/wayflowcore/tests/agentspec/test_state_snapshot_tracing_agent.py b/wayflowcore/tests/agentspec/test_state_snapshot_tracing_agent.py new file mode 100644 index 000000000..bc0c865d1 --- /dev/null +++ b/wayflowcore/tests/agentspec/test_state_snapshot_tracing_agent.py @@ -0,0 +1,396 @@ +# Copyright © 2026 Oracle and/or its affiliates. +# +# This software is under the Apache License 2.0 +# (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0) or Universal Permissive License +# (UPL) 1.0 (LICENSE-UPL or https://oss.oracle.com/licenses/upl), at your option. + +from contextlib import AbstractContextManager, ExitStack +from dataclasses import asdict, dataclass +from typing import Any, Sequence, cast + +from pyagentspec.adapters.wayflow import AgentSpecLoader +from pyagentspec.agent import Agent as AgentSpecAgent +from pyagentspec.llms import VllmConfig +from pyagentspec.tracing.events import AgentExecutionEnd as AgentSpecAgentExecutionEnd +from pyagentspec.tracing.events import AgentExecutionStart as AgentSpecAgentExecutionStart +from pyagentspec.tracing.events import Event as AgentSpecEvent +from pyagentspec.tracing.events import StateSnapshotEmitted as AgentSpecStateSnapshotEmitted +from pyagentspec.tracing.spanprocessor import SpanProcessor as AgentSpecSpanProcessor +from pyagentspec.tracing.spans import AgentExecutionSpan as AgentSpecAgentExecutionSpan +from pyagentspec.tracing.spans import LlmGenerationSpan as AgentSpecLlmGenerationSpan +from pyagentspec.tracing.spans import Span as AgentSpecSpan +from pyagentspec.tracing.trace import Trace as AgentSpecTrace + +from wayflowcore import Agent as WayflowAgent +from wayflowcore.agentspec.tracing import AgentSpecEventListener +from wayflowcore.events.eventlistener import register_event_listeners +from wayflowcore.executors.executionstatus import UserMessageRequestStatus +from wayflowcore.executors.statesnapshotpolicy import StateSnapshotInterval +from wayflowcore.models.vllmmodel import VllmModel +from wayflowcore.serialization import dump_conversation_state + +from ..testhelpers.patching import patch_llm +from ..testhelpers.statesnapshots import ( + build_state_snapshot_policy, + snapshot_message, + snapshot_status_types, +) + + +@dataclass(frozen=True) +class ExportedAGUIStateSnapshot: + conversation_id: str + snapshot: dict[str, Any] + + +@dataclass(frozen=True) +class RetrievalPreplan: + summary: str + entries: list[str] + ready_to_proceed: bool + + +@dataclass(frozen=True) +class RetrievalAssumption: + text: str + status: str + + +@dataclass(frozen=True) +class RetrievalUIState: + preplan: RetrievalPreplan + assumptions: list[RetrievalAssumption] + + +@dataclass(frozen=True) +class RetrievalAgentState: + thread_id: str + agent_type: str + llm_model_name: str + default_schema: str + input_document: str + message_count: int + last_response: str + ui: RetrievalUIState + + +class AGUIStateSnapshotExporter(AgentSpecSpanProcessor): + def __init__(self) -> None: + super().__init__() + self.exported_snapshots: list[ExportedAGUIStateSnapshot] = [] + + def on_start(self, span: AgentSpecSpan) -> None: + return None + + async def on_start_async(self, span: AgentSpecSpan) -> None: + return None + + def on_end(self, span: AgentSpecSpan) -> None: + return None + + async def on_end_async(self, span: AgentSpecSpan) -> None: + return None + + def on_event(self, event: AgentSpecEvent, span: AgentSpecSpan) -> None: + if not isinstance(event, AgentSpecStateSnapshotEmitted): + return + + conversation_snapshot = (event.state_snapshot or {}).get("conversation", {}) + self.exported_snapshots.append( + ExportedAGUIStateSnapshot( + conversation_id=event.conversation_id, + snapshot={ + "messages": conversation_snapshot.get("messages", []), + "input": conversation_snapshot.get("inputs", {}).get("input"), + "agent_state": (event.extra_state or {}).get("agent_state", {}), + }, + ) + ) + + async def on_event_async(self, event: AgentSpecEvent, span: AgentSpecSpan) -> None: + self.on_event(event, span) + + def startup(self) -> None: + return None + + def shutdown(self) -> None: + return None + + async def startup_async(self) -> None: + return None + + async def shutdown_async(self) -> None: + return None + + +class SnapshotSpanRecorder(AgentSpecSpanProcessor): + def __init__(self) -> None: + super().__init__() + self.started_spans: list[AgentSpecSpan] = [] + self.ended_spans: list[AgentSpecSpan] = [] + + def on_start(self, span: AgentSpecSpan) -> None: + self.started_spans.append(span) + + async def on_start_async(self, span: AgentSpecSpan) -> None: + self.started_spans.append(span) + + def on_end(self, span: AgentSpecSpan) -> None: + self.ended_spans.append(span) + + async def on_end_async(self, span: AgentSpecSpan) -> None: + self.ended_spans.append(span) + + def on_event(self, event: AgentSpecEvent, span: AgentSpecSpan) -> None: + return None + + async def on_event_async(self, event: AgentSpecEvent, span: AgentSpecSpan) -> None: + return None + + def startup(self) -> None: + return None + + def shutdown(self) -> None: + return None + + async def startup_async(self) -> None: + return None + + async def shutdown_async(self) -> None: + return None + + +_RETRIEVAL_INPUTS = { + "input": "How many orders last week?", + "thread_id": "thread-123", + "agent_type": "planner", + "llm_model_name": "gpt-5-mini", + "default_schema": "sales", + "input_document": "Only use the sales schema and weekly order metrics.", +} + +_RETRIEVAL_UI_STATE = RetrievalUIState( + preplan=RetrievalPreplan( + summary="Inspect weekly sales orders and answer concisely.", + entries=[ + "Inspect the active schema", + "Aggregate last week's orders", + "Return the final answer", + ], + ready_to_proceed=True, + ), + assumptions=[ + RetrievalAssumption(text="Use the sales schema only", status="approved"), + RetrievalAssumption(text="Week boundaries follow UTC", status="auto_approved"), + ], +) + + +def _create_retrieval_like_wayflow_agent() -> WayflowAgent: + agentspec_agent = AgentSpecAgent( + name="retrieval_agent", + llm_config=VllmConfig(name="llm", url="http://mock.url", model_id="mock.model"), + system_prompt="You are a helpful retrieval agent.", + ) + return cast(WayflowAgent, AgentSpecLoader().load_component(agentspec_agent)) + + +def _build_retrieval_agent_state( + *, + conversation_inputs: dict[str, Any], + message_count: int, + last_response: str, +) -> RetrievalAgentState: + return RetrievalAgentState( + thread_id=conversation_inputs["thread_id"], + agent_type=conversation_inputs["agent_type"], + llm_model_name=conversation_inputs["llm_model_name"], + default_schema=conversation_inputs["default_schema"], + input_document=conversation_inputs["input_document"], + message_count=message_count, + last_response=last_response, + ui=_RETRIEVAL_UI_STATE, + ) + + +def _build_retrieval_like_extra_state(conversation) -> dict[str, Any]: + conversation_snapshot = dump_conversation_state(conversation)["conversation"] + messages = conversation_snapshot["messages"] + last_response = next( + ( + message.get("content") + for message in reversed(messages) + if message.get("role") == "assistant" and message.get("content") + ), + "", + ) + return { + "agent_state": asdict( + _build_retrieval_agent_state( + conversation_inputs=conversation.inputs, + message_count=len(messages), + last_response=last_response, + ) + ) + } + + +def _create_mock_vllm_model(name: str) -> VllmModel: + return VllmModel(model_id="mock.model", host_port="http://mock.url", name=name) + + +def _policy( + interval: StateSnapshotInterval, + **kwargs: Any, +): + return build_state_snapshot_policy(interval, **kwargs) + + +def _execute_with_trace( + conversation, + *, + state_snapshot_policy, + span_processors: Sequence[AgentSpecSpanProcessor] = (), + contexts: Sequence[AbstractContextManager[Any]] = (), +): + span_recorder = SnapshotSpanRecorder() + listener = AgentSpecEventListener() + + with ExitStack() as stack: + for context in contexts: + stack.enter_context(context) + stack.enter_context(AgentSpecTrace(span_processors=[span_recorder, *span_processors])) + stack.enter_context(register_event_listeners([listener])) + status = conversation.execute(state_snapshot_policy=state_snapshot_policy) + + return status, span_recorder + + +def _spans( + span_recorder: SnapshotSpanRecorder, + span_type: type[AgentSpecSpan], +) -> list[AgentSpecSpan]: + return [span for span in span_recorder.started_spans if isinstance(span, span_type)] + + +def _single_span( + span_recorder: SnapshotSpanRecorder, + span_type: type[AgentSpecSpan], +) -> AgentSpecSpan: + matching_spans = _spans(span_recorder, span_type) + assert len(matching_spans) == 1 + return matching_spans[0] + + +def _events( + span: AgentSpecSpan, + event_type: type[AgentSpecEvent], +) -> list[AgentSpecEvent]: + return [event for event in span.events if isinstance(event, event_type)] + + +def _single_event( + span: AgentSpecSpan, + event_type: type[AgentSpecEvent], +) -> AgentSpecEvent: + return next(event for event in span.events if isinstance(event, event_type)) + + +def _assert_snapshot_precedes_terminal_event( + span: AgentSpecSpan, + snapshot_events: Sequence[AgentSpecStateSnapshotEmitted], + terminal_event: AgentSpecEvent, +) -> None: + assert span.events.index(snapshot_events[-1]) < span.events.index(terminal_event) + + +def test_agent_state_snapshots_support_the_agui_retrieval_export_flow() -> None: + assistant_message = "I checked the warehouse and found 42 orders last week." + wayflow_agent = _create_retrieval_like_wayflow_agent() + conversation = wayflow_agent.start_conversation(inputs=_RETRIEVAL_INPUTS) + conversation.append_user_message(_RETRIEVAL_INPUTS["input"]) + + agui_exporter = AGUIStateSnapshotExporter() + + status, span_recorder = _execute_with_trace( + conversation, + state_snapshot_policy=_policy( + StateSnapshotInterval.CONVERSATION_TURNS, + extra_state_builder=_build_retrieval_like_extra_state, + ), + span_processors=[agui_exporter], + contexts=[patch_llm(wayflow_agent.llm, [assistant_message], patch_internal=True)], + ) + + assert isinstance(status, UserMessageRequestStatus) + + agent_span = _single_span(span_recorder, AgentSpecAgentExecutionSpan) + assert _events(agent_span, AgentSpecAgentExecutionStart) + agent_end_event = _single_event(agent_span, AgentSpecAgentExecutionEnd) + state_snapshot_events = _events(agent_span, AgentSpecStateSnapshotEmitted) + assert len(state_snapshot_events) == 2 + + final_snapshot_event = state_snapshot_events[-1] + runtime_messages = final_snapshot_event.state_snapshot["conversation"]["messages"] + expected_agent_state = asdict( + _build_retrieval_agent_state( + conversation_inputs=_RETRIEVAL_INPUTS, + message_count=len(runtime_messages), + last_response=assistant_message, + ) + ) + + assert final_snapshot_event.conversation_id == conversation.conversation_id + assert ( + final_snapshot_event.state_snapshot["conversation"]["inputs"]["input"] + == _RETRIEVAL_INPUTS["input"] + ) + assert runtime_messages[-1]["content"] == assistant_message + assert final_snapshot_event.extra_state == {"agent_state": expected_agent_state} + _assert_snapshot_precedes_terminal_event(agent_span, state_snapshot_events, agent_end_event) + + assert len(agui_exporter.exported_snapshots) == 2 + assert agui_exporter.exported_snapshots[-1] == ExportedAGUIStateSnapshot( + conversation_id=conversation.conversation_id, + snapshot={ + "messages": runtime_messages, + "input": _RETRIEVAL_INPUTS["input"], + "agent_state": expected_agent_state, + }, + ) + + +def test_agent_node_turn_state_snapshots_are_mapped_into_the_agent_span_not_llm_spans() -> None: + assistant_message = "Hello from agent" + llm = _create_mock_vllm_model("agent") + agent = WayflowAgent(llm=llm) + conversation = agent.start_conversation() + conversation.append_user_message("Hi") + status, span_recorder = _execute_with_trace( + conversation, + state_snapshot_policy=_policy(StateSnapshotInterval.NODE_TURNS), + contexts=[patch_llm(llm, [assistant_message], patch_internal=True)], + ) + + assert isinstance(status, UserMessageRequestStatus) + + agent_span = _single_span(span_recorder, AgentSpecAgentExecutionSpan) + agent_end_event = _single_event(agent_span, AgentSpecAgentExecutionEnd) + state_snapshot_events = _events(agent_span, AgentSpecStateSnapshotEmitted) + + assert len(state_snapshot_events) == 2 + assert [event.state_snapshot["execution"]["curr_iter"] for event in state_snapshot_events] == [ + 0, + 1, + ] + assert snapshot_status_types(state_snapshot_events) == [None, None] + assert snapshot_message(state_snapshot_events[-1]) == assistant_message + _assert_snapshot_precedes_terminal_event(agent_span, state_snapshot_events, agent_end_event) + + llm_spans = _spans(span_recorder, AgentSpecLlmGenerationSpan) + assert llm_spans + assert not any( + isinstance(event, AgentSpecStateSnapshotEmitted) + for span in llm_spans + for event in span.events + ) diff --git a/wayflowcore/tests/agentspec/test_state_snapshot_tracing_flow.py b/wayflowcore/tests/agentspec/test_state_snapshot_tracing_flow.py new file mode 100644 index 000000000..b8c15da51 --- /dev/null +++ b/wayflowcore/tests/agentspec/test_state_snapshot_tracing_flow.py @@ -0,0 +1,381 @@ +# Copyright © 2026 Oracle and/or its affiliates. +# +# This software is under the Apache License 2.0 +# (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0) or Universal Permissive License +# (UPL) 1.0 (LICENSE-UPL or https://oss.oracle.com/licenses/upl), at your option. + +from contextlib import AbstractContextManager, ExitStack +from typing import Any, Sequence + +import pytest +from pyagentspec.tracing.events import Event as AgentSpecEvent +from pyagentspec.tracing.events import FlowExecutionEnd as AgentSpecFlowExecutionEnd +from pyagentspec.tracing.events import FlowExecutionStart as AgentSpecFlowExecutionStart +from pyagentspec.tracing.events import StateSnapshotEmitted as AgentSpecStateSnapshotEmitted +from pyagentspec.tracing.spanprocessor import SpanProcessor as AgentSpecSpanProcessor +from pyagentspec.tracing.spans import FlowExecutionSpan as AgentSpecFlowExecutionSpan +from pyagentspec.tracing.spans import NodeExecutionSpan as AgentSpecNodeExecutionSpan +from pyagentspec.tracing.spans import Span as AgentSpecSpan +from pyagentspec.tracing.spans import ToolExecutionSpan as AgentSpecToolExecutionSpan +from pyagentspec.tracing.trace import Trace as AgentSpecTrace + +from wayflowcore.agentspec.tracing import AgentSpecEventListener +from wayflowcore.events.eventlistener import register_event_listeners +from wayflowcore.executors.executionstatus import FinishedStatus +from wayflowcore.executors.statesnapshotpolicy import StateSnapshotInterval +from wayflowcore.flow import Flow +from wayflowcore.steps import CompleteStep, OutputMessageStep, ToolExecutionStep +from wayflowcore.tools import ServerTool + +from ..testhelpers.statesnapshots import ( + build_state_snapshot_policy, + snapshot_message, + snapshot_step_histories, +) + + +class SnapshotSpanRecorder(AgentSpecSpanProcessor): + def __init__(self) -> None: + super().__init__() + self.started_spans: list[AgentSpecSpan] = [] + self.ended_spans: list[AgentSpecSpan] = [] + + def on_start(self, span: AgentSpecSpan) -> None: + self.started_spans.append(span) + + async def on_start_async(self, span: AgentSpecSpan) -> None: + self.started_spans.append(span) + + def on_end(self, span: AgentSpecSpan) -> None: + self.ended_spans.append(span) + + async def on_end_async(self, span: AgentSpecSpan) -> None: + self.ended_spans.append(span) + + def on_event(self, event: AgentSpecEvent, span: AgentSpecSpan) -> None: + return None + + async def on_event_async(self, event: AgentSpecEvent, span: AgentSpecSpan) -> None: + return None + + def startup(self) -> None: + return None + + def shutdown(self) -> None: + return None + + async def startup_async(self) -> None: + return None + + async def shutdown_async(self) -> None: + return None + + +def _policy( + interval: StateSnapshotInterval, + **kwargs: Any, +): + return build_state_snapshot_policy(interval, **kwargs) + + +def _execute_with_trace( + conversation, + *, + state_snapshot_policy, + span_processors: Sequence[AgentSpecSpanProcessor] = (), + contexts: Sequence[AbstractContextManager[Any]] = (), +): + span_recorder = SnapshotSpanRecorder() + listener = AgentSpecEventListener() + + with ExitStack() as stack: + for context in contexts: + stack.enter_context(context) + stack.enter_context(AgentSpecTrace(span_processors=[span_recorder, *span_processors])) + stack.enter_context(register_event_listeners([listener])) + status = conversation.execute(state_snapshot_policy=state_snapshot_policy) + + return status, span_recorder + + +async def _execute_with_trace_async( + conversation, + *, + state_snapshot_policy, + span_processors: Sequence[AgentSpecSpanProcessor] = (), + contexts: Sequence[AbstractContextManager[Any]] = (), +): + span_recorder = SnapshotSpanRecorder() + listener = AgentSpecEventListener() + + async with AgentSpecTrace(span_processors=[span_recorder, *span_processors]): + with ExitStack() as stack: + for context in contexts: + stack.enter_context(context) + stack.enter_context(register_event_listeners([listener])) + status = await conversation.execute_async(state_snapshot_policy=state_snapshot_policy) + + return status, span_recorder + + +def _spans( + span_recorder: SnapshotSpanRecorder, + span_type: type[AgentSpecSpan], +) -> list[AgentSpecSpan]: + return [span for span in span_recorder.started_spans if isinstance(span, span_type)] + + +def _single_span( + span_recorder: SnapshotSpanRecorder, + span_type: type[AgentSpecSpan], +) -> AgentSpecSpan: + matching_spans = _spans(span_recorder, span_type) + assert len(matching_spans) == 1 + return matching_spans[0] + + +def _events( + span: AgentSpecSpan, + event_type: type[AgentSpecEvent], +) -> list[AgentSpecEvent]: + return [event for event in span.events if isinstance(event, event_type)] + + +def _single_event( + span: AgentSpecSpan, + event_type: type[AgentSpecEvent], +) -> AgentSpecEvent: + return next(event for event in span.events if isinstance(event, event_type)) + + +def _assert_snapshot_precedes_terminal_event( + span: AgentSpecSpan, + snapshot_events: Sequence[AgentSpecStateSnapshotEmitted], + terminal_event: AgentSpecEvent, +) -> None: + assert span.events.index(snapshot_events[-1]) < span.events.index(terminal_event) + + +def _build_tool_state_snapshot_flow() -> Flow: + return Flow.from_steps( + [ + ToolExecutionStep( + tool=ServerTool( + name="say_hi", + description="Say hi", + func=lambda: "hi", + input_descriptors=[], + ) + ), + CompleteStep(name="end"), + ] + ) + + +def test_flow_state_snapshots_are_mapped_into_the_flow_span_before_flow_end() -> None: + flow = Flow.from_steps( + [OutputMessageStep(message_template="Hello"), CompleteStep(name="end")], + step_names=["single_step", "end"], + ) + conversation = flow.start_conversation() + status, span_recorder = _execute_with_trace( + conversation, + state_snapshot_policy=_policy( + StateSnapshotInterval.CONVERSATION_TURNS, + extra_state_builder=lambda _conversation: {"ui": {"active_tab": "plan"}}, + ), + ) + + assert isinstance(status, FinishedStatus) + + flow_span = _single_span(span_recorder, AgentSpecFlowExecutionSpan) + assert _events(flow_span, AgentSpecFlowExecutionStart) + flow_end_event = _single_event(flow_span, AgentSpecFlowExecutionEnd) + state_snapshot_events = _events(flow_span, AgentSpecStateSnapshotEmitted) + + assert len(state_snapshot_events) == 2 + final_snapshot_event = state_snapshot_events[-1] + assert final_snapshot_event.conversation_id == conversation.conversation_id + assert snapshot_message(final_snapshot_event) == "Hello" + assert final_snapshot_event.extra_state == {"ui": {"active_tab": "plan"}} + _assert_snapshot_precedes_terminal_event(flow_span, state_snapshot_events, flow_end_event) + assert flow_span.end_time is not None + assert final_snapshot_event.timestamp <= flow_span.end_time + assert "variable_state" not in final_snapshot_event.model_dump(mask_sensitive_information=False) + assert flow_span in span_recorder.ended_spans + + +@pytest.mark.anyio +async def test_flow_state_snapshots_are_mapped_into_the_flow_span_before_flow_end_async() -> None: + flow = Flow.from_steps( + [OutputMessageStep(message_template="Hello"), CompleteStep(name="end")], + step_names=["single_step", "end"], + ) + conversation = flow.start_conversation() + status, span_recorder = await _execute_with_trace_async( + conversation, + state_snapshot_policy=_policy( + StateSnapshotInterval.CONVERSATION_TURNS, + extra_state_builder=lambda _conversation: {"ui": {"active_tab": "plan"}}, + ), + ) + + assert isinstance(status, FinishedStatus) + + flow_span = _single_span(span_recorder, AgentSpecFlowExecutionSpan) + flow_end_event = _single_event(flow_span, AgentSpecFlowExecutionEnd) + state_snapshot_events = _events(flow_span, AgentSpecStateSnapshotEmitted) + + assert len(state_snapshot_events) == 2 + final_snapshot_event = state_snapshot_events[-1] + assert final_snapshot_event.conversation_id == conversation.conversation_id + assert snapshot_message(final_snapshot_event) == "Hello" + assert final_snapshot_event.extra_state == {"ui": {"active_tab": "plan"}} + _assert_snapshot_precedes_terminal_event(flow_span, state_snapshot_events, flow_end_event) + assert flow_span.end_time is not None + assert final_snapshot_event.timestamp <= flow_span.end_time + assert "variable_state" not in final_snapshot_event.model_dump(mask_sensitive_information=False) + assert flow_span in span_recorder.ended_spans + + +def test_node_turn_state_snapshots_are_mapped_into_the_flow_span_not_node_spans() -> None: + flow = Flow.from_steps( + [OutputMessageStep(message_template="Hello"), CompleteStep(name="end")], + step_names=["single_step", "end"], + ) + conversation = flow.start_conversation() + status, span_recorder = _execute_with_trace( + conversation, + state_snapshot_policy=_policy(StateSnapshotInterval.NODE_TURNS), + ) + + assert isinstance(status, FinishedStatus) + + flow_span = _single_span(span_recorder, AgentSpecFlowExecutionSpan) + flow_end_event = _single_event(flow_span, AgentSpecFlowExecutionEnd) + flow_snapshot_events = _events(flow_span, AgentSpecStateSnapshotEmitted) + + assert len(flow_snapshot_events) == 6 + assert snapshot_step_histories(flow_snapshot_events) == [ + [], + ["__StartStep__"], + ["__StartStep__"], + ["__StartStep__", "single_step"], + ["__StartStep__", "single_step"], + ["__StartStep__", "single_step", "end"], + ] + _assert_snapshot_precedes_terminal_event(flow_span, flow_snapshot_events, flow_end_event) + + node_spans = _spans(span_recorder, AgentSpecNodeExecutionSpan) + assert node_spans + assert not any( + isinstance(event, AgentSpecStateSnapshotEmitted) + for span in node_spans + for event in span.events + ) + + +@pytest.mark.parametrize( + ("interval", "expected_step_histories"), + [ + pytest.param( + StateSnapshotInterval.TOOL_TURNS, + [ + ["__StartStep__", "step_0"], + ["__StartStep__", "step_0"], + ], + id="tool_turns", + ), + pytest.param( + StateSnapshotInterval.ALL_INTERNAL_TURNS, + [ + [], + ["__StartStep__"], + ["__StartStep__"], + ["__StartStep__", "step_0"], + ["__StartStep__", "step_0"], + ["__StartStep__", "step_0"], + ["__StartStep__", "step_0"], + ["__StartStep__", "step_0", "end"], + ], + id="all_internal_turns", + ), + ], +) +def test_internal_flow_state_snapshots_follow_conversation_ownership_for_agent_spec( + interval: StateSnapshotInterval, + expected_step_histories: list[list[str]], +) -> None: + flow = _build_tool_state_snapshot_flow() + conversation = flow.start_conversation() + status, span_recorder = _execute_with_trace( + conversation, + state_snapshot_policy=_policy(interval), + ) + + assert isinstance(status, FinishedStatus) + + flow_span = _single_span(span_recorder, AgentSpecFlowExecutionSpan) + flow_snapshot_events = _events(flow_span, AgentSpecStateSnapshotEmitted) + assert snapshot_step_histories(flow_snapshot_events) == expected_step_histories + + tool_spans = _spans(span_recorder, AgentSpecToolExecutionSpan) + node_spans = _spans(span_recorder, AgentSpecNodeExecutionSpan) + assert tool_spans + assert node_spans + assert not any( + isinstance(event, AgentSpecStateSnapshotEmitted) + for span in [*tool_spans, *node_spans] + for event in span.events + ) + + +def test_off_policy_does_not_bridge_state_snapshots_into_agent_spec_spans() -> None: + flow = Flow.from_steps( + [OutputMessageStep(message_template="Hello"), CompleteStep(name="end")], + step_names=["single_step", "end"], + ) + conversation = flow.start_conversation() + status, span_recorder = _execute_with_trace( + conversation, + state_snapshot_policy=_policy(StateSnapshotInterval.OFF), + ) + + assert isinstance(status, FinishedStatus) + + flow_span = _single_span(span_recorder, AgentSpecFlowExecutionSpan) + assert _events(flow_span, AgentSpecFlowExecutionStart) + assert _events(flow_span, AgentSpecFlowExecutionEnd) + assert not _events(flow_span, AgentSpecStateSnapshotEmitted) + + +def test_only_the_opening_state_snapshot_is_exported_when_a_turn_raises() -> None: + flow = Flow.from_steps( + [ + ToolExecutionStep( + tool=ServerTool( + name="explode", + description="Raise an error", + func=lambda: (_ for _ in ()).throw(RuntimeError("boom")), + input_descriptors=[], + ) + ), + CompleteStep(name="end"), + ] + ) + conversation = flow.start_conversation() + span_recorder = SnapshotSpanRecorder() + + with AgentSpecTrace(span_processors=[span_recorder]): + with register_event_listeners([AgentSpecEventListener()]): + with pytest.raises(RuntimeError, match="boom"): + conversation.execute( + state_snapshot_policy=_policy(StateSnapshotInterval.CONVERSATION_TURNS) + ) + + flow_span = _single_span(span_recorder, AgentSpecFlowExecutionSpan) + state_snapshot_events = _events(flow_span, AgentSpecStateSnapshotEmitted) + + assert len(state_snapshot_events) == 1 + assert state_snapshot_events[0].state_snapshot["execution"]["status"] is None + assert flow_span in span_recorder.ended_spans diff --git a/wayflowcore/tests/agentspec/test_state_snapshot_tracing_nested.py b/wayflowcore/tests/agentspec/test_state_snapshot_tracing_nested.py new file mode 100644 index 000000000..58434b91f --- /dev/null +++ b/wayflowcore/tests/agentspec/test_state_snapshot_tracing_nested.py @@ -0,0 +1,381 @@ +# Copyright © 2026 Oracle and/or its affiliates. +# +# This software is under the Apache License 2.0 +# (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0) or Universal Permissive License +# (UPL) 1.0 (LICENSE-UPL or https://oss.oracle.com/licenses/upl), at your option. + +from contextlib import AbstractContextManager, ExitStack +from dataclasses import dataclass +from typing import Any, Sequence + +import pytest +from pyagentspec.tracing.events import Event as AgentSpecEvent +from pyagentspec.tracing.events import FlowExecutionEnd as AgentSpecFlowExecutionEnd +from pyagentspec.tracing.events import FlowExecutionStart as AgentSpecFlowExecutionStart +from pyagentspec.tracing.events import StateSnapshotEmitted as AgentSpecStateSnapshotEmitted +from pyagentspec.tracing.events import SwarmExecutionEnd as AgentSpecSwarmExecutionEnd +from pyagentspec.tracing.spanprocessor import SpanProcessor as AgentSpecSpanProcessor +from pyagentspec.tracing.spans import AgentExecutionSpan as AgentSpecAgentExecutionSpan +from pyagentspec.tracing.spans import FlowExecutionSpan as AgentSpecFlowExecutionSpan +from pyagentspec.tracing.spans import Span as AgentSpecSpan +from pyagentspec.tracing.spans import SwarmExecutionSpan as AgentSpecSwarmExecutionSpan +from pyagentspec.tracing.spans import ToolExecutionSpan as AgentSpecToolExecutionSpan +from pyagentspec.tracing.trace import Trace as AgentSpecTrace + +from wayflowcore import Agent as WayflowAgent +from wayflowcore.agentspec.tracing import AgentSpecEventListener +from wayflowcore.events.eventlistener import register_event_listeners +from wayflowcore.executors.executionstatus import FinishedStatus, UserMessageRequestStatus +from wayflowcore.executors.statesnapshotpolicy import StateSnapshotInterval +from wayflowcore.flow import Flow +from wayflowcore.messagelist import Message, MessageType +from wayflowcore.models.vllmmodel import VllmModel +from wayflowcore.steps import AgentExecutionStep, CompleteStep, FlowExecutionStep, OutputMessageStep +from wayflowcore.swarm import Swarm +from wayflowcore.tools import ToolRequest + +from ..testhelpers.patching import patch_llm +from ..testhelpers.statesnapshots import ( + build_state_snapshot_policy, + snapshot_message, + snapshot_runtime_conversation_ids, + snapshot_status_types, +) + + +@dataclass(frozen=True) +class SwarmStateSnapshotScenario: + flow: Flow + primary_llm: VllmModel + primary_outputs: list[Message | str] + secondary_llm: VllmModel + secondary_outputs: list[Message | str] + multi_agent_span_class: type[AgentSpecSpan] + child_message: str + parent_message: str + multi_agent_end_event_class: type[AgentSpecEvent] + + +class SnapshotSpanRecorder(AgentSpecSpanProcessor): + def __init__(self) -> None: + super().__init__() + self.started_spans: list[AgentSpecSpan] = [] + self.ended_spans: list[AgentSpecSpan] = [] + + def on_start(self, span: AgentSpecSpan) -> None: + self.started_spans.append(span) + + async def on_start_async(self, span: AgentSpecSpan) -> None: + self.started_spans.append(span) + + def on_end(self, span: AgentSpecSpan) -> None: + self.ended_spans.append(span) + + async def on_end_async(self, span: AgentSpecSpan) -> None: + self.ended_spans.append(span) + + def on_event(self, event: AgentSpecEvent, span: AgentSpecSpan) -> None: + return None + + async def on_event_async(self, event: AgentSpecEvent, span: AgentSpecSpan) -> None: + return None + + def startup(self) -> None: + return None + + def shutdown(self) -> None: + return None + + async def startup_async(self) -> None: + return None + + async def shutdown_async(self) -> None: + return None + + +def _create_mock_vllm_model(name: str) -> VllmModel: + return VllmModel(model_id="mock.model", host_port="http://mock.url", name=name) + + +def _create_send_message_request(recipient_name: str, message: str) -> Message: + return Message( + content="", + message_type=MessageType.TOOL_REQUEST, + tool_requests=[ + ToolRequest( + name="send_message", + args={"recipient": recipient_name, "message": message}, + ) + ], + ) + + +def _build_swarm_state_snapshot_flow() -> SwarmStateSnapshotScenario: + first_agent_llm = _create_mock_vllm_model("agent1") + second_agent_llm = _create_mock_vllm_model("agent2") + first_agent = WayflowAgent(llm=first_agent_llm, name="agent1", description="agent1") + second_agent = WayflowAgent(llm=second_agent_llm, name="agent2", description="agent2") + swarm = Swarm( + first_agent=first_agent, + relationships=[(first_agent, second_agent), (second_agent, first_agent)], + name="swarm", + ) + + return SwarmStateSnapshotScenario( + flow=Flow.from_steps([AgentExecutionStep(agent=swarm), CompleteStep(name="end")]), + primary_llm=first_agent_llm, + primary_outputs=[ + _create_send_message_request("agent2", "Do it"), + "agent1 final answer", + ], + secondary_llm=second_agent_llm, + secondary_outputs=["agent2 answer"], + multi_agent_span_class=AgentSpecSwarmExecutionSpan, + child_message="agent2 answer", + parent_message="agent1 final answer", + multi_agent_end_event_class=AgentSpecSwarmExecutionEnd, + ) + + +def _policy( + interval: StateSnapshotInterval, + **kwargs: Any, +): + return build_state_snapshot_policy(interval, **kwargs) + + +def _execute_with_trace( + conversation, + *, + state_snapshot_policy, + span_processors: Sequence[AgentSpecSpanProcessor] = (), + contexts: Sequence[AbstractContextManager[Any]] = (), +): + span_recorder = SnapshotSpanRecorder() + listener = AgentSpecEventListener() + + with ExitStack() as stack: + for context in contexts: + stack.enter_context(context) + stack.enter_context(AgentSpecTrace(span_processors=[span_recorder, *span_processors])) + stack.enter_context(register_event_listeners([listener])) + status = conversation.execute(state_snapshot_policy=state_snapshot_policy) + + return status, span_recorder + + +def _spans( + span_recorder: SnapshotSpanRecorder, + span_type: type[AgentSpecSpan], +) -> list[AgentSpecSpan]: + return [span for span in span_recorder.started_spans if isinstance(span, span_type)] + + +def _single_span( + span_recorder: SnapshotSpanRecorder, + span_type: type[AgentSpecSpan], +) -> AgentSpecSpan: + matching_spans = _spans(span_recorder, span_type) + assert len(matching_spans) == 1 + return matching_spans[0] + + +def _events( + span: AgentSpecSpan, + event_type: type[AgentSpecEvent], +) -> list[AgentSpecEvent]: + return [event for event in span.events if isinstance(event, event_type)] + + +def _single_event( + span: AgentSpecSpan, + event_type: type[AgentSpecEvent], +) -> AgentSpecEvent: + return next(event for event in span.events if isinstance(event, event_type)) + + +def _assert_snapshot_precedes_terminal_event( + span: AgentSpecSpan, + snapshot_events: Sequence[AgentSpecStateSnapshotEmitted], + terminal_event: AgentSpecEvent, +) -> None: + assert span.events.index(snapshot_events[-1]) < span.events.index(terminal_event) + + +@pytest.mark.parametrize( + "flow_builder", + [ + pytest.param(_build_swarm_state_snapshot_flow, id="swarm"), + ], +) +def test_nested_multi_agent_state_snapshots_follow_conversation_ownership_boundaries( + flow_builder, +) -> None: + scenario = flow_builder() + conversation = scenario.flow.start_conversation() + conversation.append_user_message("dummy") + + status, span_recorder = _execute_with_trace( + conversation, + state_snapshot_policy=_policy(StateSnapshotInterval.CONVERSATION_TURNS), + contexts=[ + patch_llm(scenario.primary_llm, scenario.primary_outputs, patch_internal=True), + patch_llm(scenario.secondary_llm, scenario.secondary_outputs, patch_internal=True), + ], + ) + + assert isinstance(status, UserMessageRequestStatus) + + flow_span = _single_span(span_recorder, AgentSpecFlowExecutionSpan) + flow_snapshot_events = _events(flow_span, AgentSpecStateSnapshotEmitted) + assert len(flow_snapshot_events) == 2 + assert [event.conversation_id for event in flow_snapshot_events] == [ + conversation.conversation_id, + conversation.conversation_id, + ] + assert snapshot_message(flow_snapshot_events[-1]) == scenario.parent_message + + multi_agent_span = _single_span(span_recorder, scenario.multi_agent_span_class) + multi_agent_snapshot_events = _events(multi_agent_span, AgentSpecStateSnapshotEmitted) + multi_agent_end_event = _single_event(multi_agent_span, scenario.multi_agent_end_event_class) + parent_multi_agent_conversation_id = multi_agent_snapshot_events[0].conversation_id + + assert [event.conversation_id for event in multi_agent_snapshot_events] == [ + parent_multi_agent_conversation_id, + parent_multi_agent_conversation_id, + parent_multi_agent_conversation_id, + parent_multi_agent_conversation_id, + parent_multi_agent_conversation_id, + parent_multi_agent_conversation_id, + ] + assert snapshot_status_types(multi_agent_snapshot_events) == [ + None, + "ToolRequestStatus", + None, + "UserMessageRequestStatus", + None, + "UserMessageRequestStatus", + ] + assert snapshot_message(multi_agent_snapshot_events[4]) == scenario.child_message + assert snapshot_message(multi_agent_snapshot_events[-1]) == scenario.parent_message + _assert_snapshot_precedes_terminal_event( + multi_agent_span, + multi_agent_snapshot_events, + multi_agent_end_event, + ) + + agent_snapshot_spans = [ + span + for span in span_recorder.started_spans + if isinstance(span, AgentSpecAgentExecutionSpan) + and any(isinstance(event, AgentSpecStateSnapshotEmitted) for event in span.events) + ] + assert len(agent_snapshot_spans) == 3 + agent_snapshot_events_by_conversation_id: dict[str, list] = {} + for agent_span in agent_snapshot_spans: + snapshot_events = _events(agent_span, AgentSpecStateSnapshotEmitted) + agent_snapshot_events_by_conversation_id.setdefault( + snapshot_events[0].conversation_id, + [], + ).extend(snapshot_events) + + assert len(agent_snapshot_events_by_conversation_id) == 2 + manager_thread_snapshot_events = next( + snapshot_events + for snapshot_events in agent_snapshot_events_by_conversation_id.values() + if len(snapshot_events) == 4 + ) + delegated_agent_snapshot_events = next( + snapshot_events + for snapshot_events in agent_snapshot_events_by_conversation_id.values() + if len(snapshot_events) == 2 + ) + + assert manager_thread_snapshot_events[0].conversation_id != conversation.conversation_id + assert manager_thread_snapshot_events[0].conversation_id != parent_multi_agent_conversation_id + assert delegated_agent_snapshot_events[0].conversation_id not in { + conversation.conversation_id, + parent_multi_agent_conversation_id, + manager_thread_snapshot_events[0].conversation_id, + } + assert snapshot_status_types(manager_thread_snapshot_events) == [ + None, + "ToolRequestStatus", + None, + "UserMessageRequestStatus", + ] + assert snapshot_message(manager_thread_snapshot_events[2]) == scenario.child_message + assert snapshot_message(manager_thread_snapshot_events[-1]) == scenario.parent_message + assert snapshot_status_types(delegated_agent_snapshot_events) == [ + None, + "UserMessageRequestStatus", + ] + assert snapshot_message(delegated_agent_snapshot_events[-1]) == scenario.child_message + + tool_spans = _spans(span_recorder, AgentSpecToolExecutionSpan) + assert tool_spans + assert not any( + isinstance(event, AgentSpecStateSnapshotEmitted) + for span in tool_spans + for event in span.events + ) + + assert flow_span in span_recorder.ended_spans + assert multi_agent_span in span_recorder.ended_spans + + +def test_nested_flow_state_snapshots_stay_on_the_root_flow_span_for_shared_conversations() -> None: + child_flow = Flow.from_steps( + [OutputMessageStep(message_template="child"), CompleteStep(name="end")], + step_names=["child_message", "end"], + name="child_flow", + ) + parent_flow = Flow.from_steps( + [ + FlowExecutionStep(flow=child_flow), + OutputMessageStep(message_template="parent"), + CompleteStep(name="end"), + ], + step_names=["child_flow_step", "parent_message", "end"], + name="parent_flow", + ) + conversation = parent_flow.start_conversation() + + status, span_recorder = _execute_with_trace( + conversation, + state_snapshot_policy=_policy(StateSnapshotInterval.CONVERSATION_TURNS), + ) + + assert isinstance(status, FinishedStatus) + + flow_spans = _spans(span_recorder, AgentSpecFlowExecutionSpan) + assert len(flow_spans) == 2 + + flow_spans_by_name = { + _single_event(span, AgentSpecFlowExecutionStart).flow.name: span for span in flow_spans + } + parent_span = flow_spans_by_name["parent_flow"] + child_span = flow_spans_by_name["child_flow"] + + parent_snapshot_events = _events(parent_span, AgentSpecStateSnapshotEmitted) + child_snapshot_events = _events(child_span, AgentSpecStateSnapshotEmitted) + parent_end_event = _single_event(parent_span, AgentSpecFlowExecutionEnd) + + assert [event.conversation_id for event in parent_snapshot_events] == [ + conversation.conversation_id, + conversation.conversation_id, + conversation.conversation_id, + conversation.conversation_id, + ] + child_runtime_conversation_id = parent_snapshot_events[1].state_snapshot["conversation"]["id"] + assert snapshot_runtime_conversation_ids(parent_snapshot_events) == [ + conversation.id, + child_runtime_conversation_id, + child_runtime_conversation_id, + conversation.id, + ] + assert child_runtime_conversation_id != conversation.id + assert snapshot_message(parent_snapshot_events[2]) == "child" + assert snapshot_message(parent_snapshot_events[-1]) == "parent" + assert not child_snapshot_events + _assert_snapshot_precedes_terminal_event(parent_span, parent_snapshot_events, parent_end_event) diff --git a/wayflowcore/tests/events/test_state_snapshot_event.py b/wayflowcore/tests/events/test_state_snapshot_event_tracing.py similarity index 83% rename from wayflowcore/tests/events/test_state_snapshot_event.py rename to wayflowcore/tests/events/test_state_snapshot_event_tracing.py index b59172f77..0b12281ba 100644 --- a/wayflowcore/tests/events/test_state_snapshot_event.py +++ b/wayflowcore/tests/events/test_state_snapshot_event_tracing.py @@ -9,18 +9,13 @@ from wayflowcore.events.event import _PII_TEXT_MASK, StateSnapshotEvent -def test_state_snapshot_event_requires_conversation_id() -> None: - with pytest.raises(ValueError, match="conversation_id"): - StateSnapshotEvent() - - @pytest.mark.parametrize("mask_sensitive_information", [True, False]) def test_state_snapshot_event_serialization( mask_sensitive_information: bool, ) -> None: event = StateSnapshotEvent( conversation_id="conversation-123", - state_snapshot={"conversation": {"messages": []}}, + state_snapshot={"conversation": {"id": "conversation-runtime-123", "messages": []}}, extra_state={"ui": {"active_tab": "plan"}}, variable_state={"count": 2}, name="snapshot", @@ -41,6 +36,8 @@ def test_state_snapshot_event_serialization( assert serialized_event["extra_state"] == _PII_TEXT_MASK assert serialized_event["variable_state"] == _PII_TEXT_MASK else: - assert serialized_event["state_snapshot"] == {"conversation": {"messages": []}} + assert serialized_event["state_snapshot"] == { + "conversation": {"id": "conversation-runtime-123", "messages": []} + } assert serialized_event["extra_state"] == {"ui": {"active_tab": "plan"}} assert serialized_event["variable_state"] == {"count": 2} diff --git a/wayflowcore/tests/events/test_state_snapshot_event_validation.py b/wayflowcore/tests/events/test_state_snapshot_event_validation.py new file mode 100644 index 000000000..0fb30ed68 --- /dev/null +++ b/wayflowcore/tests/events/test_state_snapshot_event_validation.py @@ -0,0 +1,36 @@ +# Copyright © 2026 Oracle and/or its affiliates. +# +# This software is under the Apache License 2.0 +# (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0) or Universal Permissive License +# (UPL) 1.0 (LICENSE-UPL or https://oss.oracle.com/licenses/upl), at your option. + +import pytest + +from wayflowcore.events.event import StateSnapshotEvent + + +def test_state_snapshot_event_requires_conversation_id() -> None: + with pytest.raises(ValueError, match="conversation_id"): + StateSnapshotEvent() + + +def test_state_snapshot_event_allows_missing_state_snapshot() -> None: + event = StateSnapshotEvent(conversation_id="conversation-123") + + assert event.state_snapshot is None + + +def test_state_snapshot_event_requires_state_snapshot_to_be_a_dictionary() -> None: + with pytest.raises(ValueError, match="state_snapshot must be a dictionary"): + StateSnapshotEvent( + conversation_id="conversation-123", + state_snapshot="not-a-dictionary", + ) + + +def test_state_snapshot_event_requires_runtime_conversation_id() -> None: + with pytest.raises(ValueError, match=r"state_snapshot\['conversation'\]\['id'\]"): + StateSnapshotEvent( + conversation_id="conversation-123", + state_snapshot={"conversation": {"messages": []}}, + ) diff --git a/wayflowcore/tests/events/test_state_snapshot_runtime.py b/wayflowcore/tests/events/test_state_snapshot_runtime.py deleted file mode 100644 index 0d61edf1a..000000000 --- a/wayflowcore/tests/events/test_state_snapshot_runtime.py +++ /dev/null @@ -1,867 +0,0 @@ -# Copyright © 2026 Oracle and/or its affiliates. -# -# This software is under the Apache License 2.0 -# (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0) or Universal Permissive License -# (UPL) 1.0 (LICENSE-UPL or https://oss.oracle.com/licenses/upl), at your option. - -import threading -from contextlib import nullcontext -from typing import Sequence - -import pytest - -from wayflowcore.agent import Agent -from wayflowcore.conversation import Conversation -from wayflowcore.events.event import Event, StateSnapshotEvent -from wayflowcore.events.eventlistener import EventListener, register_event_listeners -from wayflowcore.executors._events.event import EventType -from wayflowcore.executors._executionstate import ConversationExecutionState -from wayflowcore.executors.executionstatus import FinishedStatus, UserMessageRequestStatus -from wayflowcore.executors.interrupts.executioninterrupt import ( - ExecutionInterrupt, - InterruptedExecutionStatus, - _NullExecutionInterrupt, -) -from wayflowcore.executors.statesnapshotpolicy import ( - StateSnapshotInterval, - StateSnapshotPolicy, -) -from wayflowcore.flow import Flow -from wayflowcore.managerworkers import ManagerWorkers -from wayflowcore.messagelist import Message, MessageType -from wayflowcore.property import AnyProperty -from wayflowcore.serialization.serializer import SerializableNeedToBeImplementedMixin -from wayflowcore.steps import ( - AgentExecutionStep, - CompleteStep, - FlowExecutionStep, - OutputMessageStep, - ToolExecutionStep, - VariableWriteStep, -) -from wayflowcore.swarm import Swarm -from wayflowcore.tools import ServerTool, ToolRequest, tool -from wayflowcore.variable import Variable - -from ..conftest import disable_streaming -from ..test_interrupts import OnEventExecutionInterrupt -from ..testhelpers.dummy import DummyModel - -# Runtime snapshot tests stay focused on emission semantics. Event payload -# mapping and serialization details live in dedicated tracing/serialization -# suites. - - -class SnapshotCollector(EventListener): - def __init__(self) -> None: - self.state_snapshot_events: list[StateSnapshotEvent] = [] - - def __call__(self, event: Event) -> None: - if isinstance(event, StateSnapshotEvent): - self.state_snapshot_events.append(event) - - -class MutatingExecutionEndInterrupt(SerializableNeedToBeImplementedMixin, _NullExecutionInterrupt): - def __init__(self) -> None: - self.lock = threading.Lock() - self.count = 0 - super().__init__() - - def _on_execution_end( - self, - state: ConversationExecutionState, - conversation: Conversation, - ) -> InterruptedExecutionStatus | None: - conversation.inputs["preview_count"] = conversation.inputs.get("preview_count", 0) + 1 - self.count += 1 - return None - - -class WorkerExecutionEndInterrupt(SerializableNeedToBeImplementedMixin, _NullExecutionInterrupt): - def __init__(self) -> None: - self.triggered = False - super().__init__() - - def _on_execution_end( - self, - state: ConversationExecutionState, - conversation: Conversation, - ) -> InterruptedExecutionStatus | None: - if self.triggered: - return None - if getattr(conversation.component, "name", None) != "worker": - return None - - self.triggered = True - return InterruptedExecutionStatus( - interrupter=self, - reason="worker execution end", - _conversation_id=conversation.id, - ) - - -class _UnserializableVariableValue: - pass - - -def _create_output_flow_conversation(message: str = "Hello") -> Conversation: - flow = Flow.from_steps( - [ - OutputMessageStep(message_template=message), - CompleteStep(name="end"), - ] - ) - return flow.start_conversation() - - -def _create_agent_conversation(message: str = "Hello from agent") -> Conversation: - llm = DummyModel() - llm.set_next_output(message) - conversation = Agent(llm=llm).start_conversation() - conversation.append_user_message("Hi") - return conversation - - -def _create_tool_calling_agent_conversation() -> Conversation: - @tool - def do_nothing_tool() -> str: - """Do nothing tool.""" - return "Tool called successfully" - - llm = DummyModel() - llm.set_next_output( - { - "Please use the do_nothing_tool": Message( - message_type=MessageType.TOOL_REQUEST, - content="I am calling the do nothing tool", - tool_requests=[ToolRequest("do_nothing_tool", {}, "tc1")], - ) - } - ) - conversation = Agent(llm=llm, tools=[do_nothing_tool], max_iterations=10).start_conversation() - conversation.append_user_message("Please use the do_nothing_tool") - return conversation - - -def _create_send_message_request(recipient_name: str, message: str) -> Message: - return Message( - content="", - message_type=MessageType.TOOL_REQUEST, - tool_requests=[ - ToolRequest( - name="send_message", - args={"recipient": recipient_name, "message": message}, - ) - ], - ) - - -def _create_nested_agent_step_flow_conversation() -> Conversation: - llm = DummyModel() - llm.set_next_output("agent answer") - child_agent = Agent(llm=llm) - conversation = Flow.from_steps( - [AgentExecutionStep(agent=child_agent), CompleteStep(name="end")] - ).start_conversation() - conversation.append_user_message("dummy") - return conversation - - -def _create_nested_managerworkers_flow_conversation() -> Conversation: - llm = DummyModel() - worker = Agent(llm=llm, name="worker", description="worker") - group = ManagerWorkers(group_manager=llm, workers=[worker]) - llm.set_next_output( - [ - _create_send_message_request("worker", "Do it"), - "worker answer", - "manager final answer", - ] - ) - - conversation = Flow.from_steps( - [AgentExecutionStep(agent=group), CompleteStep(name="end")] - ).start_conversation() - conversation.append_user_message("dummy") - return conversation - - -def _create_managerworkers_conversation() -> Conversation: - llm = DummyModel() - worker = Agent(llm=llm, name="worker", description="worker") - group = ManagerWorkers(group_manager=llm, workers=[worker]) - llm.set_next_output( - [ - _create_send_message_request("worker", "Do it"), - "worker answer", - "manager final answer", - ] - ) - - conversation = group.start_conversation() - conversation.append_user_message("dummy") - return conversation - - -def _create_nested_swarm_flow_conversation() -> Conversation: - llm = DummyModel() - first_agent = Agent(llm=llm, name="agent1", description="agent1") - second_agent = Agent(llm=llm, name="agent2", description="agent2") - swarm = Swarm( - first_agent=first_agent, - relationships=[(first_agent, second_agent), (second_agent, first_agent)], - ) - llm.set_next_output( - [ - _create_send_message_request("agent2", "Do it"), - "agent2 answer", - "agent1 final answer", - ] - ) - - conversation = Flow.from_steps( - [AgentExecutionStep(agent=swarm), CompleteStep(name="end")] - ).start_conversation() - conversation.append_user_message("dummy") - return conversation - - -def _snapshot_status_types(snapshot_events: Sequence[StateSnapshotEvent]) -> list[str | None]: - return [ - ( - snapshot_event.state_snapshot["execution"]["status"]["type"] - if snapshot_event.state_snapshot["execution"]["status"] is not None - else None - ) - for snapshot_event in snapshot_events - ] - - -def _execute_with_state_snapshots( - conversation: Conversation, - *, - state_snapshot_policy: StateSnapshotPolicy, - execution_interrupts: Sequence[ExecutionInterrupt] | None = None, - use_disable_streaming: bool = False, -) -> tuple[object, list[StateSnapshotEvent]]: - collector = SnapshotCollector() - streaming_context = disable_streaming() if use_disable_streaming else nullcontext() - - with streaming_context: - with register_event_listeners([collector]): - status = conversation.execute( - execution_interrupts=execution_interrupts, - state_snapshot_policy=state_snapshot_policy, - ) - - return status, collector.state_snapshot_events - - -@pytest.mark.parametrize( - ( - "conversation_factory", - "expected_status_class", - "expected_status_type", - "expected_message", - ), - [ - pytest.param( - _create_output_flow_conversation, - FinishedStatus, - "FinishedStatus", - "Hello", - id="flow", - ), - pytest.param( - _create_agent_conversation, - UserMessageRequestStatus, - "UserMessageRequestStatus", - "Hello from agent", - id="agent", - ), - ], -) -def test_conversation_turn_policy_records_opening_and_closing_checkpoints( - conversation_factory, - expected_status_class, - expected_status_type: str, - expected_message: str, -) -> None: - conversation = conversation_factory() - - status, state_snapshot_events = _execute_with_state_snapshots( - conversation, - state_snapshot_policy=StateSnapshotPolicy( - state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS - ), - ) - - assert isinstance(status, expected_status_class) - assert _snapshot_status_types(state_snapshot_events) == [None, expected_status_type] - assert ( - state_snapshot_events[-1].state_snapshot["conversation"]["messages"][-1]["content"] - == expected_message - ) - assert state_snapshot_events[-1].state_snapshot["execution"]["status_handled"] is False - - -@pytest.mark.parametrize( - ("conversation_factory", "expected_status_class"), - [ - pytest.param(_create_output_flow_conversation, FinishedStatus, id="flow"), - pytest.param(_create_agent_conversation, UserMessageRequestStatus, id="agent"), - ], -) -def test_off_policy_disables_state_snapshot_emission( - conversation_factory, - expected_status_class, -) -> None: - conversation = conversation_factory() - - status, state_snapshot_events = _execute_with_state_snapshots( - conversation, - state_snapshot_policy=StateSnapshotPolicy( - state_snapshot_interval=StateSnapshotInterval.OFF - ), - ) - - assert isinstance(status, expected_status_class) - assert state_snapshot_events == [] - - -@pytest.mark.parametrize( - ("conversation_factory", "expected_message"), - [ - pytest.param(_create_output_flow_conversation, "Hello", id="flow"), - pytest.param(_create_agent_conversation, "Hello from agent", id="agent"), - ], -) -def test_conversation_turn_policy_records_interrupted_turn_end_checkpoints( - conversation_factory, - expected_message: str, -) -> None: - conversation = conversation_factory() - - status, state_snapshot_events = _execute_with_state_snapshots( - conversation, - execution_interrupts=[OnEventExecutionInterrupt(EventType.EXECUTION_END)], - state_snapshot_policy=StateSnapshotPolicy( - state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS - ), - ) - - assert isinstance(status, InterruptedExecutionStatus) - assert _snapshot_status_types(state_snapshot_events) == [None, "InterruptedExecutionStatus"] - assert ( - state_snapshot_events[-1].state_snapshot["conversation"]["messages"][-1]["content"] - == expected_message - ) - assert state_snapshot_events[-1].state_snapshot["execution"]["status_handled"] is False - - -def test_conversation_turn_policy_keeps_only_the_opening_checkpoint_when_turn_raises() -> None: - def explode() -> str: - raise RuntimeError("boom") - - conversation = Flow.from_steps( - [ - ToolExecutionStep( - tool=ServerTool( - name="explode", - description="Raise an error", - func=explode, - input_descriptors=[], - ) - ), - CompleteStep(name="end"), - ] - ).start_conversation() - collector = SnapshotCollector() - - with register_event_listeners([collector]): - with pytest.raises(RuntimeError, match="boom"): - conversation.execute( - state_snapshot_policy=StateSnapshotPolicy( - state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS - ) - ) - - assert len(collector.state_snapshot_events) == 1 - assert collector.state_snapshot_events[0].state_snapshot["execution"]["status"] is None - - -def test_conversation_turn_policy_reflects_real_interrupt_side_effects_once() -> None: - conversation = _create_output_flow_conversation() - interrupt = MutatingExecutionEndInterrupt() - - status, state_snapshot_events = _execute_with_state_snapshots( - conversation, - execution_interrupts=[interrupt], - state_snapshot_policy=StateSnapshotPolicy( - state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS - ), - ) - - assert isinstance(status, FinishedStatus) - assert interrupt.count == 1 - assert conversation.inputs["preview_count"] == 1 - assert _snapshot_status_types(state_snapshot_events) == [None, "FinishedStatus"] - assert state_snapshot_events[-1].state_snapshot["conversation"]["inputs"]["preview_count"] == 1 - - -def test_parent_multi_agent_does_not_emit_turn_end_snapshot_when_child_turn_is_interrupted() -> ( - None -): - conversation = _create_managerworkers_conversation() - - status, state_snapshot_events = _execute_with_state_snapshots( - conversation, - execution_interrupts=[WorkerExecutionEndInterrupt()], - state_snapshot_policy=StateSnapshotPolicy( - state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS - ), - use_disable_streaming=True, - ) - - assert isinstance(status, InterruptedExecutionStatus) - - snapshot_events_by_conversation_id: dict[str, list[StateSnapshotEvent]] = {} - for snapshot_event in state_snapshot_events: - snapshot_events_by_conversation_id.setdefault( - snapshot_event.conversation_id, - [], - ).append(snapshot_event) - - parent_multi_agent_snapshot_events = next( - snapshot_events - for snapshot_events in snapshot_events_by_conversation_id.values() - if snapshot_events[0].state_snapshot["conversation"]["component_type"] == "ManagerWorkers" - ) - - assert "InterruptedExecutionStatus" not in _snapshot_status_types( - parent_multi_agent_snapshot_events - ) - - -@pytest.mark.parametrize( - ( - "conversation_factory", - "expected_status_class", - "expected_status_types", - "expected_snapshot_count", - "expected_curr_iters", - ), - [ - pytest.param( - _create_output_flow_conversation, - FinishedStatus, - [None, None, None, None, None, None, "FinishedStatus"], - 7, - None, - id="flow", - ), - pytest.param( - _create_agent_conversation, - UserMessageRequestStatus, - [None, None, "UserMessageRequestStatus"], - 3, - [0, 1], - id="agent", - ), - ], -) -def test_node_turn_policy_tracks_flow_steps_and_agent_iterations( - conversation_factory, - expected_status_class, - expected_status_types: list[str | None], - expected_snapshot_count: int, - expected_curr_iters: list[int] | None, -) -> None: - conversation = conversation_factory() - - status, state_snapshot_events = _execute_with_state_snapshots( - conversation, - state_snapshot_policy=StateSnapshotPolicy( - state_snapshot_interval=StateSnapshotInterval.NODE_TURNS - ), - ) - - # NODE_TURNS means step start/end checkpoints for flows and iteration - # start/end checkpoints for agents, plus the final turn-end checkpoint. - # Flow.from_steps(...) also inserts an internal StartStep, so the flow case - # includes start/end checkpoints for that step too. - assert isinstance(status, expected_status_class) - assert len(state_snapshot_events) == expected_snapshot_count - assert _snapshot_status_types(state_snapshot_events) == expected_status_types - if expected_curr_iters is not None: - assert [ - state_snapshot_events[0].state_snapshot["execution"]["curr_iter"], - state_snapshot_events[1].state_snapshot["execution"]["curr_iter"], - ] == expected_curr_iters - - -@pytest.mark.parametrize( - ("conversation_factory", "interrupt_event"), - [ - pytest.param(_create_output_flow_conversation, EventType.STEP_EXECUTION_START, id="flow"), - pytest.param(_create_agent_conversation, EventType.GENERATION_START, id="agent"), - ], -) -def test_node_turn_policy_keeps_partial_progress_when_interrupted_mid_turn( - conversation_factory, - interrupt_event: EventType, -) -> None: - conversation = conversation_factory() - - status, state_snapshot_events = _execute_with_state_snapshots( - conversation, - execution_interrupts=[OnEventExecutionInterrupt(interrupt_event)], - state_snapshot_policy=StateSnapshotPolicy( - state_snapshot_interval=StateSnapshotInterval.NODE_TURNS - ), - ) - - assert isinstance(status, InterruptedExecutionStatus) - assert _snapshot_status_types(state_snapshot_events) == [None] - - -def test_internal_snapshots_do_not_reuse_the_previous_turn_status() -> None: - llm = DummyModel() - llm.set_next_output(["Hello from agent", "Hello again"]) - conversation = Agent(llm=llm).start_conversation() - conversation.append_user_message("Hi") - collector = SnapshotCollector() - policy = StateSnapshotPolicy(state_snapshot_interval=StateSnapshotInterval.NODE_TURNS) - - with register_event_listeners([collector]): - first_status = conversation.execute(state_snapshot_policy=policy) - assert isinstance(first_status, UserMessageRequestStatus) - - first_status.submit_user_response("Continue") - second_status = conversation.execute(state_snapshot_policy=policy) - - assert isinstance(second_status, UserMessageRequestStatus) - assert len(collector.state_snapshot_events) == 6 - - second_turn_internal_snapshots = collector.state_snapshot_events[3:5] - assert _snapshot_status_types(second_turn_internal_snapshots) == [None, None] - assert all( - snapshot_event.state_snapshot["execution"]["status_handled"] is False - for snapshot_event in second_turn_internal_snapshots - ) - - -def test_state_snapshot_policy_is_inherited_by_nested_sub_conversations() -> None: - child_flow = Flow.from_steps( - [ - OutputMessageStep(message_template="child"), - CompleteStep(name="end"), - ] - ) - parent_flow = Flow.from_steps( - [ - FlowExecutionStep(flow=child_flow), - CompleteStep(name="end"), - ] - ) - conversation = parent_flow.start_conversation() - - status, state_snapshot_events = _execute_with_state_snapshots( - conversation, - state_snapshot_policy=StateSnapshotPolicy( - state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS - ), - ) - - assert isinstance(status, FinishedStatus) - assert len(state_snapshot_events) == 4 - assert {snapshot_event.conversation_id for snapshot_event in state_snapshot_events} == { - conversation.conversation_id - } - - -def test_state_snapshot_policy_is_inherited_by_nested_agent_steps() -> None: - conversation = _create_nested_agent_step_flow_conversation() - - status, state_snapshot_events = _execute_with_state_snapshots( - conversation, - state_snapshot_policy=StateSnapshotPolicy( - state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS - ), - ) - - nested_conversation_id = state_snapshot_events[1].conversation_id - - # A parent flow keeps its own opening/closing checkpoints, while the nested - # agent contributes its own opening/closing pair under the child - # conversation id. - assert isinstance(status, UserMessageRequestStatus) - assert _snapshot_status_types(state_snapshot_events) == [ - None, - None, - "UserMessageRequestStatus", - "UserMessageRequestStatus", - ] - assert [snapshot_event.conversation_id for snapshot_event in state_snapshot_events] == [ - conversation.conversation_id, - nested_conversation_id, - nested_conversation_id, - conversation.conversation_id, - ] - assert nested_conversation_id != conversation.conversation_id - assert ( - state_snapshot_events[-1].state_snapshot["conversation"]["messages"][-1]["content"] - == "agent answer" - ) - - -@pytest.mark.parametrize( - ( - "conversation_factory", - "expected_multi_agent_component_type", - "expected_child_message", - "expected_parent_message", - ), - [ - pytest.param( - _create_nested_managerworkers_flow_conversation, - "ManagerWorkers", - "worker answer", - "manager final answer", - id="managerworkers", - ), - pytest.param( - _create_nested_swarm_flow_conversation, - "Swarm", - "agent2 answer", - "agent1 final answer", - id="swarm", - ), - ], -) -def test_nested_multi_agent_components_emit_snapshots_for_the_active_conversation( - conversation_factory, - expected_multi_agent_component_type: str, - expected_child_message: str, - expected_parent_message: str, -) -> None: - conversation = conversation_factory() - - status, state_snapshot_events = _execute_with_state_snapshots( - conversation, - state_snapshot_policy=StateSnapshotPolicy( - state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS - ), - ) - - assert isinstance(status, UserMessageRequestStatus) - assert len(state_snapshot_events) == 14 - - snapshot_events_by_conversation_id: dict[str, list[StateSnapshotEvent]] = {} - for snapshot_event in state_snapshot_events: - snapshot_events_by_conversation_id.setdefault( - snapshot_event.conversation_id, - [], - ).append(snapshot_event) - - # A nested multi-agent turn has four independent snapshot streams: - # the outer flow, the parent multi-agent conversation, the manager/main - # thread agent conversation (which runs twice), and the delegated child. - assert len(snapshot_events_by_conversation_id) == 4 - - flow_snapshot_events = snapshot_events_by_conversation_id[conversation.conversation_id] - parent_multi_agent_snapshot_events = next( - snapshot_events - for snapshot_events in snapshot_events_by_conversation_id.values() - if snapshot_events[0].state_snapshot["conversation"]["component_type"] - == expected_multi_agent_component_type - ) - agent_snapshot_event_groups = [ - snapshot_events - for conversation_id, snapshot_events in snapshot_events_by_conversation_id.items() - if conversation_id - not in { - conversation.conversation_id, - parent_multi_agent_snapshot_events[0].conversation_id, - } - ] - manager_thread_snapshot_events = next( - snapshot_events - for snapshot_events in agent_snapshot_event_groups - if len(snapshot_events) == 4 - ) - delegated_agent_snapshot_events = next( - snapshot_events - for snapshot_events in agent_snapshot_event_groups - if len(snapshot_events) == 2 - ) - - assert _snapshot_status_types(flow_snapshot_events) == [None, "UserMessageRequestStatus"] - assert ( - flow_snapshot_events[-1].state_snapshot["conversation"]["messages"][-1]["content"] - == expected_parent_message - ) - - # The parent multi-agent conversation records checkpoints each time control - # enters or returns from a child turn, which is what lets a UI reconstruct - # the parent-level progress independently from the child conversations. - assert _snapshot_status_types(parent_multi_agent_snapshot_events) == [ - None, - "ToolRequestStatus", - None, - "UserMessageRequestStatus", - None, - "UserMessageRequestStatus", - ] - assert ( - parent_multi_agent_snapshot_events[4].state_snapshot["conversation"]["messages"][-1][ - "content" - ] - == expected_child_message - ) - assert ( - parent_multi_agent_snapshot_events[-1].state_snapshot["conversation"]["messages"][-1][ - "content" - ] - == expected_parent_message - ) - - # The manager/main thread agent conversation spans two execution turns: - # one that delegates and one that resumes after the child reply. - assert _snapshot_status_types(manager_thread_snapshot_events) == [ - None, - "ToolRequestStatus", - None, - "UserMessageRequestStatus", - ] - assert ( - manager_thread_snapshot_events[2].state_snapshot["conversation"]["messages"][-1]["content"] - == expected_child_message - ) - assert ( - manager_thread_snapshot_events[-1].state_snapshot["conversation"]["messages"][-1]["content"] - == expected_parent_message - ) - - assert _snapshot_status_types(delegated_agent_snapshot_events) == [ - None, - "UserMessageRequestStatus", - ] - assert ( - delegated_agent_snapshot_events[-1].state_snapshot["conversation"]["messages"][-1][ - "content" - ] - == expected_child_message - ) - - -def test_state_snapshot_emission_survives_broken_extra_state_builder() -> None: - def broken_builder(_conversation: Conversation) -> dict[str, object]: - raise RuntimeError("boom") - - conversation = _create_output_flow_conversation() - - status, state_snapshot_events = _execute_with_state_snapshots( - conversation, - state_snapshot_policy=StateSnapshotPolicy( - state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS, - extra_state_builder=broken_builder, - ), - ) - - assert isinstance(status, FinishedStatus) - assert len(state_snapshot_events) == 2 - assert all(snapshot_event.extra_state is None for snapshot_event in state_snapshot_events) - - -def test_state_snapshot_emission_survives_unserializable_variable_state() -> None: - custom_variable = Variable(name="custom", type=AnyProperty()) - conversation = Flow.from_steps( - [ - VariableWriteStep( - variable=custom_variable, - input_mapping={VariableWriteStep.VALUE: custom_variable.name}, - ), - OutputMessageStep(message_template="done"), - CompleteStep(name="end"), - ], - variables=[custom_variable], - ).start_conversation(inputs={custom_variable.name: _UnserializableVariableValue()}) - - status, state_snapshot_events = _execute_with_state_snapshots( - conversation, - state_snapshot_policy=StateSnapshotPolicy( - state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS, - include_variable_state=True, - ), - ) - - assert isinstance(status, FinishedStatus) - assert len(state_snapshot_events) == 2 - assert state_snapshot_events[0].variable_state == {"custom": None} - assert state_snapshot_events[-1].variable_state is None - assert ( - state_snapshot_events[-1].state_snapshot["conversation"]["messages"][-1]["content"] - == "done" - ) - - -@pytest.mark.parametrize( - ( - "conversation_factory", - "execution_interrupts", - "use_disable_streaming", - "expected_status_class", - "expected_status_types", - ), - [ - pytest.param( - lambda: Flow.from_steps( - [ - ToolExecutionStep( - tool=ServerTool( - name="say_hi", - description="Say hi", - func=lambda: "hi", - input_descriptors=[], - ) - ), - CompleteStep(name="end"), - ] - ).start_conversation(), - None, - False, - FinishedStatus, - [None, None, "FinishedStatus"], - id="flow-success", - ), - pytest.param( - lambda: _create_tool_calling_agent_conversation(), - [OnEventExecutionInterrupt(EventType.TOOL_CALL_END)], - True, - InterruptedExecutionStatus, - [None, None], - id="agent-tool-end-interrupt", - ), - ], -) -def test_tool_turn_policy_records_real_tool_boundaries( - conversation_factory, - execution_interrupts: Sequence[ExecutionInterrupt] | None, - use_disable_streaming: bool, - expected_status_class, - expected_status_types: list[str | None], -) -> None: - conversation = conversation_factory() - - status, state_snapshot_events = _execute_with_state_snapshots( - conversation, - execution_interrupts=execution_interrupts, - state_snapshot_policy=StateSnapshotPolicy( - state_snapshot_interval=StateSnapshotInterval.TOOL_TURNS - ), - use_disable_streaming=use_disable_streaming, - ) - - assert isinstance(status, expected_status_class) - assert _snapshot_status_types(state_snapshot_events) == expected_status_types diff --git a/wayflowcore/tests/events/test_state_snapshot_runtime_conversation_turns.py b/wayflowcore/tests/events/test_state_snapshot_runtime_conversation_turns.py new file mode 100644 index 000000000..b5822238f --- /dev/null +++ b/wayflowcore/tests/events/test_state_snapshot_runtime_conversation_turns.py @@ -0,0 +1,234 @@ +# Copyright © 2026 Oracle and/or its affiliates. +# +# This software is under the Apache License 2.0 +# (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0) or Universal Permissive License +# (UPL) 1.0 (LICENSE-UPL or https://oss.oracle.com/licenses/upl), at your option. + +import pytest + +from wayflowcore.events.eventlistener import register_event_listeners +from wayflowcore.executors._events.event import EventType +from wayflowcore.executors.executionstatus import FinishedStatus, UserMessageRequestStatus +from wayflowcore.executors.interrupts.executioninterrupt import InterruptedExecutionStatus +from wayflowcore.executors.statesnapshotpolicy import StateSnapshotInterval + +from ..conftest import disable_streaming +from ..test_interrupts import OnEventExecutionInterrupt +from ..testhelpers.statesnapshots import ( + MutatingExecutionEndInterrupt, + SnapshotCollector, + WorkerExecutionEndInterrupt, + assert_terminal_snapshot, + build_policy, + create_agent_conversation, + create_managerworkers_conversation, + create_output_flow_conversation, + create_tool_flow_conversation, + execute_with_state_snapshots, + execute_with_state_snapshots_async, + find_snapshot_events_by_component_type, + snapshot_status_types, +) + + +@pytest.mark.parametrize( + ( + "conversation_factory", + "expected_status_class", + "expected_status_type", + "expected_message", + ), + [ + pytest.param( + create_output_flow_conversation, + FinishedStatus, + "FinishedStatus", + "Hello", + id="flow", + ), + pytest.param( + create_agent_conversation, + UserMessageRequestStatus, + "UserMessageRequestStatus", + "Hello from agent", + id="agent", + ), + ], +) +def test_conversation_turn_policy_records_opening_and_closing_checkpoints( + conversation_factory, + expected_status_class, + expected_status_type: str, + expected_message: str, +) -> None: + conversation = conversation_factory() + + status, state_snapshot_events = execute_with_state_snapshots( + conversation, + state_snapshot_policy=build_policy(StateSnapshotInterval.CONVERSATION_TURNS), + ) + + assert isinstance(status, expected_status_class) + assert snapshot_status_types(state_snapshot_events) == [None, expected_status_type] + assert_terminal_snapshot( + state_snapshot_events, + expected_status_type=expected_status_type, + expected_message=expected_message, + ) + + +@pytest.mark.anyio +@pytest.mark.parametrize( + ( + "conversation_factory", + "expected_status_class", + "expected_status_type", + "expected_message", + ), + [ + pytest.param( + create_output_flow_conversation, + FinishedStatus, + "FinishedStatus", + "Hello", + id="flow", + ), + pytest.param( + create_agent_conversation, + UserMessageRequestStatus, + "UserMessageRequestStatus", + "Hello from agent", + id="agent", + ), + ], +) +async def test_conversation_turn_policy_records_opening_and_closing_checkpoints_async( + conversation_factory, + expected_status_class, + expected_status_type: str, + expected_message: str, +) -> None: + conversation = conversation_factory() + + status, state_snapshot_events = await execute_with_state_snapshots_async( + conversation, + state_snapshot_policy=build_policy(StateSnapshotInterval.CONVERSATION_TURNS), + ) + + assert isinstance(status, expected_status_class) + assert snapshot_status_types(state_snapshot_events) == [None, expected_status_type] + assert_terminal_snapshot( + state_snapshot_events, + expected_status_type=expected_status_type, + expected_message=expected_message, + ) + + +@pytest.mark.parametrize( + ("conversation_factory", "expected_status_class"), + [ + pytest.param(create_output_flow_conversation, FinishedStatus, id="flow"), + pytest.param(create_agent_conversation, UserMessageRequestStatus, id="agent"), + ], +) +def test_off_policy_disables_state_snapshot_emission( + conversation_factory, + expected_status_class, +) -> None: + conversation = conversation_factory() + + status, state_snapshot_events = execute_with_state_snapshots( + conversation, + state_snapshot_policy=build_policy(StateSnapshotInterval.OFF), + ) + + assert isinstance(status, expected_status_class) + assert state_snapshot_events == [] + + +@pytest.mark.parametrize( + ("conversation_factory", "expected_message"), + [ + pytest.param(create_output_flow_conversation, "Hello", id="flow"), + pytest.param(create_agent_conversation, "Hello from agent", id="agent"), + ], +) +def test_conversation_turn_policy_records_interrupted_turn_end_checkpoints( + conversation_factory, + expected_message: str, +) -> None: + conversation = conversation_factory() + + status, state_snapshot_events = execute_with_state_snapshots( + conversation, + execution_interrupts=[OnEventExecutionInterrupt(EventType.EXECUTION_END)], + state_snapshot_policy=build_policy(StateSnapshotInterval.CONVERSATION_TURNS), + ) + + assert isinstance(status, InterruptedExecutionStatus) + assert snapshot_status_types(state_snapshot_events) == [None, "InterruptedExecutionStatus"] + assert_terminal_snapshot( + state_snapshot_events, + expected_status_type="InterruptedExecutionStatus", + expected_message=expected_message, + ) + + +def test_conversation_turn_policy_keeps_only_the_opening_checkpoint_when_turn_raises() -> None: + def explode() -> str: + raise RuntimeError("boom") + + conversation = create_tool_flow_conversation( + explode, + name="explode", + description="Raise an error", + ) + collector = SnapshotCollector() + + with register_event_listeners([collector]): + with pytest.raises(RuntimeError, match="boom"): + conversation.execute( + state_snapshot_policy=build_policy(StateSnapshotInterval.CONVERSATION_TURNS) + ) + + assert len(collector.state_snapshot_events) == 1 + assert collector.state_snapshot_events[0].state_snapshot["execution"]["status"] is None + + +def test_conversation_turn_policy_reflects_real_interrupt_side_effects_once() -> None: + conversation = create_output_flow_conversation() + interrupt = MutatingExecutionEndInterrupt() + + status, state_snapshot_events = execute_with_state_snapshots( + conversation, + execution_interrupts=[interrupt], + state_snapshot_policy=build_policy(StateSnapshotInterval.CONVERSATION_TURNS), + ) + + assert isinstance(status, FinishedStatus) + assert interrupt.count == 1 + assert conversation.inputs["preview_count"] == 1 + assert snapshot_status_types(state_snapshot_events) == [None, "FinishedStatus"] + assert state_snapshot_events[-1].state_snapshot["conversation"]["inputs"]["preview_count"] == 1 + + +def test_parent_multi_agent_does_not_emit_turn_end_snapshot_when_child_turn_is_interrupted() -> ( + None +): + conversation = create_managerworkers_conversation() + + status, state_snapshot_events = execute_with_state_snapshots( + conversation, + execution_interrupts=[WorkerExecutionEndInterrupt()], + state_snapshot_policy=build_policy(StateSnapshotInterval.CONVERSATION_TURNS), + execution_context=disable_streaming(), + ) + + assert isinstance(status, InterruptedExecutionStatus) + parent_multi_agent_snapshot_events = find_snapshot_events_by_component_type( + state_snapshot_events, + "ManagerWorkers", + ) + assert "InterruptedExecutionStatus" not in snapshot_status_types( + parent_multi_agent_snapshot_events + ) diff --git a/wayflowcore/tests/events/test_state_snapshot_runtime_internal_turns.py b/wayflowcore/tests/events/test_state_snapshot_runtime_internal_turns.py new file mode 100644 index 000000000..663154d54 --- /dev/null +++ b/wayflowcore/tests/events/test_state_snapshot_runtime_internal_turns.py @@ -0,0 +1,323 @@ +# Copyright © 2026 Oracle and/or its affiliates. +# +# This software is under the Apache License 2.0 +# (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0) or Universal Permissive License +# (UPL) 1.0 (LICENSE-UPL or https://oss.oracle.com/licenses/upl), at your option. + +import pytest + +from wayflowcore.agent import Agent +from wayflowcore.events.eventlistener import register_event_listeners +from wayflowcore.executors._events.event import EventType +from wayflowcore.executors.executionstatus import FinishedStatus, UserMessageRequestStatus +from wayflowcore.executors.interrupts.executioninterrupt import InterruptedExecutionStatus +from wayflowcore.executors.statesnapshotpolicy import StateSnapshotInterval + +from ..conftest import disable_streaming +from ..test_interrupts import OnEventExecutionInterrupt +from ..testhelpers.dummy import DummyModel +from ..testhelpers.statesnapshots import ( + SnapshotCollector, + build_policy, + create_agent_conversation, + create_output_flow_conversation, + create_tool_calling_agent_conversation, + create_tool_flow_conversation, + execute_with_state_snapshots, + execute_with_state_snapshots_async, + snapshot_status_types, + snapshot_step_histories, +) + + +@pytest.mark.parametrize( + ( + "conversation_factory", + "expected_status_class", + "expected_status_types", + "expected_snapshot_count", + "expected_curr_iters", + ), + [ + pytest.param( + create_output_flow_conversation, + FinishedStatus, + [None, None, None, None, None, None], + 6, + None, + id="flow", + ), + pytest.param( + create_agent_conversation, + UserMessageRequestStatus, + [None, None], + 2, + [0, 1], + id="agent", + ), + ], +) +def test_node_turn_policy_tracks_flow_steps_and_agent_iterations( + conversation_factory, + expected_status_class, + expected_status_types: list[str | None], + expected_snapshot_count: int, + expected_curr_iters: list[int] | None, +) -> None: + conversation = conversation_factory() + + status, state_snapshot_events = execute_with_state_snapshots( + conversation, + state_snapshot_policy=build_policy(StateSnapshotInterval.NODE_TURNS), + ) + + assert isinstance(status, expected_status_class) + assert len(state_snapshot_events) == expected_snapshot_count + assert snapshot_status_types(state_snapshot_events) == expected_status_types + if expected_curr_iters is not None: + assert [ + state_snapshot_events[0].state_snapshot["execution"]["curr_iter"], + state_snapshot_events[1].state_snapshot["execution"]["curr_iter"], + ] == expected_curr_iters + + +@pytest.mark.anyio +@pytest.mark.parametrize( + ( + "conversation_factory", + "expected_status_class", + "expected_status_types", + "expected_snapshot_count", + "expected_curr_iters", + ), + [ + pytest.param( + create_output_flow_conversation, + FinishedStatus, + [None, None, None, None, None, None], + 6, + None, + id="flow", + ), + pytest.param( + create_agent_conversation, + UserMessageRequestStatus, + [None, None], + 2, + [0, 1], + id="agent", + ), + ], +) +async def test_node_turn_policy_tracks_flow_steps_and_agent_iterations_async( + conversation_factory, + expected_status_class, + expected_status_types: list[str | None], + expected_snapshot_count: int, + expected_curr_iters: list[int] | None, +) -> None: + conversation = conversation_factory() + + status, state_snapshot_events = await execute_with_state_snapshots_async( + conversation, + state_snapshot_policy=build_policy(StateSnapshotInterval.NODE_TURNS), + ) + + assert isinstance(status, expected_status_class) + assert len(state_snapshot_events) == expected_snapshot_count + assert snapshot_status_types(state_snapshot_events) == expected_status_types + if expected_curr_iters is not None: + assert [ + state_snapshot_events[0].state_snapshot["execution"]["curr_iter"], + state_snapshot_events[1].state_snapshot["execution"]["curr_iter"], + ] == expected_curr_iters + + +@pytest.mark.parametrize( + ("conversation_factory", "interrupt_event"), + [ + pytest.param(create_output_flow_conversation, EventType.STEP_EXECUTION_START, id="flow"), + pytest.param(create_agent_conversation, EventType.GENERATION_START, id="agent"), + ], +) +def test_node_turn_policy_keeps_partial_progress_when_interrupted_mid_turn( + conversation_factory, + interrupt_event: EventType, +) -> None: + conversation = conversation_factory() + + status, state_snapshot_events = execute_with_state_snapshots( + conversation, + execution_interrupts=[OnEventExecutionInterrupt(interrupt_event)], + state_snapshot_policy=build_policy(StateSnapshotInterval.NODE_TURNS), + ) + + assert isinstance(status, InterruptedExecutionStatus) + assert snapshot_status_types(state_snapshot_events) == [None] + + +def test_flow_node_turn_policy_uses_iteration_start_and_end_boundaries() -> None: + conversation = create_output_flow_conversation() + + status, state_snapshot_events = execute_with_state_snapshots( + conversation, + state_snapshot_policy=build_policy(StateSnapshotInterval.NODE_TURNS), + ) + + assert isinstance(status, FinishedStatus) + assert snapshot_step_histories(state_snapshot_events) == [ + [], + ["__StartStep__"], + ["__StartStep__"], + ["__StartStep__", "step_0"], + ["__StartStep__", "step_0"], + ["__StartStep__", "step_0", "end"], + ] + + +def test_internal_snapshots_do_not_reuse_the_previous_turn_status() -> None: + llm = DummyModel() + llm.set_next_output(["Hello from agent", "Hello again"]) + conversation = Agent(llm=llm).start_conversation() + conversation.append_user_message("Hi") + collector = SnapshotCollector() + policy = build_policy(StateSnapshotInterval.NODE_TURNS) + + with register_event_listeners([collector]): + first_status = conversation.execute(state_snapshot_policy=policy) + assert isinstance(first_status, UserMessageRequestStatus) + + first_status.submit_user_response("Continue") + second_status = conversation.execute(state_snapshot_policy=policy) + + assert isinstance(second_status, UserMessageRequestStatus) + assert len(collector.state_snapshot_events) == 4 + + second_turn_internal_snapshots = collector.state_snapshot_events[2:4] + assert snapshot_status_types(second_turn_internal_snapshots) == [None, None] + assert all( + snapshot_event.state_snapshot["execution"]["status_handled"] is False + for snapshot_event in second_turn_internal_snapshots + ) + + +@pytest.mark.parametrize( + ( + "conversation_factory", + "execution_interrupts", + "execution_context", + "expected_status_class", + "expected_status_types", + ), + [ + pytest.param( + lambda: create_tool_flow_conversation(lambda: "hi"), + None, + None, + FinishedStatus, + [None, None], + id="flow-success", + ), + pytest.param( + create_tool_calling_agent_conversation, + [OnEventExecutionInterrupt(EventType.TOOL_CALL_END)], + disable_streaming(), + InterruptedExecutionStatus, + [None, None], + id="agent-tool-end-interrupt", + ), + ], +) +def test_tool_turn_policy_records_real_tool_boundaries( + conversation_factory, + execution_interrupts, + execution_context, + expected_status_class, + expected_status_types: list[str | None], +) -> None: + conversation = conversation_factory() + + status, state_snapshot_events = execute_with_state_snapshots( + conversation, + execution_interrupts=execution_interrupts, + state_snapshot_policy=build_policy(StateSnapshotInterval.TOOL_TURNS), + execution_context=execution_context, + ) + + assert isinstance(status, expected_status_class) + assert snapshot_status_types(state_snapshot_events) == expected_status_types + + +@pytest.mark.anyio +@pytest.mark.parametrize( + ( + "conversation_factory", + "execution_interrupts", + "execution_context", + "expected_status_class", + "expected_status_types", + ), + [ + pytest.param( + lambda: create_tool_flow_conversation(lambda: "hi"), + None, + None, + FinishedStatus, + [None, None], + id="flow-success", + ), + pytest.param( + create_tool_calling_agent_conversation, + [OnEventExecutionInterrupt(EventType.TOOL_CALL_END)], + disable_streaming(), + InterruptedExecutionStatus, + [None, None], + id="agent-tool-end-interrupt", + ), + ], +) +async def test_tool_turn_policy_records_real_tool_boundaries_async( + conversation_factory, + execution_interrupts, + execution_context, + expected_status_class, + expected_status_types: list[str | None], +) -> None: + conversation = conversation_factory() + + status, state_snapshot_events = await execute_with_state_snapshots_async( + conversation, + execution_interrupts=execution_interrupts, + state_snapshot_policy=build_policy(StateSnapshotInterval.TOOL_TURNS), + execution_context=execution_context, + ) + + assert isinstance(status, expected_status_class) + assert snapshot_status_types(state_snapshot_events) == expected_status_types + + +def test_all_internal_turn_policy_combines_node_and_tool_boundaries() -> None: + conversation = create_tool_flow_conversation(lambda: "hi") + + status, state_snapshot_events = execute_with_state_snapshots( + conversation, + state_snapshot_policy=build_policy(StateSnapshotInterval.ALL_INTERNAL_TURNS), + ) + + assert isinstance(status, FinishedStatus) + assert len(state_snapshot_events) == 8 + assert snapshot_status_types(state_snapshot_events) == [None] * 8 + + +@pytest.mark.anyio +async def test_all_internal_turn_policy_combines_node_and_tool_boundaries_async() -> None: + conversation = create_tool_flow_conversation(lambda: "hi") + + status, state_snapshot_events = await execute_with_state_snapshots_async( + conversation, + state_snapshot_policy=build_policy(StateSnapshotInterval.ALL_INTERNAL_TURNS), + ) + + assert isinstance(status, FinishedStatus) + assert len(state_snapshot_events) == 8 + assert snapshot_status_types(state_snapshot_events) == [None] * 8 diff --git a/wayflowcore/tests/events/test_state_snapshot_runtime_nested.py b/wayflowcore/tests/events/test_state_snapshot_runtime_nested.py new file mode 100644 index 000000000..d307971aa --- /dev/null +++ b/wayflowcore/tests/events/test_state_snapshot_runtime_nested.py @@ -0,0 +1,258 @@ +# Copyright © 2026 Oracle and/or its affiliates. +# +# This software is under the Apache License 2.0 +# (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0) or Universal Permissive License +# (UPL) 1.0 (LICENSE-UPL or https://oss.oracle.com/licenses/upl), at your option. + +import pytest + +from wayflowcore.executors.executionstatus import FinishedStatus, UserMessageRequestStatus +from wayflowcore.executors.statesnapshotpolicy import StateSnapshotInterval +from wayflowcore.flow import Flow +from wayflowcore.steps import ( + CompleteStep, + FlowExecutionStep, + OutputMessageStep, + ParallelFlowExecutionStep, +) + +from ..testhelpers.statesnapshots import ( + build_policy, + create_nested_agent_step_flow_conversation, + create_nested_managerworkers_flow_conversation, + create_nested_swarm_flow_conversation, + create_parallel_child_flow, + execute_with_state_snapshots, + execute_with_state_snapshots_async, + find_snapshot_events_by_component_type, + group_snapshot_events_by_conversation_id, + snapshot_message, + snapshot_runtime_conversation_ids, + snapshot_status_types, +) + + +def test_state_snapshot_policy_is_inherited_by_nested_sub_conversations() -> None: + child_flow = Flow.from_steps( + [ + OutputMessageStep(message_template="child"), + CompleteStep(name="end"), + ] + ) + parent_flow = Flow.from_steps( + [ + FlowExecutionStep(flow=child_flow), + CompleteStep(name="end"), + ] + ) + conversation = parent_flow.start_conversation() + + status, state_snapshot_events = execute_with_state_snapshots( + conversation, + state_snapshot_policy=build_policy(StateSnapshotInterval.CONVERSATION_TURNS), + ) + + assert isinstance(status, FinishedStatus) + assert len(state_snapshot_events) == 4 + assert {snapshot_event.conversation_id for snapshot_event in state_snapshot_events} == { + conversation.conversation_id + } + child_runtime_conversation_id = state_snapshot_events[1].state_snapshot["conversation"]["id"] + assert snapshot_runtime_conversation_ids(state_snapshot_events) == [ + conversation.id, + child_runtime_conversation_id, + child_runtime_conversation_id, + conversation.id, + ] + assert child_runtime_conversation_id != conversation.id + + +@pytest.mark.anyio +async def test_state_snapshot_policy_is_inherited_by_nested_sub_conversations_async() -> None: + child_flow = Flow.from_steps( + [ + OutputMessageStep(message_template="child"), + CompleteStep(name="end"), + ] + ) + parent_flow = Flow.from_steps( + [ + FlowExecutionStep(flow=child_flow), + CompleteStep(name="end"), + ] + ) + conversation = parent_flow.start_conversation() + + status, state_snapshot_events = await execute_with_state_snapshots_async( + conversation, + state_snapshot_policy=build_policy(StateSnapshotInterval.CONVERSATION_TURNS), + ) + + assert isinstance(status, FinishedStatus) + assert len(state_snapshot_events) == 4 + assert {snapshot_event.conversation_id for snapshot_event in state_snapshot_events} == { + conversation.conversation_id + } + child_runtime_conversation_id = state_snapshot_events[1].state_snapshot["conversation"]["id"] + assert snapshot_runtime_conversation_ids(state_snapshot_events) == [ + conversation.id, + child_runtime_conversation_id, + child_runtime_conversation_id, + conversation.id, + ] + assert child_runtime_conversation_id != conversation.id + + +def test_state_snapshot_policy_is_inherited_by_parallel_sub_conversations() -> None: + conversation = Flow.from_steps( + [ + ParallelFlowExecutionStep( + flows=[ + create_parallel_child_flow("left_output", "left"), + create_parallel_child_flow("right_output", "right"), + ] + ), + CompleteStep(name="end"), + ] + ).start_conversation() + + status, state_snapshot_events = execute_with_state_snapshots( + conversation, + state_snapshot_policy=build_policy(StateSnapshotInterval.CONVERSATION_TURNS), + ) + + assert isinstance(status, FinishedStatus) + assert len(state_snapshot_events) == 6 + assert {snapshot_event.conversation_id for snapshot_event in state_snapshot_events} == { + conversation.conversation_id + } + assert snapshot_status_types(state_snapshot_events).count(None) == 3 + assert snapshot_status_types(state_snapshot_events).count("FinishedStatus") == 3 + + +def test_state_snapshot_policy_is_inherited_by_nested_agent_steps() -> None: + conversation = create_nested_agent_step_flow_conversation() + + status, state_snapshot_events = execute_with_state_snapshots( + conversation, + state_snapshot_policy=build_policy(StateSnapshotInterval.CONVERSATION_TURNS), + ) + + nested_conversation_id = state_snapshot_events[1].conversation_id + + assert isinstance(status, UserMessageRequestStatus) + assert snapshot_status_types(state_snapshot_events) == [ + None, + None, + "UserMessageRequestStatus", + "UserMessageRequestStatus", + ] + assert [snapshot_event.conversation_id for snapshot_event in state_snapshot_events] == [ + conversation.conversation_id, + nested_conversation_id, + nested_conversation_id, + conversation.conversation_id, + ] + assert nested_conversation_id != conversation.conversation_id + assert snapshot_message(state_snapshot_events[-1]) == "agent answer" + + +@pytest.mark.parametrize( + ( + "conversation_factory", + "expected_multi_agent_component_type", + "expected_child_message", + "expected_parent_message", + ), + [ + pytest.param( + create_nested_managerworkers_flow_conversation, + "ManagerWorkers", + "worker answer", + "manager final answer", + id="managerworkers", + ), + pytest.param( + create_nested_swarm_flow_conversation, + "Swarm", + "agent2 answer", + "agent1 final answer", + id="swarm", + ), + ], +) +def test_nested_multi_agent_components_emit_snapshots_for_the_active_conversation( + conversation_factory, + expected_multi_agent_component_type: str, + expected_child_message: str, + expected_parent_message: str, +) -> None: + conversation = conversation_factory() + + status, state_snapshot_events = execute_with_state_snapshots( + conversation, + state_snapshot_policy=build_policy(StateSnapshotInterval.CONVERSATION_TURNS), + ) + + assert isinstance(status, UserMessageRequestStatus) + assert len(state_snapshot_events) == 14 + + snapshot_events_by_conversation_id = group_snapshot_events_by_conversation_id( + state_snapshot_events + ) + + assert len(snapshot_events_by_conversation_id) == 4 + + flow_snapshot_events = snapshot_events_by_conversation_id[conversation.conversation_id] + parent_multi_agent_snapshot_events = find_snapshot_events_by_component_type( + state_snapshot_events, + expected_multi_agent_component_type, + ) + agent_snapshot_event_groups = [ + snapshot_events + for conversation_id, snapshot_events in snapshot_events_by_conversation_id.items() + if conversation_id + not in { + conversation.conversation_id, + parent_multi_agent_snapshot_events[0].conversation_id, + } + ] + manager_thread_snapshot_events = next( + snapshot_events + for snapshot_events in agent_snapshot_event_groups + if len(snapshot_events) == 4 + ) + delegated_agent_snapshot_events = next( + snapshot_events + for snapshot_events in agent_snapshot_event_groups + if len(snapshot_events) == 2 + ) + + assert snapshot_status_types(flow_snapshot_events) == [None, "UserMessageRequestStatus"] + assert snapshot_message(flow_snapshot_events[-1]) == expected_parent_message + + assert snapshot_status_types(parent_multi_agent_snapshot_events) == [ + None, + "ToolRequestStatus", + None, + "UserMessageRequestStatus", + None, + "UserMessageRequestStatus", + ] + assert snapshot_message(parent_multi_agent_snapshot_events[4]) == expected_child_message + assert snapshot_message(parent_multi_agent_snapshot_events[-1]) == expected_parent_message + + assert snapshot_status_types(manager_thread_snapshot_events) == [ + None, + "ToolRequestStatus", + None, + "UserMessageRequestStatus", + ] + assert snapshot_message(manager_thread_snapshot_events[2]) == expected_child_message + assert snapshot_message(manager_thread_snapshot_events[-1]) == expected_parent_message + + assert snapshot_status_types(delegated_agent_snapshot_events) == [ + None, + "UserMessageRequestStatus", + ] + assert snapshot_message(delegated_agent_snapshot_events[-1]) == expected_child_message diff --git a/wayflowcore/tests/events/test_state_snapshot_runtime_resilience.py b/wayflowcore/tests/events/test_state_snapshot_runtime_resilience.py new file mode 100644 index 000000000..6d8cfdb44 --- /dev/null +++ b/wayflowcore/tests/events/test_state_snapshot_runtime_resilience.py @@ -0,0 +1,54 @@ +# Copyright © 2026 Oracle and/or its affiliates. +# +# This software is under the Apache License 2.0 +# (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0) or Universal Permissive License +# (UPL) 1.0 (LICENSE-UPL or https://oss.oracle.com/licenses/upl), at your option. + +from wayflowcore.conversation import Conversation +from wayflowcore.executors.executionstatus import FinishedStatus +from wayflowcore.executors.statesnapshotpolicy import StateSnapshotInterval + +from ..testhelpers.statesnapshots import ( + build_policy, + create_output_flow_conversation, + create_unserializable_variable_conversation, + execute_with_state_snapshots, + snapshot_message, +) + + +def test_state_snapshot_emission_survives_broken_extra_state_builder() -> None: + def broken_builder(_conversation: Conversation) -> dict[str, object]: + raise RuntimeError("boom") + + conversation = create_output_flow_conversation() + + status, state_snapshot_events = execute_with_state_snapshots( + conversation, + state_snapshot_policy=build_policy( + StateSnapshotInterval.CONVERSATION_TURNS, + extra_state_builder=broken_builder, + ), + ) + + assert isinstance(status, FinishedStatus) + assert len(state_snapshot_events) == 2 + assert all(snapshot_event.extra_state is None for snapshot_event in state_snapshot_events) + + +def test_state_snapshot_emission_survives_unserializable_variable_state() -> None: + conversation = create_unserializable_variable_conversation() + + status, state_snapshot_events = execute_with_state_snapshots( + conversation, + state_snapshot_policy=build_policy( + StateSnapshotInterval.CONVERSATION_TURNS, + include_variable_state=True, + ), + ) + + assert isinstance(status, FinishedStatus) + assert len(state_snapshot_events) == 2 + assert state_snapshot_events[0].variable_state == {"custom": None} + assert state_snapshot_events[-1].variable_state is None + assert snapshot_message(state_snapshot_events[-1]) == "done" diff --git a/wayflowcore/tests/serialization/test_conversation_state_snapshot.py b/wayflowcore/tests/serialization/test_conversation_state_snapshot.py index 6a5656721..c6a943b61 100644 --- a/wayflowcore/tests/serialization/test_conversation_state_snapshot.py +++ b/wayflowcore/tests/serialization/test_conversation_state_snapshot.py @@ -78,6 +78,63 @@ def test_dump_conversation_state_is_json_serializable_and_lightweight() -> None: ) +def test_dump_conversation_state_overrides_execution_fields_without_mutating_conversation() -> None: + custom_variable = Variable( + name="custom", + type=StringProperty(), + description="Custom variable used for snapshot serialization tests", + ) + flow = _build_snapshot_flow(custom_variable) + conversation = flow.start_conversation(inputs={custom_variable.name: "custom-value"}) + conversation.execute() + + previous_status = conversation.status + previous_status_handled = conversation.status_handled + + snapshot = dump_conversation_state( + conversation, + status=None, + status_handled=True, + ) + + assert snapshot["execution"]["status"] is None + assert snapshot["execution"]["status_handled"] is True + assert conversation.status is previous_status + assert conversation.status_handled is previous_status_handled + + +def test_dump_conversation_state_includes_runtime_conversation_id() -> None: + custom_variable = Variable( + name="custom", + type=StringProperty(), + description="Custom variable used for snapshot serialization tests", + ) + flow = _build_snapshot_flow(custom_variable) + conversation = flow.start_conversation(inputs={custom_variable.name: "custom-value"}) + conversation.execute() + + snapshot = dump_conversation_state(conversation) + + assert snapshot["conversation"]["id"] == conversation.id + assert snapshot["conversation"]["conversation_id"] == conversation.conversation_id + + +def test_dump_conversation_state_does_not_overload_status_conversation_identity() -> None: + custom_variable = Variable( + name="custom", + type=StringProperty(), + description="Custom variable used for snapshot serialization tests", + ) + flow = _build_snapshot_flow(custom_variable) + conversation = flow.start_conversation(inputs={custom_variable.name: "custom-value"}) + conversation.execute() + + snapshot = dump_conversation_state(conversation) + + assert snapshot["execution"]["status"]["type"] == "FinishedStatus" + assert "conversation_id" not in snapshot["execution"]["status"] + + def test_dump_variable_state_rejects_non_json_serializable_values() -> None: custom_variable = Variable( name="custom", diff --git a/wayflowcore/tests/testhelpers/statesnapshots.py b/wayflowcore/tests/testhelpers/statesnapshots.py new file mode 100644 index 000000000..f613fae20 --- /dev/null +++ b/wayflowcore/tests/testhelpers/statesnapshots.py @@ -0,0 +1,379 @@ +# Copyright © 2026 Oracle and/or its affiliates. +# +# This software is under the Apache License 2.0 +# (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0) or Universal Permissive License +# (UPL) 1.0 (LICENSE-UPL or https://oss.oracle.com/licenses/upl), at your option. + +import threading +from collections import defaultdict +from contextlib import AbstractContextManager, nullcontext +from typing import Any, Callable, Sequence + +from wayflowcore.agent import Agent +from wayflowcore.conversation import Conversation +from wayflowcore.events.event import Event, StateSnapshotEvent +from wayflowcore.events.eventlistener import EventListener, register_event_listeners +from wayflowcore.executors._executionstate import ConversationExecutionState +from wayflowcore.executors.interrupts.executioninterrupt import ( + ExecutionInterrupt, + InterruptedExecutionStatus, + _NullExecutionInterrupt, +) +from wayflowcore.executors.statesnapshotpolicy import StateSnapshotInterval, StateSnapshotPolicy +from wayflowcore.flow import Flow +from wayflowcore.managerworkers import ManagerWorkers +from wayflowcore.messagelist import Message, MessageType +from wayflowcore.property import AnyProperty, StringProperty +from wayflowcore.serialization.serializer import SerializableNeedToBeImplementedMixin +from wayflowcore.steps import ( + AgentExecutionStep, + CompleteStep, + OutputMessageStep, + ToolExecutionStep, + VariableWriteStep, +) +from wayflowcore.swarm import Swarm +from wayflowcore.tools import ServerTool, ToolRequest, tool +from wayflowcore.variable import Variable + +from .dummy import DummyModel + + +class SnapshotCollector(EventListener): + def __init__(self) -> None: + self.state_snapshot_events: list[StateSnapshotEvent] = [] + + def __call__(self, event: Event) -> None: + if isinstance(event, StateSnapshotEvent): + self.state_snapshot_events.append(event) + + +def build_state_snapshot_policy( + interval: StateSnapshotInterval, + **kwargs: Any, +) -> StateSnapshotPolicy: + return StateSnapshotPolicy(state_snapshot_interval=interval, **kwargs) + + +def snapshot_status_type(snapshot_event: Any) -> str | None: + status = snapshot_event.state_snapshot["execution"]["status"] + return status["type"] if status is not None else None + + +def snapshot_status_types(snapshot_events: Sequence[Any]) -> list[str | None]: + return [snapshot_status_type(snapshot_event) for snapshot_event in snapshot_events] + + +def snapshot_message(snapshot_event: Any) -> str | None: + messages = snapshot_event.state_snapshot["conversation"]["messages"] + if not messages: + return None + return messages[-1].get("content") + + +def snapshot_runtime_conversation_id(snapshot_event: Any) -> str: + return snapshot_event.state_snapshot["conversation"]["id"] + + +def snapshot_runtime_conversation_ids(snapshot_events: Sequence[Any]) -> list[str]: + return [snapshot_runtime_conversation_id(snapshot_event) for snapshot_event in snapshot_events] + + +def snapshot_step_histories(snapshot_events: Sequence[Any]) -> list[list[str]]: + return [ + snapshot_event.state_snapshot["execution"]["step_history"] + for snapshot_event in snapshot_events + ] + + +def group_snapshot_events_by_conversation_id( + snapshot_events: Sequence[Any], +) -> dict[str, list[Any]]: + grouped_snapshot_events: dict[str, list[Any]] = defaultdict(list) + for snapshot_event in snapshot_events: + grouped_snapshot_events[snapshot_event.conversation_id].append(snapshot_event) + return dict(grouped_snapshot_events) + + +def find_snapshot_events_by_component_type( + snapshot_events: Sequence[Any], + component_type: str, +) -> list[Any]: + return next( + grouped_events + for grouped_events in group_snapshot_events_by_conversation_id(snapshot_events).values() + if grouped_events[0].state_snapshot["conversation"]["component_type"] == component_type + ) + + +def execute_with_state_snapshots( + conversation: Conversation, + *, + state_snapshot_policy: StateSnapshotPolicy, + execution_interrupts: Sequence[ExecutionInterrupt] | None = None, + execution_context: AbstractContextManager[Any] | None = None, +) -> tuple[object, list[StateSnapshotEvent]]: + collector = SnapshotCollector() + + with execution_context or nullcontext(): + with register_event_listeners([collector]): + status = conversation.execute( + execution_interrupts=execution_interrupts, + state_snapshot_policy=state_snapshot_policy, + ) + + return status, collector.state_snapshot_events + + +async def execute_with_state_snapshots_async( + conversation: Conversation, + *, + state_snapshot_policy: StateSnapshotPolicy, + execution_interrupts: Sequence[ExecutionInterrupt] | None = None, + execution_context: AbstractContextManager[Any] | None = None, +) -> tuple[object, list[StateSnapshotEvent]]: + collector = SnapshotCollector() + + with execution_context or nullcontext(): + with register_event_listeners([collector]): + status = await conversation.execute_async( + execution_interrupts=execution_interrupts, + state_snapshot_policy=state_snapshot_policy, + ) + + return status, collector.state_snapshot_events + + +class MutatingExecutionEndInterrupt(SerializableNeedToBeImplementedMixin, _NullExecutionInterrupt): + def __init__(self) -> None: + self.lock = threading.Lock() + self.count = 0 + super().__init__() + + def _on_execution_end( + self, + state: ConversationExecutionState, + conversation: Conversation, + ) -> InterruptedExecutionStatus | None: + conversation.inputs["preview_count"] = conversation.inputs.get("preview_count", 0) + 1 + self.count += 1 + return None + + +class WorkerExecutionEndInterrupt(SerializableNeedToBeImplementedMixin, _NullExecutionInterrupt): + def __init__(self) -> None: + self.triggered = False + super().__init__() + + def _on_execution_end( + self, + state: ConversationExecutionState, + conversation: Conversation, + ) -> InterruptedExecutionStatus | None: + if self.triggered: + return None + if getattr(conversation.component, "name", None) != "worker": + return None + + self.triggered = True + return InterruptedExecutionStatus( + interrupter=self, + reason="worker execution end", + _conversation_id=conversation.id, + ) + + +class _UnserializableVariableValue: + pass + + +def create_tool_flow_conversation( + func: Callable[[], object], + *, + name: str = "say_hi", + description: str = "Say hi", +) -> Conversation: + return Flow.from_steps( + [ + ToolExecutionStep( + tool=ServerTool( + name=name, + description=description, + func=func, + input_descriptors=[], + ) + ), + CompleteStep(name="end"), + ] + ).start_conversation() + + +def create_output_flow_conversation(message: str = "Hello") -> Conversation: + return Flow.from_steps( + [ + OutputMessageStep(message_template=message), + CompleteStep(name="end"), + ] + ).start_conversation() + + +def create_agent_conversation(message: str = "Hello from agent") -> Conversation: + llm = DummyModel() + llm.set_next_output(message) + conversation = Agent(llm=llm).start_conversation() + conversation.append_user_message("Hi") + return conversation + + +def create_tool_calling_agent_conversation() -> Conversation: + @tool + def do_nothing_tool() -> str: + """Do nothing tool.""" + return "Tool called successfully" + + llm = DummyModel() + llm.set_next_output( + { + "Please use the do_nothing_tool": Message( + message_type=MessageType.TOOL_REQUEST, + content="I am calling the do nothing tool", + tool_requests=[ToolRequest("do_nothing_tool", {}, "tc1")], + ) + } + ) + conversation = Agent(llm=llm, tools=[do_nothing_tool], max_iterations=10).start_conversation() + conversation.append_user_message("Please use the do_nothing_tool") + return conversation + + +def _create_send_message_request(recipient_name: str, message: str) -> Message: + return Message( + content="", + message_type=MessageType.TOOL_REQUEST, + tool_requests=[ + ToolRequest( + name="send_message", + args={"recipient": recipient_name, "message": message}, + ) + ], + ) + + +def create_nested_agent_step_flow_conversation() -> Conversation: + llm = DummyModel() + llm.set_next_output("agent answer") + child_agent = Agent(llm=llm) + conversation = Flow.from_steps( + [AgentExecutionStep(agent=child_agent), CompleteStep(name="end")] + ).start_conversation() + conversation.append_user_message("dummy") + return conversation + + +def create_nested_managerworkers_flow_conversation() -> Conversation: + llm = DummyModel() + worker = Agent(llm=llm, name="worker", description="worker") + group = ManagerWorkers(group_manager=llm, workers=[worker]) + llm.set_next_output( + [ + _create_send_message_request("worker", "Do it"), + "worker answer", + "manager final answer", + ] + ) + + conversation = Flow.from_steps( + [AgentExecutionStep(agent=group), CompleteStep(name="end")] + ).start_conversation() + conversation.append_user_message("dummy") + return conversation + + +def create_managerworkers_conversation() -> Conversation: + llm = DummyModel() + worker = Agent(llm=llm, name="worker", description="worker") + group = ManagerWorkers(group_manager=llm, workers=[worker]) + llm.set_next_output( + [ + _create_send_message_request("worker", "Do it"), + "worker answer", + "manager final answer", + ] + ) + + conversation = group.start_conversation() + conversation.append_user_message("dummy") + return conversation + + +def create_nested_swarm_flow_conversation() -> Conversation: + llm = DummyModel() + first_agent = Agent(llm=llm, name="agent1", description="agent1") + second_agent = Agent(llm=llm, name="agent2", description="agent2") + swarm = Swarm( + first_agent=first_agent, + relationships=[(first_agent, second_agent), (second_agent, first_agent)], + ) + llm.set_next_output( + [ + _create_send_message_request("agent2", "Do it"), + "agent2 answer", + "agent1 final answer", + ] + ) + + conversation = Flow.from_steps( + [AgentExecutionStep(agent=swarm), CompleteStep(name="end")] + ).start_conversation() + conversation.append_user_message("dummy") + return conversation + + +def create_parallel_child_flow(output_name: str, output_value: str) -> Flow: + return Flow.from_steps( + [ + ToolExecutionStep( + tool=ServerTool( + name=f"tool_{output_name}", + description=f"Return {output_name}", + input_descriptors=[], + output_descriptors=[StringProperty(name=output_name)], + func=lambda: output_value, + ) + ), + CompleteStep(name="end"), + ] + ) + + +def create_unserializable_variable_conversation() -> Conversation: + custom_variable = Variable(name="custom", type=AnyProperty()) + return Flow.from_steps( + [ + VariableWriteStep( + variable=custom_variable, + input_mapping={VariableWriteStep.VALUE: custom_variable.name}, + ), + OutputMessageStep(message_template="done"), + CompleteStep(name="end"), + ], + variables=[custom_variable], + ).start_conversation(inputs={custom_variable.name: _UnserializableVariableValue()}) + + +def build_policy( + interval: StateSnapshotInterval, + **kwargs: object, +) -> StateSnapshotPolicy: + return build_state_snapshot_policy(interval, **kwargs) + + +def assert_terminal_snapshot( + snapshot_events: Sequence[object], + *, + expected_status_type: str, + expected_message: str, +) -> None: + assert snapshot_status_types(snapshot_events)[-1] == expected_status_type + assert snapshot_message(snapshot_events[-1]) == expected_message + assert snapshot_events[-1].state_snapshot["execution"]["status_handled"] is False From aca9a35ad8496d3c065ce415a1b35b33e22b8eb2 Mon Sep 17 00:00:00 2001 From: Son Le Date: Wed, 18 Mar 2026 17:13:05 +0100 Subject: [PATCH 06/13] remove snapshot tracing for multi-agent --- docs/wayflowcore/source/core/changelog.rst | 2 +- .../src/wayflowcore/agentspec/tracing.py | 152 +-------- .../executors/_statesnapshot_eventlistener.py | 43 +-- .../test_state_snapshot_tracing_agent.py | 13 - .../test_state_snapshot_tracing_flow.py | 17 - .../test_state_snapshot_tracing_nested.py | 291 ++---------------- ...ate_snapshot_runtime_conversation_turns.py | 26 -- .../test_state_snapshot_runtime_nested.py | 105 ------- .../tests/testhelpers/statesnapshots.py | 118 ------- 9 files changed, 44 insertions(+), 723 deletions(-) diff --git a/docs/wayflowcore/source/core/changelog.rst b/docs/wayflowcore/source/core/changelog.rst index 8c9ea6c3b..d13a9c518 100644 --- a/docs/wayflowcore/source/core/changelog.rst +++ b/docs/wayflowcore/source/core/changelog.rst @@ -10,7 +10,7 @@ New features * **State snapshot tracing events:** Added ``StateSnapshotPolicy``, ``StateSnapshotEvent``, and conversation snapshot serialization helpers. State snapshots can now be enabled per ``conversation.execute()`` / ``execute_async()`` turn, emitted at conversation, node, or tool boundaries, and bridged to Agent Spec ``StateSnapshotEmitted`` events via ``AgentSpecEventListener``. - Snapshot emission is covered on both synchronous and asynchronous execution paths. + Snapshot emission is covered on both synchronous and asynchronous execution paths, with snapshot ownership currently scoped to the active conversation. * **OAuth support for MCP Clients:** diff --git a/wayflowcore/src/wayflowcore/agentspec/tracing.py b/wayflowcore/src/wayflowcore/agentspec/tracing.py index 5add8b788..78314be0c 100644 --- a/wayflowcore/src/wayflowcore/agentspec/tracing.py +++ b/wayflowcore/src/wayflowcore/agentspec/tracing.py @@ -12,7 +12,6 @@ from pyagentspec.flows.node import Node as AgentSpecNode from pyagentspec.llms import LlmConfig as AgentSpecLlmConfig from pyagentspec.llms import LlmGenerationConfig -from pyagentspec.swarm import Swarm as AgentSpecSwarm from pyagentspec.tools import Tool as AgentSpecTool from pyagentspec.tracing.events import AgentExecutionEnd as AgentSpecAgentExecutionEnd from pyagentspec.tracing.events import AgentExecutionStart as AgentSpecAgentExecutionStart @@ -27,8 +26,6 @@ from pyagentspec.tracing.events import NodeExecutionEnd as AgentSpecNodeExecutionEnd from pyagentspec.tracing.events import NodeExecutionStart as AgentSpecNodeExecutionStart from pyagentspec.tracing.events import StateSnapshotEmitted as AgentSpecStateSnapshotEmitted -from pyagentspec.tracing.events import SwarmExecutionEnd as AgentSpecSwarmExecutionEnd -from pyagentspec.tracing.events import SwarmExecutionStart as AgentSpecSwarmExecutionStart from pyagentspec.tracing.events import ToolExecutionRequest as AgentSpecToolExecutionRequest from pyagentspec.tracing.events import ToolExecutionResponse as AgentSpecToolExecutionResponse from pyagentspec.tracing.events.llmgeneration import ToolCall as AgentSpecToolCall @@ -38,7 +35,6 @@ from pyagentspec.tracing.spans import LlmGenerationSpan as AgentSpecLlmGenerationSpan from pyagentspec.tracing.spans import NodeExecutionSpan as AgentSpecNodeExecutionSpan from pyagentspec.tracing.spans import Span as AgentSpecSpan -from pyagentspec.tracing.spans import SwarmExecutionSpan as AgentSpecSwarmExecutionSpan from pyagentspec.tracing.spans import ToolExecutionSpan as AgentSpecToolExecutionSpan from wayflowcore._utils.formatting import stringify @@ -64,8 +60,6 @@ ) from wayflowcore.events.eventlistener import EventListener from wayflowcore.executors.executionstatus import FinishedStatus -from wayflowcore.steps.agentexecutionstep import AgentExecutionStep as RuntimeAgentExecutionStep -from wayflowcore.swarm import Swarm as RuntimeSwarm from wayflowcore.tracing.span import LlmGenerationSpan, get_active_span_stack, get_current_span @@ -80,14 +74,12 @@ def __init__(self) -> None: self.agentspec_exporter: AgentSpecExporter = AgentSpecExporter() # We keep a registry of conversions, so that we do not repeat the conversion for the same object twice self.agentspec_components_registry: Dict[str, AgentSpecComponent] = {} - # State snapshots belong to the span that owns their conversation id, not - # necessarily to the runtime span that was active when the snapshot event - # was emitted. Nested flow sub-conversations can intentionally reuse the - # same deprecated conversation_id, so we also track the live conversation - # object id that currently owns that stream. + # State snapshots belong to the span that owns their logical + # conversation_id, not necessarily to the runtime span that was active + # when the snapshot event was emitted. Nested flow sub-conversations can + # intentionally reuse the same logical conversation_id, so we also track + # the live runtime conversation id that currently owns that stream. self._conversation_spans_registry: Dict[str, tuple[str, AgentSpecSpan]] = {} - self._pending_multi_agent_spans_by_component_id: Dict[str, AgentSpecSpan] = {} - self._multi_agent_spans_by_step_span_id: Dict[str, AgentSpecSpan] = {} # Track last assistant message id and a robust mapping tool_request_id -> assistant message id. # Some providers may emit tool events before final assistant message id is known; we allow # temporarily missing ids and backfill on LLM response. @@ -123,142 +115,21 @@ def _register_current_conversation_span(self, agentspec_span: AgentSpecSpan) -> agentspec_span, ) - def _start_multi_agent_span_if_needed( - self, - current_span_id: str, - event_name: str, - event: StepInvocationStartEvent, - ) -> None: - if not isinstance(event.step, RuntimeAgentExecutionStep): - return - - if isinstance(event.step.agent, RuntimeSwarm): - agentspec_swarm = cast(AgentSpecSwarm, self._convert_to_agentspec(event.step.agent)) - multi_agent_span = AgentSpecSwarmExecutionSpan( - id=f"{current_span_id}:swarm", - name=f"SwarmExecution[{event.step.agent._get_display_name()}]", - swarm=agentspec_swarm, - ) - multi_agent_span.start() - multi_agent_span.add_event( - AgentSpecSwarmExecutionStart( - id=event.event_id, - name=event_name, - swarm=agentspec_swarm, - inputs={ - input_name: input_value for input_name, input_value in event.inputs.items() - }, - ) - ) - self._multi_agent_spans_by_step_span_id[current_span_id] = multi_agent_span - self._pending_multi_agent_spans_by_component_id[event.step.agent.id] = multi_agent_span - - def _end_multi_agent_span_if_needed( - self, - current_span_id: str, - event_name: str, - event: StepInvocationResultEvent, - ) -> None: - if not isinstance(event.step, RuntimeAgentExecutionStep): - return - - multi_agent_span = self._multi_agent_spans_by_step_span_id.pop(current_span_id, None) - if multi_agent_span is None: - return - - outputs = { - output_name: output_value - for output_name, output_value in event.step_result.outputs.items() - if output_name != "__execution_status__" - } - if isinstance(multi_agent_span, AgentSpecSwarmExecutionSpan): - multi_agent_span.add_event( - AgentSpecSwarmExecutionEnd( - id=event.event_id, - name=event_name, - swarm=multi_agent_span.swarm, - outputs=outputs, - ) - ) - - multi_agent_span.end() - self._pending_multi_agent_spans_by_component_id.pop(event.step.agent.id, None) - def _get_snapshot_owner_span( self, event: StateSnapshotEvent, current_agentspec_span: AgentSpecSpan | None, ) -> AgentSpecSpan | None: + # Keep snapshot ownership resolution centralized here. Today we only + # support direct span ownership plus shared-conversation routing for + # nested flows. Future multi-agent tracing can extend this method to + # route snapshots to a dedicated Swarm/ManagerWorkers wrapper span + # without changing the StateSnapshotEvent handling below. if event.conversation_id in self._conversation_spans_registry: return self._conversation_spans_registry[event.conversation_id][1] - if event.state_snapshot is None: - return None - if not isinstance(event.state_snapshot, dict): - return None - snapshot_conversation = event.state_snapshot.get("conversation") - if not isinstance(snapshot_conversation, dict): - return None - snapshot_runtime_conversation_id = snapshot_conversation.get("id") - if not isinstance(snapshot_runtime_conversation_id, str): - return None - - active_conversations = _get_active_conversations(return_copy=False) - matching_conversation = next( - ( - conversation - for conversation in reversed(active_conversations) - if conversation.id == snapshot_runtime_conversation_id - ), - None, - ) - if matching_conversation is None: - return None - - pending_multi_agent_span = self._pending_multi_agent_spans_by_component_id.get( - matching_conversation.component.id - ) - if pending_multi_agent_span is not None: - self._conversation_spans_registry[event.conversation_id] = ( - matching_conversation.id, - pending_multi_agent_span, - ) - return pending_multi_agent_span - return current_agentspec_span - @staticmethod - def _move_snapshot_before_terminal_event( - agentspec_span: AgentSpecSpan, - snapshot_event: AgentSpecStateSnapshotEmitted, - ) -> None: - # State snapshots are emitted by a separate runtime listener in response - # to the turn-end event. That means this Agent Spec listener can record - # the terminal Agent/Flow/Swarm end event first and only see the derived - # snapshot event afterward. Keep the Agent Spec span readable by exposing - # the closing snapshot immediately before the terminal event. - if len(agentspec_span.events) < 2 or not isinstance( - agentspec_span.events[-2], - ( - AgentSpecAgentExecutionEnd, - AgentSpecExceptionRaised, - AgentSpecFlowExecutionEnd, - AgentSpecSwarmExecutionEnd, - ), - ): - return - - terminal_event = agentspec_span.events[-2] - if agentspec_span.end_time is not None: - snapshot_event.timestamp = min(snapshot_event.timestamp, agentspec_span.end_time) - else: - snapshot_event.timestamp = min(snapshot_event.timestamp, terminal_event.timestamp) - - agentspec_span.events[-2], agentspec_span.events[-1] = ( - agentspec_span.events[-1], - agentspec_span.events[-2], - ) - def __call__(self, event: Event) -> None: # We intercept the wayflow events, and based on the type of event: # - if it corresponds to a span start event, we create the corresponding agent spec span, and we start it @@ -470,12 +341,10 @@ def __call__(self, event: Event) -> None: }, ) ) - self._start_multi_agent_span_if_needed(current_span.span_id, event_name, event) case StepInvocationResultEvent(): # Step execution ends. Add the event to the agent spec span and close the span if not current_agentspec_span: return - self._end_multi_agent_span_if_needed(current_span.span_id, event_name, event) agentspec_node = cast(AgentSpecNode, self._convert_to_agentspec(event.step)) current_agentspec_span.add_event( AgentSpecNodeExecutionEnd( @@ -502,7 +371,6 @@ def __call__(self, event: Event) -> None: extra_state=event.extra_state, ) owner_span.add_event(snapshot_event) - self._move_snapshot_before_terminal_event(owner_span, snapshot_event) case FlowExecutionStartedEvent(): # Flow execution starts. Create the new agent spec span, start it, add the event agentspec_flow = cast( diff --git a/wayflowcore/src/wayflowcore/executors/_statesnapshot_eventlistener.py b/wayflowcore/src/wayflowcore/executors/_statesnapshot_eventlistener.py index e608186fa..41a862018 100644 --- a/wayflowcore/src/wayflowcore/executors/_statesnapshot_eventlistener.py +++ b/wayflowcore/src/wayflowcore/executors/_statesnapshot_eventlistener.py @@ -212,25 +212,6 @@ def _get_current_active_conversation() -> Optional[Conversation]: return active_conversations[-1] -def _is_multi_agent_conversation(conversation: Conversation) -> bool: - from wayflowcore.executors._managerworkersconversation import ManagerWorkersConversation - from wayflowcore.executors._swarmconversation import SwarmConversation - - return isinstance(conversation, (ManagerWorkersConversation, SwarmConversation)) - - -def _get_nearest_parent_multi_agent_conversation() -> Optional[Conversation]: - active_conversations = _get_active_conversations(return_copy=False) - if len(active_conversations) < 2: - return None - - for conversation in reversed(active_conversations[:-1]): - if _is_multi_agent_conversation(conversation): - return conversation - - return None - - class StateSnapshotEventListener(EventListener): """Emit state snapshots for the active conversation.""" @@ -257,16 +238,12 @@ def _record_snapshot( def _handle_pre_interrupt_event( self, event: Event, - *, - is_current_conversation: bool, ) -> None: # Agents do not expose node execution events, so NODE_TURNS maps to # flow iteration boundaries and agent iteration boundaries. match event: case AgentExecutionStartedEvent() | FlowExecutionStartedEvent(): self._record_snapshot(StateSnapshotInterval.CONVERSATION_TURNS) - case _ if not is_current_conversation: - return case ToolExecutionStartEvent() | ToolExecutionResultEvent(): self._record_snapshot(StateSnapshotInterval.TOOL_TURNS) case ( @@ -286,12 +263,12 @@ def _should_record_interrupted_turn_end_snapshot( and isinstance(get_current_span(), (FlowExecutionSpan, AgentExecutionSpan)) ) - def _is_parent_multi_agent_conversation(self) -> bool: - parent_multi_agent_conversation = _get_nearest_parent_multi_agent_conversation() - return ( - parent_multi_agent_conversation is not None - and parent_multi_agent_conversation.id == self.conversation.id - ) + def _owns_current_conversation(self, current_conversation: Conversation) -> bool: + # Keep snapshot ownership resolution centralized here. Today a listener + # only reacts for its own active conversation. Follow-up PRs can widen + # this to parent multi-agent wrapper conversations without touching the + # snapshot emission logic. + return current_conversation.id == self.conversation.id def _handle_post_interrupt_event(self, event: Event) -> None: match event: @@ -317,17 +294,13 @@ def __call__(self, event: Event) -> None: if current_conversation is None: return - is_current_conversation = current_conversation.id == self.conversation.id - if not is_current_conversation and not self._is_parent_multi_agent_conversation(): + if not self._owns_current_conversation(current_conversation): return if self.post_interrupts: self._handle_post_interrupt_event(event) else: - self._handle_pre_interrupt_event( - event, - is_current_conversation=is_current_conversation, - ) + self._handle_pre_interrupt_event(event) @contextmanager diff --git a/wayflowcore/tests/agentspec/test_state_snapshot_tracing_agent.py b/wayflowcore/tests/agentspec/test_state_snapshot_tracing_agent.py index bc0c865d1..d4fad2933 100644 --- a/wayflowcore/tests/agentspec/test_state_snapshot_tracing_agent.py +++ b/wayflowcore/tests/agentspec/test_state_snapshot_tracing_agent.py @@ -11,7 +11,6 @@ from pyagentspec.adapters.wayflow import AgentSpecLoader from pyagentspec.agent import Agent as AgentSpecAgent from pyagentspec.llms import VllmConfig -from pyagentspec.tracing.events import AgentExecutionEnd as AgentSpecAgentExecutionEnd from pyagentspec.tracing.events import AgentExecutionStart as AgentSpecAgentExecutionStart from pyagentspec.tracing.events import Event as AgentSpecEvent from pyagentspec.tracing.events import StateSnapshotEmitted as AgentSpecStateSnapshotEmitted @@ -296,14 +295,6 @@ def _single_event( return next(event for event in span.events if isinstance(event, event_type)) -def _assert_snapshot_precedes_terminal_event( - span: AgentSpecSpan, - snapshot_events: Sequence[AgentSpecStateSnapshotEmitted], - terminal_event: AgentSpecEvent, -) -> None: - assert span.events.index(snapshot_events[-1]) < span.events.index(terminal_event) - - def test_agent_state_snapshots_support_the_agui_retrieval_export_flow() -> None: assistant_message = "I checked the warehouse and found 42 orders last week." wayflow_agent = _create_retrieval_like_wayflow_agent() @@ -326,7 +317,6 @@ def test_agent_state_snapshots_support_the_agui_retrieval_export_flow() -> None: agent_span = _single_span(span_recorder, AgentSpecAgentExecutionSpan) assert _events(agent_span, AgentSpecAgentExecutionStart) - agent_end_event = _single_event(agent_span, AgentSpecAgentExecutionEnd) state_snapshot_events = _events(agent_span, AgentSpecStateSnapshotEmitted) assert len(state_snapshot_events) == 2 @@ -347,7 +337,6 @@ def test_agent_state_snapshots_support_the_agui_retrieval_export_flow() -> None: ) assert runtime_messages[-1]["content"] == assistant_message assert final_snapshot_event.extra_state == {"agent_state": expected_agent_state} - _assert_snapshot_precedes_terminal_event(agent_span, state_snapshot_events, agent_end_event) assert len(agui_exporter.exported_snapshots) == 2 assert agui_exporter.exported_snapshots[-1] == ExportedAGUIStateSnapshot( @@ -375,7 +364,6 @@ def test_agent_node_turn_state_snapshots_are_mapped_into_the_agent_span_not_llm_ assert isinstance(status, UserMessageRequestStatus) agent_span = _single_span(span_recorder, AgentSpecAgentExecutionSpan) - agent_end_event = _single_event(agent_span, AgentSpecAgentExecutionEnd) state_snapshot_events = _events(agent_span, AgentSpecStateSnapshotEmitted) assert len(state_snapshot_events) == 2 @@ -385,7 +373,6 @@ def test_agent_node_turn_state_snapshots_are_mapped_into_the_agent_span_not_llm_ ] assert snapshot_status_types(state_snapshot_events) == [None, None] assert snapshot_message(state_snapshot_events[-1]) == assistant_message - _assert_snapshot_precedes_terminal_event(agent_span, state_snapshot_events, agent_end_event) llm_spans = _spans(span_recorder, AgentSpecLlmGenerationSpan) assert llm_spans diff --git a/wayflowcore/tests/agentspec/test_state_snapshot_tracing_flow.py b/wayflowcore/tests/agentspec/test_state_snapshot_tracing_flow.py index b8c15da51..6bf1ac636 100644 --- a/wayflowcore/tests/agentspec/test_state_snapshot_tracing_flow.py +++ b/wayflowcore/tests/agentspec/test_state_snapshot_tracing_flow.py @@ -148,14 +148,6 @@ def _single_event( return next(event for event in span.events if isinstance(event, event_type)) -def _assert_snapshot_precedes_terminal_event( - span: AgentSpecSpan, - snapshot_events: Sequence[AgentSpecStateSnapshotEmitted], - terminal_event: AgentSpecEvent, -) -> None: - assert span.events.index(snapshot_events[-1]) < span.events.index(terminal_event) - - def _build_tool_state_snapshot_flow() -> Flow: return Flow.from_steps( [ @@ -190,7 +182,6 @@ def test_flow_state_snapshots_are_mapped_into_the_flow_span_before_flow_end() -> flow_span = _single_span(span_recorder, AgentSpecFlowExecutionSpan) assert _events(flow_span, AgentSpecFlowExecutionStart) - flow_end_event = _single_event(flow_span, AgentSpecFlowExecutionEnd) state_snapshot_events = _events(flow_span, AgentSpecStateSnapshotEmitted) assert len(state_snapshot_events) == 2 @@ -198,9 +189,7 @@ def test_flow_state_snapshots_are_mapped_into_the_flow_span_before_flow_end() -> assert final_snapshot_event.conversation_id == conversation.conversation_id assert snapshot_message(final_snapshot_event) == "Hello" assert final_snapshot_event.extra_state == {"ui": {"active_tab": "plan"}} - _assert_snapshot_precedes_terminal_event(flow_span, state_snapshot_events, flow_end_event) assert flow_span.end_time is not None - assert final_snapshot_event.timestamp <= flow_span.end_time assert "variable_state" not in final_snapshot_event.model_dump(mask_sensitive_information=False) assert flow_span in span_recorder.ended_spans @@ -223,7 +212,6 @@ async def test_flow_state_snapshots_are_mapped_into_the_flow_span_before_flow_en assert isinstance(status, FinishedStatus) flow_span = _single_span(span_recorder, AgentSpecFlowExecutionSpan) - flow_end_event = _single_event(flow_span, AgentSpecFlowExecutionEnd) state_snapshot_events = _events(flow_span, AgentSpecStateSnapshotEmitted) assert len(state_snapshot_events) == 2 @@ -231,9 +219,7 @@ async def test_flow_state_snapshots_are_mapped_into_the_flow_span_before_flow_en assert final_snapshot_event.conversation_id == conversation.conversation_id assert snapshot_message(final_snapshot_event) == "Hello" assert final_snapshot_event.extra_state == {"ui": {"active_tab": "plan"}} - _assert_snapshot_precedes_terminal_event(flow_span, state_snapshot_events, flow_end_event) assert flow_span.end_time is not None - assert final_snapshot_event.timestamp <= flow_span.end_time assert "variable_state" not in final_snapshot_event.model_dump(mask_sensitive_information=False) assert flow_span in span_recorder.ended_spans @@ -252,7 +238,6 @@ def test_node_turn_state_snapshots_are_mapped_into_the_flow_span_not_node_spans( assert isinstance(status, FinishedStatus) flow_span = _single_span(span_recorder, AgentSpecFlowExecutionSpan) - flow_end_event = _single_event(flow_span, AgentSpecFlowExecutionEnd) flow_snapshot_events = _events(flow_span, AgentSpecStateSnapshotEmitted) assert len(flow_snapshot_events) == 6 @@ -264,8 +249,6 @@ def test_node_turn_state_snapshots_are_mapped_into_the_flow_span_not_node_spans( ["__StartStep__", "single_step"], ["__StartStep__", "single_step", "end"], ] - _assert_snapshot_precedes_terminal_event(flow_span, flow_snapshot_events, flow_end_event) - node_spans = _spans(span_recorder, AgentSpecNodeExecutionSpan) assert node_spans assert not any( diff --git a/wayflowcore/tests/agentspec/test_state_snapshot_tracing_nested.py b/wayflowcore/tests/agentspec/test_state_snapshot_tracing_nested.py index 58434b91f..b1eb87237 100644 --- a/wayflowcore/tests/agentspec/test_state_snapshot_tracing_nested.py +++ b/wayflowcore/tests/agentspec/test_state_snapshot_tracing_nested.py @@ -4,58 +4,30 @@ # (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0) or Universal Permissive License # (UPL) 1.0 (LICENSE-UPL or https://oss.oracle.com/licenses/upl), at your option. -from contextlib import AbstractContextManager, ExitStack -from dataclasses import dataclass -from typing import Any, Sequence +from contextlib import ExitStack -import pytest from pyagentspec.tracing.events import Event as AgentSpecEvent -from pyagentspec.tracing.events import FlowExecutionEnd as AgentSpecFlowExecutionEnd from pyagentspec.tracing.events import FlowExecutionStart as AgentSpecFlowExecutionStart from pyagentspec.tracing.events import StateSnapshotEmitted as AgentSpecStateSnapshotEmitted -from pyagentspec.tracing.events import SwarmExecutionEnd as AgentSpecSwarmExecutionEnd from pyagentspec.tracing.spanprocessor import SpanProcessor as AgentSpecSpanProcessor -from pyagentspec.tracing.spans import AgentExecutionSpan as AgentSpecAgentExecutionSpan from pyagentspec.tracing.spans import FlowExecutionSpan as AgentSpecFlowExecutionSpan from pyagentspec.tracing.spans import Span as AgentSpecSpan -from pyagentspec.tracing.spans import SwarmExecutionSpan as AgentSpecSwarmExecutionSpan -from pyagentspec.tracing.spans import ToolExecutionSpan as AgentSpecToolExecutionSpan from pyagentspec.tracing.trace import Trace as AgentSpecTrace -from wayflowcore import Agent as WayflowAgent from wayflowcore.agentspec.tracing import AgentSpecEventListener from wayflowcore.events.eventlistener import register_event_listeners -from wayflowcore.executors.executionstatus import FinishedStatus, UserMessageRequestStatus +from wayflowcore.executors.executionstatus import FinishedStatus from wayflowcore.executors.statesnapshotpolicy import StateSnapshotInterval from wayflowcore.flow import Flow -from wayflowcore.messagelist import Message, MessageType -from wayflowcore.models.vllmmodel import VllmModel -from wayflowcore.steps import AgentExecutionStep, CompleteStep, FlowExecutionStep, OutputMessageStep -from wayflowcore.swarm import Swarm -from wayflowcore.tools import ToolRequest +from wayflowcore.steps import CompleteStep, FlowExecutionStep, OutputMessageStep -from ..testhelpers.patching import patch_llm from ..testhelpers.statesnapshots import ( build_state_snapshot_policy, snapshot_message, snapshot_runtime_conversation_ids, - snapshot_status_types, ) -@dataclass(frozen=True) -class SwarmStateSnapshotScenario: - flow: Flow - primary_llm: VllmModel - primary_outputs: list[Message | str] - secondary_llm: VllmModel - secondary_outputs: list[Message | str] - multi_agent_span_class: type[AgentSpecSpan] - child_message: str - parent_message: str - multi_agent_end_event_class: type[AgentSpecEvent] - - class SnapshotSpanRecorder(AgentSpecSpanProcessor): def __init__(self) -> None: super().__init__() @@ -93,237 +65,23 @@ async def shutdown_async(self) -> None: return None -def _create_mock_vllm_model(name: str) -> VllmModel: - return VllmModel(model_id="mock.model", host_port="http://mock.url", name=name) - - -def _create_send_message_request(recipient_name: str, message: str) -> Message: - return Message( - content="", - message_type=MessageType.TOOL_REQUEST, - tool_requests=[ - ToolRequest( - name="send_message", - args={"recipient": recipient_name, "message": message}, - ) - ], - ) - - -def _build_swarm_state_snapshot_flow() -> SwarmStateSnapshotScenario: - first_agent_llm = _create_mock_vllm_model("agent1") - second_agent_llm = _create_mock_vllm_model("agent2") - first_agent = WayflowAgent(llm=first_agent_llm, name="agent1", description="agent1") - second_agent = WayflowAgent(llm=second_agent_llm, name="agent2", description="agent2") - swarm = Swarm( - first_agent=first_agent, - relationships=[(first_agent, second_agent), (second_agent, first_agent)], - name="swarm", - ) - - return SwarmStateSnapshotScenario( - flow=Flow.from_steps([AgentExecutionStep(agent=swarm), CompleteStep(name="end")]), - primary_llm=first_agent_llm, - primary_outputs=[ - _create_send_message_request("agent2", "Do it"), - "agent1 final answer", - ], - secondary_llm=second_agent_llm, - secondary_outputs=["agent2 answer"], - multi_agent_span_class=AgentSpecSwarmExecutionSpan, - child_message="agent2 answer", - parent_message="agent1 final answer", - multi_agent_end_event_class=AgentSpecSwarmExecutionEnd, - ) - - -def _policy( - interval: StateSnapshotInterval, - **kwargs: Any, -): - return build_state_snapshot_policy(interval, **kwargs) - - -def _execute_with_trace( - conversation, - *, - state_snapshot_policy, - span_processors: Sequence[AgentSpecSpanProcessor] = (), - contexts: Sequence[AbstractContextManager[Any]] = (), -): +def _execute_with_trace(conversation) -> tuple[FinishedStatus, SnapshotSpanRecorder]: span_recorder = SnapshotSpanRecorder() listener = AgentSpecEventListener() with ExitStack() as stack: - for context in contexts: - stack.enter_context(context) - stack.enter_context(AgentSpecTrace(span_processors=[span_recorder, *span_processors])) + stack.enter_context(AgentSpecTrace(span_processors=[span_recorder])) stack.enter_context(register_event_listeners([listener])) - status = conversation.execute(state_snapshot_policy=state_snapshot_policy) + status = conversation.execute( + state_snapshot_policy=build_state_snapshot_policy( + StateSnapshotInterval.CONVERSATION_TURNS + ) + ) + assert isinstance(status, FinishedStatus) return status, span_recorder -def _spans( - span_recorder: SnapshotSpanRecorder, - span_type: type[AgentSpecSpan], -) -> list[AgentSpecSpan]: - return [span for span in span_recorder.started_spans if isinstance(span, span_type)] - - -def _single_span( - span_recorder: SnapshotSpanRecorder, - span_type: type[AgentSpecSpan], -) -> AgentSpecSpan: - matching_spans = _spans(span_recorder, span_type) - assert len(matching_spans) == 1 - return matching_spans[0] - - -def _events( - span: AgentSpecSpan, - event_type: type[AgentSpecEvent], -) -> list[AgentSpecEvent]: - return [event for event in span.events if isinstance(event, event_type)] - - -def _single_event( - span: AgentSpecSpan, - event_type: type[AgentSpecEvent], -) -> AgentSpecEvent: - return next(event for event in span.events if isinstance(event, event_type)) - - -def _assert_snapshot_precedes_terminal_event( - span: AgentSpecSpan, - snapshot_events: Sequence[AgentSpecStateSnapshotEmitted], - terminal_event: AgentSpecEvent, -) -> None: - assert span.events.index(snapshot_events[-1]) < span.events.index(terminal_event) - - -@pytest.mark.parametrize( - "flow_builder", - [ - pytest.param(_build_swarm_state_snapshot_flow, id="swarm"), - ], -) -def test_nested_multi_agent_state_snapshots_follow_conversation_ownership_boundaries( - flow_builder, -) -> None: - scenario = flow_builder() - conversation = scenario.flow.start_conversation() - conversation.append_user_message("dummy") - - status, span_recorder = _execute_with_trace( - conversation, - state_snapshot_policy=_policy(StateSnapshotInterval.CONVERSATION_TURNS), - contexts=[ - patch_llm(scenario.primary_llm, scenario.primary_outputs, patch_internal=True), - patch_llm(scenario.secondary_llm, scenario.secondary_outputs, patch_internal=True), - ], - ) - - assert isinstance(status, UserMessageRequestStatus) - - flow_span = _single_span(span_recorder, AgentSpecFlowExecutionSpan) - flow_snapshot_events = _events(flow_span, AgentSpecStateSnapshotEmitted) - assert len(flow_snapshot_events) == 2 - assert [event.conversation_id for event in flow_snapshot_events] == [ - conversation.conversation_id, - conversation.conversation_id, - ] - assert snapshot_message(flow_snapshot_events[-1]) == scenario.parent_message - - multi_agent_span = _single_span(span_recorder, scenario.multi_agent_span_class) - multi_agent_snapshot_events = _events(multi_agent_span, AgentSpecStateSnapshotEmitted) - multi_agent_end_event = _single_event(multi_agent_span, scenario.multi_agent_end_event_class) - parent_multi_agent_conversation_id = multi_agent_snapshot_events[0].conversation_id - - assert [event.conversation_id for event in multi_agent_snapshot_events] == [ - parent_multi_agent_conversation_id, - parent_multi_agent_conversation_id, - parent_multi_agent_conversation_id, - parent_multi_agent_conversation_id, - parent_multi_agent_conversation_id, - parent_multi_agent_conversation_id, - ] - assert snapshot_status_types(multi_agent_snapshot_events) == [ - None, - "ToolRequestStatus", - None, - "UserMessageRequestStatus", - None, - "UserMessageRequestStatus", - ] - assert snapshot_message(multi_agent_snapshot_events[4]) == scenario.child_message - assert snapshot_message(multi_agent_snapshot_events[-1]) == scenario.parent_message - _assert_snapshot_precedes_terminal_event( - multi_agent_span, - multi_agent_snapshot_events, - multi_agent_end_event, - ) - - agent_snapshot_spans = [ - span - for span in span_recorder.started_spans - if isinstance(span, AgentSpecAgentExecutionSpan) - and any(isinstance(event, AgentSpecStateSnapshotEmitted) for event in span.events) - ] - assert len(agent_snapshot_spans) == 3 - agent_snapshot_events_by_conversation_id: dict[str, list] = {} - for agent_span in agent_snapshot_spans: - snapshot_events = _events(agent_span, AgentSpecStateSnapshotEmitted) - agent_snapshot_events_by_conversation_id.setdefault( - snapshot_events[0].conversation_id, - [], - ).extend(snapshot_events) - - assert len(agent_snapshot_events_by_conversation_id) == 2 - manager_thread_snapshot_events = next( - snapshot_events - for snapshot_events in agent_snapshot_events_by_conversation_id.values() - if len(snapshot_events) == 4 - ) - delegated_agent_snapshot_events = next( - snapshot_events - for snapshot_events in agent_snapshot_events_by_conversation_id.values() - if len(snapshot_events) == 2 - ) - - assert manager_thread_snapshot_events[0].conversation_id != conversation.conversation_id - assert manager_thread_snapshot_events[0].conversation_id != parent_multi_agent_conversation_id - assert delegated_agent_snapshot_events[0].conversation_id not in { - conversation.conversation_id, - parent_multi_agent_conversation_id, - manager_thread_snapshot_events[0].conversation_id, - } - assert snapshot_status_types(manager_thread_snapshot_events) == [ - None, - "ToolRequestStatus", - None, - "UserMessageRequestStatus", - ] - assert snapshot_message(manager_thread_snapshot_events[2]) == scenario.child_message - assert snapshot_message(manager_thread_snapshot_events[-1]) == scenario.parent_message - assert snapshot_status_types(delegated_agent_snapshot_events) == [ - None, - "UserMessageRequestStatus", - ] - assert snapshot_message(delegated_agent_snapshot_events[-1]) == scenario.child_message - - tool_spans = _spans(span_recorder, AgentSpecToolExecutionSpan) - assert tool_spans - assert not any( - isinstance(event, AgentSpecStateSnapshotEmitted) - for span in tool_spans - for event in span.events - ) - - assert flow_span in span_recorder.ended_spans - assert multi_agent_span in span_recorder.ended_spans - - def test_nested_flow_state_snapshots_stay_on_the_root_flow_span_for_shared_conversations() -> None: child_flow = Flow.from_steps( [OutputMessageStep(message_template="child"), CompleteStep(name="end")], @@ -341,26 +99,28 @@ def test_nested_flow_state_snapshots_stay_on_the_root_flow_span_for_shared_conve ) conversation = parent_flow.start_conversation() - status, span_recorder = _execute_with_trace( - conversation, - state_snapshot_policy=_policy(StateSnapshotInterval.CONVERSATION_TURNS), - ) - - assert isinstance(status, FinishedStatus) + _, span_recorder = _execute_with_trace(conversation) - flow_spans = _spans(span_recorder, AgentSpecFlowExecutionSpan) + flow_spans = [ + span for span in span_recorder.started_spans if isinstance(span, AgentSpecFlowExecutionSpan) + ] assert len(flow_spans) == 2 flow_spans_by_name = { - _single_event(span, AgentSpecFlowExecutionStart).flow.name: span for span in flow_spans + next( + event for event in span.events if isinstance(event, AgentSpecFlowExecutionStart) + ).flow.name: span + for span in flow_spans } parent_span = flow_spans_by_name["parent_flow"] child_span = flow_spans_by_name["child_flow"] - parent_snapshot_events = _events(parent_span, AgentSpecStateSnapshotEmitted) - child_snapshot_events = _events(child_span, AgentSpecStateSnapshotEmitted) - parent_end_event = _single_event(parent_span, AgentSpecFlowExecutionEnd) - + parent_snapshot_events = [ + event for event in parent_span.events if isinstance(event, AgentSpecStateSnapshotEmitted) + ] + child_snapshot_events = [ + event for event in child_span.events if isinstance(event, AgentSpecStateSnapshotEmitted) + ] assert [event.conversation_id for event in parent_snapshot_events] == [ conversation.conversation_id, conversation.conversation_id, @@ -378,4 +138,3 @@ def test_nested_flow_state_snapshots_stay_on_the_root_flow_span_for_shared_conve assert snapshot_message(parent_snapshot_events[2]) == "child" assert snapshot_message(parent_snapshot_events[-1]) == "parent" assert not child_snapshot_events - _assert_snapshot_precedes_terminal_event(parent_span, parent_snapshot_events, parent_end_event) diff --git a/wayflowcore/tests/events/test_state_snapshot_runtime_conversation_turns.py b/wayflowcore/tests/events/test_state_snapshot_runtime_conversation_turns.py index b5822238f..78fbd47ee 100644 --- a/wayflowcore/tests/events/test_state_snapshot_runtime_conversation_turns.py +++ b/wayflowcore/tests/events/test_state_snapshot_runtime_conversation_turns.py @@ -12,21 +12,17 @@ from wayflowcore.executors.interrupts.executioninterrupt import InterruptedExecutionStatus from wayflowcore.executors.statesnapshotpolicy import StateSnapshotInterval -from ..conftest import disable_streaming from ..test_interrupts import OnEventExecutionInterrupt from ..testhelpers.statesnapshots import ( MutatingExecutionEndInterrupt, SnapshotCollector, - WorkerExecutionEndInterrupt, assert_terminal_snapshot, build_policy, create_agent_conversation, - create_managerworkers_conversation, create_output_flow_conversation, create_tool_flow_conversation, execute_with_state_snapshots, execute_with_state_snapshots_async, - find_snapshot_events_by_component_type, snapshot_status_types, ) @@ -210,25 +206,3 @@ def test_conversation_turn_policy_reflects_real_interrupt_side_effects_once() -> assert conversation.inputs["preview_count"] == 1 assert snapshot_status_types(state_snapshot_events) == [None, "FinishedStatus"] assert state_snapshot_events[-1].state_snapshot["conversation"]["inputs"]["preview_count"] == 1 - - -def test_parent_multi_agent_does_not_emit_turn_end_snapshot_when_child_turn_is_interrupted() -> ( - None -): - conversation = create_managerworkers_conversation() - - status, state_snapshot_events = execute_with_state_snapshots( - conversation, - execution_interrupts=[WorkerExecutionEndInterrupt()], - state_snapshot_policy=build_policy(StateSnapshotInterval.CONVERSATION_TURNS), - execution_context=disable_streaming(), - ) - - assert isinstance(status, InterruptedExecutionStatus) - parent_multi_agent_snapshot_events = find_snapshot_events_by_component_type( - state_snapshot_events, - "ManagerWorkers", - ) - assert "InterruptedExecutionStatus" not in snapshot_status_types( - parent_multi_agent_snapshot_events - ) diff --git a/wayflowcore/tests/events/test_state_snapshot_runtime_nested.py b/wayflowcore/tests/events/test_state_snapshot_runtime_nested.py index d307971aa..442dcd973 100644 --- a/wayflowcore/tests/events/test_state_snapshot_runtime_nested.py +++ b/wayflowcore/tests/events/test_state_snapshot_runtime_nested.py @@ -19,13 +19,9 @@ from ..testhelpers.statesnapshots import ( build_policy, create_nested_agent_step_flow_conversation, - create_nested_managerworkers_flow_conversation, - create_nested_swarm_flow_conversation, create_parallel_child_flow, execute_with_state_snapshots, execute_with_state_snapshots_async, - find_snapshot_events_by_component_type, - group_snapshot_events_by_conversation_id, snapshot_message, snapshot_runtime_conversation_ids, snapshot_status_types, @@ -155,104 +151,3 @@ def test_state_snapshot_policy_is_inherited_by_nested_agent_steps() -> None: ] assert nested_conversation_id != conversation.conversation_id assert snapshot_message(state_snapshot_events[-1]) == "agent answer" - - -@pytest.mark.parametrize( - ( - "conversation_factory", - "expected_multi_agent_component_type", - "expected_child_message", - "expected_parent_message", - ), - [ - pytest.param( - create_nested_managerworkers_flow_conversation, - "ManagerWorkers", - "worker answer", - "manager final answer", - id="managerworkers", - ), - pytest.param( - create_nested_swarm_flow_conversation, - "Swarm", - "agent2 answer", - "agent1 final answer", - id="swarm", - ), - ], -) -def test_nested_multi_agent_components_emit_snapshots_for_the_active_conversation( - conversation_factory, - expected_multi_agent_component_type: str, - expected_child_message: str, - expected_parent_message: str, -) -> None: - conversation = conversation_factory() - - status, state_snapshot_events = execute_with_state_snapshots( - conversation, - state_snapshot_policy=build_policy(StateSnapshotInterval.CONVERSATION_TURNS), - ) - - assert isinstance(status, UserMessageRequestStatus) - assert len(state_snapshot_events) == 14 - - snapshot_events_by_conversation_id = group_snapshot_events_by_conversation_id( - state_snapshot_events - ) - - assert len(snapshot_events_by_conversation_id) == 4 - - flow_snapshot_events = snapshot_events_by_conversation_id[conversation.conversation_id] - parent_multi_agent_snapshot_events = find_snapshot_events_by_component_type( - state_snapshot_events, - expected_multi_agent_component_type, - ) - agent_snapshot_event_groups = [ - snapshot_events - for conversation_id, snapshot_events in snapshot_events_by_conversation_id.items() - if conversation_id - not in { - conversation.conversation_id, - parent_multi_agent_snapshot_events[0].conversation_id, - } - ] - manager_thread_snapshot_events = next( - snapshot_events - for snapshot_events in agent_snapshot_event_groups - if len(snapshot_events) == 4 - ) - delegated_agent_snapshot_events = next( - snapshot_events - for snapshot_events in agent_snapshot_event_groups - if len(snapshot_events) == 2 - ) - - assert snapshot_status_types(flow_snapshot_events) == [None, "UserMessageRequestStatus"] - assert snapshot_message(flow_snapshot_events[-1]) == expected_parent_message - - assert snapshot_status_types(parent_multi_agent_snapshot_events) == [ - None, - "ToolRequestStatus", - None, - "UserMessageRequestStatus", - None, - "UserMessageRequestStatus", - ] - assert snapshot_message(parent_multi_agent_snapshot_events[4]) == expected_child_message - assert snapshot_message(parent_multi_agent_snapshot_events[-1]) == expected_parent_message - - assert snapshot_status_types(manager_thread_snapshot_events) == [ - None, - "ToolRequestStatus", - None, - "UserMessageRequestStatus", - ] - assert snapshot_message(manager_thread_snapshot_events[2]) == expected_child_message - assert snapshot_message(manager_thread_snapshot_events[-1]) == expected_parent_message - - assert snapshot_status_types(delegated_agent_snapshot_events) == [ - None, - "UserMessageRequestStatus", - ] - assert snapshot_message(delegated_agent_snapshot_events[-1]) == expected_child_message diff --git a/wayflowcore/tests/testhelpers/statesnapshots.py b/wayflowcore/tests/testhelpers/statesnapshots.py index f613fae20..28e49c606 100644 --- a/wayflowcore/tests/testhelpers/statesnapshots.py +++ b/wayflowcore/tests/testhelpers/statesnapshots.py @@ -5,7 +5,6 @@ # (UPL) 1.0 (LICENSE-UPL or https://oss.oracle.com/licenses/upl), at your option. import threading -from collections import defaultdict from contextlib import AbstractContextManager, nullcontext from typing import Any, Callable, Sequence @@ -21,7 +20,6 @@ ) from wayflowcore.executors.statesnapshotpolicy import StateSnapshotInterval, StateSnapshotPolicy from wayflowcore.flow import Flow -from wayflowcore.managerworkers import ManagerWorkers from wayflowcore.messagelist import Message, MessageType from wayflowcore.property import AnyProperty, StringProperty from wayflowcore.serialization.serializer import SerializableNeedToBeImplementedMixin @@ -32,7 +30,6 @@ ToolExecutionStep, VariableWriteStep, ) -from wayflowcore.swarm import Swarm from wayflowcore.tools import ServerTool, ToolRequest, tool from wayflowcore.variable import Variable @@ -86,26 +83,6 @@ def snapshot_step_histories(snapshot_events: Sequence[Any]) -> list[list[str]]: ] -def group_snapshot_events_by_conversation_id( - snapshot_events: Sequence[Any], -) -> dict[str, list[Any]]: - grouped_snapshot_events: dict[str, list[Any]] = defaultdict(list) - for snapshot_event in snapshot_events: - grouped_snapshot_events[snapshot_event.conversation_id].append(snapshot_event) - return dict(grouped_snapshot_events) - - -def find_snapshot_events_by_component_type( - snapshot_events: Sequence[Any], - component_type: str, -) -> list[Any]: - return next( - grouped_events - for grouped_events in group_snapshot_events_by_conversation_id(snapshot_events).values() - if grouped_events[0].state_snapshot["conversation"]["component_type"] == component_type - ) - - def execute_with_state_snapshots( conversation: Conversation, *, @@ -160,29 +137,6 @@ def _on_execution_end( return None -class WorkerExecutionEndInterrupt(SerializableNeedToBeImplementedMixin, _NullExecutionInterrupt): - def __init__(self) -> None: - self.triggered = False - super().__init__() - - def _on_execution_end( - self, - state: ConversationExecutionState, - conversation: Conversation, - ) -> InterruptedExecutionStatus | None: - if self.triggered: - return None - if getattr(conversation.component, "name", None) != "worker": - return None - - self.triggered = True - return InterruptedExecutionStatus( - interrupter=self, - reason="worker execution end", - _conversation_id=conversation.id, - ) - - class _UnserializableVariableValue: pass @@ -246,19 +200,6 @@ def do_nothing_tool() -> str: return conversation -def _create_send_message_request(recipient_name: str, message: str) -> Message: - return Message( - content="", - message_type=MessageType.TOOL_REQUEST, - tool_requests=[ - ToolRequest( - name="send_message", - args={"recipient": recipient_name, "message": message}, - ) - ], - ) - - def create_nested_agent_step_flow_conversation() -> Conversation: llm = DummyModel() llm.set_next_output("agent answer") @@ -270,65 +211,6 @@ def create_nested_agent_step_flow_conversation() -> Conversation: return conversation -def create_nested_managerworkers_flow_conversation() -> Conversation: - llm = DummyModel() - worker = Agent(llm=llm, name="worker", description="worker") - group = ManagerWorkers(group_manager=llm, workers=[worker]) - llm.set_next_output( - [ - _create_send_message_request("worker", "Do it"), - "worker answer", - "manager final answer", - ] - ) - - conversation = Flow.from_steps( - [AgentExecutionStep(agent=group), CompleteStep(name="end")] - ).start_conversation() - conversation.append_user_message("dummy") - return conversation - - -def create_managerworkers_conversation() -> Conversation: - llm = DummyModel() - worker = Agent(llm=llm, name="worker", description="worker") - group = ManagerWorkers(group_manager=llm, workers=[worker]) - llm.set_next_output( - [ - _create_send_message_request("worker", "Do it"), - "worker answer", - "manager final answer", - ] - ) - - conversation = group.start_conversation() - conversation.append_user_message("dummy") - return conversation - - -def create_nested_swarm_flow_conversation() -> Conversation: - llm = DummyModel() - first_agent = Agent(llm=llm, name="agent1", description="agent1") - second_agent = Agent(llm=llm, name="agent2", description="agent2") - swarm = Swarm( - first_agent=first_agent, - relationships=[(first_agent, second_agent), (second_agent, first_agent)], - ) - llm.set_next_output( - [ - _create_send_message_request("agent2", "Do it"), - "agent2 answer", - "agent1 final answer", - ] - ) - - conversation = Flow.from_steps( - [AgentExecutionStep(agent=swarm), CompleteStep(name="end")] - ).start_conversation() - conversation.append_user_message("dummy") - return conversation - - def create_parallel_child_flow(output_name: str, output_value: str) -> Flow: return Flow.from_steps( [ From 95d3f105a63248e76965593b53dbf63b3859592d Mon Sep 17 00:00:00 2001 From: Son Le Date: Wed, 18 Mar 2026 18:52:16 +0100 Subject: [PATCH 07/13] add resumable conversation state helpers --- .../source/core/api/serialization.rst | 6 ++ docs/wayflowcore/source/core/changelog.rst | 3 +- .../src/wayflowcore/agentspec/tracing.py | 42 +++++--- wayflowcore/src/wayflowcore/conversation.py | 8 +- .../executors/_statesnapshot_eventlistener.py | 31 +++--- .../src/wayflowcore/serialization/__init__.py | 14 +++ .../wayflowcore/serialization/conversation.py | 76 ++++++++++++++- .../test_state_snapshot_tracing_agent.py | 91 +++++++---------- .../test_state_snapshot_tracing_flow.py | 93 ++++++------------ .../test_state_snapshot_tracing_nested.py | 30 ++---- .../test_conversation_state_snapshot.py | 97 ++++++++++++++++++- .../tests/testhelpers/statesnapshots.py | 27 ++---- 12 files changed, 313 insertions(+), 205 deletions(-) diff --git a/docs/wayflowcore/source/core/api/serialization.rst b/docs/wayflowcore/source/core/api/serialization.rst index 1262dd377..f58c85fb7 100644 --- a/docs/wayflowcore/source/core/api/serialization.rst +++ b/docs/wayflowcore/source/core/api/serialization.rst @@ -45,6 +45,12 @@ Conversation State Snapshots .. _deserializeconversationstate: .. autofunction:: wayflowcore.serialization.deserialize_conversation_state +.. _loadconversationstate: +.. autofunction:: wayflowcore.serialization.load_conversation_state + +.. _deserializeconversation: +.. autofunction:: wayflowcore.serialization.deserialize_conversation + .. _dumpvariablestate: .. autofunction:: wayflowcore.serialization.dump_variable_state diff --git a/docs/wayflowcore/source/core/changelog.rst b/docs/wayflowcore/source/core/changelog.rst index d13a9c518..7131c33d7 100644 --- a/docs/wayflowcore/source/core/changelog.rst +++ b/docs/wayflowcore/source/core/changelog.rst @@ -9,7 +9,8 @@ New features * **State snapshot tracing events:** - Added ``StateSnapshotPolicy``, ``StateSnapshotEvent``, and conversation snapshot serialization helpers. State snapshots can now be enabled per ``conversation.execute()`` / ``execute_async()`` turn, emitted at conversation, node, or tool boundaries, and bridged to Agent Spec ``StateSnapshotEmitted`` events via ``AgentSpecEventListener``. + Added configurable conversation state snapshots for tracing, with emission at conversation, node, or tool boundaries and bridging into Agent Spec state snapshot events. + Added resumable conversation state serialization so persisted conversations can be restored and continued. Snapshot emission is covered on both synchronous and asynchronous execution paths, with snapshot ownership currently scoped to the active conversation. * **OAuth support for MCP Clients:** diff --git a/wayflowcore/src/wayflowcore/agentspec/tracing.py b/wayflowcore/src/wayflowcore/agentspec/tracing.py index 78314be0c..973e0d565 100644 --- a/wayflowcore/src/wayflowcore/agentspec/tracing.py +++ b/wayflowcore/src/wayflowcore/agentspec/tracing.py @@ -4,6 +4,7 @@ # (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0) or Universal Permissive License # (UPL) 1.0 (LICENSE-UPL or https://oss.oracle.com/licenses/upl), at your option. import json +from dataclasses import dataclass from typing import Dict, Optional, Union, cast from pyagentspec import Component as AgentSpecComponent @@ -63,6 +64,12 @@ from wayflowcore.tracing.span import LlmGenerationSpan, get_active_span_stack, get_current_span +@dataclass(frozen=True) +class _ConversationSpanOwner: + runtime_conversation_id: str + span: AgentSpecSpan + + class AgentSpecEventListener(EventListener): """Event listener that emits traces according to the Open Agent Spec Tracing standard""" @@ -79,7 +86,7 @@ def __init__(self) -> None: # when the snapshot event was emitted. Nested flow sub-conversations can # intentionally reuse the same logical conversation_id, so we also track # the live runtime conversation id that currently owns that stream. - self._conversation_spans_registry: Dict[str, tuple[str, AgentSpecSpan]] = {} + self._conversation_span_owners: Dict[str, _ConversationSpanOwner] = {} # Track last assistant message id and a robust mapping tool_request_id -> assistant message id. # Some providers may emit tool events before final assistant message id is known; we allow # temporarily missing ids and backfill on LLM response. @@ -101,19 +108,23 @@ def _get_active_wayflow_conversation(self) -> Conversation | None: def _register_current_conversation_span(self, agentspec_span: AgentSpecSpan) -> None: active_conversation = self._get_active_wayflow_conversation() - if active_conversation is not None: - current_owner = self._conversation_spans_registry.get( - active_conversation.conversation_id + if active_conversation is None: + return + + current_owner = self._conversation_span_owners.get(active_conversation.conversation_id) + if ( + current_owner is not None + and current_owner.runtime_conversation_id != active_conversation.id + and current_owner.span.end_time is None + ): + return + + self._conversation_span_owners[active_conversation.conversation_id] = ( + _ConversationSpanOwner( + runtime_conversation_id=active_conversation.id, + span=agentspec_span, ) - if ( - current_owner is None - or current_owner[0] == active_conversation.id - or current_owner[1].end_time is not None - ): - self._conversation_spans_registry[active_conversation.conversation_id] = ( - active_conversation.id, - agentspec_span, - ) + ) def _get_snapshot_owner_span( self, @@ -125,8 +136,9 @@ def _get_snapshot_owner_span( # nested flows. Future multi-agent tracing can extend this method to # route snapshots to a dedicated Swarm/ManagerWorkers wrapper span # without changing the StateSnapshotEvent handling below. - if event.conversation_id in self._conversation_spans_registry: - return self._conversation_spans_registry[event.conversation_id][1] + owner = self._conversation_span_owners.get(event.conversation_id) + if owner is not None: + return owner.span return current_agentspec_span diff --git a/wayflowcore/src/wayflowcore/conversation.py b/wayflowcore/src/wayflowcore/conversation.py index 39dab5e32..5da41403d 100644 --- a/wayflowcore/src/wayflowcore/conversation.py +++ b/wayflowcore/src/wayflowcore/conversation.py @@ -1,4 +1,4 @@ -# Copyright © 2025 Oracle and/or its affiliates. +# Copyright © 2025, 2026 Oracle and/or its affiliates. # # This software is under the Apache License 2.0 # (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0) or Universal Permissive License @@ -152,9 +152,9 @@ async def execute_async( with _register_conversation(self): new_status = await self.component.runner.execute_async(self, execution_interrupts) - self.status = new_status - self.status_handled = False - return self.status + self.status = new_status + self.status_handled = False + return self.status @property @abstractmethod diff --git a/wayflowcore/src/wayflowcore/executors/_statesnapshot_eventlistener.py b/wayflowcore/src/wayflowcore/executors/_statesnapshot_eventlistener.py index 41a862018..c02238b30 100644 --- a/wayflowcore/src/wayflowcore/executors/_statesnapshot_eventlistener.py +++ b/wayflowcore/src/wayflowcore/executors/_statesnapshot_eventlistener.py @@ -171,12 +171,11 @@ def _build_variable_state( return None -def record_state_snapshot( +def _record_state_snapshot( conversation: Conversation, required_snapshot_interval: StateSnapshotInterval, *, execution_status: ExecutionStatus | None, - status_handled: bool, ) -> None: state_snapshot_policy = _get_snapshot_policy_for_interval( conversation, required_snapshot_interval @@ -191,7 +190,9 @@ def record_state_snapshot( state_snapshot=dump_conversation_state( conversation, status=execution_status, - status_handled=status_handled, + # Snapshots should expose the canonical pre-consumption view + # of a turn, not transient runtime bookkeeping. + status_handled=False, ), extra_state=_build_extra_state(conversation, state_snapshot_policy), variable_state=_build_variable_state(conversation, state_snapshot_policy), @@ -205,13 +206,6 @@ def record_state_snapshot( ) -def _get_current_active_conversation() -> Optional[Conversation]: - active_conversations = _get_active_conversations(return_copy=False) - if not active_conversations: - return None - return active_conversations[-1] - - class StateSnapshotEventListener(EventListener): """Emit state snapshots for the active conversation.""" @@ -228,11 +222,10 @@ def _record_snapshot( required_snapshot_interval: StateSnapshotInterval, execution_status: ExecutionStatus | None = None, ) -> None: - record_state_snapshot( + _record_state_snapshot( self.conversation, required_snapshot_interval, execution_status=execution_status, - status_handled=False, ) def _handle_pre_interrupt_event( @@ -264,10 +257,11 @@ def _should_record_interrupted_turn_end_snapshot( ) def _owns_current_conversation(self, current_conversation: Conversation) -> bool: - # Keep snapshot ownership resolution centralized here. Today a listener - # only reacts for its own active conversation. Follow-up PRs can widen - # this to parent multi-agent wrapper conversations without touching the - # snapshot emission logic. + # This is the intended extension point for future multi-agent snapshot + # ownership rules. Today a listener only reacts for its own active + # conversation. Follow-up PRs can widen this here to parent wrapper + # conversations (for example swarms or manager-workers) without + # changing the snapshot emission logic elsewhere in this listener. return current_conversation.id == self.conversation.id def _handle_post_interrupt_event(self, event: Event) -> None: @@ -290,10 +284,11 @@ def __call__(self, event: Event) -> None: if isinstance(event, StateSnapshotEvent): return - current_conversation = _get_current_active_conversation() - if current_conversation is None: + active_conversations = _get_active_conversations(return_copy=False) + if not active_conversations: return + current_conversation = active_conversations[-1] if not self._owns_current_conversation(current_conversation): return diff --git a/wayflowcore/src/wayflowcore/serialization/__init__.py b/wayflowcore/src/wayflowcore/serialization/__init__.py index 8089ebfdf..b5b9672b7 100644 --- a/wayflowcore/src/wayflowcore/serialization/__init__.py +++ b/wayflowcore/src/wayflowcore/serialization/__init__.py @@ -33,6 +33,18 @@ def deserialize_conversation_state(*args: Any, **kwargs: Any) -> Any: return _deserialize_conversation_state(*args, **kwargs) +def load_conversation_state(*args: Any, **kwargs: Any) -> Any: + from .conversation import load_conversation_state as _load_conversation_state + + return _load_conversation_state(*args, **kwargs) + + +def deserialize_conversation(*args: Any, **kwargs: Any) -> Any: + from .conversation import deserialize_conversation as _deserialize_conversation + + return _deserialize_conversation(*args, **kwargs) + + def dump_variable_state(*args: Any, **kwargs: Any) -> Any: from .conversation import dump_variable_state as _dump_variable_state @@ -42,10 +54,12 @@ def dump_variable_state(*args: Any, **kwargs: Any) -> Any: __all__ = [ "autodeserialize", "deserialize", + "deserialize_conversation", "deserialize_conversation_state", "deserialize_from_dict", "dump_conversation_state", "dump_variable_state", + "load_conversation_state", "serialize", "serialize_conversation_state", "serialize_to_dict", diff --git a/wayflowcore/src/wayflowcore/serialization/conversation.py b/wayflowcore/src/wayflowcore/serialization/conversation.py index 502320460..3a1539570 100644 --- a/wayflowcore/src/wayflowcore/serialization/conversation.py +++ b/wayflowcore/src/wayflowcore/serialization/conversation.py @@ -7,10 +7,13 @@ from __future__ import annotations import json +import warnings from datetime import datetime from enum import Enum from typing import TYPE_CHECKING, Any, Optional, cast +import yaml + from wayflowcore._utils.formatting import stringify from wayflowcore.executors.executionstatus import ( AuthChallengeRequestStatus, @@ -21,14 +24,22 @@ UserMessageRequestStatus, ) from wayflowcore.messagelist import ImageContent, Message, MessageContent, TextContent -from wayflowcore.serialization.context import SerializationContext -from wayflowcore.serialization.serializer import serialize_any_to_dict_or_stringify +from wayflowcore.serialization.context import DeserializationContext, SerializationContext +from wayflowcore.serialization.serializer import ( + autodeserialize_from_dict, + serialize, + serialize_any_to_dict_or_stringify, +) from wayflowcore.tools.tools import ToolRequest, ToolResult if TYPE_CHECKING: from wayflowcore.conversation import Conversation from wayflowcore.executors._agentconversation import AgentConversation from wayflowcore.executors._flowconversation import FlowConversation + from wayflowcore.serialization.plugins import ( + WayflowDeserializationPlugin, + WayflowSerializationPlugin, + ) _UNSET = object() @@ -283,6 +294,7 @@ def dump_conversation_state( status: object = _UNSET, status_handled: object = _UNSET, ) -> dict[str, Any]: + """Return a JSON-serializable runtime snapshot of the conversation state.""" from wayflowcore.executors._agentconversation import AgentConversation from wayflowcore.executors._flowconversation import FlowConversation @@ -317,15 +329,67 @@ def dump_conversation_state( } -def serialize_conversation_state(conversation: "Conversation") -> str: - return json.dumps(dump_conversation_state(conversation), sort_keys=True) +def serialize_conversation_state( + conversation: "Conversation", + serialization_context: Optional[SerializationContext] = None, + plugins: Optional[list["WayflowSerializationPlugin"]] = None, +) -> str: + """Serialize a full conversation state into a stable text representation.""" + return serialize( + conversation, + serialization_context=serialization_context, + plugins=plugins, + ) def deserialize_conversation_state(state: str) -> dict[str, Any]: - return cast(dict[str, Any], json.loads(state)) + """Parse a serialized conversation state string back into a dictionary.""" + loaded_state = yaml.safe_load(state) + if not isinstance(loaded_state, dict): + raise TypeError("Serialized conversation state must deserialize into a dictionary.") + return cast(dict[str, Any], loaded_state) + + +def load_conversation_state( + state: dict[str, Any], + deserialization_context: Optional[DeserializationContext] = None, + plugins: Optional[list["WayflowDeserializationPlugin"]] = None, +) -> "Conversation": + """Reconstruct a conversation from a serialized state dictionary.""" + from wayflowcore.conversation import Conversation + + if deserialization_context is None: + deserialization_context = DeserializationContext(plugins=plugins) + elif plugins is not None: + warnings.warn( + "A list of plugins was provided together with a deserialization context instance in `load_conversation_state`. " + "Do not pass the plugins to `load_conversation_state`, but create the context instance passing the list of plugins instead.", + UserWarning, + ) + + conversation = autodeserialize_from_dict(state, deserialization_context) + if not isinstance(conversation, Conversation): + raise TypeError( + f"Loaded object is of type {conversation.__class__.__name__}, not Conversation." + ) + return conversation + + +def deserialize_conversation( + conversation_state: str, + deserialization_context: Optional[DeserializationContext] = None, + plugins: Optional[list["WayflowDeserializationPlugin"]] = None, +) -> "Conversation": + """Reconstruct a conversation directly from its serialized state string.""" + return load_conversation_state( + deserialize_conversation_state(conversation_state), + deserialization_context=deserialization_context, + plugins=plugins, + ) def dump_variable_state(conversation: "Conversation") -> Optional[dict[str, Any]]: + """Return the JSON-serializable runtime-owned variable state for a conversation.""" from wayflowcore.executors._flowconversation import FlowConversation if not isinstance(conversation, FlowConversation): @@ -336,8 +400,10 @@ def dump_variable_state(conversation: "Conversation") -> Optional[dict[str, Any] __all__ = [ + "deserialize_conversation", "deserialize_conversation_state", "dump_conversation_state", "dump_variable_state", + "load_conversation_state", "serialize_conversation_state", ] diff --git a/wayflowcore/tests/agentspec/test_state_snapshot_tracing_agent.py b/wayflowcore/tests/agentspec/test_state_snapshot_tracing_agent.py index d4fad2933..a9698ea09 100644 --- a/wayflowcore/tests/agentspec/test_state_snapshot_tracing_agent.py +++ b/wayflowcore/tests/agentspec/test_state_snapshot_tracing_agent.py @@ -30,7 +30,7 @@ from ..testhelpers.patching import patch_llm from ..testhelpers.statesnapshots import ( - build_state_snapshot_policy, + build_policy, snapshot_message, snapshot_status_types, ) @@ -185,15 +185,6 @@ async def shutdown_async(self) -> None: ) -def _create_retrieval_like_wayflow_agent() -> WayflowAgent: - agentspec_agent = AgentSpecAgent( - name="retrieval_agent", - llm_config=VllmConfig(name="llm", url="http://mock.url", model_id="mock.model"), - system_prompt="You are a helpful retrieval agent.", - ) - return cast(WayflowAgent, AgentSpecLoader().load_component(agentspec_agent)) - - def _build_retrieval_agent_state( *, conversation_inputs: dict[str, Any], @@ -212,39 +203,6 @@ def _build_retrieval_agent_state( ) -def _build_retrieval_like_extra_state(conversation) -> dict[str, Any]: - conversation_snapshot = dump_conversation_state(conversation)["conversation"] - messages = conversation_snapshot["messages"] - last_response = next( - ( - message.get("content") - for message in reversed(messages) - if message.get("role") == "assistant" and message.get("content") - ), - "", - ) - return { - "agent_state": asdict( - _build_retrieval_agent_state( - conversation_inputs=conversation.inputs, - message_count=len(messages), - last_response=last_response, - ) - ) - } - - -def _create_mock_vllm_model(name: str) -> VllmModel: - return VllmModel(model_id="mock.model", host_port="http://mock.url", name=name) - - -def _policy( - interval: StateSnapshotInterval, - **kwargs: Any, -): - return build_state_snapshot_policy(interval, **kwargs) - - def _execute_with_trace( conversation, *, @@ -288,26 +246,49 @@ def _events( return [event for event in span.events if isinstance(event, event_type)] -def _single_event( - span: AgentSpecSpan, - event_type: type[AgentSpecEvent], -) -> AgentSpecEvent: - return next(event for event in span.events if isinstance(event, event_type)) - - def test_agent_state_snapshots_support_the_agui_retrieval_export_flow() -> None: assistant_message = "I checked the warehouse and found 42 orders last week." - wayflow_agent = _create_retrieval_like_wayflow_agent() + wayflow_agent = cast( + WayflowAgent, + AgentSpecLoader().load_component( + AgentSpecAgent( + name="retrieval_agent", + llm_config=VllmConfig(name="llm", url="http://mock.url", model_id="mock.model"), + system_prompt="You are a helpful retrieval agent.", + ) + ), + ) conversation = wayflow_agent.start_conversation(inputs=_RETRIEVAL_INPUTS) conversation.append_user_message(_RETRIEVAL_INPUTS["input"]) agui_exporter = AGUIStateSnapshotExporter() + def build_extra_state(conversation) -> dict[str, Any]: + conversation_snapshot = dump_conversation_state(conversation)["conversation"] + messages = conversation_snapshot["messages"] + last_response = next( + ( + message.get("content") + for message in reversed(messages) + if message.get("role") == "assistant" and message.get("content") + ), + "", + ) + return { + "agent_state": asdict( + _build_retrieval_agent_state( + conversation_inputs=conversation.inputs, + message_count=len(messages), + last_response=last_response, + ) + ) + } + status, span_recorder = _execute_with_trace( conversation, - state_snapshot_policy=_policy( + state_snapshot_policy=build_policy( StateSnapshotInterval.CONVERSATION_TURNS, - extra_state_builder=_build_retrieval_like_extra_state, + extra_state_builder=build_extra_state, ), span_processors=[agui_exporter], contexts=[patch_llm(wayflow_agent.llm, [assistant_message], patch_internal=True)], @@ -351,13 +332,13 @@ def test_agent_state_snapshots_support_the_agui_retrieval_export_flow() -> None: def test_agent_node_turn_state_snapshots_are_mapped_into_the_agent_span_not_llm_spans() -> None: assistant_message = "Hello from agent" - llm = _create_mock_vllm_model("agent") + llm = VllmModel(model_id="mock.model", host_port="http://mock.url", name="agent") agent = WayflowAgent(llm=llm) conversation = agent.start_conversation() conversation.append_user_message("Hi") status, span_recorder = _execute_with_trace( conversation, - state_snapshot_policy=_policy(StateSnapshotInterval.NODE_TURNS), + state_snapshot_policy=build_policy(StateSnapshotInterval.NODE_TURNS), contexts=[patch_llm(llm, [assistant_message], patch_internal=True)], ) diff --git a/wayflowcore/tests/agentspec/test_state_snapshot_tracing_flow.py b/wayflowcore/tests/agentspec/test_state_snapshot_tracing_flow.py index 6bf1ac636..5b196dcfd 100644 --- a/wayflowcore/tests/agentspec/test_state_snapshot_tracing_flow.py +++ b/wayflowcore/tests/agentspec/test_state_snapshot_tracing_flow.py @@ -28,7 +28,7 @@ from wayflowcore.tools import ServerTool from ..testhelpers.statesnapshots import ( - build_state_snapshot_policy, + build_policy, snapshot_message, snapshot_step_histories, ) @@ -71,13 +71,6 @@ async def shutdown_async(self) -> None: return None -def _policy( - interval: StateSnapshotInterval, - **kwargs: Any, -): - return build_state_snapshot_policy(interval, **kwargs) - - def _execute_with_trace( conversation, *, @@ -98,26 +91,6 @@ def _execute_with_trace( return status, span_recorder -async def _execute_with_trace_async( - conversation, - *, - state_snapshot_policy, - span_processors: Sequence[AgentSpecSpanProcessor] = (), - contexts: Sequence[AbstractContextManager[Any]] = (), -): - span_recorder = SnapshotSpanRecorder() - listener = AgentSpecEventListener() - - async with AgentSpecTrace(span_processors=[span_recorder, *span_processors]): - with ExitStack() as stack: - for context in contexts: - stack.enter_context(context) - stack.enter_context(register_event_listeners([listener])) - status = await conversation.execute_async(state_snapshot_policy=state_snapshot_policy) - - return status, span_recorder - - def _spans( span_recorder: SnapshotSpanRecorder, span_type: type[AgentSpecSpan], @@ -141,29 +114,6 @@ def _events( return [event for event in span.events if isinstance(event, event_type)] -def _single_event( - span: AgentSpecSpan, - event_type: type[AgentSpecEvent], -) -> AgentSpecEvent: - return next(event for event in span.events if isinstance(event, event_type)) - - -def _build_tool_state_snapshot_flow() -> Flow: - return Flow.from_steps( - [ - ToolExecutionStep( - tool=ServerTool( - name="say_hi", - description="Say hi", - func=lambda: "hi", - input_descriptors=[], - ) - ), - CompleteStep(name="end"), - ] - ) - - def test_flow_state_snapshots_are_mapped_into_the_flow_span_before_flow_end() -> None: flow = Flow.from_steps( [OutputMessageStep(message_template="Hello"), CompleteStep(name="end")], @@ -172,7 +122,7 @@ def test_flow_state_snapshots_are_mapped_into_the_flow_span_before_flow_end() -> conversation = flow.start_conversation() status, span_recorder = _execute_with_trace( conversation, - state_snapshot_policy=_policy( + state_snapshot_policy=build_policy( StateSnapshotInterval.CONVERSATION_TURNS, extra_state_builder=lambda _conversation: {"ui": {"active_tab": "plan"}}, ), @@ -201,13 +151,16 @@ async def test_flow_state_snapshots_are_mapped_into_the_flow_span_before_flow_en step_names=["single_step", "end"], ) conversation = flow.start_conversation() - status, span_recorder = await _execute_with_trace_async( - conversation, - state_snapshot_policy=_policy( - StateSnapshotInterval.CONVERSATION_TURNS, - extra_state_builder=lambda _conversation: {"ui": {"active_tab": "plan"}}, - ), - ) + span_recorder = SnapshotSpanRecorder() + + async with AgentSpecTrace(span_processors=[span_recorder]): + with register_event_listeners([AgentSpecEventListener()]): + status = await conversation.execute_async( + state_snapshot_policy=build_policy( + StateSnapshotInterval.CONVERSATION_TURNS, + extra_state_builder=lambda _conversation: {"ui": {"active_tab": "plan"}}, + ) + ) assert isinstance(status, FinishedStatus) @@ -232,7 +185,7 @@ def test_node_turn_state_snapshots_are_mapped_into_the_flow_span_not_node_spans( conversation = flow.start_conversation() status, span_recorder = _execute_with_trace( conversation, - state_snapshot_policy=_policy(StateSnapshotInterval.NODE_TURNS), + state_snapshot_policy=build_policy(StateSnapshotInterval.NODE_TURNS), ) assert isinstance(status, FinishedStatus) @@ -289,11 +242,23 @@ def test_internal_flow_state_snapshots_follow_conversation_ownership_for_agent_s interval: StateSnapshotInterval, expected_step_histories: list[list[str]], ) -> None: - flow = _build_tool_state_snapshot_flow() + flow = Flow.from_steps( + [ + ToolExecutionStep( + tool=ServerTool( + name="say_hi", + description="Say hi", + func=lambda: "hi", + input_descriptors=[], + ) + ), + CompleteStep(name="end"), + ] + ) conversation = flow.start_conversation() status, span_recorder = _execute_with_trace( conversation, - state_snapshot_policy=_policy(interval), + state_snapshot_policy=build_policy(interval), ) assert isinstance(status, FinishedStatus) @@ -321,7 +286,7 @@ def test_off_policy_does_not_bridge_state_snapshots_into_agent_spec_spans() -> N conversation = flow.start_conversation() status, span_recorder = _execute_with_trace( conversation, - state_snapshot_policy=_policy(StateSnapshotInterval.OFF), + state_snapshot_policy=build_policy(StateSnapshotInterval.OFF), ) assert isinstance(status, FinishedStatus) @@ -353,7 +318,7 @@ def test_only_the_opening_state_snapshot_is_exported_when_a_turn_raises() -> Non with register_event_listeners([AgentSpecEventListener()]): with pytest.raises(RuntimeError, match="boom"): conversation.execute( - state_snapshot_policy=_policy(StateSnapshotInterval.CONVERSATION_TURNS) + state_snapshot_policy=build_policy(StateSnapshotInterval.CONVERSATION_TURNS) ) flow_span = _single_span(span_recorder, AgentSpecFlowExecutionSpan) diff --git a/wayflowcore/tests/agentspec/test_state_snapshot_tracing_nested.py b/wayflowcore/tests/agentspec/test_state_snapshot_tracing_nested.py index b1eb87237..935ff6a37 100644 --- a/wayflowcore/tests/agentspec/test_state_snapshot_tracing_nested.py +++ b/wayflowcore/tests/agentspec/test_state_snapshot_tracing_nested.py @@ -4,8 +4,6 @@ # (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0) or Universal Permissive License # (UPL) 1.0 (LICENSE-UPL or https://oss.oracle.com/licenses/upl), at your option. -from contextlib import ExitStack - from pyagentspec.tracing.events import Event as AgentSpecEvent from pyagentspec.tracing.events import FlowExecutionStart as AgentSpecFlowExecutionStart from pyagentspec.tracing.events import StateSnapshotEmitted as AgentSpecStateSnapshotEmitted @@ -22,7 +20,7 @@ from wayflowcore.steps import CompleteStep, FlowExecutionStep, OutputMessageStep from ..testhelpers.statesnapshots import ( - build_state_snapshot_policy, + build_policy, snapshot_message, snapshot_runtime_conversation_ids, ) @@ -65,23 +63,6 @@ async def shutdown_async(self) -> None: return None -def _execute_with_trace(conversation) -> tuple[FinishedStatus, SnapshotSpanRecorder]: - span_recorder = SnapshotSpanRecorder() - listener = AgentSpecEventListener() - - with ExitStack() as stack: - stack.enter_context(AgentSpecTrace(span_processors=[span_recorder])) - stack.enter_context(register_event_listeners([listener])) - status = conversation.execute( - state_snapshot_policy=build_state_snapshot_policy( - StateSnapshotInterval.CONVERSATION_TURNS - ) - ) - - assert isinstance(status, FinishedStatus) - return status, span_recorder - - def test_nested_flow_state_snapshots_stay_on_the_root_flow_span_for_shared_conversations() -> None: child_flow = Flow.from_steps( [OutputMessageStep(message_template="child"), CompleteStep(name="end")], @@ -98,8 +79,15 @@ def test_nested_flow_state_snapshots_stay_on_the_root_flow_span_for_shared_conve name="parent_flow", ) conversation = parent_flow.start_conversation() + span_recorder = SnapshotSpanRecorder() - _, span_recorder = _execute_with_trace(conversation) + with AgentSpecTrace(span_processors=[span_recorder]): + with register_event_listeners([AgentSpecEventListener()]): + status = conversation.execute( + state_snapshot_policy=build_policy(StateSnapshotInterval.CONVERSATION_TURNS) + ) + + assert isinstance(status, FinishedStatus) flow_spans = [ span for span in span_recorder.started_spans if isinstance(span, AgentSpecFlowExecutionSpan) diff --git a/wayflowcore/tests/serialization/test_conversation_state_snapshot.py b/wayflowcore/tests/serialization/test_conversation_state_snapshot.py index c6a943b61..d71f31b03 100644 --- a/wayflowcore/tests/serialization/test_conversation_state_snapshot.py +++ b/wayflowcore/tests/serialization/test_conversation_state_snapshot.py @@ -10,15 +10,27 @@ import pytest from wayflowcore.conversation import Conversation +from wayflowcore.executors._flowconversation import FlowConversation +from wayflowcore.executors.executionstatus import FinishedStatus, UserMessageRequestStatus from wayflowcore.flow import Flow from wayflowcore.property import AnyProperty, StringProperty from wayflowcore.serialization import ( + deserialize_conversation, deserialize_conversation_state, dump_conversation_state, dump_variable_state, + load_conversation_state, serialize_conversation_state, ) -from wayflowcore.steps import OutputMessageStep, VariableWriteStep +from wayflowcore.serialization.context import DeserializationContext +from wayflowcore.steps import ( + CompleteStep, + InputMessageStep, + OutputMessageStep, + ToolExecutionStep, + VariableWriteStep, +) +from wayflowcore.tools import ServerTool, register_server_tool from wayflowcore.variable import Variable @@ -65,9 +77,11 @@ def test_dump_conversation_state_is_json_serializable_and_lightweight() -> None: snapshot = dump_conversation_state(conversation) variable_state = dump_variable_state(conversation) - serialized_snapshot = serialize_conversation_state(conversation) + serialized_conversation_state = serialize_conversation_state(conversation) + deserialized_conversation_state = deserialize_conversation_state(serialized_conversation_state) - assert json.loads(json.dumps(snapshot)) == deserialize_conversation_state(serialized_snapshot) + assert json.loads(json.dumps(snapshot)) == snapshot + assert deserialized_conversation_state["_component_type"] == conversation.__class__.__name__ assert variable_state == {"custom": "custom-value"} assert snapshot["conversation"]["component_type"] == "Flow" assert snapshot["conversation"]["messages"][-1]["content"] == "Hello there" @@ -147,3 +161,80 @@ def test_dump_variable_state_rejects_non_json_serializable_values() -> None: with pytest.raises(TypeError, match="Variable 'custom' contains a non-JSON-serializable"): dump_variable_state(conversation) + + +def test_load_conversation_state_restores_a_runnable_conversation() -> None: + flow = Flow.from_steps( + [InputMessageStep("Please answer"), OutputMessageStep("done")], + name="resume_flow", + ) + conversation = flow.start_conversation() + + status = conversation.execute() + assert isinstance(status, UserMessageRequestStatus) + + loaded_conversation = load_conversation_state( + deserialize_conversation_state(serialize_conversation_state(conversation)) + ) + + assert isinstance(loaded_conversation, FlowConversation) + loaded_conversation.append_user_message("hello") + resumed_status = loaded_conversation.execute() + + assert isinstance(resumed_status, FinishedStatus) + assert [message.content for message in loaded_conversation.get_messages()] == [ + "Please answer", + "hello", + "done", + ] + + +def test_deserialize_conversation_restores_a_runnable_conversation() -> None: + flow = Flow.from_steps( + [InputMessageStep("Please answer"), OutputMessageStep("done")], + name="resume_flow", + ) + conversation = flow.start_conversation() + + status = conversation.execute() + assert isinstance(status, UserMessageRequestStatus) + + deserialized_conversation = deserialize_conversation(serialize_conversation_state(conversation)) + + assert isinstance(deserialized_conversation, FlowConversation) + deserialized_conversation.append_user_message("hello") + resumed_status = deserialized_conversation.execute() + + assert isinstance(resumed_status, FinishedStatus) + assert [message.content for message in deserialized_conversation.get_messages()] == [ + "Please answer", + "hello", + "done", + ] + + +def test_load_conversation_state_uses_the_given_deserialization_context() -> None: + tool = ServerTool( + name="say_hi", + description="Say hi", + func=lambda: "hi", + input_descriptors=[], + ) + flow = Flow.from_steps( + [ + ToolExecutionStep(tool=tool), + CompleteStep(name="end"), + ], + name="tool_flow", + ) + + deserialization_context = DeserializationContext() + register_server_tool(tool, deserialization_context.registered_tools) + + conversation = load_conversation_state( + deserialize_conversation_state(serialize_conversation_state(flow.start_conversation())), + deserialization_context=deserialization_context, + ) + + assert isinstance(conversation, FlowConversation) + assert isinstance(conversation.execute(), FinishedStatus) diff --git a/wayflowcore/tests/testhelpers/statesnapshots.py b/wayflowcore/tests/testhelpers/statesnapshots.py index 28e49c606..9f14bb82a 100644 --- a/wayflowcore/tests/testhelpers/statesnapshots.py +++ b/wayflowcore/tests/testhelpers/statesnapshots.py @@ -45,20 +45,11 @@ def __call__(self, event: Event) -> None: self.state_snapshot_events.append(event) -def build_state_snapshot_policy( - interval: StateSnapshotInterval, - **kwargs: Any, -) -> StateSnapshotPolicy: - return StateSnapshotPolicy(state_snapshot_interval=interval, **kwargs) - - -def snapshot_status_type(snapshot_event: Any) -> str | None: - status = snapshot_event.state_snapshot["execution"]["status"] - return status["type"] if status is not None else None - - def snapshot_status_types(snapshot_events: Sequence[Any]) -> list[str | None]: - return [snapshot_status_type(snapshot_event) for snapshot_event in snapshot_events] + return [ + status["type"] if (status := snapshot_event.state_snapshot["execution"]["status"]) else None + for snapshot_event in snapshot_events + ] def snapshot_message(snapshot_event: Any) -> str | None: @@ -68,12 +59,10 @@ def snapshot_message(snapshot_event: Any) -> str | None: return messages[-1].get("content") -def snapshot_runtime_conversation_id(snapshot_event: Any) -> str: - return snapshot_event.state_snapshot["conversation"]["id"] - - def snapshot_runtime_conversation_ids(snapshot_events: Sequence[Any]) -> list[str]: - return [snapshot_runtime_conversation_id(snapshot_event) for snapshot_event in snapshot_events] + return [ + snapshot_event.state_snapshot["conversation"]["id"] for snapshot_event in snapshot_events + ] def snapshot_step_histories(snapshot_events: Sequence[Any]) -> list[list[str]]: @@ -247,7 +236,7 @@ def build_policy( interval: StateSnapshotInterval, **kwargs: object, ) -> StateSnapshotPolicy: - return build_state_snapshot_policy(interval, **kwargs) + return StateSnapshotPolicy(state_snapshot_interval=interval, **kwargs) def assert_terminal_snapshot( From 97ae8e5fdda59a0c0c32d5f1b3d671cad20b2fbd Mon Sep 17 00:00:00 2001 From: Son Le Date: Thu, 19 Mar 2026 10:53:38 +0100 Subject: [PATCH 08/13] clean up conversation state serialization --- docs/wayflowcore/source/core/changelog.rst | 4 + .../src/wayflowcore/serialization/__init__.py | 47 +-- .../wayflowcore/serialization/conversation.py | 352 ++++++++++++------ .../test_conversation_state_snapshot.py | 64 +++- 4 files changed, 318 insertions(+), 149 deletions(-) diff --git a/docs/wayflowcore/source/core/changelog.rst b/docs/wayflowcore/source/core/changelog.rst index 7131c33d7..76b10992b 100644 --- a/docs/wayflowcore/source/core/changelog.rst +++ b/docs/wayflowcore/source/core/changelog.rst @@ -68,6 +68,10 @@ Possibly Breaking Changes Bug fixes ^^^^^^^^^ +* **Serialization imports:** + + Reduced import coupling in conversation state serialization so the public conversation state helpers remain directly re-exported from ``wayflowcore.serialization``. + WayFlow 26.1.1 -------------- diff --git a/wayflowcore/src/wayflowcore/serialization/__init__.py b/wayflowcore/src/wayflowcore/serialization/__init__.py index b5b9672b7..d4cec303e 100644 --- a/wayflowcore/src/wayflowcore/serialization/__init__.py +++ b/wayflowcore/src/wayflowcore/serialization/__init__.py @@ -4,8 +4,14 @@ # (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0) or Universal Permissive License # (UPL) 1.0 (LICENSE-UPL or https://oss.oracle.com/licenses/upl), at your option. -from typing import Any - +from .conversation import ( + deserialize_conversation, + deserialize_conversation_state, + dump_conversation_state, + dump_variable_state, + load_conversation_state, + serialize_conversation_state, +) from .serializer import ( autodeserialize, deserialize, @@ -14,43 +20,6 @@ serialize_to_dict, ) - -def dump_conversation_state(*args: Any, **kwargs: Any) -> Any: - from .conversation import dump_conversation_state as _dump_conversation_state - - return _dump_conversation_state(*args, **kwargs) - - -def serialize_conversation_state(*args: Any, **kwargs: Any) -> Any: - from .conversation import serialize_conversation_state as _serialize_conversation_state - - return _serialize_conversation_state(*args, **kwargs) - - -def deserialize_conversation_state(*args: Any, **kwargs: Any) -> Any: - from .conversation import deserialize_conversation_state as _deserialize_conversation_state - - return _deserialize_conversation_state(*args, **kwargs) - - -def load_conversation_state(*args: Any, **kwargs: Any) -> Any: - from .conversation import load_conversation_state as _load_conversation_state - - return _load_conversation_state(*args, **kwargs) - - -def deserialize_conversation(*args: Any, **kwargs: Any) -> Any: - from .conversation import deserialize_conversation as _deserialize_conversation - - return _deserialize_conversation(*args, **kwargs) - - -def dump_variable_state(*args: Any, **kwargs: Any) -> Any: - from .conversation import dump_variable_state as _dump_variable_state - - return _dump_variable_state(*args, **kwargs) - - __all__ = [ "autodeserialize", "deserialize", diff --git a/wayflowcore/src/wayflowcore/serialization/conversation.py b/wayflowcore/src/wayflowcore/serialization/conversation.py index 3a1539570..f237cbfb4 100644 --- a/wayflowcore/src/wayflowcore/serialization/conversation.py +++ b/wayflowcore/src/wayflowcore/serialization/conversation.py @@ -14,7 +14,6 @@ import yaml -from wayflowcore._utils.formatting import stringify from wayflowcore.executors.executionstatus import ( AuthChallengeRequestStatus, ExecutionStatus, @@ -23,7 +22,6 @@ ToolRequestStatus, UserMessageRequestStatus, ) -from wayflowcore.messagelist import ImageContent, Message, MessageContent, TextContent from wayflowcore.serialization.context import DeserializationContext, SerializationContext from wayflowcore.serialization.serializer import ( autodeserialize_from_dict, @@ -36,6 +34,7 @@ from wayflowcore.conversation import Conversation from wayflowcore.executors._agentconversation import AgentConversation from wayflowcore.executors._flowconversation import FlowConversation + from wayflowcore.messagelist import Message, MessageContent from wayflowcore.serialization.plugins import ( WayflowDeserializationPlugin, WayflowSerializationPlugin, @@ -44,7 +43,23 @@ _UNSET = object() +def _dump_conversation_reference(conversation: "Conversation") -> dict[str, Any]: + return { + "id": conversation.id, + "conversation_id": conversation.conversation_id, + "conversation_type": conversation.__class__.__name__, + } + + +def _dump_component_reference(component: Any) -> dict[str, Any]: + return { + "component_id": component.id, + "component_type": component.__class__.__name__, + } + + def _dump_json_compatible_value(value: Any) -> Any: + from wayflowcore._utils.formatting import stringify from wayflowcore.component import Component from wayflowcore.conversation import Conversation @@ -58,16 +73,9 @@ def _dump_json_compatible_value(value: Any) -> Any: elif isinstance(value, Enum): dumped_value = _dump_json_compatible_value(value.value) elif isinstance(value, Conversation): - dumped_value = { - "id": value.id, - "conversation_id": value.conversation_id, - "conversation_type": value.__class__.__name__, - } + dumped_value = _dump_conversation_reference(value) elif isinstance(value, Component): - dumped_value = { - "component_id": value.id, - "component_type": value.__class__.__name__, - } + dumped_value = _dump_component_reference(value) elif isinstance(value, dict): dumped_value = { str(key): _dump_json_compatible_value(inner_value) for key, inner_value in value.items() @@ -85,16 +93,6 @@ def _dump_json_compatible_value(value: Any) -> Any: return dumped_value -def _dump_variable_value(variable_name: str, value: Any) -> Any: - try: - serialized_value = json.dumps(value, sort_keys=True, allow_nan=False) - except (TypeError, ValueError) as e: - raise TypeError( - f"Variable '{variable_name}' contains a non-JSON-serializable value of type {type(value).__name__}" - ) from e - return cast(Any, json.loads(serialized_value)) - - def _dump_string_keyed_mapping(values: dict[Any, Any]) -> dict[str, Any]: return {str(key): _dump_json_compatible_value(value) for key, value in values.items()} @@ -106,11 +104,10 @@ def _dump_flow_input_output_key_values(values: dict[Any, Any]) -> dict[str, Any] } -def _dump_variable_store(variable_store: dict[str, Any]) -> dict[str, Any]: - return { - variable_name: _dump_variable_value(variable_name, variable_value) - for variable_name, variable_value in variable_store.items() - } +def _dump_tool_requests( + tool_requests: Optional[list[ToolRequest]], +) -> list[Optional[dict[str, Any]]]: + return [_dump_tool_request(tool_request) for tool_request in tool_requests or []] def _dump_tool_request(tool_request: Optional[ToolRequest]) -> Optional[dict[str, Any]]: @@ -139,7 +136,22 @@ def _dump_tool_result(tool_result: Optional[ToolResult]) -> Optional[dict[str, A return dumped_tool_result +def _dump_tool_related_execution_status( + execution_status: ToolRequestStatus | ToolExecutionConfirmationStatus, +) -> dict[str, Any]: + dumped_status = { + "tool_requests": _dump_tool_requests(execution_status.tool_requests), + } + if isinstance(execution_status, ToolRequestStatus): + dumped_status["tool_results"] = [ + _dump_tool_result(tool_result) for tool_result in execution_status._tool_results or [] + ] + return dumped_status + + def _dump_message_content(content: MessageContent) -> dict[str, Any]: + from wayflowcore.messagelist import ImageContent, TextContent + content_type = getattr(content, "type", content.__class__.__name__) if isinstance(content, TextContent): @@ -179,9 +191,7 @@ def _dump_message(message: Message) -> dict[str, Any]: "contents": [_dump_message_content(content) for content in message.contents], } - tool_requests = [ - _dump_tool_request(tool_request) for tool_request in message.tool_requests or [] - ] + tool_requests = _dump_tool_requests(message.tool_requests) dumped_tool_result = _dump_tool_result(message.tool_result) if tool_requests: @@ -192,43 +202,25 @@ def _dump_message(message: Message) -> dict[str, Any]: def _dump_execution_status(execution_status: Optional[ExecutionStatus]) -> Optional[dict[str, Any]]: - dumped_status: dict[str, Any] | None if execution_status is None: - dumped_status = None - else: - dumped_status = { - "type": execution_status.__class__.__name__, - } - - if isinstance(execution_status, FinishedStatus): - dumped_status["output_values"] = _dump_json_compatible_value( - execution_status.output_values - ) - dumped_status["complete_step_name"] = execution_status.complete_step_name - elif isinstance(execution_status, UserMessageRequestStatus): - dumped_status["message"] = _dump_message(execution_status.message) - elif isinstance(execution_status, ToolRequestStatus): - dumped_status["tool_requests"] = [ - _dump_tool_request(tool_request) for tool_request in execution_status.tool_requests - ] - dumped_status["tool_results"] = [ - _dump_tool_result(tool_result) - for tool_result in execution_status._tool_results or [] - ] - elif isinstance(execution_status, ToolExecutionConfirmationStatus): - dumped_status["tool_requests"] = [ - _dump_tool_request(tool_request) for tool_request in execution_status.tool_requests - ] - elif isinstance(execution_status, AuthChallengeRequestStatus): - dumped_status["client_transport_id"] = execution_status.client_transport_id + return None + + dumped_status: dict[str, Any] = {"type": execution_status.__class__.__name__} + if isinstance(execution_status, FinishedStatus): + dumped_status["output_values"] = _dump_json_compatible_value(execution_status.output_values) + dumped_status["complete_step_name"] = execution_status.complete_step_name + elif isinstance(execution_status, UserMessageRequestStatus): + dumped_status["message"] = _dump_message(execution_status.message) + elif isinstance(execution_status, (ToolRequestStatus, ToolExecutionConfirmationStatus)): + dumped_status.update(_dump_tool_related_execution_status(execution_status)) + elif isinstance(execution_status, AuthChallengeRequestStatus): + dumped_status["client_transport_id"] = execution_status.client_transport_id return dumped_status def _dump_conversation_info(conversation: "Conversation") -> dict[str, Any]: return { - "id": conversation.id, - "conversation_id": conversation.conversation_id, - "conversation_type": conversation.__class__.__name__, + **_dump_conversation_reference(conversation), "component_type": conversation.component.__class__.__name__, "name": conversation.name, "inputs": _dump_json_compatible_value(conversation.inputs), @@ -236,7 +228,7 @@ def _dump_conversation_info(conversation: "Conversation") -> dict[str, Any]: } -def _dump_execution_info( +def _dump_common_execution_info( conversation: "Conversation", *, status: object = _UNSET, @@ -275,9 +267,7 @@ def _dump_agent_execution_info(conversation: "AgentConversation") -> dict[str, A return { "curr_iter": conversation.state.curr_iter, "has_confirmed_conversation_exit": conversation.state.has_confirmed_conversation_exit, - "tool_call_queue": [ - _dump_tool_request(tool_request) for tool_request in conversation.state.tool_call_queue - ], + "tool_call_queue": _dump_tool_requests(conversation.state.tool_call_queue), "current_tool_request": _dump_tool_request(conversation.state.current_tool_request), "current_flow_conversation": _dump_json_compatible_value( conversation.state.current_flow_conversation @@ -288,44 +278,60 @@ def _dump_agent_execution_info(conversation: "AgentConversation") -> dict[str, A } +def _dump_component_execution_info(conversation: "Conversation") -> dict[str, Any]: + from wayflowcore.executors._agentconversation import AgentConversation + from wayflowcore.executors._flowconversation import FlowConversation + + if isinstance(conversation, FlowConversation): + return _dump_flow_execution_info(conversation) + if isinstance(conversation, AgentConversation): + return _dump_agent_execution_info(conversation) + return {} + + def dump_conversation_state( conversation: "Conversation", *, status: object = _UNSET, status_handled: object = _UNSET, ) -> dict[str, Any]: - """Return a JSON-serializable runtime snapshot of the conversation state.""" - from wayflowcore.executors._agentconversation import AgentConversation - from wayflowcore.executors._flowconversation import FlowConversation - - if isinstance(conversation, FlowConversation): - execution_info = { - **_dump_execution_info( - conversation, - status=status, - status_handled=status_handled, - ), - **_dump_flow_execution_info(conversation), - } - elif isinstance(conversation, AgentConversation): - execution_info = { - **_dump_execution_info( + """ + Return a JSON-serializable runtime snapshot of a conversation. + + The returned dictionary is intended for inspection, tracing, and state snapshot + emission. It captures the user-visible conversation state and the runtime + execution state without embedding live component objects. Optional ``status`` + and ``status_handled`` overrides are available so callers can snapshot a + slightly adjusted view of the current runtime state without mutating the + conversation itself. + + Parameters + ---------- + conversation: + Conversation instance to snapshot. + status: + Optional execution status override to include in the dumped state instead + of ``conversation.status``. + status_handled: + Optional ``status_handled`` override to include in the dumped state + instead of ``conversation.status_handled``. + + Returns + ------- + dict[str, Any] + JSON-compatible conversation snapshot containing ``conversation`` and + ``execution`` sections. + """ + return { + "conversation": _dump_conversation_info(conversation), + "execution": { + **_dump_common_execution_info( conversation, status=status, status_handled=status_handled, ), - **_dump_agent_execution_info(conversation), - } - else: - execution_info = _dump_execution_info( - conversation, - status=status, - status_handled=status_handled, - ) - - return { - "conversation": _dump_conversation_info(conversation), - "execution": execution_info, + **_dump_component_execution_info(conversation), + }, } @@ -334,7 +340,28 @@ def serialize_conversation_state( serialization_context: Optional[SerializationContext] = None, plugins: Optional[list["WayflowSerializationPlugin"]] = None, ) -> str: - """Serialize a full conversation state into a stable text representation.""" + """ + Serialize a conversation into its stable textual state representation. + + This is the string form meant for storage or transport when the full runtime + conversation needs to be preserved for later loading. Unlike + ``dump_conversation_state()``, this serializes the actual conversation object + graph using WayFlow serialization. + + Parameters + ---------- + conversation: + Conversation instance to serialize. + serialization_context: + Optional serialization context to use. + plugins: + Optional serialization plugins to use. + + Returns + ------- + str + Serialized conversation state string. + """ return serialize( conversation, serialization_context=serialization_context, @@ -343,7 +370,29 @@ def serialize_conversation_state( def deserialize_conversation_state(state: str) -> dict[str, Any]: - """Parse a serialized conversation state string back into a dictionary.""" + """ + Parse a serialized conversation state string into a dictionary. + + This is the dictionary-level counterpart of + ``serialize_conversation_state()``. It is useful when callers need to inspect + or adjust the serialized payload before loading it back into a live + ``Conversation`` object. + + Parameters + ---------- + state: + Serialized conversation state string. + + Returns + ------- + dict[str, Any] + Parsed serialized state. + + Raises + ------ + TypeError + If the serialized payload does not deserialize into a dictionary. + """ loaded_state = yaml.safe_load(state) if not isinstance(loaded_state, dict): raise TypeError("Serialized conversation state must deserialize into a dictionary.") @@ -355,17 +404,40 @@ def load_conversation_state( deserialization_context: Optional[DeserializationContext] = None, plugins: Optional[list["WayflowDeserializationPlugin"]] = None, ) -> "Conversation": - """Reconstruct a conversation from a serialized state dictionary.""" + """ + Reconstruct a live conversation from a serialized state dictionary. + + The input dictionary is expected to come from + ``deserialize_conversation_state()`` or another equivalent WayFlow + serialization source. + + Parameters + ---------- + state: + Serialized conversation state as a dictionary. + deserialization_context: + Optional deserialization context to use. This is the preferred way to + provide tool registries or plugins. + plugins: + Optional deserialization plugins. When a deserialization context is + already provided, plugins should be attached to that context instead. + + Returns + ------- + Conversation + Reconstructed live conversation instance. + + Raises + ------ + TypeError + If the deserialized object is not a ``Conversation``. + """ from wayflowcore.conversation import Conversation - if deserialization_context is None: - deserialization_context = DeserializationContext(plugins=plugins) - elif plugins is not None: - warnings.warn( - "A list of plugins was provided together with a deserialization context instance in `load_conversation_state`. " - "Do not pass the plugins to `load_conversation_state`, but create the context instance passing the list of plugins instead.", - UserWarning, - ) + deserialization_context = _resolve_deserialization_context( + deserialization_context=deserialization_context, + plugins=plugins, + ) conversation = autodeserialize_from_dict(state, deserialization_context) if not isinstance(conversation, Conversation): @@ -380,7 +452,26 @@ def deserialize_conversation( deserialization_context: Optional[DeserializationContext] = None, plugins: Optional[list["WayflowDeserializationPlugin"]] = None, ) -> "Conversation": - """Reconstruct a conversation directly from its serialized state string.""" + """ + Reconstruct a conversation directly from its serialized string form. + + This is a convenience wrapper around + ``deserialize_conversation_state()`` followed by ``load_conversation_state()``. + + Parameters + ---------- + conversation_state: + Serialized conversation state string. + deserialization_context: + Optional deserialization context to use. + plugins: + Optional deserialization plugins. + + Returns + ------- + Conversation + Reconstructed live conversation instance. + """ return load_conversation_state( deserialize_conversation_state(conversation_state), deserialization_context=deserialization_context, @@ -389,16 +480,61 @@ def deserialize_conversation( def dump_variable_state(conversation: "Conversation") -> Optional[dict[str, Any]]: - """Return the JSON-serializable runtime-owned variable state for a conversation.""" + """ + Return the JSON-serializable runtime-owned variable state for a conversation. + + Only flow conversations expose runtime variable storage. For other + conversation types, this returns ``None``. + + Parameters + ---------- + conversation: + Conversation whose runtime-owned variable values should be dumped. + + Returns + ------- + dict[str, Any] | None + JSON-compatible mapping of variable names to values for flow + conversations, otherwise ``None``. + + Raises + ------ + TypeError + If a variable contains a value that cannot be represented as JSON. + """ from wayflowcore.executors._flowconversation import FlowConversation if not isinstance(conversation, FlowConversation): - variable_state = None - else: - variable_state = _dump_variable_store(conversation.state.variable_store) + return None + + variable_state: dict[str, Any] = {} + for variable_name, variable_value in conversation.state.variable_store.items(): + try: + serialized_value = json.dumps(variable_value, sort_keys=True, allow_nan=False) + except (TypeError, ValueError) as e: + raise TypeError( + f"Variable '{variable_name}' contains a non-JSON-serializable value of type {type(variable_value).__name__}" + ) from e + variable_state[variable_name] = cast(Any, json.loads(serialized_value)) return variable_state +def _resolve_deserialization_context( + *, + deserialization_context: Optional[DeserializationContext], + plugins: Optional[list["WayflowDeserializationPlugin"]], +) -> DeserializationContext: + if deserialization_context is None: + return DeserializationContext(plugins=plugins) + if plugins is not None: + warnings.warn( + "A list of plugins was provided together with a deserialization context instance in `load_conversation_state`. " + "Do not pass the plugins to `load_conversation_state`, but create the context instance passing the list of plugins instead.", + UserWarning, + ) + return deserialization_context + + __all__ = [ "deserialize_conversation", "deserialize_conversation_state", diff --git a/wayflowcore/tests/serialization/test_conversation_state_snapshot.py b/wayflowcore/tests/serialization/test_conversation_state_snapshot.py index d71f31b03..29d645d20 100644 --- a/wayflowcore/tests/serialization/test_conversation_state_snapshot.py +++ b/wayflowcore/tests/serialization/test_conversation_state_snapshot.py @@ -11,7 +11,11 @@ from wayflowcore.conversation import Conversation from wayflowcore.executors._flowconversation import FlowConversation -from wayflowcore.executors.executionstatus import FinishedStatus, UserMessageRequestStatus +from wayflowcore.executors.executionstatus import ( + FinishedStatus, + ToolRequestStatus, + UserMessageRequestStatus, +) from wayflowcore.flow import Flow from wayflowcore.property import AnyProperty, StringProperty from wayflowcore.serialization import ( @@ -30,7 +34,7 @@ ToolExecutionStep, VariableWriteStep, ) -from wayflowcore.tools import ServerTool, register_server_tool +from wayflowcore.tools import ClientTool, ServerTool, ToolResult, register_server_tool from wayflowcore.variable import Variable @@ -163,6 +167,62 @@ def test_dump_variable_state_rejects_non_json_serializable_values() -> None: dump_variable_state(conversation) +def test_conversation_state_roundtrip_preserves_pending_tool_results() -> None: + client_tool = ClientTool( + name="client_lookup", + description="Look up some data on the client side", + parameters={}, + ) + flow = Flow.from_steps( + [ + ToolExecutionStep(tool=client_tool), + CompleteStep(name="end"), + ], + name="tool_resume_flow", + ) + conversation = flow.start_conversation() + + status = conversation.execute() + assert isinstance(status, ToolRequestStatus) + + tool_request = status.tool_requests[0] + conversation.append_tool_result( + ToolResult(tool_request_id=tool_request.tool_request_id, content="client-result") + ) + + snapshot = dump_conversation_state(conversation) + assert snapshot["execution"]["status"]["type"] == "ToolRequestStatus" + assert snapshot["execution"]["status"]["tool_results"] == [ + { + "tool_request_id": tool_request.tool_request_id, + "content": "client-result", + } + ] + assert all(message.tool_result is None for message in conversation.get_messages()) + + loaded_conversation = load_conversation_state( + deserialize_conversation_state(serialize_conversation_state(conversation)) + ) + loaded_snapshot = dump_conversation_state(loaded_conversation) + + assert loaded_snapshot["execution"]["status"]["tool_results"] == [ + { + "tool_request_id": tool_request.tool_request_id, + "content": "client-result", + } + ] + + resumed_status = loaded_conversation.execute() + assert isinstance(resumed_status, FinishedStatus) + + tool_result_messages = [ + message.tool_result for message in loaded_conversation.get_messages() if message.tool_result + ] + assert len(tool_result_messages) == 1 + assert tool_result_messages[0].tool_request_id == tool_request.tool_request_id + assert tool_result_messages[0].content == "client-result" + + def test_load_conversation_state_restores_a_runnable_conversation() -> None: flow = Flow.from_steps( [InputMessageStep("Please answer"), OutputMessageStep("done")], From 3f99dc8b621e7ed1fe768f9dbd81cd0b2676bcaa Mon Sep 17 00:00:00 2001 From: Son Le Date: Thu, 19 Mar 2026 11:49:04 +0100 Subject: [PATCH 09/13] Add resumable state snapshot payloads --- .../source/core/howtoguides/howto_tracing.rst | 9 +- wayflowcore/src/wayflowcore/events/event.py | 6 +- .../executors/_statesnapshot_eventlistener.py | 44 +++++++-- .../wayflowcore/serialization/conversation.py | 67 +++++++++++-- .../test_state_snapshot_tracing_agent.py | 40 +++++++- ...ate_snapshot_runtime_conversation_turns.py | 99 ++++++++++++++++++- ...t_state_snapshot_runtime_internal_turns.py | 8 +- .../test_conversation_state_snapshot.py | 78 +++++++++++++++ .../tests/testhelpers/statesnapshots.py | 12 ++- 9 files changed, 337 insertions(+), 26 deletions(-) diff --git a/docs/wayflowcore/source/core/howtoguides/howto_tracing.rst b/docs/wayflowcore/source/core/howtoguides/howto_tracing.rst index 734611193..8adff5617 100644 --- a/docs/wayflowcore/source/core/howtoguides/howto_tracing.rst +++ b/docs/wayflowcore/source/core/howtoguides/howto_tracing.rst @@ -165,7 +165,14 @@ does not define WayFlow variable semantics. When ``include_variable_state=True`` variable values must already be JSON-serializable. ``StateSnapshotEvent.conversation_id`` is the logical/public conversation id, while ``state_snapshot["conversation"]["id"]`` identifies the concrete runtime -conversation instance that emitted the snapshot. +conversation instance that emitted the snapshot. The lightweight +``state_snapshot["conversation"]`` / ``state_snapshot["execution"]`` sections are +intended for inspection and tracing, while +``state_snapshot["conversation_state"]`` contains the authoritative serialized +WayFlow conversation blob used for resumability. To restore from that blob, use +``wayflowcore.serialization.deserialize_conversation(...)`` or +``deserialize_conversation_state(...)`` together with +``load_conversation_state(...)``. Each policy emits snapshots only for its own boundaries. ``CONVERSATION_TURNS`` emits opening and closing turn snapshots. Internal policies emit only step, iteration, and/or tool snapshots. Snapshots are emitted only when the diff --git a/wayflowcore/src/wayflowcore/events/event.py b/wayflowcore/src/wayflowcore/events/event.py index 11c2a24cf..c912aaaa7 100644 --- a/wayflowcore/src/wayflowcore/events/event.py +++ b/wayflowcore/src/wayflowcore/events/event.py @@ -800,7 +800,11 @@ class StateSnapshotEvent(Event): ``conversation_id`` is the logical/public conversation id. When a snapshot is present, the emitting runtime conversation instance is identified by - ``state_snapshot["conversation"]["id"]``. + ``state_snapshot["conversation"]["id"]``. WayFlow-emitted payloads also place + the authoritative resumable serialized state in + ``state_snapshot["conversation_state"]`` while keeping + ``state_snapshot["conversation"]`` and ``state_snapshot["execution"]`` as the + lightweight inspection view. """ conversation_id: str = field(default_factory=_required_attribute("conversation_id", str)) diff --git a/wayflowcore/src/wayflowcore/executors/_statesnapshot_eventlistener.py b/wayflowcore/src/wayflowcore/executors/_statesnapshot_eventlistener.py index c02238b30..cd9ff8eb4 100644 --- a/wayflowcore/src/wayflowcore/executors/_statesnapshot_eventlistener.py +++ b/wayflowcore/src/wayflowcore/executors/_statesnapshot_eventlistener.py @@ -37,11 +37,18 @@ from wayflowcore.executors._executor import ExecutionInterruptedException from wayflowcore.executors.executionstatus import ExecutionStatus from wayflowcore.executors.statesnapshotpolicy import StateSnapshotInterval, StateSnapshotPolicy -from wayflowcore.serialization.conversation import dump_conversation_state, dump_variable_state +from wayflowcore.serialization.conversation import ( + _serialize_conversation_state_with_runtime_overrides, + dump_conversation_state, + dump_variable_state, +) from wayflowcore.tracing.span import AgentExecutionSpan, FlowExecutionSpan, get_current_span logger = logging.getLogger(__name__) +_STATE_SNAPSHOT_RUNTIME = "wayflow" +_STATE_SNAPSHOT_SCHEMA_VERSION = 1 + _STATE_SNAPSHOT_POLICIES: ContextVar[Dict[str, StateSnapshotPolicy]] = ContextVar( "_STATE_SNAPSHOT_POLICIES", @@ -171,6 +178,34 @@ def _build_variable_state( return None +def _build_state_snapshot_payload( + conversation: Conversation, + *, + execution_status: ExecutionStatus | None, +) -> dict[str, Any]: + # Snapshots should expose the canonical pre-consumption view of a turn, not + # transient runtime bookkeeping. The serialized conversation_state blob must + # match the same logical boundary as the lightweight conversation/execution + # sections so it can be restored directly. + snapshot_status_handled = False + dumped_state = dump_conversation_state( + conversation, + status=execution_status, + status_handled=snapshot_status_handled, + ) + return { + "runtime": _STATE_SNAPSHOT_RUNTIME, + "schema_version": _STATE_SNAPSHOT_SCHEMA_VERSION, + "conversation_state": _serialize_conversation_state_with_runtime_overrides( + conversation, + status=execution_status, + status_handled=snapshot_status_handled, + ), + "conversation": dumped_state["conversation"], + "execution": dumped_state["execution"], + } + + def _record_state_snapshot( conversation: Conversation, required_snapshot_interval: StateSnapshotInterval, @@ -187,12 +222,9 @@ def _record_state_snapshot( record_event( StateSnapshotEvent( conversation_id=conversation.conversation_id, - state_snapshot=dump_conversation_state( + state_snapshot=_build_state_snapshot_payload( conversation, - status=execution_status, - # Snapshots should expose the canonical pre-consumption view - # of a turn, not transient runtime bookkeeping. - status_handled=False, + execution_status=execution_status, ), extra_state=_build_extra_state(conversation, state_snapshot_policy), variable_state=_build_variable_state(conversation, state_snapshot_policy), diff --git a/wayflowcore/src/wayflowcore/serialization/conversation.py b/wayflowcore/src/wayflowcore/serialization/conversation.py index f237cbfb4..feafa8db2 100644 --- a/wayflowcore/src/wayflowcore/serialization/conversation.py +++ b/wayflowcore/src/wayflowcore/serialization/conversation.py @@ -8,9 +8,10 @@ import json import warnings +from contextlib import contextmanager from datetime import datetime from enum import Enum -from typing import TYPE_CHECKING, Any, Optional, cast +from typing import TYPE_CHECKING, Any, Iterator, Optional, cast import yaml @@ -242,7 +243,6 @@ def _dump_common_execution_info( "status_handled": ( conversation.status_handled if status_handled is _UNSET else cast(bool, status_handled) ), - "token_usage": _dump_json_compatible_value(conversation.token_usage), } @@ -298,10 +298,11 @@ def dump_conversation_state( """ Return a JSON-serializable runtime snapshot of a conversation. - The returned dictionary is intended for inspection, tracing, and state snapshot - emission. It captures the user-visible conversation state and the runtime - execution state without embedding live component objects. Optional ``status`` - and ``status_handled`` overrides are available so callers can snapshot a + The returned dictionary is a lightweight inspection/tracing view. It captures + the user-visible conversation state and the runtime execution state without + embedding live component objects or the authoritative serialized + conversation blob used for resumability. Optional ``status`` and + ``status_handled`` overrides are available so callers can snapshot a slightly adjusted view of the current runtime state without mutating the conversation itself. @@ -335,6 +336,46 @@ def dump_conversation_state( } +@contextmanager +def _use_conversation_runtime_overrides( + conversation: "Conversation", + *, + status: Optional[ExecutionStatus], + status_handled: bool, +) -> Iterator[None]: + previous_status = conversation.status + previous_status_handled = conversation.status_handled + conversation.status = status + conversation.status_handled = status_handled + try: + yield + finally: + conversation.status = previous_status + conversation.status_handled = previous_status_handled + + +def _serialize_conversation_state_with_runtime_overrides( + conversation: "Conversation", + *, + status: Optional[ExecutionStatus], + status_handled: bool, + serialization_context: Optional[SerializationContext] = None, + plugins: Optional[list["WayflowSerializationPlugin"]] = None, +) -> str: + """Serialize a conversation as if its runtime status fields already matched a snapshot.""" + + with _use_conversation_runtime_overrides( + conversation, + status=status, + status_handled=status_handled, + ): + return serialize_conversation_state( + conversation, + serialization_context=serialization_context, + plugins=plugins, + ) + + def serialize_conversation_state( conversation: "Conversation", serialization_context: Optional[SerializationContext] = None, @@ -346,7 +387,9 @@ def serialize_conversation_state( This is the string form meant for storage or transport when the full runtime conversation needs to be preserved for later loading. Unlike ``dump_conversation_state()``, this serializes the actual conversation object - graph using WayFlow serialization. + graph using WayFlow serialization. For WayFlow-emitted state snapshots, this + is the authoritative resumable blob stored under + ``state_snapshot["conversation_state"]``. Parameters ---------- @@ -376,7 +419,8 @@ def deserialize_conversation_state(state: str) -> dict[str, Any]: This is the dictionary-level counterpart of ``serialize_conversation_state()``. It is useful when callers need to inspect or adjust the serialized payload before loading it back into a live - ``Conversation`` object. + ``Conversation`` object, including the ``conversation_state`` string emitted + in WayFlow state snapshots. Parameters ---------- @@ -409,7 +453,10 @@ def load_conversation_state( The input dictionary is expected to come from ``deserialize_conversation_state()`` or another equivalent WayFlow - serialization source. + serialization source. When resuming from a WayFlow state snapshot payload, + first parse ``state_snapshot["conversation_state"]`` with + ``deserialize_conversation_state()`` and then pass the resulting dictionary + here. Parameters ---------- @@ -457,6 +504,8 @@ def deserialize_conversation( This is a convenience wrapper around ``deserialize_conversation_state()`` followed by ``load_conversation_state()``. + It is the simplest restore API when you already have the serialized + ``conversation_state`` string from a WayFlow snapshot payload. Parameters ---------- diff --git a/wayflowcore/tests/agentspec/test_state_snapshot_tracing_agent.py b/wayflowcore/tests/agentspec/test_state_snapshot_tracing_agent.py index a9698ea09..3ee5e0bbb 100644 --- a/wayflowcore/tests/agentspec/test_state_snapshot_tracing_agent.py +++ b/wayflowcore/tests/agentspec/test_state_snapshot_tracing_agent.py @@ -26,7 +26,7 @@ from wayflowcore.executors.executionstatus import UserMessageRequestStatus from wayflowcore.executors.statesnapshotpolicy import StateSnapshotInterval from wayflowcore.models.vllmmodel import VllmModel -from wayflowcore.serialization import dump_conversation_state +from wayflowcore.serialization import deserialize_conversation, dump_conversation_state from ..testhelpers.patching import patch_llm from ..testhelpers.statesnapshots import ( @@ -302,7 +302,12 @@ def build_extra_state(conversation) -> dict[str, Any]: assert len(state_snapshot_events) == 2 final_snapshot_event = state_snapshot_events[-1] - runtime_messages = final_snapshot_event.state_snapshot["conversation"]["messages"] + assert final_snapshot_event.state_snapshot is not None + snapshot_payload = final_snapshot_event.state_snapshot + assert isinstance(snapshot_payload["conversation_state"], str) + restored_conversation = deserialize_conversation(snapshot_payload["conversation_state"]) + restored_snapshot = dump_conversation_state(restored_conversation) + runtime_messages = snapshot_payload["conversation"]["messages"] expected_agent_state = asdict( _build_retrieval_agent_state( conversation_inputs=_RETRIEVAL_INPUTS, @@ -312,10 +317,37 @@ def build_extra_state(conversation) -> dict[str, Any]: ) assert final_snapshot_event.conversation_id == conversation.conversation_id + assert snapshot_payload["runtime"] == "wayflow" + assert snapshot_payload["schema_version"] == 1 + assert restored_snapshot["conversation"] == snapshot_payload["conversation"] assert ( - final_snapshot_event.state_snapshot["conversation"]["inputs"]["input"] - == _RETRIEVAL_INPUTS["input"] + restored_snapshot["execution"]["current_step_name"] + == snapshot_payload["execution"]["current_step_name"] ) + assert restored_snapshot["execution"]["status"] == snapshot_payload["execution"]["status"] + assert restored_snapshot["execution"]["status_handled"] is False + assert restored_snapshot["execution"]["curr_iter"] == snapshot_payload["execution"]["curr_iter"] + assert ( + restored_snapshot["execution"]["has_confirmed_conversation_exit"] + == snapshot_payload["execution"]["has_confirmed_conversation_exit"] + ) + assert ( + restored_snapshot["execution"]["tool_call_queue"] + == snapshot_payload["execution"]["tool_call_queue"] + ) + assert ( + restored_snapshot["execution"]["current_tool_request"] + == snapshot_payload["execution"]["current_tool_request"] + ) + assert ( + restored_snapshot["execution"]["current_flow_conversation"] + == snapshot_payload["execution"]["current_flow_conversation"] + ) + assert ( + restored_snapshot["execution"]["current_sub_component_conversations"] + == snapshot_payload["execution"]["current_sub_component_conversations"] + ) + assert snapshot_payload["conversation"]["inputs"]["input"] == _RETRIEVAL_INPUTS["input"] assert runtime_messages[-1]["content"] == assistant_message assert final_snapshot_event.extra_state == {"agent_state": expected_agent_state} diff --git a/wayflowcore/tests/events/test_state_snapshot_runtime_conversation_turns.py b/wayflowcore/tests/events/test_state_snapshot_runtime_conversation_turns.py index 78fbd47ee..054b090e9 100644 --- a/wayflowcore/tests/events/test_state_snapshot_runtime_conversation_turns.py +++ b/wayflowcore/tests/events/test_state_snapshot_runtime_conversation_turns.py @@ -4,13 +4,23 @@ # (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0) or Universal Permissive License # (UPL) 1.0 (LICENSE-UPL or https://oss.oracle.com/licenses/upl), at your option. +from typing import Any + import pytest from wayflowcore.events.eventlistener import register_event_listeners from wayflowcore.executors._events.event import EventType -from wayflowcore.executors.executionstatus import FinishedStatus, UserMessageRequestStatus +from wayflowcore.executors.executionstatus import ( + FinishedStatus, + ToolRequestStatus, + UserMessageRequestStatus, +) from wayflowcore.executors.interrupts.executioninterrupt import InterruptedExecutionStatus from wayflowcore.executors.statesnapshotpolicy import StateSnapshotInterval +from wayflowcore.flow import Flow +from wayflowcore.serialization import deserialize_conversation, dump_conversation_state +from wayflowcore.steps import CompleteStep, InputMessageStep, OutputMessageStep, ToolExecutionStep +from wayflowcore.tools import ClientTool, ToolResult from ..test_interrupts import OnEventExecutionInterrupt from ..testhelpers.statesnapshots import ( @@ -27,6 +37,19 @@ ) +def _restore_conversation_from_snapshot_payload(snapshot_payload: dict[str, Any]): + assert snapshot_payload["runtime"] == "wayflow" + assert snapshot_payload["schema_version"] == 1 + assert isinstance(snapshot_payload["conversation_state"], str) + + restored_conversation = deserialize_conversation(snapshot_payload["conversation_state"]) + assert dump_conversation_state(restored_conversation) == { + "conversation": snapshot_payload["conversation"], + "execution": snapshot_payload["execution"], + } + return restored_conversation + + @pytest.mark.parametrize( ( "conversation_factory", @@ -206,3 +229,77 @@ def test_conversation_turn_policy_reflects_real_interrupt_side_effects_once() -> assert conversation.inputs["preview_count"] == 1 assert snapshot_status_types(state_snapshot_events) == [None, "FinishedStatus"] assert state_snapshot_events[-1].state_snapshot["conversation"]["inputs"]["preview_count"] == 1 + + +def test_conversation_turn_snapshot_payload_can_resume_waiting_for_client_tool_result() -> None: + client_tool = ClientTool( + name="client_lookup", + description="Look up some data on the client side", + parameters={}, + ) + conversation = Flow.from_steps( + [ + ToolExecutionStep(tool=client_tool), + CompleteStep(name="end"), + ], + name="snapshot_client_tool_resume_flow", + ).start_conversation() + + status, state_snapshot_events = execute_with_state_snapshots( + conversation, + state_snapshot_policy=build_policy(StateSnapshotInterval.CONVERSATION_TURNS), + ) + + assert isinstance(status, ToolRequestStatus) + assert state_snapshot_events[-1].state_snapshot is not None + restored_conversation = _restore_conversation_from_snapshot_payload( + state_snapshot_events[-1].state_snapshot + ) + assert isinstance(restored_conversation.status, ToolRequestStatus) + + tool_request = restored_conversation.status.tool_requests[0] + restored_conversation.append_tool_result( + ToolResult(tool_request_id=tool_request.tool_request_id, content="client-result") + ) + resumed_status = restored_conversation.execute() + + assert isinstance(resumed_status, FinishedStatus) + tool_result_messages = [ + message.tool_result + for message in restored_conversation.get_messages() + if message.tool_result + ] + assert len(tool_result_messages) == 1 + assert tool_result_messages[0].tool_request_id == tool_request.tool_request_id + assert tool_result_messages[0].content == "client-result" + + +@pytest.mark.anyio +async def test_conversation_turn_snapshot_payload_can_resume_waiting_for_user_input_async() -> None: + conversation = Flow.from_steps( + [ + InputMessageStep("Please answer"), + OutputMessageStep("done"), + ], + name="snapshot_user_resume_flow", + ).start_conversation() + + status, state_snapshot_events = await execute_with_state_snapshots_async( + conversation, + state_snapshot_policy=build_policy(StateSnapshotInterval.CONVERSATION_TURNS), + ) + + assert isinstance(status, UserMessageRequestStatus) + assert state_snapshot_events[-1].state_snapshot is not None + restored_conversation = _restore_conversation_from_snapshot_payload( + state_snapshot_events[-1].state_snapshot + ) + restored_conversation.append_user_message("hello") + resumed_status = await restored_conversation.execute_async() + + assert isinstance(resumed_status, FinishedStatus) + assert [message.content for message in restored_conversation.get_messages()] == [ + "Please answer", + "hello", + "done", + ] diff --git a/wayflowcore/tests/events/test_state_snapshot_runtime_internal_turns.py b/wayflowcore/tests/events/test_state_snapshot_runtime_internal_turns.py index 663154d54..9a2c01a50 100644 --- a/wayflowcore/tests/events/test_state_snapshot_runtime_internal_turns.py +++ b/wayflowcore/tests/events/test_state_snapshot_runtime_internal_turns.py @@ -30,6 +30,12 @@ ) +class _SerializableDummyModel(DummyModel): + @property + def config(self) -> dict[str, object]: + return {"model_id": self.model_id} + + @pytest.mark.parametrize( ( "conversation_factory", @@ -176,7 +182,7 @@ def test_flow_node_turn_policy_uses_iteration_start_and_end_boundaries() -> None def test_internal_snapshots_do_not_reuse_the_previous_turn_status() -> None: - llm = DummyModel() + llm = _SerializableDummyModel() llm.set_next_output(["Hello from agent", "Hello again"]) conversation = Agent(llm=llm).start_conversation() conversation.append_user_message("Hi") diff --git a/wayflowcore/tests/serialization/test_conversation_state_snapshot.py b/wayflowcore/tests/serialization/test_conversation_state_snapshot.py index 29d645d20..2d3ff2433 100644 --- a/wayflowcore/tests/serialization/test_conversation_state_snapshot.py +++ b/wayflowcore/tests/serialization/test_conversation_state_snapshot.py @@ -9,13 +9,16 @@ import pytest +from wayflowcore.controlconnection import ControlFlowEdge from wayflowcore.conversation import Conversation +from wayflowcore.dataconnection import DataFlowEdge from wayflowcore.executors._flowconversation import FlowConversation from wayflowcore.executors.executionstatus import ( FinishedStatus, ToolRequestStatus, UserMessageRequestStatus, ) +from wayflowcore.executors.statesnapshotpolicy import StateSnapshotInterval from wayflowcore.flow import Flow from wayflowcore.property import AnyProperty, StringProperty from wayflowcore.serialization import ( @@ -32,11 +35,14 @@ InputMessageStep, OutputMessageStep, ToolExecutionStep, + VariableReadStep, VariableWriteStep, ) from wayflowcore.tools import ClientTool, ServerTool, ToolResult, register_server_tool from wayflowcore.variable import Variable +from ..testhelpers.statesnapshots import build_policy, execute_with_state_snapshots + class _UnserializableValue: def __str__(self) -> str: @@ -298,3 +304,75 @@ def test_load_conversation_state_uses_the_given_deserialization_context() -> Non assert isinstance(conversation, FlowConversation) assert isinstance(conversation.execute(), FinishedStatus) + + +def test_emitted_snapshot_conversation_state_restores_variable_dependent_continuation() -> None: + customer_name = Variable( + name="customer_name", + type=StringProperty(), + description="Customer name persisted across resumable snapshots", + ) + capture_name = VariableWriteStep( + variable=customer_name, + input_mapping={VariableWriteStep.VALUE: customer_name.name}, + name="capture_name", + ) + ask_follow_up = InputMessageStep( + message_template="How can I help {{customer_name}}?", + name="ask_follow_up", + ) + read_name = VariableReadStep(variable=customer_name, name="read_name") + final_message = OutputMessageStep( + message_template="Stored {{stored_name}}. Reply: {{reply}}", + name="final_message", + ) + flow = Flow( + begin_step=capture_name, + steps={ + "capture_name": capture_name, + "ask_follow_up": ask_follow_up, + "read_name": read_name, + "final_message": final_message, + }, + control_flow_edges=[ + ControlFlowEdge(capture_name, ask_follow_up), + ControlFlowEdge(ask_follow_up, read_name), + ControlFlowEdge(read_name, final_message), + ControlFlowEdge(final_message, None), + ], + data_flow_edges=[ + DataFlowEdge( + ask_follow_up, + InputMessageStep.USER_PROVIDED_INPUT, + final_message, + "reply", + ), + DataFlowEdge(read_name, VariableReadStep.VALUE, final_message, "stored_name"), + ], + variables=[customer_name], + name="snapshot_variable_resume_flow", + ) + conversation = flow.start_conversation(inputs={customer_name.name: "Alice"}) + + status, state_snapshot_events = execute_with_state_snapshots( + conversation, + state_snapshot_policy=build_policy(StateSnapshotInterval.CONVERSATION_TURNS), + ) + + assert isinstance(status, UserMessageRequestStatus) + assert state_snapshot_events[-1].state_snapshot is not None + snapshot_payload = state_snapshot_events[-1].state_snapshot + assert isinstance(snapshot_payload["conversation_state"], str) + restored_conversation = deserialize_conversation(snapshot_payload["conversation_state"]) + + assert dump_variable_state(restored_conversation) == {"customer_name": "Alice"} + + restored_conversation.append_user_message("Need pricing") + resumed_status = restored_conversation.execute() + + assert isinstance(resumed_status, FinishedStatus) + assert [message.content for message in restored_conversation.get_messages()] == [ + "How can I help Alice?", + "Need pricing", + "Stored Alice. Reply: Need pricing", + ] diff --git a/wayflowcore/tests/testhelpers/statesnapshots.py b/wayflowcore/tests/testhelpers/statesnapshots.py index 9f14bb82a..4b7290ebc 100644 --- a/wayflowcore/tests/testhelpers/statesnapshots.py +++ b/wayflowcore/tests/testhelpers/statesnapshots.py @@ -45,6 +45,12 @@ def __call__(self, event: Event) -> None: self.state_snapshot_events.append(event) +class _SnapshotSerializableDummyModel(DummyModel): + @property + def config(self) -> dict[str, Any]: + return {"model_id": self.model_id} + + def snapshot_status_types(snapshot_events: Sequence[Any]) -> list[str | None]: return [ status["type"] if (status := snapshot_event.state_snapshot["execution"]["status"]) else None @@ -161,7 +167,7 @@ def create_output_flow_conversation(message: str = "Hello") -> Conversation: def create_agent_conversation(message: str = "Hello from agent") -> Conversation: - llm = DummyModel() + llm = _SnapshotSerializableDummyModel() llm.set_next_output(message) conversation = Agent(llm=llm).start_conversation() conversation.append_user_message("Hi") @@ -174,7 +180,7 @@ def do_nothing_tool() -> str: """Do nothing tool.""" return "Tool called successfully" - llm = DummyModel() + llm = _SnapshotSerializableDummyModel() llm.set_next_output( { "Please use the do_nothing_tool": Message( @@ -190,7 +196,7 @@ def do_nothing_tool() -> str: def create_nested_agent_step_flow_conversation() -> Conversation: - llm = DummyModel() + llm = _SnapshotSerializableDummyModel() llm.set_next_output("agent answer") child_agent = Agent(llm=llm) conversation = Flow.from_steps( From 678871ec3dc8069d57a0e5b4d942ed93649a825c Mon Sep 17 00:00:00 2001 From: Son Le Date: Thu, 19 Mar 2026 12:03:50 +0100 Subject: [PATCH 10/13] Finalize resumable state snapshot tracing semantics --- .../source/core/howtoguides/howto_tracing.rst | 38 +-- .../src/wayflowcore/agentspec/tracing.py | 57 ++++- wayflowcore/src/wayflowcore/events/event.py | 14 +- .../executors/_statesnapshot_eventlistener.py | 147 ++++++++--- .../executors/statesnapshotpolicy.py | 21 +- .../wayflowcore/steps/flowexecutionstep.py | 4 +- .../steps/parallelflowexecutionstep.py | 12 +- .../test_state_snapshot_tracing_agent.py | 96 ++++++- .../test_state_snapshot_tracing_flow.py | 105 +++++++- .../test_state_snapshot_tracing_nested.py | 152 ++++++++++- ...ate_snapshot_runtime_conversation_turns.py | 124 ++++++++- ...t_state_snapshot_runtime_internal_turns.py | 121 ++++++--- .../test_state_snapshot_runtime_nested.py | 238 +++++++++++++++--- .../test_state_snapshot_runtime_resilience.py | 41 ++- .../steps/test_flow_execution_step.py | 53 ++++ .../test_parallel_flow_execution_step.py | 50 +++- .../test_conversation_state_snapshot.py | 8 +- .../tests/testhelpers/statesnapshots.py | 9 +- 18 files changed, 1093 insertions(+), 197 deletions(-) diff --git a/docs/wayflowcore/source/core/howtoguides/howto_tracing.rst b/docs/wayflowcore/source/core/howtoguides/howto_tracing.rst index 8adff5617..426fa9222 100644 --- a/docs/wayflowcore/source/core/howtoguides/howto_tracing.rst +++ b/docs/wayflowcore/source/core/howtoguides/howto_tracing.rst @@ -159,10 +159,11 @@ WayFlow can also emit ``StateSnapshotEvent`` payloads at conversation, step, and boundaries by passing a ``StateSnapshotPolicy`` to ``conversation.execute()`` or ``conversation.execute_async()``. When ``AgentSpecEventListener`` is registered and the installed ``pyagentspec`` version exposes ``StateSnapshotEmitted``, these runtime -snapshots are bridged into Agent Spec ``StateSnapshotEmitted`` events on the active -span. The WayFlow-only ``variable_state`` payload is not bridged, because Agent Spec -does not define WayFlow variable semantics. When ``include_variable_state=True``, -variable values must already be JSON-serializable. +snapshots are bridged into Agent Spec ``StateSnapshotEmitted`` events on the +owning conversation/component span. The WayFlow-only ``variable_state`` payload +is not bridged, because Agent Spec does not define WayFlow variable semantics. +When ``include_variable_state=True``, variable values must already be +JSON-serializable. ``StateSnapshotEvent.conversation_id`` is the logical/public conversation id, while ``state_snapshot["conversation"]["id"]`` identifies the concrete runtime conversation instance that emitted the snapshot. The lightweight @@ -173,17 +174,26 @@ WayFlow conversation blob used for resumability. To restore from that blob, use ``wayflowcore.serialization.deserialize_conversation(...)`` or ``deserialize_conversation_state(...)`` together with ``load_conversation_state(...)``. -Each policy emits snapshots only for its own boundaries. ``CONVERSATION_TURNS`` -emits opening and closing turn snapshots. Internal policies emit only step, -iteration, and/or tool snapshots. Snapshots are emitted only when the -corresponding boundary event occurs. If a turn -raises or is interrupted before its matching closing event, WayFlow does not -synthesize an extra unwind snapshot. For step and tool intervals, the latest -already-emitted start snapshot is the recovery point. +Snapshot intervals are cumulative. ``TOOL_TURNS`` includes the +``CONVERSATION_TURNS`` checkpoints, ``NODE_TURNS`` includes the +``CONVERSATION_TURNS`` checkpoints, and ``ALL_INTERNAL_TURNS`` includes all +conversation, tool, and node boundaries. Only the conversation passed directly +to ``execute()`` / ``execute_async()`` emits the authoritative turn-level +resumability checkpoints for that run. Nested child conversations may still +emit internal tracing snapshots, but those child-runtime snapshots are tracing +checkpoints unless and until a stronger contract is introduced. +``CONVERSATION_TURNS`` emits an opening snapshot at execution start and a +closing turn snapshot at the end of the turn. That closing payload is emitted +before the live conversation object commits the new status, but its serialized +payload is synthesized so that after ``execute()`` returns it matches the +committed state seen by the caller. Snapshots are emitted only when the +corresponding boundary event occurs. If a turn is interrupted mid-turn, WayFlow +does not synthesize a turn-end snapshot; the latest already-emitted opening or +internal snapshot is the recovery point. For flows, ``NODE_TURNS`` uses flow-iteration start/end events, which align with -per-step execution while keeping the end snapshot on committed flow state. For -agents, the same policy emits snapshots around each decision-loop iteration. -Tool start/end snapshots are emitted only for ``TOOL_TURNS`` and ``ALL_INTERNAL_TURNS``. +per-step execution. For agents, the same policy emits snapshots around each +decision-loop iteration. Tool start/end snapshots are emitted only for +``TOOL_TURNS`` and ``ALL_INTERNAL_TURNS``. .. code-block:: python diff --git a/wayflowcore/src/wayflowcore/agentspec/tracing.py b/wayflowcore/src/wayflowcore/agentspec/tracing.py index 973e0d565..bfbfeb9df 100644 --- a/wayflowcore/src/wayflowcore/agentspec/tracing.py +++ b/wayflowcore/src/wayflowcore/agentspec/tracing.py @@ -138,6 +138,14 @@ def _get_snapshot_owner_span( # without changing the StateSnapshotEvent handling below. owner = self._conversation_span_owners.get(event.conversation_id) if owner is not None: + snapshot_conversation = ( + event.state_snapshot.get("conversation") if event.state_snapshot else {} + ) + snapshot_runtime_conversation_id = ( + snapshot_conversation.get("id") if isinstance(snapshot_conversation, dict) else None + ) + if snapshot_runtime_conversation_id != owner.runtime_conversation_id: + return None return owner.span return current_agentspec_span @@ -148,10 +156,10 @@ def __call__(self, event: Event) -> None: # - we map the wayflow event to the corresponding agent spec one, and we emit that # - if it corresponds to a span end event, we retrieve the corresponding agent spec span, and we close it current_span = get_current_span() - if not current_span: + if current_span is None: return - current_agentspec_span = self.agentspec_spans_registry.get(current_span.span_id, None) current_span_name = current_span.name or "" + current_agentspec_span = self.agentspec_spans_registry.get(current_span.span_id, None) event_name = event.name or "" match event: case LlmGenerationRequestEvent(): @@ -383,6 +391,11 @@ def __call__(self, event: Event) -> None: extra_state=event.extra_state, ) owner_span.add_event(snapshot_event) + if owner_span.end_time is None and any( + isinstance(span_event, (AgentSpecFlowExecutionEnd, AgentSpecAgentExecutionEnd)) + for span_event in owner_span.events + ): + owner_span.end() case FlowExecutionStartedEvent(): # Flow execution starts. Create the new agent spec span, start it, add the event agentspec_flow = cast( @@ -405,7 +418,9 @@ def __call__(self, event: Event) -> None: ) self._register_current_conversation_span(current_agentspec_span) case FlowExecutionFinishedEvent(): - # Flow execution ends. Add the event to the agent spec span and close the span + # Flow execution ends. Add the event to the agent spec span. If this span owns + # the logical conversation checkpoint stream, delay closing until the final + # StateSnapshotEvent is bridged so span processors still see the final snapshot. if not current_agentspec_span: return agentspec_flow = cast( @@ -426,7 +441,21 @@ def __call__(self, event: Event) -> None: branch_selected=branch_selected, ) ) - current_agentspec_span.end() + active_conversation = self._get_active_wayflow_conversation() + owner = ( + self._conversation_span_owners.get(active_conversation.conversation_id) + if active_conversation is not None + else None + ) + if ( + owner is None + or owner.span is not current_agentspec_span + or not any( + isinstance(span_event, AgentSpecStateSnapshotEmitted) + for span_event in current_agentspec_span.events + ) + ): + current_agentspec_span.end() case AgentExecutionStartedEvent(): # Agent execution starts. Create the new agent spec span, start it, add the event agentspec_agent = cast( @@ -449,7 +478,9 @@ def __call__(self, event: Event) -> None: ) self._register_current_conversation_span(current_agentspec_span) case AgentExecutionFinishedEvent(): - # Agent execution ends. Add the event to the agent spec span and close the span + # Agent execution ends. Add the event to the agent spec span. If this span owns + # the logical conversation checkpoint stream, delay closing until the final + # StateSnapshotEvent is bridged so span processors still see the final snapshot. if not current_agentspec_span: return agentspec_agent = cast( @@ -468,7 +499,21 @@ def __call__(self, event: Event) -> None: outputs=outputs, ) ) - current_agentspec_span.end() + active_conversation = self._get_active_wayflow_conversation() + owner = ( + self._conversation_span_owners.get(active_conversation.conversation_id) + if active_conversation is not None + else None + ) + if ( + owner is None + or owner.span is not current_agentspec_span + or not any( + isinstance(span_event, AgentSpecStateSnapshotEmitted) + for span_event in current_agentspec_span.events + ) + ): + current_agentspec_span.end() case ExceptionRaisedEvent(): if not current_agentspec_span: return diff --git a/wayflowcore/src/wayflowcore/events/event.py b/wayflowcore/src/wayflowcore/events/event.py index c912aaaa7..7d7db8348 100644 --- a/wayflowcore/src/wayflowcore/events/event.py +++ b/wayflowcore/src/wayflowcore/events/event.py @@ -799,12 +799,16 @@ class StateSnapshotEvent(Event): """Event emitted by WayFlow when a conversation state snapshot is recorded. ``conversation_id`` is the logical/public conversation id. When a snapshot is - present, the emitting runtime conversation instance is identified by - ``state_snapshot["conversation"]["id"]``. WayFlow-emitted payloads also place - the authoritative resumable serialized state in + present, ``state_snapshot["conversation"]["id"]`` identifies the runtime + conversation instance described by the payload. WayFlow-emitted payloads + also place the authoritative serialized state in ``state_snapshot["conversation_state"]`` while keeping - ``state_snapshot["conversation"]`` and ``state_snapshot["execution"]`` as the - lightweight inspection view. + ``state_snapshot["conversation"]`` and ``state_snapshot["execution"]`` as + the lightweight inspection view. Snapshots emitted for the conversation that + began the current ``execute()`` / ``execute_async()`` run are the + resumable checkpoints for that run; nested child snapshots are primarily + tracing checkpoints and may be filtered by downstream bridges that need a + single checkpoint owner per logical conversation. """ conversation_id: str = field(default_factory=_required_attribute("conversation_id", str)) diff --git a/wayflowcore/src/wayflowcore/executors/_statesnapshot_eventlistener.py b/wayflowcore/src/wayflowcore/executors/_statesnapshot_eventlistener.py index cd9ff8eb4..6fc38aebb 100644 --- a/wayflowcore/src/wayflowcore/executors/_statesnapshot_eventlistener.py +++ b/wayflowcore/src/wayflowcore/executors/_statesnapshot_eventlistener.py @@ -38,9 +38,11 @@ from wayflowcore.executors.executionstatus import ExecutionStatus from wayflowcore.executors.statesnapshotpolicy import StateSnapshotInterval, StateSnapshotPolicy from wayflowcore.serialization.conversation import ( + _UNSET, _serialize_conversation_state_with_runtime_overrides, dump_conversation_state, dump_variable_state, + serialize_conversation_state, ) from wayflowcore.tracing.span import AgentExecutionSpan, FlowExecutionSpan, get_current_span @@ -56,6 +58,12 @@ ) """Execution-local mapping of active conversations to their effective snapshot policy.""" +_STATE_SNAPSHOT_EXECUTION_ROOT_CONVERSATION_ID: ContextVar[Optional[str]] = ContextVar( + "_STATE_SNAPSHOT_EXECUTION_ROOT_CONVERSATION_ID", + default=None, +) +"""Runtime conversation id for the root conversation of the current snapshot-enabled run.""" + def _get_state_snapshot_policies( return_copy: bool = True, @@ -73,6 +81,21 @@ def _get_parent_state_snapshot_policy( return _get_state_snapshot_policy(active_conversations[-1]) +def get_effective_state_snapshot_policy_for_conversation( + conversation: Conversation, + state_snapshot_policy: Optional[StateSnapshotPolicy], +) -> Optional[StateSnapshotPolicy]: + return ( + state_snapshot_policy + if state_snapshot_policy is not None + else _get_parent_state_snapshot_policy(conversation) + ) + + +def _is_state_snapshot_execution_root(conversation: Conversation) -> bool: + return _STATE_SNAPSHOT_EXECUTION_ROOT_CONVERSATION_ID.get() == conversation.id + + def _get_state_snapshot_policy( conversation: Conversation, ) -> Optional[StateSnapshotPolicy]: @@ -99,6 +122,22 @@ def _use_state_snapshot_policy( _STATE_SNAPSHOT_POLICIES.reset(token) +@contextmanager +def _use_state_snapshot_execution_root( + conversation: Conversation, +) -> Iterator[None]: + current_root_conversation_id = _STATE_SNAPSHOT_EXECUTION_ROOT_CONVERSATION_ID.get() + if current_root_conversation_id is not None: + yield + return + + token = _STATE_SNAPSHOT_EXECUTION_ROOT_CONVERSATION_ID.set(conversation.id) + try: + yield + finally: + _STATE_SNAPSHOT_EXECUTION_ROOT_CONVERSATION_ID.reset(token) + + def _build_extra_state( conversation: Conversation, state_snapshot_policy: StateSnapshotPolicy, @@ -148,15 +187,26 @@ def _get_snapshot_policy_for_interval( if snapshot_interval == StateSnapshotInterval.OFF: return None - if snapshot_interval == required_snapshot_interval: - return state_snapshot_policy - - if required_snapshot_interval == StateSnapshotInterval.CONVERSATION_TURNS: - return None - - if snapshot_interval == StateSnapshotInterval.ALL_INTERNAL_TURNS: + included_intervals = { + StateSnapshotInterval.CONVERSATION_TURNS: { + StateSnapshotInterval.CONVERSATION_TURNS, + }, + StateSnapshotInterval.TOOL_TURNS: { + StateSnapshotInterval.CONVERSATION_TURNS, + StateSnapshotInterval.TOOL_TURNS, + }, + StateSnapshotInterval.NODE_TURNS: { + StateSnapshotInterval.CONVERSATION_TURNS, + StateSnapshotInterval.NODE_TURNS, + }, + StateSnapshotInterval.ALL_INTERNAL_TURNS: { + StateSnapshotInterval.CONVERSATION_TURNS, + StateSnapshotInterval.TOOL_TURNS, + StateSnapshotInterval.NODE_TURNS, + }, + } + if required_snapshot_interval in included_intervals[snapshot_interval]: return state_snapshot_policy - return None @@ -181,26 +231,37 @@ def _build_variable_state( def _build_state_snapshot_payload( conversation: Conversation, *, - execution_status: ExecutionStatus | None, + status: object = _UNSET, + status_handled: object = _UNSET, ) -> dict[str, Any]: - # Snapshots should expose the canonical pre-consumption view of a turn, not - # transient runtime bookkeeping. The serialized conversation_state blob must - # match the same logical boundary as the lightweight conversation/execution - # sections so it can be restored directly. - snapshot_status_handled = False + # The snapshot payload should match the intended runtime view at the + # boundary where it is emitted. Opening/tool/node snapshots mask any + # previous turn status, and turn-end snapshots can override the runtime + # fields before the live conversation object commits that new status. dumped_state = dump_conversation_state( conversation, - status=execution_status, - status_handled=snapshot_status_handled, + status=status, + status_handled=status_handled, + ) + conversation_state = ( + serialize_conversation_state(conversation) + if status is _UNSET and status_handled is _UNSET + else _serialize_conversation_state_with_runtime_overrides( + conversation, + status=cast( + Optional[ExecutionStatus], + conversation.status if status is _UNSET else status, + ), + status_handled=cast( + bool, + conversation.status_handled if status_handled is _UNSET else status_handled, + ), + ) ) return { "runtime": _STATE_SNAPSHOT_RUNTIME, "schema_version": _STATE_SNAPSHOT_SCHEMA_VERSION, - "conversation_state": _serialize_conversation_state_with_runtime_overrides( - conversation, - status=execution_status, - status_handled=snapshot_status_handled, - ), + "conversation_state": conversation_state, "conversation": dumped_state["conversation"], "execution": dumped_state["execution"], } @@ -210,7 +271,8 @@ def _record_state_snapshot( conversation: Conversation, required_snapshot_interval: StateSnapshotInterval, *, - execution_status: ExecutionStatus | None, + status: object = _UNSET, + status_handled: object = _UNSET, ) -> None: state_snapshot_policy = _get_snapshot_policy_for_interval( conversation, required_snapshot_interval @@ -218,13 +280,24 @@ def _record_state_snapshot( if state_snapshot_policy is None: return + if ( + required_snapshot_interval == StateSnapshotInterval.CONVERSATION_TURNS + and not _is_state_snapshot_execution_root(conversation) + ): + # The conversation passed to `execute()` / `execute_async()` owns the + # resumable turn-level checkpoint stream. Nested child conversations may + # still emit internal tracing snapshots, but they do not emit competing + # conversation-turn checkpoints for the same run. + return + try: record_event( StateSnapshotEvent( conversation_id=conversation.conversation_id, state_snapshot=_build_state_snapshot_payload( conversation, - execution_status=execution_status, + status=status, + status_handled=status_handled, ), extra_state=_build_extra_state(conversation, state_snapshot_policy), variable_state=_build_variable_state(conversation, state_snapshot_policy), @@ -252,12 +325,15 @@ def __init__( def _record_snapshot( self, required_snapshot_interval: StateSnapshotInterval, - execution_status: ExecutionStatus | None = None, + *, + status: object = None, + status_handled: object = False, ) -> None: _record_state_snapshot( self.conversation, required_snapshot_interval, - execution_status=execution_status, + status=status, + status_handled=status_handled, ) def _handle_pre_interrupt_event( @@ -303,13 +379,15 @@ def _handle_post_interrupt_event(self, event: Event) -> None: ) | AgentExecutionFinishedEvent(execution_status=execution_status): self._record_snapshot( StateSnapshotInterval.CONVERSATION_TURNS, - execution_status, + status=execution_status, + status_handled=False, ) case ExceptionRaisedEvent(exception=ExecutionInterruptedException() as exception): if self._should_record_interrupted_turn_end_snapshot(): self._record_snapshot( StateSnapshotInterval.CONVERSATION_TURNS, - exception.execution_status, + status=exception.execution_status, + status_handled=False, ) def __call__(self, event: Event) -> None: @@ -364,16 +442,18 @@ def get_state_snapshot_execution_context_for_conversation( Activate the effective snapshot policy for one `conversation.execute(...)` turn. Child conversations inherit the currently active parent policy unless they - explicitly override it. When snapshots are enabled, listener registration - happens here in the runtime order the execution model depends on: + explicitly override it. Only the conversation that started the current + snapshot-enabled execution emits conversation-turn checkpoints; nested + children may still emit internal snapshots that describe their own runtime + state. When snapshots are enabled, listener registration happens here in the + runtime order the execution model depends on: 1. pre-interrupt snapshot listener 2. interrupts listener 3. post-interrupt snapshot listener """ - active_state_snapshot_policy = ( - state_snapshot_policy - if state_snapshot_policy is not None - else _get_parent_state_snapshot_policy(conversation) + active_state_snapshot_policy = get_effective_state_snapshot_policy_for_conversation( + conversation, + state_snapshot_policy, ) with _use_state_snapshot_policy(conversation, active_state_snapshot_policy): @@ -386,6 +466,7 @@ def get_state_snapshot_execution_context_for_conversation( ) with ( + _use_state_snapshot_execution_root(conversation), get_state_snapshot_event_listener_context_for_conversation( conversation, post_interrupts=False, diff --git a/wayflowcore/src/wayflowcore/executors/statesnapshotpolicy.py b/wayflowcore/src/wayflowcore/executors/statesnapshotpolicy.py index a047700ae..baa7cd6b3 100644 --- a/wayflowcore/src/wayflowcore/executors/statesnapshotpolicy.py +++ b/wayflowcore/src/wayflowcore/executors/statesnapshotpolicy.py @@ -20,21 +20,24 @@ class StateSnapshotInterval(str, Enum): `CONVERSATION_TURNS` Emit an opening turn snapshot before execution starts and a closing - turn snapshot when the turn finishes or is interrupted at execution - end. This is the default policy because it gives a stable turn-level - checkpoint without emitting snapshots for every internal step. + turn snapshot at the end of the turn. The closing payload is emitted + before the live conversation commits the new status, but the payload is + synthesized so it matches the post-return conversation state. This is + the default policy because it gives a stable turn-level checkpoint + without emitting snapshots for every internal step. `TOOL_TURNS` - Emit snapshots around each tool invocation (`TOOL_START` and - `TOOL_END`) only. + Emit the `CONVERSATION_TURNS` snapshots plus snapshots around each tool + invocation (`TOOL_START` and `TOOL_END`). `NODE_TURNS` - Emit snapshots around each internal node boundary only. For flows this - means per-step snapshots; for agents it maps to decision-loop - iteration boundaries. + Emit the `CONVERSATION_TURNS` snapshots plus snapshots around each + internal node boundary. For flows this means per-step snapshots; for + agents it maps to decision-loop iteration boundaries. `ALL_INTERNAL_TURNS` - Emit all tool and node snapshots, without the broader turn snapshots. + Emit the `CONVERSATION_TURNS`, `TOOL_TURNS`, and `NODE_TURNS` + snapshots. `OFF` Disable state snapshot emission entirely. diff --git a/wayflowcore/src/wayflowcore/steps/flowexecutionstep.py b/wayflowcore/src/wayflowcore/steps/flowexecutionstep.py index 1e5b5c0cb..8a0b66a18 100644 --- a/wayflowcore/src/wayflowcore/steps/flowexecutionstep.py +++ b/wayflowcore/src/wayflowcore/steps/flowexecutionstep.py @@ -250,7 +250,9 @@ async def _invoke_step_async( flow=self.flow, inputs=inputs, ) - status = await sub_conversation.execute_async() + status = sub_conversation.status + if not isinstance(status, FinishedStatus): + status = await sub_conversation.execute_async() if isinstance(status, InterruptedExecutionStatus): return StepResult( diff --git a/wayflowcore/src/wayflowcore/steps/parallelflowexecutionstep.py b/wayflowcore/src/wayflowcore/steps/parallelflowexecutionstep.py index 16bfc8fe1..75878abc0 100644 --- a/wayflowcore/src/wayflowcore/steps/parallelflowexecutionstep.py +++ b/wayflowcore/src/wayflowcore/steps/parallelflowexecutionstep.py @@ -12,7 +12,6 @@ from wayflowcore._metadata import MetadataType from wayflowcore._utils.async_helpers import run_async_function_in_parallel -from wayflowcore.executors._flowexecutor import FlowConversationExecutor from wayflowcore.executors.executionstatus import ExecutionStatus, FinishedStatus from wayflowcore.executors.interrupts.executioninterrupt import InterruptedExecutionStatus from wayflowcore.property import Property @@ -264,13 +263,13 @@ async def _invoke_step_async( sub_conversations_to_run = [ sub_conversation for sub_conversation in sub_conversations - if sub_conversation.status is None or not isinstance(sub_conversation, FinishedStatus) + if not isinstance(sub_conversation.status, FinishedStatus) ] # We collect the statuses of the conversations that did already finish as we need them for the cleanup finished_statuses = [ sub_conversation.status for sub_conversation in sub_conversations - if sub_conversation.status is not None and isinstance(sub_conversation, FinishedStatus) + if isinstance(sub_conversation.status, FinishedStatus) ] async def _run_single_flow_target(sub_conv: "FlowConversation") -> ExecutionStatus: @@ -311,11 +310,8 @@ async def _run_single_flow_target(sub_conv: "FlowConversation") -> ExecutionStat f"Illegal response from a subflow: some subflow returned a non-finished status: {non_finished_status}" ) - for sub_conv in sub_conversations: - FlowConversationExecutor().cleanup_sub_conversation( - sub_conv.state, - self, - ) + for flow in self.flows: + conversation._cleanup_sub_conversation(self, flow.id) if not all(isinstance(status, FinishedStatus) for status in statuses): raise RuntimeError( diff --git a/wayflowcore/tests/agentspec/test_state_snapshot_tracing_agent.py b/wayflowcore/tests/agentspec/test_state_snapshot_tracing_agent.py index 3ee5e0bbb..fe3350996 100644 --- a/wayflowcore/tests/agentspec/test_state_snapshot_tracing_agent.py +++ b/wayflowcore/tests/agentspec/test_state_snapshot_tracing_agent.py @@ -11,6 +11,7 @@ from pyagentspec.adapters.wayflow import AgentSpecLoader from pyagentspec.agent import Agent as AgentSpecAgent from pyagentspec.llms import VllmConfig +from pyagentspec.tracing.events import AgentExecutionEnd as AgentSpecAgentExecutionEnd from pyagentspec.tracing.events import AgentExecutionStart as AgentSpecAgentExecutionStart from pyagentspec.tracing.events import Event as AgentSpecEvent from pyagentspec.tracing.events import StateSnapshotEmitted as AgentSpecStateSnapshotEmitted @@ -24,13 +25,12 @@ from wayflowcore.agentspec.tracing import AgentSpecEventListener from wayflowcore.events.eventlistener import register_event_listeners from wayflowcore.executors.executionstatus import UserMessageRequestStatus -from wayflowcore.executors.statesnapshotpolicy import StateSnapshotInterval +from wayflowcore.executors.statesnapshotpolicy import StateSnapshotInterval, StateSnapshotPolicy from wayflowcore.models.vllmmodel import VllmModel from wayflowcore.serialization import deserialize_conversation, dump_conversation_state from ..testhelpers.patching import patch_llm from ..testhelpers.statesnapshots import ( - build_policy, snapshot_message, snapshot_status_types, ) @@ -159,6 +159,42 @@ async def shutdown_async(self) -> None: return None +class SnapshotEventsSeenAtSpanEndRecorder(AgentSpecSpanProcessor): + def __init__(self) -> None: + super().__init__() + self.events_by_span_id: dict[str, list[AgentSpecEvent]] = {} + + def on_start(self, span: AgentSpecSpan) -> None: + return None + + async def on_start_async(self, span: AgentSpecSpan) -> None: + return None + + def on_end(self, span: AgentSpecSpan) -> None: + self.events_by_span_id[span.id] = list(span.events) + + async def on_end_async(self, span: AgentSpecSpan) -> None: + self.events_by_span_id[span.id] = list(span.events) + + def on_event(self, event: AgentSpecEvent, span: AgentSpecSpan) -> None: + return None + + async def on_event_async(self, event: AgentSpecEvent, span: AgentSpecSpan) -> None: + return None + + def startup(self) -> None: + return None + + def shutdown(self) -> None: + return None + + async def startup_async(self) -> None: + return None + + async def shutdown_async(self) -> None: + return None + + _RETRIEVAL_INPUTS = { "input": "How many orders last week?", "thread_id": "thread-123", @@ -286,8 +322,8 @@ def build_extra_state(conversation) -> dict[str, Any]: status, span_recorder = _execute_with_trace( conversation, - state_snapshot_policy=build_policy( - StateSnapshotInterval.CONVERSATION_TURNS, + state_snapshot_policy=StateSnapshotPolicy( + state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS, extra_state_builder=build_extra_state, ), span_processors=[agui_exporter], @@ -370,7 +406,9 @@ def test_agent_node_turn_state_snapshots_are_mapped_into_the_agent_span_not_llm_ conversation.append_user_message("Hi") status, span_recorder = _execute_with_trace( conversation, - state_snapshot_policy=build_policy(StateSnapshotInterval.NODE_TURNS), + state_snapshot_policy=StateSnapshotPolicy( + state_snapshot_interval=StateSnapshotInterval.NODE_TURNS + ), contexts=[patch_llm(llm, [assistant_message], patch_internal=True)], ) @@ -379,12 +417,19 @@ def test_agent_node_turn_state_snapshots_are_mapped_into_the_agent_span_not_llm_ agent_span = _single_span(span_recorder, AgentSpecAgentExecutionSpan) state_snapshot_events = _events(agent_span, AgentSpecStateSnapshotEmitted) - assert len(state_snapshot_events) == 2 + assert len(state_snapshot_events) == 4 assert [event.state_snapshot["execution"]["curr_iter"] for event in state_snapshot_events] == [ + 0, 0, 1, + 1, + ] + assert snapshot_status_types(state_snapshot_events) == [ + None, + None, + None, + "UserMessageRequestStatus", ] - assert snapshot_status_types(state_snapshot_events) == [None, None] assert snapshot_message(state_snapshot_events[-1]) == assistant_message llm_spans = _spans(span_recorder, AgentSpecLlmGenerationSpan) @@ -394,3 +439,40 @@ def test_agent_node_turn_state_snapshots_are_mapped_into_the_agent_span_not_llm_ for span in llm_spans for event in span.events ) + + +def test_agent_final_state_snapshot_is_visible_to_span_processors_inside_on_end() -> None: + assistant_message = "Hello from agent" + llm = VllmModel(model_id="mock.model", host_port="http://mock.url", name="agent") + agent = WayflowAgent(llm=llm) + conversation = agent.start_conversation() + conversation.append_user_message("Hi") + on_end_recorder = SnapshotEventsSeenAtSpanEndRecorder() + + status, span_recorder = _execute_with_trace( + conversation, + state_snapshot_policy=StateSnapshotPolicy( + state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS + ), + span_processors=[on_end_recorder], + contexts=[patch_llm(llm, [assistant_message], patch_internal=True)], + ) + + assert isinstance(status, UserMessageRequestStatus) + + agent_span = _single_span(span_recorder, AgentSpecAgentExecutionSpan) + events_seen_at_end = on_end_recorder.events_by_span_id[agent_span.id] + + assert any(isinstance(event, AgentSpecAgentExecutionEnd) for event in events_seen_at_end) + assert ( + len( + [ + event + for event in events_seen_at_end + if isinstance(event, AgentSpecStateSnapshotEmitted) + ] + ) + == 2 + ) + assert isinstance(events_seen_at_end[-1], AgentSpecStateSnapshotEmitted) + assert snapshot_message(events_seen_at_end[-1]) == assistant_message diff --git a/wayflowcore/tests/agentspec/test_state_snapshot_tracing_flow.py b/wayflowcore/tests/agentspec/test_state_snapshot_tracing_flow.py index 5b196dcfd..075390f7e 100644 --- a/wayflowcore/tests/agentspec/test_state_snapshot_tracing_flow.py +++ b/wayflowcore/tests/agentspec/test_state_snapshot_tracing_flow.py @@ -22,13 +22,12 @@ from wayflowcore.agentspec.tracing import AgentSpecEventListener from wayflowcore.events.eventlistener import register_event_listeners from wayflowcore.executors.executionstatus import FinishedStatus -from wayflowcore.executors.statesnapshotpolicy import StateSnapshotInterval +from wayflowcore.executors.statesnapshotpolicy import StateSnapshotInterval, StateSnapshotPolicy from wayflowcore.flow import Flow from wayflowcore.steps import CompleteStep, OutputMessageStep, ToolExecutionStep from wayflowcore.tools import ServerTool from ..testhelpers.statesnapshots import ( - build_policy, snapshot_message, snapshot_step_histories, ) @@ -71,6 +70,42 @@ async def shutdown_async(self) -> None: return None +class SnapshotEventsSeenAtSpanEndRecorder(AgentSpecSpanProcessor): + def __init__(self) -> None: + super().__init__() + self.events_by_span_id: dict[str, list[AgentSpecEvent]] = {} + + def on_start(self, span: AgentSpecSpan) -> None: + return None + + async def on_start_async(self, span: AgentSpecSpan) -> None: + return None + + def on_end(self, span: AgentSpecSpan) -> None: + self.events_by_span_id[span.id] = list(span.events) + + async def on_end_async(self, span: AgentSpecSpan) -> None: + self.events_by_span_id[span.id] = list(span.events) + + def on_event(self, event: AgentSpecEvent, span: AgentSpecSpan) -> None: + return None + + async def on_event_async(self, event: AgentSpecEvent, span: AgentSpecSpan) -> None: + return None + + def startup(self) -> None: + return None + + def shutdown(self) -> None: + return None + + async def startup_async(self) -> None: + return None + + async def shutdown_async(self) -> None: + return None + + def _execute_with_trace( conversation, *, @@ -122,8 +157,8 @@ def test_flow_state_snapshots_are_mapped_into_the_flow_span_before_flow_end() -> conversation = flow.start_conversation() status, span_recorder = _execute_with_trace( conversation, - state_snapshot_policy=build_policy( - StateSnapshotInterval.CONVERSATION_TURNS, + state_snapshot_policy=StateSnapshotPolicy( + state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS, extra_state_builder=lambda _conversation: {"ui": {"active_tab": "plan"}}, ), ) @@ -144,6 +179,42 @@ def test_flow_state_snapshots_are_mapped_into_the_flow_span_before_flow_end() -> assert flow_span in span_recorder.ended_spans +def test_flow_final_state_snapshot_is_visible_to_span_processors_inside_on_end() -> None: + flow = Flow.from_steps( + [OutputMessageStep(message_template="Hello"), CompleteStep(name="end")], + step_names=["single_step", "end"], + ) + conversation = flow.start_conversation() + on_end_recorder = SnapshotEventsSeenAtSpanEndRecorder() + + status, span_recorder = _execute_with_trace( + conversation, + state_snapshot_policy=StateSnapshotPolicy( + state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS + ), + span_processors=[on_end_recorder], + ) + + assert isinstance(status, FinishedStatus) + + flow_span = _single_span(span_recorder, AgentSpecFlowExecutionSpan) + events_seen_at_end = on_end_recorder.events_by_span_id[flow_span.id] + + assert any(isinstance(event, AgentSpecFlowExecutionEnd) for event in events_seen_at_end) + assert ( + len( + [ + event + for event in events_seen_at_end + if isinstance(event, AgentSpecStateSnapshotEmitted) + ] + ) + == 2 + ) + assert isinstance(events_seen_at_end[-1], AgentSpecStateSnapshotEmitted) + assert snapshot_message(events_seen_at_end[-1]) == "Hello" + + @pytest.mark.anyio async def test_flow_state_snapshots_are_mapped_into_the_flow_span_before_flow_end_async() -> None: flow = Flow.from_steps( @@ -156,8 +227,8 @@ async def test_flow_state_snapshots_are_mapped_into_the_flow_span_before_flow_en async with AgentSpecTrace(span_processors=[span_recorder]): with register_event_listeners([AgentSpecEventListener()]): status = await conversation.execute_async( - state_snapshot_policy=build_policy( - StateSnapshotInterval.CONVERSATION_TURNS, + state_snapshot_policy=StateSnapshotPolicy( + state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS, extra_state_builder=lambda _conversation: {"ui": {"active_tab": "plan"}}, ) ) @@ -185,7 +256,9 @@ def test_node_turn_state_snapshots_are_mapped_into_the_flow_span_not_node_spans( conversation = flow.start_conversation() status, span_recorder = _execute_with_trace( conversation, - state_snapshot_policy=build_policy(StateSnapshotInterval.NODE_TURNS), + state_snapshot_policy=StateSnapshotPolicy( + state_snapshot_interval=StateSnapshotInterval.NODE_TURNS + ), ) assert isinstance(status, FinishedStatus) @@ -193,14 +266,16 @@ def test_node_turn_state_snapshots_are_mapped_into_the_flow_span_not_node_spans( flow_span = _single_span(span_recorder, AgentSpecFlowExecutionSpan) flow_snapshot_events = _events(flow_span, AgentSpecStateSnapshotEmitted) - assert len(flow_snapshot_events) == 6 + assert len(flow_snapshot_events) == 8 assert snapshot_step_histories(flow_snapshot_events) == [ + [], [], ["__StartStep__"], ["__StartStep__"], ["__StartStep__", "single_step"], ["__StartStep__", "single_step"], ["__StartStep__", "single_step", "end"], + ["__StartStep__", "single_step", "end"], ] node_spans = _spans(span_recorder, AgentSpecNodeExecutionSpan) assert node_spans @@ -217,14 +292,17 @@ def test_node_turn_state_snapshots_are_mapped_into_the_flow_span_not_node_spans( pytest.param( StateSnapshotInterval.TOOL_TURNS, [ + [], ["__StartStep__", "step_0"], ["__StartStep__", "step_0"], + ["__StartStep__", "step_0", "end"], ], id="tool_turns", ), pytest.param( StateSnapshotInterval.ALL_INTERNAL_TURNS, [ + [], [], ["__StartStep__"], ["__StartStep__"], @@ -233,6 +311,7 @@ def test_node_turn_state_snapshots_are_mapped_into_the_flow_span_not_node_spans( ["__StartStep__", "step_0"], ["__StartStep__", "step_0"], ["__StartStep__", "step_0", "end"], + ["__StartStep__", "step_0", "end"], ], id="all_internal_turns", ), @@ -258,7 +337,7 @@ def test_internal_flow_state_snapshots_follow_conversation_ownership_for_agent_s conversation = flow.start_conversation() status, span_recorder = _execute_with_trace( conversation, - state_snapshot_policy=build_policy(interval), + state_snapshot_policy=StateSnapshotPolicy(state_snapshot_interval=interval), ) assert isinstance(status, FinishedStatus) @@ -286,7 +365,9 @@ def test_off_policy_does_not_bridge_state_snapshots_into_agent_spec_spans() -> N conversation = flow.start_conversation() status, span_recorder = _execute_with_trace( conversation, - state_snapshot_policy=build_policy(StateSnapshotInterval.OFF), + state_snapshot_policy=StateSnapshotPolicy( + state_snapshot_interval=StateSnapshotInterval.OFF + ), ) assert isinstance(status, FinishedStatus) @@ -318,7 +399,9 @@ def test_only_the_opening_state_snapshot_is_exported_when_a_turn_raises() -> Non with register_event_listeners([AgentSpecEventListener()]): with pytest.raises(RuntimeError, match="boom"): conversation.execute( - state_snapshot_policy=build_policy(StateSnapshotInterval.CONVERSATION_TURNS) + state_snapshot_policy=StateSnapshotPolicy( + state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS + ) ) flow_span = _single_span(span_recorder, AgentSpecFlowExecutionSpan) diff --git a/wayflowcore/tests/agentspec/test_state_snapshot_tracing_nested.py b/wayflowcore/tests/agentspec/test_state_snapshot_tracing_nested.py index 935ff6a37..067602000 100644 --- a/wayflowcore/tests/agentspec/test_state_snapshot_tracing_nested.py +++ b/wayflowcore/tests/agentspec/test_state_snapshot_tracing_nested.py @@ -15,12 +15,11 @@ from wayflowcore.agentspec.tracing import AgentSpecEventListener from wayflowcore.events.eventlistener import register_event_listeners from wayflowcore.executors.executionstatus import FinishedStatus -from wayflowcore.executors.statesnapshotpolicy import StateSnapshotInterval +from wayflowcore.executors.statesnapshotpolicy import StateSnapshotInterval, StateSnapshotPolicy from wayflowcore.flow import Flow -from wayflowcore.steps import CompleteStep, FlowExecutionStep, OutputMessageStep +from wayflowcore.steps import CompleteStep, FlowExecutionStep, OutputMessageStep, ParallelMapStep from ..testhelpers.statesnapshots import ( - build_policy, snapshot_message, snapshot_runtime_conversation_ids, ) @@ -63,6 +62,45 @@ async def shutdown_async(self) -> None: return None +class SnapshotRuntimeIdsByConversationExporter(AgentSpecSpanProcessor): + def __init__(self) -> None: + super().__init__() + self.runtime_ids_by_conversation_id: dict[str, list[str]] = {} + + def on_start(self, span: AgentSpecSpan) -> None: + return None + + async def on_start_async(self, span: AgentSpecSpan) -> None: + return None + + def on_end(self, span: AgentSpecSpan) -> None: + return None + + async def on_end_async(self, span: AgentSpecSpan) -> None: + return None + + def on_event(self, event: AgentSpecEvent, span: AgentSpecSpan) -> None: + if isinstance(event, AgentSpecStateSnapshotEmitted) and event.state_snapshot is not None: + self.runtime_ids_by_conversation_id.setdefault(event.conversation_id, []).append( + event.state_snapshot["conversation"]["id"] + ) + + async def on_event_async(self, event: AgentSpecEvent, span: AgentSpecSpan) -> None: + self.on_event(event, span) + + def startup(self) -> None: + return None + + def shutdown(self) -> None: + return None + + async def startup_async(self) -> None: + return None + + async def shutdown_async(self) -> None: + return None + + def test_nested_flow_state_snapshots_stay_on_the_root_flow_span_for_shared_conversations() -> None: child_flow = Flow.from_steps( [OutputMessageStep(message_template="child"), CompleteStep(name="end")], @@ -84,7 +122,9 @@ def test_nested_flow_state_snapshots_stay_on_the_root_flow_span_for_shared_conve with AgentSpecTrace(span_processors=[span_recorder]): with register_event_listeners([AgentSpecEventListener()]): status = conversation.execute( - state_snapshot_policy=build_policy(StateSnapshotInterval.CONVERSATION_TURNS) + state_snapshot_policy=StateSnapshotPolicy( + state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS + ) ) assert isinstance(status, FinishedStatus) @@ -109,20 +149,110 @@ def test_nested_flow_state_snapshots_stay_on_the_root_flow_span_for_shared_conve child_snapshot_events = [ event for event in child_span.events if isinstance(event, AgentSpecStateSnapshotEmitted) ] + assert len(parent_snapshot_events) == 2 assert [event.conversation_id for event in parent_snapshot_events] == [ conversation.conversation_id, conversation.conversation_id, - conversation.conversation_id, - conversation.conversation_id, ] - child_runtime_conversation_id = parent_snapshot_events[1].state_snapshot["conversation"]["id"] assert snapshot_runtime_conversation_ids(parent_snapshot_events) == [ conversation.id, - child_runtime_conversation_id, - child_runtime_conversation_id, conversation.id, ] - assert child_runtime_conversation_id != conversation.id - assert snapshot_message(parent_snapshot_events[2]) == "child" assert snapshot_message(parent_snapshot_events[-1]) == "parent" assert not child_snapshot_events + + +def test_nested_node_turn_state_snapshots_export_only_root_runtime_conversation_to_agent_spec() -> ( + None +): + child_flow = Flow.from_steps( + [OutputMessageStep(message_template="child"), CompleteStep(name="end")], + step_names=["child_message", "end"], + name="child_flow", + ) + parent_flow = Flow.from_steps( + [ + FlowExecutionStep(flow=child_flow), + OutputMessageStep(message_template="parent"), + CompleteStep(name="end"), + ], + step_names=["child_flow_step", "parent_message", "end"], + name="parent_flow", + ) + conversation = parent_flow.start_conversation() + span_recorder = SnapshotSpanRecorder() + + with AgentSpecTrace(span_processors=[span_recorder]): + with register_event_listeners([AgentSpecEventListener()]): + status = conversation.execute( + state_snapshot_policy=StateSnapshotPolicy( + state_snapshot_interval=StateSnapshotInterval.NODE_TURNS + ) + ) + + assert isinstance(status, FinishedStatus) + + flow_spans = [ + span for span in span_recorder.started_spans if isinstance(span, AgentSpecFlowExecutionSpan) + ] + flow_spans_by_name = { + next( + event for event in span.events if isinstance(event, AgentSpecFlowExecutionStart) + ).flow.name: span + for span in flow_spans + } + parent_span = flow_spans_by_name["parent_flow"] + child_span = flow_spans_by_name["child_flow"] + + parent_snapshot_events = [ + event for event in parent_span.events if isinstance(event, AgentSpecStateSnapshotEmitted) + ] + child_snapshot_events = [ + event for event in child_span.events if isinstance(event, AgentSpecStateSnapshotEmitted) + ] + + assert parent_snapshot_events + assert snapshot_runtime_conversation_ids(parent_snapshot_events) == [ + conversation.id for _ in parent_snapshot_events + ] + assert not child_snapshot_events + + +def test_parallel_map_snapshots_leave_agent_spec_exporters_with_root_resumable_state() -> None: + child_flow = Flow.from_steps( + [OutputMessageStep(message_template="item={{item}}"), CompleteStep(name="end")], + step_names=["child_message", "end"], + name="parallel_map_child", + ) + parent_flow = Flow.from_steps( + [ + ParallelMapStep( + flow=child_flow, + unpack_input={"item": "."}, + name="parallel_map", + ), + CompleteStep(name="end"), + ], + step_names=["parallel_map", "end"], + name="parallel_map_parent", + ) + conversation = parent_flow.start_conversation( + inputs={ParallelMapStep.ITERATED_INPUT: ["a", "b"]} + ) + snapshot_runtime_id_exporter = SnapshotRuntimeIdsByConversationExporter() + + with AgentSpecTrace(span_processors=[snapshot_runtime_id_exporter]): + with register_event_listeners([AgentSpecEventListener()]): + status = conversation.execute( + state_snapshot_policy=StateSnapshotPolicy( + state_snapshot_interval=StateSnapshotInterval.NODE_TURNS + ) + ) + + assert isinstance(status, FinishedStatus) + assert snapshot_runtime_id_exporter.runtime_ids_by_conversation_id[conversation.conversation_id] + assert snapshot_runtime_id_exporter.runtime_ids_by_conversation_id[ + conversation.conversation_id + ] == [conversation.id] * len( + snapshot_runtime_id_exporter.runtime_ids_by_conversation_id[conversation.conversation_id] + ) diff --git a/wayflowcore/tests/events/test_state_snapshot_runtime_conversation_turns.py b/wayflowcore/tests/events/test_state_snapshot_runtime_conversation_turns.py index 054b090e9..3b05dd9c7 100644 --- a/wayflowcore/tests/events/test_state_snapshot_runtime_conversation_turns.py +++ b/wayflowcore/tests/events/test_state_snapshot_runtime_conversation_turns.py @@ -8,6 +8,8 @@ import pytest +from wayflowcore.events import Event, EventListener +from wayflowcore.events.event import StateSnapshotEvent from wayflowcore.events.eventlistener import register_event_listeners from wayflowcore.executors._events.event import EventType from wayflowcore.executors.executionstatus import ( @@ -16,7 +18,7 @@ UserMessageRequestStatus, ) from wayflowcore.executors.interrupts.executioninterrupt import InterruptedExecutionStatus -from wayflowcore.executors.statesnapshotpolicy import StateSnapshotInterval +from wayflowcore.executors.statesnapshotpolicy import StateSnapshotInterval, StateSnapshotPolicy from wayflowcore.flow import Flow from wayflowcore.serialization import deserialize_conversation, dump_conversation_state from wayflowcore.steps import CompleteStep, InputMessageStep, OutputMessageStep, ToolExecutionStep @@ -27,7 +29,6 @@ MutatingExecutionEndInterrupt, SnapshotCollector, assert_terminal_snapshot, - build_policy, create_agent_conversation, create_output_flow_conversation, create_tool_flow_conversation, @@ -50,6 +51,18 @@ def _restore_conversation_from_snapshot_payload(snapshot_payload: dict[str, Any] return restored_conversation +class _LiveConversationSnapshotObserver(EventListener): + def __init__(self, conversation) -> None: + self.conversation = conversation + self.live_snapshots: list[dict[str, Any]] = [] + + def __call__(self, event: Event) -> None: + if not isinstance(event, StateSnapshotEvent): + return + + self.live_snapshots.append(dump_conversation_state(self.conversation)) + + @pytest.mark.parametrize( ( "conversation_factory", @@ -84,7 +97,9 @@ def test_conversation_turn_policy_records_opening_and_closing_checkpoints( status, state_snapshot_events = execute_with_state_snapshots( conversation, - state_snapshot_policy=build_policy(StateSnapshotInterval.CONVERSATION_TURNS), + state_snapshot_policy=StateSnapshotPolicy( + state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS + ), ) assert isinstance(status, expected_status_class) @@ -131,7 +146,9 @@ async def test_conversation_turn_policy_records_opening_and_closing_checkpoints_ status, state_snapshot_events = await execute_with_state_snapshots_async( conversation, - state_snapshot_policy=build_policy(StateSnapshotInterval.CONVERSATION_TURNS), + state_snapshot_policy=StateSnapshotPolicy( + state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS + ), ) assert isinstance(status, expected_status_class) @@ -158,7 +175,9 @@ def test_off_policy_disables_state_snapshot_emission( status, state_snapshot_events = execute_with_state_snapshots( conversation, - state_snapshot_policy=build_policy(StateSnapshotInterval.OFF), + state_snapshot_policy=StateSnapshotPolicy( + state_snapshot_interval=StateSnapshotInterval.OFF + ), ) assert isinstance(status, expected_status_class) @@ -181,7 +200,9 @@ def test_conversation_turn_policy_records_interrupted_turn_end_checkpoints( status, state_snapshot_events = execute_with_state_snapshots( conversation, execution_interrupts=[OnEventExecutionInterrupt(EventType.EXECUTION_END)], - state_snapshot_policy=build_policy(StateSnapshotInterval.CONVERSATION_TURNS), + state_snapshot_policy=StateSnapshotPolicy( + state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS + ), ) assert isinstance(status, InterruptedExecutionStatus) @@ -207,7 +228,9 @@ def explode() -> str: with register_event_listeners([collector]): with pytest.raises(RuntimeError, match="boom"): conversation.execute( - state_snapshot_policy=build_policy(StateSnapshotInterval.CONVERSATION_TURNS) + state_snapshot_policy=StateSnapshotPolicy( + state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS + ) ) assert len(collector.state_snapshot_events) == 1 @@ -221,7 +244,9 @@ def test_conversation_turn_policy_reflects_real_interrupt_side_effects_once() -> status, state_snapshot_events = execute_with_state_snapshots( conversation, execution_interrupts=[interrupt], - state_snapshot_policy=build_policy(StateSnapshotInterval.CONVERSATION_TURNS), + state_snapshot_policy=StateSnapshotPolicy( + state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS + ), ) assert isinstance(status, FinishedStatus) @@ -231,6 +256,81 @@ def test_conversation_turn_policy_reflects_real_interrupt_side_effects_once() -> assert state_snapshot_events[-1].state_snapshot["conversation"]["inputs"]["preview_count"] == 1 +@pytest.mark.parametrize( + ( + "interval", + "execution_interrupts", + "expected_status_class", + "expected_status_type", + ), + [ + pytest.param( + StateSnapshotInterval.CONVERSATION_TURNS, + None, + FinishedStatus, + "FinishedStatus", + id="conversation_turns-finished", + ), + pytest.param( + StateSnapshotInterval.TOOL_TURNS, + None, + FinishedStatus, + "FinishedStatus", + id="tool_turns-finished", + ), + pytest.param( + StateSnapshotInterval.NODE_TURNS, + None, + FinishedStatus, + "FinishedStatus", + id="node_turns-finished", + ), + pytest.param( + StateSnapshotInterval.ALL_INTERNAL_TURNS, + None, + FinishedStatus, + "FinishedStatus", + id="all_internal_turns-finished", + ), + pytest.param( + StateSnapshotInterval.CONVERSATION_TURNS, + [OnEventExecutionInterrupt(EventType.EXECUTION_END)], + InterruptedExecutionStatus, + "InterruptedExecutionStatus", + id="conversation_turns-interrupted", + ), + ], +) +def test_closing_turn_snapshot_is_emitted_before_live_conversation_status_commit( + interval: StateSnapshotInterval, + execution_interrupts, + expected_status_class, + expected_status_type: str, +) -> None: + conversation = create_output_flow_conversation() + collector = SnapshotCollector() + observer = _LiveConversationSnapshotObserver(conversation) + + with register_event_listeners([collector, observer]): + status = conversation.execute( + execution_interrupts=execution_interrupts, + state_snapshot_policy=StateSnapshotPolicy(state_snapshot_interval=interval), + ) + + assert isinstance(status, expected_status_class) + assert conversation.status is status + assert conversation.status_handled is False + assert observer.live_snapshots[-1]["execution"]["status"] is None + assert collector.state_snapshot_events[-1].state_snapshot is not None + assert collector.state_snapshot_events[-1].state_snapshot["execution"]["status"]["type"] == ( + expected_status_type + ) + assert observer.live_snapshots[-1] != { + "conversation": collector.state_snapshot_events[-1].state_snapshot["conversation"], + "execution": collector.state_snapshot_events[-1].state_snapshot["execution"], + } + + def test_conversation_turn_snapshot_payload_can_resume_waiting_for_client_tool_result() -> None: client_tool = ClientTool( name="client_lookup", @@ -247,7 +347,9 @@ def test_conversation_turn_snapshot_payload_can_resume_waiting_for_client_tool_r status, state_snapshot_events = execute_with_state_snapshots( conversation, - state_snapshot_policy=build_policy(StateSnapshotInterval.CONVERSATION_TURNS), + state_snapshot_policy=StateSnapshotPolicy( + state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS + ), ) assert isinstance(status, ToolRequestStatus) @@ -286,7 +388,9 @@ async def test_conversation_turn_snapshot_payload_can_resume_waiting_for_user_in status, state_snapshot_events = await execute_with_state_snapshots_async( conversation, - state_snapshot_policy=build_policy(StateSnapshotInterval.CONVERSATION_TURNS), + state_snapshot_policy=StateSnapshotPolicy( + state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS + ), ) assert isinstance(status, UserMessageRequestStatus) diff --git a/wayflowcore/tests/events/test_state_snapshot_runtime_internal_turns.py b/wayflowcore/tests/events/test_state_snapshot_runtime_internal_turns.py index 9a2c01a50..e12a3a6b4 100644 --- a/wayflowcore/tests/events/test_state_snapshot_runtime_internal_turns.py +++ b/wayflowcore/tests/events/test_state_snapshot_runtime_internal_turns.py @@ -11,14 +11,13 @@ from wayflowcore.executors._events.event import EventType from wayflowcore.executors.executionstatus import FinishedStatus, UserMessageRequestStatus from wayflowcore.executors.interrupts.executioninterrupt import InterruptedExecutionStatus -from wayflowcore.executors.statesnapshotpolicy import StateSnapshotInterval +from wayflowcore.executors.statesnapshotpolicy import StateSnapshotInterval, StateSnapshotPolicy from ..conftest import disable_streaming from ..test_interrupts import OnEventExecutionInterrupt from ..testhelpers.dummy import DummyModel from ..testhelpers.statesnapshots import ( SnapshotCollector, - build_policy, create_agent_conversation, create_output_flow_conversation, create_tool_calling_agent_conversation, @@ -48,16 +47,16 @@ def config(self) -> dict[str, object]: pytest.param( create_output_flow_conversation, FinishedStatus, - [None, None, None, None, None, None], - 6, + [None, None, None, None, None, None, None, "FinishedStatus"], + 8, None, id="flow", ), pytest.param( create_agent_conversation, UserMessageRequestStatus, - [None, None], - 2, + [None, None, None, "UserMessageRequestStatus"], + 4, [0, 1], id="agent", ), @@ -74,7 +73,9 @@ def test_node_turn_policy_tracks_flow_steps_and_agent_iterations( status, state_snapshot_events = execute_with_state_snapshots( conversation, - state_snapshot_policy=build_policy(StateSnapshotInterval.NODE_TURNS), + state_snapshot_policy=StateSnapshotPolicy( + state_snapshot_interval=StateSnapshotInterval.NODE_TURNS + ), ) assert isinstance(status, expected_status_class) @@ -82,8 +83,8 @@ def test_node_turn_policy_tracks_flow_steps_and_agent_iterations( assert snapshot_status_types(state_snapshot_events) == expected_status_types if expected_curr_iters is not None: assert [ - state_snapshot_events[0].state_snapshot["execution"]["curr_iter"], state_snapshot_events[1].state_snapshot["execution"]["curr_iter"], + state_snapshot_events[2].state_snapshot["execution"]["curr_iter"], ] == expected_curr_iters @@ -100,16 +101,16 @@ def test_node_turn_policy_tracks_flow_steps_and_agent_iterations( pytest.param( create_output_flow_conversation, FinishedStatus, - [None, None, None, None, None, None], - 6, + [None, None, None, None, None, None, None, "FinishedStatus"], + 8, None, id="flow", ), pytest.param( create_agent_conversation, UserMessageRequestStatus, - [None, None], - 2, + [None, None, None, "UserMessageRequestStatus"], + 4, [0, 1], id="agent", ), @@ -126,7 +127,9 @@ async def test_node_turn_policy_tracks_flow_steps_and_agent_iterations_async( status, state_snapshot_events = await execute_with_state_snapshots_async( conversation, - state_snapshot_policy=build_policy(StateSnapshotInterval.NODE_TURNS), + state_snapshot_policy=StateSnapshotPolicy( + state_snapshot_interval=StateSnapshotInterval.NODE_TURNS + ), ) assert isinstance(status, expected_status_class) @@ -134,8 +137,8 @@ async def test_node_turn_policy_tracks_flow_steps_and_agent_iterations_async( assert snapshot_status_types(state_snapshot_events) == expected_status_types if expected_curr_iters is not None: assert [ - state_snapshot_events[0].state_snapshot["execution"]["curr_iter"], state_snapshot_events[1].state_snapshot["execution"]["curr_iter"], + state_snapshot_events[2].state_snapshot["execution"]["curr_iter"], ] == expected_curr_iters @@ -155,11 +158,13 @@ def test_node_turn_policy_keeps_partial_progress_when_interrupted_mid_turn( status, state_snapshot_events = execute_with_state_snapshots( conversation, execution_interrupts=[OnEventExecutionInterrupt(interrupt_event)], - state_snapshot_policy=build_policy(StateSnapshotInterval.NODE_TURNS), + state_snapshot_policy=StateSnapshotPolicy( + state_snapshot_interval=StateSnapshotInterval.NODE_TURNS + ), ) assert isinstance(status, InterruptedExecutionStatus) - assert snapshot_status_types(state_snapshot_events) == [None] + assert snapshot_status_types(state_snapshot_events) == [None, None] def test_flow_node_turn_policy_uses_iteration_start_and_end_boundaries() -> None: @@ -167,17 +172,21 @@ def test_flow_node_turn_policy_uses_iteration_start_and_end_boundaries() -> None status, state_snapshot_events = execute_with_state_snapshots( conversation, - state_snapshot_policy=build_policy(StateSnapshotInterval.NODE_TURNS), + state_snapshot_policy=StateSnapshotPolicy( + state_snapshot_interval=StateSnapshotInterval.NODE_TURNS + ), ) assert isinstance(status, FinishedStatus) assert snapshot_step_histories(state_snapshot_events) == [ + [], [], ["__StartStep__"], ["__StartStep__"], ["__StartStep__", "step_0"], ["__StartStep__", "step_0"], ["__StartStep__", "step_0", "end"], + ["__StartStep__", "step_0", "end"], ] @@ -187,7 +196,7 @@ def test_internal_snapshots_do_not_reuse_the_previous_turn_status() -> None: conversation = Agent(llm=llm).start_conversation() conversation.append_user_message("Hi") collector = SnapshotCollector() - policy = build_policy(StateSnapshotInterval.NODE_TURNS) + policy = StateSnapshotPolicy(state_snapshot_interval=StateSnapshotInterval.NODE_TURNS) with register_event_listeners([collector]): first_status = conversation.execute(state_snapshot_policy=policy) @@ -197,9 +206,9 @@ def test_internal_snapshots_do_not_reuse_the_previous_turn_status() -> None: second_status = conversation.execute(state_snapshot_policy=policy) assert isinstance(second_status, UserMessageRequestStatus) - assert len(collector.state_snapshot_events) == 4 + assert len(collector.state_snapshot_events) == 8 - second_turn_internal_snapshots = collector.state_snapshot_events[2:4] + second_turn_internal_snapshots = collector.state_snapshot_events[5:7] assert snapshot_status_types(second_turn_internal_snapshots) == [None, None] assert all( snapshot_event.state_snapshot["execution"]["status_handled"] is False @@ -221,7 +230,7 @@ def test_internal_snapshots_do_not_reuse_the_previous_turn_status() -> None: None, None, FinishedStatus, - [None, None], + [None, None, None, "FinishedStatus"], id="flow-success", ), pytest.param( @@ -229,7 +238,7 @@ def test_internal_snapshots_do_not_reuse_the_previous_turn_status() -> None: [OnEventExecutionInterrupt(EventType.TOOL_CALL_END)], disable_streaming(), InterruptedExecutionStatus, - [None, None], + [None, None, None], id="agent-tool-end-interrupt", ), ], @@ -246,7 +255,9 @@ def test_tool_turn_policy_records_real_tool_boundaries( status, state_snapshot_events = execute_with_state_snapshots( conversation, execution_interrupts=execution_interrupts, - state_snapshot_policy=build_policy(StateSnapshotInterval.TOOL_TURNS), + state_snapshot_policy=StateSnapshotPolicy( + state_snapshot_interval=StateSnapshotInterval.TOOL_TURNS + ), execution_context=execution_context, ) @@ -269,7 +280,7 @@ def test_tool_turn_policy_records_real_tool_boundaries( None, None, FinishedStatus, - [None, None], + [None, None, None, "FinishedStatus"], id="flow-success", ), pytest.param( @@ -277,7 +288,7 @@ def test_tool_turn_policy_records_real_tool_boundaries( [OnEventExecutionInterrupt(EventType.TOOL_CALL_END)], disable_streaming(), InterruptedExecutionStatus, - [None, None], + [None, None, None], id="agent-tool-end-interrupt", ), ], @@ -294,7 +305,9 @@ async def test_tool_turn_policy_records_real_tool_boundaries_async( status, state_snapshot_events = await execute_with_state_snapshots_async( conversation, execution_interrupts=execution_interrupts, - state_snapshot_policy=build_policy(StateSnapshotInterval.TOOL_TURNS), + state_snapshot_policy=StateSnapshotPolicy( + state_snapshot_interval=StateSnapshotInterval.TOOL_TURNS + ), execution_context=execution_context, ) @@ -307,12 +320,14 @@ def test_all_internal_turn_policy_combines_node_and_tool_boundaries() -> None: status, state_snapshot_events = execute_with_state_snapshots( conversation, - state_snapshot_policy=build_policy(StateSnapshotInterval.ALL_INTERNAL_TURNS), + state_snapshot_policy=StateSnapshotPolicy( + state_snapshot_interval=StateSnapshotInterval.ALL_INTERNAL_TURNS + ), ) assert isinstance(status, FinishedStatus) - assert len(state_snapshot_events) == 8 - assert snapshot_status_types(state_snapshot_events) == [None] * 8 + assert len(state_snapshot_events) == 10 + assert snapshot_status_types(state_snapshot_events) == [None] * 9 + ["FinishedStatus"] @pytest.mark.anyio @@ -321,9 +336,51 @@ async def test_all_internal_turn_policy_combines_node_and_tool_boundaries_async( status, state_snapshot_events = await execute_with_state_snapshots_async( conversation, - state_snapshot_policy=build_policy(StateSnapshotInterval.ALL_INTERNAL_TURNS), + state_snapshot_policy=StateSnapshotPolicy( + state_snapshot_interval=StateSnapshotInterval.ALL_INTERNAL_TURNS + ), + ) + + assert isinstance(status, FinishedStatus) + assert len(state_snapshot_events) == 10 + assert snapshot_status_types(state_snapshot_events) == [None] * 9 + ["FinishedStatus"] + + +@pytest.mark.parametrize( + ("interval", "expected_status_types"), + [ + pytest.param( + StateSnapshotInterval.CONVERSATION_TURNS, + [None, "FinishedStatus"], + id="conversation_turns", + ), + pytest.param( + StateSnapshotInterval.TOOL_TURNS, + [None, None, None, "FinishedStatus"], + id="tool_turns", + ), + pytest.param( + StateSnapshotInterval.NODE_TURNS, + [None, None, None, None, None, None, None, "FinishedStatus"], + id="node_turns", + ), + pytest.param( + StateSnapshotInterval.ALL_INTERNAL_TURNS, + [None, None, None, None, None, None, None, None, None, "FinishedStatus"], + id="all_internal_turns", + ), + ], +) +def test_snapshot_interval_policies_include_conversation_turns_cumulatively( + interval: StateSnapshotInterval, + expected_status_types: list[str | None], +) -> None: + conversation = create_tool_flow_conversation(lambda: "hi") + + status, state_snapshot_events = execute_with_state_snapshots( + conversation, + state_snapshot_policy=StateSnapshotPolicy(state_snapshot_interval=interval), ) assert isinstance(status, FinishedStatus) - assert len(state_snapshot_events) == 8 - assert snapshot_status_types(state_snapshot_events) == [None] * 8 + assert snapshot_status_types(state_snapshot_events) == expected_status_types diff --git a/wayflowcore/tests/events/test_state_snapshot_runtime_nested.py b/wayflowcore/tests/events/test_state_snapshot_runtime_nested.py index 442dcd973..e0a5dde95 100644 --- a/wayflowcore/tests/events/test_state_snapshot_runtime_nested.py +++ b/wayflowcore/tests/events/test_state_snapshot_runtime_nested.py @@ -7,17 +7,18 @@ import pytest from wayflowcore.executors.executionstatus import FinishedStatus, UserMessageRequestStatus -from wayflowcore.executors.statesnapshotpolicy import StateSnapshotInterval +from wayflowcore.executors.statesnapshotpolicy import StateSnapshotInterval, StateSnapshotPolicy from wayflowcore.flow import Flow +from wayflowcore.serialization import deserialize_conversation from wayflowcore.steps import ( CompleteStep, FlowExecutionStep, OutputMessageStep, ParallelFlowExecutionStep, + ParallelMapStep, ) from ..testhelpers.statesnapshots import ( - build_policy, create_nested_agent_step_flow_conversation, create_parallel_child_flow, execute_with_state_snapshots, @@ -45,22 +46,20 @@ def test_state_snapshot_policy_is_inherited_by_nested_sub_conversations() -> Non status, state_snapshot_events = execute_with_state_snapshots( conversation, - state_snapshot_policy=build_policy(StateSnapshotInterval.CONVERSATION_TURNS), + state_snapshot_policy=StateSnapshotPolicy( + state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS + ), ) assert isinstance(status, FinishedStatus) - assert len(state_snapshot_events) == 4 + assert len(state_snapshot_events) == 2 assert {snapshot_event.conversation_id for snapshot_event in state_snapshot_events} == { conversation.conversation_id } - child_runtime_conversation_id = state_snapshot_events[1].state_snapshot["conversation"]["id"] assert snapshot_runtime_conversation_ids(state_snapshot_events) == [ conversation.id, - child_runtime_conversation_id, - child_runtime_conversation_id, conversation.id, ] - assert child_runtime_conversation_id != conversation.id @pytest.mark.anyio @@ -81,22 +80,65 @@ async def test_state_snapshot_policy_is_inherited_by_nested_sub_conversations_as status, state_snapshot_events = await execute_with_state_snapshots_async( conversation, - state_snapshot_policy=build_policy(StateSnapshotInterval.CONVERSATION_TURNS), + state_snapshot_policy=StateSnapshotPolicy( + state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS + ), ) assert isinstance(status, FinishedStatus) - assert len(state_snapshot_events) == 4 + assert len(state_snapshot_events) == 2 assert {snapshot_event.conversation_id for snapshot_event in state_snapshot_events} == { conversation.conversation_id } - child_runtime_conversation_id = state_snapshot_events[1].state_snapshot["conversation"]["id"] assert snapshot_runtime_conversation_ids(state_snapshot_events) == [ conversation.id, - child_runtime_conversation_id, - child_runtime_conversation_id, conversation.id, ] - assert child_runtime_conversation_id != conversation.id + + +def test_nested_root_turn_snapshot_payload_can_resume_the_logical_parent_conversation() -> None: + child_flow = Flow.from_steps( + [ + OutputMessageStep(message_template="child"), + CompleteStep(name="child_end"), + ], + name="child_flow", + ) + parent_flow = Flow.from_steps( + [ + FlowExecutionStep(flow=child_flow), + OutputMessageStep(message_template="parent"), + CompleteStep(name="parent_end"), + ], + name="parent_flow", + ) + conversation = parent_flow.start_conversation() + + status, state_snapshot_events = execute_with_state_snapshots( + conversation, + state_snapshot_policy=StateSnapshotPolicy( + state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS + ), + ) + + assert isinstance(status, FinishedStatus) + root_turn_snapshot_event = state_snapshot_events[-1] + root_turn_snapshot = root_turn_snapshot_event.state_snapshot + assert root_turn_snapshot is not None + assert root_turn_snapshot_event.conversation_id == conversation.conversation_id + assert root_turn_snapshot["conversation"]["id"] == conversation.id + + restored_conversation = deserialize_conversation(root_turn_snapshot["conversation_state"]) + assert restored_conversation.id == conversation.id + assert restored_conversation.conversation_id == conversation.conversation_id + + resumed_status = restored_conversation.execute() + + assert isinstance(resumed_status, FinishedStatus) + assert [message.content for message in restored_conversation.get_messages()] == [ + "child", + "parent", + ] def test_state_snapshot_policy_is_inherited_by_parallel_sub_conversations() -> None: @@ -114,16 +156,156 @@ def test_state_snapshot_policy_is_inherited_by_parallel_sub_conversations() -> N status, state_snapshot_events = execute_with_state_snapshots( conversation, - state_snapshot_policy=build_policy(StateSnapshotInterval.CONVERSATION_TURNS), + state_snapshot_policy=StateSnapshotPolicy( + state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS + ), + ) + + assert isinstance(status, FinishedStatus) + assert len(state_snapshot_events) == 2 + assert {snapshot_event.conversation_id for snapshot_event in state_snapshot_events} == { + conversation.conversation_id + } + assert snapshot_runtime_conversation_ids(state_snapshot_events) == [ + conversation.id, + conversation.id, + ] + assert snapshot_status_types(state_snapshot_events) == [None, "FinishedStatus"] + + +def test_parallel_root_turn_snapshot_payloads_can_resume_the_logical_parent_conversation() -> None: + left_child_flow = Flow.from_steps( + [ + OutputMessageStep( + message_template="left", + output_mapping={OutputMessageStep.OUTPUT: "left_message"}, + ), + CompleteStep(name="left_end"), + ], + name="left_child_flow", + ) + right_child_flow = Flow.from_steps( + [ + OutputMessageStep( + message_template="right", + output_mapping={OutputMessageStep.OUTPUT: "right_message"}, + ), + CompleteStep(name="right_end"), + ], + name="right_child_flow", + ) + conversation = Flow.from_steps( + [ + ParallelFlowExecutionStep( + flows=[ + left_child_flow, + right_child_flow, + ] + ), + CompleteStep(name="end"), + ] + ).start_conversation() + + status, state_snapshot_events = execute_with_state_snapshots( + conversation, + state_snapshot_policy=StateSnapshotPolicy( + state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS + ), + ) + + assert isinstance(status, FinishedStatus) + + for snapshot_event in state_snapshot_events: + snapshot_payload = snapshot_event.state_snapshot + assert snapshot_payload is not None + + restored_conversation = deserialize_conversation(snapshot_payload["conversation_state"]) + resumed_status = restored_conversation.execute() + + assert isinstance(resumed_status, FinishedStatus) + assert sorted(message.content for message in restored_conversation.get_messages()) == [ + "left", + "right", + ] + + +def test_parallel_map_emits_only_resumable_parent_turn_snapshots() -> None: + child_flow = Flow.from_steps( + [ + OutputMessageStep(message_template="item={{item}}"), + CompleteStep(name="child_end"), + ], + name="parallel_map_child", + ) + conversation = Flow.from_steps( + [ + ParallelMapStep( + flow=child_flow, + unpack_input={"item": "."}, + name="parallel_map", + ), + CompleteStep(name="end"), + ], + name="parallel_map_parent", + ).start_conversation(inputs={ParallelMapStep.ITERATED_INPUT: ["a", "b"]}) + + status, state_snapshot_events = execute_with_state_snapshots( + conversation, + state_snapshot_policy=StateSnapshotPolicy( + state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS + ), + ) + + assert isinstance(status, FinishedStatus) + assert snapshot_status_types(state_snapshot_events) == [None, "FinishedStatus"] + + for snapshot_event in state_snapshot_events: + snapshot_payload = snapshot_event.state_snapshot + assert snapshot_payload is not None + + restored_conversation = deserialize_conversation(snapshot_payload["conversation_state"]) + resumed_status = restored_conversation.execute() + + assert isinstance(resumed_status, FinishedStatus) + assert sorted(message.content for message in restored_conversation.get_messages()) == [ + "item=a", + "item=b", + ] + + +def test_nested_node_turn_snapshots_keep_child_runtime_conversation_identity() -> None: + child_flow = Flow.from_steps( + [ + OutputMessageStep(message_template="child"), + CompleteStep(name="child_end"), + ], + name="child_flow", + ) + parent_flow = Flow.from_steps( + [ + FlowExecutionStep(flow=child_flow), + OutputMessageStep(message_template="parent"), + CompleteStep(name="parent_end"), + ], + name="parent_flow", + ) + conversation = parent_flow.start_conversation() + + status, state_snapshot_events = execute_with_state_snapshots( + conversation, + state_snapshot_policy=StateSnapshotPolicy( + state_snapshot_interval=StateSnapshotInterval.NODE_TURNS + ), ) assert isinstance(status, FinishedStatus) - assert len(state_snapshot_events) == 6 assert {snapshot_event.conversation_id for snapshot_event in state_snapshot_events} == { conversation.conversation_id } - assert snapshot_status_types(state_snapshot_events).count(None) == 3 - assert snapshot_status_types(state_snapshot_events).count("FinishedStatus") == 3 + assert any( + snapshot_event.state_snapshot["conversation"]["id"] != conversation.id + for snapshot_event in state_snapshot_events + ) def test_state_snapshot_policy_is_inherited_by_nested_agent_steps() -> None: @@ -131,23 +313,19 @@ def test_state_snapshot_policy_is_inherited_by_nested_agent_steps() -> None: status, state_snapshot_events = execute_with_state_snapshots( conversation, - state_snapshot_policy=build_policy(StateSnapshotInterval.CONVERSATION_TURNS), + state_snapshot_policy=StateSnapshotPolicy( + state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS + ), ) - nested_conversation_id = state_snapshot_events[1].conversation_id - assert isinstance(status, UserMessageRequestStatus) - assert snapshot_status_types(state_snapshot_events) == [ - None, - None, - "UserMessageRequestStatus", - "UserMessageRequestStatus", - ] + assert snapshot_status_types(state_snapshot_events) == [None, "UserMessageRequestStatus"] assert [snapshot_event.conversation_id for snapshot_event in state_snapshot_events] == [ conversation.conversation_id, - nested_conversation_id, - nested_conversation_id, conversation.conversation_id, ] - assert nested_conversation_id != conversation.conversation_id + assert snapshot_runtime_conversation_ids(state_snapshot_events) == [ + conversation.id, + conversation.id, + ] assert snapshot_message(state_snapshot_events[-1]) == "agent answer" diff --git a/wayflowcore/tests/events/test_state_snapshot_runtime_resilience.py b/wayflowcore/tests/events/test_state_snapshot_runtime_resilience.py index 6d8cfdb44..da31813cd 100644 --- a/wayflowcore/tests/events/test_state_snapshot_runtime_resilience.py +++ b/wayflowcore/tests/events/test_state_snapshot_runtime_resilience.py @@ -4,19 +4,44 @@ # (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0) or Universal Permissive License # (UPL) 1.0 (LICENSE-UPL or https://oss.oracle.com/licenses/upl), at your option. +from dataclasses import dataclass + from wayflowcore.conversation import Conversation from wayflowcore.executors.executionstatus import FinishedStatus -from wayflowcore.executors.statesnapshotpolicy import StateSnapshotInterval +from wayflowcore.executors.statesnapshotpolicy import StateSnapshotInterval, StateSnapshotPolicy +from wayflowcore.flow import Flow +from wayflowcore.property import AnyProperty +from wayflowcore.serialization.serializer import FrozenSerializableDataclass +from wayflowcore.steps import CompleteStep, OutputMessageStep, VariableWriteStep +from wayflowcore.variable import Variable from ..testhelpers.statesnapshots import ( - build_policy, create_output_flow_conversation, - create_unserializable_variable_conversation, execute_with_state_snapshots, snapshot_message, ) +@dataclass(frozen=True) +class _SerializableButNotJson(FrozenSerializableDataclass): + value: str + + +def _create_non_json_variable_state_conversation() -> Conversation: + custom_variable = Variable(name="custom", type=AnyProperty()) + return Flow.from_steps( + [ + VariableWriteStep( + variable=custom_variable, + input_mapping={VariableWriteStep.VALUE: custom_variable.name}, + ), + OutputMessageStep(message_template="done"), + CompleteStep(name="end"), + ], + variables=[custom_variable], + ).start_conversation(inputs={custom_variable.name: _SerializableButNotJson(value="x")}) + + def test_state_snapshot_emission_survives_broken_extra_state_builder() -> None: def broken_builder(_conversation: Conversation) -> dict[str, object]: raise RuntimeError("boom") @@ -25,8 +50,8 @@ def broken_builder(_conversation: Conversation) -> dict[str, object]: status, state_snapshot_events = execute_with_state_snapshots( conversation, - state_snapshot_policy=build_policy( - StateSnapshotInterval.CONVERSATION_TURNS, + state_snapshot_policy=StateSnapshotPolicy( + state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS, extra_state_builder=broken_builder, ), ) @@ -37,12 +62,12 @@ def broken_builder(_conversation: Conversation) -> dict[str, object]: def test_state_snapshot_emission_survives_unserializable_variable_state() -> None: - conversation = create_unserializable_variable_conversation() + conversation = _create_non_json_variable_state_conversation() status, state_snapshot_events = execute_with_state_snapshots( conversation, - state_snapshot_policy=build_policy( - StateSnapshotInterval.CONVERSATION_TURNS, + state_snapshot_policy=StateSnapshotPolicy( + state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS, include_variable_state=True, ), ) diff --git a/wayflowcore/tests/integration/steps/test_flow_execution_step.py b/wayflowcore/tests/integration/steps/test_flow_execution_step.py index 94783d0b8..d03f9d7df 100644 --- a/wayflowcore/tests/integration/steps/test_flow_execution_step.py +++ b/wayflowcore/tests/integration/steps/test_flow_execution_step.py @@ -22,6 +22,7 @@ CompleteStep, FlowExecutionStep, InputMessageStep, + OutputMessageStep, ToolExecutionStep, ) from wayflowcore.tools import ClientTool, ServerTool, ToolResult @@ -210,6 +211,58 @@ def test_sub_conversation_shares_same_message_list_as_main_conversation() -> Non assert conversation.get_messages() == subconversation.get_messages() +def test_subflow_execution_reuses_finished_child_conversation_and_preserves_branch() -> None: + child_success = CompleteStep(name="child_success") + child_message = OutputMessageStep(message_template="child", name="child_message") + child_flow = Flow( + begin_step=child_message, + steps={ + "child_message": child_message, + "child_success": child_success, + }, + control_flow_edges=[ + ControlFlowEdge(child_message, child_success), + ], + ) + child_step = FlowExecutionStep(flow=child_flow, name="child_step") + success_message = OutputMessageStep(message_template="parent-success", name="success_message") + parent_flow = Flow( + begin_step=child_step, + steps={ + "child_step": child_step, + "success_message": success_message, + }, + control_flow_edges=[ + ControlFlowEdge(child_step, success_message, source_branch="child_success"), + ControlFlowEdge(success_message, None), + ], + ) + conversation = parent_flow.start_conversation() + + sub_conversation = conversation._get_or_create_current_sub_conversation( + step=child_step, + flow=child_flow, + inputs={}, + ) + sub_status = sub_conversation.execute() + assert isinstance(sub_status, FinishedStatus) + assert sub_status.complete_step_name == "child_success" + + async def _should_not_execute_again() -> FinishedStatus: + raise AssertionError("finished child conversation should not be executed again") + + sub_conversation.execute_async = _should_not_execute_again # type: ignore[method-assign] + + status = conversation.execute() + + assert isinstance(status, FinishedStatus) + assert status.output_values == {"output_message": "parent-success"} + assert [message.content for message in conversation.get_messages()] == [ + "child", + "parent-success", + ] + + def test_subflow_execution_might_yield_when_flow_contains_yielding_steps() -> None: step = FlowExecutionStep( flow=Flow.from_steps( diff --git a/wayflowcore/tests/integration/steps/test_parallel_flow_execution_step.py b/wayflowcore/tests/integration/steps/test_parallel_flow_execution_step.py index e1a41b75c..65a1ec5e1 100644 --- a/wayflowcore/tests/integration/steps/test_parallel_flow_execution_step.py +++ b/wayflowcore/tests/integration/steps/test_parallel_flow_execution_step.py @@ -9,10 +9,11 @@ import pytest from wayflowcore import Tool +from wayflowcore.executors.executionstatus import FinishedStatus from wayflowcore.flow import Flow from wayflowcore.flowhelpers import run_step_and_return_outputs from wayflowcore.property import IntegerProperty, ListProperty, StringProperty -from wayflowcore.steps import InputMessageStep, OutputMessageStep, ToolExecutionStep +from wayflowcore.steps import CompleteStep, InputMessageStep, OutputMessageStep, ToolExecutionStep from wayflowcore.steps.parallelflowexecutionstep import ParallelFlowExecutionStep from wayflowcore.tools import ServerTool @@ -225,6 +226,53 @@ def test_parallel_subflow_execution_cannot_yield(): assert step.might_yield is False +def test_parallel_subflow_execution_reuses_previously_finished_child_conversations() -> None: + left_flow = get_flow_from_tools([get_tool(output_name="left_output")]) + right_flow = get_flow_from_tools([get_tool(output_name="right_output")]) + step = ParallelFlowExecutionStep(flows=[left_flow, right_flow], name="parallel") + conversation = Flow.from_steps([step, CompleteStep(name="end")]).start_conversation() + + for sub_flow in step.flows: + sub_conversation = conversation._get_or_create_current_sub_conversation( + step=step, + flow=sub_flow, + inputs={}, + sub_conversation_id=sub_flow.id, + ) + status = sub_conversation.execute() + assert isinstance(status, FinishedStatus) + + async def _should_not_execute_again() -> FinishedStatus: + raise AssertionError("finished child conversation should not be executed again") + + sub_conversation.execute_async = _should_not_execute_again # type: ignore[method-assign] + + status = conversation.execute() + + assert isinstance(status, FinishedStatus) + assert status.output_values == { + "left_output": "{}_from_left_output", + "right_output": "{}_from_right_output", + } + + +def test_parallel_subflow_execution_cleans_up_custom_id_sub_conversations_from_parent() -> None: + left_flow = get_flow_from_tools([get_tool(output_name="left_output")]) + right_flow = get_flow_from_tools([get_tool(output_name="right_output")]) + step = ParallelFlowExecutionStep(flows=[left_flow, right_flow], name="parallel") + conversation = Flow.from_steps([step, CompleteStep(name="end")]).start_conversation() + + status = conversation.execute() + + assert isinstance(status, FinishedStatus) + assert ( + conversation._get_current_sub_conversation(step, sub_conversation_id=left_flow.id) is None + ) + assert ( + conversation._get_current_sub_conversation(step, sub_conversation_id=right_flow.id) is None + ) + + def test_step_raises_when_two_flows_have_inputs_with_same_name_but_different_types(): tool_1 = ServerTool( name="n", description="d", input_descriptors=[ListProperty(name="i")], func=lambda: "" diff --git a/wayflowcore/tests/serialization/test_conversation_state_snapshot.py b/wayflowcore/tests/serialization/test_conversation_state_snapshot.py index 2d3ff2433..1dd89c66a 100644 --- a/wayflowcore/tests/serialization/test_conversation_state_snapshot.py +++ b/wayflowcore/tests/serialization/test_conversation_state_snapshot.py @@ -18,7 +18,7 @@ ToolRequestStatus, UserMessageRequestStatus, ) -from wayflowcore.executors.statesnapshotpolicy import StateSnapshotInterval +from wayflowcore.executors.statesnapshotpolicy import StateSnapshotInterval, StateSnapshotPolicy from wayflowcore.flow import Flow from wayflowcore.property import AnyProperty, StringProperty from wayflowcore.serialization import ( @@ -41,7 +41,7 @@ from wayflowcore.tools import ClientTool, ServerTool, ToolResult, register_server_tool from wayflowcore.variable import Variable -from ..testhelpers.statesnapshots import build_policy, execute_with_state_snapshots +from ..testhelpers.statesnapshots import execute_with_state_snapshots class _UnserializableValue: @@ -356,7 +356,9 @@ def test_emitted_snapshot_conversation_state_restores_variable_dependent_continu status, state_snapshot_events = execute_with_state_snapshots( conversation, - state_snapshot_policy=build_policy(StateSnapshotInterval.CONVERSATION_TURNS), + state_snapshot_policy=StateSnapshotPolicy( + state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS + ), ) assert isinstance(status, UserMessageRequestStatus) diff --git a/wayflowcore/tests/testhelpers/statesnapshots.py b/wayflowcore/tests/testhelpers/statesnapshots.py index 4b7290ebc..f9573136e 100644 --- a/wayflowcore/tests/testhelpers/statesnapshots.py +++ b/wayflowcore/tests/testhelpers/statesnapshots.py @@ -18,7 +18,7 @@ InterruptedExecutionStatus, _NullExecutionInterrupt, ) -from wayflowcore.executors.statesnapshotpolicy import StateSnapshotInterval, StateSnapshotPolicy +from wayflowcore.executors.statesnapshotpolicy import StateSnapshotPolicy from wayflowcore.flow import Flow from wayflowcore.messagelist import Message, MessageType from wayflowcore.property import AnyProperty, StringProperty @@ -238,13 +238,6 @@ def create_unserializable_variable_conversation() -> Conversation: ).start_conversation(inputs={custom_variable.name: _UnserializableVariableValue()}) -def build_policy( - interval: StateSnapshotInterval, - **kwargs: object, -) -> StateSnapshotPolicy: - return StateSnapshotPolicy(state_snapshot_interval=interval, **kwargs) - - def assert_terminal_snapshot( snapshot_events: Sequence[object], *, From 80220e04751d4f859d9e9f592dbf73d21812e284 Mon Sep 17 00:00:00 2001 From: Son Le Date: Thu, 19 Mar 2026 18:51:12 +0100 Subject: [PATCH 11/13] Finalize snapshot tracing cleanup and coverage --- .../source/core/howtoguides/howto_tracing.rst | 9 +- .../src/wayflowcore/agentspec/tracing.py | 65 +++--- wayflowcore/src/wayflowcore/events/event.py | 15 +- .../executors/_statesnapshot_eventlistener.py | 107 +++++----- .../test_state_snapshot_tracing_agent.py | 189 +++--------------- .../test_state_snapshot_tracing_flow.py | 178 +++-------------- .../test_state_snapshot_tracing_nested.py | 93 ++------- ...ate_snapshot_runtime_conversation_turns.py | 55 +++-- ...t_state_snapshot_runtime_internal_turns.py | 116 ++++++++++- .../test_state_snapshot_runtime_nested.py | 8 +- .../test_state_snapshot_runtime_resilience.py | 30 +++ .../test_conversation_state_snapshot.py | 8 +- .../tests/testhelpers/agentspec_tracing.py | 132 ++++++++++++ .../tests/testhelpers/statesnapshots.py | 24 ++- 14 files changed, 516 insertions(+), 513 deletions(-) create mode 100644 wayflowcore/tests/testhelpers/agentspec_tracing.py diff --git a/docs/wayflowcore/source/core/howtoguides/howto_tracing.rst b/docs/wayflowcore/source/core/howtoguides/howto_tracing.rst index 426fa9222..f32847323 100644 --- a/docs/wayflowcore/source/core/howtoguides/howto_tracing.rst +++ b/docs/wayflowcore/source/core/howtoguides/howto_tracing.rst @@ -168,9 +168,12 @@ JSON-serializable. while ``state_snapshot["conversation"]["id"]`` identifies the concrete runtime conversation instance that emitted the snapshot. The lightweight ``state_snapshot["conversation"]`` / ``state_snapshot["execution"]`` sections are -intended for inspection and tracing, while -``state_snapshot["conversation_state"]`` contains the authoritative serialized -WayFlow conversation blob used for resumability. To restore from that blob, use +intended for inspection and tracing. Only the root conversation-turn +checkpoints emitted for the conversation passed directly to ``execute()`` / +``execute_async()`` include ``state_snapshot["conversation_state"]``, which is +the authoritative serialized WayFlow conversation blob used for resumability. +Internal tool/node snapshots intentionally omit that serialized blob to stay +lightweight. To restore from a resumable checkpoint, use ``wayflowcore.serialization.deserialize_conversation(...)`` or ``deserialize_conversation_state(...)`` together with ``load_conversation_state(...)``. diff --git a/wayflowcore/src/wayflowcore/agentspec/tracing.py b/wayflowcore/src/wayflowcore/agentspec/tracing.py index bfbfeb9df..a102a6b1c 100644 --- a/wayflowcore/src/wayflowcore/agentspec/tracing.py +++ b/wayflowcore/src/wayflowcore/agentspec/tracing.py @@ -150,6 +150,34 @@ def _get_snapshot_owner_span( return current_agentspec_span + def _get_current_conversation_owner(self) -> _ConversationSpanOwner | None: + active_conversation = self._get_active_wayflow_conversation() + if active_conversation is None: + return None + return self._conversation_span_owners.get(active_conversation.conversation_id) + + def _span_has_state_snapshot(self, agentspec_span: AgentSpecSpan) -> bool: + return any( + isinstance(span_event, AgentSpecStateSnapshotEmitted) + for span_event in agentspec_span.events + ) + + def _span_has_execution_end_event(self, agentspec_span: AgentSpecSpan) -> bool: + return any( + isinstance(span_event, (AgentSpecFlowExecutionEnd, AgentSpecAgentExecutionEnd)) + for span_event in agentspec_span.events + ) + + def _end_conversation_span_if_ready(self, agentspec_span: AgentSpecSpan) -> None: + owner = self._get_current_conversation_owner() + if ( + owner is not None + and owner.span is agentspec_span + and self._span_has_state_snapshot(agentspec_span) + ): + return + agentspec_span.end() + def __call__(self, event: Event) -> None: # We intercept the wayflow events, and based on the type of event: # - if it corresponds to a span start event, we create the corresponding agent spec span, and we start it @@ -391,10 +419,7 @@ def __call__(self, event: Event) -> None: extra_state=event.extra_state, ) owner_span.add_event(snapshot_event) - if owner_span.end_time is None and any( - isinstance(span_event, (AgentSpecFlowExecutionEnd, AgentSpecAgentExecutionEnd)) - for span_event in owner_span.events - ): + if owner_span.end_time is None and self._span_has_execution_end_event(owner_span): owner_span.end() case FlowExecutionStartedEvent(): # Flow execution starts. Create the new agent spec span, start it, add the event @@ -441,21 +466,7 @@ def __call__(self, event: Event) -> None: branch_selected=branch_selected, ) ) - active_conversation = self._get_active_wayflow_conversation() - owner = ( - self._conversation_span_owners.get(active_conversation.conversation_id) - if active_conversation is not None - else None - ) - if ( - owner is None - or owner.span is not current_agentspec_span - or not any( - isinstance(span_event, AgentSpecStateSnapshotEmitted) - for span_event in current_agentspec_span.events - ) - ): - current_agentspec_span.end() + self._end_conversation_span_if_ready(current_agentspec_span) case AgentExecutionStartedEvent(): # Agent execution starts. Create the new agent spec span, start it, add the event agentspec_agent = cast( @@ -499,21 +510,7 @@ def __call__(self, event: Event) -> None: outputs=outputs, ) ) - active_conversation = self._get_active_wayflow_conversation() - owner = ( - self._conversation_span_owners.get(active_conversation.conversation_id) - if active_conversation is not None - else None - ) - if ( - owner is None - or owner.span is not current_agentspec_span - or not any( - isinstance(span_event, AgentSpecStateSnapshotEmitted) - for span_event in current_agentspec_span.events - ) - ): - current_agentspec_span.end() + self._end_conversation_span_if_ready(current_agentspec_span) case ExceptionRaisedEvent(): if not current_agentspec_span: return diff --git a/wayflowcore/src/wayflowcore/events/event.py b/wayflowcore/src/wayflowcore/events/event.py index 7d7db8348..0f6d36653 100644 --- a/wayflowcore/src/wayflowcore/events/event.py +++ b/wayflowcore/src/wayflowcore/events/event.py @@ -801,14 +801,15 @@ class StateSnapshotEvent(Event): ``conversation_id`` is the logical/public conversation id. When a snapshot is present, ``state_snapshot["conversation"]["id"]`` identifies the runtime conversation instance described by the payload. WayFlow-emitted payloads - also place the authoritative serialized state in - ``state_snapshot["conversation_state"]`` while keeping + include the authoritative serialized state in + ``state_snapshot["conversation_state"]`` only for the root conversation-turn + checkpoints owned by the conversation that began the current ``execute()`` / + ``execute_async()`` run. All snapshots keep ``state_snapshot["conversation"]`` and ``state_snapshot["execution"]`` as - the lightweight inspection view. Snapshots emitted for the conversation that - began the current ``execute()`` / ``execute_async()`` run are the - resumable checkpoints for that run; nested child snapshots are primarily - tracing checkpoints and may be filtered by downstream bridges that need a - single checkpoint owner per logical conversation. + the lightweight inspection view. Nested child snapshots and internal + tool/node snapshots are primarily tracing checkpoints and may be filtered by + downstream bridges that need a single checkpoint owner per logical + conversation. """ conversation_id: str = field(default_factory=_required_attribute("conversation_id", str)) diff --git a/wayflowcore/src/wayflowcore/executors/_statesnapshot_eventlistener.py b/wayflowcore/src/wayflowcore/executors/_statesnapshot_eventlistener.py index 6fc38aebb..dd61ee228 100644 --- a/wayflowcore/src/wayflowcore/executors/_statesnapshot_eventlistener.py +++ b/wayflowcore/src/wayflowcore/executors/_statesnapshot_eventlistener.py @@ -50,6 +50,24 @@ _STATE_SNAPSHOT_RUNTIME = "wayflow" _STATE_SNAPSHOT_SCHEMA_VERSION = 1 +_STATE_SNAPSHOT_INTERVALS_BY_POLICY = { + StateSnapshotInterval.CONVERSATION_TURNS: { + StateSnapshotInterval.CONVERSATION_TURNS, + }, + StateSnapshotInterval.TOOL_TURNS: { + StateSnapshotInterval.CONVERSATION_TURNS, + StateSnapshotInterval.TOOL_TURNS, + }, + StateSnapshotInterval.NODE_TURNS: { + StateSnapshotInterval.CONVERSATION_TURNS, + StateSnapshotInterval.NODE_TURNS, + }, + StateSnapshotInterval.ALL_INTERNAL_TURNS: { + StateSnapshotInterval.CONVERSATION_TURNS, + StateSnapshotInterval.TOOL_TURNS, + StateSnapshotInterval.NODE_TURNS, + }, +} _STATE_SNAPSHOT_POLICIES: ContextVar[Dict[str, StateSnapshotPolicy]] = ContextVar( @@ -187,25 +205,7 @@ def _get_snapshot_policy_for_interval( if snapshot_interval == StateSnapshotInterval.OFF: return None - included_intervals = { - StateSnapshotInterval.CONVERSATION_TURNS: { - StateSnapshotInterval.CONVERSATION_TURNS, - }, - StateSnapshotInterval.TOOL_TURNS: { - StateSnapshotInterval.CONVERSATION_TURNS, - StateSnapshotInterval.TOOL_TURNS, - }, - StateSnapshotInterval.NODE_TURNS: { - StateSnapshotInterval.CONVERSATION_TURNS, - StateSnapshotInterval.NODE_TURNS, - }, - StateSnapshotInterval.ALL_INTERNAL_TURNS: { - StateSnapshotInterval.CONVERSATION_TURNS, - StateSnapshotInterval.TOOL_TURNS, - StateSnapshotInterval.NODE_TURNS, - }, - } - if required_snapshot_interval in included_intervals[snapshot_interval]: + if required_snapshot_interval in _STATE_SNAPSHOT_INTERVALS_BY_POLICY[snapshot_interval]: return state_snapshot_policy return None @@ -231,6 +231,7 @@ def _build_variable_state( def _build_state_snapshot_payload( conversation: Conversation, *, + include_conversation_state: bool, status: object = _UNSET, status_handled: object = _UNSET, ) -> dict[str, Any]: @@ -243,28 +244,29 @@ def _build_state_snapshot_payload( status=status, status_handled=status_handled, ) - conversation_state = ( - serialize_conversation_state(conversation) - if status is _UNSET and status_handled is _UNSET - else _serialize_conversation_state_with_runtime_overrides( - conversation, - status=cast( - Optional[ExecutionStatus], - conversation.status if status is _UNSET else status, - ), - status_handled=cast( - bool, - conversation.status_handled if status_handled is _UNSET else status_handled, - ), - ) - ) - return { + payload = { "runtime": _STATE_SNAPSHOT_RUNTIME, "schema_version": _STATE_SNAPSHOT_SCHEMA_VERSION, - "conversation_state": conversation_state, "conversation": dumped_state["conversation"], "execution": dumped_state["execution"], } + if include_conversation_state: + payload["conversation_state"] = ( + serialize_conversation_state(conversation) + if status is _UNSET and status_handled is _UNSET + else _serialize_conversation_state_with_runtime_overrides( + conversation, + status=cast( + Optional[ExecutionStatus], + conversation.status if status is _UNSET else status, + ), + status_handled=cast( + bool, + conversation.status_handled if status_handled is _UNSET else status_handled, + ), + ) + ) + return payload def _record_state_snapshot( @@ -290,25 +292,24 @@ def _record_state_snapshot( # conversation-turn checkpoints for the same run. return - try: - record_event( - StateSnapshotEvent( - conversation_id=conversation.conversation_id, - state_snapshot=_build_state_snapshot_payload( - conversation, - status=status, - status_handled=status_handled, + # Snapshot delivery is part of the execution contract: downstream listener + # or storage failures must propagate to the caller instead of being silently + # converted into best-effort behavior. + record_event( + StateSnapshotEvent( + conversation_id=conversation.conversation_id, + state_snapshot=_build_state_snapshot_payload( + conversation, + include_conversation_state=( + required_snapshot_interval == StateSnapshotInterval.CONVERSATION_TURNS ), - extra_state=_build_extra_state(conversation, state_snapshot_policy), - variable_state=_build_variable_state(conversation, state_snapshot_policy), - ) - ) - except Exception: - logger.warning( - "Failed to emit state snapshot for conversation '%s'", - conversation.conversation_id, - exc_info=True, + status=status, + status_handled=status_handled, + ), + extra_state=_build_extra_state(conversation, state_snapshot_policy), + variable_state=_build_variable_state(conversation, state_snapshot_policy), ) + ) class StateSnapshotEventListener(EventListener): diff --git a/wayflowcore/tests/agentspec/test_state_snapshot_tracing_agent.py b/wayflowcore/tests/agentspec/test_state_snapshot_tracing_agent.py index fe3350996..090a979a0 100644 --- a/wayflowcore/tests/agentspec/test_state_snapshot_tracing_agent.py +++ b/wayflowcore/tests/agentspec/test_state_snapshot_tracing_agent.py @@ -4,9 +4,8 @@ # (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0) or Universal Permissive License # (UPL) 1.0 (LICENSE-UPL or https://oss.oracle.com/licenses/upl), at your option. -from contextlib import AbstractContextManager, ExitStack from dataclasses import asdict, dataclass -from typing import Any, Sequence, cast +from typing import Any, cast from pyagentspec.adapters.wayflow import AgentSpecLoader from pyagentspec.agent import Agent as AgentSpecAgent @@ -19,18 +18,23 @@ from pyagentspec.tracing.spans import AgentExecutionSpan as AgentSpecAgentExecutionSpan from pyagentspec.tracing.spans import LlmGenerationSpan as AgentSpecLlmGenerationSpan from pyagentspec.tracing.spans import Span as AgentSpecSpan -from pyagentspec.tracing.trace import Trace as AgentSpecTrace from wayflowcore import Agent as WayflowAgent -from wayflowcore.agentspec.tracing import AgentSpecEventListener -from wayflowcore.events.eventlistener import register_event_listeners from wayflowcore.executors.executionstatus import UserMessageRequestStatus from wayflowcore.executors.statesnapshotpolicy import StateSnapshotInterval, StateSnapshotPolicy from wayflowcore.models.vllmmodel import VllmModel -from wayflowcore.serialization import deserialize_conversation, dump_conversation_state - +from wayflowcore.serialization import dump_conversation_state + +from ..testhelpers.agentspec_tracing import ( + SnapshotEventsSeenAtSpanEndRecorder, + events, + execute_with_trace, + single_span, + spans, +) from ..testhelpers.patching import patch_llm from ..testhelpers.statesnapshots import ( + restore_conversation_from_snapshot_payload, snapshot_message, snapshot_status_types, ) @@ -122,79 +126,6 @@ async def shutdown_async(self) -> None: return None -class SnapshotSpanRecorder(AgentSpecSpanProcessor): - def __init__(self) -> None: - super().__init__() - self.started_spans: list[AgentSpecSpan] = [] - self.ended_spans: list[AgentSpecSpan] = [] - - def on_start(self, span: AgentSpecSpan) -> None: - self.started_spans.append(span) - - async def on_start_async(self, span: AgentSpecSpan) -> None: - self.started_spans.append(span) - - def on_end(self, span: AgentSpecSpan) -> None: - self.ended_spans.append(span) - - async def on_end_async(self, span: AgentSpecSpan) -> None: - self.ended_spans.append(span) - - def on_event(self, event: AgentSpecEvent, span: AgentSpecSpan) -> None: - return None - - async def on_event_async(self, event: AgentSpecEvent, span: AgentSpecSpan) -> None: - return None - - def startup(self) -> None: - return None - - def shutdown(self) -> None: - return None - - async def startup_async(self) -> None: - return None - - async def shutdown_async(self) -> None: - return None - - -class SnapshotEventsSeenAtSpanEndRecorder(AgentSpecSpanProcessor): - def __init__(self) -> None: - super().__init__() - self.events_by_span_id: dict[str, list[AgentSpecEvent]] = {} - - def on_start(self, span: AgentSpecSpan) -> None: - return None - - async def on_start_async(self, span: AgentSpecSpan) -> None: - return None - - def on_end(self, span: AgentSpecSpan) -> None: - self.events_by_span_id[span.id] = list(span.events) - - async def on_end_async(self, span: AgentSpecSpan) -> None: - self.events_by_span_id[span.id] = list(span.events) - - def on_event(self, event: AgentSpecEvent, span: AgentSpecSpan) -> None: - return None - - async def on_event_async(self, event: AgentSpecEvent, span: AgentSpecSpan) -> None: - return None - - def startup(self) -> None: - return None - - def shutdown(self) -> None: - return None - - async def startup_async(self) -> None: - return None - - async def shutdown_async(self) -> None: - return None - - _RETRIEVAL_INPUTS = { "input": "How many orders last week?", "thread_id": "thread-123", @@ -239,49 +170,6 @@ def _build_retrieval_agent_state( ) -def _execute_with_trace( - conversation, - *, - state_snapshot_policy, - span_processors: Sequence[AgentSpecSpanProcessor] = (), - contexts: Sequence[AbstractContextManager[Any]] = (), -): - span_recorder = SnapshotSpanRecorder() - listener = AgentSpecEventListener() - - with ExitStack() as stack: - for context in contexts: - stack.enter_context(context) - stack.enter_context(AgentSpecTrace(span_processors=[span_recorder, *span_processors])) - stack.enter_context(register_event_listeners([listener])) - status = conversation.execute(state_snapshot_policy=state_snapshot_policy) - - return status, span_recorder - - -def _spans( - span_recorder: SnapshotSpanRecorder, - span_type: type[AgentSpecSpan], -) -> list[AgentSpecSpan]: - return [span for span in span_recorder.started_spans if isinstance(span, span_type)] - - -def _single_span( - span_recorder: SnapshotSpanRecorder, - span_type: type[AgentSpecSpan], -) -> AgentSpecSpan: - matching_spans = _spans(span_recorder, span_type) - assert len(matching_spans) == 1 - return matching_spans[0] - - -def _events( - span: AgentSpecSpan, - event_type: type[AgentSpecEvent], -) -> list[AgentSpecEvent]: - return [event for event in span.events if isinstance(event, event_type)] - - def test_agent_state_snapshots_support_the_agui_retrieval_export_flow() -> None: assistant_message = "I checked the warehouse and found 42 orders last week." wayflow_agent = cast( @@ -320,7 +208,7 @@ def build_extra_state(conversation) -> dict[str, Any]: ) } - status, span_recorder = _execute_with_trace( + status, span_recorder = execute_with_trace( conversation, state_snapshot_policy=StateSnapshotPolicy( state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS, @@ -332,17 +220,15 @@ def build_extra_state(conversation) -> dict[str, Any]: assert isinstance(status, UserMessageRequestStatus) - agent_span = _single_span(span_recorder, AgentSpecAgentExecutionSpan) - assert _events(agent_span, AgentSpecAgentExecutionStart) - state_snapshot_events = _events(agent_span, AgentSpecStateSnapshotEmitted) + agent_span = single_span(span_recorder, AgentSpecAgentExecutionSpan) + assert events(agent_span, AgentSpecAgentExecutionStart) + state_snapshot_events = events(agent_span, AgentSpecStateSnapshotEmitted) assert len(state_snapshot_events) == 2 final_snapshot_event = state_snapshot_events[-1] assert final_snapshot_event.state_snapshot is not None snapshot_payload = final_snapshot_event.state_snapshot - assert isinstance(snapshot_payload["conversation_state"], str) - restored_conversation = deserialize_conversation(snapshot_payload["conversation_state"]) - restored_snapshot = dump_conversation_state(restored_conversation) + restored_conversation = restore_conversation_from_snapshot_payload(snapshot_payload) runtime_messages = snapshot_payload["conversation"]["messages"] expected_agent_state = asdict( _build_retrieval_agent_state( @@ -353,39 +239,10 @@ def build_extra_state(conversation) -> dict[str, Any]: ) assert final_snapshot_event.conversation_id == conversation.conversation_id - assert snapshot_payload["runtime"] == "wayflow" - assert snapshot_payload["schema_version"] == 1 - assert restored_snapshot["conversation"] == snapshot_payload["conversation"] - assert ( - restored_snapshot["execution"]["current_step_name"] - == snapshot_payload["execution"]["current_step_name"] - ) - assert restored_snapshot["execution"]["status"] == snapshot_payload["execution"]["status"] - assert restored_snapshot["execution"]["status_handled"] is False - assert restored_snapshot["execution"]["curr_iter"] == snapshot_payload["execution"]["curr_iter"] - assert ( - restored_snapshot["execution"]["has_confirmed_conversation_exit"] - == snapshot_payload["execution"]["has_confirmed_conversation_exit"] - ) - assert ( - restored_snapshot["execution"]["tool_call_queue"] - == snapshot_payload["execution"]["tool_call_queue"] - ) - assert ( - restored_snapshot["execution"]["current_tool_request"] - == snapshot_payload["execution"]["current_tool_request"] - ) - assert ( - restored_snapshot["execution"]["current_flow_conversation"] - == snapshot_payload["execution"]["current_flow_conversation"] - ) - assert ( - restored_snapshot["execution"]["current_sub_component_conversations"] - == snapshot_payload["execution"]["current_sub_component_conversations"] - ) assert snapshot_payload["conversation"]["inputs"]["input"] == _RETRIEVAL_INPUTS["input"] assert runtime_messages[-1]["content"] == assistant_message assert final_snapshot_event.extra_state == {"agent_state": expected_agent_state} + assert dump_conversation_state(restored_conversation)["execution"]["status_handled"] is False assert len(agui_exporter.exported_snapshots) == 2 assert agui_exporter.exported_snapshots[-1] == ExportedAGUIStateSnapshot( @@ -404,7 +261,7 @@ def test_agent_node_turn_state_snapshots_are_mapped_into_the_agent_span_not_llm_ agent = WayflowAgent(llm=llm) conversation = agent.start_conversation() conversation.append_user_message("Hi") - status, span_recorder = _execute_with_trace( + status, span_recorder = execute_with_trace( conversation, state_snapshot_policy=StateSnapshotPolicy( state_snapshot_interval=StateSnapshotInterval.NODE_TURNS @@ -414,8 +271,8 @@ def test_agent_node_turn_state_snapshots_are_mapped_into_the_agent_span_not_llm_ assert isinstance(status, UserMessageRequestStatus) - agent_span = _single_span(span_recorder, AgentSpecAgentExecutionSpan) - state_snapshot_events = _events(agent_span, AgentSpecStateSnapshotEmitted) + agent_span = single_span(span_recorder, AgentSpecAgentExecutionSpan) + state_snapshot_events = events(agent_span, AgentSpecStateSnapshotEmitted) assert len(state_snapshot_events) == 4 assert [event.state_snapshot["execution"]["curr_iter"] for event in state_snapshot_events] == [ @@ -432,7 +289,7 @@ def test_agent_node_turn_state_snapshots_are_mapped_into_the_agent_span_not_llm_ ] assert snapshot_message(state_snapshot_events[-1]) == assistant_message - llm_spans = _spans(span_recorder, AgentSpecLlmGenerationSpan) + llm_spans = spans(span_recorder, AgentSpecLlmGenerationSpan) assert llm_spans assert not any( isinstance(event, AgentSpecStateSnapshotEmitted) @@ -449,7 +306,7 @@ def test_agent_final_state_snapshot_is_visible_to_span_processors_inside_on_end( conversation.append_user_message("Hi") on_end_recorder = SnapshotEventsSeenAtSpanEndRecorder() - status, span_recorder = _execute_with_trace( + status, span_recorder = execute_with_trace( conversation, state_snapshot_policy=StateSnapshotPolicy( state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS @@ -460,7 +317,7 @@ def test_agent_final_state_snapshot_is_visible_to_span_processors_inside_on_end( assert isinstance(status, UserMessageRequestStatus) - agent_span = _single_span(span_recorder, AgentSpecAgentExecutionSpan) + agent_span = single_span(span_recorder, AgentSpecAgentExecutionSpan) events_seen_at_end = on_end_recorder.events_by_span_id[agent_span.id] assert any(isinstance(event, AgentSpecAgentExecutionEnd) for event in events_seen_at_end) diff --git a/wayflowcore/tests/agentspec/test_state_snapshot_tracing_flow.py b/wayflowcore/tests/agentspec/test_state_snapshot_tracing_flow.py index 075390f7e..3ab126689 100644 --- a/wayflowcore/tests/agentspec/test_state_snapshot_tracing_flow.py +++ b/wayflowcore/tests/agentspec/test_state_snapshot_tracing_flow.py @@ -4,18 +4,12 @@ # (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0) or Universal Permissive License # (UPL) 1.0 (LICENSE-UPL or https://oss.oracle.com/licenses/upl), at your option. -from contextlib import AbstractContextManager, ExitStack -from typing import Any, Sequence - import pytest -from pyagentspec.tracing.events import Event as AgentSpecEvent from pyagentspec.tracing.events import FlowExecutionEnd as AgentSpecFlowExecutionEnd from pyagentspec.tracing.events import FlowExecutionStart as AgentSpecFlowExecutionStart from pyagentspec.tracing.events import StateSnapshotEmitted as AgentSpecStateSnapshotEmitted -from pyagentspec.tracing.spanprocessor import SpanProcessor as AgentSpecSpanProcessor from pyagentspec.tracing.spans import FlowExecutionSpan as AgentSpecFlowExecutionSpan from pyagentspec.tracing.spans import NodeExecutionSpan as AgentSpecNodeExecutionSpan -from pyagentspec.tracing.spans import Span as AgentSpecSpan from pyagentspec.tracing.spans import ToolExecutionSpan as AgentSpecToolExecutionSpan from pyagentspec.tracing.trace import Trace as AgentSpecTrace @@ -27,135 +21,27 @@ from wayflowcore.steps import CompleteStep, OutputMessageStep, ToolExecutionStep from wayflowcore.tools import ServerTool +from ..testhelpers.agentspec_tracing import ( + SnapshotEventsSeenAtSpanEndRecorder, + SnapshotSpanRecorder, + events, + execute_with_trace, + single_span, + spans, +) from ..testhelpers.statesnapshots import ( snapshot_message, snapshot_step_histories, ) -class SnapshotSpanRecorder(AgentSpecSpanProcessor): - def __init__(self) -> None: - super().__init__() - self.started_spans: list[AgentSpecSpan] = [] - self.ended_spans: list[AgentSpecSpan] = [] - - def on_start(self, span: AgentSpecSpan) -> None: - self.started_spans.append(span) - - async def on_start_async(self, span: AgentSpecSpan) -> None: - self.started_spans.append(span) - - def on_end(self, span: AgentSpecSpan) -> None: - self.ended_spans.append(span) - - async def on_end_async(self, span: AgentSpecSpan) -> None: - self.ended_spans.append(span) - - def on_event(self, event: AgentSpecEvent, span: AgentSpecSpan) -> None: - return None - - async def on_event_async(self, event: AgentSpecEvent, span: AgentSpecSpan) -> None: - return None - - def startup(self) -> None: - return None - - def shutdown(self) -> None: - return None - - async def startup_async(self) -> None: - return None - - async def shutdown_async(self) -> None: - return None - - -class SnapshotEventsSeenAtSpanEndRecorder(AgentSpecSpanProcessor): - def __init__(self) -> None: - super().__init__() - self.events_by_span_id: dict[str, list[AgentSpecEvent]] = {} - - def on_start(self, span: AgentSpecSpan) -> None: - return None - - async def on_start_async(self, span: AgentSpecSpan) -> None: - return None - - def on_end(self, span: AgentSpecSpan) -> None: - self.events_by_span_id[span.id] = list(span.events) - - async def on_end_async(self, span: AgentSpecSpan) -> None: - self.events_by_span_id[span.id] = list(span.events) - - def on_event(self, event: AgentSpecEvent, span: AgentSpecSpan) -> None: - return None - - async def on_event_async(self, event: AgentSpecEvent, span: AgentSpecSpan) -> None: - return None - - def startup(self) -> None: - return None - - def shutdown(self) -> None: - return None - - async def startup_async(self) -> None: - return None - - async def shutdown_async(self) -> None: - return None - - -def _execute_with_trace( - conversation, - *, - state_snapshot_policy, - span_processors: Sequence[AgentSpecSpanProcessor] = (), - contexts: Sequence[AbstractContextManager[Any]] = (), -): - span_recorder = SnapshotSpanRecorder() - listener = AgentSpecEventListener() - - with ExitStack() as stack: - for context in contexts: - stack.enter_context(context) - stack.enter_context(AgentSpecTrace(span_processors=[span_recorder, *span_processors])) - stack.enter_context(register_event_listeners([listener])) - status = conversation.execute(state_snapshot_policy=state_snapshot_policy) - - return status, span_recorder - - -def _spans( - span_recorder: SnapshotSpanRecorder, - span_type: type[AgentSpecSpan], -) -> list[AgentSpecSpan]: - return [span for span in span_recorder.started_spans if isinstance(span, span_type)] - - -def _single_span( - span_recorder: SnapshotSpanRecorder, - span_type: type[AgentSpecSpan], -) -> AgentSpecSpan: - matching_spans = _spans(span_recorder, span_type) - assert len(matching_spans) == 1 - return matching_spans[0] - - -def _events( - span: AgentSpecSpan, - event_type: type[AgentSpecEvent], -) -> list[AgentSpecEvent]: - return [event for event in span.events if isinstance(event, event_type)] - - def test_flow_state_snapshots_are_mapped_into_the_flow_span_before_flow_end() -> None: flow = Flow.from_steps( [OutputMessageStep(message_template="Hello"), CompleteStep(name="end")], step_names=["single_step", "end"], ) conversation = flow.start_conversation() - status, span_recorder = _execute_with_trace( + status, span_recorder = execute_with_trace( conversation, state_snapshot_policy=StateSnapshotPolicy( state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS, @@ -165,9 +51,9 @@ def test_flow_state_snapshots_are_mapped_into_the_flow_span_before_flow_end() -> assert isinstance(status, FinishedStatus) - flow_span = _single_span(span_recorder, AgentSpecFlowExecutionSpan) - assert _events(flow_span, AgentSpecFlowExecutionStart) - state_snapshot_events = _events(flow_span, AgentSpecStateSnapshotEmitted) + flow_span = single_span(span_recorder, AgentSpecFlowExecutionSpan) + assert events(flow_span, AgentSpecFlowExecutionStart) + state_snapshot_events = events(flow_span, AgentSpecStateSnapshotEmitted) assert len(state_snapshot_events) == 2 final_snapshot_event = state_snapshot_events[-1] @@ -187,7 +73,7 @@ def test_flow_final_state_snapshot_is_visible_to_span_processors_inside_on_end() conversation = flow.start_conversation() on_end_recorder = SnapshotEventsSeenAtSpanEndRecorder() - status, span_recorder = _execute_with_trace( + status, span_recorder = execute_with_trace( conversation, state_snapshot_policy=StateSnapshotPolicy( state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS @@ -197,7 +83,7 @@ def test_flow_final_state_snapshot_is_visible_to_span_processors_inside_on_end() assert isinstance(status, FinishedStatus) - flow_span = _single_span(span_recorder, AgentSpecFlowExecutionSpan) + flow_span = single_span(span_recorder, AgentSpecFlowExecutionSpan) events_seen_at_end = on_end_recorder.events_by_span_id[flow_span.id] assert any(isinstance(event, AgentSpecFlowExecutionEnd) for event in events_seen_at_end) @@ -235,8 +121,8 @@ async def test_flow_state_snapshots_are_mapped_into_the_flow_span_before_flow_en assert isinstance(status, FinishedStatus) - flow_span = _single_span(span_recorder, AgentSpecFlowExecutionSpan) - state_snapshot_events = _events(flow_span, AgentSpecStateSnapshotEmitted) + flow_span = single_span(span_recorder, AgentSpecFlowExecutionSpan) + state_snapshot_events = events(flow_span, AgentSpecStateSnapshotEmitted) assert len(state_snapshot_events) == 2 final_snapshot_event = state_snapshot_events[-1] @@ -254,7 +140,7 @@ def test_node_turn_state_snapshots_are_mapped_into_the_flow_span_not_node_spans( step_names=["single_step", "end"], ) conversation = flow.start_conversation() - status, span_recorder = _execute_with_trace( + status, span_recorder = execute_with_trace( conversation, state_snapshot_policy=StateSnapshotPolicy( state_snapshot_interval=StateSnapshotInterval.NODE_TURNS @@ -263,8 +149,8 @@ def test_node_turn_state_snapshots_are_mapped_into_the_flow_span_not_node_spans( assert isinstance(status, FinishedStatus) - flow_span = _single_span(span_recorder, AgentSpecFlowExecutionSpan) - flow_snapshot_events = _events(flow_span, AgentSpecStateSnapshotEmitted) + flow_span = single_span(span_recorder, AgentSpecFlowExecutionSpan) + flow_snapshot_events = events(flow_span, AgentSpecStateSnapshotEmitted) assert len(flow_snapshot_events) == 8 assert snapshot_step_histories(flow_snapshot_events) == [ @@ -277,7 +163,7 @@ def test_node_turn_state_snapshots_are_mapped_into_the_flow_span_not_node_spans( ["__StartStep__", "single_step", "end"], ["__StartStep__", "single_step", "end"], ] - node_spans = _spans(span_recorder, AgentSpecNodeExecutionSpan) + node_spans = spans(span_recorder, AgentSpecNodeExecutionSpan) assert node_spans assert not any( isinstance(event, AgentSpecStateSnapshotEmitted) @@ -335,19 +221,19 @@ def test_internal_flow_state_snapshots_follow_conversation_ownership_for_agent_s ] ) conversation = flow.start_conversation() - status, span_recorder = _execute_with_trace( + status, span_recorder = execute_with_trace( conversation, state_snapshot_policy=StateSnapshotPolicy(state_snapshot_interval=interval), ) assert isinstance(status, FinishedStatus) - flow_span = _single_span(span_recorder, AgentSpecFlowExecutionSpan) - flow_snapshot_events = _events(flow_span, AgentSpecStateSnapshotEmitted) + flow_span = single_span(span_recorder, AgentSpecFlowExecutionSpan) + flow_snapshot_events = events(flow_span, AgentSpecStateSnapshotEmitted) assert snapshot_step_histories(flow_snapshot_events) == expected_step_histories - tool_spans = _spans(span_recorder, AgentSpecToolExecutionSpan) - node_spans = _spans(span_recorder, AgentSpecNodeExecutionSpan) + tool_spans = spans(span_recorder, AgentSpecToolExecutionSpan) + node_spans = spans(span_recorder, AgentSpecNodeExecutionSpan) assert tool_spans assert node_spans assert not any( @@ -363,7 +249,7 @@ def test_off_policy_does_not_bridge_state_snapshots_into_agent_spec_spans() -> N step_names=["single_step", "end"], ) conversation = flow.start_conversation() - status, span_recorder = _execute_with_trace( + status, span_recorder = execute_with_trace( conversation, state_snapshot_policy=StateSnapshotPolicy( state_snapshot_interval=StateSnapshotInterval.OFF @@ -372,10 +258,10 @@ def test_off_policy_does_not_bridge_state_snapshots_into_agent_spec_spans() -> N assert isinstance(status, FinishedStatus) - flow_span = _single_span(span_recorder, AgentSpecFlowExecutionSpan) - assert _events(flow_span, AgentSpecFlowExecutionStart) - assert _events(flow_span, AgentSpecFlowExecutionEnd) - assert not _events(flow_span, AgentSpecStateSnapshotEmitted) + flow_span = single_span(span_recorder, AgentSpecFlowExecutionSpan) + assert events(flow_span, AgentSpecFlowExecutionStart) + assert events(flow_span, AgentSpecFlowExecutionEnd) + assert not events(flow_span, AgentSpecStateSnapshotEmitted) def test_only_the_opening_state_snapshot_is_exported_when_a_turn_raises() -> None: @@ -404,8 +290,8 @@ def test_only_the_opening_state_snapshot_is_exported_when_a_turn_raises() -> Non ) ) - flow_span = _single_span(span_recorder, AgentSpecFlowExecutionSpan) - state_snapshot_events = _events(flow_span, AgentSpecStateSnapshotEmitted) + flow_span = single_span(span_recorder, AgentSpecFlowExecutionSpan) + state_snapshot_events = events(flow_span, AgentSpecStateSnapshotEmitted) assert len(state_snapshot_events) == 1 assert state_snapshot_events[0].state_snapshot["execution"]["status"] is None diff --git a/wayflowcore/tests/agentspec/test_state_snapshot_tracing_nested.py b/wayflowcore/tests/agentspec/test_state_snapshot_tracing_nested.py index 067602000..6f91c00ba 100644 --- a/wayflowcore/tests/agentspec/test_state_snapshot_tracing_nested.py +++ b/wayflowcore/tests/agentspec/test_state_snapshot_tracing_nested.py @@ -10,58 +10,19 @@ from pyagentspec.tracing.spanprocessor import SpanProcessor as AgentSpecSpanProcessor from pyagentspec.tracing.spans import FlowExecutionSpan as AgentSpecFlowExecutionSpan from pyagentspec.tracing.spans import Span as AgentSpecSpan -from pyagentspec.tracing.trace import Trace as AgentSpecTrace -from wayflowcore.agentspec.tracing import AgentSpecEventListener -from wayflowcore.events.eventlistener import register_event_listeners from wayflowcore.executors.executionstatus import FinishedStatus from wayflowcore.executors.statesnapshotpolicy import StateSnapshotInterval, StateSnapshotPolicy from wayflowcore.flow import Flow from wayflowcore.steps import CompleteStep, FlowExecutionStep, OutputMessageStep, ParallelMapStep +from ..testhelpers.agentspec_tracing import execute_with_trace, spans from ..testhelpers.statesnapshots import ( snapshot_message, snapshot_runtime_conversation_ids, ) -class SnapshotSpanRecorder(AgentSpecSpanProcessor): - def __init__(self) -> None: - super().__init__() - self.started_spans: list[AgentSpecSpan] = [] - self.ended_spans: list[AgentSpecSpan] = [] - - def on_start(self, span: AgentSpecSpan) -> None: - self.started_spans.append(span) - - async def on_start_async(self, span: AgentSpecSpan) -> None: - self.started_spans.append(span) - - def on_end(self, span: AgentSpecSpan) -> None: - self.ended_spans.append(span) - - async def on_end_async(self, span: AgentSpecSpan) -> None: - self.ended_spans.append(span) - - def on_event(self, event: AgentSpecEvent, span: AgentSpecSpan) -> None: - return None - - async def on_event_async(self, event: AgentSpecEvent, span: AgentSpecSpan) -> None: - return None - - def startup(self) -> None: - return None - - def shutdown(self) -> None: - return None - - async def startup_async(self) -> None: - return None - - async def shutdown_async(self) -> None: - return None - - class SnapshotRuntimeIdsByConversationExporter(AgentSpecSpanProcessor): def __init__(self) -> None: super().__init__() @@ -117,21 +78,16 @@ def test_nested_flow_state_snapshots_stay_on_the_root_flow_span_for_shared_conve name="parent_flow", ) conversation = parent_flow.start_conversation() - span_recorder = SnapshotSpanRecorder() - - with AgentSpecTrace(span_processors=[span_recorder]): - with register_event_listeners([AgentSpecEventListener()]): - status = conversation.execute( - state_snapshot_policy=StateSnapshotPolicy( - state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS - ) - ) + status, span_recorder = execute_with_trace( + conversation, + state_snapshot_policy=StateSnapshotPolicy( + state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS + ), + ) assert isinstance(status, FinishedStatus) - flow_spans = [ - span for span in span_recorder.started_spans if isinstance(span, AgentSpecFlowExecutionSpan) - ] + flow_spans = spans(span_recorder, AgentSpecFlowExecutionSpan) assert len(flow_spans) == 2 flow_spans_by_name = { @@ -180,21 +136,16 @@ def test_nested_node_turn_state_snapshots_export_only_root_runtime_conversation_ name="parent_flow", ) conversation = parent_flow.start_conversation() - span_recorder = SnapshotSpanRecorder() - - with AgentSpecTrace(span_processors=[span_recorder]): - with register_event_listeners([AgentSpecEventListener()]): - status = conversation.execute( - state_snapshot_policy=StateSnapshotPolicy( - state_snapshot_interval=StateSnapshotInterval.NODE_TURNS - ) - ) + status, span_recorder = execute_with_trace( + conversation, + state_snapshot_policy=StateSnapshotPolicy( + state_snapshot_interval=StateSnapshotInterval.NODE_TURNS + ), + ) assert isinstance(status, FinishedStatus) - flow_spans = [ - span for span in span_recorder.started_spans if isinstance(span, AgentSpecFlowExecutionSpan) - ] + flow_spans = spans(span_recorder, AgentSpecFlowExecutionSpan) flow_spans_by_name = { next( event for event in span.events if isinstance(event, AgentSpecFlowExecutionStart) @@ -241,13 +192,13 @@ def test_parallel_map_snapshots_leave_agent_spec_exporters_with_root_resumable_s ) snapshot_runtime_id_exporter = SnapshotRuntimeIdsByConversationExporter() - with AgentSpecTrace(span_processors=[snapshot_runtime_id_exporter]): - with register_event_listeners([AgentSpecEventListener()]): - status = conversation.execute( - state_snapshot_policy=StateSnapshotPolicy( - state_snapshot_interval=StateSnapshotInterval.NODE_TURNS - ) - ) + status, _ = execute_with_trace( + conversation, + state_snapshot_policy=StateSnapshotPolicy( + state_snapshot_interval=StateSnapshotInterval.NODE_TURNS + ), + span_processors=[snapshot_runtime_id_exporter], + ) assert isinstance(status, FinishedStatus) assert snapshot_runtime_id_exporter.runtime_ids_by_conversation_id[conversation.conversation_id] diff --git a/wayflowcore/tests/events/test_state_snapshot_runtime_conversation_turns.py b/wayflowcore/tests/events/test_state_snapshot_runtime_conversation_turns.py index 3b05dd9c7..6d501aced 100644 --- a/wayflowcore/tests/events/test_state_snapshot_runtime_conversation_turns.py +++ b/wayflowcore/tests/events/test_state_snapshot_runtime_conversation_turns.py @@ -4,6 +4,7 @@ # (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0) or Universal Permissive License # (UPL) 1.0 (LICENSE-UPL or https://oss.oracle.com/licenses/upl), at your option. +import json from typing import Any import pytest @@ -20,7 +21,7 @@ from wayflowcore.executors.interrupts.executioninterrupt import InterruptedExecutionStatus from wayflowcore.executors.statesnapshotpolicy import StateSnapshotInterval, StateSnapshotPolicy from wayflowcore.flow import Flow -from wayflowcore.serialization import deserialize_conversation, dump_conversation_state +from wayflowcore.serialization import dump_conversation_state from wayflowcore.steps import CompleteStep, InputMessageStep, OutputMessageStep, ToolExecutionStep from wayflowcore.tools import ClientTool, ToolResult @@ -34,23 +35,11 @@ create_tool_flow_conversation, execute_with_state_snapshots, execute_with_state_snapshots_async, + restore_conversation_from_snapshot_payload, snapshot_status_types, ) -def _restore_conversation_from_snapshot_payload(snapshot_payload: dict[str, Any]): - assert snapshot_payload["runtime"] == "wayflow" - assert snapshot_payload["schema_version"] == 1 - assert isinstance(snapshot_payload["conversation_state"], str) - - restored_conversation = deserialize_conversation(snapshot_payload["conversation_state"]) - assert dump_conversation_state(restored_conversation) == { - "conversation": snapshot_payload["conversation"], - "execution": snapshot_payload["execution"], - } - return restored_conversation - - class _LiveConversationSnapshotObserver(EventListener): def __init__(self, conversation) -> None: self.conversation = conversation @@ -354,7 +343,7 @@ def test_conversation_turn_snapshot_payload_can_resume_waiting_for_client_tool_r assert isinstance(status, ToolRequestStatus) assert state_snapshot_events[-1].state_snapshot is not None - restored_conversation = _restore_conversation_from_snapshot_payload( + restored_conversation = restore_conversation_from_snapshot_payload( state_snapshot_events[-1].state_snapshot ) assert isinstance(restored_conversation.status, ToolRequestStatus) @@ -376,6 +365,40 @@ def test_conversation_turn_snapshot_payload_can_resume_waiting_for_client_tool_r assert tool_result_messages[0].content == "client-result" +def test_conversation_turn_snapshot_payload_round_trips_through_json_like_run_agent_input_state() -> ( + None +): + conversation = Flow.from_steps( + [ + InputMessageStep("Please answer"), + OutputMessageStep("done"), + ], + name="snapshot_user_json_roundtrip_resume_flow", + ).start_conversation() + + status, state_snapshot_events = execute_with_state_snapshots( + conversation, + state_snapshot_policy=StateSnapshotPolicy( + state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS + ), + ) + + assert isinstance(status, UserMessageRequestStatus) + assert state_snapshot_events[-1].state_snapshot is not None + + run_agent_input_state = json.loads(json.dumps(state_snapshot_events[-1].state_snapshot)) + restored_conversation = restore_conversation_from_snapshot_payload(run_agent_input_state) + restored_conversation.append_user_message("hello") + resumed_status = restored_conversation.execute() + + assert isinstance(resumed_status, FinishedStatus) + assert [message.content for message in restored_conversation.get_messages()] == [ + "Please answer", + "hello", + "done", + ] + + @pytest.mark.anyio async def test_conversation_turn_snapshot_payload_can_resume_waiting_for_user_input_async() -> None: conversation = Flow.from_steps( @@ -395,7 +418,7 @@ async def test_conversation_turn_snapshot_payload_can_resume_waiting_for_user_in assert isinstance(status, UserMessageRequestStatus) assert state_snapshot_events[-1].state_snapshot is not None - restored_conversation = _restore_conversation_from_snapshot_payload( + restored_conversation = restore_conversation_from_snapshot_payload( state_snapshot_events[-1].state_snapshot ) restored_conversation.append_user_message("hello") diff --git a/wayflowcore/tests/events/test_state_snapshot_runtime_internal_turns.py b/wayflowcore/tests/events/test_state_snapshot_runtime_internal_turns.py index e12a3a6b4..a4f8ba902 100644 --- a/wayflowcore/tests/events/test_state_snapshot_runtime_internal_turns.py +++ b/wayflowcore/tests/events/test_state_snapshot_runtime_internal_turns.py @@ -4,20 +4,27 @@ # (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0) or Universal Permissive License # (UPL) 1.0 (LICENSE-UPL or https://oss.oracle.com/licenses/upl), at your option. +import json + import pytest from wayflowcore.agent import Agent +from wayflowcore.conversation import Conversation from wayflowcore.events.eventlistener import register_event_listeners from wayflowcore.executors._events.event import EventType from wayflowcore.executors.executionstatus import FinishedStatus, UserMessageRequestStatus from wayflowcore.executors.interrupts.executioninterrupt import InterruptedExecutionStatus from wayflowcore.executors.statesnapshotpolicy import StateSnapshotInterval, StateSnapshotPolicy +from wayflowcore.flow import Flow +from wayflowcore.property import StringProperty +from wayflowcore.steps import CompleteStep, ToolExecutionStep +from wayflowcore.tools import ServerTool from ..conftest import disable_streaming from ..test_interrupts import OnEventExecutionInterrupt -from ..testhelpers.dummy import DummyModel from ..testhelpers.statesnapshots import ( SnapshotCollector, + SnapshotSerializableDummyModel, create_agent_conversation, create_output_flow_conversation, create_tool_calling_agent_conversation, @@ -29,10 +36,31 @@ ) -class _SerializableDummyModel(DummyModel): - @property - def config(self) -> dict[str, object]: - return {"model_id": self.model_id} +def _make_snapshot_size_stress_conversation() -> Conversation: + repeated_description = "serialized tool description " + ("D" * 1000) + tool_steps = [ + ToolExecutionStep( + tool=ServerTool( + name=f"tool_{index}", + description=f"{repeated_description}-{index}", + func=lambda index=index: f"value-{index}", + input_descriptors=[], + output_descriptors=[StringProperty(name=f"out_{index}")], + ), + name=f"step_{index}", + ) + for index in range(8) + ] + + return Flow.from_steps( + steps=[*tool_steps, CompleteStep(name="end")], + step_names=[*(f"step_{index}" for index in range(8)), "end"], + name="state_snapshot_size_stress_flow", + ).start_conversation() + + +def _snapshot_payload_num_bytes(snapshot_payload: dict[str, object]) -> int: + return len(json.dumps(snapshot_payload, sort_keys=True)) @pytest.mark.parametrize( @@ -190,8 +218,67 @@ def test_flow_node_turn_policy_uses_iteration_start_and_end_boundaries() -> None ] +def test_node_turn_policy_keeps_only_root_turn_checkpoints_resumable() -> None: + conversation = create_output_flow_conversation() + + status, state_snapshot_events = execute_with_state_snapshots( + conversation, + state_snapshot_policy=StateSnapshotPolicy( + state_snapshot_interval=StateSnapshotInterval.NODE_TURNS + ), + ) + + assert isinstance(status, FinishedStatus) + assert [ + isinstance(snapshot_event.state_snapshot.get("conversation_state"), str) + for snapshot_event in state_snapshot_events + ] == [True, False, False, False, False, False, False, True] + + +def test_node_turn_policy_stays_lightweight_under_snapshot_size_stress() -> None: + conversation = _make_snapshot_size_stress_conversation() + + status, state_snapshot_events = execute_with_state_snapshots( + conversation, + state_snapshot_policy=StateSnapshotPolicy( + state_snapshot_interval=StateSnapshotInterval.NODE_TURNS + ), + ) + + assert isinstance(status, FinishedStatus) + + snapshot_payloads = [snapshot_event.state_snapshot for snapshot_event in state_snapshot_events] + assert all(payload is not None for payload in snapshot_payloads) + snapshot_payloads = [payload for payload in snapshot_payloads if payload is not None] + internal_snapshot_payloads = snapshot_payloads[1:-1] + assert internal_snapshot_payloads + assert all("conversation_state" not in payload for payload in internal_snapshot_payloads) + + largest_root_conversation_state = max( + ( + payload["conversation_state"] + for payload in snapshot_payloads + if isinstance(payload.get("conversation_state"), str) + ), + key=len, + ) + actual_total_bytes = sum(_snapshot_payload_num_bytes(payload) for payload in snapshot_payloads) + inflated_total_bytes = sum( + ( + _snapshot_payload_num_bytes(payload) + if "conversation_state" in payload + else _snapshot_payload_num_bytes( + {**payload, "conversation_state": largest_root_conversation_state} + ) + ) + for payload in snapshot_payloads + ) + + assert actual_total_bytes < inflated_total_bytes * 0.2 + + def test_internal_snapshots_do_not_reuse_the_previous_turn_status() -> None: - llm = _SerializableDummyModel() + llm = SnapshotSerializableDummyModel() llm.set_next_output(["Hello from agent", "Hello again"]) conversation = Agent(llm=llm).start_conversation() conversation.append_user_message("Hi") @@ -265,6 +352,23 @@ def test_tool_turn_policy_records_real_tool_boundaries( assert snapshot_status_types(state_snapshot_events) == expected_status_types +def test_tool_turn_policy_keeps_only_root_turn_checkpoints_resumable() -> None: + conversation = create_tool_flow_conversation(lambda: "hi") + + status, state_snapshot_events = execute_with_state_snapshots( + conversation, + state_snapshot_policy=StateSnapshotPolicy( + state_snapshot_interval=StateSnapshotInterval.TOOL_TURNS + ), + ) + + assert isinstance(status, FinishedStatus) + assert [ + isinstance(snapshot_event.state_snapshot.get("conversation_state"), str) + for snapshot_event in state_snapshot_events + ] == [True, False, False, True] + + @pytest.mark.anyio @pytest.mark.parametrize( ( diff --git a/wayflowcore/tests/events/test_state_snapshot_runtime_nested.py b/wayflowcore/tests/events/test_state_snapshot_runtime_nested.py index e0a5dde95..657aafada 100644 --- a/wayflowcore/tests/events/test_state_snapshot_runtime_nested.py +++ b/wayflowcore/tests/events/test_state_snapshot_runtime_nested.py @@ -9,7 +9,6 @@ from wayflowcore.executors.executionstatus import FinishedStatus, UserMessageRequestStatus from wayflowcore.executors.statesnapshotpolicy import StateSnapshotInterval, StateSnapshotPolicy from wayflowcore.flow import Flow -from wayflowcore.serialization import deserialize_conversation from wayflowcore.steps import ( CompleteStep, FlowExecutionStep, @@ -23,6 +22,7 @@ create_parallel_child_flow, execute_with_state_snapshots, execute_with_state_snapshots_async, + restore_conversation_from_snapshot_payload, snapshot_message, snapshot_runtime_conversation_ids, snapshot_status_types, @@ -128,7 +128,7 @@ def test_nested_root_turn_snapshot_payload_can_resume_the_logical_parent_convers assert root_turn_snapshot_event.conversation_id == conversation.conversation_id assert root_turn_snapshot["conversation"]["id"] == conversation.id - restored_conversation = deserialize_conversation(root_turn_snapshot["conversation_state"]) + restored_conversation = restore_conversation_from_snapshot_payload(root_turn_snapshot) assert restored_conversation.id == conversation.id assert restored_conversation.conversation_id == conversation.conversation_id @@ -219,7 +219,7 @@ def test_parallel_root_turn_snapshot_payloads_can_resume_the_logical_parent_conv snapshot_payload = snapshot_event.state_snapshot assert snapshot_payload is not None - restored_conversation = deserialize_conversation(snapshot_payload["conversation_state"]) + restored_conversation = restore_conversation_from_snapshot_payload(snapshot_payload) resumed_status = restored_conversation.execute() assert isinstance(resumed_status, FinishedStatus) @@ -263,7 +263,7 @@ def test_parallel_map_emits_only_resumable_parent_turn_snapshots() -> None: snapshot_payload = snapshot_event.state_snapshot assert snapshot_payload is not None - restored_conversation = deserialize_conversation(snapshot_payload["conversation_state"]) + restored_conversation = restore_conversation_from_snapshot_payload(snapshot_payload) resumed_status = restored_conversation.execute() assert isinstance(resumed_status, FinishedStatus) diff --git a/wayflowcore/tests/events/test_state_snapshot_runtime_resilience.py b/wayflowcore/tests/events/test_state_snapshot_runtime_resilience.py index da31813cd..71f05ffeb 100644 --- a/wayflowcore/tests/events/test_state_snapshot_runtime_resilience.py +++ b/wayflowcore/tests/events/test_state_snapshot_runtime_resilience.py @@ -6,7 +6,12 @@ from dataclasses import dataclass +import pytest + from wayflowcore.conversation import Conversation +from wayflowcore.events import Event, EventListener +from wayflowcore.events.event import StateSnapshotEvent +from wayflowcore.events.eventlistener import register_event_listeners from wayflowcore.executors.executionstatus import FinishedStatus from wayflowcore.executors.statesnapshotpolicy import StateSnapshotInterval, StateSnapshotPolicy from wayflowcore.flow import Flow @@ -77,3 +82,28 @@ def test_state_snapshot_emission_survives_unserializable_variable_state() -> Non assert state_snapshot_events[0].variable_state == {"custom": None} assert state_snapshot_events[-1].variable_state is None assert snapshot_message(state_snapshot_events[-1]) == "done" + + +class _FailOnTerminalSnapshot(EventListener): + def __call__(self, event: Event) -> None: + if not isinstance(event, StateSnapshotEvent): + return + + execution_status = (event.state_snapshot or {}).get("execution", {}).get("status") + if execution_status is not None: + raise RuntimeError("snapshot sink failed") + + +def test_state_snapshot_listener_failures_propagate_to_the_caller() -> None: + conversation = create_output_flow_conversation() + + with register_event_listeners([_FailOnTerminalSnapshot()]): + with pytest.raises(RuntimeError, match="snapshot sink failed"): + conversation.execute( + state_snapshot_policy=StateSnapshotPolicy( + state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS + ) + ) + + assert conversation.get_last_message() is not None + assert conversation.get_last_message().content == "Hello" diff --git a/wayflowcore/tests/serialization/test_conversation_state_snapshot.py b/wayflowcore/tests/serialization/test_conversation_state_snapshot.py index 1dd89c66a..a92575df0 100644 --- a/wayflowcore/tests/serialization/test_conversation_state_snapshot.py +++ b/wayflowcore/tests/serialization/test_conversation_state_snapshot.py @@ -41,7 +41,10 @@ from wayflowcore.tools import ClientTool, ServerTool, ToolResult, register_server_tool from wayflowcore.variable import Variable -from ..testhelpers.statesnapshots import execute_with_state_snapshots +from ..testhelpers.statesnapshots import ( + execute_with_state_snapshots, + restore_conversation_from_snapshot_payload, +) class _UnserializableValue: @@ -364,8 +367,7 @@ def test_emitted_snapshot_conversation_state_restores_variable_dependent_continu assert isinstance(status, UserMessageRequestStatus) assert state_snapshot_events[-1].state_snapshot is not None snapshot_payload = state_snapshot_events[-1].state_snapshot - assert isinstance(snapshot_payload["conversation_state"], str) - restored_conversation = deserialize_conversation(snapshot_payload["conversation_state"]) + restored_conversation = restore_conversation_from_snapshot_payload(snapshot_payload) assert dump_variable_state(restored_conversation) == {"customer_name": "Alice"} diff --git a/wayflowcore/tests/testhelpers/agentspec_tracing.py b/wayflowcore/tests/testhelpers/agentspec_tracing.py new file mode 100644 index 000000000..7ecd00e0f --- /dev/null +++ b/wayflowcore/tests/testhelpers/agentspec_tracing.py @@ -0,0 +1,132 @@ +# Copyright © 2026 Oracle and/or its affiliates. +# +# This software is under the Apache License 2.0 +# (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0) or Universal Permissive License +# (UPL) 1.0 (LICENSE-UPL or https://oss.oracle.com/licenses/upl), at your option. + +from contextlib import AbstractContextManager, ExitStack +from typing import Any, Sequence + +from pyagentspec.tracing.events import Event as AgentSpecEvent +from pyagentspec.tracing.spanprocessor import SpanProcessor as AgentSpecSpanProcessor +from pyagentspec.tracing.spans import Span as AgentSpecSpan +from pyagentspec.tracing.trace import Trace as AgentSpecTrace + +from wayflowcore.agentspec.tracing import AgentSpecEventListener +from wayflowcore.events.eventlistener import register_event_listeners + + +class SnapshotSpanRecorder(AgentSpecSpanProcessor): + def __init__(self) -> None: + super().__init__() + self.started_spans: list[AgentSpecSpan] = [] + self.ended_spans: list[AgentSpecSpan] = [] + + def on_start(self, span: AgentSpecSpan) -> None: + self.started_spans.append(span) + + async def on_start_async(self, span: AgentSpecSpan) -> None: + self.started_spans.append(span) + + def on_end(self, span: AgentSpecSpan) -> None: + self.ended_spans.append(span) + + async def on_end_async(self, span: AgentSpecSpan) -> None: + self.ended_spans.append(span) + + def on_event(self, event: AgentSpecEvent, span: AgentSpecSpan) -> None: + return None + + async def on_event_async(self, event: AgentSpecEvent, span: AgentSpecSpan) -> None: + return None + + def startup(self) -> None: + return None + + def shutdown(self) -> None: + return None + + async def startup_async(self) -> None: + return None + + async def shutdown_async(self) -> None: + return None + + +class SnapshotEventsSeenAtSpanEndRecorder(AgentSpecSpanProcessor): + def __init__(self) -> None: + super().__init__() + self.events_by_span_id: dict[str, list[AgentSpecEvent]] = {} + + def on_start(self, span: AgentSpecSpan) -> None: + return None + + async def on_start_async(self, span: AgentSpecSpan) -> None: + return None + + def on_end(self, span: AgentSpecSpan) -> None: + self.events_by_span_id[span.id] = list(span.events) + + async def on_end_async(self, span: AgentSpecSpan) -> None: + self.events_by_span_id[span.id] = list(span.events) + + def on_event(self, event: AgentSpecEvent, span: AgentSpecSpan) -> None: + return None + + async def on_event_async(self, event: AgentSpecEvent, span: AgentSpecSpan) -> None: + return None + + def startup(self) -> None: + return None + + def shutdown(self) -> None: + return None + + async def startup_async(self) -> None: + return None + + async def shutdown_async(self) -> None: + return None + + +def execute_with_trace( + conversation, + *, + state_snapshot_policy, + span_processors: Sequence[AgentSpecSpanProcessor] = (), + contexts: Sequence[AbstractContextManager[Any]] = (), +): + span_recorder = SnapshotSpanRecorder() + listener = AgentSpecEventListener() + + with ExitStack() as stack: + for context in contexts: + stack.enter_context(context) + stack.enter_context(AgentSpecTrace(span_processors=[span_recorder, *span_processors])) + stack.enter_context(register_event_listeners([listener])) + status = conversation.execute(state_snapshot_policy=state_snapshot_policy) + + return status, span_recorder + + +def spans( + span_recorder: SnapshotSpanRecorder, + span_type: type[AgentSpecSpan], +) -> list[AgentSpecSpan]: + return [span for span in span_recorder.started_spans if isinstance(span, span_type)] + + +def single_span( + span_recorder: SnapshotSpanRecorder, + span_type: type[AgentSpecSpan], +) -> AgentSpecSpan: + matching_spans = spans(span_recorder, span_type) + assert len(matching_spans) == 1 + return matching_spans[0] + + +def events( + span: AgentSpecSpan, + event_type: type[AgentSpecEvent], +) -> list[AgentSpecEvent]: + return [event for event in span.events if isinstance(event, event_type)] diff --git a/wayflowcore/tests/testhelpers/statesnapshots.py b/wayflowcore/tests/testhelpers/statesnapshots.py index f9573136e..3ae28040a 100644 --- a/wayflowcore/tests/testhelpers/statesnapshots.py +++ b/wayflowcore/tests/testhelpers/statesnapshots.py @@ -22,6 +22,7 @@ from wayflowcore.flow import Flow from wayflowcore.messagelist import Message, MessageType from wayflowcore.property import AnyProperty, StringProperty +from wayflowcore.serialization import deserialize_conversation, dump_conversation_state from wayflowcore.serialization.serializer import SerializableNeedToBeImplementedMixin from wayflowcore.steps import ( AgentExecutionStep, @@ -45,7 +46,7 @@ def __call__(self, event: Event) -> None: self.state_snapshot_events.append(event) -class _SnapshotSerializableDummyModel(DummyModel): +class SnapshotSerializableDummyModel(DummyModel): @property def config(self) -> dict[str, Any]: return {"model_id": self.model_id} @@ -167,7 +168,7 @@ def create_output_flow_conversation(message: str = "Hello") -> Conversation: def create_agent_conversation(message: str = "Hello from agent") -> Conversation: - llm = _SnapshotSerializableDummyModel() + llm = SnapshotSerializableDummyModel() llm.set_next_output(message) conversation = Agent(llm=llm).start_conversation() conversation.append_user_message("Hi") @@ -180,7 +181,7 @@ def do_nothing_tool() -> str: """Do nothing tool.""" return "Tool called successfully" - llm = _SnapshotSerializableDummyModel() + llm = SnapshotSerializableDummyModel() llm.set_next_output( { "Please use the do_nothing_tool": Message( @@ -196,7 +197,7 @@ def do_nothing_tool() -> str: def create_nested_agent_step_flow_conversation() -> Conversation: - llm = _SnapshotSerializableDummyModel() + llm = SnapshotSerializableDummyModel() llm.set_next_output("agent answer") child_agent = Agent(llm=llm) conversation = Flow.from_steps( @@ -247,3 +248,18 @@ def assert_terminal_snapshot( assert snapshot_status_types(snapshot_events)[-1] == expected_status_type assert snapshot_message(snapshot_events[-1]) == expected_message assert snapshot_events[-1].state_snapshot["execution"]["status_handled"] is False + + +def restore_conversation_from_snapshot_payload( + snapshot_payload: dict[str, Any], +) -> Conversation: + assert snapshot_payload["runtime"] == "wayflow" + assert snapshot_payload["schema_version"] == 1 + assert isinstance(snapshot_payload["conversation_state"], str) + + restored_conversation = deserialize_conversation(snapshot_payload["conversation_state"]) + assert dump_conversation_state(restored_conversation) == { + "conversation": snapshot_payload["conversation"], + "execution": snapshot_payload["execution"], + } + return restored_conversation From a85c333d9a517c9041cdb80727d648bcf6d26fcc Mon Sep 17 00:00:00 2001 From: Son Le Date: Fri, 20 Mar 2026 09:55:54 +0100 Subject: [PATCH 12/13] Tighten state snapshot JSON contract --- .../executors/_statesnapshot_eventlistener.py | 43 +++--------- .../wayflowcore/serialization/conversation.py | 11 ++-- .../test_state_snapshot_tracing_flow.py | 43 ++++++++++++ ...ate_snapshot_runtime_conversation_turns.py | 4 +- .../test_state_snapshot_runtime_resilience.py | 66 ++++++++++--------- .../test_conversation_state_snapshot.py | 41 +++++++++++- 6 files changed, 137 insertions(+), 71 deletions(-) diff --git a/wayflowcore/src/wayflowcore/executors/_statesnapshot_eventlistener.py b/wayflowcore/src/wayflowcore/executors/_statesnapshot_eventlistener.py index dd61ee228..ea1a7fb49 100644 --- a/wayflowcore/src/wayflowcore/executors/_statesnapshot_eventlistener.py +++ b/wayflowcore/src/wayflowcore/executors/_statesnapshot_eventlistener.py @@ -7,7 +7,6 @@ from __future__ import annotations import json -import logging from contextlib import contextmanager from contextvars import ContextVar from typing import Any, Dict, Iterator, Optional, cast @@ -46,8 +45,6 @@ ) from wayflowcore.tracing.span import AgentExecutionSpan, FlowExecutionSpan, get_current_span -logger = logging.getLogger(__name__) - _STATE_SNAPSHOT_RUNTIME = "wayflow" _STATE_SNAPSHOT_SCHEMA_VERSION = 1 _STATE_SNAPSHOT_INTERVALS_BY_POLICY = { @@ -163,34 +160,20 @@ def _build_extra_state( if state_snapshot_policy.extra_state_builder is None: return None - try: - extra_state = state_snapshot_policy.extra_state_builder(conversation) - except Exception: - logger.warning( - "Failed to build extra snapshot state for conversation '%s'", - conversation.conversation_id, - exc_info=True, - ) - return None - + extra_state = state_snapshot_policy.extra_state_builder(conversation) if extra_state is None: return None if not isinstance(extra_state, dict): - logger.warning( - "Expected extra snapshot state to be a dictionary for conversation '%s'", - conversation.conversation_id, + raise TypeError( + f"Expected extra snapshot state for conversation '{conversation.conversation_id}' to be a dictionary" ) - return None try: - return cast(Dict[str, Any], json.loads(json.dumps(extra_state))) - except Exception: - logger.warning( - "Extra snapshot state is not JSON serializable for conversation '%s'", - conversation.conversation_id, - exc_info=True, - ) - return None + return cast(Dict[str, Any], json.loads(json.dumps(extra_state, allow_nan=False))) + except (TypeError, ValueError) as exc: + raise TypeError( + f"Extra snapshot state for conversation '{conversation.conversation_id}' must be strict JSON-serializable" + ) from exc def _get_snapshot_policy_for_interval( @@ -217,15 +200,7 @@ def _build_variable_state( if not state_snapshot_policy.include_variable_state: return None - try: - return dump_variable_state(conversation) - except Exception: - logger.warning( - "Failed to dump variable state for conversation '%s'", - conversation.conversation_id, - exc_info=True, - ) - return None + return dump_variable_state(conversation) def _build_state_snapshot_payload( diff --git a/wayflowcore/src/wayflowcore/serialization/conversation.py b/wayflowcore/src/wayflowcore/serialization/conversation.py index feafa8db2..5756b3aa4 100644 --- a/wayflowcore/src/wayflowcore/serialization/conversation.py +++ b/wayflowcore/src/wayflowcore/serialization/conversation.py @@ -7,6 +7,7 @@ from __future__ import annotations import json +import math import warnings from contextlib import contextmanager from datetime import datetime @@ -65,8 +66,10 @@ def _dump_json_compatible_value(value: Any) -> Any: from wayflowcore.conversation import Conversation dumped_value: Any - if value is None or isinstance(value, (bool, int, float, str)): + if value is None or isinstance(value, (bool, int, str)): dumped_value = value + elif isinstance(value, float): + dumped_value = value if math.isfinite(value) else stringify(value) elif isinstance(value, bytes): dumped_value = value.decode("utf-8", errors="replace") elif isinstance(value, datetime): @@ -296,7 +299,7 @@ def dump_conversation_state( status_handled: object = _UNSET, ) -> dict[str, Any]: """ - Return a JSON-serializable runtime snapshot of a conversation. + Return a strict-JSON-serializable runtime snapshot of a conversation. The returned dictionary is a lightweight inspection/tracing view. It captures the user-visible conversation state and the runtime execution state without @@ -530,7 +533,7 @@ def deserialize_conversation( def dump_variable_state(conversation: "Conversation") -> Optional[dict[str, Any]]: """ - Return the JSON-serializable runtime-owned variable state for a conversation. + Return the strict-JSON-serializable runtime-owned variable state for a conversation. Only flow conversations expose runtime variable storage. For other conversation types, this returns ``None``. @@ -543,7 +546,7 @@ def dump_variable_state(conversation: "Conversation") -> Optional[dict[str, Any] Returns ------- dict[str, Any] | None - JSON-compatible mapping of variable names to values for flow + Strict-JSON-compatible mapping of variable names to values for flow conversations, otherwise ``None``. Raises diff --git a/wayflowcore/tests/agentspec/test_state_snapshot_tracing_flow.py b/wayflowcore/tests/agentspec/test_state_snapshot_tracing_flow.py index 3ab126689..4a3a8d05a 100644 --- a/wayflowcore/tests/agentspec/test_state_snapshot_tracing_flow.py +++ b/wayflowcore/tests/agentspec/test_state_snapshot_tracing_flow.py @@ -4,6 +4,8 @@ # (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0) or Universal Permissive License # (UPL) 1.0 (LICENSE-UPL or https://oss.oracle.com/licenses/upl), at your option. +import json + import pytest from pyagentspec.tracing.events import FlowExecutionEnd as AgentSpecFlowExecutionEnd from pyagentspec.tracing.events import FlowExecutionStart as AgentSpecFlowExecutionStart @@ -18,6 +20,7 @@ from wayflowcore.executors.executionstatus import FinishedStatus from wayflowcore.executors.statesnapshotpolicy import StateSnapshotInterval, StateSnapshotPolicy from wayflowcore.flow import Flow +from wayflowcore.property import AnyProperty from wayflowcore.steps import CompleteStep, OutputMessageStep, ToolExecutionStep from wayflowcore.tools import ServerTool @@ -101,6 +104,46 @@ def test_flow_final_state_snapshot_is_visible_to_span_processors_inside_on_end() assert snapshot_message(events_seen_at_end[-1]) == "Hello" +def test_flow_state_snapshots_normalize_non_finite_floats_before_agent_spec_export() -> None: + flow = Flow.from_steps( + [ + ToolExecutionStep( + tool=ServerTool( + name="echo", + description="Echo input", + func=lambda bad: str(bad), + input_descriptors=[AnyProperty(name="bad")], + ) + ), + CompleteStep(name="end"), + ] + ) + conversation = flow.start_conversation(inputs={"bad": float("nan")}) + + status, span_recorder = execute_with_trace( + conversation, + state_snapshot_policy=StateSnapshotPolicy( + state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS, + include_variable_state=False, + ), + ) + + assert isinstance(status, FinishedStatus) + + flow_span = single_span(span_recorder, AgentSpecFlowExecutionSpan) + state_snapshot_events = events(flow_span, AgentSpecStateSnapshotEmitted) + + assert len(state_snapshot_events) == 2 + assert all( + event.state_snapshot["conversation"]["inputs"]["bad"] == "NaN" + for event in state_snapshot_events + ) + assert all( + json.loads(json.dumps(event.state_snapshot, allow_nan=False)) == event.state_snapshot + for event in state_snapshot_events + ) + + @pytest.mark.anyio async def test_flow_state_snapshots_are_mapped_into_the_flow_span_before_flow_end_async() -> None: flow = Flow.from_steps( diff --git a/wayflowcore/tests/events/test_state_snapshot_runtime_conversation_turns.py b/wayflowcore/tests/events/test_state_snapshot_runtime_conversation_turns.py index 6d501aced..8125ea735 100644 --- a/wayflowcore/tests/events/test_state_snapshot_runtime_conversation_turns.py +++ b/wayflowcore/tests/events/test_state_snapshot_runtime_conversation_turns.py @@ -386,7 +386,9 @@ def test_conversation_turn_snapshot_payload_round_trips_through_json_like_run_ag assert isinstance(status, UserMessageRequestStatus) assert state_snapshot_events[-1].state_snapshot is not None - run_agent_input_state = json.loads(json.dumps(state_snapshot_events[-1].state_snapshot)) + run_agent_input_state = json.loads( + json.dumps(state_snapshot_events[-1].state_snapshot, allow_nan=False) + ) restored_conversation = restore_conversation_from_snapshot_payload(run_agent_input_state) restored_conversation.append_user_message("hello") resumed_status = restored_conversation.execute() diff --git a/wayflowcore/tests/events/test_state_snapshot_runtime_resilience.py b/wayflowcore/tests/events/test_state_snapshot_runtime_resilience.py index 71f05ffeb..1a1e5b19c 100644 --- a/wayflowcore/tests/events/test_state_snapshot_runtime_resilience.py +++ b/wayflowcore/tests/events/test_state_snapshot_runtime_resilience.py @@ -12,7 +12,6 @@ from wayflowcore.events import Event, EventListener from wayflowcore.events.event import StateSnapshotEvent from wayflowcore.events.eventlistener import register_event_listeners -from wayflowcore.executors.executionstatus import FinishedStatus from wayflowcore.executors.statesnapshotpolicy import StateSnapshotInterval, StateSnapshotPolicy from wayflowcore.flow import Flow from wayflowcore.property import AnyProperty @@ -20,11 +19,7 @@ from wayflowcore.steps import CompleteStep, OutputMessageStep, VariableWriteStep from wayflowcore.variable import Variable -from ..testhelpers.statesnapshots import ( - create_output_flow_conversation, - execute_with_state_snapshots, - snapshot_message, -) +from ..testhelpers.statesnapshots import create_output_flow_conversation @dataclass(frozen=True) @@ -47,41 +42,50 @@ def _create_non_json_variable_state_conversation() -> Conversation: ).start_conversation(inputs={custom_variable.name: _SerializableButNotJson(value="x")}) -def test_state_snapshot_emission_survives_broken_extra_state_builder() -> None: +def test_state_snapshot_emission_propagates_extra_state_builder_failures() -> None: def broken_builder(_conversation: Conversation) -> dict[str, object]: raise RuntimeError("boom") conversation = create_output_flow_conversation() - status, state_snapshot_events = execute_with_state_snapshots( - conversation, - state_snapshot_policy=StateSnapshotPolicy( - state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS, - extra_state_builder=broken_builder, - ), - ) + with pytest.raises(RuntimeError, match="boom"): + conversation.execute( + state_snapshot_policy=StateSnapshotPolicy( + state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS, + extra_state_builder=broken_builder, + ) + ) + + assert conversation.get_last_message() is None + + +def test_state_snapshot_emission_rejects_non_strict_json_extra_state() -> None: + conversation = create_output_flow_conversation() - assert isinstance(status, FinishedStatus) - assert len(state_snapshot_events) == 2 - assert all(snapshot_event.extra_state is None for snapshot_event in state_snapshot_events) + with pytest.raises(TypeError, match="Extra snapshot state .* strict JSON-serializable"): + conversation.execute( + state_snapshot_policy=StateSnapshotPolicy( + state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS, + extra_state_builder=lambda _conversation: {"ui": {"preview_count": float("nan")}}, + ) + ) + + assert conversation.get_last_message() is None -def test_state_snapshot_emission_survives_unserializable_variable_state() -> None: +def test_state_snapshot_emission_propagates_unserializable_variable_state() -> None: conversation = _create_non_json_variable_state_conversation() - status, state_snapshot_events = execute_with_state_snapshots( - conversation, - state_snapshot_policy=StateSnapshotPolicy( - state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS, - include_variable_state=True, - ), - ) - - assert isinstance(status, FinishedStatus) - assert len(state_snapshot_events) == 2 - assert state_snapshot_events[0].variable_state == {"custom": None} - assert state_snapshot_events[-1].variable_state is None - assert snapshot_message(state_snapshot_events[-1]) == "done" + with pytest.raises(TypeError, match="Variable 'custom' contains a non-JSON-serializable"): + conversation.execute( + state_snapshot_policy=StateSnapshotPolicy( + state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS, + include_variable_state=True, + ) + ) + + assert conversation.get_last_message() is not None + assert conversation.get_last_message().content == "done" class _FailOnTerminalSnapshot(EventListener): diff --git a/wayflowcore/tests/serialization/test_conversation_state_snapshot.py b/wayflowcore/tests/serialization/test_conversation_state_snapshot.py index a92575df0..5fb86666d 100644 --- a/wayflowcore/tests/serialization/test_conversation_state_snapshot.py +++ b/wayflowcore/tests/serialization/test_conversation_state_snapshot.py @@ -66,6 +66,24 @@ def _build_snapshot_flow(custom_variable: Variable) -> Flow: ) +def _build_non_finite_input_snapshot_flow() -> Flow: + return Flow.from_steps( + [ + ToolExecutionStep( + tool=ServerTool( + name="echo", + description="Echo input", + func=lambda bad: str(bad), + input_descriptors=[AnyProperty(name="bad")], + output_descriptors=[StringProperty(name="out")], + ) + ), + CompleteStep(name="end"), + ], + name="non_finite_snapshot_flow", + ) + + def _walk_scalars(value: Any): if isinstance(value, dict): for inner_value in value.values(): @@ -93,7 +111,7 @@ def test_dump_conversation_state_is_json_serializable_and_lightweight() -> None: serialized_conversation_state = serialize_conversation_state(conversation) deserialized_conversation_state = deserialize_conversation_state(serialized_conversation_state) - assert json.loads(json.dumps(snapshot)) == snapshot + assert json.loads(json.dumps(snapshot, allow_nan=False)) == snapshot assert deserialized_conversation_state["_component_type"] == conversation.__class__.__name__ assert variable_state == {"custom": "custom-value"} assert snapshot["conversation"]["component_type"] == "Flow" @@ -176,6 +194,27 @@ def test_dump_variable_state_rejects_non_json_serializable_values() -> None: dump_variable_state(conversation) +@pytest.mark.parametrize( + ("value", "expected_dumped_value"), + [ + pytest.param(float("nan"), "NaN", id="nan"), + pytest.param(float("inf"), "Infinity", id="infinity"), + pytest.param(float("-inf"), "-Infinity", id="negative-infinity"), + ], +) +def test_dump_conversation_state_normalizes_non_finite_floats_for_strict_json( + value: float, + expected_dumped_value: str, +) -> None: + flow = _build_non_finite_input_snapshot_flow() + conversation = flow.start_conversation(inputs={"bad": value}) + + snapshot = dump_conversation_state(conversation) + + assert json.loads(json.dumps(snapshot, allow_nan=False)) == snapshot + assert snapshot["conversation"]["inputs"]["bad"] == expected_dumped_value + + def test_conversation_state_roundtrip_preserves_pending_tool_results() -> None: client_tool = ClientTool( name="client_lookup", From 42bcdd66646ef8c8aa091d1f5830924c7d45c706 Mon Sep 17 00:00:00 2001 From: Son Le Date: Mon, 23 Mar 2026 10:59:36 +0100 Subject: [PATCH 13/13] Simplify state snapshot test infrastructure and coverage --- docs/wayflowcore/source/core/changelog.rst | 4 +- .../test_state_snapshot_tracing_agent.py | 377 +++++------- .../test_state_snapshot_tracing_flow.py | 453 ++++++++------- .../test_state_snapshot_tracing_nested.py | 209 ------- .../test_state_snapshot_event_tracing.py | 27 + .../test_state_snapshot_event_validation.py | 36 -- ...ate_snapshot_runtime_conversation_turns.py | 434 -------------- .../test_state_snapshot_runtime_core.py | 535 ++++++++++++++++++ ...t_state_snapshot_runtime_internal_turns.py | 490 ---------------- .../test_state_snapshot_runtime_nested.py | 331 ----------- .../test_state_snapshot_runtime_resilience.py | 113 ---- .../test_conversation_state_snapshot.py | 285 +++++----- .../tests/testhelpers/agentspec_tracing.py | 132 ----- .../testhelpers/state_snapshot_testutils.py | 84 +++ .../tests/testhelpers/statesnapshots.py | 265 --------- 15 files changed, 1140 insertions(+), 2635 deletions(-) delete mode 100644 wayflowcore/tests/agentspec/test_state_snapshot_tracing_nested.py delete mode 100644 wayflowcore/tests/events/test_state_snapshot_event_validation.py delete mode 100644 wayflowcore/tests/events/test_state_snapshot_runtime_conversation_turns.py create mode 100644 wayflowcore/tests/events/test_state_snapshot_runtime_core.py delete mode 100644 wayflowcore/tests/events/test_state_snapshot_runtime_internal_turns.py delete mode 100644 wayflowcore/tests/events/test_state_snapshot_runtime_nested.py delete mode 100644 wayflowcore/tests/events/test_state_snapshot_runtime_resilience.py delete mode 100644 wayflowcore/tests/testhelpers/agentspec_tracing.py create mode 100644 wayflowcore/tests/testhelpers/state_snapshot_testutils.py delete mode 100644 wayflowcore/tests/testhelpers/statesnapshots.py diff --git a/docs/wayflowcore/source/core/changelog.rst b/docs/wayflowcore/source/core/changelog.rst index 76b10992b..cd9cc1487 100644 --- a/docs/wayflowcore/source/core/changelog.rst +++ b/docs/wayflowcore/source/core/changelog.rst @@ -68,9 +68,9 @@ Possibly Breaking Changes Bug fixes ^^^^^^^^^ -* **Serialization imports:** +* **State snapshot test coverage:** - Reduced import coupling in conversation state serialization so the public conversation state helpers remain directly re-exported from ``wayflowcore.serialization``. + Reduced duplicated flow/agent, sync/async, nested-conversation, internal-turn, serialization, and Agent Spec tracing wrappers in the state snapshot tests, and moved the reusable state snapshot fixtures plus explicit tracing/runtime helpers into shared test helper modules/plugins. WayFlow 26.1.1 -------------- diff --git a/wayflowcore/tests/agentspec/test_state_snapshot_tracing_agent.py b/wayflowcore/tests/agentspec/test_state_snapshot_tracing_agent.py index 090a979a0..6ef34b9dd 100644 --- a/wayflowcore/tests/agentspec/test_state_snapshot_tracing_agent.py +++ b/wayflowcore/tests/agentspec/test_state_snapshot_tracing_agent.py @@ -4,12 +4,9 @@ # (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0) or Universal Permissive License # (UPL) 1.0 (LICENSE-UPL or https://oss.oracle.com/licenses/upl), at your option. -from dataclasses import asdict, dataclass -from typing import Any, cast +from contextlib import AbstractContextManager, ExitStack +from typing import Any, Sequence -from pyagentspec.adapters.wayflow import AgentSpecLoader -from pyagentspec.agent import Agent as AgentSpecAgent -from pyagentspec.llms import VllmConfig from pyagentspec.tracing.events import AgentExecutionEnd as AgentSpecAgentExecutionEnd from pyagentspec.tracing.events import AgentExecutionStart as AgentSpecAgentExecutionStart from pyagentspec.tracing.events import Event as AgentSpecEvent @@ -18,70 +15,19 @@ from pyagentspec.tracing.spans import AgentExecutionSpan as AgentSpecAgentExecutionSpan from pyagentspec.tracing.spans import LlmGenerationSpan as AgentSpecLlmGenerationSpan from pyagentspec.tracing.spans import Span as AgentSpecSpan +from pyagentspec.tracing.trace import Trace as AgentSpecTrace from wayflowcore import Agent as WayflowAgent +from wayflowcore.agentspec.tracing import AgentSpecEventListener +from wayflowcore.events.eventlistener import register_event_listeners from wayflowcore.executors.executionstatus import UserMessageRequestStatus from wayflowcore.executors.statesnapshotpolicy import StateSnapshotInterval, StateSnapshotPolicy from wayflowcore.models.vllmmodel import VllmModel -from wayflowcore.serialization import dump_conversation_state - -from ..testhelpers.agentspec_tracing import ( - SnapshotEventsSeenAtSpanEndRecorder, - events, - execute_with_trace, - single_span, - spans, -) -from ..testhelpers.patching import patch_llm -from ..testhelpers.statesnapshots import ( - restore_conversation_from_snapshot_payload, - snapshot_message, - snapshot_status_types, -) - - -@dataclass(frozen=True) -class ExportedAGUIStateSnapshot: - conversation_id: str - snapshot: dict[str, Any] - - -@dataclass(frozen=True) -class RetrievalPreplan: - summary: str - entries: list[str] - ready_to_proceed: bool - - -@dataclass(frozen=True) -class RetrievalAssumption: - text: str - status: str - - -@dataclass(frozen=True) -class RetrievalUIState: - preplan: RetrievalPreplan - assumptions: list[RetrievalAssumption] - -@dataclass(frozen=True) -class RetrievalAgentState: - thread_id: str - agent_type: str - llm_model_name: str - default_schema: str - input_document: str - message_count: int - last_response: str - ui: RetrievalUIState +from ..testhelpers.patching import patch_llm -class AGUIStateSnapshotExporter(AgentSpecSpanProcessor): - def __init__(self) -> None: - super().__init__() - self.exported_snapshots: list[ExportedAGUIStateSnapshot] = [] - +class _PassiveSpanProcessor(AgentSpecSpanProcessor): def on_start(self, span: AgentSpecSpan) -> None: return None @@ -95,23 +41,10 @@ async def on_end_async(self, span: AgentSpecSpan) -> None: return None def on_event(self, event: AgentSpecEvent, span: AgentSpecSpan) -> None: - if not isinstance(event, AgentSpecStateSnapshotEmitted): - return - - conversation_snapshot = (event.state_snapshot or {}).get("conversation", {}) - self.exported_snapshots.append( - ExportedAGUIStateSnapshot( - conversation_id=event.conversation_id, - snapshot={ - "messages": conversation_snapshot.get("messages", []), - "input": conversation_snapshot.get("inputs", {}).get("input"), - "agent_state": (event.extra_state or {}).get("agent_state", {}), - }, - ) - ) + return None async def on_event_async(self, event: AgentSpecEvent, span: AgentSpecSpan) -> None: - self.on_event(event, span) + return None def startup(self) -> None: return None @@ -126,142 +59,112 @@ async def shutdown_async(self) -> None: return None -_RETRIEVAL_INPUTS = { - "input": "How many orders last week?", - "thread_id": "thread-123", - "agent_type": "planner", - "llm_model_name": "gpt-5-mini", - "default_schema": "sales", - "input_document": "Only use the sales schema and weekly order metrics.", -} - -_RETRIEVAL_UI_STATE = RetrievalUIState( - preplan=RetrievalPreplan( - summary="Inspect weekly sales orders and answer concisely.", - entries=[ - "Inspect the active schema", - "Aggregate last week's orders", - "Return the final answer", - ], - ready_to_proceed=True, - ), - assumptions=[ - RetrievalAssumption(text="Use the sales schema only", status="approved"), - RetrievalAssumption(text="Week boundaries follow UTC", status="auto_approved"), - ], -) - - -def _build_retrieval_agent_state( +class _SnapshotSpanRecorder(_PassiveSpanProcessor): + def __init__(self) -> None: + super().__init__() + self.started_spans: list[AgentSpecSpan] = [] + + def on_start(self, span: AgentSpecSpan) -> None: + self.started_spans.append(span) + + async def on_start_async(self, span: AgentSpecSpan) -> None: + self.started_spans.append(span) + + +class _EventsSeenAtSpanEndRecorder(_PassiveSpanProcessor): + def __init__(self) -> None: + super().__init__() + self.events_by_span_id: dict[str, list[AgentSpecEvent]] = {} + + def on_end(self, span: AgentSpecSpan) -> None: + self.events_by_span_id[span.id] = list(span.events) + + async def on_end_async(self, span: AgentSpecSpan) -> None: + self.events_by_span_id[span.id] = list(span.events) + + +def _recorded_spans( + span_recorder: _SnapshotSpanRecorder, + span_type: type[AgentSpecSpan], +) -> list[AgentSpecSpan]: + return [span for span in span_recorder.started_spans if isinstance(span, span_type)] + + +def _single_span( + span_recorder: _SnapshotSpanRecorder, + span_type: type[AgentSpecSpan], +) -> AgentSpecSpan: + matching_spans = _recorded_spans(span_recorder, span_type) + assert len(matching_spans) == 1 + return matching_spans[0] + + +def _span_events( + span: AgentSpecSpan, + event_type: type[AgentSpecEvent], +) -> list[AgentSpecEvent]: + return [event for event in span.events if isinstance(event, event_type)] + + +def _execute_with_trace( + conversation, *, - conversation_inputs: dict[str, Any], - message_count: int, - last_response: str, -) -> RetrievalAgentState: - return RetrievalAgentState( - thread_id=conversation_inputs["thread_id"], - agent_type=conversation_inputs["agent_type"], - llm_model_name=conversation_inputs["llm_model_name"], - default_schema=conversation_inputs["default_schema"], - input_document=conversation_inputs["input_document"], - message_count=message_count, - last_response=last_response, - ui=_RETRIEVAL_UI_STATE, - ) + state_snapshot_policy, + span_processors: Sequence[AgentSpecSpanProcessor] = (), + contexts: Sequence[AbstractContextManager[Any]] = (), +) -> tuple[Any, _SnapshotSpanRecorder]: + span_recorder = _SnapshotSpanRecorder() + listener = AgentSpecEventListener() + with ExitStack() as stack: + for context in contexts: + stack.enter_context(context) + stack.enter_context(AgentSpecTrace(span_processors=[span_recorder, *span_processors])) + stack.enter_context(register_event_listeners([listener])) + status = conversation.execute(state_snapshot_policy=state_snapshot_policy) -def test_agent_state_snapshots_support_the_agui_retrieval_export_flow() -> None: - assistant_message = "I checked the warehouse and found 42 orders last week." - wayflow_agent = cast( - WayflowAgent, - AgentSpecLoader().load_component( - AgentSpecAgent( - name="retrieval_agent", - llm_config=VllmConfig(name="llm", url="http://mock.url", model_id="mock.model"), - system_prompt="You are a helpful retrieval agent.", - ) - ), - ) - conversation = wayflow_agent.start_conversation(inputs=_RETRIEVAL_INPUTS) - conversation.append_user_message(_RETRIEVAL_INPUTS["input"]) - - agui_exporter = AGUIStateSnapshotExporter() - - def build_extra_state(conversation) -> dict[str, Any]: - conversation_snapshot = dump_conversation_state(conversation)["conversation"] - messages = conversation_snapshot["messages"] - last_response = next( - ( - message.get("content") - for message in reversed(messages) - if message.get("role") == "assistant" and message.get("content") - ), - "", - ) - return { - "agent_state": asdict( - _build_retrieval_agent_state( - conversation_inputs=conversation.inputs, - message_count=len(messages), - last_response=last_response, - ) - ) - } - - status, span_recorder = execute_with_trace( + return status, span_recorder + + +def _make_single_turn_agent_conversation() -> tuple[str, VllmModel, Any]: + assistant_message = "Hello from agent" + llm = VllmModel(model_id="mock.model", host_port="http://mock.url", name="agent") + conversation = WayflowAgent(llm=llm).start_conversation() + conversation.append_user_message("Hi") + return assistant_message, llm, conversation + + +def _snapshot_message(snapshot_event: AgentSpecStateSnapshotEmitted) -> str | None: + messages = (snapshot_event.state_snapshot or {}).get("conversation", {}).get("messages", []) + if not messages: + return None + return messages[-1].get("content") + + +def test_conversation_turn_snapshots_attach_to_the_agent_span() -> None: + assistant_message, llm, conversation = _make_single_turn_agent_conversation() + status, span_recorder = _execute_with_trace( conversation, state_snapshot_policy=StateSnapshotPolicy( - state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS, - extra_state_builder=build_extra_state, + state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS ), - span_processors=[agui_exporter], - contexts=[patch_llm(wayflow_agent.llm, [assistant_message], patch_internal=True)], + contexts=[patch_llm(llm, [assistant_message], patch_internal=True)], ) assert isinstance(status, UserMessageRequestStatus) - agent_span = single_span(span_recorder, AgentSpecAgentExecutionSpan) - assert events(agent_span, AgentSpecAgentExecutionStart) - state_snapshot_events = events(agent_span, AgentSpecStateSnapshotEmitted) - assert len(state_snapshot_events) == 2 - - final_snapshot_event = state_snapshot_events[-1] - assert final_snapshot_event.state_snapshot is not None - snapshot_payload = final_snapshot_event.state_snapshot - restored_conversation = restore_conversation_from_snapshot_payload(snapshot_payload) - runtime_messages = snapshot_payload["conversation"]["messages"] - expected_agent_state = asdict( - _build_retrieval_agent_state( - conversation_inputs=_RETRIEVAL_INPUTS, - message_count=len(runtime_messages), - last_response=assistant_message, - ) - ) + agent_span = _single_span(span_recorder, AgentSpecAgentExecutionSpan) + state_snapshot_events = _span_events(agent_span, AgentSpecStateSnapshotEmitted) - assert final_snapshot_event.conversation_id == conversation.conversation_id - assert snapshot_payload["conversation"]["inputs"]["input"] == _RETRIEVAL_INPUTS["input"] - assert runtime_messages[-1]["content"] == assistant_message - assert final_snapshot_event.extra_state == {"agent_state": expected_agent_state} - assert dump_conversation_state(restored_conversation)["execution"]["status_handled"] is False - - assert len(agui_exporter.exported_snapshots) == 2 - assert agui_exporter.exported_snapshots[-1] == ExportedAGUIStateSnapshot( - conversation_id=conversation.conversation_id, - snapshot={ - "messages": runtime_messages, - "input": _RETRIEVAL_INPUTS["input"], - "agent_state": expected_agent_state, - }, - ) + assert _span_events(agent_span, AgentSpecAgentExecutionStart) + assert len(state_snapshot_events) == 2 + assert state_snapshot_events[-1].conversation_id == conversation.conversation_id + assert _snapshot_message(state_snapshot_events[-1]) == assistant_message -def test_agent_node_turn_state_snapshots_are_mapped_into_the_agent_span_not_llm_spans() -> None: - assistant_message = "Hello from agent" - llm = VllmModel(model_id="mock.model", host_port="http://mock.url", name="agent") - agent = WayflowAgent(llm=llm) - conversation = agent.start_conversation() - conversation.append_user_message("Hi") - status, span_recorder = execute_with_trace( +def test_node_turn_snapshots_attach_to_the_agent_span_not_llm_spans() -> None: + assistant_message, llm, conversation = _make_single_turn_agent_conversation() + status, span_recorder = _execute_with_trace( conversation, state_snapshot_policy=StateSnapshotPolicy( state_snapshot_interval=StateSnapshotInterval.NODE_TURNS @@ -271,65 +174,57 @@ def test_agent_node_turn_state_snapshots_are_mapped_into_the_agent_span_not_llm_ assert isinstance(status, UserMessageRequestStatus) - agent_span = single_span(span_recorder, AgentSpecAgentExecutionSpan) - state_snapshot_events = events(agent_span, AgentSpecStateSnapshotEmitted) + agent_span = _single_span(span_recorder, AgentSpecAgentExecutionSpan) + state_snapshot_events = _span_events(agent_span, AgentSpecStateSnapshotEmitted) assert len(state_snapshot_events) == 4 - assert [event.state_snapshot["execution"]["curr_iter"] for event in state_snapshot_events] == [ - 0, - 0, - 1, - 1, - ] - assert snapshot_status_types(state_snapshot_events) == [ - None, - None, - None, - "UserMessageRequestStatus", - ] - assert snapshot_message(state_snapshot_events[-1]) == assistant_message - - llm_spans = spans(span_recorder, AgentSpecLlmGenerationSpan) - assert llm_spans + assert state_snapshot_events[-1].state_snapshot["execution"]["status"]["type"] == ( + "UserMessageRequestStatus" + ) + assert _snapshot_message(state_snapshot_events[-1]) == assistant_message assert not any( isinstance(event, AgentSpecStateSnapshotEmitted) - for span in llm_spans + for span in _recorded_spans(span_recorder, AgentSpecLlmGenerationSpan) for event in span.events ) -def test_agent_final_state_snapshot_is_visible_to_span_processors_inside_on_end() -> None: - assistant_message = "Hello from agent" - llm = VllmModel(model_id="mock.model", host_port="http://mock.url", name="agent") - agent = WayflowAgent(llm=llm) - conversation = agent.start_conversation() - conversation.append_user_message("Hi") - on_end_recorder = SnapshotEventsSeenAtSpanEndRecorder() - - status, span_recorder = execute_with_trace( +def test_final_agent_snapshot_is_visible_to_span_processors_inside_on_end() -> None: + assistant_message, llm, conversation = _make_single_turn_agent_conversation() + events_seen_at_end_recorder = _EventsSeenAtSpanEndRecorder() + status, span_recorder = _execute_with_trace( conversation, state_snapshot_policy=StateSnapshotPolicy( state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS ), - span_processors=[on_end_recorder], + span_processors=[events_seen_at_end_recorder], contexts=[patch_llm(llm, [assistant_message], patch_internal=True)], ) assert isinstance(status, UserMessageRequestStatus) - agent_span = single_span(span_recorder, AgentSpecAgentExecutionSpan) - events_seen_at_end = on_end_recorder.events_by_span_id[agent_span.id] + agent_span = _single_span(span_recorder, AgentSpecAgentExecutionSpan) + events_seen_at_end = events_seen_at_end_recorder.events_by_span_id[agent_span.id] assert any(isinstance(event, AgentSpecAgentExecutionEnd) for event in events_seen_at_end) - assert ( - len( - [ - event - for event in events_seen_at_end - if isinstance(event, AgentSpecStateSnapshotEmitted) - ] - ) - == 2 - ) assert isinstance(events_seen_at_end[-1], AgentSpecStateSnapshotEmitted) - assert snapshot_message(events_seen_at_end[-1]) == assistant_message + assert _snapshot_message(events_seen_at_end[-1]) == assistant_message + + +def test_agent_snapshot_extra_state_is_passed_through_verbatim() -> None: + assistant_message, llm, conversation = _make_single_turn_agent_conversation() + status, span_recorder = _execute_with_trace( + conversation, + state_snapshot_policy=StateSnapshotPolicy( + state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS, + extra_state_builder=lambda _conversation: {"ui": {"active_tab": "plan"}}, + ), + contexts=[patch_llm(llm, [assistant_message], patch_internal=True)], + ) + + assert isinstance(status, UserMessageRequestStatus) + + agent_span = _single_span(span_recorder, AgentSpecAgentExecutionSpan) + state_snapshot_events = _span_events(agent_span, AgentSpecStateSnapshotEmitted) + + assert state_snapshot_events[-1].extra_state == {"ui": {"active_tab": "plan"}} diff --git a/wayflowcore/tests/agentspec/test_state_snapshot_tracing_flow.py b/wayflowcore/tests/agentspec/test_state_snapshot_tracing_flow.py index 4a3a8d05a..b29ee11ef 100644 --- a/wayflowcore/tests/agentspec/test_state_snapshot_tracing_flow.py +++ b/wayflowcore/tests/agentspec/test_state_snapshot_tracing_flow.py @@ -4,14 +4,18 @@ # (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0) or Universal Permissive License # (UPL) 1.0 (LICENSE-UPL or https://oss.oracle.com/licenses/upl), at your option. -import json +from contextlib import AbstractContextManager, ExitStack +from typing import Any, Sequence import pytest +from pyagentspec.tracing.events import Event as AgentSpecEvent from pyagentspec.tracing.events import FlowExecutionEnd as AgentSpecFlowExecutionEnd from pyagentspec.tracing.events import FlowExecutionStart as AgentSpecFlowExecutionStart from pyagentspec.tracing.events import StateSnapshotEmitted as AgentSpecStateSnapshotEmitted +from pyagentspec.tracing.spanprocessor import SpanProcessor as AgentSpecSpanProcessor from pyagentspec.tracing.spans import FlowExecutionSpan as AgentSpecFlowExecutionSpan from pyagentspec.tracing.spans import NodeExecutionSpan as AgentSpecNodeExecutionSpan +from pyagentspec.tracing.spans import Span as AgentSpecSpan from pyagentspec.tracing.spans import ToolExecutionSpan as AgentSpecToolExecutionSpan from pyagentspec.tracing.trace import Trace as AgentSpecTrace @@ -20,170 +24,207 @@ from wayflowcore.executors.executionstatus import FinishedStatus from wayflowcore.executors.statesnapshotpolicy import StateSnapshotInterval, StateSnapshotPolicy from wayflowcore.flow import Flow -from wayflowcore.property import AnyProperty -from wayflowcore.steps import CompleteStep, OutputMessageStep, ToolExecutionStep +from wayflowcore.steps import CompleteStep, FlowExecutionStep, OutputMessageStep, ToolExecutionStep from wayflowcore.tools import ServerTool -from ..testhelpers.agentspec_tracing import ( - SnapshotEventsSeenAtSpanEndRecorder, - SnapshotSpanRecorder, - events, - execute_with_trace, - single_span, - spans, -) -from ..testhelpers.statesnapshots import ( - snapshot_message, - snapshot_step_histories, -) - - -def test_flow_state_snapshots_are_mapped_into_the_flow_span_before_flow_end() -> None: - flow = Flow.from_steps( - [OutputMessageStep(message_template="Hello"), CompleteStep(name="end")], - step_names=["single_step", "end"], - ) - conversation = flow.start_conversation() - status, span_recorder = execute_with_trace( - conversation, - state_snapshot_policy=StateSnapshotPolicy( - state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS, - extra_state_builder=lambda _conversation: {"ui": {"active_tab": "plan"}}, - ), - ) - assert isinstance(status, FinishedStatus) +class _PassiveSpanProcessor(AgentSpecSpanProcessor): + def on_start(self, span: AgentSpecSpan) -> None: + return None - flow_span = single_span(span_recorder, AgentSpecFlowExecutionSpan) - assert events(flow_span, AgentSpecFlowExecutionStart) - state_snapshot_events = events(flow_span, AgentSpecStateSnapshotEmitted) + async def on_start_async(self, span: AgentSpecSpan) -> None: + return None - assert len(state_snapshot_events) == 2 - final_snapshot_event = state_snapshot_events[-1] - assert final_snapshot_event.conversation_id == conversation.conversation_id - assert snapshot_message(final_snapshot_event) == "Hello" - assert final_snapshot_event.extra_state == {"ui": {"active_tab": "plan"}} - assert flow_span.end_time is not None - assert "variable_state" not in final_snapshot_event.model_dump(mask_sensitive_information=False) - assert flow_span in span_recorder.ended_spans - - -def test_flow_final_state_snapshot_is_visible_to_span_processors_inside_on_end() -> None: - flow = Flow.from_steps( - [OutputMessageStep(message_template="Hello"), CompleteStep(name="end")], - step_names=["single_step", "end"], - ) - conversation = flow.start_conversation() - on_end_recorder = SnapshotEventsSeenAtSpanEndRecorder() + def on_end(self, span: AgentSpecSpan) -> None: + return None - status, span_recorder = execute_with_trace( - conversation, - state_snapshot_policy=StateSnapshotPolicy( - state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS - ), - span_processors=[on_end_recorder], - ) + async def on_end_async(self, span: AgentSpecSpan) -> None: + return None - assert isinstance(status, FinishedStatus) + def on_event(self, event: AgentSpecEvent, span: AgentSpecSpan) -> None: + return None - flow_span = single_span(span_recorder, AgentSpecFlowExecutionSpan) - events_seen_at_end = on_end_recorder.events_by_span_id[flow_span.id] + async def on_event_async(self, event: AgentSpecEvent, span: AgentSpecSpan) -> None: + return None - assert any(isinstance(event, AgentSpecFlowExecutionEnd) for event in events_seen_at_end) - assert ( - len( - [ - event - for event in events_seen_at_end - if isinstance(event, AgentSpecStateSnapshotEmitted) - ] - ) - == 2 + def startup(self) -> None: + return None + + def shutdown(self) -> None: + return None + + async def startup_async(self) -> None: + return None + + async def shutdown_async(self) -> None: + return None + + +class _SnapshotSpanRecorder(_PassiveSpanProcessor): + def __init__(self) -> None: + super().__init__() + self.started_spans: list[AgentSpecSpan] = [] + + def on_start(self, span: AgentSpecSpan) -> None: + self.started_spans.append(span) + + async def on_start_async(self, span: AgentSpecSpan) -> None: + self.started_spans.append(span) + + +class _EventsSeenAtSpanEndRecorder(_PassiveSpanProcessor): + def __init__(self) -> None: + super().__init__() + self.events_by_span_id: dict[str, list[AgentSpecEvent]] = {} + + def on_end(self, span: AgentSpecSpan) -> None: + self.events_by_span_id[span.id] = list(span.events) + + async def on_end_async(self, span: AgentSpecSpan) -> None: + self.events_by_span_id[span.id] = list(span.events) + + +def _recorded_spans( + span_recorder: _SnapshotSpanRecorder, + span_type: type[AgentSpecSpan], +) -> list[AgentSpecSpan]: + return [span for span in span_recorder.started_spans if isinstance(span, span_type)] + + +def _single_span( + span_recorder: _SnapshotSpanRecorder, + span_type: type[AgentSpecSpan], +) -> AgentSpecSpan: + matching_spans = _recorded_spans(span_recorder, span_type) + assert len(matching_spans) == 1 + return matching_spans[0] + + +def _span_events( + span: AgentSpecSpan, + event_type: type[AgentSpecEvent], +) -> list[AgentSpecEvent]: + return [event for event in span.events if isinstance(event, event_type)] + + +def _execute_with_trace( + conversation, + *, + state_snapshot_policy, + span_processors: Sequence[AgentSpecSpanProcessor] = (), + contexts: Sequence[AbstractContextManager[Any]] = (), +) -> tuple[Any, _SnapshotSpanRecorder]: + span_recorder = _SnapshotSpanRecorder() + listener = AgentSpecEventListener() + + with ExitStack() as stack: + for context in contexts: + stack.enter_context(context) + stack.enter_context(AgentSpecTrace(span_processors=[span_recorder, *span_processors])) + stack.enter_context(register_event_listeners([listener])) + status = conversation.execute(state_snapshot_policy=state_snapshot_policy) + + return status, span_recorder + + +def _make_output_flow() -> Flow: + return Flow.from_steps( + [ + OutputMessageStep(message_template="Hello", name="single_step"), + CompleteStep(name="end"), + ], + name="simple_output_flow", ) - assert isinstance(events_seen_at_end[-1], AgentSpecStateSnapshotEmitted) - assert snapshot_message(events_seen_at_end[-1]) == "Hello" -def test_flow_state_snapshots_normalize_non_finite_floats_before_agent_spec_export() -> None: - flow = Flow.from_steps( +def _make_tool_flow() -> Flow: + return Flow.from_steps( [ ToolExecutionStep( tool=ServerTool( - name="echo", - description="Echo input", - func=lambda bad: str(bad), - input_descriptors=[AnyProperty(name="bad")], - ) + name="say_hi", + description="Say hi", + func=lambda: "hi", + input_descriptors=[], + ), + name="tool_step", ), CompleteStep(name="end"), - ] + ], + name="tool_flow", ) - conversation = flow.start_conversation(inputs={"bad": float("nan")}) - status, span_recorder = execute_with_trace( + +def _make_nested_parent_flow_conversation(): + child_flow = Flow.from_steps( + [ + OutputMessageStep(message_template="child", name="child_message"), + CompleteStep(name="end"), + ], + name="child_flow", + ) + parent_flow = Flow.from_steps( + [ + FlowExecutionStep(flow=child_flow, name="child_flow_step"), + OutputMessageStep(message_template="parent", name="parent_message"), + CompleteStep(name="end"), + ], + name="parent_flow", + ) + return parent_flow.start_conversation() + + +def _snapshot_message(snapshot_event: AgentSpecStateSnapshotEmitted) -> str | None: + messages = (snapshot_event.state_snapshot or {}).get("conversation", {}).get("messages", []) + if not messages: + return None + return messages[-1].get("content") + + +def test_conversation_turn_snapshots_attach_to_the_flow_span() -> None: + conversation = _make_output_flow().start_conversation() + status, span_recorder = _execute_with_trace( conversation, state_snapshot_policy=StateSnapshotPolicy( - state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS, - include_variable_state=False, + state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS ), ) assert isinstance(status, FinishedStatus) - flow_span = single_span(span_recorder, AgentSpecFlowExecutionSpan) - state_snapshot_events = events(flow_span, AgentSpecStateSnapshotEmitted) + flow_span = _single_span(span_recorder, AgentSpecFlowExecutionSpan) + state_snapshot_events = _span_events(flow_span, AgentSpecStateSnapshotEmitted) + assert _span_events(flow_span, AgentSpecFlowExecutionStart) assert len(state_snapshot_events) == 2 - assert all( - event.state_snapshot["conversation"]["inputs"]["bad"] == "NaN" - for event in state_snapshot_events - ) - assert all( - json.loads(json.dumps(event.state_snapshot, allow_nan=False)) == event.state_snapshot - for event in state_snapshot_events - ) + assert state_snapshot_events[-1].conversation_id == conversation.conversation_id + assert _snapshot_message(state_snapshot_events[-1]) == "Hello" -@pytest.mark.anyio -async def test_flow_state_snapshots_are_mapped_into_the_flow_span_before_flow_end_async() -> None: - flow = Flow.from_steps( - [OutputMessageStep(message_template="Hello"), CompleteStep(name="end")], - step_names=["single_step", "end"], +def test_final_flow_snapshot_is_visible_to_span_processors_inside_on_end() -> None: + conversation = _make_output_flow().start_conversation() + events_seen_at_end_recorder = _EventsSeenAtSpanEndRecorder() + status, span_recorder = _execute_with_trace( + conversation, + state_snapshot_policy=StateSnapshotPolicy( + state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS + ), + span_processors=[events_seen_at_end_recorder], ) - conversation = flow.start_conversation() - span_recorder = SnapshotSpanRecorder() - - async with AgentSpecTrace(span_processors=[span_recorder]): - with register_event_listeners([AgentSpecEventListener()]): - status = await conversation.execute_async( - state_snapshot_policy=StateSnapshotPolicy( - state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS, - extra_state_builder=lambda _conversation: {"ui": {"active_tab": "plan"}}, - ) - ) assert isinstance(status, FinishedStatus) - flow_span = single_span(span_recorder, AgentSpecFlowExecutionSpan) - state_snapshot_events = events(flow_span, AgentSpecStateSnapshotEmitted) + flow_span = _single_span(span_recorder, AgentSpecFlowExecutionSpan) + events_seen_at_end = events_seen_at_end_recorder.events_by_span_id[flow_span.id] - assert len(state_snapshot_events) == 2 - final_snapshot_event = state_snapshot_events[-1] - assert final_snapshot_event.conversation_id == conversation.conversation_id - assert snapshot_message(final_snapshot_event) == "Hello" - assert final_snapshot_event.extra_state == {"ui": {"active_tab": "plan"}} - assert flow_span.end_time is not None - assert "variable_state" not in final_snapshot_event.model_dump(mask_sensitive_information=False) - assert flow_span in span_recorder.ended_spans - - -def test_node_turn_state_snapshots_are_mapped_into_the_flow_span_not_node_spans() -> None: - flow = Flow.from_steps( - [OutputMessageStep(message_template="Hello"), CompleteStep(name="end")], - step_names=["single_step", "end"], - ) - conversation = flow.start_conversation() - status, span_recorder = execute_with_trace( + assert any(isinstance(event, AgentSpecFlowExecutionEnd) for event in events_seen_at_end) + assert isinstance(events_seen_at_end[-1], AgentSpecStateSnapshotEmitted) + assert _snapshot_message(events_seen_at_end[-1]) == "Hello" + + +def test_node_turn_snapshots_attach_to_the_flow_span_not_node_or_tool_spans() -> None: + conversation = _make_tool_flow().start_conversation() + status, span_recorder = _execute_with_trace( conversation, state_snapshot_policy=StateSnapshotPolicy( state_snapshot_interval=StateSnapshotInterval.NODE_TURNS @@ -192,107 +233,60 @@ def test_node_turn_state_snapshots_are_mapped_into_the_flow_span_not_node_spans( assert isinstance(status, FinishedStatus) - flow_span = single_span(span_recorder, AgentSpecFlowExecutionSpan) - flow_snapshot_events = events(flow_span, AgentSpecStateSnapshotEmitted) - - assert len(flow_snapshot_events) == 8 - assert snapshot_step_histories(flow_snapshot_events) == [ - [], - [], - ["__StartStep__"], - ["__StartStep__"], - ["__StartStep__", "single_step"], - ["__StartStep__", "single_step"], - ["__StartStep__", "single_step", "end"], - ["__StartStep__", "single_step", "end"], - ] - node_spans = spans(span_recorder, AgentSpecNodeExecutionSpan) - assert node_spans + flow_span = _single_span(span_recorder, AgentSpecFlowExecutionSpan) + flow_snapshot_events = _span_events(flow_span, AgentSpecStateSnapshotEmitted) + + assert len(flow_snapshot_events) > 2 + assert not any( + isinstance(event, AgentSpecStateSnapshotEmitted) + for span in _recorded_spans(span_recorder, AgentSpecNodeExecutionSpan) + for event in span.events + ) assert not any( isinstance(event, AgentSpecStateSnapshotEmitted) - for span in node_spans + for span in _recorded_spans(span_recorder, AgentSpecToolExecutionSpan) for event in span.events ) -@pytest.mark.parametrize( - ("interval", "expected_step_histories"), - [ - pytest.param( - StateSnapshotInterval.TOOL_TURNS, - [ - [], - ["__StartStep__", "step_0"], - ["__StartStep__", "step_0"], - ["__StartStep__", "step_0", "end"], - ], - id="tool_turns", - ), - pytest.param( - StateSnapshotInterval.ALL_INTERNAL_TURNS, - [ - [], - [], - ["__StartStep__"], - ["__StartStep__"], - ["__StartStep__", "step_0"], - ["__StartStep__", "step_0"], - ["__StartStep__", "step_0"], - ["__StartStep__", "step_0"], - ["__StartStep__", "step_0", "end"], - ["__StartStep__", "step_0", "end"], - ], - id="all_internal_turns", - ), - ], -) -def test_internal_flow_state_snapshots_follow_conversation_ownership_for_agent_spec( - interval: StateSnapshotInterval, - expected_step_histories: list[list[str]], -) -> None: - flow = Flow.from_steps( - [ - ToolExecutionStep( - tool=ServerTool( - name="say_hi", - description="Say hi", - func=lambda: "hi", - input_descriptors=[], - ) - ), - CompleteStep(name="end"), - ] - ) - conversation = flow.start_conversation() - status, span_recorder = execute_with_trace( +def test_nested_flow_execution_exports_snapshots_only_on_the_root_flow_span() -> None: + conversation = _make_nested_parent_flow_conversation() + status, span_recorder = _execute_with_trace( conversation, - state_snapshot_policy=StateSnapshotPolicy(state_snapshot_interval=interval), + state_snapshot_policy=StateSnapshotPolicy( + state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS + ), ) assert isinstance(status, FinishedStatus) - flow_span = single_span(span_recorder, AgentSpecFlowExecutionSpan) - flow_snapshot_events = events(flow_span, AgentSpecStateSnapshotEmitted) - assert snapshot_step_histories(flow_snapshot_events) == expected_step_histories - - tool_spans = spans(span_recorder, AgentSpecToolExecutionSpan) - node_spans = spans(span_recorder, AgentSpecNodeExecutionSpan) - assert tool_spans - assert node_spans - assert not any( - isinstance(event, AgentSpecStateSnapshotEmitted) - for span in [*tool_spans, *node_spans] - for event in span.events - ) + flow_spans = _recorded_spans(span_recorder, AgentSpecFlowExecutionSpan) + assert len(flow_spans) == 2 + + flow_spans_by_name = { + next( + event for event in span.events if isinstance(event, AgentSpecFlowExecutionStart) + ).flow.name: span + for span in flow_spans + } + parent_span = flow_spans_by_name["parent_flow"] + child_span = flow_spans_by_name["child_flow"] + + parent_snapshot_events = _span_events(parent_span, AgentSpecStateSnapshotEmitted) + child_snapshot_events = _span_events(child_span, AgentSpecStateSnapshotEmitted) + + assert len(parent_snapshot_events) == 2 + assert [event.conversation_id for event in parent_snapshot_events] == [ + conversation.conversation_id, + conversation.conversation_id, + ] + assert _snapshot_message(parent_snapshot_events[-1]) == "parent" + assert child_snapshot_events == [] -def test_off_policy_does_not_bridge_state_snapshots_into_agent_spec_spans() -> None: - flow = Flow.from_steps( - [OutputMessageStep(message_template="Hello"), CompleteStep(name="end")], - step_names=["single_step", "end"], - ) - conversation = flow.start_conversation() - status, span_recorder = execute_with_trace( +def test_off_policy_disables_flow_state_snapshot_export() -> None: + conversation = _make_output_flow().start_conversation() + status, span_recorder = _execute_with_trace( conversation, state_snapshot_policy=StateSnapshotPolicy( state_snapshot_interval=StateSnapshotInterval.OFF @@ -301,14 +295,12 @@ def test_off_policy_does_not_bridge_state_snapshots_into_agent_spec_spans() -> N assert isinstance(status, FinishedStatus) - flow_span = single_span(span_recorder, AgentSpecFlowExecutionSpan) - assert events(flow_span, AgentSpecFlowExecutionStart) - assert events(flow_span, AgentSpecFlowExecutionEnd) - assert not events(flow_span, AgentSpecStateSnapshotEmitted) + flow_span = _single_span(span_recorder, AgentSpecFlowExecutionSpan) + assert _span_events(flow_span, AgentSpecStateSnapshotEmitted) == [] -def test_only_the_opening_state_snapshot_is_exported_when_a_turn_raises() -> None: - flow = Flow.from_steps( +def test_raised_turn_exports_only_the_opening_flow_snapshot() -> None: + conversation = Flow.from_steps( [ ToolExecutionStep( tool=ServerTool( @@ -320,12 +312,12 @@ def test_only_the_opening_state_snapshot_is_exported_when_a_turn_raises() -> Non ), CompleteStep(name="end"), ] - ) - conversation = flow.start_conversation() - span_recorder = SnapshotSpanRecorder() + ).start_conversation() + span_recorder = _SnapshotSpanRecorder() + listener = AgentSpecEventListener() with AgentSpecTrace(span_processors=[span_recorder]): - with register_event_listeners([AgentSpecEventListener()]): + with register_event_listeners([listener]): with pytest.raises(RuntimeError, match="boom"): conversation.execute( state_snapshot_policy=StateSnapshotPolicy( @@ -333,9 +325,8 @@ def test_only_the_opening_state_snapshot_is_exported_when_a_turn_raises() -> Non ) ) - flow_span = single_span(span_recorder, AgentSpecFlowExecutionSpan) - state_snapshot_events = events(flow_span, AgentSpecStateSnapshotEmitted) + flow_span = _single_span(span_recorder, AgentSpecFlowExecutionSpan) + state_snapshot_events = _span_events(flow_span, AgentSpecStateSnapshotEmitted) assert len(state_snapshot_events) == 1 assert state_snapshot_events[0].state_snapshot["execution"]["status"] is None - assert flow_span in span_recorder.ended_spans diff --git a/wayflowcore/tests/agentspec/test_state_snapshot_tracing_nested.py b/wayflowcore/tests/agentspec/test_state_snapshot_tracing_nested.py deleted file mode 100644 index 6f91c00ba..000000000 --- a/wayflowcore/tests/agentspec/test_state_snapshot_tracing_nested.py +++ /dev/null @@ -1,209 +0,0 @@ -# Copyright © 2026 Oracle and/or its affiliates. -# -# This software is under the Apache License 2.0 -# (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0) or Universal Permissive License -# (UPL) 1.0 (LICENSE-UPL or https://oss.oracle.com/licenses/upl), at your option. - -from pyagentspec.tracing.events import Event as AgentSpecEvent -from pyagentspec.tracing.events import FlowExecutionStart as AgentSpecFlowExecutionStart -from pyagentspec.tracing.events import StateSnapshotEmitted as AgentSpecStateSnapshotEmitted -from pyagentspec.tracing.spanprocessor import SpanProcessor as AgentSpecSpanProcessor -from pyagentspec.tracing.spans import FlowExecutionSpan as AgentSpecFlowExecutionSpan -from pyagentspec.tracing.spans import Span as AgentSpecSpan - -from wayflowcore.executors.executionstatus import FinishedStatus -from wayflowcore.executors.statesnapshotpolicy import StateSnapshotInterval, StateSnapshotPolicy -from wayflowcore.flow import Flow -from wayflowcore.steps import CompleteStep, FlowExecutionStep, OutputMessageStep, ParallelMapStep - -from ..testhelpers.agentspec_tracing import execute_with_trace, spans -from ..testhelpers.statesnapshots import ( - snapshot_message, - snapshot_runtime_conversation_ids, -) - - -class SnapshotRuntimeIdsByConversationExporter(AgentSpecSpanProcessor): - def __init__(self) -> None: - super().__init__() - self.runtime_ids_by_conversation_id: dict[str, list[str]] = {} - - def on_start(self, span: AgentSpecSpan) -> None: - return None - - async def on_start_async(self, span: AgentSpecSpan) -> None: - return None - - def on_end(self, span: AgentSpecSpan) -> None: - return None - - async def on_end_async(self, span: AgentSpecSpan) -> None: - return None - - def on_event(self, event: AgentSpecEvent, span: AgentSpecSpan) -> None: - if isinstance(event, AgentSpecStateSnapshotEmitted) and event.state_snapshot is not None: - self.runtime_ids_by_conversation_id.setdefault(event.conversation_id, []).append( - event.state_snapshot["conversation"]["id"] - ) - - async def on_event_async(self, event: AgentSpecEvent, span: AgentSpecSpan) -> None: - self.on_event(event, span) - - def startup(self) -> None: - return None - - def shutdown(self) -> None: - return None - - async def startup_async(self) -> None: - return None - - async def shutdown_async(self) -> None: - return None - - -def test_nested_flow_state_snapshots_stay_on_the_root_flow_span_for_shared_conversations() -> None: - child_flow = Flow.from_steps( - [OutputMessageStep(message_template="child"), CompleteStep(name="end")], - step_names=["child_message", "end"], - name="child_flow", - ) - parent_flow = Flow.from_steps( - [ - FlowExecutionStep(flow=child_flow), - OutputMessageStep(message_template="parent"), - CompleteStep(name="end"), - ], - step_names=["child_flow_step", "parent_message", "end"], - name="parent_flow", - ) - conversation = parent_flow.start_conversation() - status, span_recorder = execute_with_trace( - conversation, - state_snapshot_policy=StateSnapshotPolicy( - state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS - ), - ) - - assert isinstance(status, FinishedStatus) - - flow_spans = spans(span_recorder, AgentSpecFlowExecutionSpan) - assert len(flow_spans) == 2 - - flow_spans_by_name = { - next( - event for event in span.events if isinstance(event, AgentSpecFlowExecutionStart) - ).flow.name: span - for span in flow_spans - } - parent_span = flow_spans_by_name["parent_flow"] - child_span = flow_spans_by_name["child_flow"] - - parent_snapshot_events = [ - event for event in parent_span.events if isinstance(event, AgentSpecStateSnapshotEmitted) - ] - child_snapshot_events = [ - event for event in child_span.events if isinstance(event, AgentSpecStateSnapshotEmitted) - ] - assert len(parent_snapshot_events) == 2 - assert [event.conversation_id for event in parent_snapshot_events] == [ - conversation.conversation_id, - conversation.conversation_id, - ] - assert snapshot_runtime_conversation_ids(parent_snapshot_events) == [ - conversation.id, - conversation.id, - ] - assert snapshot_message(parent_snapshot_events[-1]) == "parent" - assert not child_snapshot_events - - -def test_nested_node_turn_state_snapshots_export_only_root_runtime_conversation_to_agent_spec() -> ( - None -): - child_flow = Flow.from_steps( - [OutputMessageStep(message_template="child"), CompleteStep(name="end")], - step_names=["child_message", "end"], - name="child_flow", - ) - parent_flow = Flow.from_steps( - [ - FlowExecutionStep(flow=child_flow), - OutputMessageStep(message_template="parent"), - CompleteStep(name="end"), - ], - step_names=["child_flow_step", "parent_message", "end"], - name="parent_flow", - ) - conversation = parent_flow.start_conversation() - status, span_recorder = execute_with_trace( - conversation, - state_snapshot_policy=StateSnapshotPolicy( - state_snapshot_interval=StateSnapshotInterval.NODE_TURNS - ), - ) - - assert isinstance(status, FinishedStatus) - - flow_spans = spans(span_recorder, AgentSpecFlowExecutionSpan) - flow_spans_by_name = { - next( - event for event in span.events if isinstance(event, AgentSpecFlowExecutionStart) - ).flow.name: span - for span in flow_spans - } - parent_span = flow_spans_by_name["parent_flow"] - child_span = flow_spans_by_name["child_flow"] - - parent_snapshot_events = [ - event for event in parent_span.events if isinstance(event, AgentSpecStateSnapshotEmitted) - ] - child_snapshot_events = [ - event for event in child_span.events if isinstance(event, AgentSpecStateSnapshotEmitted) - ] - - assert parent_snapshot_events - assert snapshot_runtime_conversation_ids(parent_snapshot_events) == [ - conversation.id for _ in parent_snapshot_events - ] - assert not child_snapshot_events - - -def test_parallel_map_snapshots_leave_agent_spec_exporters_with_root_resumable_state() -> None: - child_flow = Flow.from_steps( - [OutputMessageStep(message_template="item={{item}}"), CompleteStep(name="end")], - step_names=["child_message", "end"], - name="parallel_map_child", - ) - parent_flow = Flow.from_steps( - [ - ParallelMapStep( - flow=child_flow, - unpack_input={"item": "."}, - name="parallel_map", - ), - CompleteStep(name="end"), - ], - step_names=["parallel_map", "end"], - name="parallel_map_parent", - ) - conversation = parent_flow.start_conversation( - inputs={ParallelMapStep.ITERATED_INPUT: ["a", "b"]} - ) - snapshot_runtime_id_exporter = SnapshotRuntimeIdsByConversationExporter() - - status, _ = execute_with_trace( - conversation, - state_snapshot_policy=StateSnapshotPolicy( - state_snapshot_interval=StateSnapshotInterval.NODE_TURNS - ), - span_processors=[snapshot_runtime_id_exporter], - ) - - assert isinstance(status, FinishedStatus) - assert snapshot_runtime_id_exporter.runtime_ids_by_conversation_id[conversation.conversation_id] - assert snapshot_runtime_id_exporter.runtime_ids_by_conversation_id[ - conversation.conversation_id - ] == [conversation.id] * len( - snapshot_runtime_id_exporter.runtime_ids_by_conversation_id[conversation.conversation_id] - ) diff --git a/wayflowcore/tests/events/test_state_snapshot_event_tracing.py b/wayflowcore/tests/events/test_state_snapshot_event_tracing.py index 0b12281ba..930d18f9b 100644 --- a/wayflowcore/tests/events/test_state_snapshot_event_tracing.py +++ b/wayflowcore/tests/events/test_state_snapshot_event_tracing.py @@ -9,6 +9,33 @@ from wayflowcore.events.event import _PII_TEXT_MASK, StateSnapshotEvent +def test_state_snapshot_event_requires_conversation_id() -> None: + with pytest.raises(ValueError, match="conversation_id"): + StateSnapshotEvent() + + +def test_state_snapshot_event_allows_missing_state_snapshot() -> None: + event = StateSnapshotEvent(conversation_id="conversation-123") + + assert event.state_snapshot is None + + +def test_state_snapshot_event_requires_state_snapshot_to_be_a_dictionary() -> None: + with pytest.raises(ValueError, match="state_snapshot must be a dictionary"): + StateSnapshotEvent( + conversation_id="conversation-123", + state_snapshot="not-a-dictionary", + ) + + +def test_state_snapshot_event_requires_runtime_conversation_id() -> None: + with pytest.raises(ValueError, match=r"state_snapshot\['conversation'\]\['id'\]"): + StateSnapshotEvent( + conversation_id="conversation-123", + state_snapshot={"conversation": {"messages": []}}, + ) + + @pytest.mark.parametrize("mask_sensitive_information", [True, False]) def test_state_snapshot_event_serialization( mask_sensitive_information: bool, diff --git a/wayflowcore/tests/events/test_state_snapshot_event_validation.py b/wayflowcore/tests/events/test_state_snapshot_event_validation.py deleted file mode 100644 index 0fb30ed68..000000000 --- a/wayflowcore/tests/events/test_state_snapshot_event_validation.py +++ /dev/null @@ -1,36 +0,0 @@ -# Copyright © 2026 Oracle and/or its affiliates. -# -# This software is under the Apache License 2.0 -# (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0) or Universal Permissive License -# (UPL) 1.0 (LICENSE-UPL or https://oss.oracle.com/licenses/upl), at your option. - -import pytest - -from wayflowcore.events.event import StateSnapshotEvent - - -def test_state_snapshot_event_requires_conversation_id() -> None: - with pytest.raises(ValueError, match="conversation_id"): - StateSnapshotEvent() - - -def test_state_snapshot_event_allows_missing_state_snapshot() -> None: - event = StateSnapshotEvent(conversation_id="conversation-123") - - assert event.state_snapshot is None - - -def test_state_snapshot_event_requires_state_snapshot_to_be_a_dictionary() -> None: - with pytest.raises(ValueError, match="state_snapshot must be a dictionary"): - StateSnapshotEvent( - conversation_id="conversation-123", - state_snapshot="not-a-dictionary", - ) - - -def test_state_snapshot_event_requires_runtime_conversation_id() -> None: - with pytest.raises(ValueError, match=r"state_snapshot\['conversation'\]\['id'\]"): - StateSnapshotEvent( - conversation_id="conversation-123", - state_snapshot={"conversation": {"messages": []}}, - ) diff --git a/wayflowcore/tests/events/test_state_snapshot_runtime_conversation_turns.py b/wayflowcore/tests/events/test_state_snapshot_runtime_conversation_turns.py deleted file mode 100644 index 8125ea735..000000000 --- a/wayflowcore/tests/events/test_state_snapshot_runtime_conversation_turns.py +++ /dev/null @@ -1,434 +0,0 @@ -# Copyright © 2026 Oracle and/or its affiliates. -# -# This software is under the Apache License 2.0 -# (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0) or Universal Permissive License -# (UPL) 1.0 (LICENSE-UPL or https://oss.oracle.com/licenses/upl), at your option. - -import json -from typing import Any - -import pytest - -from wayflowcore.events import Event, EventListener -from wayflowcore.events.event import StateSnapshotEvent -from wayflowcore.events.eventlistener import register_event_listeners -from wayflowcore.executors._events.event import EventType -from wayflowcore.executors.executionstatus import ( - FinishedStatus, - ToolRequestStatus, - UserMessageRequestStatus, -) -from wayflowcore.executors.interrupts.executioninterrupt import InterruptedExecutionStatus -from wayflowcore.executors.statesnapshotpolicy import StateSnapshotInterval, StateSnapshotPolicy -from wayflowcore.flow import Flow -from wayflowcore.serialization import dump_conversation_state -from wayflowcore.steps import CompleteStep, InputMessageStep, OutputMessageStep, ToolExecutionStep -from wayflowcore.tools import ClientTool, ToolResult - -from ..test_interrupts import OnEventExecutionInterrupt -from ..testhelpers.statesnapshots import ( - MutatingExecutionEndInterrupt, - SnapshotCollector, - assert_terminal_snapshot, - create_agent_conversation, - create_output_flow_conversation, - create_tool_flow_conversation, - execute_with_state_snapshots, - execute_with_state_snapshots_async, - restore_conversation_from_snapshot_payload, - snapshot_status_types, -) - - -class _LiveConversationSnapshotObserver(EventListener): - def __init__(self, conversation) -> None: - self.conversation = conversation - self.live_snapshots: list[dict[str, Any]] = [] - - def __call__(self, event: Event) -> None: - if not isinstance(event, StateSnapshotEvent): - return - - self.live_snapshots.append(dump_conversation_state(self.conversation)) - - -@pytest.mark.parametrize( - ( - "conversation_factory", - "expected_status_class", - "expected_status_type", - "expected_message", - ), - [ - pytest.param( - create_output_flow_conversation, - FinishedStatus, - "FinishedStatus", - "Hello", - id="flow", - ), - pytest.param( - create_agent_conversation, - UserMessageRequestStatus, - "UserMessageRequestStatus", - "Hello from agent", - id="agent", - ), - ], -) -def test_conversation_turn_policy_records_opening_and_closing_checkpoints( - conversation_factory, - expected_status_class, - expected_status_type: str, - expected_message: str, -) -> None: - conversation = conversation_factory() - - status, state_snapshot_events = execute_with_state_snapshots( - conversation, - state_snapshot_policy=StateSnapshotPolicy( - state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS - ), - ) - - assert isinstance(status, expected_status_class) - assert snapshot_status_types(state_snapshot_events) == [None, expected_status_type] - assert_terminal_snapshot( - state_snapshot_events, - expected_status_type=expected_status_type, - expected_message=expected_message, - ) - - -@pytest.mark.anyio -@pytest.mark.parametrize( - ( - "conversation_factory", - "expected_status_class", - "expected_status_type", - "expected_message", - ), - [ - pytest.param( - create_output_flow_conversation, - FinishedStatus, - "FinishedStatus", - "Hello", - id="flow", - ), - pytest.param( - create_agent_conversation, - UserMessageRequestStatus, - "UserMessageRequestStatus", - "Hello from agent", - id="agent", - ), - ], -) -async def test_conversation_turn_policy_records_opening_and_closing_checkpoints_async( - conversation_factory, - expected_status_class, - expected_status_type: str, - expected_message: str, -) -> None: - conversation = conversation_factory() - - status, state_snapshot_events = await execute_with_state_snapshots_async( - conversation, - state_snapshot_policy=StateSnapshotPolicy( - state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS - ), - ) - - assert isinstance(status, expected_status_class) - assert snapshot_status_types(state_snapshot_events) == [None, expected_status_type] - assert_terminal_snapshot( - state_snapshot_events, - expected_status_type=expected_status_type, - expected_message=expected_message, - ) - - -@pytest.mark.parametrize( - ("conversation_factory", "expected_status_class"), - [ - pytest.param(create_output_flow_conversation, FinishedStatus, id="flow"), - pytest.param(create_agent_conversation, UserMessageRequestStatus, id="agent"), - ], -) -def test_off_policy_disables_state_snapshot_emission( - conversation_factory, - expected_status_class, -) -> None: - conversation = conversation_factory() - - status, state_snapshot_events = execute_with_state_snapshots( - conversation, - state_snapshot_policy=StateSnapshotPolicy( - state_snapshot_interval=StateSnapshotInterval.OFF - ), - ) - - assert isinstance(status, expected_status_class) - assert state_snapshot_events == [] - - -@pytest.mark.parametrize( - ("conversation_factory", "expected_message"), - [ - pytest.param(create_output_flow_conversation, "Hello", id="flow"), - pytest.param(create_agent_conversation, "Hello from agent", id="agent"), - ], -) -def test_conversation_turn_policy_records_interrupted_turn_end_checkpoints( - conversation_factory, - expected_message: str, -) -> None: - conversation = conversation_factory() - - status, state_snapshot_events = execute_with_state_snapshots( - conversation, - execution_interrupts=[OnEventExecutionInterrupt(EventType.EXECUTION_END)], - state_snapshot_policy=StateSnapshotPolicy( - state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS - ), - ) - - assert isinstance(status, InterruptedExecutionStatus) - assert snapshot_status_types(state_snapshot_events) == [None, "InterruptedExecutionStatus"] - assert_terminal_snapshot( - state_snapshot_events, - expected_status_type="InterruptedExecutionStatus", - expected_message=expected_message, - ) - - -def test_conversation_turn_policy_keeps_only_the_opening_checkpoint_when_turn_raises() -> None: - def explode() -> str: - raise RuntimeError("boom") - - conversation = create_tool_flow_conversation( - explode, - name="explode", - description="Raise an error", - ) - collector = SnapshotCollector() - - with register_event_listeners([collector]): - with pytest.raises(RuntimeError, match="boom"): - conversation.execute( - state_snapshot_policy=StateSnapshotPolicy( - state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS - ) - ) - - assert len(collector.state_snapshot_events) == 1 - assert collector.state_snapshot_events[0].state_snapshot["execution"]["status"] is None - - -def test_conversation_turn_policy_reflects_real_interrupt_side_effects_once() -> None: - conversation = create_output_flow_conversation() - interrupt = MutatingExecutionEndInterrupt() - - status, state_snapshot_events = execute_with_state_snapshots( - conversation, - execution_interrupts=[interrupt], - state_snapshot_policy=StateSnapshotPolicy( - state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS - ), - ) - - assert isinstance(status, FinishedStatus) - assert interrupt.count == 1 - assert conversation.inputs["preview_count"] == 1 - assert snapshot_status_types(state_snapshot_events) == [None, "FinishedStatus"] - assert state_snapshot_events[-1].state_snapshot["conversation"]["inputs"]["preview_count"] == 1 - - -@pytest.mark.parametrize( - ( - "interval", - "execution_interrupts", - "expected_status_class", - "expected_status_type", - ), - [ - pytest.param( - StateSnapshotInterval.CONVERSATION_TURNS, - None, - FinishedStatus, - "FinishedStatus", - id="conversation_turns-finished", - ), - pytest.param( - StateSnapshotInterval.TOOL_TURNS, - None, - FinishedStatus, - "FinishedStatus", - id="tool_turns-finished", - ), - pytest.param( - StateSnapshotInterval.NODE_TURNS, - None, - FinishedStatus, - "FinishedStatus", - id="node_turns-finished", - ), - pytest.param( - StateSnapshotInterval.ALL_INTERNAL_TURNS, - None, - FinishedStatus, - "FinishedStatus", - id="all_internal_turns-finished", - ), - pytest.param( - StateSnapshotInterval.CONVERSATION_TURNS, - [OnEventExecutionInterrupt(EventType.EXECUTION_END)], - InterruptedExecutionStatus, - "InterruptedExecutionStatus", - id="conversation_turns-interrupted", - ), - ], -) -def test_closing_turn_snapshot_is_emitted_before_live_conversation_status_commit( - interval: StateSnapshotInterval, - execution_interrupts, - expected_status_class, - expected_status_type: str, -) -> None: - conversation = create_output_flow_conversation() - collector = SnapshotCollector() - observer = _LiveConversationSnapshotObserver(conversation) - - with register_event_listeners([collector, observer]): - status = conversation.execute( - execution_interrupts=execution_interrupts, - state_snapshot_policy=StateSnapshotPolicy(state_snapshot_interval=interval), - ) - - assert isinstance(status, expected_status_class) - assert conversation.status is status - assert conversation.status_handled is False - assert observer.live_snapshots[-1]["execution"]["status"] is None - assert collector.state_snapshot_events[-1].state_snapshot is not None - assert collector.state_snapshot_events[-1].state_snapshot["execution"]["status"]["type"] == ( - expected_status_type - ) - assert observer.live_snapshots[-1] != { - "conversation": collector.state_snapshot_events[-1].state_snapshot["conversation"], - "execution": collector.state_snapshot_events[-1].state_snapshot["execution"], - } - - -def test_conversation_turn_snapshot_payload_can_resume_waiting_for_client_tool_result() -> None: - client_tool = ClientTool( - name="client_lookup", - description="Look up some data on the client side", - parameters={}, - ) - conversation = Flow.from_steps( - [ - ToolExecutionStep(tool=client_tool), - CompleteStep(name="end"), - ], - name="snapshot_client_tool_resume_flow", - ).start_conversation() - - status, state_snapshot_events = execute_with_state_snapshots( - conversation, - state_snapshot_policy=StateSnapshotPolicy( - state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS - ), - ) - - assert isinstance(status, ToolRequestStatus) - assert state_snapshot_events[-1].state_snapshot is not None - restored_conversation = restore_conversation_from_snapshot_payload( - state_snapshot_events[-1].state_snapshot - ) - assert isinstance(restored_conversation.status, ToolRequestStatus) - - tool_request = restored_conversation.status.tool_requests[0] - restored_conversation.append_tool_result( - ToolResult(tool_request_id=tool_request.tool_request_id, content="client-result") - ) - resumed_status = restored_conversation.execute() - - assert isinstance(resumed_status, FinishedStatus) - tool_result_messages = [ - message.tool_result - for message in restored_conversation.get_messages() - if message.tool_result - ] - assert len(tool_result_messages) == 1 - assert tool_result_messages[0].tool_request_id == tool_request.tool_request_id - assert tool_result_messages[0].content == "client-result" - - -def test_conversation_turn_snapshot_payload_round_trips_through_json_like_run_agent_input_state() -> ( - None -): - conversation = Flow.from_steps( - [ - InputMessageStep("Please answer"), - OutputMessageStep("done"), - ], - name="snapshot_user_json_roundtrip_resume_flow", - ).start_conversation() - - status, state_snapshot_events = execute_with_state_snapshots( - conversation, - state_snapshot_policy=StateSnapshotPolicy( - state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS - ), - ) - - assert isinstance(status, UserMessageRequestStatus) - assert state_snapshot_events[-1].state_snapshot is not None - - run_agent_input_state = json.loads( - json.dumps(state_snapshot_events[-1].state_snapshot, allow_nan=False) - ) - restored_conversation = restore_conversation_from_snapshot_payload(run_agent_input_state) - restored_conversation.append_user_message("hello") - resumed_status = restored_conversation.execute() - - assert isinstance(resumed_status, FinishedStatus) - assert [message.content for message in restored_conversation.get_messages()] == [ - "Please answer", - "hello", - "done", - ] - - -@pytest.mark.anyio -async def test_conversation_turn_snapshot_payload_can_resume_waiting_for_user_input_async() -> None: - conversation = Flow.from_steps( - [ - InputMessageStep("Please answer"), - OutputMessageStep("done"), - ], - name="snapshot_user_resume_flow", - ).start_conversation() - - status, state_snapshot_events = await execute_with_state_snapshots_async( - conversation, - state_snapshot_policy=StateSnapshotPolicy( - state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS - ), - ) - - assert isinstance(status, UserMessageRequestStatus) - assert state_snapshot_events[-1].state_snapshot is not None - restored_conversation = restore_conversation_from_snapshot_payload( - state_snapshot_events[-1].state_snapshot - ) - restored_conversation.append_user_message("hello") - resumed_status = await restored_conversation.execute_async() - - assert isinstance(resumed_status, FinishedStatus) - assert [message.content for message in restored_conversation.get_messages()] == [ - "Please answer", - "hello", - "done", - ] diff --git a/wayflowcore/tests/events/test_state_snapshot_runtime_core.py b/wayflowcore/tests/events/test_state_snapshot_runtime_core.py new file mode 100644 index 000000000..76df65fd6 --- /dev/null +++ b/wayflowcore/tests/events/test_state_snapshot_runtime_core.py @@ -0,0 +1,535 @@ +# Copyright © 2026 Oracle and/or its affiliates. +# +# This software is under the Apache License 2.0 +# (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0) or Universal Permissive License +# (UPL) 1.0 (LICENSE-UPL or https://oss.oracle.com/licenses/upl), at your option. + +from dataclasses import dataclass + +import pytest + +from wayflowcore.conversation import Conversation +from wayflowcore.events import Event, EventListener +from wayflowcore.events.event import StateSnapshotEvent +from wayflowcore.events.eventlistener import register_event_listeners +from wayflowcore.executors._events.event import EventType +from wayflowcore.executors.executionstatus import FinishedStatus +from wayflowcore.executors.interrupts.executioninterrupt import InterruptedExecutionStatus +from wayflowcore.executors.statesnapshotpolicy import StateSnapshotInterval, StateSnapshotPolicy +from wayflowcore.flow import Flow +from wayflowcore.property import AnyProperty +from wayflowcore.serialization import dump_conversation_state +from wayflowcore.serialization.serializer import FrozenSerializableDataclass +from wayflowcore.steps import ( + CompleteStep, + FlowExecutionStep, + OutputMessageStep, + ParallelFlowExecutionStep, + ToolExecutionStep, + VariableWriteStep, +) +from wayflowcore.tools import ServerTool +from wayflowcore.variable import Variable + +from ..test_interrupts import OnEventExecutionInterrupt +from ..testhelpers.state_snapshot_testutils import ( + execute_with_state_snapshots, + execute_with_state_snapshots_async, + restore_conversation_from_snapshot_payload, + snapshot_status_types, +) + + +class _LiveConversationSnapshotObserver(EventListener): + def __init__(self, conversation) -> None: + self.conversation = conversation + self.live_snapshots: list[dict[str, Any]] = [] + + def __call__(self, event: Event) -> None: + if isinstance(event, StateSnapshotEvent): + self.live_snapshots.append(dump_conversation_state(self.conversation)) + + +def _make_output_conversation(): + return Flow.from_steps( + [ + OutputMessageStep(message_template="Hello", name="single_step"), + CompleteStep(name="end"), + ] + ).start_conversation() + + +def _make_tool_flow_conversation(): + return Flow.from_steps( + [ + ToolExecutionStep( + tool=ServerTool( + name="say_hi", + description="Say hi", + func=lambda: "hi", + input_descriptors=[], + ), + name="tool_step", + ), + CompleteStep(name="end"), + ] + ).start_conversation() + + +def _make_parent_child_conversation(): + child_flow = Flow.from_steps( + [ + OutputMessageStep(message_template="child", name="child_message"), + CompleteStep(name="child_end"), + ], + name="child_flow", + ) + parent_flow = Flow.from_steps( + [ + FlowExecutionStep(flow=child_flow, name="child_flow_step"), + OutputMessageStep(message_template="parent", name="parent_message"), + CompleteStep(name="end"), + ], + name="parent_flow", + ) + return parent_flow.start_conversation() + + +def _make_parallel_parent_conversation(): + return Flow.from_steps( + [ + ParallelFlowExecutionStep( + flows=[ + Flow.from_steps( + [ + OutputMessageStep( + message_template="left", + output_mapping={OutputMessageStep.OUTPUT: "left_message"}, + ), + CompleteStep(name="left_end"), + ], + name="left_child_flow", + ), + Flow.from_steps( + [ + OutputMessageStep( + message_template="right", + output_mapping={OutputMessageStep.OUTPUT: "right_message"}, + ), + CompleteStep(name="right_end"), + ], + name="right_child_flow", + ), + ], + name="parallel_children", + ), + CompleteStep(name="end"), + ], + name="parallel_parent_flow", + ).start_conversation() + + +def _snapshot_message(snapshot_event: StateSnapshotEvent) -> str | None: + messages = snapshot_event.state_snapshot["conversation"]["messages"] + if not messages: + return None + return messages[-1].get("content") + + +def _snapshot_step_histories(snapshot_events: list[StateSnapshotEvent]) -> list[list[str]]: + return [ + snapshot_event.state_snapshot["execution"]["step_history"] + for snapshot_event in snapshot_events + ] + + +def _execute_tool_flow_with_interval( + interval: StateSnapshotInterval, +) -> tuple[object, list[StateSnapshotEvent]]: + return execute_with_state_snapshots( + _make_tool_flow_conversation(), + state_snapshot_policy=StateSnapshotPolicy(state_snapshot_interval=interval), + ) + + +@dataclass(frozen=True) +class _SerializableButNotJson(FrozenSerializableDataclass): + value: str + + +class _FailOnTerminalSnapshot(EventListener): + def __call__(self, event: Event) -> None: + if not isinstance(event, StateSnapshotEvent): + return + + execution_status = (event.state_snapshot or {}).get("execution", {}).get("status") + if execution_status is not None: + raise RuntimeError("snapshot sink failed") + + +def test_off_policy_emits_no_state_snapshots() -> None: + status, state_snapshot_events = execute_with_state_snapshots( + _make_output_conversation(), + state_snapshot_policy=StateSnapshotPolicy( + state_snapshot_interval=StateSnapshotInterval.OFF + ), + ) + + assert isinstance(status, FinishedStatus) + assert state_snapshot_events == [] + + +def test_conversation_turns_emit_opening_and_closing_snapshots() -> None: + status, state_snapshot_events = execute_with_state_snapshots( + _make_output_conversation(), + state_snapshot_policy=StateSnapshotPolicy( + state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS + ), + ) + + assert isinstance(status, FinishedStatus) + assert len(state_snapshot_events) == 2 + assert snapshot_status_types(state_snapshot_events) == [None, "FinishedStatus"] + assert _snapshot_message(state_snapshot_events[-1]) == "Hello" + assert all( + isinstance(snapshot_event.state_snapshot.get("conversation_state"), str) + for snapshot_event in state_snapshot_events + ) + + +def test_terminal_snapshot_is_synthesized_before_live_status_commit() -> None: + conversation = _make_output_conversation() + collector = [] + observer = _LiveConversationSnapshotObserver(conversation) + + class _Collector(EventListener): + def __call__(self, event: Event) -> None: + if isinstance(event, StateSnapshotEvent): + collector.append(event) + + with register_event_listeners([_Collector(), observer]): + status = conversation.execute( + state_snapshot_policy=StateSnapshotPolicy( + state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS + ) + ) + + assert isinstance(status, FinishedStatus) + assert conversation.status is status + assert observer.live_snapshots[-1]["execution"]["status"] is None + assert collector[-1].state_snapshot["execution"]["status"]["type"] == "FinishedStatus" + assert collector[-1].state_snapshot["execution"]["status_handled"] is False + assert observer.live_snapshots[-1] != { + "conversation": collector[-1].state_snapshot["conversation"], + "execution": collector[-1].state_snapshot["execution"], + } + + +def test_interrupted_conversation_turn_emits_terminal_snapshot() -> None: + status, state_snapshot_events = execute_with_state_snapshots( + _make_output_conversation(), + execution_interrupts=[OnEventExecutionInterrupt(EventType.EXECUTION_END)], + state_snapshot_policy=StateSnapshotPolicy( + state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS + ), + ) + + assert isinstance(status, InterruptedExecutionStatus) + assert snapshot_status_types(state_snapshot_events) == [None, "InterruptedExecutionStatus"] + assert state_snapshot_events[-1].state_snapshot["execution"]["status_handled"] is False + + +def test_node_turns_emit_internal_step_snapshots_and_only_root_turns_are_resumable() -> None: + status, state_snapshot_events = execute_with_state_snapshots( + _make_output_conversation(), + state_snapshot_policy=StateSnapshotPolicy( + state_snapshot_interval=StateSnapshotInterval.NODE_TURNS + ), + ) + + step_histories = [ + snapshot_event.state_snapshot["execution"]["step_history"] + for snapshot_event in state_snapshot_events + ] + + assert isinstance(status, FinishedStatus) + assert len(state_snapshot_events) > 2 + assert snapshot_status_types(state_snapshot_events)[0] is None + assert snapshot_status_types(state_snapshot_events)[-1] == "FinishedStatus" + assert all( + status_type is None for status_type in snapshot_status_types(state_snapshot_events)[1:-1] + ) + assert [] in step_histories + assert ["__StartStep__"] in step_histories + assert ["__StartStep__", "single_step"] in step_histories + assert ["__StartStep__", "single_step", "end"] in step_histories + assert isinstance(state_snapshot_events[0].state_snapshot.get("conversation_state"), str) + assert isinstance(state_snapshot_events[-1].state_snapshot.get("conversation_state"), str) + assert all( + "conversation_state" not in snapshot_event.state_snapshot + for snapshot_event in state_snapshot_events[1:-1] + ) + + +def test_tool_turns_emit_tool_boundary_snapshots_and_only_root_turns_are_resumable() -> None: + status, state_snapshot_events = _execute_tool_flow_with_interval( + StateSnapshotInterval.TOOL_TURNS + ) + + assert isinstance(status, FinishedStatus) + assert len(state_snapshot_events) == 4 + assert snapshot_status_types(state_snapshot_events) == [None, None, None, "FinishedStatus"] + assert isinstance(state_snapshot_events[0].state_snapshot.get("conversation_state"), str) + assert isinstance(state_snapshot_events[-1].state_snapshot.get("conversation_state"), str) + assert all( + "conversation_state" not in snapshot_event.state_snapshot + for snapshot_event in state_snapshot_events[1:-1] + ) + + +def test_all_internal_turns_emit_more_snapshots_than_tool_turns() -> None: + _, tool_turn_events = _execute_tool_flow_with_interval(StateSnapshotInterval.TOOL_TURNS) + status, all_internal_turn_events = _execute_tool_flow_with_interval( + StateSnapshotInterval.ALL_INTERNAL_TURNS + ) + + assert isinstance(status, FinishedStatus) + assert len(all_internal_turn_events) > len(tool_turn_events) + + +def test_all_internal_turns_emit_more_snapshots_than_node_turns() -> None: + _, node_turn_events = _execute_tool_flow_with_interval(StateSnapshotInterval.NODE_TURNS) + status, all_internal_turn_events = _execute_tool_flow_with_interval( + StateSnapshotInterval.ALL_INTERNAL_TURNS + ) + + assert isinstance(status, FinishedStatus) + assert len(all_internal_turn_events) > len(node_turn_events) + + +def test_all_internal_turns_emit_opening_and_terminal_statuses() -> None: + status, all_internal_turn_events = _execute_tool_flow_with_interval( + StateSnapshotInterval.ALL_INTERNAL_TURNS + ) + + assert isinstance(status, FinishedStatus) + assert snapshot_status_types(all_internal_turn_events)[0] is None + assert snapshot_status_types(all_internal_turn_events)[-1] == "FinishedStatus" + + +def test_all_internal_turns_include_node_step_boundaries() -> None: + status, all_internal_turn_events = _execute_tool_flow_with_interval( + StateSnapshotInterval.ALL_INTERNAL_TURNS + ) + step_histories = _snapshot_step_histories(all_internal_turn_events) + + assert isinstance(status, FinishedStatus) + assert ["__StartStep__"] in step_histories + assert ["__StartStep__", "tool_step"] in step_histories + assert ["__StartStep__", "tool_step", "end"] in step_histories + + +def test_all_internal_turns_keep_only_root_turns_resumable() -> None: + status, all_internal_turn_events = _execute_tool_flow_with_interval( + StateSnapshotInterval.ALL_INTERNAL_TURNS + ) + + assert isinstance(status, FinishedStatus) + assert isinstance(all_internal_turn_events[0].state_snapshot.get("conversation_state"), str) + assert isinstance(all_internal_turn_events[-1].state_snapshot.get("conversation_state"), str) + assert all( + "conversation_state" not in snapshot_event.state_snapshot + for snapshot_event in all_internal_turn_events[1:-1] + ) + + +@pytest.mark.anyio +async def test_execute_async_emits_conversation_turn_snapshots() -> None: + status, state_snapshot_events = await execute_with_state_snapshots_async( + _make_output_conversation(), + state_snapshot_policy=StateSnapshotPolicy( + state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS + ), + ) + + assert isinstance(status, FinishedStatus) + assert snapshot_status_types(state_snapshot_events) == [None, "FinishedStatus"] + + +def test_parallel_conversation_turn_snapshots_stay_on_the_root_conversation() -> None: + conversation = _make_parallel_parent_conversation() + status, state_snapshot_events = execute_with_state_snapshots( + conversation, + state_snapshot_policy=StateSnapshotPolicy( + state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS + ), + ) + + assert isinstance(status, FinishedStatus) + assert snapshot_status_types(state_snapshot_events) == [None, "FinishedStatus"] + assert [snapshot_event.conversation_id for snapshot_event in state_snapshot_events] == [ + conversation.conversation_id, + conversation.conversation_id, + ] + assert [ + snapshot_event.state_snapshot["conversation"]["id"] + for snapshot_event in state_snapshot_events + ] == [conversation.id, conversation.id] + + restored_conversation = restore_conversation_from_snapshot_payload( + state_snapshot_events[-1].state_snapshot + ) + resumed_status = restored_conversation.execute() + + assert isinstance(resumed_status, FinishedStatus) + assert sorted(message.content for message in restored_conversation.get_messages()) == [ + "left", + "right", + ] + + +def test_nested_conversation_turn_snapshots_stay_on_the_root_conversation() -> None: + conversation = _make_parent_child_conversation() + status, state_snapshot_events = execute_with_state_snapshots( + conversation, + state_snapshot_policy=StateSnapshotPolicy( + state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS + ), + ) + + assert isinstance(status, FinishedStatus) + assert len(state_snapshot_events) == 2 + assert snapshot_status_types(state_snapshot_events) == [None, "FinishedStatus"] + assert [snapshot_event.conversation_id for snapshot_event in state_snapshot_events] == [ + conversation.conversation_id, + conversation.conversation_id, + ] + assert [ + snapshot_event.state_snapshot["conversation"]["id"] + for snapshot_event in state_snapshot_events + ] == [conversation.id, conversation.id] + assert _snapshot_message(state_snapshot_events[-1]) == "parent" + + +def test_nested_node_turn_snapshots_can_capture_child_runtime_conversation_ids() -> None: + conversation = _make_parent_child_conversation() + status, state_snapshot_events = execute_with_state_snapshots( + conversation, + state_snapshot_policy=StateSnapshotPolicy( + state_snapshot_interval=StateSnapshotInterval.NODE_TURNS + ), + ) + + assert isinstance(status, FinishedStatus) + assert {snapshot_event.conversation_id for snapshot_event in state_snapshot_events} == { + conversation.conversation_id + } + assert isinstance(state_snapshot_events[0].state_snapshot.get("conversation_state"), str) + assert isinstance(state_snapshot_events[-1].state_snapshot.get("conversation_state"), str) + assert all( + "conversation_state" not in snapshot_event.state_snapshot + for snapshot_event in state_snapshot_events[1:-1] + ) + assert any( + snapshot_event.state_snapshot["conversation"]["id"] != conversation.id + for snapshot_event in state_snapshot_events[1:-1] + ) + + +def test_state_snapshot_emits_variable_state_for_successful_flow_execution() -> None: + custom_variable = Variable(name="custom", type=AnyProperty()) + conversation = Flow.from_steps( + [ + VariableWriteStep( + variable=custom_variable, + input_mapping={VariableWriteStep.VALUE: custom_variable.name}, + ), + OutputMessageStep(message_template="done"), + CompleteStep(name="end"), + ], + variables=[custom_variable], + ).start_conversation(inputs={custom_variable.name: "stored-value"}) + + status, state_snapshot_events = execute_with_state_snapshots( + conversation, + state_snapshot_policy=StateSnapshotPolicy( + state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS, + include_variable_state=True, + ), + ) + + assert isinstance(status, FinishedStatus) + assert state_snapshot_events[0].variable_state == {"custom": None} + assert state_snapshot_events[-1].variable_state == {"custom": "stored-value"} + + +def test_state_snapshot_emission_propagates_extra_state_builder_failures() -> None: + output_conversation = _make_output_conversation() + + def broken_builder(_conversation: Conversation) -> dict[str, object]: + raise RuntimeError("boom") + + with pytest.raises(RuntimeError, match="boom"): + output_conversation.execute( + state_snapshot_policy=StateSnapshotPolicy( + state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS, + extra_state_builder=broken_builder, + ) + ) + + assert output_conversation.get_last_message() is None + + +def test_state_snapshot_emission_rejects_non_strict_json_extra_state() -> None: + output_conversation = _make_output_conversation() + + with pytest.raises(TypeError, match="Extra snapshot state .* strict JSON-serializable"): + output_conversation.execute( + state_snapshot_policy=StateSnapshotPolicy( + state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS, + extra_state_builder=lambda _conversation: {"ui": {"preview_count": float("nan")}}, + ) + ) + + assert output_conversation.get_last_message() is None + + +def test_state_snapshot_emission_propagates_unserializable_variable_state() -> None: + custom_variable = Variable(name="custom", type=AnyProperty()) + conversation = Flow.from_steps( + [ + VariableWriteStep( + variable=custom_variable, + input_mapping={VariableWriteStep.VALUE: custom_variable.name}, + ), + OutputMessageStep(message_template="done"), + CompleteStep(name="end"), + ], + variables=[custom_variable], + ).start_conversation(inputs={custom_variable.name: _SerializableButNotJson(value="x")}) + + with pytest.raises(TypeError, match="Variable 'custom' contains a non-JSON-serializable"): + conversation.execute( + state_snapshot_policy=StateSnapshotPolicy( + state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS, + include_variable_state=True, + ) + ) + + assert conversation.get_last_message() is not None + assert conversation.get_last_message().content == "done" + + +def test_state_snapshot_listener_failures_propagate_to_the_caller() -> None: + output_conversation = _make_output_conversation() + + with register_event_listeners([_FailOnTerminalSnapshot()]): + with pytest.raises(RuntimeError, match="snapshot sink failed"): + output_conversation.execute( + state_snapshot_policy=StateSnapshotPolicy( + state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS + ) + ) + + assert output_conversation.get_last_message() is not None + assert output_conversation.get_last_message().content == "Hello" diff --git a/wayflowcore/tests/events/test_state_snapshot_runtime_internal_turns.py b/wayflowcore/tests/events/test_state_snapshot_runtime_internal_turns.py deleted file mode 100644 index a4f8ba902..000000000 --- a/wayflowcore/tests/events/test_state_snapshot_runtime_internal_turns.py +++ /dev/null @@ -1,490 +0,0 @@ -# Copyright © 2026 Oracle and/or its affiliates. -# -# This software is under the Apache License 2.0 -# (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0) or Universal Permissive License -# (UPL) 1.0 (LICENSE-UPL or https://oss.oracle.com/licenses/upl), at your option. - -import json - -import pytest - -from wayflowcore.agent import Agent -from wayflowcore.conversation import Conversation -from wayflowcore.events.eventlistener import register_event_listeners -from wayflowcore.executors._events.event import EventType -from wayflowcore.executors.executionstatus import FinishedStatus, UserMessageRequestStatus -from wayflowcore.executors.interrupts.executioninterrupt import InterruptedExecutionStatus -from wayflowcore.executors.statesnapshotpolicy import StateSnapshotInterval, StateSnapshotPolicy -from wayflowcore.flow import Flow -from wayflowcore.property import StringProperty -from wayflowcore.steps import CompleteStep, ToolExecutionStep -from wayflowcore.tools import ServerTool - -from ..conftest import disable_streaming -from ..test_interrupts import OnEventExecutionInterrupt -from ..testhelpers.statesnapshots import ( - SnapshotCollector, - SnapshotSerializableDummyModel, - create_agent_conversation, - create_output_flow_conversation, - create_tool_calling_agent_conversation, - create_tool_flow_conversation, - execute_with_state_snapshots, - execute_with_state_snapshots_async, - snapshot_status_types, - snapshot_step_histories, -) - - -def _make_snapshot_size_stress_conversation() -> Conversation: - repeated_description = "serialized tool description " + ("D" * 1000) - tool_steps = [ - ToolExecutionStep( - tool=ServerTool( - name=f"tool_{index}", - description=f"{repeated_description}-{index}", - func=lambda index=index: f"value-{index}", - input_descriptors=[], - output_descriptors=[StringProperty(name=f"out_{index}")], - ), - name=f"step_{index}", - ) - for index in range(8) - ] - - return Flow.from_steps( - steps=[*tool_steps, CompleteStep(name="end")], - step_names=[*(f"step_{index}" for index in range(8)), "end"], - name="state_snapshot_size_stress_flow", - ).start_conversation() - - -def _snapshot_payload_num_bytes(snapshot_payload: dict[str, object]) -> int: - return len(json.dumps(snapshot_payload, sort_keys=True)) - - -@pytest.mark.parametrize( - ( - "conversation_factory", - "expected_status_class", - "expected_status_types", - "expected_snapshot_count", - "expected_curr_iters", - ), - [ - pytest.param( - create_output_flow_conversation, - FinishedStatus, - [None, None, None, None, None, None, None, "FinishedStatus"], - 8, - None, - id="flow", - ), - pytest.param( - create_agent_conversation, - UserMessageRequestStatus, - [None, None, None, "UserMessageRequestStatus"], - 4, - [0, 1], - id="agent", - ), - ], -) -def test_node_turn_policy_tracks_flow_steps_and_agent_iterations( - conversation_factory, - expected_status_class, - expected_status_types: list[str | None], - expected_snapshot_count: int, - expected_curr_iters: list[int] | None, -) -> None: - conversation = conversation_factory() - - status, state_snapshot_events = execute_with_state_snapshots( - conversation, - state_snapshot_policy=StateSnapshotPolicy( - state_snapshot_interval=StateSnapshotInterval.NODE_TURNS - ), - ) - - assert isinstance(status, expected_status_class) - assert len(state_snapshot_events) == expected_snapshot_count - assert snapshot_status_types(state_snapshot_events) == expected_status_types - if expected_curr_iters is not None: - assert [ - state_snapshot_events[1].state_snapshot["execution"]["curr_iter"], - state_snapshot_events[2].state_snapshot["execution"]["curr_iter"], - ] == expected_curr_iters - - -@pytest.mark.anyio -@pytest.mark.parametrize( - ( - "conversation_factory", - "expected_status_class", - "expected_status_types", - "expected_snapshot_count", - "expected_curr_iters", - ), - [ - pytest.param( - create_output_flow_conversation, - FinishedStatus, - [None, None, None, None, None, None, None, "FinishedStatus"], - 8, - None, - id="flow", - ), - pytest.param( - create_agent_conversation, - UserMessageRequestStatus, - [None, None, None, "UserMessageRequestStatus"], - 4, - [0, 1], - id="agent", - ), - ], -) -async def test_node_turn_policy_tracks_flow_steps_and_agent_iterations_async( - conversation_factory, - expected_status_class, - expected_status_types: list[str | None], - expected_snapshot_count: int, - expected_curr_iters: list[int] | None, -) -> None: - conversation = conversation_factory() - - status, state_snapshot_events = await execute_with_state_snapshots_async( - conversation, - state_snapshot_policy=StateSnapshotPolicy( - state_snapshot_interval=StateSnapshotInterval.NODE_TURNS - ), - ) - - assert isinstance(status, expected_status_class) - assert len(state_snapshot_events) == expected_snapshot_count - assert snapshot_status_types(state_snapshot_events) == expected_status_types - if expected_curr_iters is not None: - assert [ - state_snapshot_events[1].state_snapshot["execution"]["curr_iter"], - state_snapshot_events[2].state_snapshot["execution"]["curr_iter"], - ] == expected_curr_iters - - -@pytest.mark.parametrize( - ("conversation_factory", "interrupt_event"), - [ - pytest.param(create_output_flow_conversation, EventType.STEP_EXECUTION_START, id="flow"), - pytest.param(create_agent_conversation, EventType.GENERATION_START, id="agent"), - ], -) -def test_node_turn_policy_keeps_partial_progress_when_interrupted_mid_turn( - conversation_factory, - interrupt_event: EventType, -) -> None: - conversation = conversation_factory() - - status, state_snapshot_events = execute_with_state_snapshots( - conversation, - execution_interrupts=[OnEventExecutionInterrupt(interrupt_event)], - state_snapshot_policy=StateSnapshotPolicy( - state_snapshot_interval=StateSnapshotInterval.NODE_TURNS - ), - ) - - assert isinstance(status, InterruptedExecutionStatus) - assert snapshot_status_types(state_snapshot_events) == [None, None] - - -def test_flow_node_turn_policy_uses_iteration_start_and_end_boundaries() -> None: - conversation = create_output_flow_conversation() - - status, state_snapshot_events = execute_with_state_snapshots( - conversation, - state_snapshot_policy=StateSnapshotPolicy( - state_snapshot_interval=StateSnapshotInterval.NODE_TURNS - ), - ) - - assert isinstance(status, FinishedStatus) - assert snapshot_step_histories(state_snapshot_events) == [ - [], - [], - ["__StartStep__"], - ["__StartStep__"], - ["__StartStep__", "step_0"], - ["__StartStep__", "step_0"], - ["__StartStep__", "step_0", "end"], - ["__StartStep__", "step_0", "end"], - ] - - -def test_node_turn_policy_keeps_only_root_turn_checkpoints_resumable() -> None: - conversation = create_output_flow_conversation() - - status, state_snapshot_events = execute_with_state_snapshots( - conversation, - state_snapshot_policy=StateSnapshotPolicy( - state_snapshot_interval=StateSnapshotInterval.NODE_TURNS - ), - ) - - assert isinstance(status, FinishedStatus) - assert [ - isinstance(snapshot_event.state_snapshot.get("conversation_state"), str) - for snapshot_event in state_snapshot_events - ] == [True, False, False, False, False, False, False, True] - - -def test_node_turn_policy_stays_lightweight_under_snapshot_size_stress() -> None: - conversation = _make_snapshot_size_stress_conversation() - - status, state_snapshot_events = execute_with_state_snapshots( - conversation, - state_snapshot_policy=StateSnapshotPolicy( - state_snapshot_interval=StateSnapshotInterval.NODE_TURNS - ), - ) - - assert isinstance(status, FinishedStatus) - - snapshot_payloads = [snapshot_event.state_snapshot for snapshot_event in state_snapshot_events] - assert all(payload is not None for payload in snapshot_payloads) - snapshot_payloads = [payload for payload in snapshot_payloads if payload is not None] - internal_snapshot_payloads = snapshot_payloads[1:-1] - assert internal_snapshot_payloads - assert all("conversation_state" not in payload for payload in internal_snapshot_payloads) - - largest_root_conversation_state = max( - ( - payload["conversation_state"] - for payload in snapshot_payloads - if isinstance(payload.get("conversation_state"), str) - ), - key=len, - ) - actual_total_bytes = sum(_snapshot_payload_num_bytes(payload) for payload in snapshot_payloads) - inflated_total_bytes = sum( - ( - _snapshot_payload_num_bytes(payload) - if "conversation_state" in payload - else _snapshot_payload_num_bytes( - {**payload, "conversation_state": largest_root_conversation_state} - ) - ) - for payload in snapshot_payloads - ) - - assert actual_total_bytes < inflated_total_bytes * 0.2 - - -def test_internal_snapshots_do_not_reuse_the_previous_turn_status() -> None: - llm = SnapshotSerializableDummyModel() - llm.set_next_output(["Hello from agent", "Hello again"]) - conversation = Agent(llm=llm).start_conversation() - conversation.append_user_message("Hi") - collector = SnapshotCollector() - policy = StateSnapshotPolicy(state_snapshot_interval=StateSnapshotInterval.NODE_TURNS) - - with register_event_listeners([collector]): - first_status = conversation.execute(state_snapshot_policy=policy) - assert isinstance(first_status, UserMessageRequestStatus) - - first_status.submit_user_response("Continue") - second_status = conversation.execute(state_snapshot_policy=policy) - - assert isinstance(second_status, UserMessageRequestStatus) - assert len(collector.state_snapshot_events) == 8 - - second_turn_internal_snapshots = collector.state_snapshot_events[5:7] - assert snapshot_status_types(second_turn_internal_snapshots) == [None, None] - assert all( - snapshot_event.state_snapshot["execution"]["status_handled"] is False - for snapshot_event in second_turn_internal_snapshots - ) - - -@pytest.mark.parametrize( - ( - "conversation_factory", - "execution_interrupts", - "execution_context", - "expected_status_class", - "expected_status_types", - ), - [ - pytest.param( - lambda: create_tool_flow_conversation(lambda: "hi"), - None, - None, - FinishedStatus, - [None, None, None, "FinishedStatus"], - id="flow-success", - ), - pytest.param( - create_tool_calling_agent_conversation, - [OnEventExecutionInterrupt(EventType.TOOL_CALL_END)], - disable_streaming(), - InterruptedExecutionStatus, - [None, None, None], - id="agent-tool-end-interrupt", - ), - ], -) -def test_tool_turn_policy_records_real_tool_boundaries( - conversation_factory, - execution_interrupts, - execution_context, - expected_status_class, - expected_status_types: list[str | None], -) -> None: - conversation = conversation_factory() - - status, state_snapshot_events = execute_with_state_snapshots( - conversation, - execution_interrupts=execution_interrupts, - state_snapshot_policy=StateSnapshotPolicy( - state_snapshot_interval=StateSnapshotInterval.TOOL_TURNS - ), - execution_context=execution_context, - ) - - assert isinstance(status, expected_status_class) - assert snapshot_status_types(state_snapshot_events) == expected_status_types - - -def test_tool_turn_policy_keeps_only_root_turn_checkpoints_resumable() -> None: - conversation = create_tool_flow_conversation(lambda: "hi") - - status, state_snapshot_events = execute_with_state_snapshots( - conversation, - state_snapshot_policy=StateSnapshotPolicy( - state_snapshot_interval=StateSnapshotInterval.TOOL_TURNS - ), - ) - - assert isinstance(status, FinishedStatus) - assert [ - isinstance(snapshot_event.state_snapshot.get("conversation_state"), str) - for snapshot_event in state_snapshot_events - ] == [True, False, False, True] - - -@pytest.mark.anyio -@pytest.mark.parametrize( - ( - "conversation_factory", - "execution_interrupts", - "execution_context", - "expected_status_class", - "expected_status_types", - ), - [ - pytest.param( - lambda: create_tool_flow_conversation(lambda: "hi"), - None, - None, - FinishedStatus, - [None, None, None, "FinishedStatus"], - id="flow-success", - ), - pytest.param( - create_tool_calling_agent_conversation, - [OnEventExecutionInterrupt(EventType.TOOL_CALL_END)], - disable_streaming(), - InterruptedExecutionStatus, - [None, None, None], - id="agent-tool-end-interrupt", - ), - ], -) -async def test_tool_turn_policy_records_real_tool_boundaries_async( - conversation_factory, - execution_interrupts, - execution_context, - expected_status_class, - expected_status_types: list[str | None], -) -> None: - conversation = conversation_factory() - - status, state_snapshot_events = await execute_with_state_snapshots_async( - conversation, - execution_interrupts=execution_interrupts, - state_snapshot_policy=StateSnapshotPolicy( - state_snapshot_interval=StateSnapshotInterval.TOOL_TURNS - ), - execution_context=execution_context, - ) - - assert isinstance(status, expected_status_class) - assert snapshot_status_types(state_snapshot_events) == expected_status_types - - -def test_all_internal_turn_policy_combines_node_and_tool_boundaries() -> None: - conversation = create_tool_flow_conversation(lambda: "hi") - - status, state_snapshot_events = execute_with_state_snapshots( - conversation, - state_snapshot_policy=StateSnapshotPolicy( - state_snapshot_interval=StateSnapshotInterval.ALL_INTERNAL_TURNS - ), - ) - - assert isinstance(status, FinishedStatus) - assert len(state_snapshot_events) == 10 - assert snapshot_status_types(state_snapshot_events) == [None] * 9 + ["FinishedStatus"] - - -@pytest.mark.anyio -async def test_all_internal_turn_policy_combines_node_and_tool_boundaries_async() -> None: - conversation = create_tool_flow_conversation(lambda: "hi") - - status, state_snapshot_events = await execute_with_state_snapshots_async( - conversation, - state_snapshot_policy=StateSnapshotPolicy( - state_snapshot_interval=StateSnapshotInterval.ALL_INTERNAL_TURNS - ), - ) - - assert isinstance(status, FinishedStatus) - assert len(state_snapshot_events) == 10 - assert snapshot_status_types(state_snapshot_events) == [None] * 9 + ["FinishedStatus"] - - -@pytest.mark.parametrize( - ("interval", "expected_status_types"), - [ - pytest.param( - StateSnapshotInterval.CONVERSATION_TURNS, - [None, "FinishedStatus"], - id="conversation_turns", - ), - pytest.param( - StateSnapshotInterval.TOOL_TURNS, - [None, None, None, "FinishedStatus"], - id="tool_turns", - ), - pytest.param( - StateSnapshotInterval.NODE_TURNS, - [None, None, None, None, None, None, None, "FinishedStatus"], - id="node_turns", - ), - pytest.param( - StateSnapshotInterval.ALL_INTERNAL_TURNS, - [None, None, None, None, None, None, None, None, None, "FinishedStatus"], - id="all_internal_turns", - ), - ], -) -def test_snapshot_interval_policies_include_conversation_turns_cumulatively( - interval: StateSnapshotInterval, - expected_status_types: list[str | None], -) -> None: - conversation = create_tool_flow_conversation(lambda: "hi") - - status, state_snapshot_events = execute_with_state_snapshots( - conversation, - state_snapshot_policy=StateSnapshotPolicy(state_snapshot_interval=interval), - ) - - assert isinstance(status, FinishedStatus) - assert snapshot_status_types(state_snapshot_events) == expected_status_types diff --git a/wayflowcore/tests/events/test_state_snapshot_runtime_nested.py b/wayflowcore/tests/events/test_state_snapshot_runtime_nested.py deleted file mode 100644 index 657aafada..000000000 --- a/wayflowcore/tests/events/test_state_snapshot_runtime_nested.py +++ /dev/null @@ -1,331 +0,0 @@ -# Copyright © 2026 Oracle and/or its affiliates. -# -# This software is under the Apache License 2.0 -# (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0) or Universal Permissive License -# (UPL) 1.0 (LICENSE-UPL or https://oss.oracle.com/licenses/upl), at your option. - -import pytest - -from wayflowcore.executors.executionstatus import FinishedStatus, UserMessageRequestStatus -from wayflowcore.executors.statesnapshotpolicy import StateSnapshotInterval, StateSnapshotPolicy -from wayflowcore.flow import Flow -from wayflowcore.steps import ( - CompleteStep, - FlowExecutionStep, - OutputMessageStep, - ParallelFlowExecutionStep, - ParallelMapStep, -) - -from ..testhelpers.statesnapshots import ( - create_nested_agent_step_flow_conversation, - create_parallel_child_flow, - execute_with_state_snapshots, - execute_with_state_snapshots_async, - restore_conversation_from_snapshot_payload, - snapshot_message, - snapshot_runtime_conversation_ids, - snapshot_status_types, -) - - -def test_state_snapshot_policy_is_inherited_by_nested_sub_conversations() -> None: - child_flow = Flow.from_steps( - [ - OutputMessageStep(message_template="child"), - CompleteStep(name="end"), - ] - ) - parent_flow = Flow.from_steps( - [ - FlowExecutionStep(flow=child_flow), - CompleteStep(name="end"), - ] - ) - conversation = parent_flow.start_conversation() - - status, state_snapshot_events = execute_with_state_snapshots( - conversation, - state_snapshot_policy=StateSnapshotPolicy( - state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS - ), - ) - - assert isinstance(status, FinishedStatus) - assert len(state_snapshot_events) == 2 - assert {snapshot_event.conversation_id for snapshot_event in state_snapshot_events} == { - conversation.conversation_id - } - assert snapshot_runtime_conversation_ids(state_snapshot_events) == [ - conversation.id, - conversation.id, - ] - - -@pytest.mark.anyio -async def test_state_snapshot_policy_is_inherited_by_nested_sub_conversations_async() -> None: - child_flow = Flow.from_steps( - [ - OutputMessageStep(message_template="child"), - CompleteStep(name="end"), - ] - ) - parent_flow = Flow.from_steps( - [ - FlowExecutionStep(flow=child_flow), - CompleteStep(name="end"), - ] - ) - conversation = parent_flow.start_conversation() - - status, state_snapshot_events = await execute_with_state_snapshots_async( - conversation, - state_snapshot_policy=StateSnapshotPolicy( - state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS - ), - ) - - assert isinstance(status, FinishedStatus) - assert len(state_snapshot_events) == 2 - assert {snapshot_event.conversation_id for snapshot_event in state_snapshot_events} == { - conversation.conversation_id - } - assert snapshot_runtime_conversation_ids(state_snapshot_events) == [ - conversation.id, - conversation.id, - ] - - -def test_nested_root_turn_snapshot_payload_can_resume_the_logical_parent_conversation() -> None: - child_flow = Flow.from_steps( - [ - OutputMessageStep(message_template="child"), - CompleteStep(name="child_end"), - ], - name="child_flow", - ) - parent_flow = Flow.from_steps( - [ - FlowExecutionStep(flow=child_flow), - OutputMessageStep(message_template="parent"), - CompleteStep(name="parent_end"), - ], - name="parent_flow", - ) - conversation = parent_flow.start_conversation() - - status, state_snapshot_events = execute_with_state_snapshots( - conversation, - state_snapshot_policy=StateSnapshotPolicy( - state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS - ), - ) - - assert isinstance(status, FinishedStatus) - root_turn_snapshot_event = state_snapshot_events[-1] - root_turn_snapshot = root_turn_snapshot_event.state_snapshot - assert root_turn_snapshot is not None - assert root_turn_snapshot_event.conversation_id == conversation.conversation_id - assert root_turn_snapshot["conversation"]["id"] == conversation.id - - restored_conversation = restore_conversation_from_snapshot_payload(root_turn_snapshot) - assert restored_conversation.id == conversation.id - assert restored_conversation.conversation_id == conversation.conversation_id - - resumed_status = restored_conversation.execute() - - assert isinstance(resumed_status, FinishedStatus) - assert [message.content for message in restored_conversation.get_messages()] == [ - "child", - "parent", - ] - - -def test_state_snapshot_policy_is_inherited_by_parallel_sub_conversations() -> None: - conversation = Flow.from_steps( - [ - ParallelFlowExecutionStep( - flows=[ - create_parallel_child_flow("left_output", "left"), - create_parallel_child_flow("right_output", "right"), - ] - ), - CompleteStep(name="end"), - ] - ).start_conversation() - - status, state_snapshot_events = execute_with_state_snapshots( - conversation, - state_snapshot_policy=StateSnapshotPolicy( - state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS - ), - ) - - assert isinstance(status, FinishedStatus) - assert len(state_snapshot_events) == 2 - assert {snapshot_event.conversation_id for snapshot_event in state_snapshot_events} == { - conversation.conversation_id - } - assert snapshot_runtime_conversation_ids(state_snapshot_events) == [ - conversation.id, - conversation.id, - ] - assert snapshot_status_types(state_snapshot_events) == [None, "FinishedStatus"] - - -def test_parallel_root_turn_snapshot_payloads_can_resume_the_logical_parent_conversation() -> None: - left_child_flow = Flow.from_steps( - [ - OutputMessageStep( - message_template="left", - output_mapping={OutputMessageStep.OUTPUT: "left_message"}, - ), - CompleteStep(name="left_end"), - ], - name="left_child_flow", - ) - right_child_flow = Flow.from_steps( - [ - OutputMessageStep( - message_template="right", - output_mapping={OutputMessageStep.OUTPUT: "right_message"}, - ), - CompleteStep(name="right_end"), - ], - name="right_child_flow", - ) - conversation = Flow.from_steps( - [ - ParallelFlowExecutionStep( - flows=[ - left_child_flow, - right_child_flow, - ] - ), - CompleteStep(name="end"), - ] - ).start_conversation() - - status, state_snapshot_events = execute_with_state_snapshots( - conversation, - state_snapshot_policy=StateSnapshotPolicy( - state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS - ), - ) - - assert isinstance(status, FinishedStatus) - - for snapshot_event in state_snapshot_events: - snapshot_payload = snapshot_event.state_snapshot - assert snapshot_payload is not None - - restored_conversation = restore_conversation_from_snapshot_payload(snapshot_payload) - resumed_status = restored_conversation.execute() - - assert isinstance(resumed_status, FinishedStatus) - assert sorted(message.content for message in restored_conversation.get_messages()) == [ - "left", - "right", - ] - - -def test_parallel_map_emits_only_resumable_parent_turn_snapshots() -> None: - child_flow = Flow.from_steps( - [ - OutputMessageStep(message_template="item={{item}}"), - CompleteStep(name="child_end"), - ], - name="parallel_map_child", - ) - conversation = Flow.from_steps( - [ - ParallelMapStep( - flow=child_flow, - unpack_input={"item": "."}, - name="parallel_map", - ), - CompleteStep(name="end"), - ], - name="parallel_map_parent", - ).start_conversation(inputs={ParallelMapStep.ITERATED_INPUT: ["a", "b"]}) - - status, state_snapshot_events = execute_with_state_snapshots( - conversation, - state_snapshot_policy=StateSnapshotPolicy( - state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS - ), - ) - - assert isinstance(status, FinishedStatus) - assert snapshot_status_types(state_snapshot_events) == [None, "FinishedStatus"] - - for snapshot_event in state_snapshot_events: - snapshot_payload = snapshot_event.state_snapshot - assert snapshot_payload is not None - - restored_conversation = restore_conversation_from_snapshot_payload(snapshot_payload) - resumed_status = restored_conversation.execute() - - assert isinstance(resumed_status, FinishedStatus) - assert sorted(message.content for message in restored_conversation.get_messages()) == [ - "item=a", - "item=b", - ] - - -def test_nested_node_turn_snapshots_keep_child_runtime_conversation_identity() -> None: - child_flow = Flow.from_steps( - [ - OutputMessageStep(message_template="child"), - CompleteStep(name="child_end"), - ], - name="child_flow", - ) - parent_flow = Flow.from_steps( - [ - FlowExecutionStep(flow=child_flow), - OutputMessageStep(message_template="parent"), - CompleteStep(name="parent_end"), - ], - name="parent_flow", - ) - conversation = parent_flow.start_conversation() - - status, state_snapshot_events = execute_with_state_snapshots( - conversation, - state_snapshot_policy=StateSnapshotPolicy( - state_snapshot_interval=StateSnapshotInterval.NODE_TURNS - ), - ) - - assert isinstance(status, FinishedStatus) - assert {snapshot_event.conversation_id for snapshot_event in state_snapshot_events} == { - conversation.conversation_id - } - assert any( - snapshot_event.state_snapshot["conversation"]["id"] != conversation.id - for snapshot_event in state_snapshot_events - ) - - -def test_state_snapshot_policy_is_inherited_by_nested_agent_steps() -> None: - conversation = create_nested_agent_step_flow_conversation() - - status, state_snapshot_events = execute_with_state_snapshots( - conversation, - state_snapshot_policy=StateSnapshotPolicy( - state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS - ), - ) - - assert isinstance(status, UserMessageRequestStatus) - assert snapshot_status_types(state_snapshot_events) == [None, "UserMessageRequestStatus"] - assert [snapshot_event.conversation_id for snapshot_event in state_snapshot_events] == [ - conversation.conversation_id, - conversation.conversation_id, - ] - assert snapshot_runtime_conversation_ids(state_snapshot_events) == [ - conversation.id, - conversation.id, - ] - assert snapshot_message(state_snapshot_events[-1]) == "agent answer" diff --git a/wayflowcore/tests/events/test_state_snapshot_runtime_resilience.py b/wayflowcore/tests/events/test_state_snapshot_runtime_resilience.py deleted file mode 100644 index 1a1e5b19c..000000000 --- a/wayflowcore/tests/events/test_state_snapshot_runtime_resilience.py +++ /dev/null @@ -1,113 +0,0 @@ -# Copyright © 2026 Oracle and/or its affiliates. -# -# This software is under the Apache License 2.0 -# (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0) or Universal Permissive License -# (UPL) 1.0 (LICENSE-UPL or https://oss.oracle.com/licenses/upl), at your option. - -from dataclasses import dataclass - -import pytest - -from wayflowcore.conversation import Conversation -from wayflowcore.events import Event, EventListener -from wayflowcore.events.event import StateSnapshotEvent -from wayflowcore.events.eventlistener import register_event_listeners -from wayflowcore.executors.statesnapshotpolicy import StateSnapshotInterval, StateSnapshotPolicy -from wayflowcore.flow import Flow -from wayflowcore.property import AnyProperty -from wayflowcore.serialization.serializer import FrozenSerializableDataclass -from wayflowcore.steps import CompleteStep, OutputMessageStep, VariableWriteStep -from wayflowcore.variable import Variable - -from ..testhelpers.statesnapshots import create_output_flow_conversation - - -@dataclass(frozen=True) -class _SerializableButNotJson(FrozenSerializableDataclass): - value: str - - -def _create_non_json_variable_state_conversation() -> Conversation: - custom_variable = Variable(name="custom", type=AnyProperty()) - return Flow.from_steps( - [ - VariableWriteStep( - variable=custom_variable, - input_mapping={VariableWriteStep.VALUE: custom_variable.name}, - ), - OutputMessageStep(message_template="done"), - CompleteStep(name="end"), - ], - variables=[custom_variable], - ).start_conversation(inputs={custom_variable.name: _SerializableButNotJson(value="x")}) - - -def test_state_snapshot_emission_propagates_extra_state_builder_failures() -> None: - def broken_builder(_conversation: Conversation) -> dict[str, object]: - raise RuntimeError("boom") - - conversation = create_output_flow_conversation() - - with pytest.raises(RuntimeError, match="boom"): - conversation.execute( - state_snapshot_policy=StateSnapshotPolicy( - state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS, - extra_state_builder=broken_builder, - ) - ) - - assert conversation.get_last_message() is None - - -def test_state_snapshot_emission_rejects_non_strict_json_extra_state() -> None: - conversation = create_output_flow_conversation() - - with pytest.raises(TypeError, match="Extra snapshot state .* strict JSON-serializable"): - conversation.execute( - state_snapshot_policy=StateSnapshotPolicy( - state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS, - extra_state_builder=lambda _conversation: {"ui": {"preview_count": float("nan")}}, - ) - ) - - assert conversation.get_last_message() is None - - -def test_state_snapshot_emission_propagates_unserializable_variable_state() -> None: - conversation = _create_non_json_variable_state_conversation() - - with pytest.raises(TypeError, match="Variable 'custom' contains a non-JSON-serializable"): - conversation.execute( - state_snapshot_policy=StateSnapshotPolicy( - state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS, - include_variable_state=True, - ) - ) - - assert conversation.get_last_message() is not None - assert conversation.get_last_message().content == "done" - - -class _FailOnTerminalSnapshot(EventListener): - def __call__(self, event: Event) -> None: - if not isinstance(event, StateSnapshotEvent): - return - - execution_status = (event.state_snapshot or {}).get("execution", {}).get("status") - if execution_status is not None: - raise RuntimeError("snapshot sink failed") - - -def test_state_snapshot_listener_failures_propagate_to_the_caller() -> None: - conversation = create_output_flow_conversation() - - with register_event_listeners([_FailOnTerminalSnapshot()]): - with pytest.raises(RuntimeError, match="snapshot sink failed"): - conversation.execute( - state_snapshot_policy=StateSnapshotPolicy( - state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS - ) - ) - - assert conversation.get_last_message() is not None - assert conversation.get_last_message().content == "Hello" diff --git a/wayflowcore/tests/serialization/test_conversation_state_snapshot.py b/wayflowcore/tests/serialization/test_conversation_state_snapshot.py index 5fb86666d..1004ab250 100644 --- a/wayflowcore/tests/serialization/test_conversation_state_snapshot.py +++ b/wayflowcore/tests/serialization/test_conversation_state_snapshot.py @@ -12,7 +12,6 @@ from wayflowcore.controlconnection import ControlFlowEdge from wayflowcore.conversation import Conversation from wayflowcore.dataconnection import DataFlowEdge -from wayflowcore.executors._flowconversation import FlowConversation from wayflowcore.executors.executionstatus import ( FinishedStatus, ToolRequestStatus, @@ -22,7 +21,6 @@ from wayflowcore.flow import Flow from wayflowcore.property import AnyProperty, StringProperty from wayflowcore.serialization import ( - deserialize_conversation, deserialize_conversation_state, dump_conversation_state, dump_variable_state, @@ -41,7 +39,7 @@ from wayflowcore.tools import ClientTool, ServerTool, ToolResult, register_server_tool from wayflowcore.variable import Variable -from ..testhelpers.statesnapshots import ( +from ..testhelpers.state_snapshot_testutils import ( execute_with_state_snapshots, restore_conversation_from_snapshot_payload, ) @@ -54,7 +52,7 @@ def __str__(self) -> str: def _build_snapshot_flow(custom_variable: Variable) -> Flow: return Flow.from_steps( - steps=[ + [ VariableWriteStep( variable=custom_variable, input_mapping={VariableWriteStep.VALUE: custom_variable.name}, @@ -75,7 +73,6 @@ def _build_non_finite_input_snapshot_flow() -> Flow: description="Echo input", func=lambda bad: str(bad), input_descriptors=[AnyProperty(name="bad")], - output_descriptors=[StringProperty(name="out")], ) ), CompleteStep(name="end"), @@ -84,6 +81,44 @@ def _build_non_finite_input_snapshot_flow() -> Flow: ) +def _make_snapshot_flow_conversation( + *, + variable_type: StringProperty | AnyProperty, + input_value: Any, +) -> tuple[Variable, Conversation]: + custom_variable = Variable( + name="custom", + type=variable_type, + description="Custom variable used for snapshot serialization tests", + ) + conversation = _build_snapshot_flow(custom_variable).start_conversation( + inputs={custom_variable.name: input_value} + ) + conversation.execute() + return custom_variable, conversation + + +def _build_user_input_resume_flow() -> Flow: + return Flow.from_steps( + [InputMessageStep("Please answer"), OutputMessageStep("done")], + name="resume_flow", + ) + + +def _conversation_turn_snapshot_payload( + conversation: Conversation, +) -> tuple[object, dict[str, Any]]: + status, state_snapshot_events = execute_with_state_snapshots( + conversation, + state_snapshot_policy=StateSnapshotPolicy( + state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS + ), + ) + + assert state_snapshot_events[-1].state_snapshot is not None + return status, state_snapshot_events[-1].state_snapshot + + def _walk_scalars(value: Any): if isinstance(value, dict): for inner_value in value.values(): @@ -96,99 +131,58 @@ def _walk_scalars(value: Any): yield value -def test_dump_conversation_state_is_json_serializable_and_lightweight() -> None: - custom_variable = Variable( - name="custom", - type=StringProperty(), - description="Custom variable used for snapshot serialization tests", +def test_dump_conversation_state_is_strict_json_serializable_and_lightweight() -> None: + _, conversation = _make_snapshot_flow_conversation( + variable_type=StringProperty(), + input_value="custom-value", ) - flow = _build_snapshot_flow(custom_variable) - conversation = flow.start_conversation(inputs={custom_variable.name: "custom-value"}) - conversation.execute() snapshot = dump_conversation_state(conversation) - variable_state = dump_variable_state(conversation) - serialized_conversation_state = serialize_conversation_state(conversation) - deserialized_conversation_state = deserialize_conversation_state(serialized_conversation_state) assert json.loads(json.dumps(snapshot, allow_nan=False)) == snapshot - assert deserialized_conversation_state["_component_type"] == conversation.__class__.__name__ - assert variable_state == {"custom": "custom-value"} + assert dump_variable_state(conversation) == {"custom": "custom-value"} assert snapshot["conversation"]["component_type"] == "Flow" assert snapshot["conversation"]["messages"][-1]["content"] == "Hello there" - assert all( not isinstance(scalar, (Conversation, Flow, OutputMessageStep)) for scalar in _walk_scalars(snapshot) ) -def test_dump_conversation_state_overrides_execution_fields_without_mutating_conversation() -> None: - custom_variable = Variable( - name="custom", - type=StringProperty(), - description="Custom variable used for snapshot serialization tests", - ) - flow = _build_snapshot_flow(custom_variable) - conversation = flow.start_conversation(inputs={custom_variable.name: "custom-value"}) - conversation.execute() - - previous_status = conversation.status - previous_status_handled = conversation.status_handled - - snapshot = dump_conversation_state( - conversation, - status=None, - status_handled=True, +def test_dump_conversation_state_includes_runtime_conversation_ids() -> None: + _, conversation = _make_snapshot_flow_conversation( + variable_type=StringProperty(), + input_value="custom-value", ) - assert snapshot["execution"]["status"] is None - assert snapshot["execution"]["status_handled"] is True - assert conversation.status is previous_status - assert conversation.status_handled is previous_status_handled - - -def test_dump_conversation_state_includes_runtime_conversation_id() -> None: - custom_variable = Variable( - name="custom", - type=StringProperty(), - description="Custom variable used for snapshot serialization tests", - ) - flow = _build_snapshot_flow(custom_variable) - conversation = flow.start_conversation(inputs={custom_variable.name: "custom-value"}) - conversation.execute() - snapshot = dump_conversation_state(conversation) assert snapshot["conversation"]["id"] == conversation.id assert snapshot["conversation"]["conversation_id"] == conversation.conversation_id -def test_dump_conversation_state_does_not_overload_status_conversation_identity() -> None: - custom_variable = Variable( - name="custom", - type=StringProperty(), - description="Custom variable used for snapshot serialization tests", +def test_dump_conversation_state_status_overrides_do_not_mutate_live_conversation() -> None: + _, conversation = _make_snapshot_flow_conversation( + variable_type=StringProperty(), + input_value="custom-value", ) - flow = _build_snapshot_flow(custom_variable) - conversation = flow.start_conversation(inputs={custom_variable.name: "custom-value"}) - conversation.execute() - snapshot = dump_conversation_state(conversation) + previous_status = conversation.status + previous_status_handled = conversation.status_handled + + snapshot = dump_conversation_state(conversation, status=None, status_handled=True) - assert snapshot["execution"]["status"]["type"] == "FinishedStatus" - assert "conversation_id" not in snapshot["execution"]["status"] + assert snapshot["execution"]["status"] is None + assert snapshot["execution"]["status_handled"] is True + assert conversation.status is previous_status + assert conversation.status_handled is previous_status_handled def test_dump_variable_state_rejects_non_json_serializable_values() -> None: - custom_variable = Variable( - name="custom", - type=AnyProperty(), - description="Custom variable used for snapshot serialization tests", + _, conversation = _make_snapshot_flow_conversation( + variable_type=AnyProperty(), + input_value=_UnserializableValue(), ) - flow = _build_snapshot_flow(custom_variable) - conversation = flow.start_conversation(inputs={custom_variable.name: _UnserializableValue()}) - conversation.execute() with pytest.raises(TypeError, match="Variable 'custom' contains a non-JSON-serializable"): dump_variable_state(conversation) @@ -206,8 +200,7 @@ def test_dump_conversation_state_normalizes_non_finite_floats_for_strict_json( value: float, expected_dumped_value: str, ) -> None: - flow = _build_non_finite_input_snapshot_flow() - conversation = flow.start_conversation(inputs={"bad": value}) + conversation = _build_non_finite_input_snapshot_flow().start_conversation(inputs={"bad": value}) snapshot = dump_conversation_state(conversation) @@ -215,20 +208,19 @@ def test_dump_conversation_state_normalizes_non_finite_floats_for_strict_json( assert snapshot["conversation"]["inputs"]["bad"] == expected_dumped_value -def test_conversation_state_roundtrip_preserves_pending_tool_results() -> None: +def test_serialized_conversation_roundtrip_preserves_pending_tool_results() -> None: client_tool = ClientTool( name="client_lookup", description="Look up some data on the client side", parameters={}, ) - flow = Flow.from_steps( + conversation = Flow.from_steps( [ ToolExecutionStep(tool=client_tool), CompleteStep(name="end"), ], name="tool_resume_flow", - ) - conversation = flow.start_conversation() + ).start_conversation() status = conversation.execute() assert isinstance(status, ToolRequestStatus) @@ -238,16 +230,6 @@ def test_conversation_state_roundtrip_preserves_pending_tool_results() -> None: ToolResult(tool_request_id=tool_request.tool_request_id, content="client-result") ) - snapshot = dump_conversation_state(conversation) - assert snapshot["execution"]["status"]["type"] == "ToolRequestStatus" - assert snapshot["execution"]["status"]["tool_results"] == [ - { - "tool_request_id": tool_request.tool_request_id, - "content": "client-result", - } - ] - assert all(message.tool_result is None for message in conversation.get_messages()) - loaded_conversation = load_conversation_state( deserialize_conversation_state(serialize_conversation_state(conversation)) ) @@ -261,8 +243,8 @@ def test_conversation_state_roundtrip_preserves_pending_tool_results() -> None: ] resumed_status = loaded_conversation.execute() - assert isinstance(resumed_status, FinishedStatus) + assert isinstance(resumed_status, FinishedStatus) tool_result_messages = [ message.tool_result for message in loaded_conversation.get_messages() if message.tool_result ] @@ -271,84 +253,66 @@ def test_conversation_state_roundtrip_preserves_pending_tool_results() -> None: assert tool_result_messages[0].content == "client-result" -def test_load_conversation_state_restores_a_runnable_conversation() -> None: - flow = Flow.from_steps( - [InputMessageStep("Please answer"), OutputMessageStep("done")], - name="resume_flow", +def test_emitted_snapshot_payload_restores_waiting_for_user_input() -> None: + status, snapshot_payload = _conversation_turn_snapshot_payload( + _build_user_input_resume_flow().start_conversation() ) - conversation = flow.start_conversation() - status = conversation.execute() assert isinstance(status, UserMessageRequestStatus) - loaded_conversation = load_conversation_state( - deserialize_conversation_state(serialize_conversation_state(conversation)) - ) - - assert isinstance(loaded_conversation, FlowConversation) - loaded_conversation.append_user_message("hello") - resumed_status = loaded_conversation.execute() - - assert isinstance(resumed_status, FinishedStatus) - assert [message.content for message in loaded_conversation.get_messages()] == [ - "Please answer", - "hello", - "done", - ] - - -def test_deserialize_conversation_restores_a_runnable_conversation() -> None: - flow = Flow.from_steps( - [InputMessageStep("Please answer"), OutputMessageStep("done")], - name="resume_flow", + restored_conversation = restore_conversation_from_snapshot_payload( + json.loads(json.dumps(snapshot_payload, allow_nan=False)) ) - conversation = flow.start_conversation() - - status = conversation.execute() - assert isinstance(status, UserMessageRequestStatus) - - deserialized_conversation = deserialize_conversation(serialize_conversation_state(conversation)) - - assert isinstance(deserialized_conversation, FlowConversation) - deserialized_conversation.append_user_message("hello") - resumed_status = deserialized_conversation.execute() + restored_conversation.append_user_message("hello") + resumed_status = restored_conversation.execute() assert isinstance(resumed_status, FinishedStatus) - assert [message.content for message in deserialized_conversation.get_messages()] == [ + assert [message.content for message in restored_conversation.get_messages()] == [ "Please answer", "hello", "done", ] -def test_load_conversation_state_uses_the_given_deserialization_context() -> None: - tool = ServerTool( - name="say_hi", - description="Say hi", - func=lambda: "hi", - input_descriptors=[], +def test_emitted_snapshot_payload_restores_waiting_for_client_tool_result() -> None: + client_tool = ClientTool( + name="client_lookup", + description="Look up some data on the client side", + parameters={}, ) - flow = Flow.from_steps( + conversation = Flow.from_steps( [ - ToolExecutionStep(tool=tool), + ToolExecutionStep(tool=client_tool), CompleteStep(name="end"), ], - name="tool_flow", - ) + name="snapshot_client_tool_resume_flow", + ).start_conversation() - deserialization_context = DeserializationContext() - register_server_tool(tool, deserialization_context.registered_tools) + status, snapshot_payload = _conversation_turn_snapshot_payload(conversation) - conversation = load_conversation_state( - deserialize_conversation_state(serialize_conversation_state(flow.start_conversation())), - deserialization_context=deserialization_context, + assert isinstance(status, ToolRequestStatus) + + restored_conversation = restore_conversation_from_snapshot_payload(snapshot_payload) + assert isinstance(restored_conversation.status, ToolRequestStatus) + + tool_request = restored_conversation.status.tool_requests[0] + restored_conversation.append_tool_result( + ToolResult(tool_request_id=tool_request.tool_request_id, content="client-result") ) + resumed_status = restored_conversation.execute() - assert isinstance(conversation, FlowConversation) - assert isinstance(conversation.execute(), FinishedStatus) + assert isinstance(resumed_status, FinishedStatus) + tool_result_messages = [ + message.tool_result + for message in restored_conversation.get_messages() + if message.tool_result + ] + assert len(tool_result_messages) == 1 + assert tool_result_messages[0].tool_request_id == tool_request.tool_request_id + assert tool_result_messages[0].content == "client-result" -def test_emitted_snapshot_conversation_state_restores_variable_dependent_continuation() -> None: +def test_emitted_snapshot_payload_restores_variable_dependent_continuation() -> None: customer_name = Variable( name="customer_name", type=StringProperty(), @@ -396,18 +360,11 @@ def test_emitted_snapshot_conversation_state_restores_variable_dependent_continu ) conversation = flow.start_conversation(inputs={customer_name.name: "Alice"}) - status, state_snapshot_events = execute_with_state_snapshots( - conversation, - state_snapshot_policy=StateSnapshotPolicy( - state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS - ), - ) + status, snapshot_payload = _conversation_turn_snapshot_payload(conversation) assert isinstance(status, UserMessageRequestStatus) - assert state_snapshot_events[-1].state_snapshot is not None - snapshot_payload = state_snapshot_events[-1].state_snapshot - restored_conversation = restore_conversation_from_snapshot_payload(snapshot_payload) + restored_conversation = restore_conversation_from_snapshot_payload(snapshot_payload) assert dump_variable_state(restored_conversation) == {"customer_name": "Alice"} restored_conversation.append_user_message("Need pricing") @@ -419,3 +376,29 @@ def test_emitted_snapshot_conversation_state_restores_variable_dependent_continu "Need pricing", "Stored Alice. Reply: Need pricing", ] + + +def test_load_conversation_state_uses_the_given_deserialization_context() -> None: + tool = ServerTool( + name="say_hi", + description="Say hi", + func=lambda: "hi", + input_descriptors=[], + ) + flow = Flow.from_steps( + [ + ToolExecutionStep(tool=tool), + CompleteStep(name="end"), + ], + name="tool_flow", + ) + + deserialization_context = DeserializationContext() + register_server_tool(tool, deserialization_context.registered_tools) + + conversation = load_conversation_state( + deserialize_conversation_state(serialize_conversation_state(flow.start_conversation())), + deserialization_context=deserialization_context, + ) + + assert isinstance(conversation.execute(), FinishedStatus) diff --git a/wayflowcore/tests/testhelpers/agentspec_tracing.py b/wayflowcore/tests/testhelpers/agentspec_tracing.py deleted file mode 100644 index 7ecd00e0f..000000000 --- a/wayflowcore/tests/testhelpers/agentspec_tracing.py +++ /dev/null @@ -1,132 +0,0 @@ -# Copyright © 2026 Oracle and/or its affiliates. -# -# This software is under the Apache License 2.0 -# (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0) or Universal Permissive License -# (UPL) 1.0 (LICENSE-UPL or https://oss.oracle.com/licenses/upl), at your option. - -from contextlib import AbstractContextManager, ExitStack -from typing import Any, Sequence - -from pyagentspec.tracing.events import Event as AgentSpecEvent -from pyagentspec.tracing.spanprocessor import SpanProcessor as AgentSpecSpanProcessor -from pyagentspec.tracing.spans import Span as AgentSpecSpan -from pyagentspec.tracing.trace import Trace as AgentSpecTrace - -from wayflowcore.agentspec.tracing import AgentSpecEventListener -from wayflowcore.events.eventlistener import register_event_listeners - - -class SnapshotSpanRecorder(AgentSpecSpanProcessor): - def __init__(self) -> None: - super().__init__() - self.started_spans: list[AgentSpecSpan] = [] - self.ended_spans: list[AgentSpecSpan] = [] - - def on_start(self, span: AgentSpecSpan) -> None: - self.started_spans.append(span) - - async def on_start_async(self, span: AgentSpecSpan) -> None: - self.started_spans.append(span) - - def on_end(self, span: AgentSpecSpan) -> None: - self.ended_spans.append(span) - - async def on_end_async(self, span: AgentSpecSpan) -> None: - self.ended_spans.append(span) - - def on_event(self, event: AgentSpecEvent, span: AgentSpecSpan) -> None: - return None - - async def on_event_async(self, event: AgentSpecEvent, span: AgentSpecSpan) -> None: - return None - - def startup(self) -> None: - return None - - def shutdown(self) -> None: - return None - - async def startup_async(self) -> None: - return None - - async def shutdown_async(self) -> None: - return None - - -class SnapshotEventsSeenAtSpanEndRecorder(AgentSpecSpanProcessor): - def __init__(self) -> None: - super().__init__() - self.events_by_span_id: dict[str, list[AgentSpecEvent]] = {} - - def on_start(self, span: AgentSpecSpan) -> None: - return None - - async def on_start_async(self, span: AgentSpecSpan) -> None: - return None - - def on_end(self, span: AgentSpecSpan) -> None: - self.events_by_span_id[span.id] = list(span.events) - - async def on_end_async(self, span: AgentSpecSpan) -> None: - self.events_by_span_id[span.id] = list(span.events) - - def on_event(self, event: AgentSpecEvent, span: AgentSpecSpan) -> None: - return None - - async def on_event_async(self, event: AgentSpecEvent, span: AgentSpecSpan) -> None: - return None - - def startup(self) -> None: - return None - - def shutdown(self) -> None: - return None - - async def startup_async(self) -> None: - return None - - async def shutdown_async(self) -> None: - return None - - -def execute_with_trace( - conversation, - *, - state_snapshot_policy, - span_processors: Sequence[AgentSpecSpanProcessor] = (), - contexts: Sequence[AbstractContextManager[Any]] = (), -): - span_recorder = SnapshotSpanRecorder() - listener = AgentSpecEventListener() - - with ExitStack() as stack: - for context in contexts: - stack.enter_context(context) - stack.enter_context(AgentSpecTrace(span_processors=[span_recorder, *span_processors])) - stack.enter_context(register_event_listeners([listener])) - status = conversation.execute(state_snapshot_policy=state_snapshot_policy) - - return status, span_recorder - - -def spans( - span_recorder: SnapshotSpanRecorder, - span_type: type[AgentSpecSpan], -) -> list[AgentSpecSpan]: - return [span for span in span_recorder.started_spans if isinstance(span, span_type)] - - -def single_span( - span_recorder: SnapshotSpanRecorder, - span_type: type[AgentSpecSpan], -) -> AgentSpecSpan: - matching_spans = spans(span_recorder, span_type) - assert len(matching_spans) == 1 - return matching_spans[0] - - -def events( - span: AgentSpecSpan, - event_type: type[AgentSpecEvent], -) -> list[AgentSpecEvent]: - return [event for event in span.events if isinstance(event, event_type)] diff --git a/wayflowcore/tests/testhelpers/state_snapshot_testutils.py b/wayflowcore/tests/testhelpers/state_snapshot_testutils.py new file mode 100644 index 000000000..3b4ff7664 --- /dev/null +++ b/wayflowcore/tests/testhelpers/state_snapshot_testutils.py @@ -0,0 +1,84 @@ +# Copyright © 2026 Oracle and/or its affiliates. +# +# This software is under the Apache License 2.0 +# (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0) or Universal Permissive License +# (UPL) 1.0 (LICENSE-UPL or https://oss.oracle.com/licenses/upl), at your option. + +from contextlib import AbstractContextManager, nullcontext +from typing import Any, Sequence + +from wayflowcore.conversation import Conversation +from wayflowcore.events.event import Event, StateSnapshotEvent +from wayflowcore.events.eventlistener import EventListener, register_event_listeners +from wayflowcore.executors.interrupts.executioninterrupt import ExecutionInterrupt +from wayflowcore.executors.statesnapshotpolicy import StateSnapshotPolicy +from wayflowcore.serialization import deserialize_conversation, dump_conversation_state + + +class SnapshotCollector(EventListener): + def __init__(self) -> None: + self.state_snapshot_events: list[StateSnapshotEvent] = [] + + def __call__(self, event: Event) -> None: + if isinstance(event, StateSnapshotEvent): + self.state_snapshot_events.append(event) + + +def snapshot_status_types(snapshot_events: Sequence[StateSnapshotEvent]) -> list[str | None]: + return [ + status["type"] if (status := snapshot_event.state_snapshot["execution"]["status"]) else None + for snapshot_event in snapshot_events + ] + + +def execute_with_state_snapshots( + conversation: Conversation, + *, + state_snapshot_policy: StateSnapshotPolicy, + execution_interrupts: Sequence[ExecutionInterrupt] | None = None, + execution_context: AbstractContextManager[Any] | None = None, +) -> tuple[object, list[StateSnapshotEvent]]: + collector = SnapshotCollector() + + with execution_context or nullcontext(): + with register_event_listeners([collector]): + status = conversation.execute( + execution_interrupts=execution_interrupts, + state_snapshot_policy=state_snapshot_policy, + ) + + return status, collector.state_snapshot_events + + +async def execute_with_state_snapshots_async( + conversation: Conversation, + *, + state_snapshot_policy: StateSnapshotPolicy, + execution_interrupts: Sequence[ExecutionInterrupt] | None = None, + execution_context: AbstractContextManager[Any] | None = None, +) -> tuple[object, list[StateSnapshotEvent]]: + collector = SnapshotCollector() + + with execution_context or nullcontext(): + with register_event_listeners([collector]): + status = await conversation.execute_async( + execution_interrupts=execution_interrupts, + state_snapshot_policy=state_snapshot_policy, + ) + + return status, collector.state_snapshot_events + + +def restore_conversation_from_snapshot_payload( + snapshot_payload: dict[str, Any], +) -> Conversation: + assert snapshot_payload["runtime"] == "wayflow" + assert snapshot_payload["schema_version"] == 1 + assert isinstance(snapshot_payload["conversation_state"], str) + + restored_conversation = deserialize_conversation(snapshot_payload["conversation_state"]) + assert dump_conversation_state(restored_conversation) == { + "conversation": snapshot_payload["conversation"], + "execution": snapshot_payload["execution"], + } + return restored_conversation diff --git a/wayflowcore/tests/testhelpers/statesnapshots.py b/wayflowcore/tests/testhelpers/statesnapshots.py deleted file mode 100644 index 3ae28040a..000000000 --- a/wayflowcore/tests/testhelpers/statesnapshots.py +++ /dev/null @@ -1,265 +0,0 @@ -# Copyright © 2026 Oracle and/or its affiliates. -# -# This software is under the Apache License 2.0 -# (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0) or Universal Permissive License -# (UPL) 1.0 (LICENSE-UPL or https://oss.oracle.com/licenses/upl), at your option. - -import threading -from contextlib import AbstractContextManager, nullcontext -from typing import Any, Callable, Sequence - -from wayflowcore.agent import Agent -from wayflowcore.conversation import Conversation -from wayflowcore.events.event import Event, StateSnapshotEvent -from wayflowcore.events.eventlistener import EventListener, register_event_listeners -from wayflowcore.executors._executionstate import ConversationExecutionState -from wayflowcore.executors.interrupts.executioninterrupt import ( - ExecutionInterrupt, - InterruptedExecutionStatus, - _NullExecutionInterrupt, -) -from wayflowcore.executors.statesnapshotpolicy import StateSnapshotPolicy -from wayflowcore.flow import Flow -from wayflowcore.messagelist import Message, MessageType -from wayflowcore.property import AnyProperty, StringProperty -from wayflowcore.serialization import deserialize_conversation, dump_conversation_state -from wayflowcore.serialization.serializer import SerializableNeedToBeImplementedMixin -from wayflowcore.steps import ( - AgentExecutionStep, - CompleteStep, - OutputMessageStep, - ToolExecutionStep, - VariableWriteStep, -) -from wayflowcore.tools import ServerTool, ToolRequest, tool -from wayflowcore.variable import Variable - -from .dummy import DummyModel - - -class SnapshotCollector(EventListener): - def __init__(self) -> None: - self.state_snapshot_events: list[StateSnapshotEvent] = [] - - def __call__(self, event: Event) -> None: - if isinstance(event, StateSnapshotEvent): - self.state_snapshot_events.append(event) - - -class SnapshotSerializableDummyModel(DummyModel): - @property - def config(self) -> dict[str, Any]: - return {"model_id": self.model_id} - - -def snapshot_status_types(snapshot_events: Sequence[Any]) -> list[str | None]: - return [ - status["type"] if (status := snapshot_event.state_snapshot["execution"]["status"]) else None - for snapshot_event in snapshot_events - ] - - -def snapshot_message(snapshot_event: Any) -> str | None: - messages = snapshot_event.state_snapshot["conversation"]["messages"] - if not messages: - return None - return messages[-1].get("content") - - -def snapshot_runtime_conversation_ids(snapshot_events: Sequence[Any]) -> list[str]: - return [ - snapshot_event.state_snapshot["conversation"]["id"] for snapshot_event in snapshot_events - ] - - -def snapshot_step_histories(snapshot_events: Sequence[Any]) -> list[list[str]]: - return [ - snapshot_event.state_snapshot["execution"]["step_history"] - for snapshot_event in snapshot_events - ] - - -def execute_with_state_snapshots( - conversation: Conversation, - *, - state_snapshot_policy: StateSnapshotPolicy, - execution_interrupts: Sequence[ExecutionInterrupt] | None = None, - execution_context: AbstractContextManager[Any] | None = None, -) -> tuple[object, list[StateSnapshotEvent]]: - collector = SnapshotCollector() - - with execution_context or nullcontext(): - with register_event_listeners([collector]): - status = conversation.execute( - execution_interrupts=execution_interrupts, - state_snapshot_policy=state_snapshot_policy, - ) - - return status, collector.state_snapshot_events - - -async def execute_with_state_snapshots_async( - conversation: Conversation, - *, - state_snapshot_policy: StateSnapshotPolicy, - execution_interrupts: Sequence[ExecutionInterrupt] | None = None, - execution_context: AbstractContextManager[Any] | None = None, -) -> tuple[object, list[StateSnapshotEvent]]: - collector = SnapshotCollector() - - with execution_context or nullcontext(): - with register_event_listeners([collector]): - status = await conversation.execute_async( - execution_interrupts=execution_interrupts, - state_snapshot_policy=state_snapshot_policy, - ) - - return status, collector.state_snapshot_events - - -class MutatingExecutionEndInterrupt(SerializableNeedToBeImplementedMixin, _NullExecutionInterrupt): - def __init__(self) -> None: - self.lock = threading.Lock() - self.count = 0 - super().__init__() - - def _on_execution_end( - self, - state: ConversationExecutionState, - conversation: Conversation, - ) -> InterruptedExecutionStatus | None: - conversation.inputs["preview_count"] = conversation.inputs.get("preview_count", 0) + 1 - self.count += 1 - return None - - -class _UnserializableVariableValue: - pass - - -def create_tool_flow_conversation( - func: Callable[[], object], - *, - name: str = "say_hi", - description: str = "Say hi", -) -> Conversation: - return Flow.from_steps( - [ - ToolExecutionStep( - tool=ServerTool( - name=name, - description=description, - func=func, - input_descriptors=[], - ) - ), - CompleteStep(name="end"), - ] - ).start_conversation() - - -def create_output_flow_conversation(message: str = "Hello") -> Conversation: - return Flow.from_steps( - [ - OutputMessageStep(message_template=message), - CompleteStep(name="end"), - ] - ).start_conversation() - - -def create_agent_conversation(message: str = "Hello from agent") -> Conversation: - llm = SnapshotSerializableDummyModel() - llm.set_next_output(message) - conversation = Agent(llm=llm).start_conversation() - conversation.append_user_message("Hi") - return conversation - - -def create_tool_calling_agent_conversation() -> Conversation: - @tool - def do_nothing_tool() -> str: - """Do nothing tool.""" - return "Tool called successfully" - - llm = SnapshotSerializableDummyModel() - llm.set_next_output( - { - "Please use the do_nothing_tool": Message( - message_type=MessageType.TOOL_REQUEST, - content="I am calling the do nothing tool", - tool_requests=[ToolRequest("do_nothing_tool", {}, "tc1")], - ) - } - ) - conversation = Agent(llm=llm, tools=[do_nothing_tool], max_iterations=10).start_conversation() - conversation.append_user_message("Please use the do_nothing_tool") - return conversation - - -def create_nested_agent_step_flow_conversation() -> Conversation: - llm = SnapshotSerializableDummyModel() - llm.set_next_output("agent answer") - child_agent = Agent(llm=llm) - conversation = Flow.from_steps( - [AgentExecutionStep(agent=child_agent), CompleteStep(name="end")] - ).start_conversation() - conversation.append_user_message("dummy") - return conversation - - -def create_parallel_child_flow(output_name: str, output_value: str) -> Flow: - return Flow.from_steps( - [ - ToolExecutionStep( - tool=ServerTool( - name=f"tool_{output_name}", - description=f"Return {output_name}", - input_descriptors=[], - output_descriptors=[StringProperty(name=output_name)], - func=lambda: output_value, - ) - ), - CompleteStep(name="end"), - ] - ) - - -def create_unserializable_variable_conversation() -> Conversation: - custom_variable = Variable(name="custom", type=AnyProperty()) - return Flow.from_steps( - [ - VariableWriteStep( - variable=custom_variable, - input_mapping={VariableWriteStep.VALUE: custom_variable.name}, - ), - OutputMessageStep(message_template="done"), - CompleteStep(name="end"), - ], - variables=[custom_variable], - ).start_conversation(inputs={custom_variable.name: _UnserializableVariableValue()}) - - -def assert_terminal_snapshot( - snapshot_events: Sequence[object], - *, - expected_status_type: str, - expected_message: str, -) -> None: - assert snapshot_status_types(snapshot_events)[-1] == expected_status_type - assert snapshot_message(snapshot_events[-1]) == expected_message - assert snapshot_events[-1].state_snapshot["execution"]["status_handled"] is False - - -def restore_conversation_from_snapshot_payload( - snapshot_payload: dict[str, Any], -) -> Conversation: - assert snapshot_payload["runtime"] == "wayflow" - assert snapshot_payload["schema_version"] == 1 - assert isinstance(snapshot_payload["conversation_state"], str) - - restored_conversation = deserialize_conversation(snapshot_payload["conversation_state"]) - assert dump_conversation_state(restored_conversation) == { - "conversation": snapshot_payload["conversation"], - "execution": snapshot_payload["execution"], - } - return restored_conversation