diff --git a/docs/wayflowcore/source/core/api/events.rst b/docs/wayflowcore/source/core/api/events.rst index cd3996709..98cb8e5b0 100644 --- a/docs/wayflowcore/source/core/api/events.rst +++ b/docs/wayflowcore/source/core/api/events.rst @@ -55,6 +55,10 @@ Events .. autoclass:: wayflowcore.events.event.ConversationExecutionFinishedEvent :exclude-members: to_tracing_info +.. _statesnapshotevent: +.. autoclass:: wayflowcore.events.event.StateSnapshotEvent + :exclude-members: to_tracing_info + .. _toolexecutionstartevent: .. autoclass:: wayflowcore.events.event.ToolExecutionStartEvent :exclude-members: to_tracing_info diff --git a/docs/wayflowcore/source/core/api/serialization.rst b/docs/wayflowcore/source/core/api/serialization.rst index 28d1e3c5a..f58c85fb7 100644 --- a/docs/wayflowcore/source/core/api/serialization.rst +++ b/docs/wayflowcore/source/core/api/serialization.rst @@ -33,6 +33,28 @@ Deserialization .. autofunction:: wayflowcore.serialization.serializer.autodeserialize +Conversation State Snapshots +---------------------------- + +.. _dumpconversationstate: +.. autofunction:: wayflowcore.serialization.dump_conversation_state + +.. _serializeconversationstate: +.. autofunction:: wayflowcore.serialization.serialize_conversation_state + +.. _deserializeconversationstate: +.. autofunction:: wayflowcore.serialization.deserialize_conversation_state + +.. _loadconversationstate: +.. autofunction:: wayflowcore.serialization.load_conversation_state + +.. _deserializeconversation: +.. autofunction:: wayflowcore.serialization.deserialize_conversation + +.. _dumpvariablestate: +.. autofunction:: wayflowcore.serialization.dump_variable_state + + Plugins ------- diff --git a/docs/wayflowcore/source/core/changelog.rst b/docs/wayflowcore/source/core/changelog.rst index 337935eb3..cd9cc1487 100644 --- a/docs/wayflowcore/source/core/changelog.rst +++ b/docs/wayflowcore/source/core/changelog.rst @@ -7,6 +7,12 @@ WayFlow |current_version| New features ^^^^^^^^^^^^ +* **State snapshot tracing events:** + + Added configurable conversation state snapshots for tracing, with emission at conversation, node, or tool boundaries and bridging into Agent Spec state snapshot events. + Added resumable conversation state serialization so persisted conversations can be restored and continued. + Snapshot emission is covered on both synchronous and asynchronous execution paths, with snapshot ownership currently scoped to the active conversation. + * **OAuth support for MCP Clients:** MCP Clients now support OAuth-based authorization. @@ -62,6 +68,10 @@ Possibly Breaking Changes Bug fixes ^^^^^^^^^ +* **State snapshot test coverage:** + + Reduced duplicated flow/agent, sync/async, nested-conversation, internal-turn, serialization, and Agent Spec tracing wrappers in the state snapshot tests, and moved the reusable state snapshot fixtures plus explicit tracing/runtime helpers into shared test helper modules/plugins. + WayFlow 26.1.1 -------------- diff --git a/docs/wayflowcore/source/core/howtoguides/howto_tracing.rst b/docs/wayflowcore/source/core/howtoguides/howto_tracing.rst index 8a1de2286..f32847323 100644 --- a/docs/wayflowcore/source/core/howtoguides/howto_tracing.rst +++ b/docs/wayflowcore/source/core/howtoguides/howto_tracing.rst @@ -152,6 +152,67 @@ Here's an example of how to use it in your code. :start-after: .. start-##_Enable_Agent_Spec_Tracing :end-before: .. end-##_Enable_Agent_Spec_Tracing +State snapshot events +--------------------- + +WayFlow can also emit ``StateSnapshotEvent`` payloads at conversation, step, and tool +boundaries by passing a ``StateSnapshotPolicy`` to ``conversation.execute()`` or +``conversation.execute_async()``. When ``AgentSpecEventListener`` is registered and +the installed ``pyagentspec`` version exposes ``StateSnapshotEmitted``, these runtime +snapshots are bridged into Agent Spec ``StateSnapshotEmitted`` events on the +owning conversation/component span. The WayFlow-only ``variable_state`` payload +is not bridged, because Agent Spec does not define WayFlow variable semantics. +When ``include_variable_state=True``, variable values must already be +JSON-serializable. +``StateSnapshotEvent.conversation_id`` is the logical/public conversation id, +while ``state_snapshot["conversation"]["id"]`` identifies the concrete runtime +conversation instance that emitted the snapshot. The lightweight +``state_snapshot["conversation"]`` / ``state_snapshot["execution"]`` sections are +intended for inspection and tracing. Only the root conversation-turn +checkpoints emitted for the conversation passed directly to ``execute()`` / +``execute_async()`` include ``state_snapshot["conversation_state"]``, which is +the authoritative serialized WayFlow conversation blob used for resumability. +Internal tool/node snapshots intentionally omit that serialized blob to stay +lightweight. To restore from a resumable checkpoint, use +``wayflowcore.serialization.deserialize_conversation(...)`` or +``deserialize_conversation_state(...)`` together with +``load_conversation_state(...)``. +Snapshot intervals are cumulative. ``TOOL_TURNS`` includes the +``CONVERSATION_TURNS`` checkpoints, ``NODE_TURNS`` includes the +``CONVERSATION_TURNS`` checkpoints, and ``ALL_INTERNAL_TURNS`` includes all +conversation, tool, and node boundaries. Only the conversation passed directly +to ``execute()`` / ``execute_async()`` emits the authoritative turn-level +resumability checkpoints for that run. Nested child conversations may still +emit internal tracing snapshots, but those child-runtime snapshots are tracing +checkpoints unless and until a stronger contract is introduced. +``CONVERSATION_TURNS`` emits an opening snapshot at execution start and a +closing turn snapshot at the end of the turn. That closing payload is emitted +before the live conversation object commits the new status, but its serialized +payload is synthesized so that after ``execute()`` returns it matches the +committed state seen by the caller. Snapshots are emitted only when the +corresponding boundary event occurs. If a turn is interrupted mid-turn, WayFlow +does not synthesize a turn-end snapshot; the latest already-emitted opening or +internal snapshot is the recovery point. +For flows, ``NODE_TURNS`` uses flow-iteration start/end events, which align with +per-step execution. For agents, the same policy emits snapshots around each +decision-loop iteration. Tool start/end snapshots are emitted only for +``TOOL_TURNS`` and ``ALL_INTERNAL_TURNS``. + +.. code-block:: python + + from wayflowcore.executors.statesnapshotpolicy import ( + StateSnapshotInterval, + StateSnapshotPolicy, + ) + + status = conversation.execute( + state_snapshot_policy=StateSnapshotPolicy( + state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS, + include_variable_state=True, + extra_state_builder=lambda conv: {"ui": {"active_tab": "plan"}}, + ) + ) + Agent Spec Exporting/Loading ============================ diff --git a/wayflowcore/src/wayflowcore/agentspec/tracing.py b/wayflowcore/src/wayflowcore/agentspec/tracing.py index e42147179..a102a6b1c 100644 --- a/wayflowcore/src/wayflowcore/agentspec/tracing.py +++ b/wayflowcore/src/wayflowcore/agentspec/tracing.py @@ -4,6 +4,7 @@ # (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0) or Universal Permissive License # (UPL) 1.0 (LICENSE-UPL or https://oss.oracle.com/licenses/upl), at your option. import json +from dataclasses import dataclass from typing import Dict, Optional, Union, cast from pyagentspec import Component as AgentSpecComponent @@ -25,6 +26,7 @@ from pyagentspec.tracing.events import LlmGenerationResponse as AgentSpecLlmGenerationResponse from pyagentspec.tracing.events import NodeExecutionEnd as AgentSpecNodeExecutionEnd from pyagentspec.tracing.events import NodeExecutionStart as AgentSpecNodeExecutionStart +from pyagentspec.tracing.events import StateSnapshotEmitted as AgentSpecStateSnapshotEmitted from pyagentspec.tracing.events import ToolExecutionRequest as AgentSpecToolExecutionRequest from pyagentspec.tracing.events import ToolExecutionResponse as AgentSpecToolExecutionResponse from pyagentspec.tracing.events.llmgeneration import ToolCall as AgentSpecToolCall @@ -39,16 +41,19 @@ from wayflowcore._utils.formatting import stringify from wayflowcore.agentspec import AgentSpecExporter from wayflowcore.component import Component +from wayflowcore.conversation import Conversation, _get_active_conversations from wayflowcore.events.event import ( AgentExecutionFinishedEvent, AgentExecutionStartedEvent, ConversationMessageStreamChunkEvent, + EndSpanEvent, Event, ExceptionRaisedEvent, FlowExecutionFinishedEvent, FlowExecutionStartedEvent, LlmGenerationRequestEvent, LlmGenerationResponseEvent, + StateSnapshotEvent, StepInvocationResultEvent, StepInvocationStartEvent, ToolExecutionResultEvent, @@ -59,6 +64,12 @@ from wayflowcore.tracing.span import LlmGenerationSpan, get_active_span_stack, get_current_span +@dataclass(frozen=True) +class _ConversationSpanOwner: + runtime_conversation_id: str + span: AgentSpecSpan + + class AgentSpecEventListener(EventListener): """Event listener that emits traces according to the Open Agent Spec Tracing standard""" @@ -70,6 +81,12 @@ def __init__(self) -> None: self.agentspec_exporter: AgentSpecExporter = AgentSpecExporter() # We keep a registry of conversions, so that we do not repeat the conversion for the same object twice self.agentspec_components_registry: Dict[str, AgentSpecComponent] = {} + # State snapshots belong to the span that owns their logical + # conversation_id, not necessarily to the runtime span that was active + # when the snapshot event was emitted. Nested flow sub-conversations can + # intentionally reuse the same logical conversation_id, so we also track + # the live runtime conversation id that currently owns that stream. + self._conversation_span_owners: Dict[str, _ConversationSpanOwner] = {} # Track last assistant message id and a robust mapping tool_request_id -> assistant message id. # Some providers may emit tool events before final assistant message id is known; we allow # temporarily missing ids and backfill on LLM response. @@ -83,16 +100,94 @@ def _convert_to_agentspec(self, component: Component) -> AgentSpecComponent: ) return self.agentspec_components_registry[component.id] + def _get_active_wayflow_conversation(self) -> Conversation | None: + active_conversations = _get_active_conversations(return_copy=False) + if not active_conversations: + return None + return active_conversations[-1] + + def _register_current_conversation_span(self, agentspec_span: AgentSpecSpan) -> None: + active_conversation = self._get_active_wayflow_conversation() + if active_conversation is None: + return + + current_owner = self._conversation_span_owners.get(active_conversation.conversation_id) + if ( + current_owner is not None + and current_owner.runtime_conversation_id != active_conversation.id + and current_owner.span.end_time is None + ): + return + + self._conversation_span_owners[active_conversation.conversation_id] = ( + _ConversationSpanOwner( + runtime_conversation_id=active_conversation.id, + span=agentspec_span, + ) + ) + + def _get_snapshot_owner_span( + self, + event: StateSnapshotEvent, + current_agentspec_span: AgentSpecSpan | None, + ) -> AgentSpecSpan | None: + # Keep snapshot ownership resolution centralized here. Today we only + # support direct span ownership plus shared-conversation routing for + # nested flows. Future multi-agent tracing can extend this method to + # route snapshots to a dedicated Swarm/ManagerWorkers wrapper span + # without changing the StateSnapshotEvent handling below. + owner = self._conversation_span_owners.get(event.conversation_id) + if owner is not None: + snapshot_conversation = ( + event.state_snapshot.get("conversation") if event.state_snapshot else {} + ) + snapshot_runtime_conversation_id = ( + snapshot_conversation.get("id") if isinstance(snapshot_conversation, dict) else None + ) + if snapshot_runtime_conversation_id != owner.runtime_conversation_id: + return None + return owner.span + + return current_agentspec_span + + def _get_current_conversation_owner(self) -> _ConversationSpanOwner | None: + active_conversation = self._get_active_wayflow_conversation() + if active_conversation is None: + return None + return self._conversation_span_owners.get(active_conversation.conversation_id) + + def _span_has_state_snapshot(self, agentspec_span: AgentSpecSpan) -> bool: + return any( + isinstance(span_event, AgentSpecStateSnapshotEmitted) + for span_event in agentspec_span.events + ) + + def _span_has_execution_end_event(self, agentspec_span: AgentSpecSpan) -> bool: + return any( + isinstance(span_event, (AgentSpecFlowExecutionEnd, AgentSpecAgentExecutionEnd)) + for span_event in agentspec_span.events + ) + + def _end_conversation_span_if_ready(self, agentspec_span: AgentSpecSpan) -> None: + owner = self._get_current_conversation_owner() + if ( + owner is not None + and owner.span is agentspec_span + and self._span_has_state_snapshot(agentspec_span) + ): + return + agentspec_span.end() + def __call__(self, event: Event) -> None: # We intercept the wayflow events, and based on the type of event: # - if it corresponds to a span start event, we create the corresponding agent spec span, and we start it # - we map the wayflow event to the corresponding agent spec one, and we emit that # - if it corresponds to a span end event, we retrieve the corresponding agent spec span, and we close it current_span = get_current_span() - if not current_span: + if current_span is None: return - current_agentspec_span = self.agentspec_spans_registry.get(current_span.span_id, None) current_span_name = current_span.name or "" + current_agentspec_span = self.agentspec_spans_registry.get(current_span.span_id, None) event_name = event.name or "" match event: case LlmGenerationRequestEvent(): @@ -312,6 +407,20 @@ def __call__(self, event: Event) -> None: ) ) current_agentspec_span.end() + case StateSnapshotEvent(): + owner_span = self._get_snapshot_owner_span(event, current_agentspec_span) + if not owner_span: + return + snapshot_event = AgentSpecStateSnapshotEmitted( + id=event.event_id, + name=event_name, + conversation_id=event.conversation_id, + state_snapshot=event.state_snapshot, + extra_state=event.extra_state, + ) + owner_span.add_event(snapshot_event) + if owner_span.end_time is None and self._span_has_execution_end_event(owner_span): + owner_span.end() case FlowExecutionStartedEvent(): # Flow execution starts. Create the new agent spec span, start it, add the event agentspec_flow = cast( @@ -332,8 +441,11 @@ def __call__(self, event: Event) -> None: inputs={}, ) ) + self._register_current_conversation_span(current_agentspec_span) case FlowExecutionFinishedEvent(): - # Flow execution ends. Add the event to the agent spec span and close the span + # Flow execution ends. Add the event to the agent spec span. If this span owns + # the logical conversation checkpoint stream, delay closing until the final + # StateSnapshotEvent is bridged so span processors still see the final snapshot. if not current_agentspec_span: return agentspec_flow = cast( @@ -354,7 +466,7 @@ def __call__(self, event: Event) -> None: branch_selected=branch_selected, ) ) - current_agentspec_span.end() + self._end_conversation_span_if_ready(current_agentspec_span) case AgentExecutionStartedEvent(): # Agent execution starts. Create the new agent spec span, start it, add the event agentspec_agent = cast( @@ -375,8 +487,11 @@ def __call__(self, event: Event) -> None: inputs={}, ) ) + self._register_current_conversation_span(current_agentspec_span) case AgentExecutionFinishedEvent(): - # Agent execution ends. Add the event to the agent spec span and close the span + # Agent execution ends. Add the event to the agent spec span. If this span owns + # the logical conversation checkpoint stream, delay closing until the final + # StateSnapshotEvent is bridged so span processors still see the final snapshot. if not current_agentspec_span: return agentspec_agent = cast( @@ -395,7 +510,7 @@ def __call__(self, event: Event) -> None: outputs=outputs, ) ) - current_agentspec_span.end() + self._end_conversation_span_if_ready(current_agentspec_span) case ExceptionRaisedEvent(): if not current_agentspec_span: return @@ -408,3 +523,7 @@ def __call__(self, event: Event) -> None: exception_stacktrace=str(event.exception.__traceback__), ) ) + case EndSpanEvent(): + if not current_agentspec_span or current_agentspec_span.end_time is not None: + return + current_agentspec_span.end() diff --git a/wayflowcore/src/wayflowcore/conversation.py b/wayflowcore/src/wayflowcore/conversation.py index 6d347d1b9..5da41403d 100644 --- a/wayflowcore/src/wayflowcore/conversation.py +++ b/wayflowcore/src/wayflowcore/conversation.py @@ -1,4 +1,4 @@ -# Copyright © 2025 Oracle and/or its affiliates. +# Copyright © 2025, 2026 Oracle and/or its affiliates. # # This software is under the Apache License 2.0 # (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0) or Universal Permissive License @@ -32,6 +32,7 @@ ToolRequestStatus, UserMessageRequestStatus, ) +from wayflowcore.executors.statesnapshotpolicy import StateSnapshotPolicy from wayflowcore.messagelist import Message, MessageContent, MessageList from wayflowcore.planning import ExecutionPlan from wayflowcore.tokenusage import TokenUsage @@ -114,6 +115,7 @@ def _register_event(self, event: Event) -> None: def execute( self, execution_interrupts: Optional[Sequence["ExecutionInterrupt"]] = None, + state_snapshot_policy: Optional[StateSnapshotPolicy] = None, ) -> "ExecutionStatus": """ Execute the conversation and get its ``ExecutionStatus`` based on the outcome. @@ -122,12 +124,16 @@ def execute( finished the conversation. """ return run_async_in_sync( - self.execute_async, execution_interrupts, method_name="execute_async" + self.execute_async, + execution_interrupts, + state_snapshot_policy, + method_name="execute_async", ) async def execute_async( self, execution_interrupts: Optional[Sequence["ExecutionInterrupt"]] = None, + state_snapshot_policy: Optional[StateSnapshotPolicy] = None, ) -> "ExecutionStatus": """ Execute the conversation and get its ``ExecutionStatus`` based on the outcome. @@ -138,8 +144,13 @@ async def execute_async( if self.status_handled is False: self._update_conversation_with_status() - with _register_conversation(self): - new_status = await self.component.runner.execute_async(self, execution_interrupts) + from wayflowcore.executors._statesnapshot_eventlistener import ( + get_state_snapshot_execution_context_for_conversation, + ) + + with get_state_snapshot_execution_context_for_conversation(self, state_snapshot_policy): + with _register_conversation(self): + new_status = await self.component.runner.execute_async(self, execution_interrupts) self.status = new_status self.status_handled = False diff --git a/wayflowcore/src/wayflowcore/events/event.py b/wayflowcore/src/wayflowcore/events/event.py index 4c9c39f1c..0f6d36653 100644 --- a/wayflowcore/src/wayflowcore/events/event.py +++ b/wayflowcore/src/wayflowcore/events/event.py @@ -794,6 +794,58 @@ def to_tracing_info(self, mask_sensitive_information: bool = True) -> Dict[str, } +@dataclass(frozen=True) +class StateSnapshotEvent(Event): + """Event emitted by WayFlow when a conversation state snapshot is recorded. + + ``conversation_id`` is the logical/public conversation id. When a snapshot is + present, ``state_snapshot["conversation"]["id"]`` identifies the runtime + conversation instance described by the payload. WayFlow-emitted payloads + include the authoritative serialized state in + ``state_snapshot["conversation_state"]`` only for the root conversation-turn + checkpoints owned by the conversation that began the current ``execute()`` / + ``execute_async()`` run. All snapshots keep + ``state_snapshot["conversation"]`` and ``state_snapshot["execution"]`` as + the lightweight inspection view. Nested child snapshots and internal + tool/node snapshots are primarily tracing checkpoints and may be filtered by + downstream bridges that need a single checkpoint owner per logical + conversation. + """ + + conversation_id: str = field(default_factory=_required_attribute("conversation_id", str)) + state_snapshot: Optional[Dict[str, Any]] = None + extra_state: Optional[Dict[str, Any]] = None + variable_state: Optional[Dict[str, Any]] = None + + def __post_init__(self) -> None: + if self.state_snapshot is None: + return + if not isinstance(self.state_snapshot, dict): + raise ValueError("state_snapshot must be a dictionary") + + snapshot_conversation = self.state_snapshot.get("conversation") + if not isinstance(snapshot_conversation, dict): + raise ValueError("state_snapshot must contain a 'conversation' object") + if not isinstance(snapshot_conversation.get("id"), str): + raise ValueError("state_snapshot['conversation']['id'] must be a string") + + def to_tracing_info(self, mask_sensitive_information: bool = True) -> Dict[str, Any]: + def _masked(value: Optional[Dict[str, Any]]) -> Any: + if value is None: + return None + if mask_sensitive_information: + return _PII_TEXT_MASK + return value + + return { + **super().to_tracing_info(mask_sensitive_information=mask_sensitive_information), + "conversation_id": self.conversation_id, + "state_snapshot": _masked(self.state_snapshot), + "extra_state": _masked(self.extra_state), + "variable_state": _masked(self.variable_state), + } + + @dataclass(frozen=True) class AgentNextActionDecisionStartEvent(Event): """ diff --git a/wayflowcore/src/wayflowcore/executors/_statesnapshot_eventlistener.py b/wayflowcore/src/wayflowcore/executors/_statesnapshot_eventlistener.py new file mode 100644 index 000000000..ea1a7fb49 --- /dev/null +++ b/wayflowcore/src/wayflowcore/executors/_statesnapshot_eventlistener.py @@ -0,0 +1,456 @@ +# Copyright © 2026 Oracle and/or its affiliates. +# +# This software is under the Apache License 2.0 +# (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0) or Universal Permissive License +# (UPL) 1.0 (LICENSE-UPL or https://oss.oracle.com/licenses/upl), at your option. + +from __future__ import annotations + +import json +from contextlib import contextmanager +from contextvars import ContextVar +from typing import Any, Dict, Iterator, Optional, cast + +from wayflowcore.conversation import Conversation, _get_active_conversations +from wayflowcore.events import Event, EventListener +from wayflowcore.events.event import ( + AgentExecutionFinishedEvent, + AgentExecutionIterationFinishedEvent, + AgentExecutionIterationStartedEvent, + AgentExecutionStartedEvent, + ExceptionRaisedEvent, + FlowExecutionFinishedEvent, + FlowExecutionIterationFinishedEvent, + FlowExecutionIterationStartedEvent, + FlowExecutionStartedEvent, + StateSnapshotEvent, + ToolExecutionResultEvent, + ToolExecutionStartEvent, +) +from wayflowcore.events.eventlistener import ( + get_event_listeners, + record_event, + register_event_listeners, +) +from wayflowcore.executors._events.event import EventType as ExecutionEventType +from wayflowcore.executors._executor import ExecutionInterruptedException +from wayflowcore.executors.executionstatus import ExecutionStatus +from wayflowcore.executors.statesnapshotpolicy import StateSnapshotInterval, StateSnapshotPolicy +from wayflowcore.serialization.conversation import ( + _UNSET, + _serialize_conversation_state_with_runtime_overrides, + dump_conversation_state, + dump_variable_state, + serialize_conversation_state, +) +from wayflowcore.tracing.span import AgentExecutionSpan, FlowExecutionSpan, get_current_span + +_STATE_SNAPSHOT_RUNTIME = "wayflow" +_STATE_SNAPSHOT_SCHEMA_VERSION = 1 +_STATE_SNAPSHOT_INTERVALS_BY_POLICY = { + StateSnapshotInterval.CONVERSATION_TURNS: { + StateSnapshotInterval.CONVERSATION_TURNS, + }, + StateSnapshotInterval.TOOL_TURNS: { + StateSnapshotInterval.CONVERSATION_TURNS, + StateSnapshotInterval.TOOL_TURNS, + }, + StateSnapshotInterval.NODE_TURNS: { + StateSnapshotInterval.CONVERSATION_TURNS, + StateSnapshotInterval.NODE_TURNS, + }, + StateSnapshotInterval.ALL_INTERNAL_TURNS: { + StateSnapshotInterval.CONVERSATION_TURNS, + StateSnapshotInterval.TOOL_TURNS, + StateSnapshotInterval.NODE_TURNS, + }, +} + + +_STATE_SNAPSHOT_POLICIES: ContextVar[Dict[str, StateSnapshotPolicy]] = ContextVar( + "_STATE_SNAPSHOT_POLICIES", + default={}, +) +"""Execution-local mapping of active conversations to their effective snapshot policy.""" + +_STATE_SNAPSHOT_EXECUTION_ROOT_CONVERSATION_ID: ContextVar[Optional[str]] = ContextVar( + "_STATE_SNAPSHOT_EXECUTION_ROOT_CONVERSATION_ID", + default=None, +) +"""Runtime conversation id for the root conversation of the current snapshot-enabled run.""" + + +def _get_state_snapshot_policies( + return_copy: bool = True, +) -> Dict[str, StateSnapshotPolicy]: + state_snapshot_policies = _STATE_SNAPSHOT_POLICIES.get() + return state_snapshot_policies.copy() if return_copy else state_snapshot_policies + + +def _get_parent_state_snapshot_policy( + conversation: Conversation, +) -> Optional[StateSnapshotPolicy]: + active_conversations = _get_active_conversations(return_copy=False) + if not active_conversations or active_conversations[-1] is conversation: + return None + return _get_state_snapshot_policy(active_conversations[-1]) + + +def get_effective_state_snapshot_policy_for_conversation( + conversation: Conversation, + state_snapshot_policy: Optional[StateSnapshotPolicy], +) -> Optional[StateSnapshotPolicy]: + return ( + state_snapshot_policy + if state_snapshot_policy is not None + else _get_parent_state_snapshot_policy(conversation) + ) + + +def _is_state_snapshot_execution_root(conversation: Conversation) -> bool: + return _STATE_SNAPSHOT_EXECUTION_ROOT_CONVERSATION_ID.get() == conversation.id + + +def _get_state_snapshot_policy( + conversation: Conversation, +) -> Optional[StateSnapshotPolicy]: + return _get_state_snapshot_policies(return_copy=False).get(conversation.id) + + +@contextmanager +def _use_state_snapshot_policy( + conversation: Conversation, + state_snapshot_policy: Optional[StateSnapshotPolicy], +) -> Iterator[None]: + # Copy-on-write is needed here because child anyio tasks inherit the current + # context, including references to mutable ContextVar values. + state_snapshot_policies = _get_state_snapshot_policies(return_copy=True) + if state_snapshot_policy is None: + state_snapshot_policies.pop(conversation.id, None) + else: + state_snapshot_policies[conversation.id] = state_snapshot_policy + + token = _STATE_SNAPSHOT_POLICIES.set(state_snapshot_policies) + try: + yield + finally: + _STATE_SNAPSHOT_POLICIES.reset(token) + + +@contextmanager +def _use_state_snapshot_execution_root( + conversation: Conversation, +) -> Iterator[None]: + current_root_conversation_id = _STATE_SNAPSHOT_EXECUTION_ROOT_CONVERSATION_ID.get() + if current_root_conversation_id is not None: + yield + return + + token = _STATE_SNAPSHOT_EXECUTION_ROOT_CONVERSATION_ID.set(conversation.id) + try: + yield + finally: + _STATE_SNAPSHOT_EXECUTION_ROOT_CONVERSATION_ID.reset(token) + + +def _build_extra_state( + conversation: Conversation, + state_snapshot_policy: StateSnapshotPolicy, +) -> Optional[Dict[str, Any]]: + if state_snapshot_policy.extra_state_builder is None: + return None + + extra_state = state_snapshot_policy.extra_state_builder(conversation) + if extra_state is None: + return None + if not isinstance(extra_state, dict): + raise TypeError( + f"Expected extra snapshot state for conversation '{conversation.conversation_id}' to be a dictionary" + ) + + try: + return cast(Dict[str, Any], json.loads(json.dumps(extra_state, allow_nan=False))) + except (TypeError, ValueError) as exc: + raise TypeError( + f"Extra snapshot state for conversation '{conversation.conversation_id}' must be strict JSON-serializable" + ) from exc + + +def _get_snapshot_policy_for_interval( + conversation: Conversation, + required_snapshot_interval: StateSnapshotInterval, +) -> Optional[StateSnapshotPolicy]: + state_snapshot_policy = _get_state_snapshot_policy(conversation) + if state_snapshot_policy is None: + return None + + snapshot_interval = state_snapshot_policy.state_snapshot_interval + if snapshot_interval == StateSnapshotInterval.OFF: + return None + + if required_snapshot_interval in _STATE_SNAPSHOT_INTERVALS_BY_POLICY[snapshot_interval]: + return state_snapshot_policy + return None + + +def _build_variable_state( + conversation: Conversation, + state_snapshot_policy: StateSnapshotPolicy, +) -> Optional[dict[str, Any]]: + if not state_snapshot_policy.include_variable_state: + return None + + return dump_variable_state(conversation) + + +def _build_state_snapshot_payload( + conversation: Conversation, + *, + include_conversation_state: bool, + status: object = _UNSET, + status_handled: object = _UNSET, +) -> dict[str, Any]: + # The snapshot payload should match the intended runtime view at the + # boundary where it is emitted. Opening/tool/node snapshots mask any + # previous turn status, and turn-end snapshots can override the runtime + # fields before the live conversation object commits that new status. + dumped_state = dump_conversation_state( + conversation, + status=status, + status_handled=status_handled, + ) + payload = { + "runtime": _STATE_SNAPSHOT_RUNTIME, + "schema_version": _STATE_SNAPSHOT_SCHEMA_VERSION, + "conversation": dumped_state["conversation"], + "execution": dumped_state["execution"], + } + if include_conversation_state: + payload["conversation_state"] = ( + serialize_conversation_state(conversation) + if status is _UNSET and status_handled is _UNSET + else _serialize_conversation_state_with_runtime_overrides( + conversation, + status=cast( + Optional[ExecutionStatus], + conversation.status if status is _UNSET else status, + ), + status_handled=cast( + bool, + conversation.status_handled if status_handled is _UNSET else status_handled, + ), + ) + ) + return payload + + +def _record_state_snapshot( + conversation: Conversation, + required_snapshot_interval: StateSnapshotInterval, + *, + status: object = _UNSET, + status_handled: object = _UNSET, +) -> None: + state_snapshot_policy = _get_snapshot_policy_for_interval( + conversation, required_snapshot_interval + ) + if state_snapshot_policy is None: + return + + if ( + required_snapshot_interval == StateSnapshotInterval.CONVERSATION_TURNS + and not _is_state_snapshot_execution_root(conversation) + ): + # The conversation passed to `execute()` / `execute_async()` owns the + # resumable turn-level checkpoint stream. Nested child conversations may + # still emit internal tracing snapshots, but they do not emit competing + # conversation-turn checkpoints for the same run. + return + + # Snapshot delivery is part of the execution contract: downstream listener + # or storage failures must propagate to the caller instead of being silently + # converted into best-effort behavior. + record_event( + StateSnapshotEvent( + conversation_id=conversation.conversation_id, + state_snapshot=_build_state_snapshot_payload( + conversation, + include_conversation_state=( + required_snapshot_interval == StateSnapshotInterval.CONVERSATION_TURNS + ), + status=status, + status_handled=status_handled, + ), + extra_state=_build_extra_state(conversation, state_snapshot_policy), + variable_state=_build_variable_state(conversation, state_snapshot_policy), + ) + ) + + +class StateSnapshotEventListener(EventListener): + """Emit state snapshots for the active conversation.""" + + def __init__( + self, + conversation: Conversation, + post_interrupts: bool, + ) -> None: + self.conversation = conversation + self.post_interrupts = post_interrupts + + def _record_snapshot( + self, + required_snapshot_interval: StateSnapshotInterval, + *, + status: object = None, + status_handled: object = False, + ) -> None: + _record_state_snapshot( + self.conversation, + required_snapshot_interval, + status=status, + status_handled=status_handled, + ) + + def _handle_pre_interrupt_event( + self, + event: Event, + ) -> None: + # Agents do not expose node execution events, so NODE_TURNS maps to + # flow iteration boundaries and agent iteration boundaries. + match event: + case AgentExecutionStartedEvent() | FlowExecutionStartedEvent(): + self._record_snapshot(StateSnapshotInterval.CONVERSATION_TURNS) + case ToolExecutionStartEvent() | ToolExecutionResultEvent(): + self._record_snapshot(StateSnapshotInterval.TOOL_TURNS) + case ( + FlowExecutionIterationStartedEvent() + | FlowExecutionIterationFinishedEvent() + | AgentExecutionIterationStartedEvent() + | AgentExecutionIterationFinishedEvent() + ): + self._record_snapshot(StateSnapshotInterval.NODE_TURNS) + + def _should_record_interrupted_turn_end_snapshot( + self, + ) -> bool: + return ( + bool(self.conversation.state.events) + and (self.conversation.state.events[-1].type == ExecutionEventType.EXECUTION_END) + and isinstance(get_current_span(), (FlowExecutionSpan, AgentExecutionSpan)) + ) + + def _owns_current_conversation(self, current_conversation: Conversation) -> bool: + # This is the intended extension point for future multi-agent snapshot + # ownership rules. Today a listener only reacts for its own active + # conversation. Follow-up PRs can widen this here to parent wrapper + # conversations (for example swarms or manager-workers) without + # changing the snapshot emission logic elsewhere in this listener. + return current_conversation.id == self.conversation.id + + def _handle_post_interrupt_event(self, event: Event) -> None: + match event: + case FlowExecutionFinishedEvent( + execution_status=execution_status + ) | AgentExecutionFinishedEvent(execution_status=execution_status): + self._record_snapshot( + StateSnapshotInterval.CONVERSATION_TURNS, + status=execution_status, + status_handled=False, + ) + case ExceptionRaisedEvent(exception=ExecutionInterruptedException() as exception): + if self._should_record_interrupted_turn_end_snapshot(): + self._record_snapshot( + StateSnapshotInterval.CONVERSATION_TURNS, + status=exception.execution_status, + status_handled=False, + ) + + def __call__(self, event: Event) -> None: + if isinstance(event, StateSnapshotEvent): + return + + active_conversations = _get_active_conversations(return_copy=False) + if not active_conversations: + return + + current_conversation = active_conversations[-1] + if not self._owns_current_conversation(current_conversation): + return + + if self.post_interrupts: + self._handle_post_interrupt_event(event) + else: + self._handle_pre_interrupt_event(event) + + +@contextmanager +def get_state_snapshot_event_listener_context_for_conversation( + conversation: Conversation, + *, + post_interrupts: bool, +) -> Iterator[StateSnapshotEventListener]: + current_listener = next( + ( + event_listener + for event_listener in get_event_listeners() + if isinstance(event_listener, StateSnapshotEventListener) + and event_listener.conversation.id == conversation.id + and event_listener.post_interrupts == post_interrupts + ), + None, + ) + + if current_listener is not None: + yield current_listener + else: + listener = StateSnapshotEventListener(conversation, post_interrupts=post_interrupts) + with register_event_listeners([listener]): + yield listener + + +@contextmanager +def get_state_snapshot_execution_context_for_conversation( + conversation: Conversation, + state_snapshot_policy: Optional[StateSnapshotPolicy], +) -> Iterator[None]: + """ + Activate the effective snapshot policy for one `conversation.execute(...)` turn. + + Child conversations inherit the currently active parent policy unless they + explicitly override it. Only the conversation that started the current + snapshot-enabled execution emits conversation-turn checkpoints; nested + children may still emit internal snapshots that describe their own runtime + state. When snapshots are enabled, listener registration happens here in the + runtime order the execution model depends on: + 1. pre-interrupt snapshot listener + 2. interrupts listener + 3. post-interrupt snapshot listener + """ + active_state_snapshot_policy = get_effective_state_snapshot_policy_for_conversation( + conversation, + state_snapshot_policy, + ) + + with _use_state_snapshot_policy(conversation, active_state_snapshot_policy): + if active_state_snapshot_policy is None: + yield + return + + from wayflowcore.executors._interrupts_eventlistener import ( + get_interrupts_event_listener_context_for_conversation, + ) + + with ( + _use_state_snapshot_execution_root(conversation), + get_state_snapshot_event_listener_context_for_conversation( + conversation, + post_interrupts=False, + ), + get_interrupts_event_listener_context_for_conversation(conversation), + get_state_snapshot_event_listener_context_for_conversation( + conversation, + post_interrupts=True, + ), + ): + yield diff --git a/wayflowcore/src/wayflowcore/executors/statesnapshotpolicy.py b/wayflowcore/src/wayflowcore/executors/statesnapshotpolicy.py new file mode 100644 index 000000000..baa7cd6b3 --- /dev/null +++ b/wayflowcore/src/wayflowcore/executors/statesnapshotpolicy.py @@ -0,0 +1,59 @@ +# Copyright © 2026 Oracle and/or its affiliates. +# +# This software is under the Apache License 2.0 +# (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0) or Universal Permissive License +# (UPL) 1.0 (LICENSE-UPL or https://oss.oracle.com/licenses/upl), at your option. + +from __future__ import annotations + +from dataclasses import dataclass +from enum import Enum +from typing import TYPE_CHECKING, Any, Callable, Dict, Optional + +if TYPE_CHECKING: + from wayflowcore.conversation import Conversation + + +class StateSnapshotInterval(str, Enum): + """ + Configure which execution boundaries emit state snapshots. + + `CONVERSATION_TURNS` + Emit an opening turn snapshot before execution starts and a closing + turn snapshot at the end of the turn. The closing payload is emitted + before the live conversation commits the new status, but the payload is + synthesized so it matches the post-return conversation state. This is + the default policy because it gives a stable turn-level checkpoint + without emitting snapshots for every internal step. + + `TOOL_TURNS` + Emit the `CONVERSATION_TURNS` snapshots plus snapshots around each tool + invocation (`TOOL_START` and `TOOL_END`). + + `NODE_TURNS` + Emit the `CONVERSATION_TURNS` snapshots plus snapshots around each + internal node boundary. For flows this means per-step snapshots; for + agents it maps to decision-loop iteration boundaries. + + `ALL_INTERNAL_TURNS` + Emit the `CONVERSATION_TURNS`, `TOOL_TURNS`, and `NODE_TURNS` + snapshots. + + `OFF` + Disable state snapshot emission entirely. + """ + + CONVERSATION_TURNS = "conversation_turns" + TOOL_TURNS = "tool_turns" + NODE_TURNS = "node_turns" + ALL_INTERNAL_TURNS = "all_internal_turns" + OFF = "off" + + +@dataclass(frozen=True) +class StateSnapshotPolicy: + """Execution-time policy controlling WayFlow state snapshot emission.""" + + state_snapshot_interval: StateSnapshotInterval = StateSnapshotInterval.CONVERSATION_TURNS + include_variable_state: bool = True + extra_state_builder: Optional[Callable[["Conversation"], Optional[Dict[str, Any]]]] = None diff --git a/wayflowcore/src/wayflowcore/serialization/__init__.py b/wayflowcore/src/wayflowcore/serialization/__init__.py index 5d90ab78e..d4cec303e 100644 --- a/wayflowcore/src/wayflowcore/serialization/__init__.py +++ b/wayflowcore/src/wayflowcore/serialization/__init__.py @@ -4,6 +4,14 @@ # (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0) or Universal Permissive License # (UPL) 1.0 (LICENSE-UPL or https://oss.oracle.com/licenses/upl), at your option. +from .conversation import ( + deserialize_conversation, + deserialize_conversation_state, + dump_conversation_state, + dump_variable_state, + load_conversation_state, + serialize_conversation_state, +) from .serializer import ( autodeserialize, deserialize, @@ -15,7 +23,13 @@ __all__ = [ "autodeserialize", "deserialize", + "deserialize_conversation", + "deserialize_conversation_state", "deserialize_from_dict", + "dump_conversation_state", + "dump_variable_state", + "load_conversation_state", "serialize", + "serialize_conversation_state", "serialize_to_dict", ] diff --git a/wayflowcore/src/wayflowcore/serialization/_builtins_serialization_plugin.py b/wayflowcore/src/wayflowcore/serialization/_builtins_serialization_plugin.py index e10eea2a7..799fbf3b2 100644 --- a/wayflowcore/src/wayflowcore/serialization/_builtins_serialization_plugin.py +++ b/wayflowcore/src/wayflowcore/serialization/_builtins_serialization_plugin.py @@ -134,7 +134,9 @@ from wayflowcore.agentspec.components import ( PluginVllmEmbeddingConfig as AgentSpecPluginVllmEmbeddingConfig, ) -from wayflowcore.agentspec.components import all_serialization_plugin +from wayflowcore.agentspec.components import ( + all_serialization_plugin, +) from wayflowcore.agentspec.components.agent import ExtendedAgent as AgentSpecExtendedAgent from wayflowcore.agentspec.components.contextprovider import ( PluginConstantContextProvider as AgentSpecPluginConstantContextProvider, @@ -377,7 +379,9 @@ from wayflowcore.models.ociclientconfig import ( OCIClientConfigWithUserAuthentication as RuntimeOCIClientConfigWithUserAuthentication, ) -from wayflowcore.models.openaicompatiblemodel import EMPTY_API_KEY +from wayflowcore.models.openaicompatiblemodel import ( + EMPTY_API_KEY, +) from wayflowcore.models.openaicompatiblemodel import ( OpenAICompatibleModel as RuntimeOpenAICompatibleModel, ) diff --git a/wayflowcore/src/wayflowcore/serialization/conversation.py b/wayflowcore/src/wayflowcore/serialization/conversation.py new file mode 100644 index 000000000..5756b3aa4 --- /dev/null +++ b/wayflowcore/src/wayflowcore/serialization/conversation.py @@ -0,0 +1,597 @@ +# Copyright © 2026 Oracle and/or its affiliates. +# +# This software is under the Apache License 2.0 +# (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0) or Universal Permissive License +# (UPL) 1.0 (LICENSE-UPL or https://oss.oracle.com/licenses/upl), at your option. + +from __future__ import annotations + +import json +import math +import warnings +from contextlib import contextmanager +from datetime import datetime +from enum import Enum +from typing import TYPE_CHECKING, Any, Iterator, Optional, cast + +import yaml + +from wayflowcore.executors.executionstatus import ( + AuthChallengeRequestStatus, + ExecutionStatus, + FinishedStatus, + ToolExecutionConfirmationStatus, + ToolRequestStatus, + UserMessageRequestStatus, +) +from wayflowcore.serialization.context import DeserializationContext, SerializationContext +from wayflowcore.serialization.serializer import ( + autodeserialize_from_dict, + serialize, + serialize_any_to_dict_or_stringify, +) +from wayflowcore.tools.tools import ToolRequest, ToolResult + +if TYPE_CHECKING: + from wayflowcore.conversation import Conversation + from wayflowcore.executors._agentconversation import AgentConversation + from wayflowcore.executors._flowconversation import FlowConversation + from wayflowcore.messagelist import Message, MessageContent + from wayflowcore.serialization.plugins import ( + WayflowDeserializationPlugin, + WayflowSerializationPlugin, + ) + +_UNSET = object() + + +def _dump_conversation_reference(conversation: "Conversation") -> dict[str, Any]: + return { + "id": conversation.id, + "conversation_id": conversation.conversation_id, + "conversation_type": conversation.__class__.__name__, + } + + +def _dump_component_reference(component: Any) -> dict[str, Any]: + return { + "component_id": component.id, + "component_type": component.__class__.__name__, + } + + +def _dump_json_compatible_value(value: Any) -> Any: + from wayflowcore._utils.formatting import stringify + from wayflowcore.component import Component + from wayflowcore.conversation import Conversation + + dumped_value: Any + if value is None or isinstance(value, (bool, int, str)): + dumped_value = value + elif isinstance(value, float): + dumped_value = value if math.isfinite(value) else stringify(value) + elif isinstance(value, bytes): + dumped_value = value.decode("utf-8", errors="replace") + elif isinstance(value, datetime): + dumped_value = value.isoformat() + elif isinstance(value, Enum): + dumped_value = _dump_json_compatible_value(value.value) + elif isinstance(value, Conversation): + dumped_value = _dump_conversation_reference(value) + elif isinstance(value, Component): + dumped_value = _dump_component_reference(value) + elif isinstance(value, dict): + dumped_value = { + str(key): _dump_json_compatible_value(inner_value) for key, inner_value in value.items() + } + elif isinstance(value, (list, tuple, set)): + dumped_value = [_dump_json_compatible_value(inner_value) for inner_value in value] + else: + serialized_value = serialize_any_to_dict_or_stringify( + value, SerializationContext(root=value) + ) + if serialized_value is value: + dumped_value = stringify(value) + else: + dumped_value = _dump_json_compatible_value(serialized_value) + return dumped_value + + +def _dump_string_keyed_mapping(values: dict[Any, Any]) -> dict[str, Any]: + return {str(key): _dump_json_compatible_value(value) for key, value in values.items()} + + +def _dump_flow_input_output_key_values(values: dict[Any, Any]) -> dict[str, Any]: + return { + f"{step_name}.{value_name}": _dump_json_compatible_value(value) + for (step_name, value_name), value in values.items() + } + + +def _dump_tool_requests( + tool_requests: Optional[list[ToolRequest]], +) -> list[Optional[dict[str, Any]]]: + return [_dump_tool_request(tool_request) for tool_request in tool_requests or []] + + +def _dump_tool_request(tool_request: Optional[ToolRequest]) -> Optional[dict[str, Any]]: + if tool_request is None: + dumped_tool_request = None + else: + dumped_tool_request = { + "name": tool_request.name, + "tool_request_id": tool_request.tool_request_id, + "args": _dump_json_compatible_value(tool_request.args), + "requires_confirmation": tool_request._requires_confirmation, + "tool_execution_confirmed": tool_request._tool_execution_confirmed, + "tool_rejection_reason": tool_request._tool_rejection_reason, + } + return dumped_tool_request + + +def _dump_tool_result(tool_result: Optional[ToolResult]) -> Optional[dict[str, Any]]: + if tool_result is None: + dumped_tool_result = None + else: + dumped_tool_result = { + "tool_request_id": tool_result.tool_request_id, + "content": _dump_json_compatible_value(tool_result.content), + } + return dumped_tool_result + + +def _dump_tool_related_execution_status( + execution_status: ToolRequestStatus | ToolExecutionConfirmationStatus, +) -> dict[str, Any]: + dumped_status = { + "tool_requests": _dump_tool_requests(execution_status.tool_requests), + } + if isinstance(execution_status, ToolRequestStatus): + dumped_status["tool_results"] = [ + _dump_tool_result(tool_result) for tool_result in execution_status._tool_results or [] + ] + return dumped_status + + +def _dump_message_content(content: MessageContent) -> dict[str, Any]: + from wayflowcore.messagelist import ImageContent, TextContent + + content_type = getattr(content, "type", content.__class__.__name__) + + if isinstance(content, TextContent): + dumped_content = { + "type": content.type, + "content": content.content, + } + elif isinstance(content, ImageContent): + dumped_content = { + "type": content.type, + "base64_content": content.base64_content, + } + else: + serialized_content = _dump_json_compatible_value(content) + if isinstance(serialized_content, dict): + if "type" in serialized_content: + dumped_content = serialized_content + else: + dumped_content = {"type": content_type, **serialized_content} + else: + dumped_content = { + "type": content_type, + "content": serialized_content, + } + return dumped_content + + +def _dump_message(message: Message) -> dict[str, Any]: + dumped_message: dict[str, Any] = { + "role": message.role, + "message_type": message.message_type.value if message.message_type else None, + "sender": message.sender, + "recipients": sorted(message.recipients), + "time_created": message.time_created.isoformat(), + "time_updated": message.time_updated.isoformat(), + "content": message.content, + "contents": [_dump_message_content(content) for content in message.contents], + } + + tool_requests = _dump_tool_requests(message.tool_requests) + dumped_tool_result = _dump_tool_result(message.tool_result) + + if tool_requests: + dumped_message["tool_requests"] = tool_requests + if dumped_tool_result is not None: + dumped_message["tool_result"] = dumped_tool_result + return dumped_message + + +def _dump_execution_status(execution_status: Optional[ExecutionStatus]) -> Optional[dict[str, Any]]: + if execution_status is None: + return None + + dumped_status: dict[str, Any] = {"type": execution_status.__class__.__name__} + if isinstance(execution_status, FinishedStatus): + dumped_status["output_values"] = _dump_json_compatible_value(execution_status.output_values) + dumped_status["complete_step_name"] = execution_status.complete_step_name + elif isinstance(execution_status, UserMessageRequestStatus): + dumped_status["message"] = _dump_message(execution_status.message) + elif isinstance(execution_status, (ToolRequestStatus, ToolExecutionConfirmationStatus)): + dumped_status.update(_dump_tool_related_execution_status(execution_status)) + elif isinstance(execution_status, AuthChallengeRequestStatus): + dumped_status["client_transport_id"] = execution_status.client_transport_id + return dumped_status + + +def _dump_conversation_info(conversation: "Conversation") -> dict[str, Any]: + return { + **_dump_conversation_reference(conversation), + "component_type": conversation.component.__class__.__name__, + "name": conversation.name, + "inputs": _dump_json_compatible_value(conversation.inputs), + "messages": [_dump_message(message) for message in conversation.get_messages()], + } + + +def _dump_common_execution_info( + conversation: "Conversation", + *, + status: object = _UNSET, + status_handled: object = _UNSET, +) -> dict[str, Any]: + return { + "current_step_name": conversation.current_step_name, + "status": _dump_execution_status( + conversation.status if status is _UNSET else cast(Optional[ExecutionStatus], status) + ), + "status_handled": ( + conversation.status_handled if status_handled is _UNSET else cast(bool, status_handled) + ), + } + + +def _dump_flow_execution_info(conversation: "FlowConversation") -> dict[str, Any]: + return { + "step_history": list(conversation.state.step_history), + "nesting_level": conversation.state.nesting_level, + "input_output_key_values": _dump_flow_input_output_key_values( + conversation.state.input_output_key_values + ), + "flow_output_values": _dump_json_compatible_value( + conversation.state._flow_output_value_dict + ), + "context_key_values": _dump_json_compatible_value(conversation.state.context_key_values), + "internal_context_key_values": _dump_string_keyed_mapping( + conversation.state.internal_context_key_values + ), + } + + +def _dump_agent_execution_info(conversation: "AgentConversation") -> dict[str, Any]: + return { + "curr_iter": conversation.state.curr_iter, + "has_confirmed_conversation_exit": conversation.state.has_confirmed_conversation_exit, + "tool_call_queue": _dump_tool_requests(conversation.state.tool_call_queue), + "current_tool_request": _dump_tool_request(conversation.state.current_tool_request), + "current_flow_conversation": _dump_json_compatible_value( + conversation.state.current_flow_conversation + ), + "current_sub_component_conversations": _dump_string_keyed_mapping( + conversation.state.current_sub_component_conversations + ), + } + + +def _dump_component_execution_info(conversation: "Conversation") -> dict[str, Any]: + from wayflowcore.executors._agentconversation import AgentConversation + from wayflowcore.executors._flowconversation import FlowConversation + + if isinstance(conversation, FlowConversation): + return _dump_flow_execution_info(conversation) + if isinstance(conversation, AgentConversation): + return _dump_agent_execution_info(conversation) + return {} + + +def dump_conversation_state( + conversation: "Conversation", + *, + status: object = _UNSET, + status_handled: object = _UNSET, +) -> dict[str, Any]: + """ + Return a strict-JSON-serializable runtime snapshot of a conversation. + + The returned dictionary is a lightweight inspection/tracing view. It captures + the user-visible conversation state and the runtime execution state without + embedding live component objects or the authoritative serialized + conversation blob used for resumability. Optional ``status`` and + ``status_handled`` overrides are available so callers can snapshot a + slightly adjusted view of the current runtime state without mutating the + conversation itself. + + Parameters + ---------- + conversation: + Conversation instance to snapshot. + status: + Optional execution status override to include in the dumped state instead + of ``conversation.status``. + status_handled: + Optional ``status_handled`` override to include in the dumped state + instead of ``conversation.status_handled``. + + Returns + ------- + dict[str, Any] + JSON-compatible conversation snapshot containing ``conversation`` and + ``execution`` sections. + """ + return { + "conversation": _dump_conversation_info(conversation), + "execution": { + **_dump_common_execution_info( + conversation, + status=status, + status_handled=status_handled, + ), + **_dump_component_execution_info(conversation), + }, + } + + +@contextmanager +def _use_conversation_runtime_overrides( + conversation: "Conversation", + *, + status: Optional[ExecutionStatus], + status_handled: bool, +) -> Iterator[None]: + previous_status = conversation.status + previous_status_handled = conversation.status_handled + conversation.status = status + conversation.status_handled = status_handled + try: + yield + finally: + conversation.status = previous_status + conversation.status_handled = previous_status_handled + + +def _serialize_conversation_state_with_runtime_overrides( + conversation: "Conversation", + *, + status: Optional[ExecutionStatus], + status_handled: bool, + serialization_context: Optional[SerializationContext] = None, + plugins: Optional[list["WayflowSerializationPlugin"]] = None, +) -> str: + """Serialize a conversation as if its runtime status fields already matched a snapshot.""" + + with _use_conversation_runtime_overrides( + conversation, + status=status, + status_handled=status_handled, + ): + return serialize_conversation_state( + conversation, + serialization_context=serialization_context, + plugins=plugins, + ) + + +def serialize_conversation_state( + conversation: "Conversation", + serialization_context: Optional[SerializationContext] = None, + plugins: Optional[list["WayflowSerializationPlugin"]] = None, +) -> str: + """ + Serialize a conversation into its stable textual state representation. + + This is the string form meant for storage or transport when the full runtime + conversation needs to be preserved for later loading. Unlike + ``dump_conversation_state()``, this serializes the actual conversation object + graph using WayFlow serialization. For WayFlow-emitted state snapshots, this + is the authoritative resumable blob stored under + ``state_snapshot["conversation_state"]``. + + Parameters + ---------- + conversation: + Conversation instance to serialize. + serialization_context: + Optional serialization context to use. + plugins: + Optional serialization plugins to use. + + Returns + ------- + str + Serialized conversation state string. + """ + return serialize( + conversation, + serialization_context=serialization_context, + plugins=plugins, + ) + + +def deserialize_conversation_state(state: str) -> dict[str, Any]: + """ + Parse a serialized conversation state string into a dictionary. + + This is the dictionary-level counterpart of + ``serialize_conversation_state()``. It is useful when callers need to inspect + or adjust the serialized payload before loading it back into a live + ``Conversation`` object, including the ``conversation_state`` string emitted + in WayFlow state snapshots. + + Parameters + ---------- + state: + Serialized conversation state string. + + Returns + ------- + dict[str, Any] + Parsed serialized state. + + Raises + ------ + TypeError + If the serialized payload does not deserialize into a dictionary. + """ + loaded_state = yaml.safe_load(state) + if not isinstance(loaded_state, dict): + raise TypeError("Serialized conversation state must deserialize into a dictionary.") + return cast(dict[str, Any], loaded_state) + + +def load_conversation_state( + state: dict[str, Any], + deserialization_context: Optional[DeserializationContext] = None, + plugins: Optional[list["WayflowDeserializationPlugin"]] = None, +) -> "Conversation": + """ + Reconstruct a live conversation from a serialized state dictionary. + + The input dictionary is expected to come from + ``deserialize_conversation_state()`` or another equivalent WayFlow + serialization source. When resuming from a WayFlow state snapshot payload, + first parse ``state_snapshot["conversation_state"]`` with + ``deserialize_conversation_state()`` and then pass the resulting dictionary + here. + + Parameters + ---------- + state: + Serialized conversation state as a dictionary. + deserialization_context: + Optional deserialization context to use. This is the preferred way to + provide tool registries or plugins. + plugins: + Optional deserialization plugins. When a deserialization context is + already provided, plugins should be attached to that context instead. + + Returns + ------- + Conversation + Reconstructed live conversation instance. + + Raises + ------ + TypeError + If the deserialized object is not a ``Conversation``. + """ + from wayflowcore.conversation import Conversation + + deserialization_context = _resolve_deserialization_context( + deserialization_context=deserialization_context, + plugins=plugins, + ) + + conversation = autodeserialize_from_dict(state, deserialization_context) + if not isinstance(conversation, Conversation): + raise TypeError( + f"Loaded object is of type {conversation.__class__.__name__}, not Conversation." + ) + return conversation + + +def deserialize_conversation( + conversation_state: str, + deserialization_context: Optional[DeserializationContext] = None, + plugins: Optional[list["WayflowDeserializationPlugin"]] = None, +) -> "Conversation": + """ + Reconstruct a conversation directly from its serialized string form. + + This is a convenience wrapper around + ``deserialize_conversation_state()`` followed by ``load_conversation_state()``. + It is the simplest restore API when you already have the serialized + ``conversation_state`` string from a WayFlow snapshot payload. + + Parameters + ---------- + conversation_state: + Serialized conversation state string. + deserialization_context: + Optional deserialization context to use. + plugins: + Optional deserialization plugins. + + Returns + ------- + Conversation + Reconstructed live conversation instance. + """ + return load_conversation_state( + deserialize_conversation_state(conversation_state), + deserialization_context=deserialization_context, + plugins=plugins, + ) + + +def dump_variable_state(conversation: "Conversation") -> Optional[dict[str, Any]]: + """ + Return the strict-JSON-serializable runtime-owned variable state for a conversation. + + Only flow conversations expose runtime variable storage. For other + conversation types, this returns ``None``. + + Parameters + ---------- + conversation: + Conversation whose runtime-owned variable values should be dumped. + + Returns + ------- + dict[str, Any] | None + Strict-JSON-compatible mapping of variable names to values for flow + conversations, otherwise ``None``. + + Raises + ------ + TypeError + If a variable contains a value that cannot be represented as JSON. + """ + from wayflowcore.executors._flowconversation import FlowConversation + + if not isinstance(conversation, FlowConversation): + return None + + variable_state: dict[str, Any] = {} + for variable_name, variable_value in conversation.state.variable_store.items(): + try: + serialized_value = json.dumps(variable_value, sort_keys=True, allow_nan=False) + except (TypeError, ValueError) as e: + raise TypeError( + f"Variable '{variable_name}' contains a non-JSON-serializable value of type {type(variable_value).__name__}" + ) from e + variable_state[variable_name] = cast(Any, json.loads(serialized_value)) + return variable_state + + +def _resolve_deserialization_context( + *, + deserialization_context: Optional[DeserializationContext], + plugins: Optional[list["WayflowDeserializationPlugin"]], +) -> DeserializationContext: + if deserialization_context is None: + return DeserializationContext(plugins=plugins) + if plugins is not None: + warnings.warn( + "A list of plugins was provided together with a deserialization context instance in `load_conversation_state`. " + "Do not pass the plugins to `load_conversation_state`, but create the context instance passing the list of plugins instead.", + UserWarning, + ) + return deserialization_context + + +__all__ = [ + "deserialize_conversation", + "deserialize_conversation_state", + "dump_conversation_state", + "dump_variable_state", + "load_conversation_state", + "serialize_conversation_state", +] diff --git a/wayflowcore/src/wayflowcore/steps/flowexecutionstep.py b/wayflowcore/src/wayflowcore/steps/flowexecutionstep.py index 1e5b5c0cb..8a0b66a18 100644 --- a/wayflowcore/src/wayflowcore/steps/flowexecutionstep.py +++ b/wayflowcore/src/wayflowcore/steps/flowexecutionstep.py @@ -250,7 +250,9 @@ async def _invoke_step_async( flow=self.flow, inputs=inputs, ) - status = await sub_conversation.execute_async() + status = sub_conversation.status + if not isinstance(status, FinishedStatus): + status = await sub_conversation.execute_async() if isinstance(status, InterruptedExecutionStatus): return StepResult( diff --git a/wayflowcore/src/wayflowcore/steps/parallelflowexecutionstep.py b/wayflowcore/src/wayflowcore/steps/parallelflowexecutionstep.py index 16bfc8fe1..75878abc0 100644 --- a/wayflowcore/src/wayflowcore/steps/parallelflowexecutionstep.py +++ b/wayflowcore/src/wayflowcore/steps/parallelflowexecutionstep.py @@ -12,7 +12,6 @@ from wayflowcore._metadata import MetadataType from wayflowcore._utils.async_helpers import run_async_function_in_parallel -from wayflowcore.executors._flowexecutor import FlowConversationExecutor from wayflowcore.executors.executionstatus import ExecutionStatus, FinishedStatus from wayflowcore.executors.interrupts.executioninterrupt import InterruptedExecutionStatus from wayflowcore.property import Property @@ -264,13 +263,13 @@ async def _invoke_step_async( sub_conversations_to_run = [ sub_conversation for sub_conversation in sub_conversations - if sub_conversation.status is None or not isinstance(sub_conversation, FinishedStatus) + if not isinstance(sub_conversation.status, FinishedStatus) ] # We collect the statuses of the conversations that did already finish as we need them for the cleanup finished_statuses = [ sub_conversation.status for sub_conversation in sub_conversations - if sub_conversation.status is not None and isinstance(sub_conversation, FinishedStatus) + if isinstance(sub_conversation.status, FinishedStatus) ] async def _run_single_flow_target(sub_conv: "FlowConversation") -> ExecutionStatus: @@ -311,11 +310,8 @@ async def _run_single_flow_target(sub_conv: "FlowConversation") -> ExecutionStat f"Illegal response from a subflow: some subflow returned a non-finished status: {non_finished_status}" ) - for sub_conv in sub_conversations: - FlowConversationExecutor().cleanup_sub_conversation( - sub_conv.state, - self, - ) + for flow in self.flows: + conversation._cleanup_sub_conversation(self, flow.id) if not all(isinstance(status, FinishedStatus) for status in statuses): raise RuntimeError( diff --git a/wayflowcore/tests/agentspec/test_state_snapshot_tracing_agent.py b/wayflowcore/tests/agentspec/test_state_snapshot_tracing_agent.py new file mode 100644 index 000000000..6ef34b9dd --- /dev/null +++ b/wayflowcore/tests/agentspec/test_state_snapshot_tracing_agent.py @@ -0,0 +1,230 @@ +# Copyright © 2026 Oracle and/or its affiliates. +# +# This software is under the Apache License 2.0 +# (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0) or Universal Permissive License +# (UPL) 1.0 (LICENSE-UPL or https://oss.oracle.com/licenses/upl), at your option. + +from contextlib import AbstractContextManager, ExitStack +from typing import Any, Sequence + +from pyagentspec.tracing.events import AgentExecutionEnd as AgentSpecAgentExecutionEnd +from pyagentspec.tracing.events import AgentExecutionStart as AgentSpecAgentExecutionStart +from pyagentspec.tracing.events import Event as AgentSpecEvent +from pyagentspec.tracing.events import StateSnapshotEmitted as AgentSpecStateSnapshotEmitted +from pyagentspec.tracing.spanprocessor import SpanProcessor as AgentSpecSpanProcessor +from pyagentspec.tracing.spans import AgentExecutionSpan as AgentSpecAgentExecutionSpan +from pyagentspec.tracing.spans import LlmGenerationSpan as AgentSpecLlmGenerationSpan +from pyagentspec.tracing.spans import Span as AgentSpecSpan +from pyagentspec.tracing.trace import Trace as AgentSpecTrace + +from wayflowcore import Agent as WayflowAgent +from wayflowcore.agentspec.tracing import AgentSpecEventListener +from wayflowcore.events.eventlistener import register_event_listeners +from wayflowcore.executors.executionstatus import UserMessageRequestStatus +from wayflowcore.executors.statesnapshotpolicy import StateSnapshotInterval, StateSnapshotPolicy +from wayflowcore.models.vllmmodel import VllmModel + +from ..testhelpers.patching import patch_llm + + +class _PassiveSpanProcessor(AgentSpecSpanProcessor): + def on_start(self, span: AgentSpecSpan) -> None: + return None + + async def on_start_async(self, span: AgentSpecSpan) -> None: + return None + + def on_end(self, span: AgentSpecSpan) -> None: + return None + + async def on_end_async(self, span: AgentSpecSpan) -> None: + return None + + def on_event(self, event: AgentSpecEvent, span: AgentSpecSpan) -> None: + return None + + async def on_event_async(self, event: AgentSpecEvent, span: AgentSpecSpan) -> None: + return None + + def startup(self) -> None: + return None + + def shutdown(self) -> None: + return None + + async def startup_async(self) -> None: + return None + + async def shutdown_async(self) -> None: + return None + + +class _SnapshotSpanRecorder(_PassiveSpanProcessor): + def __init__(self) -> None: + super().__init__() + self.started_spans: list[AgentSpecSpan] = [] + + def on_start(self, span: AgentSpecSpan) -> None: + self.started_spans.append(span) + + async def on_start_async(self, span: AgentSpecSpan) -> None: + self.started_spans.append(span) + + +class _EventsSeenAtSpanEndRecorder(_PassiveSpanProcessor): + def __init__(self) -> None: + super().__init__() + self.events_by_span_id: dict[str, list[AgentSpecEvent]] = {} + + def on_end(self, span: AgentSpecSpan) -> None: + self.events_by_span_id[span.id] = list(span.events) + + async def on_end_async(self, span: AgentSpecSpan) -> None: + self.events_by_span_id[span.id] = list(span.events) + + +def _recorded_spans( + span_recorder: _SnapshotSpanRecorder, + span_type: type[AgentSpecSpan], +) -> list[AgentSpecSpan]: + return [span for span in span_recorder.started_spans if isinstance(span, span_type)] + + +def _single_span( + span_recorder: _SnapshotSpanRecorder, + span_type: type[AgentSpecSpan], +) -> AgentSpecSpan: + matching_spans = _recorded_spans(span_recorder, span_type) + assert len(matching_spans) == 1 + return matching_spans[0] + + +def _span_events( + span: AgentSpecSpan, + event_type: type[AgentSpecEvent], +) -> list[AgentSpecEvent]: + return [event for event in span.events if isinstance(event, event_type)] + + +def _execute_with_trace( + conversation, + *, + state_snapshot_policy, + span_processors: Sequence[AgentSpecSpanProcessor] = (), + contexts: Sequence[AbstractContextManager[Any]] = (), +) -> tuple[Any, _SnapshotSpanRecorder]: + span_recorder = _SnapshotSpanRecorder() + listener = AgentSpecEventListener() + + with ExitStack() as stack: + for context in contexts: + stack.enter_context(context) + stack.enter_context(AgentSpecTrace(span_processors=[span_recorder, *span_processors])) + stack.enter_context(register_event_listeners([listener])) + status = conversation.execute(state_snapshot_policy=state_snapshot_policy) + + return status, span_recorder + + +def _make_single_turn_agent_conversation() -> tuple[str, VllmModel, Any]: + assistant_message = "Hello from agent" + llm = VllmModel(model_id="mock.model", host_port="http://mock.url", name="agent") + conversation = WayflowAgent(llm=llm).start_conversation() + conversation.append_user_message("Hi") + return assistant_message, llm, conversation + + +def _snapshot_message(snapshot_event: AgentSpecStateSnapshotEmitted) -> str | None: + messages = (snapshot_event.state_snapshot or {}).get("conversation", {}).get("messages", []) + if not messages: + return None + return messages[-1].get("content") + + +def test_conversation_turn_snapshots_attach_to_the_agent_span() -> None: + assistant_message, llm, conversation = _make_single_turn_agent_conversation() + status, span_recorder = _execute_with_trace( + conversation, + state_snapshot_policy=StateSnapshotPolicy( + state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS + ), + contexts=[patch_llm(llm, [assistant_message], patch_internal=True)], + ) + + assert isinstance(status, UserMessageRequestStatus) + + agent_span = _single_span(span_recorder, AgentSpecAgentExecutionSpan) + state_snapshot_events = _span_events(agent_span, AgentSpecStateSnapshotEmitted) + + assert _span_events(agent_span, AgentSpecAgentExecutionStart) + assert len(state_snapshot_events) == 2 + assert state_snapshot_events[-1].conversation_id == conversation.conversation_id + assert _snapshot_message(state_snapshot_events[-1]) == assistant_message + + +def test_node_turn_snapshots_attach_to_the_agent_span_not_llm_spans() -> None: + assistant_message, llm, conversation = _make_single_turn_agent_conversation() + status, span_recorder = _execute_with_trace( + conversation, + state_snapshot_policy=StateSnapshotPolicy( + state_snapshot_interval=StateSnapshotInterval.NODE_TURNS + ), + contexts=[patch_llm(llm, [assistant_message], patch_internal=True)], + ) + + assert isinstance(status, UserMessageRequestStatus) + + agent_span = _single_span(span_recorder, AgentSpecAgentExecutionSpan) + state_snapshot_events = _span_events(agent_span, AgentSpecStateSnapshotEmitted) + + assert len(state_snapshot_events) == 4 + assert state_snapshot_events[-1].state_snapshot["execution"]["status"]["type"] == ( + "UserMessageRequestStatus" + ) + assert _snapshot_message(state_snapshot_events[-1]) == assistant_message + assert not any( + isinstance(event, AgentSpecStateSnapshotEmitted) + for span in _recorded_spans(span_recorder, AgentSpecLlmGenerationSpan) + for event in span.events + ) + + +def test_final_agent_snapshot_is_visible_to_span_processors_inside_on_end() -> None: + assistant_message, llm, conversation = _make_single_turn_agent_conversation() + events_seen_at_end_recorder = _EventsSeenAtSpanEndRecorder() + status, span_recorder = _execute_with_trace( + conversation, + state_snapshot_policy=StateSnapshotPolicy( + state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS + ), + span_processors=[events_seen_at_end_recorder], + contexts=[patch_llm(llm, [assistant_message], patch_internal=True)], + ) + + assert isinstance(status, UserMessageRequestStatus) + + agent_span = _single_span(span_recorder, AgentSpecAgentExecutionSpan) + events_seen_at_end = events_seen_at_end_recorder.events_by_span_id[agent_span.id] + + assert any(isinstance(event, AgentSpecAgentExecutionEnd) for event in events_seen_at_end) + assert isinstance(events_seen_at_end[-1], AgentSpecStateSnapshotEmitted) + assert _snapshot_message(events_seen_at_end[-1]) == assistant_message + + +def test_agent_snapshot_extra_state_is_passed_through_verbatim() -> None: + assistant_message, llm, conversation = _make_single_turn_agent_conversation() + status, span_recorder = _execute_with_trace( + conversation, + state_snapshot_policy=StateSnapshotPolicy( + state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS, + extra_state_builder=lambda _conversation: {"ui": {"active_tab": "plan"}}, + ), + contexts=[patch_llm(llm, [assistant_message], patch_internal=True)], + ) + + assert isinstance(status, UserMessageRequestStatus) + + agent_span = _single_span(span_recorder, AgentSpecAgentExecutionSpan) + state_snapshot_events = _span_events(agent_span, AgentSpecStateSnapshotEmitted) + + assert state_snapshot_events[-1].extra_state == {"ui": {"active_tab": "plan"}} diff --git a/wayflowcore/tests/agentspec/test_state_snapshot_tracing_flow.py b/wayflowcore/tests/agentspec/test_state_snapshot_tracing_flow.py new file mode 100644 index 000000000..b29ee11ef --- /dev/null +++ b/wayflowcore/tests/agentspec/test_state_snapshot_tracing_flow.py @@ -0,0 +1,332 @@ +# Copyright © 2026 Oracle and/or its affiliates. +# +# This software is under the Apache License 2.0 +# (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0) or Universal Permissive License +# (UPL) 1.0 (LICENSE-UPL or https://oss.oracle.com/licenses/upl), at your option. + +from contextlib import AbstractContextManager, ExitStack +from typing import Any, Sequence + +import pytest +from pyagentspec.tracing.events import Event as AgentSpecEvent +from pyagentspec.tracing.events import FlowExecutionEnd as AgentSpecFlowExecutionEnd +from pyagentspec.tracing.events import FlowExecutionStart as AgentSpecFlowExecutionStart +from pyagentspec.tracing.events import StateSnapshotEmitted as AgentSpecStateSnapshotEmitted +from pyagentspec.tracing.spanprocessor import SpanProcessor as AgentSpecSpanProcessor +from pyagentspec.tracing.spans import FlowExecutionSpan as AgentSpecFlowExecutionSpan +from pyagentspec.tracing.spans import NodeExecutionSpan as AgentSpecNodeExecutionSpan +from pyagentspec.tracing.spans import Span as AgentSpecSpan +from pyagentspec.tracing.spans import ToolExecutionSpan as AgentSpecToolExecutionSpan +from pyagentspec.tracing.trace import Trace as AgentSpecTrace + +from wayflowcore.agentspec.tracing import AgentSpecEventListener +from wayflowcore.events.eventlistener import register_event_listeners +from wayflowcore.executors.executionstatus import FinishedStatus +from wayflowcore.executors.statesnapshotpolicy import StateSnapshotInterval, StateSnapshotPolicy +from wayflowcore.flow import Flow +from wayflowcore.steps import CompleteStep, FlowExecutionStep, OutputMessageStep, ToolExecutionStep +from wayflowcore.tools import ServerTool + + +class _PassiveSpanProcessor(AgentSpecSpanProcessor): + def on_start(self, span: AgentSpecSpan) -> None: + return None + + async def on_start_async(self, span: AgentSpecSpan) -> None: + return None + + def on_end(self, span: AgentSpecSpan) -> None: + return None + + async def on_end_async(self, span: AgentSpecSpan) -> None: + return None + + def on_event(self, event: AgentSpecEvent, span: AgentSpecSpan) -> None: + return None + + async def on_event_async(self, event: AgentSpecEvent, span: AgentSpecSpan) -> None: + return None + + def startup(self) -> None: + return None + + def shutdown(self) -> None: + return None + + async def startup_async(self) -> None: + return None + + async def shutdown_async(self) -> None: + return None + + +class _SnapshotSpanRecorder(_PassiveSpanProcessor): + def __init__(self) -> None: + super().__init__() + self.started_spans: list[AgentSpecSpan] = [] + + def on_start(self, span: AgentSpecSpan) -> None: + self.started_spans.append(span) + + async def on_start_async(self, span: AgentSpecSpan) -> None: + self.started_spans.append(span) + + +class _EventsSeenAtSpanEndRecorder(_PassiveSpanProcessor): + def __init__(self) -> None: + super().__init__() + self.events_by_span_id: dict[str, list[AgentSpecEvent]] = {} + + def on_end(self, span: AgentSpecSpan) -> None: + self.events_by_span_id[span.id] = list(span.events) + + async def on_end_async(self, span: AgentSpecSpan) -> None: + self.events_by_span_id[span.id] = list(span.events) + + +def _recorded_spans( + span_recorder: _SnapshotSpanRecorder, + span_type: type[AgentSpecSpan], +) -> list[AgentSpecSpan]: + return [span for span in span_recorder.started_spans if isinstance(span, span_type)] + + +def _single_span( + span_recorder: _SnapshotSpanRecorder, + span_type: type[AgentSpecSpan], +) -> AgentSpecSpan: + matching_spans = _recorded_spans(span_recorder, span_type) + assert len(matching_spans) == 1 + return matching_spans[0] + + +def _span_events( + span: AgentSpecSpan, + event_type: type[AgentSpecEvent], +) -> list[AgentSpecEvent]: + return [event for event in span.events if isinstance(event, event_type)] + + +def _execute_with_trace( + conversation, + *, + state_snapshot_policy, + span_processors: Sequence[AgentSpecSpanProcessor] = (), + contexts: Sequence[AbstractContextManager[Any]] = (), +) -> tuple[Any, _SnapshotSpanRecorder]: + span_recorder = _SnapshotSpanRecorder() + listener = AgentSpecEventListener() + + with ExitStack() as stack: + for context in contexts: + stack.enter_context(context) + stack.enter_context(AgentSpecTrace(span_processors=[span_recorder, *span_processors])) + stack.enter_context(register_event_listeners([listener])) + status = conversation.execute(state_snapshot_policy=state_snapshot_policy) + + return status, span_recorder + + +def _make_output_flow() -> Flow: + return Flow.from_steps( + [ + OutputMessageStep(message_template="Hello", name="single_step"), + CompleteStep(name="end"), + ], + name="simple_output_flow", + ) + + +def _make_tool_flow() -> Flow: + return Flow.from_steps( + [ + ToolExecutionStep( + tool=ServerTool( + name="say_hi", + description="Say hi", + func=lambda: "hi", + input_descriptors=[], + ), + name="tool_step", + ), + CompleteStep(name="end"), + ], + name="tool_flow", + ) + + +def _make_nested_parent_flow_conversation(): + child_flow = Flow.from_steps( + [ + OutputMessageStep(message_template="child", name="child_message"), + CompleteStep(name="end"), + ], + name="child_flow", + ) + parent_flow = Flow.from_steps( + [ + FlowExecutionStep(flow=child_flow, name="child_flow_step"), + OutputMessageStep(message_template="parent", name="parent_message"), + CompleteStep(name="end"), + ], + name="parent_flow", + ) + return parent_flow.start_conversation() + + +def _snapshot_message(snapshot_event: AgentSpecStateSnapshotEmitted) -> str | None: + messages = (snapshot_event.state_snapshot or {}).get("conversation", {}).get("messages", []) + if not messages: + return None + return messages[-1].get("content") + + +def test_conversation_turn_snapshots_attach_to_the_flow_span() -> None: + conversation = _make_output_flow().start_conversation() + status, span_recorder = _execute_with_trace( + conversation, + state_snapshot_policy=StateSnapshotPolicy( + state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS + ), + ) + + assert isinstance(status, FinishedStatus) + + flow_span = _single_span(span_recorder, AgentSpecFlowExecutionSpan) + state_snapshot_events = _span_events(flow_span, AgentSpecStateSnapshotEmitted) + + assert _span_events(flow_span, AgentSpecFlowExecutionStart) + assert len(state_snapshot_events) == 2 + assert state_snapshot_events[-1].conversation_id == conversation.conversation_id + assert _snapshot_message(state_snapshot_events[-1]) == "Hello" + + +def test_final_flow_snapshot_is_visible_to_span_processors_inside_on_end() -> None: + conversation = _make_output_flow().start_conversation() + events_seen_at_end_recorder = _EventsSeenAtSpanEndRecorder() + status, span_recorder = _execute_with_trace( + conversation, + state_snapshot_policy=StateSnapshotPolicy( + state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS + ), + span_processors=[events_seen_at_end_recorder], + ) + + assert isinstance(status, FinishedStatus) + + flow_span = _single_span(span_recorder, AgentSpecFlowExecutionSpan) + events_seen_at_end = events_seen_at_end_recorder.events_by_span_id[flow_span.id] + + assert any(isinstance(event, AgentSpecFlowExecutionEnd) for event in events_seen_at_end) + assert isinstance(events_seen_at_end[-1], AgentSpecStateSnapshotEmitted) + assert _snapshot_message(events_seen_at_end[-1]) == "Hello" + + +def test_node_turn_snapshots_attach_to_the_flow_span_not_node_or_tool_spans() -> None: + conversation = _make_tool_flow().start_conversation() + status, span_recorder = _execute_with_trace( + conversation, + state_snapshot_policy=StateSnapshotPolicy( + state_snapshot_interval=StateSnapshotInterval.NODE_TURNS + ), + ) + + assert isinstance(status, FinishedStatus) + + flow_span = _single_span(span_recorder, AgentSpecFlowExecutionSpan) + flow_snapshot_events = _span_events(flow_span, AgentSpecStateSnapshotEmitted) + + assert len(flow_snapshot_events) > 2 + assert not any( + isinstance(event, AgentSpecStateSnapshotEmitted) + for span in _recorded_spans(span_recorder, AgentSpecNodeExecutionSpan) + for event in span.events + ) + assert not any( + isinstance(event, AgentSpecStateSnapshotEmitted) + for span in _recorded_spans(span_recorder, AgentSpecToolExecutionSpan) + for event in span.events + ) + + +def test_nested_flow_execution_exports_snapshots_only_on_the_root_flow_span() -> None: + conversation = _make_nested_parent_flow_conversation() + status, span_recorder = _execute_with_trace( + conversation, + state_snapshot_policy=StateSnapshotPolicy( + state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS + ), + ) + + assert isinstance(status, FinishedStatus) + + flow_spans = _recorded_spans(span_recorder, AgentSpecFlowExecutionSpan) + assert len(flow_spans) == 2 + + flow_spans_by_name = { + next( + event for event in span.events if isinstance(event, AgentSpecFlowExecutionStart) + ).flow.name: span + for span in flow_spans + } + parent_span = flow_spans_by_name["parent_flow"] + child_span = flow_spans_by_name["child_flow"] + + parent_snapshot_events = _span_events(parent_span, AgentSpecStateSnapshotEmitted) + child_snapshot_events = _span_events(child_span, AgentSpecStateSnapshotEmitted) + + assert len(parent_snapshot_events) == 2 + assert [event.conversation_id for event in parent_snapshot_events] == [ + conversation.conversation_id, + conversation.conversation_id, + ] + assert _snapshot_message(parent_snapshot_events[-1]) == "parent" + assert child_snapshot_events == [] + + +def test_off_policy_disables_flow_state_snapshot_export() -> None: + conversation = _make_output_flow().start_conversation() + status, span_recorder = _execute_with_trace( + conversation, + state_snapshot_policy=StateSnapshotPolicy( + state_snapshot_interval=StateSnapshotInterval.OFF + ), + ) + + assert isinstance(status, FinishedStatus) + + flow_span = _single_span(span_recorder, AgentSpecFlowExecutionSpan) + assert _span_events(flow_span, AgentSpecStateSnapshotEmitted) == [] + + +def test_raised_turn_exports_only_the_opening_flow_snapshot() -> None: + conversation = Flow.from_steps( + [ + ToolExecutionStep( + tool=ServerTool( + name="explode", + description="Raise an error", + func=lambda: (_ for _ in ()).throw(RuntimeError("boom")), + input_descriptors=[], + ) + ), + CompleteStep(name="end"), + ] + ).start_conversation() + + span_recorder = _SnapshotSpanRecorder() + listener = AgentSpecEventListener() + with AgentSpecTrace(span_processors=[span_recorder]): + with register_event_listeners([listener]): + with pytest.raises(RuntimeError, match="boom"): + conversation.execute( + state_snapshot_policy=StateSnapshotPolicy( + state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS + ) + ) + + flow_span = _single_span(span_recorder, AgentSpecFlowExecutionSpan) + state_snapshot_events = _span_events(flow_span, AgentSpecStateSnapshotEmitted) + + assert len(state_snapshot_events) == 1 + assert state_snapshot_events[0].state_snapshot["execution"]["status"] is None diff --git a/wayflowcore/tests/events/test_state_snapshot_event_tracing.py b/wayflowcore/tests/events/test_state_snapshot_event_tracing.py new file mode 100644 index 000000000..930d18f9b --- /dev/null +++ b/wayflowcore/tests/events/test_state_snapshot_event_tracing.py @@ -0,0 +1,70 @@ +# Copyright © 2026 Oracle and/or its affiliates. +# +# This software is under the Apache License 2.0 +# (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0) or Universal Permissive License +# (UPL) 1.0 (LICENSE-UPL or https://oss.oracle.com/licenses/upl), at your option. + +import pytest + +from wayflowcore.events.event import _PII_TEXT_MASK, StateSnapshotEvent + + +def test_state_snapshot_event_requires_conversation_id() -> None: + with pytest.raises(ValueError, match="conversation_id"): + StateSnapshotEvent() + + +def test_state_snapshot_event_allows_missing_state_snapshot() -> None: + event = StateSnapshotEvent(conversation_id="conversation-123") + + assert event.state_snapshot is None + + +def test_state_snapshot_event_requires_state_snapshot_to_be_a_dictionary() -> None: + with pytest.raises(ValueError, match="state_snapshot must be a dictionary"): + StateSnapshotEvent( + conversation_id="conversation-123", + state_snapshot="not-a-dictionary", + ) + + +def test_state_snapshot_event_requires_runtime_conversation_id() -> None: + with pytest.raises(ValueError, match=r"state_snapshot\['conversation'\]\['id'\]"): + StateSnapshotEvent( + conversation_id="conversation-123", + state_snapshot={"conversation": {"messages": []}}, + ) + + +@pytest.mark.parametrize("mask_sensitive_information", [True, False]) +def test_state_snapshot_event_serialization( + mask_sensitive_information: bool, +) -> None: + event = StateSnapshotEvent( + conversation_id="conversation-123", + state_snapshot={"conversation": {"id": "conversation-runtime-123", "messages": []}}, + extra_state={"ui": {"active_tab": "plan"}}, + variable_state={"count": 2}, + name="snapshot", + event_id="evt-1", + timestamp=12, + ) + + serialized_event = event.to_tracing_info(mask_sensitive_information=mask_sensitive_information) + + assert serialized_event["event_type"] == "StateSnapshotEvent" + assert serialized_event["conversation_id"] == "conversation-123" + assert serialized_event["name"] == "snapshot" + assert serialized_event["event_id"] == "evt-1" + assert serialized_event["timestamp"] == 12 + + if mask_sensitive_information: + assert serialized_event["state_snapshot"] == _PII_TEXT_MASK + assert serialized_event["extra_state"] == _PII_TEXT_MASK + assert serialized_event["variable_state"] == _PII_TEXT_MASK + else: + assert serialized_event["state_snapshot"] == { + "conversation": {"id": "conversation-runtime-123", "messages": []} + } + assert serialized_event["extra_state"] == {"ui": {"active_tab": "plan"}} + assert serialized_event["variable_state"] == {"count": 2} diff --git a/wayflowcore/tests/events/test_state_snapshot_runtime_core.py b/wayflowcore/tests/events/test_state_snapshot_runtime_core.py new file mode 100644 index 000000000..76df65fd6 --- /dev/null +++ b/wayflowcore/tests/events/test_state_snapshot_runtime_core.py @@ -0,0 +1,535 @@ +# Copyright © 2026 Oracle and/or its affiliates. +# +# This software is under the Apache License 2.0 +# (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0) or Universal Permissive License +# (UPL) 1.0 (LICENSE-UPL or https://oss.oracle.com/licenses/upl), at your option. + +from dataclasses import dataclass + +import pytest + +from wayflowcore.conversation import Conversation +from wayflowcore.events import Event, EventListener +from wayflowcore.events.event import StateSnapshotEvent +from wayflowcore.events.eventlistener import register_event_listeners +from wayflowcore.executors._events.event import EventType +from wayflowcore.executors.executionstatus import FinishedStatus +from wayflowcore.executors.interrupts.executioninterrupt import InterruptedExecutionStatus +from wayflowcore.executors.statesnapshotpolicy import StateSnapshotInterval, StateSnapshotPolicy +from wayflowcore.flow import Flow +from wayflowcore.property import AnyProperty +from wayflowcore.serialization import dump_conversation_state +from wayflowcore.serialization.serializer import FrozenSerializableDataclass +from wayflowcore.steps import ( + CompleteStep, + FlowExecutionStep, + OutputMessageStep, + ParallelFlowExecutionStep, + ToolExecutionStep, + VariableWriteStep, +) +from wayflowcore.tools import ServerTool +from wayflowcore.variable import Variable + +from ..test_interrupts import OnEventExecutionInterrupt +from ..testhelpers.state_snapshot_testutils import ( + execute_with_state_snapshots, + execute_with_state_snapshots_async, + restore_conversation_from_snapshot_payload, + snapshot_status_types, +) + + +class _LiveConversationSnapshotObserver(EventListener): + def __init__(self, conversation) -> None: + self.conversation = conversation + self.live_snapshots: list[dict[str, Any]] = [] + + def __call__(self, event: Event) -> None: + if isinstance(event, StateSnapshotEvent): + self.live_snapshots.append(dump_conversation_state(self.conversation)) + + +def _make_output_conversation(): + return Flow.from_steps( + [ + OutputMessageStep(message_template="Hello", name="single_step"), + CompleteStep(name="end"), + ] + ).start_conversation() + + +def _make_tool_flow_conversation(): + return Flow.from_steps( + [ + ToolExecutionStep( + tool=ServerTool( + name="say_hi", + description="Say hi", + func=lambda: "hi", + input_descriptors=[], + ), + name="tool_step", + ), + CompleteStep(name="end"), + ] + ).start_conversation() + + +def _make_parent_child_conversation(): + child_flow = Flow.from_steps( + [ + OutputMessageStep(message_template="child", name="child_message"), + CompleteStep(name="child_end"), + ], + name="child_flow", + ) + parent_flow = Flow.from_steps( + [ + FlowExecutionStep(flow=child_flow, name="child_flow_step"), + OutputMessageStep(message_template="parent", name="parent_message"), + CompleteStep(name="end"), + ], + name="parent_flow", + ) + return parent_flow.start_conversation() + + +def _make_parallel_parent_conversation(): + return Flow.from_steps( + [ + ParallelFlowExecutionStep( + flows=[ + Flow.from_steps( + [ + OutputMessageStep( + message_template="left", + output_mapping={OutputMessageStep.OUTPUT: "left_message"}, + ), + CompleteStep(name="left_end"), + ], + name="left_child_flow", + ), + Flow.from_steps( + [ + OutputMessageStep( + message_template="right", + output_mapping={OutputMessageStep.OUTPUT: "right_message"}, + ), + CompleteStep(name="right_end"), + ], + name="right_child_flow", + ), + ], + name="parallel_children", + ), + CompleteStep(name="end"), + ], + name="parallel_parent_flow", + ).start_conversation() + + +def _snapshot_message(snapshot_event: StateSnapshotEvent) -> str | None: + messages = snapshot_event.state_snapshot["conversation"]["messages"] + if not messages: + return None + return messages[-1].get("content") + + +def _snapshot_step_histories(snapshot_events: list[StateSnapshotEvent]) -> list[list[str]]: + return [ + snapshot_event.state_snapshot["execution"]["step_history"] + for snapshot_event in snapshot_events + ] + + +def _execute_tool_flow_with_interval( + interval: StateSnapshotInterval, +) -> tuple[object, list[StateSnapshotEvent]]: + return execute_with_state_snapshots( + _make_tool_flow_conversation(), + state_snapshot_policy=StateSnapshotPolicy(state_snapshot_interval=interval), + ) + + +@dataclass(frozen=True) +class _SerializableButNotJson(FrozenSerializableDataclass): + value: str + + +class _FailOnTerminalSnapshot(EventListener): + def __call__(self, event: Event) -> None: + if not isinstance(event, StateSnapshotEvent): + return + + execution_status = (event.state_snapshot or {}).get("execution", {}).get("status") + if execution_status is not None: + raise RuntimeError("snapshot sink failed") + + +def test_off_policy_emits_no_state_snapshots() -> None: + status, state_snapshot_events = execute_with_state_snapshots( + _make_output_conversation(), + state_snapshot_policy=StateSnapshotPolicy( + state_snapshot_interval=StateSnapshotInterval.OFF + ), + ) + + assert isinstance(status, FinishedStatus) + assert state_snapshot_events == [] + + +def test_conversation_turns_emit_opening_and_closing_snapshots() -> None: + status, state_snapshot_events = execute_with_state_snapshots( + _make_output_conversation(), + state_snapshot_policy=StateSnapshotPolicy( + state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS + ), + ) + + assert isinstance(status, FinishedStatus) + assert len(state_snapshot_events) == 2 + assert snapshot_status_types(state_snapshot_events) == [None, "FinishedStatus"] + assert _snapshot_message(state_snapshot_events[-1]) == "Hello" + assert all( + isinstance(snapshot_event.state_snapshot.get("conversation_state"), str) + for snapshot_event in state_snapshot_events + ) + + +def test_terminal_snapshot_is_synthesized_before_live_status_commit() -> None: + conversation = _make_output_conversation() + collector = [] + observer = _LiveConversationSnapshotObserver(conversation) + + class _Collector(EventListener): + def __call__(self, event: Event) -> None: + if isinstance(event, StateSnapshotEvent): + collector.append(event) + + with register_event_listeners([_Collector(), observer]): + status = conversation.execute( + state_snapshot_policy=StateSnapshotPolicy( + state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS + ) + ) + + assert isinstance(status, FinishedStatus) + assert conversation.status is status + assert observer.live_snapshots[-1]["execution"]["status"] is None + assert collector[-1].state_snapshot["execution"]["status"]["type"] == "FinishedStatus" + assert collector[-1].state_snapshot["execution"]["status_handled"] is False + assert observer.live_snapshots[-1] != { + "conversation": collector[-1].state_snapshot["conversation"], + "execution": collector[-1].state_snapshot["execution"], + } + + +def test_interrupted_conversation_turn_emits_terminal_snapshot() -> None: + status, state_snapshot_events = execute_with_state_snapshots( + _make_output_conversation(), + execution_interrupts=[OnEventExecutionInterrupt(EventType.EXECUTION_END)], + state_snapshot_policy=StateSnapshotPolicy( + state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS + ), + ) + + assert isinstance(status, InterruptedExecutionStatus) + assert snapshot_status_types(state_snapshot_events) == [None, "InterruptedExecutionStatus"] + assert state_snapshot_events[-1].state_snapshot["execution"]["status_handled"] is False + + +def test_node_turns_emit_internal_step_snapshots_and_only_root_turns_are_resumable() -> None: + status, state_snapshot_events = execute_with_state_snapshots( + _make_output_conversation(), + state_snapshot_policy=StateSnapshotPolicy( + state_snapshot_interval=StateSnapshotInterval.NODE_TURNS + ), + ) + + step_histories = [ + snapshot_event.state_snapshot["execution"]["step_history"] + for snapshot_event in state_snapshot_events + ] + + assert isinstance(status, FinishedStatus) + assert len(state_snapshot_events) > 2 + assert snapshot_status_types(state_snapshot_events)[0] is None + assert snapshot_status_types(state_snapshot_events)[-1] == "FinishedStatus" + assert all( + status_type is None for status_type in snapshot_status_types(state_snapshot_events)[1:-1] + ) + assert [] in step_histories + assert ["__StartStep__"] in step_histories + assert ["__StartStep__", "single_step"] in step_histories + assert ["__StartStep__", "single_step", "end"] in step_histories + assert isinstance(state_snapshot_events[0].state_snapshot.get("conversation_state"), str) + assert isinstance(state_snapshot_events[-1].state_snapshot.get("conversation_state"), str) + assert all( + "conversation_state" not in snapshot_event.state_snapshot + for snapshot_event in state_snapshot_events[1:-1] + ) + + +def test_tool_turns_emit_tool_boundary_snapshots_and_only_root_turns_are_resumable() -> None: + status, state_snapshot_events = _execute_tool_flow_with_interval( + StateSnapshotInterval.TOOL_TURNS + ) + + assert isinstance(status, FinishedStatus) + assert len(state_snapshot_events) == 4 + assert snapshot_status_types(state_snapshot_events) == [None, None, None, "FinishedStatus"] + assert isinstance(state_snapshot_events[0].state_snapshot.get("conversation_state"), str) + assert isinstance(state_snapshot_events[-1].state_snapshot.get("conversation_state"), str) + assert all( + "conversation_state" not in snapshot_event.state_snapshot + for snapshot_event in state_snapshot_events[1:-1] + ) + + +def test_all_internal_turns_emit_more_snapshots_than_tool_turns() -> None: + _, tool_turn_events = _execute_tool_flow_with_interval(StateSnapshotInterval.TOOL_TURNS) + status, all_internal_turn_events = _execute_tool_flow_with_interval( + StateSnapshotInterval.ALL_INTERNAL_TURNS + ) + + assert isinstance(status, FinishedStatus) + assert len(all_internal_turn_events) > len(tool_turn_events) + + +def test_all_internal_turns_emit_more_snapshots_than_node_turns() -> None: + _, node_turn_events = _execute_tool_flow_with_interval(StateSnapshotInterval.NODE_TURNS) + status, all_internal_turn_events = _execute_tool_flow_with_interval( + StateSnapshotInterval.ALL_INTERNAL_TURNS + ) + + assert isinstance(status, FinishedStatus) + assert len(all_internal_turn_events) > len(node_turn_events) + + +def test_all_internal_turns_emit_opening_and_terminal_statuses() -> None: + status, all_internal_turn_events = _execute_tool_flow_with_interval( + StateSnapshotInterval.ALL_INTERNAL_TURNS + ) + + assert isinstance(status, FinishedStatus) + assert snapshot_status_types(all_internal_turn_events)[0] is None + assert snapshot_status_types(all_internal_turn_events)[-1] == "FinishedStatus" + + +def test_all_internal_turns_include_node_step_boundaries() -> None: + status, all_internal_turn_events = _execute_tool_flow_with_interval( + StateSnapshotInterval.ALL_INTERNAL_TURNS + ) + step_histories = _snapshot_step_histories(all_internal_turn_events) + + assert isinstance(status, FinishedStatus) + assert ["__StartStep__"] in step_histories + assert ["__StartStep__", "tool_step"] in step_histories + assert ["__StartStep__", "tool_step", "end"] in step_histories + + +def test_all_internal_turns_keep_only_root_turns_resumable() -> None: + status, all_internal_turn_events = _execute_tool_flow_with_interval( + StateSnapshotInterval.ALL_INTERNAL_TURNS + ) + + assert isinstance(status, FinishedStatus) + assert isinstance(all_internal_turn_events[0].state_snapshot.get("conversation_state"), str) + assert isinstance(all_internal_turn_events[-1].state_snapshot.get("conversation_state"), str) + assert all( + "conversation_state" not in snapshot_event.state_snapshot + for snapshot_event in all_internal_turn_events[1:-1] + ) + + +@pytest.mark.anyio +async def test_execute_async_emits_conversation_turn_snapshots() -> None: + status, state_snapshot_events = await execute_with_state_snapshots_async( + _make_output_conversation(), + state_snapshot_policy=StateSnapshotPolicy( + state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS + ), + ) + + assert isinstance(status, FinishedStatus) + assert snapshot_status_types(state_snapshot_events) == [None, "FinishedStatus"] + + +def test_parallel_conversation_turn_snapshots_stay_on_the_root_conversation() -> None: + conversation = _make_parallel_parent_conversation() + status, state_snapshot_events = execute_with_state_snapshots( + conversation, + state_snapshot_policy=StateSnapshotPolicy( + state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS + ), + ) + + assert isinstance(status, FinishedStatus) + assert snapshot_status_types(state_snapshot_events) == [None, "FinishedStatus"] + assert [snapshot_event.conversation_id for snapshot_event in state_snapshot_events] == [ + conversation.conversation_id, + conversation.conversation_id, + ] + assert [ + snapshot_event.state_snapshot["conversation"]["id"] + for snapshot_event in state_snapshot_events + ] == [conversation.id, conversation.id] + + restored_conversation = restore_conversation_from_snapshot_payload( + state_snapshot_events[-1].state_snapshot + ) + resumed_status = restored_conversation.execute() + + assert isinstance(resumed_status, FinishedStatus) + assert sorted(message.content for message in restored_conversation.get_messages()) == [ + "left", + "right", + ] + + +def test_nested_conversation_turn_snapshots_stay_on_the_root_conversation() -> None: + conversation = _make_parent_child_conversation() + status, state_snapshot_events = execute_with_state_snapshots( + conversation, + state_snapshot_policy=StateSnapshotPolicy( + state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS + ), + ) + + assert isinstance(status, FinishedStatus) + assert len(state_snapshot_events) == 2 + assert snapshot_status_types(state_snapshot_events) == [None, "FinishedStatus"] + assert [snapshot_event.conversation_id for snapshot_event in state_snapshot_events] == [ + conversation.conversation_id, + conversation.conversation_id, + ] + assert [ + snapshot_event.state_snapshot["conversation"]["id"] + for snapshot_event in state_snapshot_events + ] == [conversation.id, conversation.id] + assert _snapshot_message(state_snapshot_events[-1]) == "parent" + + +def test_nested_node_turn_snapshots_can_capture_child_runtime_conversation_ids() -> None: + conversation = _make_parent_child_conversation() + status, state_snapshot_events = execute_with_state_snapshots( + conversation, + state_snapshot_policy=StateSnapshotPolicy( + state_snapshot_interval=StateSnapshotInterval.NODE_TURNS + ), + ) + + assert isinstance(status, FinishedStatus) + assert {snapshot_event.conversation_id for snapshot_event in state_snapshot_events} == { + conversation.conversation_id + } + assert isinstance(state_snapshot_events[0].state_snapshot.get("conversation_state"), str) + assert isinstance(state_snapshot_events[-1].state_snapshot.get("conversation_state"), str) + assert all( + "conversation_state" not in snapshot_event.state_snapshot + for snapshot_event in state_snapshot_events[1:-1] + ) + assert any( + snapshot_event.state_snapshot["conversation"]["id"] != conversation.id + for snapshot_event in state_snapshot_events[1:-1] + ) + + +def test_state_snapshot_emits_variable_state_for_successful_flow_execution() -> None: + custom_variable = Variable(name="custom", type=AnyProperty()) + conversation = Flow.from_steps( + [ + VariableWriteStep( + variable=custom_variable, + input_mapping={VariableWriteStep.VALUE: custom_variable.name}, + ), + OutputMessageStep(message_template="done"), + CompleteStep(name="end"), + ], + variables=[custom_variable], + ).start_conversation(inputs={custom_variable.name: "stored-value"}) + + status, state_snapshot_events = execute_with_state_snapshots( + conversation, + state_snapshot_policy=StateSnapshotPolicy( + state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS, + include_variable_state=True, + ), + ) + + assert isinstance(status, FinishedStatus) + assert state_snapshot_events[0].variable_state == {"custom": None} + assert state_snapshot_events[-1].variable_state == {"custom": "stored-value"} + + +def test_state_snapshot_emission_propagates_extra_state_builder_failures() -> None: + output_conversation = _make_output_conversation() + + def broken_builder(_conversation: Conversation) -> dict[str, object]: + raise RuntimeError("boom") + + with pytest.raises(RuntimeError, match="boom"): + output_conversation.execute( + state_snapshot_policy=StateSnapshotPolicy( + state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS, + extra_state_builder=broken_builder, + ) + ) + + assert output_conversation.get_last_message() is None + + +def test_state_snapshot_emission_rejects_non_strict_json_extra_state() -> None: + output_conversation = _make_output_conversation() + + with pytest.raises(TypeError, match="Extra snapshot state .* strict JSON-serializable"): + output_conversation.execute( + state_snapshot_policy=StateSnapshotPolicy( + state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS, + extra_state_builder=lambda _conversation: {"ui": {"preview_count": float("nan")}}, + ) + ) + + assert output_conversation.get_last_message() is None + + +def test_state_snapshot_emission_propagates_unserializable_variable_state() -> None: + custom_variable = Variable(name="custom", type=AnyProperty()) + conversation = Flow.from_steps( + [ + VariableWriteStep( + variable=custom_variable, + input_mapping={VariableWriteStep.VALUE: custom_variable.name}, + ), + OutputMessageStep(message_template="done"), + CompleteStep(name="end"), + ], + variables=[custom_variable], + ).start_conversation(inputs={custom_variable.name: _SerializableButNotJson(value="x")}) + + with pytest.raises(TypeError, match="Variable 'custom' contains a non-JSON-serializable"): + conversation.execute( + state_snapshot_policy=StateSnapshotPolicy( + state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS, + include_variable_state=True, + ) + ) + + assert conversation.get_last_message() is not None + assert conversation.get_last_message().content == "done" + + +def test_state_snapshot_listener_failures_propagate_to_the_caller() -> None: + output_conversation = _make_output_conversation() + + with register_event_listeners([_FailOnTerminalSnapshot()]): + with pytest.raises(RuntimeError, match="snapshot sink failed"): + output_conversation.execute( + state_snapshot_policy=StateSnapshotPolicy( + state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS + ) + ) + + assert output_conversation.get_last_message() is not None + assert output_conversation.get_last_message().content == "Hello" diff --git a/wayflowcore/tests/integration/steps/test_flow_execution_step.py b/wayflowcore/tests/integration/steps/test_flow_execution_step.py index 94783d0b8..d03f9d7df 100644 --- a/wayflowcore/tests/integration/steps/test_flow_execution_step.py +++ b/wayflowcore/tests/integration/steps/test_flow_execution_step.py @@ -22,6 +22,7 @@ CompleteStep, FlowExecutionStep, InputMessageStep, + OutputMessageStep, ToolExecutionStep, ) from wayflowcore.tools import ClientTool, ServerTool, ToolResult @@ -210,6 +211,58 @@ def test_sub_conversation_shares_same_message_list_as_main_conversation() -> Non assert conversation.get_messages() == subconversation.get_messages() +def test_subflow_execution_reuses_finished_child_conversation_and_preserves_branch() -> None: + child_success = CompleteStep(name="child_success") + child_message = OutputMessageStep(message_template="child", name="child_message") + child_flow = Flow( + begin_step=child_message, + steps={ + "child_message": child_message, + "child_success": child_success, + }, + control_flow_edges=[ + ControlFlowEdge(child_message, child_success), + ], + ) + child_step = FlowExecutionStep(flow=child_flow, name="child_step") + success_message = OutputMessageStep(message_template="parent-success", name="success_message") + parent_flow = Flow( + begin_step=child_step, + steps={ + "child_step": child_step, + "success_message": success_message, + }, + control_flow_edges=[ + ControlFlowEdge(child_step, success_message, source_branch="child_success"), + ControlFlowEdge(success_message, None), + ], + ) + conversation = parent_flow.start_conversation() + + sub_conversation = conversation._get_or_create_current_sub_conversation( + step=child_step, + flow=child_flow, + inputs={}, + ) + sub_status = sub_conversation.execute() + assert isinstance(sub_status, FinishedStatus) + assert sub_status.complete_step_name == "child_success" + + async def _should_not_execute_again() -> FinishedStatus: + raise AssertionError("finished child conversation should not be executed again") + + sub_conversation.execute_async = _should_not_execute_again # type: ignore[method-assign] + + status = conversation.execute() + + assert isinstance(status, FinishedStatus) + assert status.output_values == {"output_message": "parent-success"} + assert [message.content for message in conversation.get_messages()] == [ + "child", + "parent-success", + ] + + def test_subflow_execution_might_yield_when_flow_contains_yielding_steps() -> None: step = FlowExecutionStep( flow=Flow.from_steps( diff --git a/wayflowcore/tests/integration/steps/test_parallel_flow_execution_step.py b/wayflowcore/tests/integration/steps/test_parallel_flow_execution_step.py index e1a41b75c..65a1ec5e1 100644 --- a/wayflowcore/tests/integration/steps/test_parallel_flow_execution_step.py +++ b/wayflowcore/tests/integration/steps/test_parallel_flow_execution_step.py @@ -9,10 +9,11 @@ import pytest from wayflowcore import Tool +from wayflowcore.executors.executionstatus import FinishedStatus from wayflowcore.flow import Flow from wayflowcore.flowhelpers import run_step_and_return_outputs from wayflowcore.property import IntegerProperty, ListProperty, StringProperty -from wayflowcore.steps import InputMessageStep, OutputMessageStep, ToolExecutionStep +from wayflowcore.steps import CompleteStep, InputMessageStep, OutputMessageStep, ToolExecutionStep from wayflowcore.steps.parallelflowexecutionstep import ParallelFlowExecutionStep from wayflowcore.tools import ServerTool @@ -225,6 +226,53 @@ def test_parallel_subflow_execution_cannot_yield(): assert step.might_yield is False +def test_parallel_subflow_execution_reuses_previously_finished_child_conversations() -> None: + left_flow = get_flow_from_tools([get_tool(output_name="left_output")]) + right_flow = get_flow_from_tools([get_tool(output_name="right_output")]) + step = ParallelFlowExecutionStep(flows=[left_flow, right_flow], name="parallel") + conversation = Flow.from_steps([step, CompleteStep(name="end")]).start_conversation() + + for sub_flow in step.flows: + sub_conversation = conversation._get_or_create_current_sub_conversation( + step=step, + flow=sub_flow, + inputs={}, + sub_conversation_id=sub_flow.id, + ) + status = sub_conversation.execute() + assert isinstance(status, FinishedStatus) + + async def _should_not_execute_again() -> FinishedStatus: + raise AssertionError("finished child conversation should not be executed again") + + sub_conversation.execute_async = _should_not_execute_again # type: ignore[method-assign] + + status = conversation.execute() + + assert isinstance(status, FinishedStatus) + assert status.output_values == { + "left_output": "{}_from_left_output", + "right_output": "{}_from_right_output", + } + + +def test_parallel_subflow_execution_cleans_up_custom_id_sub_conversations_from_parent() -> None: + left_flow = get_flow_from_tools([get_tool(output_name="left_output")]) + right_flow = get_flow_from_tools([get_tool(output_name="right_output")]) + step = ParallelFlowExecutionStep(flows=[left_flow, right_flow], name="parallel") + conversation = Flow.from_steps([step, CompleteStep(name="end")]).start_conversation() + + status = conversation.execute() + + assert isinstance(status, FinishedStatus) + assert ( + conversation._get_current_sub_conversation(step, sub_conversation_id=left_flow.id) is None + ) + assert ( + conversation._get_current_sub_conversation(step, sub_conversation_id=right_flow.id) is None + ) + + def test_step_raises_when_two_flows_have_inputs_with_same_name_but_different_types(): tool_1 = ServerTool( name="n", description="d", input_descriptors=[ListProperty(name="i")], func=lambda: "" diff --git a/wayflowcore/tests/serialization/test_conversation_state_snapshot.py b/wayflowcore/tests/serialization/test_conversation_state_snapshot.py new file mode 100644 index 000000000..1004ab250 --- /dev/null +++ b/wayflowcore/tests/serialization/test_conversation_state_snapshot.py @@ -0,0 +1,404 @@ +# Copyright © 2026 Oracle and/or its affiliates. +# +# This software is under the Apache License 2.0 +# (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0) or Universal Permissive License +# (UPL) 1.0 (LICENSE-UPL or https://oss.oracle.com/licenses/upl), at your option. + +import json +from typing import Any + +import pytest + +from wayflowcore.controlconnection import ControlFlowEdge +from wayflowcore.conversation import Conversation +from wayflowcore.dataconnection import DataFlowEdge +from wayflowcore.executors.executionstatus import ( + FinishedStatus, + ToolRequestStatus, + UserMessageRequestStatus, +) +from wayflowcore.executors.statesnapshotpolicy import StateSnapshotInterval, StateSnapshotPolicy +from wayflowcore.flow import Flow +from wayflowcore.property import AnyProperty, StringProperty +from wayflowcore.serialization import ( + deserialize_conversation_state, + dump_conversation_state, + dump_variable_state, + load_conversation_state, + serialize_conversation_state, +) +from wayflowcore.serialization.context import DeserializationContext +from wayflowcore.steps import ( + CompleteStep, + InputMessageStep, + OutputMessageStep, + ToolExecutionStep, + VariableReadStep, + VariableWriteStep, +) +from wayflowcore.tools import ClientTool, ServerTool, ToolResult, register_server_tool +from wayflowcore.variable import Variable + +from ..testhelpers.state_snapshot_testutils import ( + execute_with_state_snapshots, + restore_conversation_from_snapshot_payload, +) + + +class _UnserializableValue: + def __str__(self) -> str: + return "custom-value" + + +def _build_snapshot_flow(custom_variable: Variable) -> Flow: + return Flow.from_steps( + [ + VariableWriteStep( + variable=custom_variable, + input_mapping={VariableWriteStep.VALUE: custom_variable.name}, + ), + OutputMessageStep(message_template="Hello there"), + ], + variables=[custom_variable], + name="snapshot_flow", + ) + + +def _build_non_finite_input_snapshot_flow() -> Flow: + return Flow.from_steps( + [ + ToolExecutionStep( + tool=ServerTool( + name="echo", + description="Echo input", + func=lambda bad: str(bad), + input_descriptors=[AnyProperty(name="bad")], + ) + ), + CompleteStep(name="end"), + ], + name="non_finite_snapshot_flow", + ) + + +def _make_snapshot_flow_conversation( + *, + variable_type: StringProperty | AnyProperty, + input_value: Any, +) -> tuple[Variable, Conversation]: + custom_variable = Variable( + name="custom", + type=variable_type, + description="Custom variable used for snapshot serialization tests", + ) + conversation = _build_snapshot_flow(custom_variable).start_conversation( + inputs={custom_variable.name: input_value} + ) + conversation.execute() + return custom_variable, conversation + + +def _build_user_input_resume_flow() -> Flow: + return Flow.from_steps( + [InputMessageStep("Please answer"), OutputMessageStep("done")], + name="resume_flow", + ) + + +def _conversation_turn_snapshot_payload( + conversation: Conversation, +) -> tuple[object, dict[str, Any]]: + status, state_snapshot_events = execute_with_state_snapshots( + conversation, + state_snapshot_policy=StateSnapshotPolicy( + state_snapshot_interval=StateSnapshotInterval.CONVERSATION_TURNS + ), + ) + + assert state_snapshot_events[-1].state_snapshot is not None + return status, state_snapshot_events[-1].state_snapshot + + +def _walk_scalars(value: Any): + if isinstance(value, dict): + for inner_value in value.values(): + yield from _walk_scalars(inner_value) + return + if isinstance(value, list): + for inner_value in value: + yield from _walk_scalars(inner_value) + return + yield value + + +def test_dump_conversation_state_is_strict_json_serializable_and_lightweight() -> None: + _, conversation = _make_snapshot_flow_conversation( + variable_type=StringProperty(), + input_value="custom-value", + ) + + snapshot = dump_conversation_state(conversation) + + assert json.loads(json.dumps(snapshot, allow_nan=False)) == snapshot + assert dump_variable_state(conversation) == {"custom": "custom-value"} + assert snapshot["conversation"]["component_type"] == "Flow" + assert snapshot["conversation"]["messages"][-1]["content"] == "Hello there" + assert all( + not isinstance(scalar, (Conversation, Flow, OutputMessageStep)) + for scalar in _walk_scalars(snapshot) + ) + + +def test_dump_conversation_state_includes_runtime_conversation_ids() -> None: + _, conversation = _make_snapshot_flow_conversation( + variable_type=StringProperty(), + input_value="custom-value", + ) + + snapshot = dump_conversation_state(conversation) + + assert snapshot["conversation"]["id"] == conversation.id + assert snapshot["conversation"]["conversation_id"] == conversation.conversation_id + + +def test_dump_conversation_state_status_overrides_do_not_mutate_live_conversation() -> None: + _, conversation = _make_snapshot_flow_conversation( + variable_type=StringProperty(), + input_value="custom-value", + ) + + previous_status = conversation.status + previous_status_handled = conversation.status_handled + + snapshot = dump_conversation_state(conversation, status=None, status_handled=True) + + assert snapshot["execution"]["status"] is None + assert snapshot["execution"]["status_handled"] is True + assert conversation.status is previous_status + assert conversation.status_handled is previous_status_handled + + +def test_dump_variable_state_rejects_non_json_serializable_values() -> None: + _, conversation = _make_snapshot_flow_conversation( + variable_type=AnyProperty(), + input_value=_UnserializableValue(), + ) + + with pytest.raises(TypeError, match="Variable 'custom' contains a non-JSON-serializable"): + dump_variable_state(conversation) + + +@pytest.mark.parametrize( + ("value", "expected_dumped_value"), + [ + pytest.param(float("nan"), "NaN", id="nan"), + pytest.param(float("inf"), "Infinity", id="infinity"), + pytest.param(float("-inf"), "-Infinity", id="negative-infinity"), + ], +) +def test_dump_conversation_state_normalizes_non_finite_floats_for_strict_json( + value: float, + expected_dumped_value: str, +) -> None: + conversation = _build_non_finite_input_snapshot_flow().start_conversation(inputs={"bad": value}) + + snapshot = dump_conversation_state(conversation) + + assert json.loads(json.dumps(snapshot, allow_nan=False)) == snapshot + assert snapshot["conversation"]["inputs"]["bad"] == expected_dumped_value + + +def test_serialized_conversation_roundtrip_preserves_pending_tool_results() -> None: + client_tool = ClientTool( + name="client_lookup", + description="Look up some data on the client side", + parameters={}, + ) + conversation = Flow.from_steps( + [ + ToolExecutionStep(tool=client_tool), + CompleteStep(name="end"), + ], + name="tool_resume_flow", + ).start_conversation() + + status = conversation.execute() + assert isinstance(status, ToolRequestStatus) + + tool_request = status.tool_requests[0] + conversation.append_tool_result( + ToolResult(tool_request_id=tool_request.tool_request_id, content="client-result") + ) + + loaded_conversation = load_conversation_state( + deserialize_conversation_state(serialize_conversation_state(conversation)) + ) + loaded_snapshot = dump_conversation_state(loaded_conversation) + + assert loaded_snapshot["execution"]["status"]["tool_results"] == [ + { + "tool_request_id": tool_request.tool_request_id, + "content": "client-result", + } + ] + + resumed_status = loaded_conversation.execute() + + assert isinstance(resumed_status, FinishedStatus) + tool_result_messages = [ + message.tool_result for message in loaded_conversation.get_messages() if message.tool_result + ] + assert len(tool_result_messages) == 1 + assert tool_result_messages[0].tool_request_id == tool_request.tool_request_id + assert tool_result_messages[0].content == "client-result" + + +def test_emitted_snapshot_payload_restores_waiting_for_user_input() -> None: + status, snapshot_payload = _conversation_turn_snapshot_payload( + _build_user_input_resume_flow().start_conversation() + ) + + assert isinstance(status, UserMessageRequestStatus) + + restored_conversation = restore_conversation_from_snapshot_payload( + json.loads(json.dumps(snapshot_payload, allow_nan=False)) + ) + restored_conversation.append_user_message("hello") + resumed_status = restored_conversation.execute() + + assert isinstance(resumed_status, FinishedStatus) + assert [message.content for message in restored_conversation.get_messages()] == [ + "Please answer", + "hello", + "done", + ] + + +def test_emitted_snapshot_payload_restores_waiting_for_client_tool_result() -> None: + client_tool = ClientTool( + name="client_lookup", + description="Look up some data on the client side", + parameters={}, + ) + conversation = Flow.from_steps( + [ + ToolExecutionStep(tool=client_tool), + CompleteStep(name="end"), + ], + name="snapshot_client_tool_resume_flow", + ).start_conversation() + + status, snapshot_payload = _conversation_turn_snapshot_payload(conversation) + + assert isinstance(status, ToolRequestStatus) + + restored_conversation = restore_conversation_from_snapshot_payload(snapshot_payload) + assert isinstance(restored_conversation.status, ToolRequestStatus) + + tool_request = restored_conversation.status.tool_requests[0] + restored_conversation.append_tool_result( + ToolResult(tool_request_id=tool_request.tool_request_id, content="client-result") + ) + resumed_status = restored_conversation.execute() + + assert isinstance(resumed_status, FinishedStatus) + tool_result_messages = [ + message.tool_result + for message in restored_conversation.get_messages() + if message.tool_result + ] + assert len(tool_result_messages) == 1 + assert tool_result_messages[0].tool_request_id == tool_request.tool_request_id + assert tool_result_messages[0].content == "client-result" + + +def test_emitted_snapshot_payload_restores_variable_dependent_continuation() -> None: + customer_name = Variable( + name="customer_name", + type=StringProperty(), + description="Customer name persisted across resumable snapshots", + ) + capture_name = VariableWriteStep( + variable=customer_name, + input_mapping={VariableWriteStep.VALUE: customer_name.name}, + name="capture_name", + ) + ask_follow_up = InputMessageStep( + message_template="How can I help {{customer_name}}?", + name="ask_follow_up", + ) + read_name = VariableReadStep(variable=customer_name, name="read_name") + final_message = OutputMessageStep( + message_template="Stored {{stored_name}}. Reply: {{reply}}", + name="final_message", + ) + flow = Flow( + begin_step=capture_name, + steps={ + "capture_name": capture_name, + "ask_follow_up": ask_follow_up, + "read_name": read_name, + "final_message": final_message, + }, + control_flow_edges=[ + ControlFlowEdge(capture_name, ask_follow_up), + ControlFlowEdge(ask_follow_up, read_name), + ControlFlowEdge(read_name, final_message), + ControlFlowEdge(final_message, None), + ], + data_flow_edges=[ + DataFlowEdge( + ask_follow_up, + InputMessageStep.USER_PROVIDED_INPUT, + final_message, + "reply", + ), + DataFlowEdge(read_name, VariableReadStep.VALUE, final_message, "stored_name"), + ], + variables=[customer_name], + name="snapshot_variable_resume_flow", + ) + conversation = flow.start_conversation(inputs={customer_name.name: "Alice"}) + + status, snapshot_payload = _conversation_turn_snapshot_payload(conversation) + + assert isinstance(status, UserMessageRequestStatus) + + restored_conversation = restore_conversation_from_snapshot_payload(snapshot_payload) + assert dump_variable_state(restored_conversation) == {"customer_name": "Alice"} + + restored_conversation.append_user_message("Need pricing") + resumed_status = restored_conversation.execute() + + assert isinstance(resumed_status, FinishedStatus) + assert [message.content for message in restored_conversation.get_messages()] == [ + "How can I help Alice?", + "Need pricing", + "Stored Alice. Reply: Need pricing", + ] + + +def test_load_conversation_state_uses_the_given_deserialization_context() -> None: + tool = ServerTool( + name="say_hi", + description="Say hi", + func=lambda: "hi", + input_descriptors=[], + ) + flow = Flow.from_steps( + [ + ToolExecutionStep(tool=tool), + CompleteStep(name="end"), + ], + name="tool_flow", + ) + + deserialization_context = DeserializationContext() + register_server_tool(tool, deserialization_context.registered_tools) + + conversation = load_conversation_state( + deserialize_conversation_state(serialize_conversation_state(flow.start_conversation())), + deserialization_context=deserialization_context, + ) + + assert isinstance(conversation.execute(), FinishedStatus) diff --git a/wayflowcore/tests/testhelpers/state_snapshot_testutils.py b/wayflowcore/tests/testhelpers/state_snapshot_testutils.py new file mode 100644 index 000000000..3b4ff7664 --- /dev/null +++ b/wayflowcore/tests/testhelpers/state_snapshot_testutils.py @@ -0,0 +1,84 @@ +# Copyright © 2026 Oracle and/or its affiliates. +# +# This software is under the Apache License 2.0 +# (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0) or Universal Permissive License +# (UPL) 1.0 (LICENSE-UPL or https://oss.oracle.com/licenses/upl), at your option. + +from contextlib import AbstractContextManager, nullcontext +from typing import Any, Sequence + +from wayflowcore.conversation import Conversation +from wayflowcore.events.event import Event, StateSnapshotEvent +from wayflowcore.events.eventlistener import EventListener, register_event_listeners +from wayflowcore.executors.interrupts.executioninterrupt import ExecutionInterrupt +from wayflowcore.executors.statesnapshotpolicy import StateSnapshotPolicy +from wayflowcore.serialization import deserialize_conversation, dump_conversation_state + + +class SnapshotCollector(EventListener): + def __init__(self) -> None: + self.state_snapshot_events: list[StateSnapshotEvent] = [] + + def __call__(self, event: Event) -> None: + if isinstance(event, StateSnapshotEvent): + self.state_snapshot_events.append(event) + + +def snapshot_status_types(snapshot_events: Sequence[StateSnapshotEvent]) -> list[str | None]: + return [ + status["type"] if (status := snapshot_event.state_snapshot["execution"]["status"]) else None + for snapshot_event in snapshot_events + ] + + +def execute_with_state_snapshots( + conversation: Conversation, + *, + state_snapshot_policy: StateSnapshotPolicy, + execution_interrupts: Sequence[ExecutionInterrupt] | None = None, + execution_context: AbstractContextManager[Any] | None = None, +) -> tuple[object, list[StateSnapshotEvent]]: + collector = SnapshotCollector() + + with execution_context or nullcontext(): + with register_event_listeners([collector]): + status = conversation.execute( + execution_interrupts=execution_interrupts, + state_snapshot_policy=state_snapshot_policy, + ) + + return status, collector.state_snapshot_events + + +async def execute_with_state_snapshots_async( + conversation: Conversation, + *, + state_snapshot_policy: StateSnapshotPolicy, + execution_interrupts: Sequence[ExecutionInterrupt] | None = None, + execution_context: AbstractContextManager[Any] | None = None, +) -> tuple[object, list[StateSnapshotEvent]]: + collector = SnapshotCollector() + + with execution_context or nullcontext(): + with register_event_listeners([collector]): + status = await conversation.execute_async( + execution_interrupts=execution_interrupts, + state_snapshot_policy=state_snapshot_policy, + ) + + return status, collector.state_snapshot_events + + +def restore_conversation_from_snapshot_payload( + snapshot_payload: dict[str, Any], +) -> Conversation: + assert snapshot_payload["runtime"] == "wayflow" + assert snapshot_payload["schema_version"] == 1 + assert isinstance(snapshot_payload["conversation_state"], str) + + restored_conversation = deserialize_conversation(snapshot_payload["conversation_state"]) + assert dump_conversation_state(restored_conversation) == { + "conversation": snapshot_payload["conversation"], + "execution": snapshot_payload["execution"], + } + return restored_conversation