From 202598f954653c8cbdfe7ab086bcf09ea37ca4af Mon Sep 17 00:00:00 2001 From: Conrad Date: Sat, 27 Jun 2026 17:50:41 -0400 Subject: [PATCH 1/7] build: Add uvloop, the stdlib_parity marker, and worker-process coverage Add uvloop as a dev dependency so the parity suite runs under both the default asyncio loop and uvloop, register the stdlib_parity pytest marker, and measure coverage from spawned worker processes so worker- side dispatch code counts toward the gate. --- wool/.coveragerc | 8 ++++++++ wool/pyproject.toml | 2 ++ 2 files changed, 10 insertions(+) diff --git a/wool/.coveragerc b/wool/.coveragerc index 1ccdb537..474e88ea 100644 --- a/wool/.coveragerc +++ b/wool/.coveragerc @@ -1,5 +1,13 @@ [run] source = src/wool +# Measure code that runs in spawned worker processes (WorkerProcess is a +# ``multiprocessing.get_context("spawn").Process``). Without this the +# integration suite's worker-side coverage (service.py/session.py) is +# invisible, understating what those tests actually exercise. ``sigterm`` +# flushes coverage when a worker is stopped via SIGTERM at teardown. +concurrency = multiprocessing,thread +parallel = true +sigterm = true omit = src/wool/__init__.py src/wool/cli.py diff --git a/wool/pyproject.toml b/wool/pyproject.toml index 05158357..21fdac5c 100644 --- a/wool/pyproject.toml +++ b/wool/pyproject.toml @@ -50,6 +50,7 @@ dev = [ "pytest-grpc-aio~=0.3.0", "pytest-mock", "ruff", + "uvloop", ] [project.scripts] @@ -94,6 +95,7 @@ addopts = "--cov --cov-config=.coveragerc" pythonpath = ["."] markers = [ "integration: end-to-end integration tests against real WorkerPool", + "stdlib_parity: stdlib contextvars propagation parity pins", ] [tool.ruff] From fa8f6d161bd87569e017e9f8fc5abee982dc9362 Mon Sep 17 00:00:00 2001 From: Conrad Date: Sat, 27 Jun 2026 17:50:41 -0400 Subject: [PATCH 2/7] refactor!: Rename the wire context message to ChainManifest Reshape the protocol wire contract for the chain model: rename the per-dispatch wire message to ChainManifest carrying value-only ContextVar entries, adjust the package re-exports, and rename the protocol exception module to its plural form. --- wool/proto/wire.proto | 43 +++++++--------- wool/src/wool/protocol/__init__.py | 62 +++++++++-------------- wool/src/wool/protocol/_wire.py | 74 ++++++++++++++++++++++++++++ wool/src/wool/protocol/exception.py | 3 -- wool/src/wool/protocol/exceptions.py | 24 +++++++++ 5 files changed, 139 insertions(+), 67 deletions(-) create mode 100644 wool/src/wool/protocol/_wire.py delete mode 100644 wool/src/wool/protocol/exception.py create mode 100644 wool/src/wool/protocol/exceptions.py diff --git a/wool/proto/wire.proto b/wool/proto/wire.proto index 1b5d7d15..de36f498 100644 --- a/wool/proto/wire.proto +++ b/wool/proto/wire.proto @@ -30,39 +30,30 @@ message TaskEnvelope { string tag = 4; } -// Wire shape for a single wool.ContextVar within a Context +// Wire shape for a single wool.ContextVar within a ChainManifest // snapshot. Mirrors the in-memory wool.ContextVar identity -// ((namespace, name) pair) and carries the var's per-snapshot -// state: an optional cloudpickled value and the hex-encoded ids -// of any wool.Tokens this var has minted that were subsequently -// consumed by wool.ContextVar.reset in the enclosing Context's -// logical chain. +// and carries the var's per-snapshot state: an optional +// cloudpickled value. An absent value signals a reset-to- +// no-prior-value on the sender's chain. message ContextVar { // Namespace component of the ContextVar identity. string namespace = 1; // Name component of the ContextVar identity. string name = 2; - // Cloudpickled value. Unset when the var has no current - // value in this snapshot (e.g. it was reset to no prior - // value but a consumed token still needs to propagate). + // Cloudpickled value. Unset when the var has no + // current value in this snapshot. optional bytes value = 3; - // Hex-encoded ids of wool.Tokens minted by this var that - // have been consumed by wool.ContextVar.reset in this - // logical chain. Rides forward- and back-propagation so - // receivers (a) see consumed tokens as used and cannot - // double-reset and (b) pop the corresponding var from - // their local Context, completing the reset signal. - repeated string consumed_tokens = 4; } -// wool.Context wire shape. Rides on every Request and Response -// frame to carry the caller's wool.ContextVar snapshot and the -// wool.Context id that scopes the logical execution chain. -message Context { - // wool.Context id as hex string. Stable across sequential +// wool.Chain wire shape — the manifest of a chain crossing the +// wire. Rides on every Request and Response frame to carry the +// sender's wool.ContextVar snapshot and the wool.Chain id that +// scopes the logical execution chain. +message ChainManifest { + // wool.Chain id as hex string. Stable across sequential // awaits inside the same asyncio task; fresh on // asyncio.create_task boundaries. Empty for root dispatches - // that will start a new context on the receiver. + // that will start a new chain on the receiver. string id = 1; // ContextVar entries — each one carries identity, optional // value, and any consumed-token ids for that var. @@ -104,10 +95,10 @@ message Request { Message send = 3; Message throw = 4; } - // Caller's wool.Context snapshot. Propagates the caller's + // Caller's wool.Chain snapshot. Propagates the caller's // state to the worker on task frames, and forward-propagates // caller mutations to the worker on streaming frames. - Context context = 5; + ChainManifest context = 5; } message Response { @@ -117,9 +108,9 @@ message Response { Message result = 3; Message exception = 4; } - // Worker's post-yield/post-return wool.Context snapshot, + // Worker's post-yield/post-return wool.Chain snapshot, // back-propagated to the caller. - Context context = 5; + ChainManifest context = 5; } message StopRequest { diff --git a/wool/src/wool/protocol/__init__.py b/wool/src/wool/protocol/__init__.py index d5e799fb..4029d5a9 100644 --- a/wool/src/wool/protocol/__init__.py +++ b/wool/src/wool/protocol/__init__.py @@ -1,53 +1,39 @@ -import os -import sys from importlib.metadata import PackageNotFoundError from importlib.metadata import version -from typing import Protocol try: __version__ = version("wool") except PackageNotFoundError: __version__ = "unknown" -sys.path.insert(0, os.path.abspath(os.path.dirname(__file__))) - -try: - from wool.protocol.wire_pb2 import Ack - from wool.protocol.wire_pb2 import ChannelOptions - from wool.protocol.wire_pb2 import Context - from wool.protocol.wire_pb2 import ContextVar - from wool.protocol.wire_pb2 import Message - from wool.protocol.wire_pb2 import Nack - from wool.protocol.wire_pb2 import Request - from wool.protocol.wire_pb2 import Response - from wool.protocol.wire_pb2 import RuntimeContext - from wool.protocol.wire_pb2 import StopRequest - from wool.protocol.wire_pb2 import Task - from wool.protocol.wire_pb2 import TaskEnvelope - from wool.protocol.wire_pb2 import Void - from wool.protocol.wire_pb2 import WorkerMetadata - from wool.protocol.wire_pb2_grpc import WorkerServicer - from wool.protocol.wire_pb2_grpc import WorkerStub - from wool.protocol.wire_pb2_grpc import add_WorkerServicer_to_server -except ImportError as e: - from wool.protocol.exception import ProtobufImportError - - raise ProtobufImportError(e) from e - - -class AddServicerToServerProtocol(Protocol): - @staticmethod - def __call__(servicer, server) -> None: ... - - -add_to_server: dict[type[WorkerServicer], AddServicerToServerProtocol] = { - WorkerServicer: add_WorkerServicer_to_server, -} +from wool.protocol._wire import Ack as Ack +from wool.protocol._wire import ( + AddServicerToServerProtocol as AddServicerToServerProtocol, +) +from wool.protocol._wire import ChainManifest as ChainManifest +from wool.protocol._wire import ChannelOptions as ChannelOptions +from wool.protocol._wire import ContextVar as ContextVar +from wool.protocol._wire import Message as Message +from wool.protocol._wire import Nack as Nack +from wool.protocol._wire import Request as Request +from wool.protocol._wire import Response as Response +from wool.protocol._wire import RuntimeContext as RuntimeContext +from wool.protocol._wire import StopRequest as StopRequest +from wool.protocol._wire import Task as Task +from wool.protocol._wire import TaskEnvelope as TaskEnvelope +from wool.protocol._wire import Void as Void +from wool.protocol._wire import WorkerMetadata as WorkerMetadata +from wool.protocol._wire import WorkerServicer as WorkerServicer +from wool.protocol._wire import WorkerStub as WorkerStub +from wool.protocol._wire import add_to_server as add_to_server +from wool.protocol._wire import ( + add_WorkerServicer_to_server as add_WorkerServicer_to_server, +) __all__ = [ "Ack", + "ChainManifest", "ChannelOptions", - "Context", "ContextVar", "Message", "Nack", diff --git a/wool/src/wool/protocol/_wire.py b/wool/src/wool/protocol/_wire.py new file mode 100644 index 00000000..1c551e32 --- /dev/null +++ b/wool/src/wool/protocol/_wire.py @@ -0,0 +1,74 @@ +import os +import sys +from typing import Protocol + +# Path hack — grpc-tools emits a sibling-flat ``import wire_pb2`` +# inside the generated ``wire_pb2_grpc.py``, which only resolves with +# the package directory on ``sys.path``. Upstream wontfix: see +# protocolbuffers/protobuf#1491 and grpc/grpc#29459 (which endorses +# this exact workaround). Removal tracked in wool-labs/wool#236. +sys.path.insert(0, os.path.abspath(os.path.dirname(__file__))) + +try: + from wool.protocol.wire_pb2 import Ack + from wool.protocol.wire_pb2 import ChainManifest + from wool.protocol.wire_pb2 import ChannelOptions + from wool.protocol.wire_pb2 import ContextVar + from wool.protocol.wire_pb2 import Message + from wool.protocol.wire_pb2 import Nack + from wool.protocol.wire_pb2 import Request + from wool.protocol.wire_pb2 import Response + from wool.protocol.wire_pb2 import RuntimeContext + from wool.protocol.wire_pb2 import StopRequest + from wool.protocol.wire_pb2 import Task + from wool.protocol.wire_pb2 import TaskEnvelope + from wool.protocol.wire_pb2 import Void + from wool.protocol.wire_pb2 import WorkerMetadata + from wool.protocol.wire_pb2_grpc import WorkerServicer + from wool.protocol.wire_pb2_grpc import WorkerStub + from wool.protocol.wire_pb2_grpc import add_WorkerServicer_to_server +except ImportError as e: + from wool.protocol.exceptions import ProtobufImportError + + raise ProtobufImportError(e) from e + + +class AddServicerToServerProtocol(Protocol): + """Callable signature shape for ``add_ServicerToServer`` helpers. + + grpc-tools generates one such helper per service. The + :data:`add_to_server` table maps the service class to its + matching helper so the worker process can look up the correct + add-to-server function at startup. + """ + + @staticmethod + def __call__(servicer, server) -> None: ... + + +add_to_server: dict[type[WorkerServicer], AddServicerToServerProtocol] = { + WorkerServicer: add_WorkerServicer_to_server, +} + + +__all__ = [ + "Ack", + "AddServicerToServerProtocol", + "ChainManifest", + "ChannelOptions", + "ContextVar", + "Message", + "Nack", + "Request", + "Response", + "RuntimeContext", + "StopRequest", + "Task", + "TaskEnvelope", + "Void", + "WorkerMetadata", + "WorkerServicer", + "WorkerStub", + "add_WorkerServicer_to_server", + "add_to_server", +] diff --git a/wool/src/wool/protocol/exception.py b/wool/src/wool/protocol/exception.py deleted file mode 100644 index 9345999e..00000000 --- a/wool/src/wool/protocol/exception.py +++ /dev/null @@ -1,3 +0,0 @@ -class ProtobufImportError(ImportError): - def __init__(self, exception: ImportError): - super().__init__(f"{str(exception)} - ensure protocol buffers are compiled.") diff --git a/wool/src/wool/protocol/exceptions.py b/wool/src/wool/protocol/exceptions.py new file mode 100644 index 00000000..a6ad4cd1 --- /dev/null +++ b/wool/src/wool/protocol/exceptions.py @@ -0,0 +1,24 @@ +"""The protocol package's import-failure signal. + +Provides `ProtobufImportError`: the exception that surfaces a failed +import of the generated protobuf wire modules. +""" + + +class ProtobufImportError(ImportError): + """Raised when the generated protobuf wire modules fail to import. + + The generated modules are build artifacts — absent from a source + checkout until the protobuf definitions are compiled — so a failed + import usually means a missing build step, not a packaging defect. + The wrapper preserves the underlying error's message and appends + that remedy. Subclassing `ImportError` keeps the failure catchable + as the import error it is. + + :param exception: + The underlying `ImportError` from the failed import, whose + message this exception's own message extends. + """ + + def __init__(self, exception: ImportError): + super().__init__(f"{str(exception)} - ensure protocol buffers are compiled.") From 3446b5ce088d475bdcd89f456157bef49baaa429 Mon Sep 17 00:00:00 2001 From: Conrad Date: Sat, 27 Jun 2026 17:50:41 -0400 Subject: [PATCH 3/7] refactor!: Re-found the context model on stdlib contextvars Replace the two-system context model with a single Wool-owned stdlib contextvars.ContextVar holding an immutable Chain. Chain state is an index over per-variable backing vars; ChainManifest carries decoded wire state and mounts it onto a live or fresh chain. Add the task factory (copy-on-fork, displacement detection), the chain-ownership guard, the stdlib-aligned to_thread, the RuntimeContext overlay, and a shared variable registry. Root the exception hierarchy at WoolError and WoolWarning, with ChainContention, ContextVarCollision, TaskFactoryDisplaced, SerializationError, ChainSerializationError, and SerializationWarning underneath. Remove the home-grown Token and the base, stub, and token modules. --- wool/src/wool/{exception.py => exceptions.py} | 0 wool/src/wool/runtime/context/__init__.py | 30 - wool/src/wool/runtime/context/base.py | 953 ------------------ wool/src/wool/runtime/context/chain.py | 357 +++++++ wool/src/wool/runtime/context/exceptions.py | 337 +++++++ wool/src/wool/runtime/context/factory.py | 632 ++++++++++++ wool/src/wool/runtime/context/guard.py | 109 ++ wool/src/wool/runtime/context/manifest.py | 479 +++++++++ wool/src/wool/runtime/context/registry.py | 203 +--- wool/src/wool/runtime/context/runtime.py | 144 +++ wool/src/wool/runtime/context/stub.py | 117 --- wool/src/wool/runtime/context/threading.py | 116 +++ wool/src/wool/runtime/context/token.py | 297 ------ wool/src/wool/runtime/context/var.py | 518 ++++++---- 14 files changed, 2512 insertions(+), 1780 deletions(-) rename wool/src/wool/{exception.py => exceptions.py} (100%) delete mode 100644 wool/src/wool/runtime/context/base.py create mode 100644 wool/src/wool/runtime/context/chain.py create mode 100644 wool/src/wool/runtime/context/exceptions.py create mode 100644 wool/src/wool/runtime/context/factory.py create mode 100644 wool/src/wool/runtime/context/guard.py create mode 100644 wool/src/wool/runtime/context/manifest.py create mode 100644 wool/src/wool/runtime/context/runtime.py delete mode 100644 wool/src/wool/runtime/context/stub.py create mode 100644 wool/src/wool/runtime/context/threading.py delete mode 100644 wool/src/wool/runtime/context/token.py diff --git a/wool/src/wool/exception.py b/wool/src/wool/exceptions.py similarity index 100% rename from wool/src/wool/exception.py rename to wool/src/wool/exceptions.py diff --git a/wool/src/wool/runtime/context/__init__.py b/wool/src/wool/runtime/context/__init__.py index 90caca8d..e69de29b 100644 --- a/wool/src/wool/runtime/context/__init__.py +++ b/wool/src/wool/runtime/context/__init__.py @@ -1,30 +0,0 @@ -from wool.runtime.context.base import Context -from wool.runtime.context.base import ContextAlreadyBound -from wool.runtime.context.base import ContextDecodeWarning -from wool.runtime.context.base import RuntimeContext -from wool.runtime.context.base import attached as attached -from wool.runtime.context.base import copy_context -from wool.runtime.context.base import create_task -from wool.runtime.context.base import current_context -from wool.runtime.context.base import dispatch_timeout as dispatch_timeout -from wool.runtime.context.base import install_task_factory as install_task_factory -from wool.runtime.context.registry import context_registry as context_registry -from wool.runtime.context.registry import lock as lock -from wool.runtime.context.registry import scope_key as scope_key -from wool.runtime.context.registry import var_registry as var_registry -from wool.runtime.context.token import Token -from wool.runtime.context.var import ContextVar -from wool.runtime.context.var import ContextVarCollision - -__all__ = [ - "Context", - "ContextAlreadyBound", - "ContextDecodeWarning", - "ContextVar", - "ContextVarCollision", - "RuntimeContext", - "Token", - "copy_context", - "create_task", - "current_context", -] diff --git a/wool/src/wool/runtime/context/base.py b/wool/src/wool/runtime/context/base.py deleted file mode 100644 index db28f498..00000000 --- a/wool/src/wool/runtime/context/base.py +++ /dev/null @@ -1,953 +0,0 @@ -from __future__ import annotations - -import asyncio -import contextvars -import logging -import threading -import warnings -import weakref -from contextlib import contextmanager -from contextlib import nullcontext -from typing import TYPE_CHECKING -from typing import Any -from typing import Callable -from typing import Coroutine -from typing import Final -from typing import Generator -from typing import ItemsView -from typing import Iterator -from typing import KeysView -from typing import NoReturn -from typing import SupportsIndex -from typing import TypeVar -from typing import ValuesView -from typing import cast -from uuid import UUID -from uuid import uuid4 - -import wool -from wool import protocol -from wool.exception import WoolWarning -from wool.runtime.context.registry import context_registry -from wool.runtime.context.registry import lock -from wool.runtime.context.registry import scope_key -from wool.runtime.context.registry import token_registry -from wool.runtime.context.registry import var_registry -from wool.runtime.context.stub import resolve_stub -from wool.runtime.serializer import Serializer -from wool.runtime.typing import Undefined -from wool.runtime.typing import UndefinedType - -if TYPE_CHECKING: - from wool.runtime.context.stub import StubPin - from wool.runtime.context.token import Token - from wool.runtime.context.var import ContextVar - - -# Ambient per-Context dispatch timeout in seconds. ``None`` means no -# timeout. The value scopes to whichever execution chain is currently -# active and rides through nested dispatches until reset or overridden. -dispatch_timeout: Final[contextvars.ContextVar[float | None]] = contextvars.ContextVar( - "dispatch_timeout", default=None -) - -_log = logging.getLogger(__name__) - -_loops_with_factory: weakref.WeakSet[asyncio.AbstractEventLoop] = weakref.WeakSet() - -T = TypeVar("T") - - -# public -class ContextDecodeWarning(WoolWarning): - """Emitted when a wire :class:`protocol.Context` fails to decode. - - Wool's wire protocol treats context propagation as ancillary - state — failures to decode incoming context never preempt the - primary signal (the routine's return value or raised exception). - Instead a :class:`ContextDecodeWarning` is emitted so callers - that depend on context state can detect the inconsistency. - - Callers that prefer strict semantics — treat any decode failure - as fatal — can opt in by promoting the warning to an exception:: - - import warnings - import wool - - warnings.filterwarnings("error", category=wool.ContextDecodeWarning) - - Under strict mode, per-var failures inside :meth:`Context.to_protobuf` - and :meth:`Context.from_protobuf` aggregate into a single - :class:`BaseExceptionGroup` raised after the loop completes — every - bad var surfaces, not just the first. - """ - - -# public -class ContextAlreadyBound(RuntimeError): - """Raised when a task is bound to a :class:`Context` more than once. - - Enforces the one-shot contract: a task is bound exactly once, - at creation time. A second binding attempt would silently stomp - the prior binding, masking bugs where the wrong :class:`Context` - (and thus the wrong chain ID) rides through nested dispatches. - """ - - -# public -class RuntimeContext: - """Block-scoped runtime option overrides for wool routines. - - Used as a context manager to override runtime options (currently - only :data:`dispatch_timeout`) for the duration of a block. Auto- - captured on every :class:`Task` at construction time, which ships - the caller's snapshot across the wire so the worker restores it - before running the routine. - - :param dispatch_timeout: - Default timeout for task dispatch operations. ``None`` means - no timeout. Leaving this argument out (the default sentinel) - has two effects: ``__enter__`` skips setting the stdlib var - — useful for "no-override" usage as a context manager — and - :meth:`to_protobuf` substitutes the current scope's live - :data:`dispatch_timeout` at encode time, so a bare - ``RuntimeContext()`` constructed for wire transport still - propagates the encoder's effective timeout to the receiver. - """ - - _dispatch_timeout: float | None | UndefinedType - _dispatch_timeout_token: contextvars.Token[float | None] | None - - def __init__( - self, - *, - dispatch_timeout: float | None | UndefinedType = Undefined, - ) -> None: - self._dispatch_timeout = dispatch_timeout - self._dispatch_timeout_token = None - - def __enter__(self) -> RuntimeContext: - if self._dispatch_timeout is not Undefined: - self._dispatch_timeout_token = dispatch_timeout.set(self._dispatch_timeout) - return self - - def __exit__(self, *_): - if self._dispatch_timeout_token is not None: - dispatch_timeout.reset(self._dispatch_timeout_token) - self._dispatch_timeout_token = None - - @classmethod - def get_current(cls) -> RuntimeContext: - """Capture the current stdlib :data:`dispatch_timeout` value.""" - return cls(dispatch_timeout=dispatch_timeout.get()) - - @classmethod - def from_protobuf(cls, context: protocol.RuntimeContext) -> RuntimeContext: - """Reconstruct from a :class:`protocol.RuntimeContext` message.""" - return cls( - dispatch_timeout=( - context.dispatch_timeout - if context.HasField("dispatch_timeout") - else None - ) - ) - - def to_protobuf(self) -> protocol.RuntimeContext: - """Serialize to a :class:`protocol.RuntimeContext` message. - - When the instance was constructed without an explicit - ``dispatch_timeout`` (i.e., the default sentinel), the live - :data:`dispatch_timeout` value from the current scope is - captured at encode time and rides the wire. An explicit - :data:`None` skips emission, so the receiver inherits its - own scope's default. - """ - message = protocol.RuntimeContext() - timeout = self._dispatch_timeout - if timeout is Undefined: - timeout = dispatch_timeout.get() - if timeout is not None: - message.dispatch_timeout = timeout - return message - - -# public -class Context: - """Snapshot of :class:`wool.ContextVar` state and context ID, - scoped to a single task at a time. - - Mirrors :class:`contextvars.Context` at the surface — supports - the mapping and container protocols and scopes mutations via - :meth:`Context.run` — but is a parallel mechanism with no shared - state. A :class:`wool.Context` and a :class:`contextvars.Context` - never interact: stdlib :meth:`contextvars.Context.run` does not - fork or clear the :class:`wool.Context`, and vice versa. The - Wool task factory is the boundary where Wool's fork-on-task - semantics engage. - - Beyond the snapshot of :class:`ContextVar` values, a - :class:`Context` carries a ``UUID`` that identifies the - logical chain it belongs to. - - .. caution:: - At most one task may run inside a given :class:`Context` at - a time. :meth:`run` raises :class:`RuntimeError` on - re-entry. - """ - - __slots__ = ( - "_id", - "_data", - "_lock", - "_running", - "_stub_pins", - "_used_tokens", - "_external_used_tokens", - "_bound_task", - "__weakref__", - ) - - _id: UUID - _data: dict[ContextVar[Any], Any] - _lock: threading.Lock - _running: bool - _stub_pins: set[StubPin] - # Tokens consumed locally that still have a live in-process - # instance. Auto-prunes via :class:`weakref.WeakSet` when the - # last reference to a Token is dropped: a consumed Token whose - # only role was double-reset detection has nothing left to - # block once it is unreachable, so its ID can be reclaimed - # along with the instance. - _used_tokens: weakref.WeakSet[Token] - # Tokens known to be consumed but lacking a live in-process - # :class:`Token` instance — typically wire-supplied entries that - # arrived in :meth:`from_protobuf` without a same-process pickle - # round-trip having registered the Token under - # :data:`token_registry`. The map carries each consumed-token id - # alongside the :class:`wool.ContextVar` key the token reset, so - # the receiving :meth:`Context.update` can pop the var from - # :attr:`_data` and propagate the reset signal even when the - # token instance never materialized locally. When a matching - # :class:`Token` later materializes (via - # :meth:`Token._reconstitute`'s promotion hook) the entry - # migrates from this map into :attr:`_used_tokens`. - _external_used_tokens: dict[UUID, tuple[str, str]] - # Weakref to the asyncio task currently scoped via - # :func:`_context_scope`. Set on factory-routed task entry and - # cleared on exit, giving "first task wins for the routine's - # lifetime" semantics — a second factory-routed task targeting - # the same Context while the first is still active raises before - # acquiring :meth:`_guard`. ``None`` outside of a - # :func:`_context_scope` block, or for sync-only callers. - _bound_task: weakref.ref[asyncio.Task[Any]] | None - - def __init__(self) -> None: - self._init_state(context_id=uuid4(), data={}) - - def __iter__(self) -> Iterator[ContextVar[Any]]: - return iter(self._data) - - def __getitem__(self, var: ContextVar[T]) -> T: - return self._data[var] - - def __contains__(self, var: Any) -> bool: - return var in self._data - - def __len__(self) -> int: - return len(self._data) - - def __repr__(self) -> str: - return f"" - - def __reduce_ex__(self, _protocol: SupportsIndex) -> NoReturn: - """Refuse pickle, ``copy.copy``, and ``copy.deepcopy``. - - Mirrors :class:`contextvars.Context`. A snapshot disconnected - from live state is uniformly a footgun, so the rejection - applies under both vanilla pickling and Wool's own pickler — - intentionally no :meth:`__wool_reduce__` is defined. Callers - wanting in-process duplication must use :meth:`Context.copy` - explicitly; cross-process propagation rides - :meth:`to_protobuf` and :meth:`from_protobuf` instead. - """ - raise TypeError( - "cannot pickle 'wool.Context' object — use Context.copy() " - "for in-process duplication" - ) - - @property - def id(self) -> UUID: - """The UUID that identifies this :class:`Context`'s logical chain.""" - return self._id - - @classmethod - def from_protobuf( - cls, - wire_context: protocol.Context, - *, - serializer: Serializer | None = None, - ) -> Context: - """Reconstruct a :class:`Context` from a wire :class:`protocol.Context`. - - Walks ``wire_context.vars`` once, materializing each entry into - the receiver's :class:`Context`: the var identity resolves - through the process-wide :class:`wool.ContextVar` registry - (or pins a stub if undeclared), the optional serialized value - is deserialized into the receiver's data dict, and any - consumed-token IDs the entry carries promote to the live-Token - slot or stash into :attr:`_external_used_tokens` for later - reset propagation. Stub creation pins the stub onto the - current :class:`Context`, so a lazy-import receiver sees the - propagated value as soon as it later declares the var — the - same stub-promotion path a pickled :class:`ContextVar` - instance uses when it rides through ``__reduce__`` embedded - in a routine argument. - - :param wire_context: - The wire :class:`protocol.Context` to decode. - :param serializer: - Deserializer for values. ``None`` (default) selects - :data:`wool.__serializer__`. - - Decode failures emit :class:`ContextDecodeWarning` and the - offending entry is skipped — surviving entries decode - normally. A malformed wire context id falls back to a fresh - UUID with the failure recorded as the same warning class. - Operators who want strict behavior promote the warning via - ``PYTHONWARNINGS=error::wool.ContextDecodeWarning``; under - strict mode the failures aggregate into a single - :class:`BaseExceptionGroup` raised after the decode loop so - callers learn about every failure, not just the first. - - :raises BaseExceptionGroup: - Under strict mode, when one or more entries fail to - decode; the group's peers are the per-failure - :class:`ContextDecodeWarning` instances. - """ - if serializer is None: - serializer = wool.__serializer__ - failures: list[ContextDecodeWarning] = [] - try: - ctx_id = UUID(hex=wire_context.id) if wire_context.id else uuid4() - except ValueError as e: - try: - warnings.warn( - f"Failed to decode wire context id {wire_context.id!r}: {e}", - ContextDecodeWarning, - stacklevel=2, - ) - except ContextDecodeWarning as raised: - failures.append(raised) - ctx_id = uuid4() - data: dict[ContextVar[Any], Any] = {} - ctx = cls._reconstitute(ctx_id, data) - with attached(ctx, guarded=False): - for entry in wire_context.vars: - var_key = (entry.namespace, entry.name) - var = resolve_stub(var_key, ctx) - if entry.HasField("value"): - try: - data[var] = serializer.loads(entry.value) - except Exception as e: - try: - warnings.warn( - f"Failed to deserialize wool.ContextVar " - f"{var_key!r}: {e}", - ContextDecodeWarning, - stacklevel=2, - ) - except ContextDecodeWarning as raised: - failures.append(raised) - # Each consumed-token ID is either promoted to the - # live-Token slot when the same-process registry has - # an instance for it, or stashed in - # :attr:`_external_used_tokens` keyed by the var so - # a subsequent :meth:`Context.update` can pop the - # corresponding var from the receiver's data — the - # reset signal propagates even when the token - # instance never materializes locally. - for token_id_hex in entry.consumed_tokens: - try: - token_id = UUID(hex=token_id_hex) - except ValueError as e: - try: - warnings.warn( - f"Failed to decode consumed-token ID " - f"{token_id_hex!r} for var {var_key!r}: {e}", - ContextDecodeWarning, - stacklevel=2, - ) - except ContextDecodeWarning as raised: - failures.append(raised) - continue - live = token_registry.get(token_id) - if live is not None: - if not live._used: - live._used = True - ctx._used_tokens.add(live) - else: - ctx._external_used_tokens[token_id] = var_key - if failures: - raise BaseExceptionGroup( - "wool context decode failed", - failures, - ) - return ctx - - def to_protobuf( - self, - *, - serializer: Serializer | None = None, - ) -> protocol.Context: - """Snapshot this :class:`Context` to a wire :class:`protocol.Context`. - - Each var observable in the snapshot — by carrying a value or - by being the source of a consumed token — emits one - :class:`protocol.ContextVar` entry. The entry's ``value`` is - set when the var has a current binding in this :class:`Context` - and unset otherwise (a var that was set and then reset to no - prior value still rides the wire so its consumed-token IDs - propagate). Default-only values — vars that have never been - explicitly set in this :class:`Context` and have no consumed - tokens — are absent from the snapshot. - - Per-var encode failures emit :class:`ContextDecodeWarning` - and the offending key is skipped — surviving vars encode - normally, mirroring the per-entry resilience of - :meth:`from_protobuf`. Operators who want strict behavior - promote the warning via - ``PYTHONWARNINGS=error::wool.ContextDecodeWarning``; under - strict mode the per-var failures aggregate into a single - :class:`BaseExceptionGroup` raised after the loop so callers - learn about every bad var, not just the first. - - :param serializer: - Serializer for values. ``None`` (default) selects - :data:`wool.__serializer__`. - :raises BaseExceptionGroup: - Under strict mode, when one or more vars fail to encode; - the group's peers are the per-var - :class:`ContextDecodeWarning` instances. - """ - if serializer is None: - serializer = wool.__serializer__ - wire_context = protocol.Context(id=self._id.hex) - failures: list[ContextDecodeWarning] = [] - encoded_values: dict[tuple[str, str], bytes] = {} - failed_keys: set[tuple[str, str]] = set() - for var, value in self.items(): - try: - encoded_values[var._key] = serializer.dumps(value) - except Exception as e: - failed_keys.add(var._key) - try: - warnings.warn( - f"Failed to serialize wool.ContextVar {var._key!r}: {e}", - ContextDecodeWarning, - stacklevel=2, - ) - except ContextDecodeWarning as raised: - failures.append(raised) - token_ids_by_key: dict[tuple[str, str], list[str]] = {} - for token_id, var_key in self._consumed_entries(): - # A var whose value failed to serialize is suppressed - # entirely — emitting consumed tokens without a value would - # propagate a phantom reset on the receiver, which would - # interpret the half-encoded entry as "reset and not - # re-set" via consumed_keys minus sender_data_keys in - # :meth:`update`. - if var_key in failed_keys: - continue - token_ids_by_key.setdefault(var_key, []).append(token_id.hex) - for var_key in set(encoded_values).union(token_ids_by_key): - namespace, name = var_key - entry = wire_context.vars.add(namespace=namespace, name=name) - if var_key in encoded_values: - entry.value = encoded_values[var_key] - if var_key in token_ids_by_key: - entry.consumed_tokens.extend(token_ids_by_key[var_key]) - if failures: - raise BaseExceptionGroup( - "wool context encode failed for one or more vars", - failures, - ) - return wire_context - - def has_state(self) -> bool: - """Return True if this :class:`Context` carries any observable - state — var bindings, locally-consumed tokens, or externally- - supplied consumed-token entries awaiting promotion. - - Distinct from :meth:`__bool__` / :meth:`__len__`, which follow - the mapping-container contract and report only on var - bindings. Wire-side callers use :meth:`has_state` to skip - no-op merges from empty wire frames. - """ - return ( - bool(self._data) - or bool(self._used_tokens) - or bool(self._external_used_tokens) - ) - - def copy(self) -> Context: - """Return a shallow copy of this :class:`Context` with a fresh ID. - - Mirrors :meth:`contextvars.Context.copy` — the copy is a - new logical chain with its own UUID, so mutations to the - copy do not affect this :class:`Context` and nested - dispatches fired under the copy carry its fresh ID, not this - :class:`Context`'s. The consumed-token set is not carried - across: any :class:`Token` minted under this - :class:`Context`'s UUID is already incompatible with the - copy's UUID for :meth:`ContextVar.reset` purposes. - """ - return Context._reconstitute(uuid4(), dict(self)) - - def run(self, fn: Callable[..., T], /, *args: Any, **kwargs: Any) -> T: - """Run the specified callable in this :class:`Context`. - - Installs this :class:`Context` as the current scope's active - :class:`Context`, runs the callable, then restores the - previous :class:`Context`. Mutations made by the callable go - directly into this :class:`Context`. Affects only Wool's - per-scope :class:`Context` registry; the surrounding - :class:`contextvars.Context` is untouched. Compose with - :meth:`contextvars.Context.run` to scope both at once:: - - wool_ctx.run(stdlib_ctx.run, fn, *args, **kwargs) - - :raises RuntimeError: - If this :class:`Context` is already running a task. - """ - with attached(self): - return fn(*args, **kwargs) - - def get(self, var: ContextVar[T], default: Any = None) -> Any: - """Return *var*'s value in this :class:`Context`, or *default* if unset.""" - return self._data.get(var, default) - - def keys(self) -> KeysView[ContextVar[Any]]: - """Return a view of every :class:`ContextVar` bound in this :class:`Context`.""" - return self._data.keys() - - def values(self) -> ValuesView[Any]: - """Return a view of every value bound in this :class:`Context`.""" - return self._data.values() - - def items(self) -> ItemsView[ContextVar[Any], Any]: - """Return a view of every (var, value) pair bound in this :class:`Context`.""" - return self._data.items() - - def update(self, other: Context) -> None: - """Apply *other*'s vars and used-token state to this :class:`Context`. - - One-way: *other* is the source of truth for overlapping - keys. This :class:`Context`'s ID is unchanged. Mirrors - :meth:`dict.update` semantics for the var map but extends it - with reset propagation: a var consumed by *other* and absent - from *other*'s data (i.e. reset and not subsequently re-set) - is popped from this :class:`Context`'s data so the merge - carries the reset signal, not just the post-set state. Live - Tokens from *other*'s :attr:`_used_tokens` are added to this - :class:`Context`'s :attr:`_used_tokens` as-is — their - ``_used`` flag was already flipped by the originating - :meth:`ContextVar.reset` call before they landed in - *other*'s set; :meth:`update` does not re-flip them. External - consumed-token entries from *other* not yet known to this - :class:`Context` are resolved against :data:`token_registry`: - a live match promotes directly into :attr:`_used_tokens` and - ``_used`` is flipped if not already, otherwise the (id, - var_key) entry joins :attr:`_external_used_tokens`. - """ - self._data.update(other._data) - sender_data_keys = {var._key for var in other._data} - consumed_keys = {token._key for token in other._used_tokens} - consumed_keys.update(other._external_used_tokens.values()) - for key in consumed_keys - sender_data_keys: - receiver_var = var_registry.get(key) - if receiver_var is not None: - self._data.pop(receiver_var, None) - for token in other._used_tokens: - self._used_tokens.add(token) - self._external_used_tokens.pop(token._id, None) - known = self._consumed_token_ids() - for token_id, var_key in other._external_used_tokens.items(): - if token_id in known: - continue - live = token_registry.get(token_id) - if live is not None: - if not live._used: - live._used = True - self._used_tokens.add(live) - else: - self._external_used_tokens[token_id] = var_key - - @classmethod - def _reconstitute( - cls, - context_id: UUID, - data: dict[ContextVar[Any], Any], - ) -> Context: - """Rebuild a :class:`Context` from externally-supplied parts. - - Bypasses ``__init__`` to adopt an externally-supplied - context ID and data dict, for callers that already hold the - canonical identity and state to rebuild a :class:`Context` - around. Not a copy — the dict reference is taken as-is. The - consumed-token slots start empty; callers populate them via - :meth:`update`, :meth:`from_protobuf`, or direct mutation - as their wire shape dictates. - """ - instance: Context = object.__new__(cls) - instance._init_state(context_id=context_id, data=data) - return instance - - def _consumed_entries(self) -> Iterator[tuple[UUID, tuple[str, str]]]: - """Yield ``(token_id, var_key)`` for every consumed token - tracked by this :class:`Context`, deduped across - :attr:`_used_tokens` and :attr:`_external_used_tokens`. - - Iterating :attr:`_used_tokens` materializes only Tokens that - are still alive — :class:`weakref.WeakSet` skips entries - whose referents have been collected — so the resulting - sequence elides IDs whose double-reset detection role is no - longer load-bearing in this process. External entries cover - IDs whose owning Token was never reconstituted locally; if - an ID appears in both stores the live-token entry wins. - """ - seen: set[UUID] = set() - for token in self._used_tokens: - seen.add(token._id) - yield token._id, token._key - for token_id, var_key in self._external_used_tokens.items(): - if token_id in seen: - continue - yield token_id, var_key - - def _consumed_token_ids(self) -> set[UUID]: - """Return every consumed-token ID this :class:`Context` - tracks, across both live and external stores.""" - return {token_id for token_id, _ in self._consumed_entries()} - - def _init_state( - self, - *, - context_id: UUID, - data: dict[ContextVar[Any], Any], - ) -> None: - self._id = context_id - self._data = data - self._lock = threading.Lock() - self._running = False - self._stub_pins = set() - self._used_tokens = weakref.WeakSet() - self._external_used_tokens = {} - self._bound_task = None - - @contextmanager - def _guard(self) -> Iterator[None]: - """Enforce the single-task invariant for the wrapped block. - - Acquires the running flag under :attr:`_lock` on entry - (raising :class:`RuntimeError` if another task is already - running inside this :class:`Context`) and releases it on - exit. Thread-safe. - """ - with self._lock: - if self._running: - raise RuntimeError( - "wool.Context is already running; at most one " - "task may run inside a given Context at a time" - ) - self._running = True - try: - yield - finally: - with self._lock: - self._running = False - - -# public -def current_context() -> Context: - """Return the live :class:`wool.Context` for the current execution scope. - - Inside an asyncio task, looks up the task's :class:`Context` in - the process-wide registry. Outside a task (sync code), uses a - per-thread fallback. If no :class:`Context` exists for the - current scope, one is created lazily and registered. - """ - _ensure_task_factory_installed() - with lock: - existing = context_registry.get() - if existing is not None: - return existing - fresh = Context() - context_registry[scope_key()] = fresh - return fresh - - -@contextmanager -def attached(ctx: Context, *, guarded: bool = True) -> Iterator[None]: - """Install *ctx* as the current scope's :class:`Context` for the - duration of the ``with`` block. - - Scoped install/restore. Holds :meth:`Context._guard` for the - block by default — the discipline that enforces the single- - task-per-:class:`Context` invariant for user code running - inside *ctx*. Pass ``guarded=False`` for framework-internal - decode plumbing where the :class:`Context` only needs to be - visible for transitive :class:`wool.ContextVar` / - :class:`wool.Token` reconstitution inside ``serializer.loads`` - calls and the caller is not running user code in *ctx*. - - :raises RuntimeError: - Under ``guarded=True``, if another task is already running - inside *ctx*. - """ - guard = ctx._guard() if guarded else nullcontext() - with guard: - token = context_registry.set(ctx) - try: - yield - finally: - context_registry.reset(token) - - -# public -def copy_context() -> Context: - """Return a shallow copy of the current :class:`wool.Context`. - - Mirrors :func:`contextvars.copy_context` — returns a shallow - copy of the current scope's context as a new :class:`Context` - instance. The copy receives a fresh logical-chain ID, so it is - independent of the source's chain for dispatch, tracing, and - :class:`Token` scoping purposes. - """ - return current_context().copy() - - -def install_task_factory( - loop: asyncio.AbstractEventLoop | None = None, -) -> None: - """Install Wool's task factory on the given (or running) loop. - - Composes with an existing factory if one is set, so that - asyncio child tasks created via ``create_task`` inherit a - forked :class:`Context`. Idempotent — a subsequent call on a - loop that already has the Wool-wrapped factory installed is a - no-op. :func:`current_context` self-installs the factory on - first contact, so user code that touches Wool's API without - first calling :func:`install_task_factory` still gets fork-on- - task semantics for tasks created after the first contact. - - **Ordering contract** — If a user installs their own task factory - *after* Wool's, Wool's wrapping of child coroutines is dropped - and copy-on-fork breaks silently for subsequently-created tasks. - Install Wool's factory last (or compose manually) when other - libraries also want a factory on the same loop. - """ - if loop is None: - loop = asyncio.get_running_loop() - - existing = loop.get_task_factory() - if existing is not None and getattr(existing, "__wool_wrapped__", False): - _log.debug(f"wool-composed task factory already installed on {loop}") - return - inner = existing if existing is not None else _default_task_factory - - def wool_factory( - loop: asyncio.AbstractEventLoop, - coro: Coroutine[Any, Any, Any] | Generator[Any, None, Any], - *, - context: Context | contextvars.Context | None = None, - **kwargs: Any, - ) -> asyncio.Task[Any]: - # Widen to ``Coroutine | Generator`` to satisfy typeshed's - # ``_CoroutineLike[_T]`` contravariant parameter — the - # ``Generator`` arm exists for pre-3.8 generator-coroutines - # and is unreachable from asyncio's modern create_task path, - # but the static type must accept it for ``wool_factory`` to - # be a valid ``_TaskFactory``. Narrow back to ``Coroutine`` - # for the body, which uses ``Coroutine``-only operations - # (await semantics, ``_context_scope`` wrapping). - coro = cast(Coroutine[Any, Any, Any], coro) - if isinstance(context, Context): - # Explicit wool.Context: hide it from asyncio (which would - # call ctx.run(step_fn) per step and fragment _guard's - # held-across-awaits semantics into per-step entries) and - # instead wrap the coroutine so the guard + attach span - # the whole routine. ``_context_scope`` installs the - # :class:`Context` in the registry under the new task's - # identity for the routine's lifetime. - task = inner(loop, _context_scope(context, coro), **kwargs) - _register(task, context) # pyright: ignore[reportArgumentType] - return task # pyright: ignore[reportReturnType] - # No wool.Context: forward stdlib ``context=`` if supplied - # (so a third-party-supplied :class:`contextvars.Context` is - # honored verbatim) and fork the parent's wool.Context at - # task creation. The new task is registered with a fork (or - # a fresh Context if no parent is bound), so wool inherits - # the parent's bindings under a fresh chain id without - # depending on stdlib Context boundaries. - if context is not None: - kwargs["context"] = context - task = inner(loop, coro, **kwargs) - child = ( - parent.copy() - if (parent := context_registry.get()) is not None - else Context() - ) - _register(task, child) # pyright: ignore[reportArgumentType] - return task # pyright: ignore[reportReturnType] - - wool_factory.__wool_wrapped__ = True # pyright: ignore[reportFunctionMemberAccess] - wool_factory.__wool_inner__ = inner # pyright: ignore[reportFunctionMemberAccess] - loop.set_task_factory(wool_factory) - if existing is None: - _log.debug(f"wool task factory installed on {loop}") - else: - _log.debug( - f"wool task factory composed with existing factory {existing} on {loop}", - ) - - -# public -def create_task( - coro: Coroutine[Any, Any, T], - *, - name: str | None = None, - context: Context | None = None, -) -> asyncio.Task[T]: - """Create a task on the running loop, optionally pre-bound to a - :class:`wool.Context`. - - Mirrors :func:`asyncio.create_task` — uses the running loop, no - explicit loop parameter, and accepts the same ``name``/``context`` - keywords in the same order. asyncio's stdlib ``context=`` kwarg - is typed for :class:`contextvars.Context`, and :class:`wool.Context` - cannot subclass it (the stdlib C type disallows subclassing); this - helper hides the cast so callers do not need - ``# pyright: ignore[reportArgumentType]`` at every call site. The wool task - factory's interception of wool-Context-typed ``context`` is what - actually does the binding — it wraps the coroutine so the - single-task guard is held continuously across awaits and pins - the :class:`Context` to the new task in the process-wide - registry. - - Calling :func:`asyncio.create_task` directly with - ``context=wool_ctx`` is functionally identical when wool's task - factory is installed on the running loop; this helper exists - purely as a typing shim. To schedule on a non-running loop, call - :meth:`AbstractEventLoop.create_task` directly. - - :param coro: - The coroutine to run. - :param name: - Optional task name (forwarded to the factory). - :param context: - Optional :class:`wool.Context` to bind. When supplied, the - wool task factory wraps *coro* so :meth:`Context._guard` is - held across the task's lifetime; concurrent attempts to enter - the same Context from another task raise :class:`RuntimeError` - immediately. When ``None``, the new task inherits a fork of - the parent's :class:`wool.Context` via the factory's copy-on- - fork path. - :returns: - The freshly created :class:`asyncio.Task`. - """ - # stdlib's create_task is typed for contextvars.Context; the wool - # factory accepts a duck-typed wool.Context. - return asyncio.create_task(coro, name=name, context=context) # pyright: ignore[reportArgumentType] - - -def _register(task: asyncio.Task[Any], ctx: Context) -> None: - """Pin *ctx* to *task* in the process-wide :data:`context_registry`. - - Enforces the one-shot contract — a task is bound exactly once - at creation — so duplicate bindings surface immediately as - :class:`ContextAlreadyBound` rather than silently stomping - prior state. - - :raises ContextAlreadyBound: - If *task* is already bound to a :class:`Context` (one-shot - contract — see :class:`ContextAlreadyBound`). - """ - with lock: - if task in context_registry: - raise ContextAlreadyBound( - f"task {task!r} is already bound to {context_registry[task]!r}" - ) - context_registry[task] = ctx - - -async def _context_scope(ctx: Context, coro: Coroutine[Any, Any, T]) -> T: - """Run *coro* with *ctx* attached, the single-task guard held, - and the Context's ``_bound_task`` slot pinned to the current - task for the coroutine's lifetime. - - Implements the "first-task-wins for the routine's lifetime" - binding for ``loop.create_task(coro, context=wool_ctx)``: a - second task targeting the same Context while the first is - still mid-flight raises before acquiring :meth:`_guard`, - catching the cross-task interleaving that :meth:`_guard`'s - ``_running`` flag alone could miss for concurrent attempts - spread across asyncio loop ticks. Holds :meth:`_guard` - continuously across the coroutine's awaits so yield frames - cannot be interleaved by another task. - """ - current = asyncio.current_task() - if current is not None: - with ctx._lock: - if ctx._bound_task is not None: - bound = ctx._bound_task() - if bound is not None and bound is not current and not bound.done(): - # Close the un-awaited coroutine to suppress the - # "coroutine was never awaited" RuntimeWarning that - # would otherwise leak at GC. - coro.close() - raise RuntimeError( - "wool.Context is bound to another live task; " - "first-task-wins for the routine's lifetime" - ) - ctx._bound_task = weakref.ref(current) - try: - with attached(ctx): - return await coro - finally: - if current is not None: - with ctx._lock: - if ctx._bound_task is not None and ctx._bound_task() is current: - ctx._bound_task = None - - -def _default_task_factory( - loop: asyncio.AbstractEventLoop, - coro: Coroutine[Any, Any, Any], - **kwargs: Any, -) -> asyncio.Task[Any]: - """Fall-back task factory matching :meth:`AbstractEventLoop.create_task`. - - Used when no user factory is installed, so wool's factory has - a uniform inner layer to delegate to. - """ - return asyncio.Task(coro, loop=loop, **kwargs) - - -def _ensure_task_factory_installed() -> None: - """Self-install Wool's task factory on the running loop if absent. - - Lets user code that touches Wool without first calling - :func:`install_task_factory` still get fork-on-task semantics - for tasks created after the first Wool API contact. No-ops in - sync contexts (no running loop). The :data:`_loops_with_factory` - weak set short-circuits the lookup to a single membership check - after the first install on a given loop. - """ - try: - loop = asyncio.get_running_loop() - except RuntimeError: - return - if loop in _loops_with_factory: - return - install_task_factory(loop) - _loops_with_factory.add(loop) diff --git a/wool/src/wool/runtime/context/chain.py b/wool/src/wool/runtime/context/chain.py new file mode 100644 index 00000000..74545c2e --- /dev/null +++ b/wool/src/wool/runtime/context/chain.py @@ -0,0 +1,357 @@ +"""The chain-state model. + +Provides `Chain`: the immutable record of a logical +execution chain that the rest of the context subsystem mutates, +ships across the wire, forks per task, and polices for contention. +""" + +from __future__ import annotations + +import asyncio +import threading +import weakref +from dataclasses import dataclass +from dataclasses import field +from dataclasses import replace +from typing import TYPE_CHECKING +from typing import Any +from uuid import UUID +from uuid import uuid4 + +import wool +from wool.runtime.context.factory import ensure_task_factory_installed +from wool.runtime.context.manifest import ChainManifest +from wool.runtime.context.registry import var_registry +from wool.runtime.typing import Undefined + +if TYPE_CHECKING: + from wool.runtime.context.manifest import ContextVarManifest + + +@dataclass(frozen=True, eq=False) +class Chain: + """Immutable index of Wool chain state. + + Wool chain state — i.e., the logical-chain UUID, the set of bound + `wool.ContextVar` instances, and their resets — rides in a + single Wool-owned `contextvars.ContextVar` as a + `Chain` instance. The chain is an *index*, not a value + store: each `wool.ContextVar`'s value lives in its own + backing `contextvars.ContextVar`. + + A chain identifies one serial branch of the program's async call + tree: the logical call stack descending from the most recent + `asyncio.create_task` fork, on which every frame executes + strictly in sequence. The branch is not one + `contextvars.Context` — it spans every context copy + descended from its arming, e.g., event-loop callback copies, explicit + `contextvars.Context.run` re-entries, even the worker-side + context a dispatch arms with the caller's chain id — which is what + gives a routine awaited on a worker the same context continuity + as a process-local await. `asyncio.create_task` always + forks onto a fresh chain, the chain-level analogue of stdlib + copy-on-fork, so concurrent branches can never mutate one + another's chain state; an execution that enters a chain it does + not own has circumvented that discipline and fails loudly instead + (see `wool.ChainContention`). + + The Wool-owned `contextvars.ContextVar` becomes a permanent + member of any `contextvars.Context` once that context is + armed with a `Chain` instance. An armed context + additionally carries one backing variable per bound + `wool.ContextVar`, so a `contextvars.copy_context` + of an armed context enumerates ``1 + N`` Wool-owned variables. + An unarmed context holds none of them and is indistinguishable + from a plain `contextvars.Context`. A context is armed when a + chain is mounted into it (see *Lifecycle*). + + Because the chain and every backing context variable live in a + `contextvars.Context`, Wool state is governed by standard + context semantics: it's copied into new tasks at creation (by + default) and into event-loop callbacks at scheduling; code run + explicitly inside a captured context — e.g., via + `contextvars.Context.run` — observes and mutates the Wool + state that context carries; and writes made inside a copy stay in + that copy, never leaking back into the originating scope. + + **Lifecycle** + + A `Chain` instance is strictly *live*: + fresh (constructed directly or via `_fork`) or mounted + (installed in a `contextvars.Context` via `mount`). + Decoded-but-unmounted wire state lives on + `~wool.runtime.context.manifest.ChainManifest` and is + installed by `from_manifest`. `mount` is the single transition + from pure data → installed: it stamps ``thread`` / ``task`` from + the calling scope and applies any pending resets. + + **Indexing asymmetry** + + ``vars`` keys on `wool.ContextVar` + instance identity (safe because the process-wide var registry + enforces a singleton per ``(namespace, name)`` key); ``resets`` + keys on the ``(namespace, name)`` tuple directly so the reset + signal survives a wire round-trip without requiring the receiver + to have declared the variable. The two indices reference the same + logical concept under different key spaces by design. + + :param id: + UUID identifying the logical execution chain. Defaults to a + freshly minted UUID. Re-minted on every task fork by the Wool + task factory; an explicit value lets a worker preserve the + caller's chain id when arming on an inbound wire frame. + :param thread: + `threading.get_ident` of the OS thread that owns this + chain — stamped by `mount`. Defaults to ``0`` (no owner) + so construction is pure data; `mount` is the single + owner-stamping site. The chain-contention guard compares it + against the accessing thread; see `wool.ChainContention`. + :param task: + Weak reference to the `asyncio.Task` that owns this + chain — stamped by `mount`. Defaults to ``None`` (no + owner). ``None`` also marks a chain armed outside any task + (synchronous code, or a `wool.to_thread` worker thread). + The chain-contention guard compares it against the running + task so a second task entering the chain fails loud. This + field is process-local and never crosses the wire. + :param vars: + Set of bound `wool.ContextVar` instances. Membership + is the index of "bound in this chain": ``X in vars`` is + equivalent to ``X._backing`` resolving to a + non-`~wool.runtime.typing.Undefined` value in the active + `contextvars.Context`. + :param resets: + ``(namespace, name)`` keys of variables reset to no prior + value and not since re-set. The token-independent "drop this + variable" signal for the wire merge (`from_manifest`); survives + even if the resetting token is collected before the reset + propagates. + :param stubs: + Undeclared-stub `wool.ContextVar` instances observed + while decoding a chain manifest, held strongly so a lazy-import + receiver can still promote them when it declares the variable. + """ + + id: UUID = field(default_factory=uuid4) + # A chain instance's thread and task are runtime bookkeeping, not + # part of the chain's identity, so we exclude them from comparison. + thread: int = field(default=0, repr=False, compare=False) + task: weakref.ref[asyncio.Future[Any]] | None = field( + default=None, repr=False, compare=False + ) + vars: frozenset[ContextVarManifest[Any]] = field(default_factory=frozenset) + resets: frozenset[tuple[str, str]] = field(default_factory=frozenset) + stubs: frozenset[ContextVarManifest[Any]] = field(default_factory=frozenset) + + def __post_init__(self) -> None: + """Coerce the container fields to frozensets. + + Callers — including `dataclasses.replace` — may pass any + iterable; coercion preserves the frozen facade regardless. + """ + if not isinstance(self.vars, frozenset): + object.__setattr__(self, "vars", frozenset(self.vars)) + if not isinstance(self.resets, frozenset): + object.__setattr__(self, "resets", frozenset(self.resets)) + if not isinstance(self.stubs, frozenset): + object.__setattr__(self, "stubs", frozenset(self.stubs)) + + def to_manifest(self) -> ChainManifest: + """Snapshot this live chain into a decoded `ChainManifest`. + + The send-side counterpart to `from_manifest`. Reads each bound + variable's value from its backing `contextvars.ContextVar` in the + *calling* `contextvars.Context` and captures it inline on the + returned manifest, ready for + `~wool.runtime.context.manifest.ChainManifest.to_protobuf` to + serialise. A chain spans many contexts, so the requirement is + membership, not identity: this method *must* run inside a context + that carries this chain's bindings — run anywhere else, the reads + observe that context's values (or none), not this chain's. + + A variable whose backing resolves to + `~wool.runtime.typing.Undefined` is absent from the snapshot; a + variable reset to no prior value still rides along via ``resets`` + so the reset propagates regardless of source. + + This is a pure read — it never serialises, so it cannot raise a + serialisation error. The bound-key singleton and the + bound/reset disjointness invariants are asserted here, at the + snapshot boundary, so any manifest reaching ``to_protobuf`` is + already well-formed: ``vars`` holds singletons keyed by + ``(namespace, name)`` (enforced at construction via + ``var_registry`` and `ContextVarCollision`), and no key is both + bound and reset-pending. + """ + assert len({(v.namespace, v.name) for v in self.vars}) == len(self.vars), ( + "singleton invariant violated: duplicate var keys in Chain.vars" + ) + assert not (self.resets & {v._key for v in self.vars}), ( + "disjointness invariant violated: keys both bound and reset-pending in Chain" + ) + vars: dict[ContextVarManifest[Any], Any] = {} + for var in self.vars: + value = var._backing.get(Undefined) + if value is Undefined: + continue + vars[var] = value + return ChainManifest( + id=self.id, + vars=vars, + resets=self.resets, + stubs=self.stubs, + ) + + def mount(self, *, owned: bool = True) -> Chain: + """Arm this chain as the live chain in the current `contextvars.Context`. + + Installs an evolved copy of this chain — re-stamped so the + current thread (and, when *owned*, the owning task) hold it — + and applies each `resets` signal by rewinding the matching + backing, when declared in this process, to + `~wool.runtime.typing.Undefined` (idempotent for resets the + chain already carried). Must run inside the owning task's real + `contextvars.Context` so backing-variable state and any native + tokens minted afterward bind to it. + + ``mount`` is the shared arming step. The local path + (`wool.ContextVar.set` on first-arm, `wool.ContextVar.reset`, + and the per-task fork) arms an already-live chain directly; the + wire-ingress path arrives through `from_manifest`, which drains + and merges the decoded state into a fresh chain and then + delegates here. + + ``mount`` is also the keystone arming boundary: every arming — + caller-side or wire-ingress — transits through here and + unconditionally ensures Wool's task factory is on the running + loop via + `~wool.runtime.context.factory.ensure_task_factory_installed` + (a no-op once installed). That call doubles as the displacement + tripwire: if a third-party factory has displaced Wool's, it + raises `wool.TaskFactoryDisplaced` — so a worker-side mount + surfaces displacement at frame ingress, not only on the next + `wool.ContextVar` set. + + :param owned: + When ``True`` (the default), stamp the current asyncio task + as the chain's owner so the cross-task contention guard + (`assert_chain_owner`) can arbitrate by task identity. The + worker-side per-step driver passes ``False``: its cached + `contextvars.Context` is driven by a succession of distinct + step-tasks, so the chain is owned thread-wise and + task-agnostically — stamping any one step-task would make + the next trip the guard. The OS-thread stamp is applied + either way. + :returns: + The installed Chain — the evolved copy with the owner + re-stamped. + """ + ensure_task_factory_installed() + for key in self.resets: + receiver_var = var_registry.get(key) + if receiver_var is not None: + receiver_var._backing.set(Undefined) + task = None + if owned: + try: + task = asyncio.current_task() + except RuntimeError: + pass + installed = self._evolve( + thread=threading.get_ident(), + task=weakref.ref(task) if task is not None else None, + ) + wool.__chain__.set(installed) + return installed + + @classmethod + def from_manifest( + cls, + manifest: ChainManifest, + *, + owned: bool, + merge_with: Chain | None = None, + ) -> Chain: + """Drain a decoded `ChainManifest` into the backings and arm it. + + The wire-ingress counterpart to `to_manifest`, entered through + `Frame.mount`. Drains the manifest's decoded values into their + backing variables, builds the receiver chain — merging onto a live + receiver when *merge_with* is given (the receiver keeps its chain + id), or seeding fresh state from the manifest alone when it is + ``None`` — and delegates the arming (reset rewind, owner stamp, + task factory, `wool.__chain__` set) to `mount`. + + :param manifest: + The decoded-but-unmounted chain snapshot to install. + :param owned: + Forwarded to `mount`. Caller-side mounts pass ``True``; the + worker-side per-step driver passes ``False`` because the + cached `contextvars.Context` is driven by successive + step-tasks rather than owned by one. + :param merge_with: + The live receiver chain to union the manifest onto, or ``None`` + to seed a fresh chain from the manifest alone. + """ + # Drain values into the backings first so the built Chain + # observes them via ``vars``. Backing writes mint stray native + # tokens — discarded with the local scope. + for var, value in manifest.vars.items(): + var._backing.set(value) + # Build the vars index from the manifest, optionally unioned with + # the live receiver's bindings (the receiver keeps its chain id + # through the merge). + manifest_vars = frozenset(manifest.vars.keys()) + if merge_with is None: + merged_vars = manifest_vars + resets = manifest.resets + stubs = manifest.stubs + chain_id = manifest.id + else: + merged_vars = merge_with.vars | manifest_vars + manifest_touched = {var._key for var in manifest_vars} | manifest.resets + resets = manifest.resets | (merge_with.resets - manifest_touched) + stubs = merge_with.stubs | manifest.stubs + chain_id = merge_with.id + # Reset-only keys drop out of ``vars`` so the reset propagates; the + # rewind itself happens in mount. + for key in manifest.resets: + receiver_var = var_registry.get(key) + if receiver_var is not None: + merged_vars = merged_vars - {receiver_var} + target = cls(id=chain_id, vars=merged_vars, resets=resets, stubs=stubs) + return target.mount(owned=owned) + + def _evolve(self, **changes: Any) -> Chain: + """Return a copy of this chain with *changes* applied.""" + return replace(self, **changes) + + def _fork(self) -> Chain: + """Fork this chain into a fresh logical chain. + + The fork inherits the variable bindings (the ``vars`` index) + and stub pins on a freshly minted ``id``, with no owner + stamps — like any fresh Chain, it acquires its owners at the + subsequent `mount`. The reset signals are dropped: a + `wool.Token` minted in the parent chain is already + incompatible with the fork's chain for + `wool.ContextVar.reset`, so the fork starts clean. The + backing variables' values ride the fork natively — the forked + task runs in a `contextvars.copy_context` copy. This is + the copy-on-fork the task factory applies at every task + creation. + + **Re-handoff is undefined behaviour.** A fork minted on one + thread/task and then re-driven elsewhere — e.g., a + `wool.to_thread` worker's fork captured back into the + loop thread and passed to `asyncio.create_task` — + is unsupported. It will fail loudly only at the next + `wool.ContextVar` access on the re-handed chain, via + the chain-contention guard. Forks are intended to be owned and + retired by the scope that created them. + """ + return Chain( + id=uuid4(), + vars=self.vars, + stubs=self.stubs, + ) diff --git a/wool/src/wool/runtime/context/exceptions.py b/wool/src/wool/runtime/context/exceptions.py new file mode 100644 index 00000000..d9737895 --- /dev/null +++ b/wool/src/wool/runtime/context/exceptions.py @@ -0,0 +1,337 @@ +"""The context subsystem's exceptions and warnings. + +Single home for every error and warning the context subsystem +raises. The serialization branch — `SerializationError`, its +strict-mode aggregator `ChainSerializationError`, and the +non-fatal `SerializationWarning` — sits alongside `ChainContention` +(the chain-ownership guard's signal), `ContextVarCollision` (raised +on a duplicate variable key), and `TaskFactoryDisplaced` (raised +when a third-party task factory displaces Wool's). +""" + +from __future__ import annotations + +import asyncio +from typing import Any +from typing import Literal +from uuid import UUID + +from wool.exceptions import WoolError +from wool.exceptions import WoolWarning + + +# public +class SerializationError(WoolError): + """Raised when a value cannot be serialized across the Wool wire. + + Raised atomically when a single value fails to encode. The + strict-mode aggregator `ChainSerializationError` subclasses + this to collect per-variable chain-manifest failures, so catching + `SerializationError` matches every wire serialization failure, + atomic or aggregated. + + :param cause: + The underlying exception that produced the failure. Also + carried on ``__cause__`` when raised via exception chaining. + :param value_repr: + Optional ``repr()``-style preview of the value that failed to + encode, for diagnostics when the cause exception's own message + does not name the offending value. + """ + + def __init__( + self, + *args: Any, + cause: BaseException | None = None, + value_repr: str | None = None, + ) -> None: + super().__init__(*args) + self.cause = cause + self.value_repr = value_repr + + +# public +class ChainSerializationError(SerializationError): + """Aggregator raised when a value cannot be serialized across the Wool wire. + + Strict mode (see `SerializationWarning` for promotion recipe) + turns per-variable serialization warnings into errors. The wire + codec catches each promoted warning and raises the batch as a + single `ChainSerializationError` once its encode or decode loop + completes, so every bad variable surfaces, not just the first. + + When a routine also raises a primary exception, the aggregator + rides on that primary's ``__cause__``. Any existing ``except`` + clauses wrapping the routine keep matching unchanged. + + :param warnings: + The promoted `SerializationWarning` instances, in emission + order. They are kept on the `warnings` attribute; the + exception's own message is a synthesized summary, so ``str()`` + reads as a count of failures rather than a tuple of warnings. + """ + + def __init__(self, *warnings: SerializationWarning) -> None: + self.warnings: tuple[SerializationWarning, ...] = tuple( + w for w in warnings if isinstance(w, SerializationWarning) + ) + count = len(self.warnings) + noun = "variable" if count == 1 else "variables" + super().__init__(f"{count} context {noun} failed to serialize across the wire") + + +# public +class SerializationWarning(WoolWarning): + """Emitted when a value cannot be serialized across the Wool wire. + + Wool's wire protocol treats chain propagation and exception + fidelity as ancillary state: a failure there never preempts the + routine's primary signal, i.e., its return value or raised exception. + The failure is reported through this warning instead, so callers + that depend on the ancillary state can detect the inconsistency. + + Every emission site shares this one class, so callers that prefer + strict semantics promote serialization failures to errors with a + single category-level filter:: + + import warnings + import wool + + warnings.filterwarnings("error", category=wool.SerializationWarning) + + Under strict mode the wire codec aggregates the promoted + per-variable warnings into a single `ChainSerializationError` + (see that class for the aggregation contract). + + The structured fields below identify the failure programmatically; + each may be ``None`` if the emitting site does not supply it. + + :param cause: + The underlying exception, if applicable. + :param var_key: + The ``(namespace, name)`` identity of the variable whose value + failed to serialize. + :param direction: + Which side of the wire hop failed. + :param original_type: + The exception class that fallback reconstruction replaced when + a routine's exception could not be rebuilt with full fidelity. + """ + + def __init__( + self, + *args: Any, + cause: BaseException | None = None, + var_key: tuple[str, str] | None = None, + direction: Literal["encode", "decode"] | None = None, + original_type: type | None = None, + ) -> None: + super().__init__(*args) + self.cause = cause + self.var_key = var_key + self.direction = direction + self.original_type = original_type + + +_KIND_MESSAGES: dict[str, str] = { + "thread": ( + "wool.ContextVar accessed from thread {current_thread} but chain " + "{chain_id} is owned by thread {owning_thread}; an armed Wool " + "context cannot be shared across OS threads in parallel. Use " + "wool.to_thread to offload work onto a fresh, detached chain." + ), + "task": ( + "wool.ContextVar accessed by task {current_task!r} but chain " + "{chain_id} is owned by task {owning_task!r}; a Wool chain cannot " + "be entered by two tasks at once. Each task must run on " + "its own chain — create child tasks the ordinary way (the task " + "factory forks a fresh chain per task), or pass a fresh " + "contextvars.copy_context() to each create_task instead of " + "sharing one." + ), + "create_task": ( + "the same armed contextvars.Context was passed to create_task " + "while an earlier task running in it is still live (chain " + "{chain_id}). An armed context cannot be shared across " + "concurrently-live tasks — both tasks would corrupt each other's " + "Wool state through the single context it holds. Omit context= " + "(the default copies the context per task) or pass a fresh " + "contextvars.copy_context() to each task." + ), +} + + +# public +class ChainContention(WoolError): + """Raised when a Wool chain is entered by a thread or task that does + not own it. + + Wool enforces strictly serial execution within a logical chain: at + most one OS thread *and* one `asyncio.Task` may run code + under a given chain at a time. The guard has two dimensions — an + OS-thread check and an asyncio-task check within a single thread — + and engages only on an *armed* context (one carrying Wool chain + state). It fires at the point a `wool.ContextVar` is read + or written, not at a boundary crossing; offloaded code that never + touches a Wool variable is never flagged. An armed + `contextvars.Context` re-passed to + `asyncio.create_task` is also rejected up front (the factory + installed by `wool.install_task_factory` performs this + rejection; see `wool.runtime.context.factory`) — the + ``"create_task"`` kind below distinguishes that case. + + The supported way to run Wool-aware work on another OS thread is + `wool.to_thread`, which forks a fresh, detached chain for + the worker. + + See ``wool/src/wool/runtime/context/README.md`` for the model + context. + + :param chain_id: + UUID of the chain whose ownership was violated. + :param kind: + ``"thread"`` for a cross-thread access, ``"task"`` for a + cross-task access, ``"create_task"`` for an armed context + re-passed to `asyncio.create_task`. + :param owning_thread: + Owner thread identity, when *kind* is ``"thread"``. + :param current_thread: + Offending thread identity, when *kind* is ``"thread"``. + :param owning_task: + Owner task, when *kind* is ``"task"``. + :param current_task: + Offending task, when *kind* is ``"task"``. + """ + + chain_id: UUID + kind: Literal["thread", "task", "create_task"] + owning_thread: int | None + current_thread: int | None + owning_task: asyncio.Future[Any] | None + current_task: asyncio.Future[Any] | None + + def __init__( + self, + *, + chain_id: UUID, + kind: Literal["thread", "task", "create_task"], + owning_thread: int | None = None, + current_thread: int | None = None, + owning_task: asyncio.Future[Any] | None = None, + current_task: asyncio.Future[Any] | None = None, + ) -> None: + # Validate ``kind`` explicitly so an unknown value raises a + # typed ``ValueError`` from a known origin rather than a bare + # ``KeyError`` from inside the exception's own constructor. + # The Literal annotation guards static call sites; this guard + # covers dynamic call sites (most notably ``__reduce__``-driven + # cross-process reconstruction where a forward-compat receiver + # might decode a ``kind`` value it does not know). + if kind not in _KIND_MESSAGES: + raise ValueError( + f"unknown ChainContention kind: {kind!r}; " + f"expected one of {sorted(_KIND_MESSAGES)}" + ) + message = _KIND_MESSAGES[kind].format( + chain_id=chain_id, + owning_thread=owning_thread, + current_thread=current_thread, + owning_task=owning_task, + current_task=current_task, + ) + super().__init__(message) + self.chain_id = chain_id + self.kind = kind + self.owning_thread = owning_thread + self.current_thread = current_thread + self.owning_task = owning_task + self.current_task = current_task + + def __reduce__(self) -> tuple[Any, ...]: + # ``ChainContention`` crosses the wire via + # `_safely_serialize_exception`. The default + # ``BaseException.__reduce__`` pickles ``(type, args)`` where + # ``args`` is the pre-formatted message — fine for the primary + # ``serializer.dumps`` path, but the type-preserving fallback + # rebuilds via ``cls(*exc.args)``, which our keyword-only + # constructor rejects. ``__reduce__`` returning the structured + # kwargs as a ``(cls, (), state)`` triple keeps both paths + # intact: the structured fields ride the wire, and the message + # is re-composed by ``__init__`` on the receiver. + return ( + _reconstruct_chain_contention, + ( + self.chain_id, + self.kind, + self.owning_thread, + self.current_thread, + self.owning_task, + self.current_task, + ), + ) + + +def _reconstruct_chain_contention( + chain_id: UUID, + kind: Literal["thread", "task", "create_task"], + owning_thread: int | None, + current_thread: int | None, + owning_task: asyncio.Future[Any] | None, + current_task: asyncio.Future[Any] | None, +) -> ChainContention: + """Module-level constructor for `ChainContention.__reduce__`.""" + return ChainContention( + chain_id=chain_id, + kind=kind, + owning_thread=owning_thread, + current_thread=current_thread, + owning_task=owning_task, + current_task=current_task, + ) + + +# public +class ContextVarCollision(WoolError): + """Raised on construction of a second ContextVar under an existing key. + + Keys must be unique within the inferred package namespace. Library + authors should pass ``namespace=`` explicitly when constructing + variables from shared factory code; application code can rely on + the implicit package-name inference. + + Detection is best-effort under garbage collection: the process-wide + variable registry holds `ContextVar` instances weakly, so a + key frees up once its previous instance is collected and a later + construction under that key then succeeds instead of colliding. In + practice `ContextVar` instances are module-level singletons + held for the process lifetime, so a genuine collision always + raises. + """ + + +# public +class TaskFactoryDisplaced(WoolError): + """Raised when Wool's task factory has been displaced by a later one. + + Wool installs its task factory on a loop, composing with any + factory already present. A third-party factory installed *after* + Wool's silently drops Wool's wrapping: child tasks created on that + loop thereafter no longer fork onto fresh chains — copy-on-fork is + lost. A non-forked child inherits its parent's owning-task + identity and trips `wool.ChainContention` on its first + `wool.ContextVar` access, or, when the parent has already + finished, silently reuses the parent's chain identity. + + The displacement is detected reactively: Wool cannot intercept + `loop.set_task_factory`, so it is noticed only on the next + `wool.ContextVar` access (or other path that arms a chain). + Raised unconditionally — not as a warning that callers opt to + promote — because displacement is structurally fatal to chain + propagation across every subsequent task on the loop, materially + different from per-variable wire-state corruption + (`wool.SerializationWarning`) which can degrade gracefully + on individual entries. + + Install Wool's task factory last, or compose factories manually, + to avoid the displacement entirely. + """ diff --git a/wool/src/wool/runtime/context/factory.py b/wool/src/wool/runtime/context/factory.py new file mode 100644 index 00000000..9725afe9 --- /dev/null +++ b/wool/src/wool/runtime/context/factory.py @@ -0,0 +1,632 @@ +from __future__ import annotations + +import asyncio +import contextvars +import logging +import threading +import weakref +from typing import Any +from typing import Callable +from typing import Coroutine +from typing import Generator +from typing import TypeVar +from typing import cast + +import wool +from wool.runtime.context.exceptions import ChainContention +from wool.runtime.context.exceptions import TaskFactoryDisplaced + +_log = logging.getLogger(__name__) + +T = TypeVar("T") + + +def context_is_armed(context: contextvars.Context) -> bool: + """Return ``True`` if *context* carries a Wool chain. + + A `contextvars.Context` in which no `wool.ContextVar` + has been set never holds the Wool-owned context variable at all — it + is *unarmed* and behaves as a plain `contextvars.Context`. + + Inspects an *explicit* `contextvars.Context` (e.g., a + ``copy_context()`` snapshot or a child task's materialised context), + not the active one — the task factory consults it on a child's + freshly copied context before scheduling it. + """ + return wool.__chain__ in context + + +_loops_with_factory: weakref.WeakSet[asyncio.AbstractEventLoop] = weakref.WeakSet() + +# Loops where Wool's factory has been observed displaced. Populated by +# the weakref.finalize callback on Wool's factory object (fires the +# moment the loop drops its reference inside loop.set_task_factory) +# and by ``_release``'s done-callback check (covers the corner where +# the third-party stash holds the reference but doesn't invoke it). +# Consulted by ``ensure_task_factory_installed`` (and by anything +# else that needs to surface displacement loudly to user code) so a +# user-Wool entry point flagged on a displaced loop raises +# `TaskFactoryDisplaced` regardless of whether mount is still +# the trigger path. +_displaced_loops: weakref.WeakSet[asyncio.AbstractEventLoop] = weakref.WeakSet() + +# Every task's running contextvars.Context, keyed by id() and mapped to +# the live task running under it. Lets the task factory detect a context +# shared across concurrently-live tasks (which would silently corrupt +# Wool state — see wool_factory). contextvars.Context is unhashable, so +# id() is the key; it is safe because while an entry exists the +# registered task pins its context alive, so the id cannot be reused. +# _release clears the entry when the task ends. wool_factory materialises +# the context for every task it creates and registers it here — not only +# explicitly-passed ones — so re-passing a live task's own context to a +# second create_task is caught. +_task_contexts: dict[int, asyncio.Future[Any]] = {} + + +class _PendingSentinel: + """Sentinel marking a reserved-but-not-yet-populated context slot. + + Placed into ``_task_contexts`` by ``wool_factory`` *before* + invoking the inner factory so two threads concurrently calling + ``loop.create_task(coro, context=same_armed_ctx)`` cannot both + pass the owner-check, race the registration, and silently share + a chain. The first thread sees ``None``, reserves with the + sentinel under the lock; the second thread sees the sentinel and + treats it as a live owner (raising ``ChainContention``). + """ + + +# Module-level singleton — typed as Future for compatibility with +# _task_contexts's value type but checked via ``is _PENDING`` at the +# use site, so the nominal type is purely for the dict. +_PENDING: Any = _PendingSentinel() + +# Guards (a) the read-modify-write in `install_task_factory` +# (``loop.get_task_factory()`` then ``loop.set_task_factory()``) so two +# threads cannot double-install on one loop, and (b) every read/write of +# `_task_contexts` so the per-task registration/release sequence +# stays consistent. Both critical sections are narrow (~5 lines each), +# so a single lock for both is cheaper than per-table sharding. +_lock = threading.Lock() + + +def _on_factory_collected( + loop_ref: weakref.ref[asyncio.AbstractEventLoop], +) -> None: + """``weakref.finalize`` callback for the Wool factory installed on a loop. + + Fires when Wool's factory object is garbage-collected. On CPython + this is synchronous with the displacement: a third party calling + ``loop.set_task_factory(other)`` drops the loop's last reference + to Wool's factory, the refcount hits zero, and this callback runs + immediately in the displacer's stack frame. + + Filters out legitimate teardown (loop closed / collected) so the + callback fires only on actual displacement. Flags the loop in + `_displaced_loops` so the next user-Wool entry point + consulting the flag raises `TaskFactoryDisplaced` loudly to + user code. Raises from finalize callbacks are caught by the GC + machinery and printed as ``Exception ignored in:`` — not user- + visible exceptions — so the flag-and-surface-later pattern is + needed for a guaranteed user-facing raise. + + Only mark displacement when ``loop.is_running()``. In the + finalize-during-shutdown window (between ``loop.close()`` + initiation and ``is_closed()`` returning True) a closing loop + flagged as displaced logs at DEBUG rather than WARNING, since a + non-running loop has no Wool API surface that could surface the + flag anyway. + """ + loop = loop_ref() + if loop is None or loop.is_closed(): + return + if not loop.is_running(): + _log.debug( + "Wool's task factory finalize fired on non-running loop %r — " + "treating as legitimate teardown rather than displacement.", + loop, + ) + return + _displaced_loops.add(loop) + _log.warning( + "Wool's task factory has been displaced from %r; child tasks " + "no longer fork onto fresh chains. The next wool.ContextVar " + "access on this loop will raise TaskFactoryDisplaced.", + loop, + ) + + +def _release( + task: asyncio.Future[Any], + key: int, + coro: Coroutine[Any, Any, Any] | None, +) -> None: + """Drop *key* from `_task_contexts` once *task* is done. + + Removes the entry only if *task* is still the registered owner: a + sequential reuse of the same context object — a fresh task created + with it after this one finished but before this done-callback ran — + may already hold the slot, and must not be evicted. + + *coro* is the inner coroutine of a task whose coroutine the factory + wrapped in `_forked_scope`. By the time this done-callback + runs the task is finished, so *coro* is either already exhausted + (the wrapper ran and awaited it to completion) or never started + (the wrapper was cancelled before its first step). It is closed + unconditionally: `close` is a no-op on an exhausted coroutine, + and on a never-started one it suppresses the "coroutine was never + awaited" `RuntimeWarning` that would otherwise leak at GC. + *coro* is ``None`` for an unwrapped (unarmed) task, whose coroutine + asyncio owns and closes itself. + + Also runs a displacement check as a backstop to the + `_on_factory_collected` finalize callback: if a third party + stashes Wool's factory reference but never calls it (so the + finalize never fires) and creates new tasks via its own factory, + eventually a pre-displacement Wool-tracked task completes — that + completion runs ``_release`` on this loop and we observe that the + current factory is no longer Wool-wrapped. + """ + with _lock: + if _task_contexts.get(key) is task: + del _task_contexts[key] + if coro is not None: + # Close defensively. If ``coro.close()`` raises (e.g., a + # third-party generator with a bug in ``close()``), the + # ``_release`` body must continue so the displacement backstop + # below still runs. Without this guard, asyncio would surface + # the raise as "Exception in callback" and silently skip the + # displacement check, defeating the backstop for the + # stash-but-don't-call corner. + try: + coro.close() + except (KeyboardInterrupt, SystemExit): # pragma: no cover + raise + except BaseException: # pragma: no cover + _log.debug("coro.close() raised during _release; ignored", exc_info=True) + try: + loop = task.get_loop() + except RuntimeError: # pragma: no cover — orphaned future has no loop + # ``get_loop`` raises on an orphaned future. Nothing to check. + return + if loop in _loops_with_factory and loop not in _displaced_loops: + if not _is_wool_factory(loop.get_task_factory()): + _displaced_loops.add(loop) + _log.warning( + "Wool's task factory has been displaced from %r " + "(detected at task-completion time); the next " + "wool.ContextVar access on this loop will raise " + "TaskFactoryDisplaced.", + loop, + ) + + +async def _forked_scope(coro: Coroutine[Any, Any, T]) -> T: + """Run *coro* under a freshly-forked chain. + + The Wool task factory wraps every *armed* child coroutine in this + scope (an unarmed task runs its coroutine bare — there is no + context to fork, and wrapping would make the task's coroutine + identity diverge from a plain asyncio task). Running inside the new + task's own `contextvars.Context`, it forks the inherited + context — minting a fresh chain UUID and adopting the running + thread as owner — so a child task never shares its parent's chain. + """ + # The factory only wraps an armed child in this scope (see the gate + # at the wool_factory definition), so the chain is armed by + # construction; a bare ``get`` raises ``LookupError`` loudly if the + # gate-and-wrap drift ever breaks. + context = wool.__chain__.get() + # ``mount`` stamps owners and installs the task factory; fork + # produces an owner-less Chain on a fresh chain id, so routing it + # through mount is the canonical "make this fork live" step. + context._fork().mount() + return await coro + + +# Memoize per-factory `_is_wool_factory` outcomes, with built-in +# cycle detection. Wool's stamps (``__wool_wrapped__``, +# ``__wool_inner__``) are written once at install time and never +# mutated by Wool afterward, so the result is invariant under any +# code path Wool initiates. Caching short-circuits the steady-state +# hot path (``ensure_task_factory_installed`` fires from every +# ``wool.ContextVar.set()``'s ``mount``). A per-walk ``seen`` set +# defends against an adversarial third-party wrapper that points +# ``__wool_inner__`` back into the chain — without it, the walk +# spins indefinitely under the install lock and freezes every +# install attempt across the process. +_wool_factory_cache: weakref.WeakKeyDictionary[Callable[..., Any], bool] = ( + weakref.WeakKeyDictionary() +) + + +def _is_wool_factory(factory: Any) -> bool: + """Return ``True`` if any layer of *factory*'s composition is wool. + + Walks the ``__wool_inner__`` chain Wool stamps on its wrapper — + each Wool wrap remembers the factory it composed over via this + attribute, terminated by ``None`` at the bottom. The chain lets the + idempotency check in `install_task_factory` detect Wool even + when buried under a third-party factory installed *over* a prior + Wool install: ``wool → third-party → wool`` would otherwise pass + the outer-only check and re-wrap into a double-fork hazard. + + Memoized via a `weakref.WeakKeyDictionary` and cycle-guarded + by a per-walk ``seen`` set. + """ + cached = _wool_factory_cache.get(factory) + if cached is not None: + return cached + chain: list[Any] = [] + seen: set[int] = set() + current: Any = factory + result = False + while current is not None and id(current) not in seen: + seen.add(id(current)) + chain.append(current) + # Mid-walk cache hit short-circuits. + cached = _wool_factory_cache.get(current) + if cached is not None: + result = cached + break + if getattr(current, "__wool_wrapped__", False): + result = True + break + current = getattr(current, "__wool_inner__", None) + # Memoize every layer we walked. Layers that fall out (e.g., not + # weakref-able) silently skip — the next call re-walks them. + for f in chain: + try: + _wool_factory_cache[f] = result + except TypeError: # pragma: no cover — non-weakref-able layer; skip + pass + return result + + +# public +def install_task_factory( + loop: asyncio.AbstractEventLoop | None = None, +) -> None: + """Install Wool's task factory on the given (or running) loop. + + Composes with an existing factory if one is set, so that asyncio + child tasks created via ``create_task`` fork the parent's Wool + chain onto a fresh chain. Idempotent — a subsequent call on a + loop that already has the Wool-wrapped factory installed is a + no-op. The first `wool.ContextVar.set` self-installs the + factory on the running loop, so user code that touches Wool's API + without first calling `install_task_factory` still gets + fork-on-task semantics for tasks created after that first set. + + **Ordering contract** — If a user installs their own task factory + *after* Wool's, Wool's wrapping of child coroutines is dropped + and copy-on-fork breaks silently for subsequently-created tasks. + Install Wool's factory last (or compose manually) when other + libraries also want a factory on the same loop. + + **Composed-factory contract** — When Wool composes with an existing + factory, it forwards to that inner factory both ``context=`` and + whatever keyword arguments the running loop handed Wool. Wool + always supplies ``context=`` — it materialises each task's + `contextvars.Context` itself so the chain-contention guard + can register it. CPython's ``asyncio`` additionally hands the + factory ``name=``; other conformant loops (uvloop, for one) apply + the task name themselves and do not forward it — so a composed + inner factory must accept arbitrary ``**kwargs`` and must not + *depend* on ``name=`` being present. An inner factory written to + the legacy two-argument ``(loop, coro)`` signature raises + `TypeError` under composition and is unsupported — this + holds even on Python versions where bare stdlib would have called + that factory without ``context=``, because Wool's guard requires + the explicit context. + + Installation is one-way — Wool does not provide an uninstall + path; the wrapped factory stays on the loop until the loop is + closed. + + See ``wool/src/wool/runtime/context/README.md`` for the model + context. + + :param loop: + The event loop to install the factory on. When ``None`` (the + default), the running loop is resolved via + `asyncio.get_running_loop`; calling with ``loop=None`` + outside a running loop raises `RuntimeError`. The call + is idempotent — installing on a loop that already has the + Wool-wrapped factory is a no-op — and composes with any + existing factory by stamping ``__wool_inner__`` so the prior + factory is invoked underneath Wool's wrapper. + """ + if loop is None: + try: + loop = asyncio.get_running_loop() + except RuntimeError as e: + raise RuntimeError( + "install_task_factory() with loop=None must run inside a " + "running event loop, or pass loop= explicitly" + ) from e + + # Guard the read-modify-write so two threads installing on the same + # loop cannot both pass the wool-detection check and double-install + # into a ``wool → wool`` composition. Today this is reactive + # protection against a third-party factory installed lazily from + # another thread; the lock is narrow enough that single-threaded + # callers pay no measurable cost. + with _lock: + existing = loop.get_task_factory() + if existing is not None and _is_wool_factory(existing): + _log.debug(f"Wool-composed task factory already installed on {loop}") + return + # Inline default — a one-line passthrough to ``asyncio.Task`` + # that gives ``inner`` a uniform non-None value, avoiding a + # named helper whose only purpose was to be that placeholder. + inner = ( + existing + if existing is not None + else (lambda lp, cr, **kw: asyncio.Task(cr, loop=lp, **kw)) + ) + + wool_factory = _build_wool_factory(inner) + wool_factory.__wool_wrapped__ = True # pyright: ignore[reportFunctionMemberAccess] + wool_factory.__wool_inner__ = existing # pyright: ignore[reportFunctionMemberAccess] + loop.set_task_factory(wool_factory) # pyright: ignore[reportArgumentType] + # Register the loop for displacement monitoring — including loops + # set up by a direct install_task_factory() call (e.g., worker- + # process loops), not only those routed through + # ensure_task_factory_installed. + _loops_with_factory.add(loop) + # Clear the displacement flag atomically with the + # re-install so a recovery path (``install_task_factory(loop)`` + # after the displacement was detected) self-heals. Without + # this discard, ``ensure_task_factory_installed`` would keep + # short-circuiting on ``if loop in _displaced_loops: + # raise TaskFactoryDisplaced(...)`` even after the user + # explicitly puts Wool's factory back on top. + _displaced_loops.discard(loop) + # Detect displacement at the moment it happens. When a third + # party calls ``loop.set_task_factory(other)`` without + # composing through ``install_task_factory``, the loop drops + # its reference to ``wool_factory`` and the refcount hits + # zero synchronously (CPython); the ``weakref.finalize`` + # callback fires right there in the displacer's stack frame + # and flags the loop so the next user-Wool entry point + # raises. The ``loop.is_closed()`` filter avoids false + # positives at legitimate loop shutdown. + weakref.finalize(wool_factory, _on_factory_collected, weakref.ref(loop)) + if existing is None: + _log.debug(f"Wool task factory installed on {loop}") + else: + _log.debug( + f"Wool task factory composed with existing factory {existing} on {loop}", + ) + + +def _build_wool_factory( + inner: Callable[..., asyncio.Future[Any]], +) -> Callable[..., asyncio.Task[Any]]: + """Return Wool's task-factory closure composing over *inner*. + + Hoisted out of `install_task_factory` so the lock-held + install body stays narrow. The returned closure captures *inner* + by reference and exposes ``__wool_wrapped__`` / ``__wool_inner__`` + on itself; `install_task_factory` stamps them after this + returns. + + *inner* is typed as ``Callable[..., asyncio.Future[Any]]`` to + match typeshed's ``asyncio.events._TaskFactory`` protocol — what + ``loop.get_task_factory()`` returns. Wool's own ``wool_factory`` + closure unconditionally materialises an `asyncio.Task`, + so the outer return stays ``Task[Any]``. + """ + + def wool_factory( + loop: asyncio.AbstractEventLoop, + coro: Coroutine[Any, Any, Any] | Generator[Any, None, Any], + *, + context: contextvars.Context | None = None, + **kwargs: Any, + ) -> asyncio.Task[Any]: + # Widen to ``Coroutine | Generator`` to satisfy typeshed's + # ``_CoroutineLike[_T]`` contravariant parameter — the + # ``Generator`` arm exists for pre-3.8 generator-coroutines + # and is unreachable from asyncio's modern create_task path, + # but the static type must accept it for ``wool_factory`` to + # be a valid ``_TaskFactory``. Narrow back to ``Coroutine`` + # for the body. + coro = cast(Coroutine[Any, Any, Any], coro) + if context is None: + # Materialise the context the task will run in here, rather + # than letting asyncio.Task copy_context() internally — it + # is the same call one frame up — so every task's context + # is a known object the chain-contention guard registers + # below. A fresh copy is unique and cannot alias a live + # entry, so an ordinary task never trips the guard. + context = contextvars.copy_context() + reserved_pending = False + else: + # An *armed* contextvars.Context already driving a live task + # must not be handed to a second create_task: Wool chain + # state lives in the one ``wool.__chain__`` binding the context + # holds, so two tasks running in it would read and write + # each other's context and silently corrupt both chains. + # The rejection is armed-gated — an unarmed shared context + # carries no chain to corrupt and is permitted, exactly as + # stdlib asyncio permits it; if it is armed later, the + # chain-owner guard (assert_chain_owner) catches the + # second task then. Because every task's context is registered + # below, this catches a context re-passed from any live task + # — including one obtained via ``task.get_context()`` — not + # only explicitly-created ones. + # + # Reserve the slot with a ``_PENDING`` sentinel + # *inside the same locked section* as the owner check, so + # two threads concurrently calling + # ``loop.create_task(coro, context=same_armed_ctx)`` cannot + # both pass the "owner is None" check and race the + # registration. The first thread reserves; the second sees + # the pending reservation and treats it as a live owner. + armed_at_entry = context_is_armed(context) + with _lock: + owner = _task_contexts.get(id(context)) + if owner is None and armed_at_entry: + _task_contexts[id(context)] = _PENDING # type: ignore[assignment] + reserved_pending = True + else: + reserved_pending = False + if owner is not None and ( + owner is _PENDING or (not owner.done() and armed_at_entry) + ): + # Close the un-awaited coroutine to suppress the + # "coroutine was never awaited" RuntimeWarning that + # would otherwise leak at GC, then fail loud. + coro.close() + # ``context[wool.__chain__]`` is the armed Wool Chain — by + # the ``context_is_armed`` check above — so its + # ``chain_id`` is the one being doubly-driven. Surface + # it in the exception for diagnostics. Also + # surface the owning task (the registered owner, if + # any) and the current task so the structured + # diagnostic fields are populated when available. + wool_ctx = context[wool.__chain__] + try: + current = asyncio.current_task() + except RuntimeError: # pragma: no cover — always under a running loop + current = None + raise ChainContention( + chain_id=wool_ctx.id, + kind="create_task", + owning_task=owner if isinstance(owner, asyncio.Future) else None, + current_task=current, + ) + kwargs["context"] = context + # Wrap the coroutine in ``_forked_scope`` only when the creating + # context is armed: an armed child must fork onto a fresh chain. + # An unarmed task has no context to fork, and wrapping it would + # make Task.get_coro()/repr()/auto-name reflect ``_forked_scope`` + # instead of the user coroutine — a divergence from a plain + # asyncio task — so it runs the coroutine bare. An armed task + # does take that get_coro()/repr() divergence: it is the + # accepted cost of copy-on-fork, not an oversight — wrapping is + # the only way to run the child under its own forked chain. + scope_coro: Coroutine[Any, Any, Any] | None + # Gate the wrap on the *materialised child*'s armed state, not + # the caller's: when the caller is armed but explicitly passes + # an unarmed ``contextvars.Context`` as ``context=``, only the + # child's view matters — wrapping an unarmed child would make + # ``Task.get_coro()`` / ``repr()`` reflect ``_forked_scope`` for + # no reason (the inner fork would be a no-op). + if not context_is_armed(context): + scope_coro = None + # ``inner`` is statically typed to return ``Future | Task`` + # (the ``_TaskFactory`` protocol); a task factory in + # practice always returns a ``Task``. + try: + task = inner(loop, coro, **kwargs) + except BaseException: + # The user coroutine never reached a task — close it + # to suppress the "coroutine was never awaited" warning + # at GC, then re-raise. Also clean up the + # _PENDING reservation if we made one, so the slot + # doesn't pin forever. + coro.close() + if reserved_pending: # pragma: no cover — no pending slot when unarmed + with _lock: + if _task_contexts.get(id(context)) is _PENDING: + del _task_contexts[id(context)] + raise + else: + scope_coro = _forked_scope(coro) + try: + task = inner(loop, scope_coro, **kwargs) + except BaseException: + # ``inner`` raising leaves the wrapper coroutine (and + # therefore the user coroutine inside it) unawaited. + # Close both to suppress the "coroutine was never + # awaited" warning at GC, symmetric with the proactive + # close on the re-passed-armed-context branch above. + scope_coro.close() + coro.close() + if reserved_pending: + with _lock: + if _task_contexts.get(id(context)) is _PENDING: + del _task_contexts[id(context)] + raise + key = id(context) + # An eager inner factory can step the task to completion inside + # the inner(...) call above, before this registration. That is + # harmless: a re-passed already-done context is caught by the + # ``not owner.done()`` check, not by registration ordering, and + # the done-callback below still fires (add_done_callback on an + # already-done task schedules via call_soon) to clear the slot. + # Replace the ``_PENDING`` reservation we placed in the + # owner check above with the real task; for non-reserved + # contexts the entry is fresh. + with _lock: + _task_contexts[key] = task + task.add_done_callback( + lambda finished, k=key, c=scope_coro: _release(finished, k, c) + ) + return task # pyright: ignore[reportReturnType] + + return wool_factory + + +def ensure_task_factory_installed() -> None: + """Self-install Wool's task factory on the running loop if absent. + + Lets user code that touches Wool without first calling + `install_task_factory` still get fork-on-task semantics for + tasks created after the first Wool API contact. No-ops in sync + contexts (no running loop). The `_loops_with_factory` weak + set short-circuits the lookup to a single membership check. Both + `install_task_factory` and this function add the loop to the + set, so a loop set up by a direct `install_task_factory` + call — e.g., a worker-process loop — is displacement-monitored + exactly like one self-installed here. + + If a later call finds Wool's factory has been displaced from a + loop it was previously installed on — a third-party factory + installed after Wool's — it raises `TaskFactoryDisplaced`, + since copy-on-fork is silently lost for tasks created after the + displacement. Displacement is detected via three converging + paths: `_displaced_loops` (set by + `_on_factory_collected`'s ``weakref.finalize`` callback the + moment the loop drops Wool's factory reference, and by + `_release`'s done-callback backstop), the post-displacement + factory inspection at this site, and any direct re-check the + user-Wool entry point performs. + """ + try: + loop = asyncio.get_running_loop() + except RuntimeError: + return + if loop in _displaced_loops: + # Flagged by ``_on_factory_collected`` (synchronous with the + # ``loop.set_task_factory`` that displaced us on CPython) or + # by ``_release`` (the stash-but-don't-call corner). Surface + # loudly to user code regardless of whether this site's own + # factory inspection also catches it. + raise TaskFactoryDisplaced( + "Wool's task factory was displaced by a task factory " + "installed after it; child tasks on this loop no longer " + "fork onto fresh chains (copy-on-fork is lost). Install " + "Wool's task factory last, or compose factories manually." + ) + if loop in _loops_with_factory: + # Wool's factory was installed on this loop earlier. If the + # current factory is no longer Wool-wrapped, a third-party + # factory was installed after Wool's and silently dropped + # copy-on-fork for every task created since — a chain that no + # longer forks per task is a latent correctness bug, so fail + # loud rather than let it pass unnoticed. + current = loop.get_task_factory() + if not _is_wool_factory(current): + _displaced_loops.add(loop) + raise TaskFactoryDisplaced( + "Wool's task factory was displaced by a task factory " + "installed after it; child tasks on this loop no longer " + "fork onto fresh chains (copy-on-fork is lost). Install " + "Wool's task factory last, or compose factories manually." + ) + return + install_task_factory(loop) diff --git a/wool/src/wool/runtime/context/guard.py b/wool/src/wool/runtime/context/guard.py new file mode 100644 index 00000000..1ce1f5fd --- /dev/null +++ b/wool/src/wool/runtime/context/guard.py @@ -0,0 +1,109 @@ +from __future__ import annotations + +import asyncio +import threading +from typing import TYPE_CHECKING +from typing import Any + +from wool.runtime.context.exceptions import ChainContention + +if TYPE_CHECKING: + from wool.runtime.context.chain import Chain + + +def assert_chain_owner(chain: Chain | None) -> None: + """Raise `ChainContention` if *chain* is entered by a thread + or task that does not own it. + + Runs both chain-contention checks in order — OS thread first, + asyncio task second. The ordering is load-bearing: a cross-thread + access fails on the thread guard before the task guard runs, so + foreign-thread access never reaches `asyncio.current_task` + on a thread with no running loop. + + A no-op for an unarmed context (*chain* is ``None``). On the + thread dimension, also a no-op when the calling thread is the + chain's owner (the common case: tasks, callbacks, and timers on + a single event loop all run on the loop's thread). On the task + dimension, also a no-op when the chain has no live owner task + (armed in synchronous code or a `wool.to_thread` worker, or + the owner has finished), or when the caller is not running inside + a task at all — a bare event-loop callback or timer shares its + scheduling scope's chain and runs serially on the loop thread. + Calling outside a running event loop is treated the same as + calling outside any task — the check is a no-op. + + The thread owner is identified by `threading.get_ident`, + whose integer the OS may reuse after a thread exits. A context + that outlived its owning thread could therefore admit a later + thread assigned the same identifier. For an event-loop chain, the + owner is the loop's thread, which lives for the program's + lifetime, so the window does not arise; a `wool.to_thread` + fork is owned by a shorter-lived worker thread, but that forked + context is detached and dropped when the offload returns, so + nothing observes it once the worker thread has exited. + + The task dimension catches the case the task factory's + copy-on-fork cannot — two tasks handed the same + `contextvars.Context` while it was still unarmed, one of + which later arms it: the second task to touch the chain does not + own it and fails loudly. + + :param chain: + The active wool `~wool.runtime.context.chain.Chain`, or + ``None`` for an unarmed context. + :raises ChainContention: + If *chain* is armed and the calling thread or running task is + not its owner. + """ + if chain is None: + return + current_thread = threading.get_ident() + if chain.thread != current_thread: + raise ChainContention( + chain_id=chain.id, + kind="thread", + owning_thread=chain.thread, + current_thread=current_thread, + ) + owning_task = _resolve_owning_task(chain) + if owning_task is None: + return + try: + current_task = asyncio.current_task() + except RuntimeError: + return + if current_task is None or current_task is owning_task: + return + raise ChainContention( + chain_id=chain.id, + kind="task", + owning_task=owning_task, + current_task=current_task, + ) + + +def _resolve_owning_task(context: Chain) -> asyncio.Future[Any] | None: + """Resolve *context*'s owning task, or ``None`` if it has no live owner. + + Returns ``None`` when the context was armed outside any task + (synchronous code, or a `wool.to_thread` worker thread), when + the owning task has been garbage-collected, or when it has + finished. In all three cases there is no live task to arbitrate + against — the thread-owner check in `assert_chain_owner` + and, for a finished owner, the unconditional owner re-stamp + performed by `~wool.runtime.context.chain.Chain.mount` + (invoked from `wool.ContextVar.set`) cover what remains. + + :param context: + The active wool `~wool.runtime.context.chain.Chain`. + Always non-``None`` — `assert_chain_owner` short-circuits + on an unarmed context before this helper runs. + """ + ref = context.task + if ref is None: + return None + task = ref() + if task is None or task.done(): + return None + return task diff --git a/wool/src/wool/runtime/context/manifest.py b/wool/src/wool/runtime/context/manifest.py new file mode 100644 index 00000000..c712d6b8 --- /dev/null +++ b/wool/src/wool/runtime/context/manifest.py @@ -0,0 +1,479 @@ +"""Wool's wire-abstraction types — the manifests that ride in place of +live context state. + +A *manifest* is the wire-side abstraction of a live type: it travels in +place of the concrete thing and knows nothing of the live runtime. This +module houses both. `ContextVarManifest` is the identity/data layer of a +single variable — its ``(namespace, name)`` key, default, and backing — +which `wool.ContextVar` subclasses to add live behavior. `ChainManifest` +is the decoded snapshot of a whole chain. ("Stub" is reserved for the +``_stub`` *state*, which either a bare manifest or a promoted variable +can occupy; it is no longer a type.) + +A `ChainManifest` is the successfully-deserialised state of a wire +`~wool.protocol.ChainManifest`: the variable bindings, reset signals, +and stub pins recovered from the wire. Distinct from `Chain` (the live +chain): the manifest holds its values inline rather than in the +contextvar backings, and it knows nothing about the live runtime — it +neither reads backings nor arms a context. It is the staging form on +both sides of the wire. + +`ChainManifest.from_protobuf` and `ChainManifest.to_protobuf` are the +codec, the two `protocol.ChainManifest` ↔ `ChainManifest` halves. +Crossing the `ChainManifest` ↔ `Chain` boundary — the half that touches +the backings — lives on `Chain`: `Chain.to_manifest` snapshots a live +chain into a manifest, and `Chain.from_manifest` drains a manifest into +the backings and arms it. + +`Frame.from_protobuf` decodes the optional wire ``context`` field and +stores either the manifest or a `ChainSerializationError` on the decoded +frame. `Frame.mount` drives the install through `Chain.from_manifest`. +User code does not construct manifests; the frame layer is the only +entry into the codec. +""" + +from __future__ import annotations + +import contextvars +import warnings +from dataclasses import dataclass +from typing import TYPE_CHECKING +from typing import Any +from typing import Generic +from typing import NoReturn +from typing import SupportsIndex +from typing import TypeVar +from uuid import UUID +from uuid import uuid4 + +import wool +from wool import protocol as _protocol +from wool.runtime.context.exceptions import ChainSerializationError +from wool.runtime.context.exceptions import SerializationWarning +from wool.runtime.context.registry import lock +from wool.runtime.context.registry import var_registry +from wool.runtime.typing import Undefined +from wool.runtime.typing import UndefinedType + +if TYPE_CHECKING: + from wool.runtime.serializer import Serializer as _Serializer + +T = TypeVar("T") + + +class ContextVarManifest(Generic[T]): + """The wire-abstraction of a single context variable. + + Carries the ``(namespace, name)`` key, the constructor default, and + the backing `contextvars.ContextVar` — the identity a variable holds + independently of any live `~wool.runtime.context.chain.Chain`. + `~wool.runtime.context.var.ContextVar` subclasses this and adds the + chain-bound `get`/`set`/`reset` behavior; a bare + ``ContextVarManifest`` is what `resolve_stub` mints on the receiver + when a wire frame references a key no local variable has declared + yet. + + Instances are process-wide singletons per key, registered in + `~wool.runtime.context.registry.var_registry`. The ``_stub`` flag + records whether an instance is still an undeclared placeholder — a + *state* a bare manifest or a promoted `ContextVar` can occupy; + promotion (see `~wool.runtime.context.var.promote`) clears it. + + Not part of the public surface — user code constructs + `wool.ContextVar`, never this base directly. + """ + + __slots__ = ( + "_name", + "_namespace", + "_key", + "_default", + "_stub", + "_backing", + "__weakref__", + ) + + _name: str + _namespace: str + _key: tuple[str, str] + _default: T | UndefinedType + _stub: bool + _backing: contextvars.ContextVar[T | UndefinedType] + + def __repr__(self) -> str: + default_part = ( + f" default={self._default!r}" if self._default is not Undefined else "" + ) + return ( + f"" + ) + + def __reduce_ex__(self, _protocol: SupportsIndex) -> NoReturn: + """Reject vanilla pickling. + + ContextVar identity is registered against the process-wide + `~wool.runtime.context.registry.var_registry`; restoring an + instance outside Wool's dispatch path bypasses the stub- + promotion and collision-detection that + `~wool.runtime.context.var.ContextVar._reconstitute` relies on. + Wool's own pickler consults ``reducer_override`` (and therefore + `~wool.runtime.context.var.ContextVar.__wool_reduce__`) before + ``__reduce_ex__``, so this guard is invisible to Wool's + serialization. + + `copy.copy` and `copy.deepcopy` also route + through ``__reduce_ex__`` and are rejected for the same + reason — a registry-bound ContextVar has no meaningful copy + semantics. + + :raises TypeError: + Always. + """ + raise TypeError( + "wool.ContextVar cannot be pickled via vanilla pickle/cloudpickle; " + "it is serialized automatically when dispatched through Wool's " + "runtime." + ) + + @property + def name(self) -> str: + """The variable's name, matching the `contextvars.ContextVar` API.""" + return self._name + + @property + def namespace(self) -> str: + """The namespace this variable belongs to.""" + return self._namespace + + @classmethod + def _build( + cls, + key: tuple[str, str], + default: Any, + *, + stub: bool, + ) -> ContextVarManifest[Any]: + """Construct a `ContextVarManifest` instance with field assignment. + + Single source of truth for the + ``object.__new__`` + per-field-assignment idiom shared by + `~wool.runtime.context.var.ContextVar.__new__` (the user-facing + construction path, which builds the subtype) and `resolve_stub` + (the wire-boundary path). The instance is *not* registered in + `~wool.runtime.context.registry.var_registry` — callers do that + under the registry lock. The backing stdlib variable is created + once and shared by every chain for the process lifetime; it + carries no ``contextvars``-level default — + `~wool.runtime.context.var.ContextVar.get` owns the default- + resolution ladder, and "unset" is the + `~wool.runtime.typing.Undefined` sentinel value. + """ + namespace, name = key + instance: ContextVarManifest[Any] = object.__new__(cls) + instance._name = name + instance._namespace = namespace + instance._key = key + instance._default = default + instance._stub = stub + instance._backing = contextvars.ContextVar( + f"__wool_var__:{namespace}:{name}", default=Undefined + ) + return instance + + +def resolve_stub( + key: tuple[str, str], + *, + default: Any = Undefined, +) -> ContextVarManifest[Any]: + """Return the `ContextVarManifest` registered under *key*, minting a + stub-state one if no authoritative declaration exists yet. + + Unifies the two ingress paths that may encounter an unregistered + variable key on a receiving process: the pickle-embedded + `ContextVar` instance path (via + `~wool.runtime.context.var.ContextVar._reconstitute`) and the + chain-manifest path (via `ChainManifest.from_protobuf`). Both route + through this helper so a lazy-import receiver converges on a single + instance per key regardless of whether the value arrived as a bare + wire entry or embedded in a pickled variable reference. + + A freshly created stub is registered in + `~wool.runtime.context.registry.var_registry` (a + `weakref.WeakValueDictionary`, so it needs a strong referent to + survive). It is held by the embedding object graph (the + pickle-embedded ingress) or by the decoded + `~wool.runtime.context.chain.Chain`'s ``stubs`` (the chain-manifest + ingress) until the receiver's user code declares the real variable, + at which point `~wool.runtime.context.var.ContextVar.__new__` + promotes it in place. A promoted variable remains in ``stubs`` for + the rest of that chain's life — harmless, since a declared variable + is a process-wide singleton anyway. If the receiver never declares + the variable, it is collected with whatever held it and the + propagated value is dropped. + + Pass *default* to seed the constructor default before promotion when + that information is available on the ingress side (the pickle path + carries it; the chain-manifest path does not). + """ + with lock: + existing = var_registry.get(key) + if existing is not None: + # Fold the supplied default into a default-less stub: the + # chain-manifest path supplies no default (wire bytes don't + # carry it), but the pickle-embedded path does. Whichever + # ingress encounters the key second must not silently + # discard a known default. + if ( + existing._stub + and existing._default is Undefined + and default is not Undefined + ): + existing._default = default + return existing + # The backing variable is created with the stub and preserved + # across promotion (``promote`` keeps the same instance), so a + # value applied to the backing before the receiver declares the + # real variable survives the promotion. + stub = ContextVarManifest._build(key, default, stub=True) + var_registry[key] = stub + return stub + + +@dataclass +class ChainManifest: + """A deserialised, successfully-decoded chain manifest. + + Carries the decoded chain state — the var-to-value mapping and the + reset signals — inline, ready for `Chain.from_manifest` to drain into + the backings. A decode *failure* is never a ChainManifest: + `Frame.from_protobuf` captures it as a + `~wool.runtime.context.exceptions.ChainSerializationError` instead. + + Present on a `Frame` iff the frame carried non-empty receive-side + state: a chain manifest with bindings or resets. An empty chain + manifest decodes to ``None`` so `Frame.mount` can no-op. + + The transition into a live `Chain` is one-way and lives on `Chain`: + `Chain.from_manifest` drains `vars` into the backing variables, + merges the decoded state onto a live receiver (or seeds a fresh chain + when unarmed), and arms the result. `Frame.mount` is the entry point. + """ + + id: UUID + vars: dict[ContextVarManifest[Any], Any] + resets: frozenset[tuple[str, str]] + stubs: frozenset[ContextVarManifest[Any]] + + @classmethod + def empty(cls) -> ChainManifest: + """Return a fresh empty manifest carrying a new chain id. + + The default for `DispatchSession.decoded` when the initial + dispatch frame carries no chain manifest. A present-but-empty + manifest keeps ``session.decoded.vars`` an empty dict so the + backpressure hook's attribute access stays shape-consistent + whether or not the inbound frame carried chain state. Returned + fresh per call to avoid sharing mutable state across dispatches. + """ + return cls( + id=uuid4(), + vars={}, + resets=frozenset(), + stubs=frozenset(), + ) + + @classmethod + def from_protobuf( + cls, + wire: _protocol.ChainManifest, + *, + serializer: _Serializer, + ) -> ChainManifest: + """Decode a wire `~wool.protocol.ChainManifest` into a manifest. + + Pure decode — resolves variable identities and deserialises + values but never touches a `contextvars.Context`. Walks + ``wire.vars`` once: each variable identity resolves through the + process-wide `wool.ContextVar` registry (or registers a + stub if undeclared, pinned into the returned manifest). An + entry carrying a ``value`` populates `vars`; an + entry with no ``value`` records a reset-to-no-value signal in + `resets`. + + Decode failures emit `SerializationWarning` and the + offending entry is skipped — surviving entries decode normally. + A malformed wire chain ID falls back to a fresh UUID. Under + strict mode the failures aggregate into a single + `ChainSerializationError` raised after the decode loop — no + partial manifest is surfaced. `Frame.from_protobuf` captures + that error as the frame's ``chain_manifest`` value instead of + letting it propagate, so `Frame.mount` (and, for the initial + dispatch frame, `DispatchSession.__aenter__`) can raise it or + walk-and-append it onto an exception payload's ``__context__`` + chain rather than preempting the payload. + + :raises ChainSerializationError: + Under strict mode, when one or more entries fail to decode. + """ + failures: list[SerializationWarning] = [] + # Chain-id parse failure is a *structural* protocol + # error, distinct from per-var data errors. Raise + # unconditionally regardless of strict-mode warning filter: + # without a valid chain id the receiver cannot correlate + # subsequent frames against the same logical caller, and a + # silently-replaced ``uuid4()`` would route follow-up frames + # to a fresh cached contextvars.Context (silent state loss). + try: + chain_id = UUID(hex=wire.id) if wire.id else uuid4() + except ValueError as e: + raise ChainSerializationError( + SerializationWarning( + f"Failed to decode chain id {wire.id!r}: {e}", + cause=e, + direction="decode", + ), + ) from e + vars: dict[ContextVarManifest[Any], Any] = {} + resets: set[tuple[str, str]] = set() + stubs: set[ContextVarManifest[Any]] = set() + failed_keys: set[tuple[str, str]] = set() + seen_keys: set[tuple[str, str]] = set() + for entry in wire.vars: + var_key = (entry.namespace, entry.name) + # Duplicate keys on the wire are explicitly undefined + # behaviour; the first occurrence wins. A value-failed + # first occurrence pins the key in ``failed_keys``, so a + # later legitimate reset signal for the same key is + # dropped. Emit a typed warning so the encoder bug surfaces + # without disturbing the first-occurrence-wins semantics. + if var_key in seen_keys: + try: + warnings.warn( + SerializationWarning( + f"Duplicate wool.ContextVar key {var_key!r} in wire " + f"context — second occurrence ignored", + var_key=var_key, + direction="decode", + ), + stacklevel=2, + ) + except SerializationWarning as raised: + failures.append(raised) + continue + seen_keys.add(var_key) + var = resolve_stub(var_key) + if var._stub: + stubs.add(var) + if entry.HasField("value"): + try: + vars[var] = serializer.loads(entry.value) + except Exception as e: + failed_keys.add(var_key) + try: + warnings.warn( + SerializationWarning( + f"Failed to deserialize wool.ContextVar " + f"{var_key!r}: {e}", + cause=e, + var_key=var_key, + direction="decode", + ), + stacklevel=2, + ) + except SerializationWarning as raised: + failures.append(raised) + if var_key in failed_keys: + # A variable whose value failed to deserialize is absent + # from ``vars``; recording its reset signal would + # let a subsequent merge read a phantom reset and drop + # a live binding on the receiver. + continue + if not entry.HasField("value"): + # No value on a surviving entry: the variable is in a + # reset-to-no-value state in the sender's chain. + resets.add(var_key) + # For every reset-only entry, ensure the + # receiver-side ContextVar is pinned via ``stubs`` + # regardless of ``_stub`` status. Without this pin a + # garbage-collectible non-stub instance can drop + # between decode and apply, leaving + # ``var_registry.get(key)`` returning ``None`` and the + # reset silently swallowed. + stubs.add(var) + if failures: + raise ChainSerializationError(*failures) + return cls( + id=chain_id, + vars=vars, + resets=frozenset(resets), + stubs=frozenset(stubs), + ) + + def to_protobuf( + self, + *, + serializer: _Serializer | None = None, + ) -> _protocol.ChainManifest: + """Serialise this manifest to a wire `~wool.protocol.ChainManifest`. + + The encode half of the send path: `Chain.to_manifest` captures the + live backing values inline on this manifest, and this method turns + them into wire bytes. Each entry in `vars` emits one + `~wool.protocol.ContextVar` with a populated ``value``; each key in + `resets` emits a value-less entry so the reset propagates + regardless of source. The two source sets are disjoint by + construction (asserted in `Chain.to_manifest`), so every key emits + exactly one entry. + + Per-variable encode failures emit `SerializationWarning` and the + offending key is suppressed entirely — its reset signal included, + so the receiver cannot read a phantom reset for a variable whose + value failed to ship. Under strict mode + (``PYTHONWARNINGS=error::wool.SerializationWarning``) the failures + aggregate into a single `ChainSerializationError` raised after the + loop. + + Encoding is deterministic: entries are emitted in sorted key order, + so identical manifest state encodes to byte-identical frames across + processes — preserving content-addressed caching and replay-style + fingerprinting despite hash-randomised set iteration. + + :raises ChainSerializationError: + Under strict mode, when one or more variables fail to encode. + """ + if serializer is None: + serializer = wool.__serializer__ + wire = _protocol.ChainManifest(id=self.id.hex) + failures: list[SerializationWarning] = [] + encoded_values: dict[tuple[str, str], bytes] = {} + failed_keys: set[tuple[str, str]] = set() + for var, value in self.vars.items(): + try: + encoded_values[var._key] = serializer.dumps(value) + except Exception as e: + failed_keys.add(var._key) + try: + warnings.warn( + SerializationWarning( + f"Failed to serialize wool.ContextVar {var._key!r}: {e}", + cause=e, + var_key=var._key, + direction="encode", + ), + stacklevel=2, + ) + except SerializationWarning as raised: + failures.append(raised) + # Failed keys are suppressed entirely — no phantom resets (see the docstring). + reset_keys = {key for key in self.resets if key not in failed_keys} + # Sorted emission — the byte-determinism guarantee in the docstring. + for var_key in sorted(set(encoded_values) | reset_keys): + namespace, name = var_key + entry = wire.vars.add(namespace=namespace, name=name) + if var_key in encoded_values: + entry.value = encoded_values[var_key] + if failures: + raise ChainSerializationError(*failures) + return wire diff --git a/wool/src/wool/runtime/context/registry.py b/wool/src/wool/runtime/context/registry.py index a1ca76ef..ca08ac3a 100644 --- a/wool/src/wool/runtime/context/registry.py +++ b/wool/src/wool/runtime/context/registry.py @@ -1,201 +1,30 @@ +"""Shared `ContextVar` registry, factored out so the context +modules can import it without forming an import cycle.""" + from __future__ import annotations -import asyncio import threading import weakref from typing import TYPE_CHECKING from typing import Any -from typing import Callable from typing import Final -from typing import NoReturn -from typing import SupportsIndex -from uuid import UUID - -from wool.runtime.typing import Undefined -from wool.runtime.typing import UndefinedType if TYPE_CHECKING: - from wool.runtime.context.base import Context - from wool.runtime.context.token import Token - from wool.runtime.context.var import ContextVar + from wool.runtime.context.manifest import ContextVarManifest +# Serializes ``var_registry`` registration so concurrent declarations of +# the same ``(namespace, name)`` key cannot observe an intermediate +# state — see `ContextVar.__new__` and `resolve_stub`. lock: Final[threading.Lock] = threading.Lock() -var_registry: Final[weakref.WeakValueDictionary[tuple[str, str], ContextVar[Any]]] = ( - weakref.WeakValueDictionary() -) - - -class _TokenRegistry(weakref.WeakValueDictionary[UUID, "Token[Any]"]): - """Process-wide registry of live :class:`Token` instances that - pickles by module-attribute reference under Wool's pickler. - - Plain :class:`weakref.WeakValueDictionary` is unpicklable - (weakrefs reject pickle), and cloudpickle serializes bound - classmethods like :meth:`Token._reconstitute` by walking the - function's globals and capturing each name by value. Without - this override the by-value walk crashes on the registry; with - it, the registry reduces to a zero-arg lookup of this module's - :data:`token_registry` attribute, keeping the actual contents - process-local. - - The reduction is exposed only to Wool's pickler via - :meth:`__wool_reduce__`; vanilla :func:`pickle.dumps` and - :func:`cloudpickle.dumps` are rejected by :meth:`__reduce_ex__` - so the registry never silently leaves the dispatch path. - """ - - def __wool_reduce__(self) -> tuple[Callable[..., _TokenRegistry], tuple[()]]: - return (_resolve_token_registry, ()) - - def __reduce_ex__(self, _protocol: SupportsIndex) -> NoReturn: - raise TypeError( - "_TokenRegistry cannot be pickled via vanilla pickle/cloudpickle; " - "it is serialized automatically when dispatched through Wool's " - "runtime." - ) - - -# Process-wide weak registry of live :class:`Token` instances keyed -# by ID. Preserves pickle identity within a process and resolves -# incoming wire IDs to live tokens so their ``_used`` flag can be -# flipped on merge. Weak values auto-prune when a token is GC'd, so -# transient tokens from a ``set``/``reset`` loop do not accumulate. -token_registry: Final[_TokenRegistry] = _TokenRegistry() - - -def _resolve_token_registry() -> _TokenRegistry: - """Return the process-wide :class:`Token` registry. - - Module-level shim used by :meth:`_TokenRegistry.__wool_reduce__` - so cloudpickle's lookup-and-qualname path can pickle the registry - reference by name instead of by value. MUST stay at module level; - cloudpickle's by-reference lookup requires a stable qualname. - """ - return token_registry - - -class _ContextToken: - """Restore cookie returned by :meth:`_ContextRegistry.set` and - consumed by :meth:`_ContextRegistry.reset` to undo a - :class:`Context` installation. - - Mirrors :class:`contextvars.Token` in spirit: opaque to the - caller, carrying enough state to return the registry slot to - exactly where it was (including the "was-unset" case, which - :meth:`_ContextRegistry.reset` pops rather than rewriting to - ``None``). - """ - - __slots__ = ("_key", "_previous", "_used") - - _key: asyncio.Task[Any] | threading.Thread - _previous: Context | None - _used: bool - - def __init__( - self, - key: asyncio.Task[Any] | threading.Thread, - previous: Context | None, - ) -> None: - self._key = key - self._previous = previous - self._used = False - - -class _ContextRegistry( - weakref.WeakKeyDictionary["asyncio.Task[Any] | threading.Thread", "Context"] -): - """WeakKey registry of per-scope :class:`Context` bindings. - - Implements the standard :class:`MutableMapping` protocol — - ``__getitem__`` raises :class:`KeyError` on miss (no auto-create). - Overrides :meth:`get` and :meth:`setdefault` to default the *key* - argument to :func:`scope_key` when omitted (or passed as - :data:`Undefined`), so callers asking about "the current scope" - do not have to spell out the lookup. Adds two wool-specific - methods: :meth:`set` returns a :class:`_ContextToken` that - captures the previous slot state for later :meth:`reset` - restoration. - - Pickles by module-attribute reference under Wool's pickler. - Plain :class:`weakref.WeakKeyDictionary` is unpicklable - (weakrefs reject pickle), and cloudpickle serializes bound - classmethods like :meth:`Token._reconstitute` by walking the - function's globals and capturing each name by value. Without - the reduce hooks the by-value walk crashes on the registry; with - them, the registry reduces to a zero-arg lookup of this module's - :data:`context_registry` attribute, keeping the actual contents - process-local. - """ - - def get( # pyright: ignore[reportIncompatibleMethodOverride] - self, - key: asyncio.Task[Any] | threading.Thread | UndefinedType = Undefined, - default: Context | None = None, - ) -> Context | None: - if key is Undefined: - key = scope_key() - return super().get(key, default) # pyright: ignore[reportArgumentType] - - def setdefault( # pyright: ignore[reportIncompatibleMethodOverride] - self, - key: asyncio.Task[Any] | threading.Thread | UndefinedType = Undefined, - default: Context | None = None, - ) -> Context | None: - if key is Undefined: - key = scope_key() - return super().setdefault(key, default) # pyright: ignore[reportArgumentType] - - def set(self, ctx: Context) -> _ContextToken: - with lock: - key = scope_key() - previous = self.get(key) - self[key] = ctx - return _ContextToken(key, previous) - - def reset(self, token: _ContextToken) -> None: - if token._used: - raise RuntimeError("token already consumed by reset") - token._used = True - with lock: - if token._previous is None: - self.pop(token._key, None) - else: - self[token._key] = token._previous - - def __wool_reduce__(self) -> tuple[Callable[..., _ContextRegistry], tuple[()]]: - return (_resolve_context_registry, ()) - - def __reduce_ex__(self, _protocol: SupportsIndex) -> NoReturn: - raise TypeError( - "_ContextRegistry cannot be pickled via vanilla pickle/cloudpickle; " - "it is serialized automatically when dispatched through Wool's " - "runtime." - ) - - -context_registry: Final[_ContextRegistry] = _ContextRegistry() - - -def _resolve_context_registry() -> _ContextRegistry: - """Return the process-wide :class:`Context` registry. - - Module-level shim used by :meth:`_ContextRegistry.__wool_reduce__` - so cloudpickle's lookup-and-qualname path can pickle the registry - reference by name instead of by value. MUST stay at module level; - cloudpickle's by-reference lookup requires a stable qualname. - """ - return context_registry - - -def scope_key() -> asyncio.Task[Any] | threading.Thread: - """Identify the current execution scope (asyncio task, or - thread for sync callers). - """ - try: - return asyncio.current_task() or threading.current_thread() - except RuntimeError: - return threading.current_thread() +# Process-wide identity map for ``(namespace, name) → ContextVar``. +# Keyed values are ``ContextVarManifest`` — the identity/data base — so a +# still-undeclared stub and a promoted ``ContextVar`` both register here. +# Weak values so unreferenced stubs are reclaimed when the embedding +# object graph drops them; ``lock`` above guards concurrent registration +# so two threads cannot observe an intermediate state. +var_registry: Final[ + weakref.WeakValueDictionary[tuple[str, str], ContextVarManifest[Any]] +] = weakref.WeakValueDictionary() diff --git a/wool/src/wool/runtime/context/runtime.py b/wool/src/wool/runtime/context/runtime.py new file mode 100644 index 00000000..cbc2eaf5 --- /dev/null +++ b/wool/src/wool/runtime/context/runtime.py @@ -0,0 +1,144 @@ +"""Block-scoped runtime option overrides. + +Houses `RuntimeContext` — the user-facing context manager +that overrides per-chain runtime options (currently `dispatch_timeout`) +for the duration of a ``with`` block — and the ambient +`dispatch_timeout` `contextvars.ContextVar` it scopes. + +Distinct from `wool.runtime.context.chain` (the +`Chain` chain-state model) and +`wool.runtime.context.manifest` (the chain-manifest codec): runtime +options are a separate concern from Wool's chain-context machinery. +`dispatch_timeout` is a plain stdlib `contextvars.ContextVar` +— it does *not* ride the Wool-owned `wool.__chain__` variable +and is therefore unaffected by `Chain.mount`. Each dispatch +encodes its current value via `RuntimeContext` and the worker +re-installs that value before running the routine. +""" + +from __future__ import annotations + +import contextvars +from typing import Final + +from wool import protocol +from wool.runtime.typing import Undefined +from wool.runtime.typing import UndefinedType + +# Ambient per-chain dispatch timeout in seconds. ``None`` means no +# timeout. The value scopes to whichever execution chain is currently +# active and rides through nested dispatches until reset or +# overridden. This is a plain stdlib `contextvars.ContextVar` — +# distinct from the Wool-owned context variable that carries +# `wool.ContextVar` state. +dispatch_timeout: Final[contextvars.ContextVar[float | None]] = contextvars.ContextVar( + "dispatch_timeout", default=None +) + + +# public +class RuntimeContext: + """Block-scoped runtime option overrides for wool routines. + + Used as a context manager to override runtime options (currently + only `dispatch_timeout`) for the duration of a block. Auto- + captured on every `Task` at construction time and serialised + onto each dispatch so the worker restores the caller's effective + `dispatch_timeout` before running the routine. + + :param dispatch_timeout: + Default timeout for task dispatch operations as a positive + ``float``. Omit (or pass ``None``) to inherit the surrounding + scope's `dispatch_timeout`. ``None`` and ``Undefined`` are + observationally equivalent on the wire: `to_protobuf` + substitutes the live `dispatch_timeout` at encode time on + inherit, and omits + the field when the resolved value is ``None``. So a bare + ``RuntimeContext()`` constructed for wire transport still + propagates the encoder's effective timeout to the receiver. + """ + + _dispatch_timeout: float | UndefinedType + _dispatch_timeout_token: contextvars.Token[float | None] | None + + def __init__( + self, + *, + dispatch_timeout: float | None | UndefinedType = Undefined, + ) -> None: + # Collapse ``None`` to ``Undefined`` on input. The two + # were already observationally equivalent on the wire; this + # extends the equivalence to ``__enter__`` so a caller passing + # explicit ``None`` no longer surprisingly overrides the + # surrounding scope's timeout to ``None``. + self._dispatch_timeout = ( + Undefined if dispatch_timeout is None else dispatch_timeout + ) + self._dispatch_timeout_token = None + + def __enter__(self) -> RuntimeContext: + # Block-scoped, single-use: re-entering would overwrite the + # outer token without releasing it, leaking the original + # ``dispatch_timeout`` binding for the chain's lifetime. + if self._dispatch_timeout_token is not None: + raise RuntimeError( + "RuntimeContext is already active in a `with` block; " + "instances are block-scoped and single-use as context " + "managers" + ) + if self._dispatch_timeout is not Undefined: + self._dispatch_timeout_token = dispatch_timeout.set(self._dispatch_timeout) + return self + + def __exit__(self, *_): + if self._dispatch_timeout_token is not None: + dispatch_timeout.reset(self._dispatch_timeout_token) + self._dispatch_timeout_token = None + + @classmethod + def get_current(cls) -> RuntimeContext: + """Capture the current stdlib `dispatch_timeout` value.""" + return cls(dispatch_timeout=dispatch_timeout.get()) + + @classmethod + def from_protobuf(cls, context: protocol.RuntimeContext) -> RuntimeContext: + """Reconstruct from a `protocol.RuntimeContext` message. + + Mirrors `to_protobuf`'s omit-when-None semantic: an + absent ``dispatch_timeout`` field on the wire decodes to the + `Undefined` sentinel rather than explicit ``None`` so + the receiver's existing scope inherits through unchanged. + Decoding the absent field as explicit ``None`` would force + the receiver's `dispatch_timeout` to ``None`` on + `__enter__`, silently overriding whatever timeout the + receiver's scope had set — a round-trip non-identity for the + "encoder had no timeout, receiver had a scope timeout" case. + """ + return cls( + dispatch_timeout=( + context.dispatch_timeout + if context.HasField("dispatch_timeout") + else Undefined + ) + ) + + def to_protobuf(self) -> protocol.RuntimeContext: + """Serialize to a `protocol.RuntimeContext` message. + + When the instance was constructed without an explicit + ``dispatch_timeout`` (i.e., the default sentinel), the live + `dispatch_timeout` value from the current scope is + captured at encode time and rides the wire. An explicit + `None` skips emission, so the receiver inherits its + own scope's default. + """ + message = protocol.RuntimeContext() + timeout = self._dispatch_timeout + if timeout is Undefined: + timeout = dispatch_timeout.get() + if timeout is not None: + message.dispatch_timeout = timeout + return message + + +__all__ = ["RuntimeContext", "dispatch_timeout"] diff --git a/wool/src/wool/runtime/context/stub.py b/wool/src/wool/runtime/context/stub.py deleted file mode 100644 index f45869fa..00000000 --- a/wool/src/wool/runtime/context/stub.py +++ /dev/null @@ -1,117 +0,0 @@ -from __future__ import annotations - -import weakref -from typing import TYPE_CHECKING -from typing import Any - -from wool.runtime.context.registry import lock -from wool.runtime.context.registry import var_registry -from wool.runtime.typing import Undefined - -if TYPE_CHECKING: - from wool.runtime.context.base import Context - from wool.runtime.context.var import ContextVar - - -# Weakly indexes :class:`StubPin` instances by var key for O(1) -# release on promotion. Entries auto-prune when a pin has no strong -# refs (all pinning :class:`Context` instances have died without -# promotion occurring). -_stub_pin_anchors: weakref.WeakValueDictionary[tuple[str, str], StubPin] = ( - weakref.WeakValueDictionary() -) - - -class StubPin: - """Severable anchor that keeps a stub :class:`ContextVar` alive. - - Held strongly by each :class:`Context` that observed the stub's - creation (via :attr:`Context._stub_pins`), and weakly indexed by - var key in :data:`_stub_pin_anchors`. Because the only path from - a :class:`Context` to its pinned stub goes through this anchor, - promotion can release the stub in O(1) by nulling :attr:`stub` - — live :class:`Context` instances retain the gutted anchor until - they are themselves collected, but the stub itself is free to - be reclaimed as soon as user references drop. - """ - - __slots__ = ("stub", "__weakref__") - - stub: ContextVar[Any] | None - - def __init__(self, stub: ContextVar[Any]) -> None: - self.stub = stub - - -def pin_stub(stub: ContextVar[Any], ctx: Context) -> None: - """Pin a freshly reconstructed stub to *ctx* and the global index.""" - anchor = StubPin(stub) - _stub_pin_anchors[stub._key] = anchor - ctx._stub_pins.add(anchor) - - -def release_stub(key: tuple[str, str]) -> None: - """Release the stub pin for *key* so the stub can be reclaimed. - - Pops the anchor from the global index and severs its strong - reference to the stub; any live :class:`Context` still holding - the (now gutted) anchor in its pin set drops it naturally when - the :class:`Context` itself is collected. - """ - anchor = _stub_pin_anchors.pop(key, None) - if anchor is not None: - anchor.stub = None - - -def resolve_stub( - key: tuple[str, str], - ctx: Context, - *, - default: Any = Undefined, -) -> ContextVar[Any]: - """Return the :class:`ContextVar` registered under *key*, creating a - stub pinned to *ctx* if no authoritative declaration exists yet. - - Unifies the two ingress paths that may encounter an unregistered - var key on a receiving process: the pickle-embedded :class:`ContextVar` - instance path (via :meth:`ContextVar._reconstitute`) and the wire - snapshot path (via :meth:`Context.from_protobuf`). Both route - through this helper so a lazy-import receiver converges on the - same state regardless of whether the value arrived as a bare - wire entry or embedded in a pickled var reference. - - Pass *default* to seed the stub's default before promotion when - that information is available on the ingress side (the pickle - path carries it; the wire-snapshot path does not). - """ - # Local import breaks the wool.runtime.context.var ↔ - # wool.runtime.context.stub cycle: var imports stub helpers at - # module level, so stub must defer its ContextVar import until - # both modules have finished loading. - from wool.runtime.context.var import ContextVar - - with lock: - existing = var_registry.get(key) - if existing is not None: - # Fold the supplied default into a default-less stub: the - # wire-snapshot path supplies no default (wire bytes don't - # carry it), but the pickle-embedded path does. Whichever - # ingress encounters the key second must not silently - # discard a known default. - if ( - existing._stub - and existing._default is Undefined - and default is not Undefined - ): - existing._default = default - return existing - namespace, name = key - stub: ContextVar[Any] = object.__new__(ContextVar) - stub._name = name - stub._namespace = namespace - stub._key = key - stub._default = default - stub._stub = True - var_registry[key] = stub - pin_stub(stub, ctx) - return stub diff --git a/wool/src/wool/runtime/context/threading.py b/wool/src/wool/runtime/context/threading.py new file mode 100644 index 00000000..b53ce5e8 --- /dev/null +++ b/wool/src/wool/runtime/context/threading.py @@ -0,0 +1,116 @@ +"""Wool-aware thread offload. + +`to_thread` is the supported way to run blocking work on a +worker thread from an armed Wool chain. It forks the caller's chain +onto a fresh, **detached** chain owned by the worker thread, then +runs the offloaded callable under the forked context — preserving +`wool.ContextVar` bindings without sharing chain identity. + +The fork is the keystone: plain `asyncio.to_thread` copies the +caller's `contextvars.Context` verbatim, leaving the worker +thread on the *same* logical Wool chain. The first +`wool.ContextVar` access on the worker thread then trips the +chain-contention guard and raises `wool.ChainContention`. +`to_thread` here forks instead — a fresh chain UUID, owner- +stamped for the worker thread, with no merge-back path to the caller. + +Lives in its own module so `wool.runtime.context.factory` +retains its narrow charter — the task factory and its displacement +bookkeeping — and the thread-offload surface stays discoverable in +its own right. +""" + +from __future__ import annotations + +import asyncio +import contextvars +from typing import Any +from typing import Callable +from typing import TypeVar + +import wool + +T = TypeVar("T") + + +# public +async def to_thread( + func: Callable[..., T], + /, + *args: Any, + **kwargs: Any, +) -> T: + """Offload a blocking call to a worker thread on a fresh Wool chain. + + Mirrors `asyncio.to_thread` — runs *func* in the event + loop's default executor and awaits the result — but additionally + forks Wool chain state: the worker thread runs under a freshly + minted, **detached** chain. The offloaded function sees a copy of + the caller's `wool.ContextVar` bindings under a new chain + UUID owned by the worker thread; mutations it makes do not + propagate back to the caller (no merge-back). + + When called from an unarmed context (no `wool.ContextVar` + ever set, no incoming chain manifest merged), the worker thread + inherits the caller's plain `contextvars` context with no + Wool chain — there is no chain to fork — and behaves exactly like + `asyncio.to_thread` plus a no-op. + + This is the supported way to run Wool-aware work in another OS + thread. Plain `asyncio.to_thread` from an armed context + copies the caller's chain verbatim into the worker thread, placing + a second runner on one chain in genuine parallelism; the worker + thread's first `wool.ContextVar` access then trips the + chain-contention guard and raises + `wool.ChainContention`. Use this function instead. + + See ``wool/src/wool/runtime/context/README.md`` for the model + context. + + **Re-handoff is undefined behaviour.** A `to_thread` chain + is owner-stamped for the worker thread and detached on purpose; + handing the worker-side `contextvars.Context` back to the + caller and re-driving it (e.g., ``loop.create_task(coro, + context=captured)``) is unsupported. The behaviour is not silently + incorrect — it will fail loudly only at the chain's next + `wool.ContextVar` access — but the failure is not + explicitly designed for and the diagnostic surfaces no clearer + than the underlying guard raise. + + :param func: + The blocking callable to offload. + :param args: + Positional arguments forwarded to *func*. + :param kwargs: + Keyword arguments forwarded to *func*. + :returns: + The value returned by *func*. + """ + loop = asyncio.get_running_loop() + ctx = contextvars.copy_context() + + def _run() -> T: + # Runs inside the worker thread under the copied context. The + # fork must happen here, on the executor thread, not before + # copy_context() on the loop thread: Chain._fork stamps + # ``thread`` with threading.get_ident(), so forking any + # earlier would mint the fork owned by the loop thread and the + # worker's first wool.ContextVar access would self-trip the + # guard. Forked here, the offloaded chain is detached and owned + # by the worker, with no path back to the caller's chain. The + # copied context is deliberately not registered in + # ``_task_contexts`` (unlike a task-factory creation): it never + # escapes ``_run``, so it can never be re-passed to a + # create_task and need the guard. + context = wool.__chain__.get(None) + if context is not None: + # Route the fork through mount so the forked chain's + # owner-thread is stamped to this worker thread, not the + # loop thread that called copy_context(). + context._fork().mount() + return func(*args, **kwargs) + + return await loop.run_in_executor(None, ctx.run, _run) + + +__all__ = ["to_thread"] diff --git a/wool/src/wool/runtime/context/token.py b/wool/src/wool/runtime/context/token.py deleted file mode 100644 index 3ae6ff91..00000000 --- a/wool/src/wool/runtime/context/token.py +++ /dev/null @@ -1,297 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING -from typing import Any -from typing import Callable -from typing import ClassVar -from typing import Generic -from typing import NoReturn -from typing import SupportsIndex -from typing import TypeVar -from uuid import UUID -from uuid import uuid4 - -from wool.runtime.context.registry import context_registry -from wool.runtime.context.registry import token_registry -from wool.runtime.context.registry import var_registry -from wool.runtime.typing import Undefined -from wool.runtime.typing import UndefinedType - -if TYPE_CHECKING: - from wool.runtime.context.base import Context - from wool.runtime.context.var import ContextVar - -T = TypeVar("T") - - -# public -class Token(Generic[T]): - """Picklable token for reverting a :class:`ContextVar` mutation. - - Mirrors :class:`contextvars.Token`: single-use, same-var - rejection, and scoped to the :class:`wool.Context` in which it - was created. Attempting to :meth:`ContextVar.reset` with a token - minted in a different :class:`wool.Context` raises - :class:`ValueError`. Same-Context identity is checked by - comparing the :class:`wool.Context` UUID — the canonical chain - identity that holds across in-process and cross-process - boundaries. - - On construction the token registers itself in the process-wide - :data:`token_registry` so a same-process pickle round-trip - resolves back to this instance and wire-supplied consumed-token - UUIDs can be matched against live tokens for used-flag promotion. - - :param var: - The :class:`ContextVar` whose mutation this token reverts. - Only its key is captured — the var instance itself is - resolved on demand at :meth:`ContextVar.reset` time via - :data:`var_registry`. - :param old_value: - The value to restore on :meth:`ContextVar.reset`. - :data:`~wool.runtime.typing.Undefined` (also exposed as - :attr:`Token.MISSING`) signals the var was unset before - the corresponding :meth:`ContextVar.set`, so reset pops - the var from the :class:`Context`'s data rather than - restoring a value. - :param context: - The :class:`wool.Context` in which the corresponding - :meth:`ContextVar.set` ran. Only its UUID is retained — - the live Context reference is intentionally not held so - long-lived tokens do not pin their originating Context. - """ - - __slots__ = ( - "_id", - "_key", - "_old_value", - "_context_id", - "_used", - "__weakref__", - ) - - _id: UUID - _key: tuple[str, str] - _old_value: T | UndefinedType - _context_id: UUID - _used: bool - - MISSING: ClassVar[UndefinedType] = Undefined - - def __init__( - self, - var: ContextVar[T], - old_value: T | UndefinedType, - context: Context, - ): - self._id = uuid4() - self._key = var._key - self._old_value = old_value - self._context_id = context._id - self._used = False - token_registry[self._id] = self - - def __wool_reduce__(self) -> tuple[Callable[..., Token[T]], tuple[Any, ...]]: - """Return constructor args for unpickling via Wool's pickler. - - Token's transport carries identity (``_id``, ``_key``, - ``_context_id``) plus the consumed-state bit and old value - needed for cross-process reset. Token is guarded against - vanilla pickling (see :meth:`__reduce_ex__`); this method - is invoked only by Wool's own pickler. - """ - return ( - Token._reconstitute, - ( - self._id, - self._key, - self._old_value, - self._context_id, - self._used, - ), - ) - - def __reduce_ex__(self, _protocol: SupportsIndex) -> NoReturn: - """Reject vanilla pickling. - - Token reset semantics are scoped to a live :class:`wool.Context` - and the process-wide :data:`token_registry`, neither of which - is reconstructible from a vanilla pickle taken outside the - dispatch path. Wool's own pickler consults - ``reducer_override`` (and therefore :meth:`__wool_reduce__`) - before ``__reduce_ex__``, so this guard is invisible to - Wool's serialization. - - :func:`copy.copy` and :func:`copy.deepcopy` also route - through ``__reduce_ex__`` and are rejected for the same - reason — a registry-bound Token has no meaningful copy - semantics. - - :raises TypeError: - Always. - """ - raise TypeError( - "wool.Token cannot be pickled via vanilla pickle/cloudpickle; " - "it is serialized automatically when dispatched through Wool's " - "runtime." - ) - - def __repr__(self) -> str: - used_marker = " used" if self._used else "" - return f"" - - @property - def id(self) -> UUID: - """The UUID that identifies this :class:`Token` on the wire.""" - return self._id - - @property - def var(self) -> ContextVar[T]: - """Return the :class:`ContextVar` this token was created for. - - Resolves the stored key against the process-wide - :data:`var_registry`. Raises :class:`KeyError` if the key - is not registered locally — typically a cross-process token - whose owning :class:`ContextVar` was never declared on this - side; the caller can declare the var first or use the - wire-snapshot ingress path that pins a stub eagerly. - """ - return var_registry[self._key] - - @property - def old_value(self) -> T | UndefinedType: - """The prior value the var held before the :meth:`ContextVar.set` - call that produced this token. Returns :attr:`Token.MISSING` - if the var had no value set. - - :attr:`Token.MISSING` is a singleton; check for it via - identity (``token.old_value is Token.MISSING``) rather than - :func:`isinstance` since the underlying sentinel type is - internal and not part of the public API. - """ - return self._old_value - - @property - def used(self) -> bool: - """Whether this :class:`Token` has been consumed by :meth:`ContextVar.reset`. - - ``True`` once a successful :meth:`ContextVar.reset` call has - passed this token, whether the call occurred in this process - or in another process running in a :class:`Context` whose - used-token state has since been merged back (via pickle - round-trip or back-propagation into this token via the - process-wide token registry). Tokens are single-use across - the logical chain; a second :meth:`ContextVar.reset` raises - :class:`RuntimeError`. - """ - return self._used - - @staticmethod - def _promote_external_used( - active: Context | None, instance: Token[T], token_id: UUID - ) -> None: - """Migrate *token_id* from :attr:`Context._external_used_tokens` - into :attr:`Context._used_tokens` when *active* carries it under - the former. - - The active :class:`Context`'s external map is authoritative for - consumed-state in this chain — a wire-supplied entry only lands - there when an upstream :meth:`Context.to_protobuf` listed it - under ``consumed_tokens``. Promoting the entry into the auto- - pruning :class:`weakref.WeakSet` engages lifetime cleanup once - the live :class:`Token` is reclaimed, and brings :attr:`_used` - up to date even when the pickle's wire bit is stale. - """ - if active is None: - return - if token_id not in active._external_used_tokens: - return - if not instance._used: - instance._used = True - active._external_used_tokens.pop(token_id, None) - active._used_tokens.add(instance) - - @staticmethod - def _sync_state(token: Token[T], wire_used: bool) -> None: - """Monotonically advance *token*'s ``_used`` flag to match a - wire payload's used flag. - - :attr:`Token._used` is a one-way bit (``False → True``); a - wire payload reporting ``used=True`` for an id whose - registry instance still reads ``_used=False`` indicates the - token was consumed in a snapshot taken after the registry - instance was last updated. Bringing the live flag up to - date keeps the registry coherent across out-of-order pickle - round-trips — notably, the case where a user pickles a - token before and after reset, the original is GC'd, and the - older snapshot is unpickled first (registering a stub with - ``_used=False``) before the newer snapshot is unpickled. - Without this catch-up the registry would stay at - ``_used=False`` and a subsequent :meth:`ContextVar.reset` - would succeed against an already-consumed token. - - In the dispatch pipeline this is typically a no-op because - :meth:`Context.update` runs before result decode and flips - ``_used`` first via the wire's ``consumed_tokens`` field; - by the time :meth:`_reconstitute` sees a token, the - registry instance already reflects the consumed state. - """ - if wire_used and not token._used: - token._used = True - - @classmethod - def _reconstitute( - cls, - token_id: UUID, - key: tuple[str, str], - old_value: T | UndefinedType, - context_id: UUID, - used: bool, - ) -> Token[T]: - """Rebuild a :class:`Token` from externally-supplied parts. - - Same-process pickle identity is preserved via - :data:`token_registry`: if a live :class:`Token` already exists - under *token_id*, it is returned with its ``_used`` flag synced - via :meth:`_sync_state` so an out-of-order pickle round-trip - does not strand the registry at a stale ``_used=False``. - Cross-process callers see a fresh stub whose ``_used`` flag - is adopted from the pickled tuple. The token stores the var - *key* directly; the owning :class:`ContextVar` is resolved - on demand via the :attr:`var` property. - - When the active :class:`Context` carries this token's UUID - in :attr:`Context._external_used_tokens` — typically because - the consumed UUID arrived on an earlier wire frame before - the Token instance itself — the fresh instance is promoted - into :attr:`Context._used_tokens` so subsequent emissions - and merges go through the auto-pruning weak-set path. - """ - existing: Token[T] | None = token_registry.get(token_id) - if existing is not None: - cls._sync_state(existing, used) - instance = existing - else: - candidate: Token[T] = object.__new__(cls) - candidate._id = token_id - candidate._key = key - candidate._old_value = old_value - candidate._context_id = context_id - candidate._used = used - # Single-instance registry claim. Single-threaded - # asyncio per worker plus the single-task-per-Context - # invariant means only one task unpickles a given - # token's bytes at a time, so concurrent ``setdefault`` - # calls cannot occur in steady state. ``setdefault`` is - # used over plain assignment to express second-caller- - # wins (first candidate wins the slot; later callers - # receive the winner). Note that - # ``WeakValueDictionary.setdefault`` is not GIL-atomic - # — its check-then-insert is pure Python — so coherence - # rests on the architectural invariant rather than on - # dict atomicity. - instance = token_registry.setdefault(token_id, candidate) - if instance is not candidate: - cls._sync_state(instance, used) - cls._promote_external_used(context_registry.get(), instance, token_id) - return instance diff --git a/wool/src/wool/runtime/context/var.py b/wool/src/wool/runtime/context/var.py index 83731604..92e26353 100644 --- a/wool/src/wool/runtime/context/var.py +++ b/wool/src/wool/runtime/context/var.py @@ -1,21 +1,22 @@ from __future__ import annotations +import contextvars import inspect from typing import Any from typing import Callable from typing import Final -from typing import Generic -from typing import NoReturn -from typing import SupportsIndex from typing import TypeVar +from typing import cast from typing import overload -from wool.runtime.context.base import current_context +import wool +from wool.runtime.context.chain import Chain +from wool.runtime.context.exceptions import ContextVarCollision +from wool.runtime.context.guard import assert_chain_owner +from wool.runtime.context.manifest import ContextVarManifest +from wool.runtime.context.manifest import resolve_stub from wool.runtime.context.registry import lock from wool.runtime.context.registry import var_registry -from wool.runtime.context.stub import release_stub -from wool.runtime.context.stub import resolve_stub -from wool.runtime.context.token import Token from wool.runtime.typing import Undefined from wool.runtime.typing import UndefinedType @@ -25,69 +26,79 @@ # public -class ContextVarCollision(Exception): - """Raised when two distinct :class:`ContextVar` instances are - constructed with the same ``(namespace, name)`` key. - - Keys must be unique within the inferred package namespace. Library - authors should pass ``namespace=`` explicitly when constructing - vars from shared factory code; application code can rely on the - implicit package-name inference. - """ - - -# public -class ContextVar(Generic[T]): +class ContextVar(ContextVarManifest[T]): """Propagating context variable that crosses worker boundaries. - Mirrors :class:`contextvars.ContextVar` at the surface: construct - with a name and optional default; call :meth:`get`, :meth:`set`, - :meth:`reset`. Unlike :class:`contextvars.ContextVar`, instances + Mirrors `contextvars.ContextVar` across the `get` / `set` / + `reset` call shapes: construct with a name and optional default, + then call `get`, `set`, `reset`. Two deliberate divergences bound + that parity. First, `get`, `set`, and `reset` additionally raise + `wool.ChainContention` when a chain is entered by a thread or + asyncio task other than the one that owns it — stdlib + `contextvars.ContextVar` never raises it; see + `wool.ChainContention` for the full scenario catalogue. Second, + the `contextvars.Token` that `set` returns is process-local, and + its ``var`` attribute references the variable's internal backing + rather than this `ContextVar`, so ``wool_var.set(x).var is + wool_var`` is `False` where stdlib's is `True` — the supported + reset path is ``wool_var.reset(token)`` and ``token.var`` is not a + supported attribute. Unlike `contextvars.ContextVar`, instances pickle across process boundaries and their values propagate through ``@wool.routine`` dispatches. - **Identity model** — Every :class:`ContextVar` has a unique + **Identity model** — Every `ContextVar` has a unique ``(namespace, name)`` key. The ``name`` is the first positional argument; the ``namespace`` is inferred from the top-level package of the calling frame or provided explicitly via ``namespace=``. Two distinct instances constructed under the same key raise - :class:`ContextVarCollision`. + `ContextVarCollision`. **Namespace stability** — The inferred namespace is the top-level package of the calling frame. This is deliberately coarse so that wire keys stay stable when a module is refactored deeper within - its package — a rolling deploy that moves - ``myapp.auth.tokens`` to ``myapp.auth.credentials.tokens`` - continues to propagate values between caller and worker. The - trade-off is that two subpackages of the same library cannot - define distinct vars with the same ``name`` without one of them - passing ``namespace=`` explicitly; the construction raises - :class:`ContextVarCollision` instead. - - **Storage model** — Values are stored in the current - :class:`wool.Context` (one per :class:`asyncio.Task`, one per - thread for sync code) — separate state from the surrounding - :class:`contextvars.Context`. Child tasks fork a copy of the - parent's :class:`wool.Context` on creation when Wool's task - factory is installed on the running loop. + its package — a rolling deploy that moves ``myapp.auth.tokens`` + to ``myapp.auth.credentials.tokens`` continues to propagate + values between caller and worker. The trade-off is that two + subpackages of the same library cannot define distinct variables + with the same ``name`` without one of them passing ``namespace=`` + explicitly; the construction raises `ContextVarCollision` + instead. + + **Storage model** — Values ride in a single immutable + `~wool.runtime.context.chain.Chain` held in one + Wool-owned stdlib `contextvars.ContextVar`. Because the + Wool chain rides in stdlib ``contextvars``, ``wool.ContextVar`` + values propagate with stdlib visibility across every conformant + event loop and every cooperative asyncio scheduling edge — task + creation, ``call_soon``/``call_later``/``call_at``, + ``add_reader``/``add_writer``/``add_signal_handler``, + ``Future.add_done_callback``. The first `set` on a context + *arms* it: a chain UUID is minted and the chain-contention guard + engages. A context in which no `ContextVar` has been set + is unarmed and behaves as a plain `contextvars.Context`. + Once armed, the Wool-owned variable is a permanent member of the + `contextvars.Context`: a `contextvars.copy_context` + of an armed context carries one extra variable, and it stays even + after every `ContextVar` is reset. + + Child tasks fork a copy of the parent's context under a fresh + chain UUID when Wool's task factory is installed on the running + loop. Values propagated across the wire must be cloudpicklable. + Variable serialisation is dispatch-path-only — vanilla + ``pickle.dumps`` / ``cloudpickle.dumps`` / `copy.copy` / + `copy.deepcopy` raise `TypeError`; see + `__reduce_ex__`. """ - __slots__ = ( - "_name", - "_namespace", - "_key", - "_default", - "_stub", - "__weakref__", - ) - - _name: str - _namespace: str - _key: tuple[str, str] - _default: T | UndefinedType - _stub: bool + # No instance slots of its own — the data layer (the key, default, + # stub flag, and backing) lives on ``ContextVarManifest``. This + # empty-slots invariant is load-bearing: it keeps a manifest layout- + # compatible with ``ContextVar`` so ``promote`` can reassign + # ``__class__`` in place. Adding a slot here breaks that swap (see + # ``promote`` and its guard test). + __slots__ = () @overload def __new__( @@ -116,7 +127,7 @@ def __new__( namespace: str | None = None, default: T | UndefinedType = Undefined, ) -> ContextVar[T]: - """Resolve or construct the :class:`ContextVar` for *namespace:name*. + """Resolve or construct the `ContextVar` for *namespace:name*. The lookup, the registry insert, and the new instance's observable state are all serialized under the registry lock @@ -127,16 +138,19 @@ def __new__( * No prior registration — a fresh instance is constructed and registered. - * A stub already registered (seeded by an earlier pickle-path - ingress before any user declaration) — promoted in place. + * A stub already registered (seeded by an earlier wire + ingress — pickle-embedded or chain-manifest — before any user + declaration) — promoted in place. An explicit ``default=`` wins; an implicit - :data:`~wool.runtime.typing.Undefined` preserves whatever + `~wool.runtime.typing.Undefined` preserves whatever default the stub already carries, mirroring - :func:`~wool.runtime.context.stub.resolve_stub`'s + `resolve_stub`'s "don't silently discard a known default" rule. - * A non-stub registration already exists — :class:`ContextVarCollision` + * A non-stub registration already exists — `ContextVarCollision` raises; keys must be unique within a namespace. """ + if not isinstance(name, str): + raise TypeError("context variable name must be a str") if namespace is None: namespace = _infer_namespace() key = (namespace, name) @@ -144,22 +158,16 @@ def __new__( existing = var_registry.get(key) if existing is not None: if existing._stub: + promoted = promote(existing) if default is not Undefined: - existing._default = default - existing._stub = False - release_stub(key) - return existing + promoted._default = default + return promoted raise ContextVarCollision( f"wool.ContextVar {key!r} is already registered " f"({existing!r}). Keys must be unique within a " f"namespace." ) - instance = super().__new__(cls) - instance._name = name - instance._namespace = namespace - instance._key = key - instance._default = default - instance._stub = False + instance = cast(ContextVar[T], cls._build(key, default, stub=False)) var_registry[key] = instance return instance @@ -168,82 +176,46 @@ def __wool_reduce__( ) -> tuple[Callable[..., ContextVar[Any]], tuple[Any, ...]]: """Return constructor args for unpickling via Wool's pickler. - A :class:`ContextVar` is a key for resolving a value from - the active :class:`wool.Context`; its pickled state is - therefore the key plus the constructor default, never a - value snapshot. State propagation rides on the wire-context - path (:meth:`Context.to_protobuf` walks the sender's - ``_data``; :meth:`Context.from_protobuf` populates the - receiver's). The pickle path stays pure-identity so a - reconstituted var is a key only — `var.get()` on the - receiver resolves through the receiver's :class:`Context` - without the unpickle ever writing to it. - - ContextVar is guarded against vanilla pickling (see - :meth:`__reduce_ex__`); this method is invoked only by - Wool's own pickler. + A `ContextVar` is a key for resolving a value from the active + `~wool.runtime.context.chain.Chain`; its pickled state is + therefore the key plus the constructor default, never a captured + value. State propagation rides on the chain-manifest path + (`~wool.runtime.context.chain.Chain.to_manifest` snapshots the + sender's context and + `~wool.runtime.context.manifest.ChainManifest.to_protobuf` + encodes it; + `~wool.runtime.context.manifest.ChainManifest.from_protobuf` + decodes the wire frame and + `~wool.runtime.context.chain.Chain.from_manifest` populates the + receiver's context). The pickle path stays pure-identity so a + reconstituted variable is a key only — ``var.get()`` on the + receiver resolves through the receiver's context without the + unpickle ever writing to it. + + The pickle surface lives on `ContextVar` rather than + `~wool.runtime.context.manifest.ContextVarManifest` because + unpickling reconstitutes a *usable* variable — the receiver calls + `get`/`set`/`reset` on it — so `_reconstitute` upgrades a bare + manifest to a behavioral `ContextVar`. The vanilla-pickle guard + (`__reduce_ex__`) lives on the base, covering both flavors. """ return ( ContextVar._reconstitute, (self._namespace, self._name, self._default), ) - def __reduce_ex__(self, _protocol: SupportsIndex) -> NoReturn: - """Reject vanilla pickling. - - ContextVar identity is registered against the process-wide - :data:`var_registry`; restoring an instance outside Wool's - dispatch path bypasses the stub-promotion and collision- - detection that :meth:`_reconstitute` relies on. Wool's own - pickler consults ``reducer_override`` (and therefore - :meth:`__wool_reduce__`) before ``__reduce_ex__``, so this - guard is invisible to Wool's serialization. - - :func:`copy.copy` and :func:`copy.deepcopy` also route - through ``__reduce_ex__`` and are rejected for the same - reason — a registry-bound ContextVar has no meaningful - copy semantics. - - :raises TypeError: - Always. - """ - raise TypeError( - "wool.ContextVar cannot be pickled via vanilla pickle/cloudpickle; " - "it is serialized automatically when dispatched through Wool's " - "runtime." - ) - - def __repr__(self) -> str: - default_part = ( - f" default={self._default!r}" if self._default is not Undefined else "" - ) - return ( - f"" - ) - - @property - def name(self) -> str: - """The variable's name, matching the :class:`contextvars.ContextVar` API.""" - return self._name - - @property - def namespace(self) -> str: - """The namespace this var belongs to.""" - return self._namespace - @overload def get(self) -> T: ... @overload def get(self, default: T, /) -> T: ... - # ``*args`` sentinel pattern mirrors :meth:`contextvars.ContextVar.get` — - # distinguishes "no default supplied" (raise :class:`LookupError`) from - # "default is :data:`None`" (return :data:`None`). The user-facing surface + # ``*args`` sentinel pattern mirrors `contextvars.ContextVar.get` — + # distinguishes "no default supplied" (raise `LookupError`) from + # "default is `None`" (return `None`). The user-facing surface # is constrained by the two ``@overload`` declarations above. def get(self, *args: T) -> T: - """Return the current value in the active :class:`wool.Context`. + """Return the current value in the active context. :param default: Optional fallback returned when the variable has no value @@ -251,73 +223,184 @@ def get(self, *args: T) -> T: :returns: The current value, the supplied fallback, or the constructor default. + :raises TypeError: + If more than one positional argument is supplied. :raises LookupError: If the variable has no value, no fallback, and no default. + :raises ChainContention: + If the active chain is being entered by a thread or asyncio + task other than the one that owns it. See + `wool.ChainContention` for full detail. """ - ctx = current_context() - try: - return ctx._data[self] - except KeyError: - if args: - return args[0] - if self._default is not Undefined: - return self._default - raise LookupError(self) - - def set(self, value: T) -> Token[T]: - """Set the variable's value in the active :class:`wool.Context`. + if len(args) > 1: + raise TypeError(f"get expected at most 1 argument, got {len(args)}") + # Resolve through the backing variable. ``Undefined`` — whether + # the backing variable was never set in this Chain or was + # reset/merged to the sentinel value — means "fall through to + # the default ladder" (get argument, constructor default, + # LookupError). Guard at the user-facing API boundary to + # match `set` and `reset` — the storage layer + # (raw stdlib contextvars.ContextVar) carries no guard, so + # ``self._backing.get`` is direct stdlib plumbing. + assert_chain_owner(wool.__chain__.get(None)) + value = self._backing.get(Undefined) + if value is not Undefined: + return value + if args: + return args[0] + if self._default is not Undefined: + return self._default + raise LookupError(self) + + def set(self, value: T) -> contextvars.Token[T | UndefinedType]: + """Set the variable's value in the active context. + + The first `set` on an unarmed context arms it — mints a + fresh chain UUID, installs the first context, and self-installs + Wool's task factory on the running loop (raising + `~wool.TaskFactoryDisplaced` if Wool's factory was + previously installed on the loop but has since been displaced + by a third-party factory installed after it). The value rides + in the variable's backing + `contextvars.ContextVar`; the context's ``vars`` index + gains an entry for this variable the first time it is bound in + the chain. The factory install is performed by + `~wool.runtime.context.chain.Chain.mount`, which this + method routes through on every set, but the user-visible effect + chain is the same. :param value: The new value. :returns: - A :class:`Token` usable with :meth:`reset` to restore - the previous value. + A `contextvars.Token` usable with `reset` to + restore the previous value. The token is process-local, and + its ``var`` attribute references the variable's internal + backing `contextvars.ContextVar` rather than this + `ContextVar` (see the class docstring); the supported reset + path is ``wool_var.reset(token)``. + :raises ChainContention: + If the active chain is being entered by a thread or asyncio + task other than the one that owns it. See + `wool.ChainContention` for full detail. """ - ctx = current_context() - old_value = ctx._data.get(self, Undefined) - ctx._data[self] = value - return Token(self, old_value, ctx) - - def reset(self, token: Token[T]) -> None: + # Enforce the chain-ownership invariant against the *currently + # installed* `Chain` before doing anything else. + # `Chain.mount` below unconditionally re-stamps the + # owning thread/task, so a guard check that runs after mount + # would silently transfer ownership to the calling thread — + # the exact corruption `ChainContention` is meant to + # surface. The guard lives at the user-facing API boundary + # (here, plus `get` and `reset`); the storage + # layer (raw stdlib contextvars.ContextVar) is plumbing. + assert_chain_owner(wool.__chain__.get(None)) + chain = wool.__chain__.get(None) + if chain is None: + # First set on this chain: arm it with a fresh chain + # UUID. ``Chain.mount`` below is the single point at + # which Wool's task factory self-installs on the running + # loop — every code path that arms a chain transits + # through here. + chain = Chain() + # Add this variable to the chain's vars index the first time + # it is bound in the chain, and clear any prior reset signal — + # the variable now carries a value again. When it is already + # indexed and not reset, the chain is unchanged. + if self not in chain.vars or self._key in chain.resets: + chain = chain._evolve( + vars=chain.vars | {self}, + resets=chain.resets - {self._key}, + ) + # Mount before mutating the backing so that a mount-time + # raise (e.g., ``TaskFactoryDisplaced``) rolls back cleanly + # without leaving the backing in a state inconsistent with the + # Wool Chain. Mount in the set path doesn't touch this + # variable's backing — the evolve above removed ``self._key`` + # from ``resets``, so mount's reset-drain loop skips it, so + # the later ``self._backing.set(value)`` operates on a freshly + # mounted-but-unmodified backing. + chain.mount() + return self._backing.set(value) + + def reset(self, token: contextvars.Token[T | UndefinedType]) -> None: """Restore the variable to the value it had before *token*. - Matches :meth:`contextvars.ContextVar.reset` semantics, - scoped to the :class:`wool.Context`: the token must have - been created in the same :class:`wool.Context` as the one - currently active. Same-Context identity is checked by - comparing the :class:`wool.Context` UUID — the canonical - chain identity that holds across in-process and cross- - process boundaries. + Matches `contextvars.ContextVar.reset` semantics. The + reset delegates to the backing variable's native + `contextvars.ContextVar.reset`, so stdlib itself + enforces single-use, rejects a token whose ``var`` is not this + variable's backing, and rejects a token reset in a different + `contextvars.Context` than the one it was minted in. :param token: - A token previously returned by :meth:`set`. - :raises RuntimeError: - If the token has already been used. + A `contextvars.Token` previously returned by + `set`. ``token.var`` references the variable's + internal backing `contextvars.ContextVar` rather + than this `ContextVar` instance — supported reset + path is ``wool_var.reset(token)``, not direct stdlib + reset. :raises ValueError: If the token was created by a different - :class:`ContextVar` or in a different - :class:`wool.Context`. + `ContextVar` or in a different + `contextvars.Context` (surfaced by stdlib + `contextvars.ContextVar.reset`). + :raises RuntimeError: + If the token has already been used (surfaced by stdlib + single-use enforcement). + :raises ChainContention: + If the active chain is being entered by a thread or asyncio + task other than the one that owns it. See + `wool.ChainContention` for full detail. """ - if token._key != self._key: - raise ValueError("Token was created by a different ContextVar") - ctx = current_context() - if token._context_id != ctx._id: - raise ValueError("Token was created in a different wool.Context") - if token._used: - raise RuntimeError("Token has already been used") - token._used = True - # Track the live Token via :class:`weakref.WeakSet` so the ID - # is reclaimed automatically when the Token is collected. If - # this UUID arrived earlier as a wire-supplied external entry - # (e.g. propagated from a prior hop before the local Token - # materialized), promote it now so emission goes through the - # auto-pruning path. - ctx._used_tokens.add(token) - ctx._external_used_tokens.pop(token._id, None) - if token._old_value is Undefined: - ctx._data.pop(self, None) + # Enforce the chain-ownership invariant against the *currently + # installed* `Chain` before doing anything else. + # `Chain.mount` below unconditionally re-stamps the + # owning thread/task, so a guard check that runs after mount + # would silently transfer ownership to the calling thread. + # See `set` for the analogous guard placement. + assert_chain_owner(wool.__chain__.get(None)) + # Atomicity: run the native reset first. Stdlib + # ``ContextVar.reset`` is atomic — if it rejects (wrong var, + # wrong Context, already used) observable state is unchanged. + # Mounting first would break that contract: by the time + # stdlib raised, the Wool ``Chain`` would already have been + # evolved + installed and (for the unset case) the backing + # rewound to ``Undefined`` via the mount drain. + # Native-reset-first preserves stdlib parity: + # if the native reset raises, no Wool bookkeeping happens; + # if it succeeds, evolve + mount commit the Wool view. + self._backing.reset(token) + # Below this point the native reset has succeeded; mutate Wool + # bookkeeping. A subsequent mount-time raise (e.g. + # ``TaskFactoryDisplaced``) does leave the backing in its + # restored state — but mount-time raises on the reset path + # are pathological (the chain was already armed; reset just + # rewound it) and the diagnostic surface intent of mount + # raising is unrelated to reset atomicity. + context = wool.__chain__.get(None) + if context is None: + # Unarmed context — native reset already raised the + # appropriate ValueError, so this branch is unreachable + # except via test mocks. Safe to no-op. + return + # ``token.old_value`` is `contextvars.Token.MISSING` when + # the variable was never set in this Chain (true first-set), + # ``Undefined`` (Wool's sentinel) when the backing was + # rewound by a prior ``mount`` drain, and the prior value + # otherwise. Both ``MISSING`` and ``Undefined`` mean "reset to + # no observable value" for the Wool ``Chain`` bookkeeping. + old_value = token.old_value + old_was_unset = old_value is contextvars.Token.MISSING or old_value is Undefined + if old_was_unset: + new_vars = context.vars - {self} + new_resets = context.resets | {self._key} else: - ctx._data[self] = token._old_value + new_vars = context.vars | {self} + new_resets = context.resets - {self._key} + evolved = context._evolve( + vars=new_vars, + resets=new_resets, + ) + evolved.mount() @classmethod def _reconstitute( @@ -326,24 +409,27 @@ def _reconstitute( name: str, default: Any, ) -> ContextVar[Any]: - """Rebuild or resolve a :class:`ContextVar` from externally- - supplied parts. - - Routes through :func:`resolve_stub` for the lookup-or-stub - path so the wire-snapshot ingress (via - :meth:`Context.from_protobuf`) and the pickle ingress (this - method) converge on a single creation site. Pickle restores - identity only — the receiver's :class:`Context` is the - source of truth for value lookup, populated via the - wire-context path rather than as a side-effect of - unpickling. + """Rebuild or resolve a usable `ContextVar` from pickled parts. + + Routes through `~wool.runtime.context.manifest.resolve_stub` so + the chain-manifest ingress and the pickle ingress (this method) + converge on a single registry entry per key. Where the manifest + ingress is content with a bare + `~wool.runtime.context.manifest.ContextVarManifest`, the pickle + ingress needs behavior — the receiver invokes `get`/`set`/`reset` + — so `_as_contextvar` upgrades a freshly minted manifest to a + functional `ContextVar` in place. The instance stays flagged as a + stub + (``_stub`` is untouched) so a later user declaration still + promotes it without `ContextVarCollision`. Pickle restores + identity only — the receiver's context, populated via the + chain-manifest path, is the source of truth for value lookup. """ - ctx = current_context() - return resolve_stub((namespace, name), ctx, default=default) + return _as_contextvar(resolve_stub((namespace, name), default=default)) def _infer_namespace() -> str: - """Infer the namespace for a :class:`ContextVar` constructor call. + """Infer the namespace for a `ContextVar` constructor call. Walks up the call stack from the current frame, skipping frames from any ``wool.runtime.context`` submodule, and returns the @@ -357,3 +443,43 @@ def _infer_namespace() -> str: return module.partition(".")[0] frame = frame.f_back return "__main__" # pragma: no cover — stack always has a caller + + +def _as_contextvar(manifest: ContextVarManifest[T]) -> ContextVar[T]: + """Retag a bare `ContextVarManifest` as a behavioral `ContextVar` in place. + + The upgrade preserves identity: ``manifest`` keeps its backing + `contextvars.ContextVar` and its registry registration, so a value + drained into the backing before the upgrade survives, and every + immutable `~wool.runtime.context.chain.Chain` that already captured + the manifest observes the upgrade without rebuild. In-place is + mandatory: a fresh object would leave those chains pointing at a + stale manifest and break the one-instance-per-key invariant. + + ``_stub`` is left untouched — this only grants behavior, not + declared status (`promote` is the declaration transition). + + Reassigning ``__class__`` is sound only because `ContextVar` adds no + instance slots over `ContextVarManifest` (``__slots__ = ()``), + keeping the two memory layouts compatible. A future slot on + `ContextVar` would make this raise `TypeError` at the first upgrade; + ``test_contextvar_should_declare_empty_slots`` guards the invariant. + """ + if type(manifest) is ContextVarManifest: + manifest.__class__ = ContextVar + return cast(ContextVar[T], manifest) + + +def promote(manifest: ContextVarManifest[T]) -> ContextVar[T]: + """Promote a stub-state manifest to a *declared* `ContextVar`, in place. + + The receiver-side counterpart to `resolve_stub`: when user code + finally declares a variable whose key a wire ingress already + registered as a stub, `ContextVar.__new__` calls this to upgrade the + registered placeholder. Grants behavior via `_as_contextvar` and + clears ``_stub`` so the variable counts as declared — a subsequent + declaration of the same key then raises `ContextVarCollision`. + """ + promoted = _as_contextvar(manifest) + promoted._stub = False + return promoted From 3fa072d925deddc3cebf0b358cf6e3d37a5e96ce Mon Sep 17 00:00:00 2001 From: Conrad Date: Sat, 27 Jun 2026 17:51:27 -0400 Subject: [PATCH 4/7] refactor!: Adopt the per-frame chain architecture on the worker Wrap every worker-wire request and response in a Frame subclass that carries its own ChainManifest, and mount it into the active chain at receive time as a single canonical step. Frame.chain_manifest is a union of the decoded manifest, the decode error, or absent; a strict- mode decode failure is deferred to mount, which raises it or chains it onto an exception payload's __context__. The dispatch driver caches one contextvars.Context per chain id so async-generator frames reuse it across yields. --- wool/src/wool/runtime/worker/connection.py | 1613 +++++++++---------- wool/src/wool/runtime/worker/frame.py | 686 ++++++++ wool/src/wool/runtime/worker/interceptor.py | 9 +- wool/src/wool/runtime/worker/metadata.py | 13 +- wool/src/wool/runtime/worker/pool.py | 4 +- wool/src/wool/runtime/worker/process.py | 2 +- wool/src/wool/runtime/worker/proxy.py | 2 +- wool/src/wool/runtime/worker/service.py | 486 ++---- wool/src/wool/runtime/worker/session.py | 1330 +++++++++------ 9 files changed, 2532 insertions(+), 1613 deletions(-) create mode 100644 wool/src/wool/runtime/worker/frame.py diff --git a/wool/src/wool/runtime/worker/connection.py b/wool/src/wool/runtime/worker/connection.py index 4bc83ea8..c48831b4 100644 --- a/wool/src/wool/runtime/worker/connection.py +++ b/wool/src/wool/runtime/worker/connection.py @@ -17,11 +17,18 @@ import wool from wool import protocol -from wool.runtime import context from wool.runtime.resourcepool import ResourcePool from wool.runtime.routine.task import Task from wool.runtime.serializer import Serializer from wool.runtime.worker.base import ChannelOptions +from wool.runtime.worker.frame import ExceptionResponseFrame +from wool.runtime.worker.frame import Frame +from wool.runtime.worker.frame import NextRequestFrame +from wool.runtime.worker.frame import RequestFrame +from wool.runtime.worker.frame import ResultResponseFrame +from wool.runtime.worker.frame import SendRequestFrame +from wool.runtime.worker.frame import TaskRequestFrame +from wool.runtime.worker.frame import ThrowRequestFrame _DispatchCall: TypeAlias = grpc.aio.StreamStreamCall[protocol.Request, protocol.Response] _PoolKey: TypeAlias = tuple[str, grpc.ChannelCredentials | None, ChannelOptions] @@ -31,925 +38,905 @@ _log = logging.getLogger(__name__) -@dataclass -class _Channel: - """Internal holder for a pooled gRPC channel and its resources.""" +# public +class UnexpectedResponse(Exception): + """Raised when a worker returns an unexpected response type. - channel: grpc.aio.Channel - stub: protocol.WorkerStub - semaphore: asyncio.Semaphore + This exception indicates a protocol violation where the worker's + response doesn't match the expected format (e.g., missing acknowledgment + or returning an unrecognized payload type). + """ - async def close(self): - """Close the underlying gRPC channel.""" - await self.channel.close() +# public +class RpcError(Exception): + """Raised when a gRPC call to a worker fails with a non-transient + error. -def _channel_factory(key): - """Create a new :class:`_Channel` for the given pool key. + Non-transient errors indicate persistent issues with the worker + that are unlikely to be resolved by retrying (e.g., invalid + arguments, unimplemented methods, permission denied, + server-side bugs, version skew). - :param key: - Tuple of ``(target, credentials, options)``. - :returns: - A new :class:`_Channel` instance. - """ - target, credentials, options = key - grpc_options = [ - ("grpc.max_receive_message_length", options.max_receive_message_length), - ("grpc.max_send_message_length", options.max_send_message_length), - ("grpc.keepalive_time_ms", options.keepalive_time_ms), - ("grpc.keepalive_timeout_ms", options.keepalive_timeout_ms), - ( - "grpc.keepalive_permit_without_calls", - int(options.keepalive_permit_without_calls), - ), - ("grpc.http2.max_pings_without_data", options.max_pings_without_data), - ("grpc.max_concurrent_streams", options.max_concurrent_streams), - ( - "grpc.default_compression_algorithm", - options.compression.value, - ), - ] - if credentials is not None: - channel = grpc.aio.secure_channel(target, credentials, options=grpc_options) - else: - channel = grpc.aio.insecure_channel(target, options=grpc_options) - stub = protocol.WorkerStub(channel) - return _Channel(channel, stub, asyncio.Semaphore(options.max_concurrent_streams)) + **Worker-health exception contract.** Load-balancer strategies + treat exception classes from `WorkerConnection.dispatch` + as a three-way classification: + - `TransientRpcError` — worker is hiccupping + (``UNAVAILABLE`` / ``DEADLINE_EXCEEDED`` / + ``RESOURCE_EXHAUSTED``); the strategy should **skip** to + the next worker without eviction. The worker may recover. + - `RpcError` (non-transient) — worker is unhealthy + (``INTERNAL``, ``FAILED_PRECONDITION``, malformed Nack + dump, version skew, etc.); the strategy should **evict**. + Today's binary policy is "evict on first occurrence"; + health-aware forgiveness (N-strikes) is a follow-up. + - Any other class — caller-fault (`UnexpectedResponse` for a + malformed or unrecognized worker response, parse-phase + failures re-raised as the original exception type, caller-side + encode failures, programming bugs); the strategy + **propagates** to the caller without touching the pool. A + strict-mode chain-encode failure surfaces in this bucket as a + `~wool.runtime.context.exceptions.ChainSerializationError` (or + a `BaseExceptionGroup` wrapping one) and propagates unwrapped. -async def _channel_finalizer(channel: _Channel): - """Close the gRPC channel held by a :class:`_Channel`. + Strategy authors implementing `LoadBalancerLike` MUST + honor this contract: a strategy that catches `Exception` + indiscriminately will silently evict workers on every + caller-side bug, wiping the pool over time. - :param channel: - The :class:`_Channel` to finalize. + **Constructor invariant.** ``code`` and ``details`` are both + optional, but a real gRPC fault always carries ``details`` (the + human-readable failure), with ``code`` set whenever the status is + known; the all-``None`` form is reserved as a test sentinel. """ - await channel.close() + def __init__( + self, + code: grpc.StatusCode | None = None, + details: str | None = None, + ): + self.code = code + self.details = details + if code is not None and details is not None: + super().__init__(f"{code.name}: {details}") + elif code is not None: # pragma: no cover + super().__init__(code.name) + elif details is not None: + super().__init__(details) + else: # pragma: no cover + super().__init__() -_channel_pool: ResourcePool[_Channel] = ResourcePool( - factory=_channel_factory, finalizer=_channel_finalizer, ttl=60 -) +# public +class TransientRpcError(RpcError): + """Raised when a gRPC call to a worker fails with a transient error. -async def clear_channel_pool() -> None: - """Close and clear every gRPC channel in the process-wide pool. + Transient errors indicate temporary issues that may be resolved by + retrying the operation, such as: - Invalidates cached channels across every pool key, including - UDS targets. + - ``UNAVAILABLE``: Worker temporarily unavailable + - ``DEADLINE_EXCEEDED``: Request took too long + - ``RESOURCE_EXHAUSTED``: Worker temporarily overloaded """ - await _channel_pool.clear() - - -_TEARDOWN_TIMEOUT: Final = 60.0 - -async def _complete_teardown(teardown: Coroutine[Any, Any, None]) -> None: - """Drive *teardown* to completion, immune to caller cancellation. - Resource teardown registered on an :class:`AsyncExitStack` - awaits its release callbacks. When the caller task carries a - pending cancellation — externally via ``task.cancel()`` or after - a worker-side ``CancelledError`` bumped ``cancelling()`` in - :meth:`_DispatchStream._read_next` — asyncio would pre-empt the - next suspending teardown ``await`` and skip the remaining - callbacks, leaking a pooled resource reference. +# public +class WorkerConnection: + """gRPC connection to a worker for task dispatch. - Running *teardown* as a shielded child task gives it an - independent cancellation state, so its ``await`` boundaries run - uninterrupted. A cancellation observed while waiting is deferred - and re-raised once teardown finishes, so the caller still - observes the cancel. + Acquires pooled gRPC channels keyed by ``(target, credentials, + options)``. Each `dispatch` call obtains a reference-counted + channel from the module-level pool, primes an async generator that + holds its own reference, then releases the dispatch-scope reference. + The channel stays alive until the caller finishes consuming the + result stream. - A teardown-side exception other than ``CancelledError`` propagates - and supersedes a deferred cancel, mirroring ``finally`` precedence. - ``KeyboardInterrupt`` and ``SystemExit`` are captured off the child - task — where they would otherwise escape straight to the event-loop - runner via ``Task.__step`` — and re-raised in the caller's context. - If teardown does not finish within :data:`_TEARDOWN_TIMEOUT` the - caller is unblocked and the shielded task is left running detached - so the release still completes. - """ - interrupt: KeyboardInterrupt | SystemExit | None = None + **Cleanup semantics on cancellation.** Every code path that owns + an in-flight gRPC call wraps its body in + ``try / except BaseException`` so that ``asyncio.CancelledError`` + (a `BaseException` subclass, not `Exception`) still + triggers ``call.cancel()`` before re-raising. The cancel itself + is swallowed at `Exception` (not `BaseException`) — + cleanup-during-cleanup should let ``KeyboardInterrupt`` propagate + rather than silently drop a process-level signal. - async def _run() -> None: - # Capture process-level signals here, inside the child task's - # own frame: a ``KeyboardInterrupt``/``SystemExit`` raised by a - # task escapes to the event-loop runner rather than to the - # awaiter, so it must not be left to propagate out of the task. - nonlocal interrupt - try: - await teardown - except (KeyboardInterrupt, SystemExit) as exc: - interrupt = exc + :param target: + Worker URI. Supports multiple formats: - task = asyncio.ensure_future(_run()) - deferred: asyncio.CancelledError | None = None - while True: - try: - async with asyncio.timeout(_TEARDOWN_TIMEOUT): - await asyncio.shield(task) - except TimeoutError: - # Teardown is wedged — only reachable under pathological - # pool-lock contention. Stop blocking the caller; the - # shielded task keeps running so the release still - # completes, just not synchronously. - _log.warning( - "Routine teardown exceeded %.0fs; pooled-resource " - "release deferred to a detached task.", - _TEARDOWN_TIMEOUT, - ) - break - except asyncio.CancelledError as exc: - if not task.done(): - # Caller cancelled mid-teardown — keep the shielded - # task running and re-await it on the next iteration. - deferred = exc - continue - raise - break - if interrupt is not None: - raise interrupt - if deferred is not None: - raise deferred + - ``host:port`` - DNS name or IP with port + - ``dns://host:port`` - Explicit DNS resolution + - ``ipv4:address:port`` - IPv4 address + - ``ipv6:[address]:port`` - IPv6 address + - ``unix:path`` - Unix domain socket + Examples: ``localhost:50051``, ``192.0.2.1:50051`` + :param credentials: + Optional channel credentials for TLS/mTLS connections. + :param options: + Optional channel options controlling gRPC message + size limits, keepalive, concurrency, and compression. + See `ChannelOptions` for defaults. The + ``max_concurrent_streams`` field sizes the per-channel + concurrency semaphore. -class _DispatchStream(Generic[_T]): - """Async iterator wrapper for streaming task results from workers. + **Usage:** - Handles iteration over gRPC response streams and deserializes - task results or raises exceptions received from remote workers. + .. code-block:: python - :param call: - The underlying gRPC response stream. + conn = WorkerConnection("localhost:50051") + async for result in conn.dispatch(task): + process(result) + await conn.close() """ + TRANSIENT_ERRORS: Final = { + grpc.StatusCode.UNAVAILABLE, + grpc.StatusCode.DEADLINE_EXCEEDED, + grpc.StatusCode.RESOURCE_EXHAUSTED, + } + def __init__( self, - call: _DispatchCall, - task: Task, - serializer: Serializer | None = None, + target: str, + *, + credentials: grpc.ChannelCredentials | None = None, + options: ChannelOptions | None = None, ): - self._call = call - self._task = task - self._serializer: Serializer = ( - serializer if serializer is not None else wool.__serializer__ - ) - self._iter = aiter(call) - self._closed = False - self._running = False + self._target = target + self._credentials = credentials + self._options = options if options is not None else ChannelOptions() + self._key: _PoolKey = (target, credentials, self._options) + self._uds_key: _PoolKey | None = None - async def __anext__(self) -> _T: - """Get the next response from the stream. + async def dispatch( + self, + task: Task, + *, + timeout: float | None = None, + ) -> AsyncGenerator[protocol.Message, None]: + """Dispatch a task to the remote worker for execution. - Sends a ``next`` request to the server to advance the remote - generator, then reads and returns the next result. + Sends the task to the worker via gRPC, waits for acknowledgment, + and returns an async iterator that streams back results. Respects + concurrency limits and applies timeout to the dispatch phase only + (semaphore acquisition and acknowledgment). + + **Chain decode failures (caller-side).** + Each response frame may carry a back-propagated chain manifest + that needs decoding before the caller can merge worker-side + mutations. The chain manifest is **ancillary state** under wool's + protocol contract: per-entry decode failures emit + `wool.SerializationWarning` instances inside + `~wool.runtime.context.manifest.ChainManifest.from_protobuf`. + Under the warnings system's default filter these surface once + as warnings and decoding returns the partial manifest; under a + filter that promotes `wool.SerializationWarning` to an + error, + `~wool.runtime.context.manifest.ChainManifest.from_protobuf` + aggregates the per-entry warnings into a + `wool.ChainSerializationError` and raises in place of + returning. Caller-side handling after loading the primary + signal: + + * On a result frame, the `wool.ChainSerializationError` + raises as the primary — strict mode loses the result value + but every decode failure surfaces, not just the first. The + result cannot be trusted alongside a chain manifest that failed + to apply. + * On an exception frame, the decode error chains onto the + worker exception's ``__context__`` (implicit context, set + directly rather than via ``raise ... from``). The worker + exception class is preserved so the caller's existing + ``except RoutineError`` continues to catch — no migration + to ``except*`` required. The decode error remains visible + in the traceback through context chaining. Under the + default filter the per-entry warnings emit once during + decode and the worker exception raises unchained. + :param task: + The `Task` instance to dispatch to the worker. + :param timeout: + Timeout in seconds for semaphore acquisition and task + acknowledgment. If ``None``, no timeout is applied. Does not + apply to the execution phase. :returns: - The next task result from the worker. - :raises StopAsyncIteration: - When the stream is exhausted or after aclose() is called. - :raises RuntimeError: - If another iteration is already in progress. - :raises UnexpectedResponse: - If the response payload is unrecognised (neither a - result nor an exception), if a result or exception - dump cannot be deserialised (e.g. cloudpickle version - skew, missing class on the caller's path, truncated - bytes, etc.), or if the worker ships a non-:class:`Exception` - :class:`BaseException` payload other than - :class:`asyncio.CancelledError` (e.g. :class:`KeyboardInterrupt`, - :class:`SystemExit`, user-defined :class:`BaseException` - subclasses, etc.). - :raises asyncio.CancelledError: - When the worker-side routine raises - :class:`asyncio.CancelledError` from its body (or is - externally cancelled and propagates the - ``CancelledError`` out). Mirrors stdlib's ``await task`` - semantics where ``raise CancelledError`` from the - awaitee is indistinguishable from - ``task.cancel()`` — both transition the task to - ``CANCELLED`` and the caller's ``await`` raises - ``CancelledError``. The caller task's ``cancelling()`` - count is incremented synchronously with the raise, - mirroring stdlib's local-cancel state shape so that - ``if cancelling() > 0: raise`` re-raise gates and - ``current_task().uncancel()`` absorbers behave - identically for worker-side and local cancels. - :raises Exception: - The worker-side routine's exception, re-raised in its - original class. The class is narrowed to - :class:`Exception` for non-:class:`CancelledError` - :class:`BaseException` subclasses (:class:`KeyboardInterrupt`, - :class:`SystemExit`, or user-defined :class:`BaseException` - subclasses): these are degraded to - :class:`UnexpectedResponse` so process-level signals - cannot be smuggled across the wire and trip caller-side - signal handlers. :class:`UnexpectedResponse` is not a - :class:`RpcError` subclass, so the load balancer treats - it as a caller-fault and does not evict the worker. - Caller-side gRPC cancellation arrives via a different - path, not via this exception. - """ - if self._closed: # pragma: no cover - raise StopAsyncIteration - if self._running: # pragma: no cover - raise RuntimeError("anext(): asynchronous generator is already running") - self._running = True - try: - request = protocol.Request( - next=protocol.Void(), - context=context.current_context().to_protobuf( - serializer=self._serializer - ), - ) - await self._call.write(request) - result = await self._read_next() - return result - finally: - self._running = False + An async iterator that yields task results from the worker. + :raises TransientRpcError: + If the worker returns a transient RPC error (UNAVAILABLE, + DEADLINE_EXCEEDED, or RESOURCE_EXHAUSTED) or the local + dispatch-phase timeout fires (also classified as + DEADLINE_EXCEEDED). + :raises RpcError: + If the worker returns a non-transient RPC error or + rejects with a Nack whose dumped exception cannot be + deserialized (malformed-dump fallback). + :raises UnexpectedResponse: + If the worker doesn't acknowledge the task. + :raises ValueError: + If the timeout value is not positive. - async def _read_next(self) -> _T: - """Read the next response from the stream without writing — - for paths that have already written their own request. + If the worker rejects the task during the parse phase due + to a malformed task payload, the original exception class + is deserialized from the Nack and re-raised so the caller + observes the actual failure class rather than an opaque + protocol error. A malformed Nack payload falls back to + `RpcError`. + + Encode-side failures (e.g., a strict-mode + `wool.ChainSerializationError` aggregating + `wool.SerializationWarning` peers raised by + `~wool.runtime.context.manifest.ChainManifest.to_protobuf` when + an unpicklable `wool.ContextVar` value is set) + propagate unwrapped: + the load-balancer contract treats only `RpcError` + instances as worker-health concerns, so a caller-side encode + failure surfaces directly to the caller rather than evicting + workers. + """ + if timeout is not None and timeout <= 0: + raise ValueError("Dispatch timeout must be positive") - Applies the response's :class:`Context` into the caller's - current :class:`Context` — var mutations and consumed-token - state both ride back-propagation. + if ( + metadata := wool.__worker_metadata__ + ) is not None and metadata.address == self._target: + if (uds_address := wool.__worker_uds_address__) is not None: + key = (uds_address, None, self._options) + self._uds_key = key + else: + key = self._key + else: + key = self._key - :returns: - The next task result from the worker. - """ + stream = self._execute(task, key, timeout) try: - response = await anext(self._iter) - # Wool treats response context as ancillary state. Per-var - # decode failures aggregate inside - # :meth:`Context.from_protobuf` and surface as a - # :class:`BaseExceptionGroup` only under strict mode; on the - # primary-signal path we bundle them with the worker - # exception (or the result-bearing response's group) so - # callers can extract both signals via ``except*``. - decode_failures: list[BaseException] = [] - try: - incoming_context = context.Context.from_protobuf( - response.context, serializer=self._serializer - ) - except BaseExceptionGroup as eg: - decode_failures.extend(eg.exceptions) - else: - if incoming_context.has_state(): - context.current_context().update(incoming_context) - if response.HasField("result"): - try: - result = self._serializer.loads(response.result.dump) - except Exception as exc: - # Degrade malformed result payloads to - # :class:`UnexpectedResponse` so callers can - # ``except UnexpectedResponse`` uniformly while - # the original pickle/import failure remains on - # ``__cause__`` for diagnostic chains. Load - # balancer treats this as caller-fault and does - # not evict the worker (typically a version - # skew on a shared result class). - raise UnexpectedResponse( - "Worker shipped a malformed result payload" - ) from exc - if decode_failures: - raise BaseExceptionGroup( - "response context decode failed", - decode_failures, - ) - return result - elif response.HasField("exception"): - # Degrade malformed exception payloads (cloudpickle - # version skew, missing class on the caller's path, - # truncated bytes, worker-side serializer bug) to - # :class:`UnexpectedResponse` so the load balancer - # treats it as a caller-fault and does not evict the - # worker for what is typically a version-skew issue. - # Mirrors the non-Exception payload degradation - # below; the parse-phase Nack path keeps its - # :class:`RpcError` fallback because worker-side - # parse rejection has different worker-health - # semantics than a routine-time decode mismatch. - try: - worker_exc = self._serializer.loads(response.exception.dump) - except Exception as exc: - # Preserve the original pickle/import failure - # via manual ``__cause__`` chaining — we assign - # ``worker_exc`` and continue into the - # narrowing + note-attachment block below, so - # ``raise X from Y`` syntax isn't applicable - # here. The later ``raise worker_exc`` honors - # the manually-set ``__cause__`` identically to - # ``raise X from Y``. - worker_exc = UnexpectedResponse( - "Worker shipped a malformed exception payload" - ) - worker_exc.__cause__ = exc - worker_exc.__suppress_context__ = True - # See ``__anext__``'s ``:raises Exception:`` / - # ``:raises asyncio.CancelledError:`` for the - # narrowing contract. ``CancelledError`` is allowed - # to propagate raw to mirror stdlib's ``await - # task`` semantics where a routine that self-raises - # ``CancelledError`` is indistinguishable from one - # that was externally cancelled. Other non-Exception - # ``BaseException`` subclasses are degraded to - # :class:`UnexpectedResponse` (not :class:`RpcError`) - # so process-level signals cannot be smuggled and - # the load balancer does not evict the worker for a - # routine-level fault. - if not isinstance(worker_exc, (Exception, asyncio.CancelledError)): - worker_exc = UnexpectedResponse( - "Worker shipped a non-Exception payload in " - f"Response.exception: {type(worker_exc).__name__}" - ) - if decode_failures: - # Attach decode failures to the worker exception - # rather than wrap both in a - # :class:`BaseExceptionGroup`. Mirrors the - # worker's encode-side handling - # (:mod:`wool.runtime.worker.service`), so the - # caller's existing ``except`` against the - # routine's exception class keeps matching — - # users don't have to migrate to ``except*``. - try: - for w in decode_failures: - worker_exc.add_note(f"wool context warning: {w}") - except (AttributeError, TypeError): - pass - try: - setattr( - worker_exc, - "__wool_context_warnings__", - decode_failures, - ) - except AttributeError: - pass - # Mirror stdlib's local-cancel state shape: bump - # ``current_task().cancelling()`` synchronously and - # forward the worker's cancel message so idiomatic - # ``except CancelledError`` patterns - # (``if cancelling() > 0: raise`` re-raise gates, - # ``current_task().uncancel()`` absorbers) and any - # caller that introspects task state behave - # identically for worker-side and local cancels. The - # next-cycle ``CancelledError`` that ``Task.cancel()`` - # schedules is suppressed by ``uncancel()`` per - # asyncio's contract. - if isinstance(worker_exc, asyncio.CancelledError): - current = asyncio.current_task() - if current is not None: - cancel_msg = worker_exc.args[0] if worker_exc.args else None - current.cancel(cancel_msg) - raise worker_exc + await stream.__anext__() # Prime: pins resources + handshake + except grpc.RpcError as error: + code = error.code() + details = error.details() or str(error) + if code in self.TRANSIENT_ERRORS: + raise TransientRpcError(code, details) from error else: - raise UnexpectedResponse( - f"Expected 'result' or 'exception' response, " - f"received '{response.WhichOneof('payload')}'" - ) - except BaseException: - # Cancel the underlying gRPC call on any abnormal exit - # — including ``asyncio.CancelledError`` (a - # ``BaseException`` subclass), so cancellation - # propagates without leaking the in-flight call. - # Mirrors stdlib ``await agen.__anext__()`` cleanup - # semantics: any non-normal-return exit triggers - # resource cleanup before re-raising. - # - # The inner cancel-swallow is ``Exception``, not - # ``BaseException``: this is cleanup-during-cleanup, - # so a ``KeyboardInterrupt`` mid-cancel should - # propagate rather than be silently dropped. - try: - self._call.cancel() - except Exception: - pass - raise + raise RpcError(code, details) from error + except asyncio.TimeoutError as error: + # Local dispatch-phase timeout is the same semantic as + # gRPC DEADLINE_EXCEEDED — request took too long. Wrap + # so the load-balancer contract only needs to know + # about `RpcError`. Worker isn't presumed + # unhealthy; transient-class makes the LB skip without + # eviction. + raise TransientRpcError( + grpc.StatusCode.DEADLINE_EXCEEDED, + "Local dispatch-phase timeout exceeded", + ) from error - async def aclose(self) -> None: - """Close the async generator and cancel the underlying gRPC call. + return cast(AsyncGenerator[protocol.Message, None], stream) - This method provides proper cleanup for async generators decorated - with @routine. When called, it cancels the gRPC stream to the worker, - which triggers cleanup on the worker side. + async def close(self): + """Close the connection and release all pooled resources. - Implements the async generator protocol's aclose() method to match - native Python async generator behavior. This method is idempotent - and can be safely called multiple times. + Clears the pooled channel entries for both the TCP key and, + if a UDS address is available, the UDS key. Idempotent: safe + to call multiple times or on connections that were never used. """ - if self._closed: # pragma: no cover - return - - self._closed = True try: - self._call.cancel() - except Exception: + await _channel_pool.clear(self._key) + except KeyError: pass + if self._uds_key is not None: + try: + await _channel_pool.clear(self._uds_key) + except KeyError: + pass - async def asend(self, value): - """Send a value into the remote async generator. + async def _handshake( + self, + call: _DispatchCall, + task: Task, + ) -> None: + """Send the dispatch request and await the worker's acknowledgement. - Serializes *value*, writes it as a ``Message`` frame to the - bidirectional stream, and returns the next yielded result. + Caller is responsible for channel-permit and call-cancel + lifecycle; `_execute` pins both on its exit stack so any + failure here triggers the registered cleanup callbacks during + unwind. - :param value: - The value to send into the generator. - :returns: - The next yielded value from the remote generator. - :raises StopAsyncIteration: - When the remote generator is exhausted or the stream - has been closed. - :raises RuntimeError: - If another iteration is already in progress. + On a Nack (parse-phase worker rejection), re-raises the + worker's original exception unchanged. On a malformed + Nack payload (loads raises, or yields a non-Exception), + falls back to `RpcError`. """ - if self._closed: # pragma: no cover - raise StopAsyncIteration - if self._running: # pragma: no cover - raise RuntimeError("asend(): asynchronous generator is already running") - self._running = True - try: - request = protocol.Request( - send=protocol.Message(dump=self._serializer.dumps(value)), - context=context.current_context().to_protobuf( - serializer=self._serializer - ), + request = TaskRequestFrame.for_send( + task, serializer=wool.__serializer__ + ).to_protobuf() + await call.write(request) + response = await anext(aiter(call)) + if response.HasField("nack"): + # Every Nack carries a typed parse-phase exception. + # Deserialize and re-raise so the caller observes the + # actual failure class rather than an opaque RpcError. + # Envelope-level rejections (e.g., protocol-version + # mismatch) ride gRPC status codes, not Nack — those + # land in `dispatch`'s ``except grpc.RpcError`` + # arm instead. + try: + raised = wool.__serializer__.loads(response.nack.exception.dump) + except Exception: + raised = None + # Narrowed to ``Exception`` to match + # ``Rejected.original``'s typed contract (worker + # constructs ``Rejected`` only from + # ``except Exception``). A worker that ships a + # non-``Exception`` ``BaseException`` would be a worker + # bug; degrade to `RpcError` rather than smuggle + # cancel/interrupt signals across the wire. A malformed + # dump (loads raises) lands here too. + if isinstance(raised, Exception): + raise raised from None + raise RpcError(details="Task rejected by worker (malformed Nack payload)") + if not response.HasField("ack"): + raise UnexpectedResponse( + f"Expected 'ack' response, received '{response.WhichOneof('payload')}'" ) - await self._call.write(request) - result = await self._read_next() - return result - finally: - self._running = False - async def athrow(self, typ, val=None, tb=None): - """Throw an exception into the remote async generator. + async def _execute( + self, + task: Task, + key: _PoolKey, + timeout: float | None, + ) -> AsyncGenerator[protocol.Message | None, None]: + """Async generator that owns the full dispatch lifecycle. - Serializes the exception and sends it as a ``Message`` frame. - The remote generator receives the exception via ``athrow()`` - and may handle or propagate it. + Pins the channel pool ref, the channel-concurrency permit, + and the gRPC call's cancel hook on a single + `AsyncExitStack`. + Completes the handshake before yielding to the caller; any + exit path — setup failure, priming-yield ``GeneratorExit``, + mid-stream exception, natural end of stream — unwinds the + stack and releases every resource exactly once. - :param typ: - The exception type or instance to throw. - :param val: - The exception value (if *typ* is a type). - :param tb: - The exception traceback. - :returns: - The next yielded value from the remote generator. - :raises StopAsyncIteration: - When the remote generator is exhausted or the stream - has been closed. - :raises RuntimeError: - If another iteration is already in progress. + The stack unwind is driven through `_complete_teardown` + so the release callbacks run to completion even when the + caller task is mid-cancellation — otherwise a pending + ``CancelledError`` could pre-empt ``AsyncExitStack.__aexit__`` + and leak a pooled channel reference. """ - if self._closed: # pragma: no cover - raise StopAsyncIteration - if self._running: # pragma: no cover - raise RuntimeError("athrow(): asynchronous generator is already running") - self._running = True + stack = AsyncExitStack() try: - if isinstance(typ, BaseException): # pragma: no cover - exc = typ - elif val is not None: - exc = val - else: # pragma: no cover - exc = typ() - - request = protocol.Request( - throw=protocol.Message(dump=self._serializer.dumps(exc)), - context=context.current_context().to_protobuf( - serializer=self._serializer - ), - ) - await self._call.write(request) - result = await self._read_next() - return result - finally: - self._running = False + channel = await stack.enter_async_context(_channel_pool.get(key)) + # Acquire the concurrency permit and complete the + # handshake under the dispatch-phase timeout. + # ``Semaphore.acquire()`` is cancel-safe: if cancelled + # before it returns, no permit is taken; if it returns, + # the next line (sync) registers the release. The two- + # step "acquire then register" is therefore atomic with + # respect to cancellation. + async with asyncio.timeout(timeout): + await channel.semaphore.acquire() + stack.callback(channel.semaphore.release) -# public -class UnexpectedResponse(Exception): - """Raised when a worker returns an unexpected response type. + call: _DispatchCall = channel.stub.dispatch() - This exception indicates a protocol violation where the worker's - response doesn't match the expected format (e.g., missing acknowledgment - or returning an unrecognized payload type). - """ + # Cancel the in-flight gRPC call on any unwind. + # Swallow ``Exception`` (not ``BaseException``) so + # a buggy stub's ``cancel()`` does not replace + # whatever exception is unwinding the stack; + # cleanup-during-cleanup. + def _safe_cancel() -> None: + try: + call.cancel() + except Exception: + pass + stack.callback(_safe_cancel) + await self._handshake(call, task) -# public -class RpcError(Exception): - """Raised when a gRPC call to a worker fails with a non-transient - error. - - Non-transient errors indicate persistent issues with the worker - that are unlikely to be resolved by retrying (e.g., invalid - arguments, unimplemented methods, permission denied, - server-side bugs, version skew). + # Priming yield. All resources are pinned on the stack + # and the worker has acknowledged the task. The + # caller's ``__anext__`` prime returns here. + yield - **Worker-health exception contract.** Load-balancer strategies - treat exception classes from :meth:`WorkerConnection.dispatch` - as a three-way classification: + stream = _DispatchStream(call, task) + try: + sent = None + result = await anext(stream) + while True: + try: + sent = yield result + except GeneratorExit: + # Short-circuit before ``except + # BaseException`` below catches and + # ``athrow``s the GeneratorExit into the + # inner stream. Cancellation of the + # in-flight gRPC call happens via the + # AsyncExitStack's ``_safe_cancel`` + # callback on stack unwind — single + # resource ownership, single cancel. + return + except BaseException as exc: + result = await stream.athrow(type(exc), exc) + else: + result = await stream.asend(sent) + except StopAsyncIteration: + return + # Other abnormal exits (``asyncio.CancelledError``, + # routine exceptions, mid-stream gRPC errors) propagate + # uncaught; the AsyncExitStack's ``_safe_cancel`` + # callback fires on unwind to cancel the in-flight + # gRPC call. + finally: + # Shield the stack unwind from caller cancellation so + # every pooled-resource release callback runs — see + # `_complete_teardown`. ``aclose()`` drives each + # registered ``__aexit__`` with no exception info; that + # is equivalent to the implicit ``async with`` exit only + # because every context manager on this stack is + # exception-agnostic. + await _complete_teardown(stack.aclose()) - - :class:`TransientRpcError` — worker is hiccupping - (``UNAVAILABLE`` / ``DEADLINE_EXCEEDED`` / - ``RESOURCE_EXHAUSTED``); the strategy should **skip** to - the next worker without eviction. The worker may recover. - - :class:`RpcError` (non-transient) — worker is unhealthy - (``INTERNAL``, ``FAILED_PRECONDITION``, malformed Nack - dump, version skew, etc.); the strategy should **evict**. - Today's binary policy is "evict on first occurrence"; - health-aware forgiveness (N-strikes) is a follow-up. - - Any other class — caller-fault (parse-phase failures - re-raised as the original exception type, caller-side - encode failures, programming bugs); the strategy - **propagates** to the caller without touching the pool. - Strategy authors implementing :class:`LoadBalancerLike` MUST - honor this contract: a strategy that catches :class:`Exception` - indiscriminately will silently evict workers on every - caller-side bug, wiping the pool over time. - """ +@dataclass +class _Channel: + """Internal holder for a pooled gRPC channel and its resources.""" - def __init__( - self, - code: grpc.StatusCode | None = None, - details: str | None = None, - ): - self.code = code - self.details = details - if code is not None and details is not None: - super().__init__(f"{code.name}: {details}") - elif code is not None: # pragma: no cover - super().__init__(code.name) - elif details is not None: - super().__init__(details) - else: # pragma: no cover - super().__init__() + channel: grpc.aio.Channel + stub: protocol.WorkerStub + semaphore: asyncio.Semaphore + async def close(self): + """Close the underlying gRPC channel.""" + await self.channel.close() -# public -class TransientRpcError(RpcError): - """Raised when a gRPC call to a worker fails with a transient error. - Transient errors indicate temporary issues that may be resolved by - retrying the operation, such as: +def _channel_factory(key): + """Create a new `_Channel` for the given pool key. - - ``UNAVAILABLE``: Worker temporarily unavailable - - ``DEADLINE_EXCEEDED``: Request took too long - - ``RESOURCE_EXHAUSTED``: Worker temporarily overloaded + :param key: + Tuple of ``(target, credentials, options)``. + :returns: + A new `_Channel` instance. """ + target, credentials, options = key + grpc_options = [ + ("grpc.max_receive_message_length", options.max_receive_message_length), + ("grpc.max_send_message_length", options.max_send_message_length), + ("grpc.keepalive_time_ms", options.keepalive_time_ms), + ("grpc.keepalive_timeout_ms", options.keepalive_timeout_ms), + ( + "grpc.keepalive_permit_without_calls", + int(options.keepalive_permit_without_calls), + ), + ("grpc.http2.max_pings_without_data", options.max_pings_without_data), + ("grpc.max_concurrent_streams", options.max_concurrent_streams), + ( + "grpc.default_compression_algorithm", + options.compression.value, + ), + ] + if credentials is not None: + channel = grpc.aio.secure_channel(target, credentials, options=grpc_options) + else: + channel = grpc.aio.insecure_channel(target, options=grpc_options) + stub = protocol.WorkerStub(channel) + return _Channel(channel, stub, asyncio.Semaphore(options.max_concurrent_streams)) -# public -class WorkerConnection: - """gRPC connection to a worker for task dispatch. - - Acquires pooled gRPC channels keyed by ``(target, credentials, - options)``. Each :meth:`dispatch` call obtains a reference-counted - channel from the module-level pool, primes an async generator that - holds its own reference, then releases the dispatch-scope reference. - The channel stays alive until the caller finishes consuming the - result stream. - - **Cleanup semantics on cancellation.** Every code path that owns - an in-flight gRPC call wraps its body in - ``try / except BaseException`` so that ``asyncio.CancelledError`` - (a :class:`BaseException` subclass, not :class:`Exception`) still - triggers ``call.cancel()`` before re-raising. The cancel itself - is swallowed at :class:`Exception` (not :class:`BaseException`) — - cleanup-during-cleanup should let ``KeyboardInterrupt`` propagate - rather than silently drop a process-level signal. +async def _channel_finalizer(channel: _Channel): + """Close the gRPC channel held by a `_Channel`. - **Usage:** + :param channel: + The `_Channel` to finalize. + """ + await channel.close() - .. code-block:: python - conn = WorkerConnection("localhost:50051") - async for result in conn.dispatch(task): - process(result) - await conn.close() +_channel_pool: ResourcePool[_Channel] = ResourcePool( + factory=_channel_factory, finalizer=_channel_finalizer, ttl=60 +) - :param target: - Worker URI. Supports multiple formats: - - ``host:port`` - DNS name or IP with port - - ``dns://host:port`` - Explicit DNS resolution - - ``ipv4:address:port`` - IPv4 address - - ``ipv6:[address]:port`` - IPv6 address - - ``unix:path`` - Unix domain socket +async def clear_channel_pool() -> None: + """Close and clear every gRPC channel in the process-wide pool. - Examples: ``localhost:50051``, ``192.0.2.1:50051`` - :param credentials: - Optional channel credentials for TLS/mTLS connections. - :param options: - Optional channel options controlling gRPC message - size limits, keepalive, concurrency, and compression. - See :class:`ChannelOptions` for defaults. The - ``max_concurrent_streams`` field sizes the per-channel - concurrency semaphore. + Invalidates cached channels across every pool key, including + UDS targets. """ + await _channel_pool.clear() - TRANSIENT_ERRORS: Final = { - grpc.StatusCode.UNAVAILABLE, - grpc.StatusCode.DEADLINE_EXCEEDED, - grpc.StatusCode.RESOURCE_EXHAUSTED, - } - def __init__( - self, - target: str, - *, - credentials: grpc.ChannelCredentials | None = None, - options: ChannelOptions | None = None, - ): - self._target = target - self._credentials = credentials - self._options = options if options is not None else ChannelOptions() - self._key: _PoolKey = (target, credentials, self._options) - self._uds_key: _PoolKey | None = None +_TEARDOWN_TIMEOUT: Final = 60.0 - async def dispatch( - self, - task: Task, - *, - timeout: float | None = None, - ) -> AsyncGenerator[protocol.Message, None]: - """Dispatch a task to the remote worker for execution. - Sends the task to the worker via gRPC, waits for acknowledgment, - and returns an async iterator that streams back results. Respects - concurrency limits and applies timeout to the dispatch phase only - (semaphore acquisition and acknowledgment). +async def _complete_teardown(teardown: Coroutine[Any, Any, None]) -> None: + """Drive *teardown* to completion, immune to caller cancellation. - **Context decode failures (caller-side).** - Each response frame may carry a back-propagated wire context - that needs decoding before the caller can merge worker-side - mutations. Wire context is **ancillary state** under wool's - protocol contract: per-entry decode failures emit - :class:`wool.ContextDecodeWarning` instances inside - :meth:`Context.from_protobuf`. Under the warnings system's - default filter these surface once as warnings and decoding - returns the partial Context; under a filter that promotes - :class:`wool.ContextDecodeWarning` to an error, - :meth:`Context.from_protobuf` aggregates the per-entry - exceptions into a :class:`BaseExceptionGroup` and raises in - place of returning. Caller-side handling after loading the - primary signal: - - * On a result frame, if decoding aggregated, the - :class:`BaseExceptionGroup` raises in place of the return — - strict mode loses the primary value but every decode - failure surfaces, not just the first. - * On an exception frame, decode failures are attached to - the worker exception via PEP 678 ``__notes__`` (visible - in tracebacks) and a ``__wool_context_warnings__`` - attribute (programmatic access), mirroring the worker's - encode-side handling. The worker exception class is - preserved so the caller's existing - ``except RoutineError`` continues to catch without - migration to ``except*``. Under the default filter the - per-entry warnings emit once during decode and the worker - exception raises unwrapped. + Resource teardown registered on an `AsyncExitStack` + awaits its release callbacks. When the caller task carries a + pending cancellation — externally via ``task.cancel()`` or from + a worker-side ``CancelledError`` re-raised into it by + `_DispatchStream._read_next` — asyncio would pre-empt the + next suspending teardown ``await`` and skip the remaining + callbacks, leaking a pooled resource reference. - :param task: - The :class:`Task` instance to dispatch to the worker. - :param timeout: - Timeout in seconds for semaphore acquisition and task - acknowledgment. If ``None``, no timeout is applied. Does not - apply to the execution phase. - :returns: - An async iterator that yields task results from the worker. - :raises TransientRpcError: - If the worker returns a transient RPC error (UNAVAILABLE, - DEADLINE_EXCEEDED, or RESOURCE_EXHAUSTED) or the local - dispatch-phase timeout fires (also classified as - DEADLINE_EXCEEDED). - :raises RpcError: - If the worker returns a non-transient RPC error or - rejects with a Nack whose dumped exception cannot be - deserialized (malformed-dump fallback). - :raises UnexpectedResponse: - If the worker doesn't acknowledge the task. - :raises ValueError: - If the timeout value is not positive. + Running *teardown* as a shielded child task gives it an + independent cancellation state, so its ``await`` boundaries run + uninterrupted. A cancellation observed while waiting is deferred + and re-raised once teardown finishes, so the caller still + observes the cancel. - If the worker rejects the task during the parse phase due - to a malformed task payload, the original exception class - is deserialized from the Nack and re-raised so the caller - observes the actual failure class rather than an opaque - protocol error. A malformed Nack payload falls back to - :class:`RpcError`. - - Encode-side failures (e.g. a strict-mode - :class:`BaseExceptionGroup` of - :class:`wool.ContextDecodeWarning` peers raised by - :meth:`Context.to_protobuf` when an unpicklable - :class:`wool.ContextVar` value is set) propagate unwrapped: - the load-balancer contract treats only :class:`RpcError` - instances as worker-health concerns, so a caller-side encode - failure surfaces directly to the caller rather than evicting - workers. - """ - if timeout is not None and timeout <= 0: - raise ValueError("Dispatch timeout must be positive") + A teardown-side exception other than ``CancelledError`` propagates + and supersedes a deferred cancel, mirroring ``finally`` precedence. + ``KeyboardInterrupt`` and ``SystemExit`` are captured off the child + task — where they would otherwise escape straight to the event-loop + runner via ``Task.__step`` — and re-raised in the caller's context. + If teardown does not finish within `_TEARDOWN_TIMEOUT` the + caller is unblocked and the shielded task is left running detached + so the release still completes. + """ + interrupt: KeyboardInterrupt | SystemExit | None = None - if ( - metadata := wool.__worker_metadata__ - ) is not None and metadata.address == self._target: - if (uds_address := wool.__worker_uds_address__) is not None: - key = (uds_address, None, self._options) - self._uds_key = key - else: - key = self._key - else: - key = self._key + async def _run() -> None: + # Capture process-level signals here, inside the child task's + # own frame: a ``KeyboardInterrupt``/``SystemExit`` raised by a + # task escapes to the event-loop runner rather than to the + # awaiter, so it must not be left to propagate out of the task. + nonlocal interrupt + try: + await teardown + except (KeyboardInterrupt, SystemExit) as exc: + interrupt = exc - stream = self._execute(task, key, timeout) + task = asyncio.ensure_future(_run()) + deferred: asyncio.CancelledError | None = None + while True: try: - await stream.__anext__() # Prime: pins resources + handshake - except grpc.RpcError as error: - code = error.code() - details = error.details() or str(error) - if code in self.TRANSIENT_ERRORS: - raise TransientRpcError(code, details) from error - else: - raise RpcError(code, details) from error - except asyncio.TimeoutError as error: - # Local dispatch-phase timeout is the same semantic as - # gRPC DEADLINE_EXCEEDED — request took too long. Wrap - # so the load-balancer contract only needs to know - # about :class:`RpcError`. Worker isn't presumed - # unhealthy; transient-class makes the LB skip without - # eviction. - raise TransientRpcError( - grpc.StatusCode.DEADLINE_EXCEEDED, - "Local dispatch-phase timeout exceeded", - ) from error + async with asyncio.timeout(_TEARDOWN_TIMEOUT): + await asyncio.shield(task) + except TimeoutError: + # Teardown is wedged — only reachable under pathological + # pool-lock contention. Stop blocking the caller; the + # shielded task keeps running so the release still + # completes, just not synchronously. + _log.warning( + "Routine teardown exceeded %.0fs; pooled-resource " + "release deferred to a detached task.", + _TEARDOWN_TIMEOUT, + ) + # If the shielded task is still in flight at + # timeout, a captured process-level interrupt could + # surface from it later with no awaiter. Log so an + # operator can correlate. The done-with-interrupt case + # falls through to the post-loop ``raise interrupt``. + if not task.done(): + _log.debug( + "Routine teardown detached with shielded task " + "still pending; a process-level interrupt may " + "surface later from the detached task." + ) + break + except asyncio.CancelledError as exc: + if not task.done(): + # Caller cancelled mid-teardown — keep the shielded + # task running and re-await it on the next iteration. + deferred = exc + continue + raise + break + if interrupt is not None: + raise interrupt + if deferred is not None: + raise deferred - return cast(AsyncGenerator[protocol.Message, None], stream) - async def close(self): - """Close the connection and release all pooled resources. +class _DispatchStream(Generic[_T]): + """Async iterator wrapper for streaming task results from workers. - Clears the pooled channel entries for both the TCP key and, - if a UDS address is available, the UDS key. Idempotent: safe - to call multiple times or on connections that were never used. - """ - try: - await _channel_pool.clear(self._key) - except KeyError: - pass - if self._uds_key is not None: - try: - await _channel_pool.clear(self._uds_key) - except KeyError: - pass + Handles iteration over gRPC response streams and deserializes + task results or raises exceptions received from remote workers. - async def _handshake( + :param call: + The underlying gRPC response stream. + """ + + def __init__( self, call: _DispatchCall, - wire_task: protocol.Task, - ) -> None: - """Send the dispatch request and wait for the worker's - acknowledgement. Caller is responsible for channel-permit - and call-cancel lifecycle; :meth:`_execute` pins both on - its exit stack so any failure here triggers the registered - cleanup callbacks during unwind. + task: Task, + serializer: Serializer | None = None, + ): + self._call = call + self._task = task + self._serializer: Serializer = ( + serializer if serializer is not None else wool.__serializer__ + ) + self._iter = aiter(call) + self._closed = False + self._running = False - On a Nack (parse-phase worker rejection), re-raises the - worker's original exception unchanged. On a malformed - Nack payload (loads raises, or yields a non-Exception), - falls back to :class:`RpcError`. + async def __anext__(self) -> _T: + """Get the next response from the stream. + + Sends a ``next`` request to the server to advance the remote + generator, then reads and returns the next result. + + :returns: + The next task result from the worker. + :raises StopAsyncIteration: + When the stream is exhausted or after aclose() is called. + :raises RuntimeError: + If another iteration is already in progress. + :raises UnexpectedResponse: + If the response payload is unrecognised (the wire's + ``oneof`` carries neither a ``result`` nor an + ``exception``), or if the worker ships a + non-`Exception` `BaseException` payload + other than `asyncio.CancelledError` (e.g., + `KeyboardInterrupt`, `SystemExit`, + user-defined `BaseException` subclasses). Both + are protocol-shape violations. + :raises pickle.UnpicklingError: + If a result or exception payload cannot be deserialised + (cloudpickle version skew, missing class on the caller's + path, truncated bytes, worker-side serializer bug, etc.). + The original exception type from the serializer + propagates with no wrapping — `pickle.PickleError` + subclasses, `AttributeError`, + `ImportError`, and similar all surface raw. The + load balancer treats anything outside `RpcError` + as caller-fault and does not evict the worker. + :raises asyncio.CancelledError: + When the worker-side routine raises + `asyncio.CancelledError` from its body (or is + externally cancelled and propagates the + ``CancelledError`` out). Mirrors stdlib's ``await task`` + semantics where ``raise CancelledError`` from the + awaitee is indistinguishable from + ``task.cancel()`` — both transition the task to + ``CANCELLED`` and the caller's ``await`` raises + ``CancelledError``. This site does not bump the caller + task's ``cancelling()`` count — the worker-shipped + ``CancelledError`` is re-raised as-is, leaving the + awaiter's cancelling count untouched exactly as stdlib + does when the awaitee raises ``CancelledError``. + A caller that catches it and continues to ``await`` a + recovery path may be re-interrupted at the next checkpoint + until it calls ``current_task().uncancel()`` — a step the + wool-naive caller cannot reasonably know to add. + :raises Exception: + The worker-side routine's exception, re-raised in its + original class. The class is narrowed to + `Exception` for non-`CancelledError` + `BaseException` subclasses (`KeyboardInterrupt`, + `SystemExit`, or user-defined `BaseException` + subclasses): these are degraded to + `UnexpectedResponse` so process-level signals + cannot be smuggled across the wire and trip caller-side + signal handlers. `UnexpectedResponse` is not a + `RpcError` subclass, so the load balancer treats + it as a caller-fault and does not evict the worker. + Caller-side gRPC cancellation arrives via a different + path, not via this exception. + :raises wool.ChainSerializationError: + Under strict mode, when the response's chain manifest + fails to decode. On a result frame it raises as the + primary (the routine's value is dropped — a result + cannot be trusted alongside a chain manifest that failed to + apply). On an exception frame it is appended to the + tail of the worker exception's ``__context__`` chain + (preserving any routine-side chain the worker brought + via tblib) — neither failure caused the other, so + ``__context__`` is the honest channel rather than + ``__cause__``. """ - request = protocol.Request( - task=wire_task, - context=context.current_context().to_protobuf(), + if self._closed: # pragma: no cover + raise StopAsyncIteration + return await self._send_and_read( + NextRequestFrame.for_send(serializer=self._serializer), + method_name="anext", ) - await call.write(request) - response = await anext(aiter(call)) - if response.HasField("nack"): - # Every Nack carries a typed parse-phase exception. - # Deserialize and re-raise so the caller observes the - # actual failure class rather than an opaque RpcError. - # Envelope-level rejections (e.g., protocol-version - # mismatch) ride gRPC status codes, not Nack — those - # land in :meth:`dispatch`'s ``except grpc.RpcError`` - # arm instead. + + async def _send_and_read( + self, request_frame: RequestFrame, *, method_name: str + ) -> _T: + """Run the shared guard/write/read/finally choreography for one round-trip. + + Centralizes the guard/write/read/finally shape so each + caller (``__anext__`` / ``asend`` / ``athrow``) stays a + two-line wrapper. *method_name* is used purely for the + already-running guard's error message so the diagnostic + still names the caller-facing entry point. + """ + if self._running: # pragma: no cover + raise RuntimeError( + f"{method_name}(): asynchronous generator is already running" + ) + self._running = True + try: + await self._call.write(request_frame.to_protobuf()) + return await self._read_next() + finally: + self._running = False + + async def _read_next(self) -> _T: + """Read the next response from the stream without writing. + + Serves paths that have already written their own request. + Decodes the wire envelope into the matching response leaf + (typically `ResultResponseFrame` or + `ExceptionResponseFrame`), then merges the response's + chain manifest into the caller's active chain — variable + mutations and reset signals both ride back-propagation. + + :returns: + The next task result from the worker. + """ + try: + response = await anext(self._iter) + # Up-front protocol-shape check. ``Frame.from_protobuf`` + # raises ``ValueError`` for an unset payload oneof; surface + # that as ``UnexpectedResponse`` (the caller-side + # protocol-violation channel) before any decode so the + # serializer never sees malformed bytes. + kind = response.WhichOneof("payload") + if kind not in ("result", "exception"): + raise UnexpectedResponse( + f"Expected 'result' or 'exception' response, received '{kind}'" + ) + # Decode the response envelope. A payload deserialization + # failure (cloudpickle/pickle error, missing class on the + # caller's path, version skew, etc.) propagates with its + # original type — the load balancer treats anything outside + # RpcError as caller-fault. + frame = Frame.from_protobuf(response, serializer=self._serializer) + + if isinstance(frame, ResultResponseFrame): + # A strict-mode chain-manifest decode failure is fatal + # on a result frame — the value can't be trusted + # alongside a chain manifest that failed to apply — so + # ChainSerializationError from frame.mount() propagates raw + # and the result is dropped. + frame.mount() + return frame.payload + + elif isinstance(frame, ExceptionResponseFrame): + exception = frame.payload + + # Narrow non-Exception payloads up-front so subsequent + # code can assume `exception` is Exception | + # CancelledError (raisable and chainable). Catches both + # non-BaseException payloads (dict, string, arbitrary + # objects from a buggy/malicious worker) and + # non-Exception BaseException subclasses + # (KeyboardInterrupt, SystemExit, user-defined + # BaseException). Process-level signals cannot be + # smuggled across the wire; the original is preserved + # on UnexpectedResponse.__context__ when it's a + # BaseException (Python rejects non-BaseException as + # __context__). + if not isinstance(exception, (Exception, asyncio.CancelledError)): + original = exception + exception = UnexpectedResponse( + "Worker shipped a non-Exception payload in " + f"Response.exception: {type(original).__name__}" + ) + if isinstance(original, BaseException): + exception.__context__ = original + # Replace the frame's payload so the decode-error + # chaining walks the validated exception's + # ``__context__`` rather than the raw worker-shipped + # non-Exception payload. + frame.payload = exception + + # Mount the chain manifest. ``ExceptionResponseFrame`` + # carries ``_chain_exceptions`` so a + # deferred ChainSerializationError is silently chained onto + # the payload exception's ``__context__`` walked to the + # bottom — no raise propagates here. The two failures + # are independent (the worker raised X for routine + # reasons; the chain manifest failed to apply for + # serializer reasons), so neither caused the other and + # ``__cause__`` would overclaim. + frame.mount() + + # Worker-shipped ``CancelledError`` propagates as-is; + # this site does not bump ``current_task().cancelling()`` + # to mirror stdlib's local-cancel state shape. That + # would deviate from ``await task`` semantics: stdlib + # does not bump the awaiter's cancelling count when the + # awaitee raises CancelledError. A caller that catches + # ``CancelledError`` and continues to ``await`` + # something else (a recovery path) would be + # re-interrupted at the next checkpoint until it called + # ``current_task().uncancel()`` — a step the wool-naive + # caller cannot reasonably know to add. + raise exception + + else: # pragma: no cover — guarded by the up-front kind check + raise UnexpectedResponse( + f"Expected 'result' or 'exception' response, " + f"received {type(frame).__name__}" + ) + except BaseException: try: - raised = wool.__serializer__.loads(response.nack.exception.dump) + self._call.cancel() except Exception: - raised = None - # Narrowed to ``Exception`` to match - # ``Rejected.original``'s typed contract (worker - # constructs ``Rejected`` only from - # ``except Exception``). A worker that ships a - # non-``Exception`` ``BaseException`` would be a worker - # bug; degrade to :class:`RpcError` rather than smuggle - # cancel/interrupt signals across the wire. A malformed - # dump (loads raises) lands here too. - if isinstance(raised, Exception): - raise raised from None - raise RpcError(details="Task rejected by worker (malformed Nack payload)") - if not response.HasField("ack"): - raise UnexpectedResponse( - f"Expected 'ack' response, received '{response.WhichOneof('payload')}'" - ) + pass + raise - async def _execute( - self, - task: Task, - key: _PoolKey, - timeout: float | None, - ) -> AsyncGenerator[protocol.Message | None, None]: - """Async generator that owns the full dispatch lifecycle. + async def aclose(self) -> None: # pragma: no cover — no production caller + """Close the async generator and cancel the underlying gRPC call. - Pins the channel pool ref, the channel-concurrency permit, - and the gRPC call's cancel hook on a single - :class:`AsyncExitStack`. - Completes the handshake before yielding to the caller; any - exit path — setup failure, priming-yield ``GeneratorExit``, - mid-stream exception, natural end of stream — unwinds the - stack and releases every resource exactly once. + This method provides proper cleanup for async generators decorated + with @routine. When called, it cancels the gRPC stream to the worker, + which triggers cleanup on the worker side. - The stack unwind is driven through :func:`_complete_teardown` - so the release callbacks run to completion even when the - caller task is mid-cancellation — otherwise a pending - ``CancelledError`` could pre-empt ``AsyncExitStack.__aexit__`` - and leak a pooled channel reference. + Implements the async generator protocol's aclose() method to match + native Python async generator behavior. This method is idempotent + and can be safely called multiple times. """ - stack = AsyncExitStack() + if self._closed: + return + + self._closed = True try: - channel = await stack.enter_async_context(_channel_pool.get(key)) - wire_task = task.to_protobuf() + self._call.cancel() + except Exception: + pass - # Acquire the concurrency permit and complete the - # handshake under the dispatch-phase timeout. - # ``Semaphore.acquire()`` is cancel-safe: if cancelled - # before it returns, no permit is taken; if it returns, - # the next line (sync) registers the release. The two- - # step "acquire then register" is therefore atomic with - # respect to cancellation. - async with asyncio.timeout(timeout): - await channel.semaphore.acquire() - stack.callback(channel.semaphore.release) + async def asend(self, value): + """Send a value into the remote async generator. - call: _DispatchCall = channel.stub.dispatch() + Serializes *value*, writes it as a ``Message`` frame to the + bidirectional stream, and returns the next yielded result. - # Cancel the in-flight gRPC call on any unwind. - # Swallow ``Exception`` (not ``BaseException``) so - # a buggy stub's ``cancel()`` does not replace - # whatever exception is unwinding the stack; - # cleanup-during-cleanup. - def _safe_cancel() -> None: - try: - call.cancel() - except Exception: - pass + :param value: + The value to send into the generator. + :returns: + The next yielded value from the remote generator. + :raises StopAsyncIteration: + When the remote generator is exhausted or the stream + has been closed. + :raises RuntimeError: + If another iteration is already in progress. + """ + if self._closed: # pragma: no cover + raise StopAsyncIteration + return await self._send_and_read( + SendRequestFrame.for_send(value, serializer=self._serializer), + method_name="asend", + ) - stack.callback(_safe_cancel) - await self._handshake(call, wire_task) + async def athrow(self, typ, val=None, tb=None): + """Throw an exception into the remote async generator. - # Priming yield. All resources are pinned on the stack - # and the worker has acknowledged the task. The - # caller's ``__anext__`` prime returns here. - yield + Serializes the exception and sends it as a ``Message`` frame. + The remote generator receives the exception via ``athrow()`` + and may handle or propagate it. - stream = _DispatchStream(call, task) - try: - sent = None - result = await anext(stream) - while True: - try: - sent = yield result - except GeneratorExit: - # Short-circuit before ``except - # BaseException`` below catches and - # ``athrow``s the GeneratorExit into the - # inner stream. Cancellation of the - # in-flight gRPC call happens via the - # AsyncExitStack's ``_safe_cancel`` - # callback on stack unwind — single - # resource ownership, single cancel. - return - except BaseException as exc: - result = await stream.athrow(type(exc), exc) - else: - result = await stream.asend(sent) - except StopAsyncIteration: - return - # Other abnormal exits (``asyncio.CancelledError``, - # routine exceptions, mid-stream gRPC errors) propagate - # uncaught; the AsyncExitStack's ``_safe_cancel`` - # callback fires on unwind to cancel the in-flight - # gRPC call. - finally: - # Shield the stack unwind from caller cancellation so - # every pooled-resource release callback runs — see - # :func:`_complete_teardown`. ``aclose()`` drives each - # registered ``__aexit__`` with no exception info; that - # is equivalent to the implicit ``async with`` exit only - # because every context manager on this stack is - # exception-agnostic. - await _complete_teardown(stack.aclose()) + :param typ: + The exception type or instance to throw. + :param val: + The exception value (if *typ* is a type). + :param tb: + The exception traceback. + :returns: + The next yielded value from the remote generator. + :raises StopAsyncIteration: + When the remote generator is exhausted or the stream + has been closed. + :raises RuntimeError: + If another iteration is already in progress. + """ + if self._closed: # pragma: no cover + raise StopAsyncIteration + if isinstance(typ, BaseException): # pragma: no cover + exc = typ + elif val is not None: + exc = val + else: # pragma: no cover + exc = typ() + return await self._send_and_read( + ThrowRequestFrame.for_send(exc, serializer=self._serializer), + method_name="athrow", + ) diff --git a/wool/src/wool/runtime/worker/frame.py b/wool/src/wool/runtime/worker/frame.py new file mode 100644 index 00000000..6e3698c1 --- /dev/null +++ b/wool/src/wool/runtime/worker/frame.py @@ -0,0 +1,686 @@ +"""Unified wire-frame abstraction for dispatch. + +Provides the `Frame` hierarchy: deserialised views of the +`~wool.protocol.Request` / `~wool.protocol.Response` envelopes a +dispatch exchanges, with `Frame.from_protobuf` decoding an inbound +envelope to its frame and `Frame.mount` applying the decoded chain +state onto the active chain. + +The hierarchy is three-tiered: the abstract `Frame` base, the +`RequestFrame` / `ResponseFrame` envelope intermediates, and one +concrete frame per protobuf payload oneof. +""" + +from __future__ import annotations + +import contextvars +import pickle +import warnings +from dataclasses import dataclass +from dataclasses import field +from typing import TYPE_CHECKING +from typing import Any +from typing import ClassVar +from typing import Generic +from typing import Self +from typing import TypeVar +from typing import cast +from typing import get_args + +import wool +from wool import protocol +from wool.runtime.context.chain import Chain +from wool.runtime.context.exceptions import ChainSerializationError +from wool.runtime.context.exceptions import SerializationWarning +from wool.runtime.context.manifest import ChainManifest +from wool.runtime.routine.task import Task +from wool.runtime.typing import Undefined +from wool.runtime.typing import UndefinedType + +if TYPE_CHECKING: + from wool.runtime.serializer import Serializer + + +T = TypeVar("T") + + +@dataclass +class Frame(Generic[T]): + """Abstract base class for the frames a dispatch exchanges. + + A frame is the deserialised view of one + `~wool.protocol.Request` or `~wool.protocol.Response` + envelope: the application ``payload`` plus an optional + `~wool.runtime.context.manifest.ChainManifest` — or the + `~wool.runtime.context.exceptions.ChainSerializationError` raised + when its wire context failed to decode. + + `mount` is the single "receive this frame" entry point, + collapsing the receive paths into one; its docstring owns the + caller-side / worker-side install and the exception-frame + chaining contract. + """ + + payload: T + + _wire_type_name: ClassVar[str] + _payload_field: ClassVar[str] + + chain_manifest: ChainManifest | ChainSerializationError | None = None + + _wire_chain_manifest: protocol.ChainManifest | None = field( + default=None, init=False, repr=False + ) + _serializer: Serializer | None = field(default=None, init=False, repr=False) + _frame_by_field: ClassVar[dict[str, type[Frame]]] = {} + _chain_exceptions: ClassVar[bool] = False + + def __init_subclass__( + cls, *, field: str | None = None, wire_type: str | None = None, **kwargs: Any + ) -> None: + super().__init_subclass__(**kwargs) + if wire_type is not None: + cls._wire_type_name = wire_type + # A frame whose payload is a BaseException chains a chain-manifest + # decode error onto that payload rather than propagating (see + # `mount`). Derive the flag from the bound payload type so it + # cannot drift from the payload, and a new exception frame gets + # the behaviour for free. + for base in getattr(cls, "__orig_bases__", ()): + args = get_args(base) + if args and isinstance(args[0], type) and issubclass(args[0], BaseException): + cls._chain_exceptions = True + break + # A concrete frame declares its protobuf payload-oneof name via + # ``field=`` and is registered for dispatch. Subclasses that + # pass no ``field`` — the abstract intermediates + # (``RequestFrame`` / ``ResponseFrame``) — are not dispatch + # targets and are skipped. + if field is None: + return + cls._payload_field = field + registered = Frame._frame_by_field.setdefault(field, cls) + if registered is not cls: + raise TypeError( + f"{cls.__qualname__} claims payload field {field!r}, already " + f"registered to {registered.__qualname__}; each concrete frame " + f"must map to a distinct protobuf oneof field." + ) + + @classmethod + def from_protobuf( + cls, + wire: protocol.Request | protocol.Response, + *, + serializer: Serializer | None = None, + ) -> Frame: + """Decode an incoming wire envelope into the matching frame. + + Matches ``wire.WhichOneof("payload")`` against the frame + registry; the matched frame class's `_decode_payload` + deserialises the per-variant sub-message, and the wire context + decodes through `ChainManifest.from_protobuf`. A strict-mode + decode failure is not raised here — it is captured as the + `~wool.runtime.context.exceptions.ChainSerializationError` value + of ``chain_manifest`` and deferred to `mount`, which raises it or + chains it onto an exception payload. + + :raises ValueError: + If the payload oneof is unset. + """ + + if serializer is None: + serializer = wool.__serializer__ + field_name = wire.WhichOneof("payload") + if field_name is None: + raise ValueError("wire envelope has no payload set") + # ``_frame_by_field`` is exhaustive over the protobuf payload + # oneof, so a direct lookup is total; a missing key is an + # invariant violation (registry/schema drift) and fails loudly + # via ``KeyError`` rather than being a reachable input error. + frame_cls = cls._frame_by_field[field_name] + payload = frame_cls._decode_payload(wire, serializer=serializer) + chain_manifest: ChainManifest | ChainSerializationError | None + if wire.HasField("context"): + try: + chain_manifest = ChainManifest.from_protobuf( + wire.context, serializer=serializer + ) + except ChainSerializationError as decode_err: + # Defer the strict-mode failure: carry it as the manifest + # value so `mount` raises it (or chains it onto an + # exception payload) instead of preempting the payload here. + chain_manifest = decode_err + else: + chain_manifest = None + return frame_cls(payload=payload, chain_manifest=chain_manifest) + + def to_protobuf(self) -> protocol.Request | protocol.Response: + """Encode this frame as its wire envelope. + + Resolves the wire envelope class via the intermediate's + `_wire_type_name`, builds it with the frame's + `_encode_payload` kwargs, then copies in + `_wire_chain_manifest` if non-None. + + Lazy-wire-frame: the optional ``context`` field is omitted + entirely when `_wire_chain_manifest` is ``None`` (the active + chain was unarmed at `for_send` time). + """ + + serializer = self._serializer or wool.__serializer__ + wire_cls = getattr(protocol, self._wire_type_name) + wire = wire_cls(**self._encode_payload(serializer=serializer)) + if self._wire_chain_manifest is not None: + wire.context.CopyFrom(self._wire_chain_manifest) + return wire + + @classmethod + def _decode_payload(cls, wire: Any, *, serializer: Serializer) -> Any: + """Deserialise this frame's payload sub-message from ``wire``. + + Override on each concrete frame. The default raises so an + abstract / mistyped frame surfaces loudly at decode time. + """ + raise NotImplementedError(f"{cls.__name__}._decode_payload must be overridden") + + def _encode_payload(self, *, serializer: Serializer) -> dict[str, Any]: + """Return a kwargs dict for the wire-envelope constructor. + + The dict has exactly one key — the frame's `_payload_field` + — mapped to the encoded sub-message. Override on each + concrete frame. + """ + raise NotImplementedError( + f"{type(self).__name__}._encode_payload must be overridden" + ) + + @classmethod + def for_send( + cls, + payload: T = None, # type: ignore[assignment] + *, + serializer: Serializer | None = None, + ) -> Self: + """Build an outbound boundary frame ready for `to_protobuf`. + + Boundary frames carry no chain manifest; mid-stream frames mix in + `_ChainManifestFrame` to capture or accept one. + """ + if serializer is None: + serializer = wool.__serializer__ + frame = cls(payload=payload, chain_manifest=None) + frame._serializer = serializer + return frame + + def mount(self, ctx: contextvars.Context | None = None) -> None: + """Apply the frame's chain manifest onto the active chain. + + Gates in order: no-op when ``chain_manifest`` is None; raise the + deferred `ChainSerializationError` first so an all-empty + manifest carrying a strict-mode decode failure still + surfaces; then short-circuit when the manifest has no + observable state. Otherwise route through + `~wool.runtime.context.chain.Chain.from_manifest`: + + * Caller-side mount — ``ctx`` is ``None`` (the default). + The install runs in the current + `contextvars.Context`, stamps the calling thread / + task as owners, ensures Wool's task factory is on the + loop, and merges with the active `Chain` when + armed. + * Worker-side mount — ``ctx`` is the chain's cached + `contextvars.Context`. The install runs inside + ``ctx.run(...)`` so backing writes land in the chain's + context, with no task-owner stamp (the cached context is + driven by a succession of step-tasks, so the chain is owned + thread-wise rather than by any one task). The task factory is + still ensured — a no-op on the already-equipped worker loop, + but the site that surfaces a displaced factory; the per-step + driver separately bypasses the factory when it constructs + step tasks (see + `~wool.runtime.worker.session._create_step_task`). + + Frames whose payload is an exception + (`ThrowRequestFrame`, `NackResponseFrame`, + `ExceptionResponseFrame`) carry `_chain_exceptions`. Any + `ChainSerializationError` raised by the mount is appended + onto `payload`'s ``__context__`` chain rather than + propagating — the caller (or routine, in the throw case) + still gets to handle the primary exception, with the decode + failure surfaced as a chained cause. + + :param ctx: + Optional `contextvars.Context` to install into. + When provided, the install runs via ``ctx.run(...)``. + Pass ``None`` (the default) for caller-side mounts + running in the current context. + + :raises ChainSerializationError: + When ``chain_manifest`` is a `ChainSerializationError` on a + frame that does not chain onto its payload. + """ + + try: + self._do_mount(ctx) + except ChainSerializationError as decode_err: + if not self._chain_exceptions: + raise + # The ``_chain_exceptions`` flag is only set on + # frames whose payload is a BaseException + # (ThrowRequestFrame, NackResponseFrame, + # ExceptionResponseFrame); pyright can't narrow ``T`` + # through the ClassVar correlation. + payload: BaseException = self.payload # type: ignore[assignment] + # Cycle guard on the ``__context__`` walk. An + # adversarial / malformed payload whose ``__context__`` + # eventually points back into the chain would spin this + # walk indefinitely; the visited-set short-circuits + # the first time we'd revisit a node and attaches the + # decode error there. + node = payload + seen: set[int] = {id(node)} + while node.__context__ is not None and id(node.__context__) not in seen: + node = node.__context__ + seen.add(id(node)) + node.__context__ = decode_err + + def _do_mount(self, ctx: contextvars.Context | None) -> None: + """The actual manifest-apply logic, called from `mount`.""" + if self.chain_manifest is None: + return + if isinstance(self.chain_manifest, ChainSerializationError): + # A deferred strict-mode wire-decode failure captured by + # `from_protobuf`. Raising routes it through `mount`'s + # raise-or-chain handler. + raise self.chain_manifest + # Short-circuit a manifest that decoded successfully but binds + # no vars and records no resets — nothing to install. + if not (self.chain_manifest.vars or self.chain_manifest.resets): + return + + if ctx is None: + # Caller-side: stamp the owning task, ensure the factory is + # installed, merge with the current Chain (``None`` falls + # through to the unarmed-receiver fresh-install path inside + # `Chain.from_manifest`). + Chain.from_manifest( + self.chain_manifest, + owned=True, + merge_with=wool.__chain__.get(None), + ) + else: + # Worker-side: install inside the chain's cached + # contextvars.Context. The cached context is driven by + # successive step-tasks, so the chain is armed + # task-agnostically (``owned=False``); the factory is still + # ensured (a no-op on the worker loop, but the displacement + # tripwire). + ctx.run( + Chain.from_manifest, + self.chain_manifest, + owned=False, + ) + + +@dataclass +class RequestFrame(Frame[T], wire_type="Request"): + """Abstract intermediate for worker-bound dispatch frames. + + Each concrete request frame inherits this intermediate's + ``_wire_type_name = "Request"`` and declares its own + `_payload_field`. + """ + + def to_protobuf(self) -> protocol.Request: + """Encode this request frame as a `protocol.Request`. + + Narrows `Frame.to_protobuf`'s ``Request | Response`` + union return to ``Request`` so call sites that already know + they hold a ``RequestFrame`` (e.g., ``call.write(...)``) can + consume the wire envelope without an extra ``cast`` / + ``isinstance`` step. + """ + return cast(protocol.Request, super().to_protobuf()) + + +@dataclass +class ResponseFrame(Frame[T], wire_type="Response"): + """Abstract intermediate for caller-bound dispatch frames. + + Each concrete response frame inherits this intermediate's + ``_wire_type_name = "Response"`` and declares its own + `_payload_field`. + """ + + def to_protobuf(self) -> protocol.Response: + """Encode this response frame as a `protocol.Response`. + + Narrows `Frame.to_protobuf`'s ``Request | Response`` + union return to ``Response`` so async-generator yield sites + (e.g., ``yield response.to_protobuf()`` into a stream typed + ``AsyncGenerator[Response, ...]``) type-check without an + extra ``cast`` / ``isinstance`` step. + """ + return cast(protocol.Response, super().to_protobuf()) + + +@dataclass +class _ChainManifestFrame(Frame[T]): + """Mixin for mid-stream frames that carry a chain manifest. + + Extends `for_send` with `wire_chain_manifest`: auto-captures the active + chain when omitted, or accepts an explicit chain manifest (``None`` + suppresses the field). The worker's result / terminal-exception + paths pass a chain manifest captured inside the routine's cached + `contextvars.Context`; auto-capturing on the worker's main loop + would observe the wrong scope. + """ + + @classmethod + def for_send( + cls, + payload: T = None, # type: ignore[assignment] + *, + serializer: Serializer | None = None, + wire_chain_manifest: protocol.ChainManifest | None | UndefinedType = Undefined, + ) -> Self: + """Build an outbound frame carrying the active (or given) chain manifest.""" + if serializer is None: + serializer = wool.__serializer__ + frame = super().for_send(payload, serializer=serializer) + if wire_chain_manifest is Undefined: + try: + wire_chain_manifest = ( + wool.__chain__.get().to_manifest().to_protobuf(serializer=serializer) + ) + except LookupError: + wire_chain_manifest = None + frame._wire_chain_manifest = wire_chain_manifest + return frame + + +@dataclass +class TaskRequestFrame(RequestFrame[Task], field="task"): + """Initial dispatch request carrying a `wool.Task`.""" + + @classmethod + def _decode_payload(cls, wire: protocol.Request, *, serializer: Serializer) -> Any: + + return Task.from_protobuf(wire.task) + + def _encode_payload(self, *, serializer: Serializer) -> dict[str, Any]: + return {"task": self.payload.to_protobuf()} + + +@dataclass +class NextRequestFrame(RequestFrame[None], _ChainManifestFrame[None], field="next"): + """Mid-stream pull request (no payload).""" + + @classmethod + def _decode_payload(cls, wire: protocol.Request, *, serializer: Serializer) -> Any: + return None + + def _encode_payload(self, *, serializer: Serializer) -> dict[str, Any]: + + return {"next": protocol.Void()} + + +@dataclass +class SendRequestFrame(RequestFrame[Any], _ChainManifestFrame[Any], field="send"): + """Mid-stream ``asend`` request carrying an arbitrary payload.""" + + @classmethod + def _decode_payload(cls, wire: protocol.Request, *, serializer: Serializer) -> Any: + return serializer.loads(wire.send.dump) + + def _encode_payload(self, *, serializer: Serializer) -> dict[str, Any]: + + return {"send": protocol.Message(dump=serializer.dumps(self.payload))} + + +@dataclass +class ThrowRequestFrame( + RequestFrame[BaseException], _ChainManifestFrame[BaseException], field="throw" +): + """Mid-stream ``athrow`` request carrying an exception. + + Its `BaseException` payload makes `mount` chain a chain-manifest + decode error onto the throw payload's ``__context__`` rather than + abandoning the throw entirely. + """ + + @classmethod + def _decode_payload(cls, wire: protocol.Request, *, serializer: Serializer) -> Any: + return serializer.loads(wire.throw.dump) + + def _encode_payload(self, *, serializer: Serializer) -> dict[str, Any]: + + return {"throw": protocol.Message(dump=serializer.dumps(self.payload))} + + +@dataclass +class AckResponseFrame(ResponseFrame[None], field="ack"): + """Worker's protocol-version ack for the initial dispatch frame.""" + + @classmethod + def _decode_payload(cls, wire: protocol.Response, *, serializer: Serializer) -> Any: + return None + + def _encode_payload(self, *, serializer: Serializer) -> dict[str, Any]: + + return {"ack": protocol.Ack(version=protocol.__version__)} + + +@dataclass +class NackResponseFrame(ResponseFrame[BaseException], field="nack"): + """Dispatch-handler rejection response carrying an exception. + + Its `BaseException` payload makes `mount` chain a chain-manifest + decode error onto the payload's ``__context__``. + The payload is serialised via `_safely_serialize_exception` + so an un-picklable exception instance degrades to a + type-preserving fallback. + """ + + @classmethod + def _decode_payload(cls, wire: protocol.Response, *, serializer: Serializer) -> Any: + return serializer.loads(wire.nack.exception.dump) + + def _encode_payload(self, *, serializer: Serializer) -> dict[str, Any]: + + return { + "nack": protocol.Nack( + exception=protocol.Message( + dump=_safely_serialize_exception(serializer, self.payload), + ), + ) + } + + +@dataclass +class ResultResponseFrame(ResponseFrame[Any], _ChainManifestFrame[Any], field="result"): + """Routine yield / return value response.""" + + @classmethod + def _decode_payload(cls, wire: protocol.Response, *, serializer: Serializer) -> Any: + return serializer.loads(wire.result.dump) + + def _encode_payload(self, *, serializer: Serializer) -> dict[str, Any]: + + return {"result": protocol.Message(dump=serializer.dumps(self.payload))} + + +@dataclass +class ExceptionResponseFrame( + ResponseFrame[BaseException], _ChainManifestFrame[BaseException], field="exception" +): + """Routine-raised exception response. + + Its `BaseException` payload makes `mount` chain a chain-manifest + decode error onto the payload's ``__context__``. + The payload is serialised via `_safely_serialize_exception` + so an un-picklable exception instance degrades to a + type-preserving fallback. + """ + + @classmethod + def _decode_payload(cls, wire: protocol.Response, *, serializer: Serializer) -> Any: + return serializer.loads(wire.exception.dump) + + def _encode_payload(self, *, serializer: Serializer) -> dict[str, Any]: + + return { + "exception": protocol.Message( + dump=_safely_serialize_exception(serializer, self.payload), + ) + } + + +def _safely_serialize_exception( + serializer: Serializer, + exc: BaseException, +) -> bytes: + """Serialize ``exc``, preserving the exception class when the + original instance carries un-picklable state. + + Prevents un-picklable exception state from converting a + wool-class failure on the wire into a generic gRPC stream + error on the caller side. The serializer is + `CloudpickleSerializer`, which fails on un-picklable + input via `pickle.PickleError`, `TypeError` + for un-picklable C types, `AttributeError` for + un-picklable local closures, or `RecursionError` for + deeply self-referential graphs. + + **Type-preserving fallback.** Stdlib exception pickling + round-trips ``(type, args, __dict__)``, dropping + `__traceback__`, `__cause__`, `__context__`, + and `__suppress_context__` — those four are not part of + the exception's identity. `__notes__` and other + ``__dict__`` attributes survive the round-trip; cloudpickle (with + tblib) preserves `__cause__` so the dispatch handler's + ``raise routine_exc from encode_err`` chaining for strict-mode + chain-encode failures survives the wire. When a routine-level + exception accumulates state that drags an un-picklable C-level + object into the graph (e.g., a worker-thread frame on + `__traceback__` reachable via `__cause__`), the + first ``dumps`` raises but the exception's *class* and *args* + are still picklable on their own. Reconstruct a clean instance + and reship — the caller's ``except RoutineError`` still + matches, mirroring the stdlib pickle contract for exceptions. + Side-channel attachments (notes, the ``__cause__`` chain) are + lost on the reconstructed instance because the fallback builds + a fresh ``cls(*exc.args)``; the wire-survival guarantee holds + only on the primary path. + + If even ``cls(*exc.args)`` cannot be constructed or pickled + (constructor side effects, unpicklable args), demote to a + stdlib `RuntimeError` carrying the original class + name and message — always picklable, so the third ``dumps`` + cannot fail. + """ + try: + return serializer.dumps(exc) + except (pickle.PickleError, TypeError, AttributeError, RecursionError) as primary: + primary_err = primary + try: + cls = type(exc) + clean = cls(*exc.args) + # Walk the ``__cause__`` chain recursively, depth-bounded + # and cycle-guarded. Reconstruct each level via the + # ``cls(*args)`` shape so tracebacks (a frequent source of + # cloudpickle failures) don't take the chain down with them. + _MAX_CAUSE_DEPTH = 64 + seen_ids: set[int] = {id(exc)} + head = clean + original = exc.__cause__ + depth = 0 + while ( + original is not None + and id(original) not in seen_ids + and depth < _MAX_CAUSE_DEPTH + ): + seen_ids.add(id(original)) + depth += 1 + try: + cause_cls = type(original) + clean_args = tuple( + type(a)(*a.args) if isinstance(a, BaseException) else a + for a in original.args + ) + clean_cause = cause_cls(*clean_args) + head.__cause__ = clean_cause + head.__suppress_context__ = True + head = clean_cause + original = original.__cause__ + except Exception: + # Reconstruction failed at this level; stop walking. + # The rest of the chain is lost — the fidelity warning + # below reports the loss. + break + # Surface exception-fidelity loss via a typed warning. The + # warning is also attached to the reconstructed exception's + # ``__context__`` so the caller's diagnostic stack carries + # the fidelity signal alongside the (degraded) exception + # they actually catch. + fidelity_warning = SerializationWarning( + f"Exception {type(exc).__qualname__!r} required fallback " + f"reconstruction; some side-channel state was lost.", + cause=primary_err, + original_type=type(exc), + ) + try: + warnings.warn(fidelity_warning, stacklevel=2) + except SerializationWarning: + # Strict-mode: the warning was promoted to an error. The + # contract is that fidelity loss is non-fatal — even + # under strict mode the caller's primary signal (the + # routine exception) must reach them. Swallow and continue + # to ship the degraded copy with the warning attached + # via ``__context__`` below. + pass + # Attach the warning via ``__context__`` so the diagnostic + # rides on the wire alongside the reconstructed exception + # without replacing the caller's primary catchable type. Only + # attach if no ``__context__`` was already set; preserves + # whatever the caller's exception machinery established. + if clean.__context__ is None: + clean.__context__ = fidelity_warning + try: + return serializer.dumps(clean) + except Exception: + # The attached fidelity warning's ``cause`` (the primary + # pickle failure) — or a reconstructed ``__cause__`` level — + # can itself carry the same un-picklable state that broke the + # primary dumps (e.g., an env-dependent worker-thread + # traceback). Strip the side-channel attachments and ship the + # bare reconstructed exception so the caller still catches the + # original type rather than a demoted ``RuntimeError``: the + # class and args are picklable on their own — the primary + # failure was the attachments, not the identity. This also + # makes the shipped type deterministic across environments + # where traceback picklability differs. + clean.__cause__ = None + clean.__context__ = None + clean.__suppress_context__ = True + return serializer.dumps(clean) + except Exception: + # Honor the docstring's "third dumps cannot fail" promise: + # any reconstruction failure (over-eager ``__init__`` + # validation, custom ``__new__``, un-picklable args, etc.) + # — not just the narrow pickle/type/recursion set — must + # fall through to the always-picklable ``RuntimeError`` + # demotion. ``KeyboardInterrupt``/``SystemExit`` still + # propagate since they are ``BaseException``-only. + cls_name = type(exc).__name__ + # Guard the f-string. ``__str__`` is user-overridable and + # can raise (touches per-instance state, delegates to a + # buggy ``__repr__``, etc.). The bare class name is a + # string attribute lookup and cannot raise — last resort + # so the safety net always succeeds. + try: + message = f"{cls_name}: {exc!s}" + except Exception: + message = cls_name + return serializer.dumps(RuntimeError(message)) diff --git a/wool/src/wool/runtime/worker/interceptor.py b/wool/src/wool/runtime/worker/interceptor.py index 1e899ad5..9eb0f645 100644 --- a/wool/src/wool/runtime/worker/interceptor.py +++ b/wool/src/wool/runtime/worker/interceptor.py @@ -58,15 +58,18 @@ async def version_checked_handler( request_msg.ParseFromString(first_bytes) task_bytes = request_msg.task.SerializeToString() envelope.ParseFromString(task_bytes) + # F38 — keep parse_version inside the envelope-parse + # try/except so any failure (TypeError for a non-str + # input, etc.) routes through ``FAILED_PRECONDITION`` + # rather than propagating as gRPC ``UNKNOWN``. + client_version = parse_version(envelope.version) + local_version = parse_version(protocol.__version__) except Exception: await context.abort( grpc.StatusCode.FAILED_PRECONDITION, "Failed to parse version envelope", ) - client_version = parse_version(envelope.version) - local_version = parse_version(protocol.__version__) - if client_version is None or local_version is None: await context.abort( grpc.StatusCode.FAILED_PRECONDITION, diff --git a/wool/src/wool/runtime/worker/metadata.py b/wool/src/wool/runtime/worker/metadata.py index 4608b3a0..e564020f 100644 --- a/wool/src/wool/runtime/worker/metadata.py +++ b/wool/src/wool/runtime/worker/metadata.py @@ -7,8 +7,7 @@ import grpc -from wool.protocol import ChannelOptions as ChannelOptionsProtobuf -from wool.protocol import WorkerMetadata as WorkerMetadataProtobuf +from wool import protocol as wire from wool.runtime.worker.base import ChannelOptions @@ -53,7 +52,7 @@ class WorkerMetadata: options: ChannelOptions = field(default_factory=ChannelOptions, hash=False) @classmethod - def from_protobuf(cls, protobuf: WorkerMetadataProtobuf) -> WorkerMetadata: + def from_protobuf(cls, protobuf: wire.WorkerMetadata) -> WorkerMetadata: """Create a WorkerMetadata instance from a protobuf message. :param protobuf: @@ -75,14 +74,14 @@ def from_protobuf(cls, protobuf: WorkerMetadataProtobuf) -> WorkerMetadata: options=cls._options_from_protobuf(protobuf), ) - def to_protobuf(self) -> WorkerMetadataProtobuf: + def to_protobuf(self) -> wire.WorkerMetadata: """Convert this WorkerMetadata instance to a protobuf message. :returns: A protobuf WorkerMetadata message containing this instance's data. """ - msg = WorkerMetadataProtobuf( + msg = wire.WorkerMetadata( uid=str(self.uid), address=self.address, pid=self.pid, @@ -92,7 +91,7 @@ def to_protobuf(self) -> WorkerMetadataProtobuf: secure=self.secure, ) msg.connection.CopyFrom( - ChannelOptionsProtobuf( + wire.ChannelOptions( max_receive_message_length=self.options.max_receive_message_length, max_send_message_length=self.options.max_send_message_length, keepalive_time_ms=self.options.keepalive_time_ms, @@ -106,7 +105,7 @@ def to_protobuf(self) -> WorkerMetadataProtobuf: return msg @classmethod - def _options_from_protobuf(cls, protobuf: WorkerMetadataProtobuf) -> ChannelOptions: + def _options_from_protobuf(cls, protobuf: wire.WorkerMetadata) -> ChannelOptions: """Reconstruct ChannelOptions from the protobuf connection config. :param protobuf: diff --git a/wool/src/wool/runtime/worker/pool.py b/wool/src/wool/runtime/worker/pool.py index 0a702d69..7663dea5 100644 --- a/wool/src/wool/runtime/worker/pool.py +++ b/wool/src/wool/runtime/worker/pool.py @@ -17,8 +17,8 @@ from typing_extensions import deprecated -from wool.exception import WoolWarning -from wool.runtime.context import install_task_factory +from wool.exceptions import WoolWarning +from wool.runtime.context.factory import install_task_factory from wool.runtime.discovery.base import DiscoveryLike from wool.runtime.discovery.base import DiscoveryPublisherLike from wool.runtime.discovery.base import DiscoverySubscriberLike diff --git a/wool/src/wool/runtime/worker/process.py b/wool/src/wool/runtime/worker/process.py index 0c4cf353..c4a6751e 100644 --- a/wool/src/wool/runtime/worker/process.py +++ b/wool/src/wool/runtime/worker/process.py @@ -23,7 +23,7 @@ import wool from wool import protocol -from wool.runtime.context import install_task_factory +from wool.runtime.context.factory import install_task_factory from wool.runtime.resourcepool import ResourcePool from wool.runtime.worker.auth import CredentialContext from wool.runtime.worker.auth import WorkerCredentials diff --git a/wool/src/wool/runtime/worker/proxy.py b/wool/src/wool/runtime/worker/proxy.py index dfc65253..0f918d1e 100644 --- a/wool/src/wool/runtime/worker/proxy.py +++ b/wool/src/wool/runtime/worker/proxy.py @@ -25,7 +25,7 @@ from packaging.version import Version import wool -from wool.exception import WoolWarning +from wool.exceptions import WoolWarning from wool.runtime.discovery.base import DiscoveryEvent from wool.runtime.discovery.base import DiscoverySubscriberLike from wool.runtime.discovery.local import LocalDiscovery diff --git a/wool/src/wool/runtime/worker/service.py b/wool/src/wool/runtime/worker/service.py index 51a60f1c..bf233714 100644 --- a/wool/src/wool/runtime/worker/service.py +++ b/wool/src/wool/runtime/worker/service.py @@ -19,12 +19,13 @@ import wool from wool import protocol -from wool.runtime.context import attached -from wool.runtime.context import install_task_factory +from wool.runtime.context.exceptions import SerializationError +from wool.runtime.context.factory import install_task_factory from wool.runtime.discovery import __subscriber_pool__ from wool.runtime.resourcepool import ResourcePool from wool.runtime.routine.task import Task -from wool.runtime.serializer import Serializer +from wool.runtime.worker.frame import AckResponseFrame +from wool.runtime.worker.frame import NackResponseFrame from wool.runtime.worker.session import DispatchSession from wool.runtime.worker.session import Rejected @@ -32,138 +33,21 @@ _DRAIN_TIMEOUT: Final[float] = 5.0 """Wall-clock timeout in seconds for the multi-generation task drain -in :meth:`WorkerService._destroy_worker_loop`. Generous enough for a +in `WorkerService._destroy_worker_loop`. Generous enough for a normal chain of ``finally``-scheduled cleanup tasks to unwind, short enough not to stall worker-loop teardown; past this timeout the drain gives up, with the daemon-thread reap as the backstop.""" -def _safely_serialize_exception( - serializer: Serializer, - exc: BaseException, -) -> bytes: - """Serialize *exc*, preserving the exception class when the - original instance carries un-picklable state. - - Prevents un-picklable exception state from converting a - wool-class failure on the wire into a generic gRPC stream - error on the caller side. The serializer is - :class:`CloudpickleSerializer`, which fails on un-picklable - input via :class:`pickle.PickleError`, :class:`TypeError` - for un-picklable C types, :class:`AttributeError` for - un-picklable local closures, or :class:`RecursionError` for - deeply self-referential graphs. - - **Type-preserving fallback.** Stdlib exception pickling - round-trips ``(type, args, __dict__)``, dropping - :attr:`__traceback__`, :attr:`__cause__`, :attr:`__context__`, - and :attr:`__suppress_context__` — those four are not part of - the exception's identity. :attr:`__notes__` and other - ``__dict__`` attributes (including the wool-private - ``__wool_context_warnings__`` set by the dispatch handler when - strict-mode :class:`wool.ContextDecodeWarning` peers fire on - snapshot encode) survive the round-trip. When a routine-level - exception accumulates state that drags an un-picklable C-level - object into the graph (e.g. a worker-thread frame on - :attr:`__traceback__` reachable via :attr:`__cause__`), the - first ``dumps`` raises but the exception's *class* and *args* - are still picklable on their own. Reconstruct a clean instance - and reship — the caller's ``except RoutineError`` still - matches, mirroring the stdlib pickle contract for exceptions. - Side-channel attachments (notes, ``__wool_context_warnings__``) - are lost on the reconstructed instance because the fallback - builds a fresh ``cls(*exc.args)``; the wire-survival guarantee - holds only on the primary path. - - If even ``cls(*exc.args)`` cannot be constructed or pickled - (constructor side effects, unpicklable args), demote to a - stdlib :class:`RuntimeError` carrying the original class - name and message — always picklable, so the third ``dumps`` - cannot fail. - """ - try: - return serializer.dumps(exc) - except (pickle.PickleError, TypeError, AttributeError, RecursionError): - pass - try: - cls = type(exc) - clean = cls(*exc.args) - return serializer.dumps(clean) - except Exception: - # Honor the docstring's "third dumps cannot fail" promise: - # any reconstruction failure (over-eager ``__init__`` - # validation, custom ``__new__``, un-picklable args, etc.) - # — not just the narrow pickle/type/recursion set — must - # fall through to the always-picklable ``RuntimeError`` - # demotion. ``KeyboardInterrupt``/``SystemExit`` still - # propagate since they are ``BaseException``-only. - cls_name = type(exc).__name__ - # Guard the f-string. ``__str__`` is user-overridable and - # can raise (touches per-instance state, delegates to a - # buggy ``__repr__``, etc.). The bare class name is a - # string attribute lookup and cannot raise — last resort - # so the safety net always succeeds. - try: - message = f"{cls_name}: {exc!s}" - except Exception: - message = cls_name - return serializer.dumps(RuntimeError(message)) - - -def _attach_strict_mode_warnings(exc: BaseException, encode_exc: BaseException) -> None: - """Attach strict-mode :class:`wool.ContextDecodeWarning` peers to a - routine exception. - - When strict mode promotes :class:`wool.ContextDecodeWarning` to an - exception, the post-run snapshot encode (``session.context.to_protobuf``) - raises a :class:`BaseExceptionGroup` of warning peers. Attach the peers - to *exc* via PEP 678 ``__notes__`` (visible in tracebacks) and a - ``__wool_context_warnings__`` attribute (programmatic access). The - routine exception's type is preserved, so the caller's existing - ``except RoutineError:`` clause continues to catch — no migration to - ``except*`` or ``except ExceptionGroup`` required. - - Both attachment paths are best-effort. ``add_note`` may raise on a - subclass with an overridden ``__setattr__`` or an unusual C-level - storage policy; ``setattr`` may raise ``AttributeError`` on frozen - dataclass exceptions or slotted layouts without ``__dict__``. Either - is swallowed so the routine's primary signal still ships — the - warnings simply will not ride on a type that rejects them. - - Lifted out of the dispatch handler's terminal-exception clause as a - flat sync helper so coverage tooling on Python 3.11 (``sys.settrace``) - can track it; the deeply nested original lived inside an async - generator's nested except arm and was opaque to pre-PEP-669 tracing. - - :param exc: - The routine's primary exception, annotated in place. - :param encode_exc: - The encode-time failure carrying the warning peers. - """ - if isinstance(encode_exc, BaseExceptionGroup): - context_warnings: list[BaseException] = list(encode_exc.exceptions) - else: - context_warnings = [encode_exc] - try: - for w in context_warnings: - exc.add_note(f"wool context warning: {w}") - except (AttributeError, TypeError): - pass - try: - setattr(exc, "__wool_context_warnings__", context_warnings) - except AttributeError: - pass - - # public @dataclass(frozen=True) class BackpressureContext: - """Snapshot of worker state provided to backpressure hooks. + """Worker state provided to backpressure hooks. :param active_task_count: Number of tasks currently executing on this worker. :param task: - The incoming :class:`~wool.runtime.routine.task.Task` being + The incoming `~wool.runtime.routine.task.Task` being evaluated for admission. """ @@ -186,13 +70,6 @@ class BackpressureLike(Protocol): Pass ``None`` (the default) to accept all tasks unconditionally. - The hook runs after the caller's wire-shipped ContextVar snapshot - is applied to the handler's context, so a hook that reads a - :class:`wool.ContextVar` (e.g., a tenant id) observes the caller's - value for that dispatch. This enables tenant- or request-scoped - admission decisions without plumbing values through the - :class:`BackpressureContext` explicitly. - Both sync and async implementations are supported:: def sync_hook(ctx: BackpressureContext) -> bool: @@ -207,8 +84,8 @@ def __call__(self, ctx: BackpressureContext) -> bool | Awaitable[bool]: """Evaluate whether to reject the incoming task. :param ctx: - Snapshot of the worker's current state and the incoming - task. + The worker's current dispatch state (active task count, + the incoming task). :returns: ``True`` to reject the task, ``False`` to accept it. """ @@ -216,13 +93,13 @@ def __call__(self, ctx: BackpressureContext) -> bool | Awaitable[bool]: class _ReadOnlyEvent: - """A read-only wrapper around :class:`asyncio.Event`. + """A read-only wrapper around `asyncio.Event`. Provides access to check if an event is set and wait for it to be set, but prevents external code from setting or clearing the event. :param event: - The underlying :class:`asyncio.Event` to wrap. + The underlying `asyncio.Event` to wrap. """ def __init__(self, event: asyncio.Event): @@ -241,6 +118,7 @@ async def wait(self) -> None: await self._event.wait() +# public class WorkerService(protocol.WorkerServicer): """gRPC service for task execution. @@ -249,12 +127,12 @@ class WorkerService(protocol.WorkerServicer): results back to the client. Handles graceful shutdown by rejecting new tasks while allowing - in-flight tasks to complete. Exposes :attr:`stopping` and - :attr:`stopped` events for lifecycle monitoring. + in-flight tasks to complete. Exposes `stopping` and + `stopped` events for lifecycle monitoring. :param backpressure: Optional admission control hook. See - :class:`BackpressureLike`. ``None`` (default) accepts all + `BackpressureLike`. ``None`` (default) accepts all tasks unconditionally. """ @@ -268,7 +146,15 @@ def __init__(self, *, backpressure: BackpressureLike | None = None): self._stopping = asyncio.Event() self._docket = set() self._backpressure = backpressure - # Budget for the loop-teardown join, set by :meth:`_stop` + # Strong refs to live ``session.cancel()`` tasks scheduled + # from ``_propagate_cancel_on_done`` (a gRPC-internal-thread + # callback that hops the cancel onto the main loop). Without + # a strong ref the task is eligible for GC mid-flight, + # causing "Task was destroyed but it is pending" and "exception + # was never retrieved" hazards. The done-callback discards + # the task once the cancel propagation has completed. + self._cancel_propagators: set[asyncio.Task[None]] = set() + # Budget for the loop-teardown join, set by `_stop` # from the StopRequest's ``timeout``. ``0`` means "do not # synchronously wait" (worker thread closes its own loop # after ``run_forever`` returns; ``daemon=True`` reaps it @@ -291,7 +177,7 @@ def stopping(self) -> _ReadOnlyEvent: """Read-only event signaling that the service is stopping. :returns: - A :class:`_ReadOnlyEvent`. + A `_ReadOnlyEvent`. """ return _ReadOnlyEvent(self._stopping) @@ -300,7 +186,7 @@ def stopped(self) -> _ReadOnlyEvent: """Read-only event signaling that the service has stopped. :returns: - A :class:`_ReadOnlyEvent`. + A `_ReadOnlyEvent`. """ return _ReadOnlyEvent(self._stopped) @@ -311,35 +197,34 @@ async def dispatch( ) -> AsyncIterator[protocol.Response]: """Execute a task in the current event loop. - Reads the first :class:`~wool.protocol.Request` from - the bidirectional stream to obtain the :class:`Task`, then + Reads the first `~wool.protocol.Request` from + the bidirectional stream to obtain the `Task`, then schedules it for execution. For async generators, subsequent ``Message`` frames are forwarded into the generator via ``asend()``. - **Context serialization failures (worker-side).** - Wire context is **ancillary state** under wool's protocol - contract: a failure to serialize the post-run snapshot or to - deserialize an incoming context (initial request or + **Chain serialization failures (worker-side).** + The chain manifest is **ancillary state** under wool's protocol + contract: a failure to serialize the post-run chain manifest or to + decode the incoming chain manifest (initial request or mid-stream frame) is non-fatal in non-strict mode but fatal in strict mode. Both modes emit a - :class:`wool.ContextDecodeWarning` for each failure. - - *Non-strict mode (default).* The routine still runs — with a - fresh empty context as fallback when initial-frame - deserialization fails — and the back-propagated snapshot is - replaced with an empty context when post-run serialization - fails. A snapshot serialization failure that coincides with - a routine exception rides back as peers in a - :class:`BaseExceptionGroup` (extending an existing group - when the routine exception is already grouped) rather than - as nested causes, so the caller observes both signals at - the same level. + `wool.SerializationWarning` for each failure. + + *Non-strict mode (default).* The routine still runs — when + initial-frame deserialization fails, each unreadable entry is + dropped with a warning and the routine runs under whatever + partial chain decoded (an entirely unreadable frame leaves + the worker chain unarmed) — and the back-propagated chain manifest + is replaced with an empty chain manifest when post-run serialization + fails. Caller-side per-var warnings emit through the standard + warnings machinery; there is no aggregated error to ride + back alongside the routine's signal. *Strict mode* (e.g., - ``PYTHONWARNINGS=error::wool.ContextDecodeWarning``). The + ``PYTHONWARNINGS=error::wool.SerializationWarning``). The warning promotes to an exception. The dispatch handler - catches :class:`wool.ContextDecodeWarning` raised before the + catches `wool.SerializationWarning` raised before the routine starts and ships it via the routine-exception channel — the routine does not run — so the caller catches the same warning class symmetrically with caller-side strict @@ -352,12 +237,12 @@ async def dispatch( terminal signals ride on ``Response.exception``. The dispatch FSM is ``Ack? (Result* (Exception | ε)) | Nack``. Code that emits a Nack after an Ack would violate the caller-side - consumer contract in :class:`WorkerConnection`. + consumer contract in `WorkerConnection`. :param request_iterator: The incoming bidirectional request stream. :param context: - The :class:`grpc.aio.ServicerContext` for this request. + The `grpc.aio.ServicerContext` for this request. :yields: One of four terminal shapes per dispatch. @@ -365,7 +250,7 @@ async def dispatch( ``nack`` payload carries the parse-time failure (with ``exception`` set to the dumped original cause). No preceding ``Ack``. Triggered by malformed task id, - strict-mode :class:`wool.ContextDecodeWarning`, + strict-mode `wool.SerializationWarning`, cloudpickle errors on the task callable, ImportError on a missing module, or non-async callable. @@ -377,17 +262,17 @@ async def dispatch( optionally followed by zero or more ``result`` Responses, then a single terminal ``exception`` Response carrying the dumped routine / handler-level - failure plus a ``context`` snapshot. + failure plus a ``context`` frame. **Routine-failure-with-encode-failure variant** — same as the routine-failure path except the terminal ``Response`` drops the ``context`` field. The post-run - snapshot itself failed to serialize (strict-mode - :class:`wool.ContextDecodeWarning`); the encode peers - are attached to the routine exception via PEP 678 - ``__notes__`` and a ``__wool_context_warnings__`` - attribute, so the caller-visible exception class is - preserved. + chain manifest itself failed to serialize (strict-mode + `wool.ChainSerializationError` aggregating per-var + warnings); the encode error rides on the routine + exception as ``__cause__`` via ``raise from`` chaining, + so the caller-visible exception class is preserved and + the encode error remains visible in the traceback. **Operator pre-emption.** A worker-side graceful shutdown cancels in-flight dispatches and the underlying @@ -407,10 +292,10 @@ async def dispatch( async with self._loop_pool.get("worker") as (loop, _): # Instantiate before ``async with`` so a ``Rejected`` raised - # from :meth:`DispatchSession.__aenter__` (parse-phase failure) + # from `DispatchSession.__aenter__` (parse-phase failure) # leaves ``session`` bound for the ``except Rejected`` arm's # access to ``session.serializer`` (always ``wool.__serializer__``, - # cloudpickle, set in :meth:`__init__`). + # cloudpickle, set in `__init__`). session = DispatchSession(request_iterator, loop) # Register a deterministic cancellation propagation hook @@ -425,10 +310,10 @@ async def dispatch( # thread when the RPC reaches a terminal state; on a # client-side cancellation it fires with # ``context.cancelled()`` == True, and we schedule - # :meth:`DispatchSession.cancel` on the main loop via + # `DispatchSession.cancel` on the main loop via # ``call_soon_threadsafe`` so the routine task is # cancelled cross-loop on the same path - # ``WorkerService._cancel`` uses for graceful shutdown. + # ``WorkerService._preempt`` uses for graceful shutdown. # ``DispatchSession.cancel`` is idempotent, so the # dispatch handler's own except-clause cancel is a no-op # if this callback raced ahead. Avoids the watcher-task @@ -441,10 +326,19 @@ async def dispatch( def _propagate_cancel_on_done(ctx) -> None: if not ctx.cancelled(): return + + def _spawn() -> None: + # Hold a strong ref to the cancel-propagator task + # so it is not GC'd mid-flight (which would emit + # "Task was destroyed but it is pending" and + # "exception was never retrieved" warnings). The + # done-callback discards once the task completes. + task = main_loop.create_task(session.cancel()) + self._cancel_propagators.add(task) + task.add_done_callback(self._cancel_propagators.discard) + try: - main_loop.call_soon_threadsafe( - lambda: main_loop.create_task(session.cancel()) - ) + main_loop.call_soon_threadsafe(_spawn) except RuntimeError: # Main loop already closed (graceful shutdown # raced us). Nothing to propagate to; the @@ -452,7 +346,7 @@ def _propagate_cancel_on_done(ctx) -> None: # surrounding context-manager unwind. pass - # grpc.aio's :meth:`ServicerContext.add_done_callback` is + # grpc.aio's `ServicerContext.add_done_callback` is # typed in typeshed as # ``Callable[[_DoneCallback[_TRequest, _TResponse]], None]`` # where ``_DoneCallback`` is a generic callable *class*. @@ -468,25 +362,23 @@ def _propagate_cancel_on_done(ctx) -> None: async with session: if self._backpressure is not None: backpressure = self._backpressure - # ``guarded=False`` — the dispatch task is not - # running the routine itself, only reading - # caller-shipped wool.ContextVar values for the - # hook. The single-task ownership of - # ``session.context`` belongs to the worker - # task scheduled lazily on the first - # ``__aiter__`` call below; entering the - # guard here would race that scheduling - # under ``Context._lock``. + # The hook observes dispatch-time worker + # state (active task count, the incoming + # task). Caller-shipped wool.ContextVar + # values are not exposed: under the + # per-frame architecture the Task frame + # carries no chain manifest, so the + # decoded manifest at admission time is + # always empty. try: - with attached(session.context, guarded=False): - decision = backpressure( - BackpressureContext( - active_task_count=len(self._docket), - task=session.task, - ) + decision = backpressure( + BackpressureContext( + active_task_count=len(self._docket), + task=session.task, ) - if isawaitable(decision): - decision = await decision + ) + if isawaitable(decision): + decision = await decision except Exception: # User-supplied backpressure hook crashed. # Log so the operator notices, then abort @@ -508,12 +400,35 @@ def _propagate_cancel_on_done(ctx) -> None: ) async with self._tracked(session, context): - yield protocol.Response( - ack=protocol.Ack(version=protocol.__version__) - ) + # Under the per-frame architecture, boundary + # frames (Ack/Nack/Task) carry no chain manifest — + # the chain manifest lives only on mid-stream payload + # frames (Next/Send/Throw/Result/Exception). The + # dispatch handler is on the main loop and never + # carries the worker's chain itself. + yield AckResponseFrame.for_send( + serializer=session.serializer, + ).to_protobuf() try: async for response in session: - yield response.to_protobuf(serializer=session.serializer) + # Wrap the main-loop encode in a + # try/except for pickle/type errors so a + # mid-stream encode failure is shipped as + # a typed terminal SerializationError + # rather than misclassified as a routine + # raise. Worker-shipped + # ExceptionResponseFrames continue + # to flow unchanged: only the main-loop + # ``response.to_protobuf()`` is wrapped. + try: + wire = response.to_protobuf() + except (pickle.PickleError, TypeError) as ee: + raise SerializationError( + f"Failed to encode result payload: {ee}", + cause=ee, + value_repr=repr(response.payload), + ) from ee + yield wire except (Exception, asyncio.CancelledError) as e: # Cancel the session before drain on the # error path so a routine suspended @@ -533,101 +448,58 @@ def _propagate_cancel_on_done(ctx) -> None: # the cancel failure to the caller. try: await session.cancel() + except (KeyboardInterrupt, SystemExit): + # A process-exit signal must win even over + # the routine's primary ``e`` — the + # dispatch is being torn down regardless. + raise except BaseException: pass # All worker-side and main-side failures # land here: routine exceptions raised in - # :func:`_step` propagate through the + # `_drive_step` propagate through the # response queue and out of - # :meth:`DispatchSession.__aiter__` raw; + # `DispatchSession.__aiter__` raw; # pre-stream worker setup failures surface - # via :meth:`_ResponseQueue.get` raising on - # close; mid-stream context-decode / - # update failures escape :func:`_step` the - # same way; handler-level failures (e.g. + # via `_ResponseQueue.get` raising on + # close; mid-stream chain-manifest-decode / + # update failures escape `_drive_step` the + # same way; handler-level failures (e.g., # ``response.to_protobuf`` raising) raise # directly here; gRPC stream cancellation # raises ``CancelledError`` mid-iteration. - # Drain the worker before snapshotting - # ``session.context``: worker-failure - # paths arrive with the worker already - # finalized (so drain is a no-op), but - # cancellation and main-loop handler- - # level failures leave the worker mid- - # ``_step``, racing the snapshot's read - # of ``_data`` against the worker's - # ``work_ctx.update`` / - # ``work_ctx.to_protobuf`` writes. - # :meth:`DispatchSession.drain` is - # idempotent — :meth:`__aexit__` will + # Drain the worker before reading + # ``session._final_wire_chain_manifest``: the + # worker task encodes it inside its own + # Chain and publishes it from its + # ``finally``, so the read must wait for + # the worker task to finish. Worker-failure + # paths arrive already finalized (drain is + # a no-op); cancellation and main-loop + # handler-level failures leave the worker + # mid-``_drive_step``, so the drain is what + # guarantees the publish has happened. + # `DispatchSession.drain` is + # idempotent — `__aexit__` will # call it again on the way out. On the # external-cancellation path drain may # re-raise ``CancelledError`` before - # the snapshot can be built; the gRPC + # the chain manifest can be built; the gRPC # stream is being torn down anyway, so # losing the terminal Response is # acceptable — the caller has no # consumer left. await session.drain() - # Unwrap PEP 525's auto-conversion for - # coroutine routines so the caller's - # ``await routine()`` surfaces the - # original :class:`StopAsyncIteration` - # raw — matching stdlib coroutine - # semantics. The wrap happens in - # :meth:`DispatchSession._iterate` (the - # asyncgen transport layer): when a - # coroutine raises StopAsyncIteration, - # _ResponseQueue.get re-raises it inside - # _iterate's body, and PEP 525 converts - # it to ``RuntimeError("async generator - # raised StopAsyncIteration")`` with the - # original SAI on ``__cause__``. - # Streaming routines keep the - # RuntimeError shape — that already - # matches stdlib ``async for x in - # agen()`` semantics. - if ( - not session.streaming - and isinstance(e, RuntimeError) - and isinstance(e.__cause__, StopAsyncIteration) - ): - e = e.__cause__ - try: - wire_context = session.context.to_protobuf( - serializer=session.serializer - ) - except Exception as encode_exc: - # Strict-mode-only path: attach the - # encoded ``ContextDecodeWarning`` - # peers to ``e`` so the caller's - # ``except RoutineError`` clause keeps - # matching. See - # :func:`_attach_strict_mode_warnings` - # for the attachment contract and - # rationale. Drops the post-run - # ``context`` field on the wire (the - # snapshot itself failed); peers ride - # on the routine exception via PEP 678 - # ``__notes__`` and - # ``__wool_context_warnings__``. - _attach_strict_mode_warnings(e, encode_exc) - yield protocol.Response( - exception=protocol.Message( - dump=_safely_serialize_exception( - session.serializer, e - ) - ), - ) - else: - yield protocol.Response( - exception=protocol.Message( - dump=_safely_serialize_exception( - session.serializer, e - ) - ), - context=wire_context, - ) + # Encode-error vs. lazy-wire-frame + # routing (plus the PEP 525 SAI unwrap + # for coroutine routines) lives on + # `DispatchSession.terminal_response`. + # That single call keeps session._final_* + # access encapsulated rather than read + # here. + yield session.terminal_response( + e, serializer=session.serializer + ).to_protobuf() except Rejected as e: # Parse-phase failure (malformed task payload). # Reported via Nack so the client deserializes the @@ -637,15 +509,10 @@ def _propagate_cancel_on_done(ctx) -> None: # ``wool.__serializer__`` (cloudpickle). Same path as # ``Response.exception`` post-Ack — symmetry on the # wire. - yield protocol.Response( - nack=protocol.Nack( - exception=protocol.Message( - dump=_safely_serialize_exception( - session.serializer, e.original - ) - ), - ), - ) + yield NackResponseFrame.for_send( + e.original, + serializer=session.serializer, + ).to_protobuf() return except AbortError: # Intentional ``context.abort(...)`` calls inside the @@ -681,7 +548,7 @@ async def stop( :param request: The protobuf stop request containing the wait timeout. :param context: - The :class:`grpc.aio.ServicerContext` for this request. + The `grpc.aio.ServicerContext` for this request. :returns: An empty protobuf response indicating completion. """ @@ -696,16 +563,16 @@ def _create_worker_loop( ) -> tuple[asyncio.AbstractEventLoop, threading.Thread]: """Create a new event loop running on a dedicated daemon thread. - The thread target wraps :meth:`asyncio.AbstractEventLoop.run_forever` - in a ``try/finally`` that calls :meth:`asyncio.AbstractEventLoop.close` + The thread target wraps `asyncio.AbstractEventLoop.run_forever` + in a ``try/finally`` that calls `asyncio.AbstractEventLoop.close` once ``run_forever`` returns. Closing the loop from the worker thread (rather than the caller's thread inside - :meth:`_destroy_worker_loop`) eliminates the race that produced + `_destroy_worker_loop`) eliminates the race that produced ``RuntimeError("Cannot close a running event loop")`` when a caller-side close raced the still-active ``run_forever``. :param key: - The :class:`ResourcePool` cache key (unused). + The `ResourcePool` cache key (unused). :returns: A tuple of the event loop and the thread running it. """ @@ -722,7 +589,7 @@ def _run_then_close(): thread.start() return loop, thread - def _destroy_worker_loop( + async def _destroy_worker_loop( self, loop_thread: tuple[asyncio.AbstractEventLoop, threading.Thread], ) -> None: @@ -731,20 +598,20 @@ def _destroy_worker_loop( Drains successive generations of pending tasks on the worker loop, then signals the loop to stop. A cancelled task's ``finally`` clause can schedule a second generation - of tasks (e.g. follow-up cleanup, fire-and-forget logging, + of tasks (e.g., follow-up cleanup, fire-and-forget logging, further cancellations, etc.); the drain cancels and awaits successive generations until none remain or - :data:`_DRAIN_TIMEOUT` elapses, so the loop closes without + `_DRAIN_TIMEOUT` elapses, so the loop closes without leaking ``Task was destroyed but it is pending!`` warnings. If a routine schedules cleanup-of-cleanup past the budget, the drain stops anyway and the daemon-thread reap remains the backstop. The loop is closed by the worker - thread itself (see :meth:`_create_worker_loop`'s + thread itself (see `_create_worker_loop`'s ``_run_then_close`` target), not from this caller's thread — eliminating the close-while-running race. - Joins the worker thread for up to :attr:`_stop_timeout` - seconds (set by :meth:`_stop` from the StopRequest's + Joins the worker thread for up to `_stop_timeout` + seconds (set by `_stop` from the StopRequest's ``timeout``). ``timeout=0`` means "do not wait"; positive values bound the synchronous wait; ``None`` means "wait indefinitely" (caller asked for unlimited graceful shutdown). @@ -801,7 +668,12 @@ async def _shutdown(): timeout = self._stop_timeout if timeout is None or timeout > 0: - thread.join(timeout=timeout) + # Offload the synchronous ``thread.join`` to a worker + # thread so the main loop keeps pumping while we wait. + # ``ResourcePool._await`` already dispatches coroutine + # finalizers, so changing this function to ``async def`` + # is transparent at the call site. + await asyncio.get_running_loop().run_in_executor(None, thread.join, timeout) @asynccontextmanager async def _tracked( @@ -809,19 +681,19 @@ async def _tracked( session: DispatchSession, context: ServicerContext, ) -> AsyncIterator[None]: - """Add *session* to :attr:`_docket` for the duration of the + """Add *session* to `_docket` for the duration of the yield, removing it on exit. The docket is the registry of in-flight - :class:`DispatchSession` instances that :meth:`_stop` + `DispatchSession` instances that `_stop` pre-empts on graceful shutdown. The CM scope mirrors the dispatch handler's iteration scope, so an in-flight dispatch is always either tracked or already finalized. - Re-checks :attr:`_stopping` on entry to close the - check-to-register window in :meth:`dispatch` — a concurrent - :meth:`_stop` between the entry gate and docket registration - would otherwise admit a session that :meth:`_preempt` never + Re-checks `_stopping` on entry to close the + check-to-register window in `dispatch` — a concurrent + `_stop` between the entry gate and docket registration + would otherwise admit a session that `_preempt` never sees, leaving it to be torn down indirectly by loop-pool teardown rather than the explicit cancel path. """ @@ -840,7 +712,7 @@ async def _stop(self, *, timeout: float | None = 0) -> None: if timeout is not None and timeout < 0: timeout = None # Stash the StopRequest's timeout for the loop-teardown - # finalizer (read by :meth:`_destroy_worker_loop`) before + # finalizer (read by `_destroy_worker_loop`) before # any ``await`` so it is always set when ``_loop_pool.clear`` # later invokes the finalizer. ``timeout=0`` (the default) # → don't synchronously join; positive → bound the join; @@ -863,9 +735,9 @@ async def _preempt(self, *, timeout: float | None = 0) -> None: The service-wide pre-emption entry point. Waits for running tasks to complete or cancels them depending on the timeout - value. Calls :meth:`DispatchSession.cancel` on each session + value. Calls `DispatchSession.cancel` on each session in the docket when forced cancellation is required, which - propagates :class:`asyncio.CancelledError` to the routine. + propagates `asyncio.CancelledError` to the routine. The caller observes ``CancelledError`` from ``await routine()``, matching stdlib's ``task.cancel()`` semantics — operator pre-emption is indistinguishable from diff --git a/wool/src/wool/runtime/worker/session.py b/wool/src/wool/runtime/worker/session.py index a5a0dabb..de603fea 100644 --- a/wool/src/wool/runtime/worker/session.py +++ b/wool/src/wool/runtime/worker/session.py @@ -2,12 +2,13 @@ Layered abstractions cover the worker-side dispatch lifetime: -- :class:`_RequestQueue` / :class:`_ResponseQueue` — cross-loop +- `_RequestQueue` / `_ResponseQueue` — cross-loop queues bridging the gRPC main loop and the worker loop. -- :func:`_step` — inner routine stepper. One step per request, - yields one :class:`_Response`. Routine-shape variation - (coroutine = one step; async-generator = N steps) lives here. -- :class:`DispatchSession` — per-dispatch async context manager +- `_drive_step` — inner routine stepper. One step per + request, runs inside the chain's cached + `contextvars.Context`. Routine-shape variation (coroutine + = one step; async-generator = N steps) lives here. +- `DispatchSession` — per-dispatch async context manager and iterator that owns parse, lazy worker scheduling, drive, drain, and cancel. @@ -16,11 +17,14 @@ from __future__ import annotations +__all__ = ["DispatchSession", "Rejected"] + import asyncio import concurrent.futures +import contextvars +import copy import logging from contextlib import AsyncExitStack -from dataclasses import dataclass from inspect import isasyncgenfunction from inspect import iscoroutinefunction from typing import Any @@ -28,347 +32,42 @@ from typing import AsyncIterator from typing import Coroutine from typing import Final -from typing import Literal +from typing import TypeVar from typing import assert_never from typing import cast +from uuid import UUID +from weakref import WeakValueDictionary import wool from wool import protocol -from wool.runtime.context import Context -from wool.runtime.context import attached +from wool.runtime.context.exceptions import ChainSerializationError +from wool.runtime.context.manifest import ChainManifest from wool.runtime.routine.task import Task from wool.runtime.routine.task import routine_scope from wool.runtime.serializer import Serializer from wool.runtime.worker.connection import _complete_teardown - -__all__ = ["DispatchSession", "Rejected"] +from wool.runtime.worker.frame import ExceptionResponseFrame +from wool.runtime.worker.frame import Frame +from wool.runtime.worker.frame import NextRequestFrame +from wool.runtime.worker.frame import RequestFrame +from wool.runtime.worker.frame import ResponseFrame +from wool.runtime.worker.frame import ResultResponseFrame +from wool.runtime.worker.frame import SendRequestFrame +from wool.runtime.worker.frame import TaskRequestFrame +from wool.runtime.worker.frame import ThrowRequestFrame + +_T = TypeVar("_T") _log = logging.getLogger(__name__) -class _EndOfStream: - """Marker type for the end-of-stream sentinel pushed onto - :class:`_RequestQueue` and :class:`_ResponseQueue` to wake a - suspended ``get`` after :meth:`close`. Identity is unique by - construction (one instance, :data:`_EOS`); the dedicated type - parameterizes both queues precisely without falling back to - ``object`` or a string ``Literal``. - """ - - -_EOS: Final[_EndOfStream] = _EndOfStream() -"""Singleton sentinel marking end of a queue-based dispatch stream.""" - - -@dataclass -class _Response: - """One frame on the response side of the dispatch protocol. - - Carries a successful step result and the post-step - :class:`protocol.Context` snapshot so the handler can ship - caller-visible mutations on the Response. Failures never reach - this type — they propagate raw out of :func:`_step` and ship - through the dispatch handler's terminal-exception clause - instead. - """ - - result: Any - context: protocol.Context - - def to_protobuf(self, *, serializer: Serializer) -> protocol.Response: - """Build a :class:`protocol.Response` from this frame. - - Serializes ``result`` via *serializer* and attaches the - post-step context (ID + var snapshot) on the response. - - :param serializer: - Serializer for the payload (:data:`wool.__serializer__`). - """ - return protocol.Response( - result=protocol.Message(dump=serializer.dumps(self.result)), - context=self.context, - ) - - -@dataclass -class _Request: - """One request on the dispatch protocol. - - Wire-decoded on the caller (main-loop) side via - :meth:`from_protobuf`, pushed cross-loop to the worker, and - consumed by :func:`_step` on the worker loop. - - :param action: - The async-generator step verb: ``"next"`` advances without - a value (``asend(None)``; also synthesized for coroutine - routines that take a single step), ``"send"`` advances - with a payload (``asend(payload)``), ``"throw"`` injects - an exception (``athrow(payload)``). - :param payload: - The decoded payload for ``send``/``throw``; ``None`` for - ``next``. - :param caller_wire_context: - The caller's :class:`protocol.Context` (to be decoded and - merged into the worker's ``work_ctx`` before the step runs). - """ - - action: Literal["next", "send", "throw"] - payload: Any - caller_wire_context: protocol.Context - - @classmethod - def from_protobuf( - cls, - request: protocol.Request, - *, - work_ctx: Context, - serializer: Serializer, - ) -> _Request: - """Decode a :class:`protocol.Request` into a request object. - - Reads the ``payload`` oneof and decodes ``send``/``throw`` - bodies via *serializer* under - ``attached(work_ctx, guarded=False)`` — so any pickled - :class:`wool.ContextVar` / :class:`wool.Token` in the payload - reconstitutes against ``work_ctx`` rather than lazily - registering a Context on the dispatch handler's transient - task. The ``request.context`` field is forwarded as - ``caller_wire_context`` for the worker-loop side to decode and - merge into ``work_ctx`` before the routine step runs. - - :param request: - The incoming :class:`protocol.Request`. - :param work_ctx: - The dispatch handler's :class:`Context`, used as the - attach scope for payload decode. - :param serializer: - Serializer for the payload — always :data:`wool.__serializer__`. - :raises ValueError: - If the ``payload`` oneof is unset or unknown — the wire - envelope parsed cleanly but carries no recognizable - iteration command. - """ - match request.WhichOneof("payload"): - case "next": - return cls("next", None, request.context) - case "send": - with attached(work_ctx, guarded=False): - value = serializer.loads(request.send.dump) - return cls("send", value, request.context) - case "throw": - with attached(work_ctx, guarded=False): - exc = serializer.loads(request.throw.dump) - return cls("throw", exc, request.context) - case _: # pragma: no cover — defensive default for proto oneof - raise ValueError( - f"unknown request payload oneof: {request.WhichOneof('payload')!r}" - ) - - -class _RequestQueue: - """Cross-loop queue carrying gRPC request envelopes from the - main (gRPC) loop to the worker loop's :func:`_step` driver. - - Producers on the main loop push :class:`protocol.Request` - envelopes via :meth:`put`. The consumer on the worker loop pulls - them via :meth:`get`, which decodes each envelope into a - :class:`_Request` via :meth:`_Request.from_protobuf` before - returning. Decoding on the worker side keeps payload - deserialization (which may reconstitute pickled - :class:`wool.ContextVar` / :class:`wool.Token` instances under - ``work_ctx``) inside the same task that owns ``work_ctx`` for - the routine's lifetime. - - Closure: :meth:`close` pushes a sentinel so :meth:`get` returns - :data:`None` once the producer side is done. - """ - - def __init__( - self, - work_ctx: Context, - worker_loop: asyncio.AbstractEventLoop, - *, - serializer: Serializer, - ) -> None: - self._queue: asyncio.Queue[protocol.Request | _EndOfStream] = asyncio.Queue() - self._work_ctx = work_ctx - self._worker_loop = worker_loop - self._serializer = serializer - - def put(self, request: protocol.Request) -> None: - """Push a :class:`protocol.Request` onto the queue. - - Cross-loop safe — schedules the put on the worker loop via - :func:`asyncio.AbstractEventLoop.call_soon_threadsafe`. - """ - self._worker_loop.call_soon_threadsafe(self._queue.put_nowait, request) - - async def get(self) -> _Request | None: - """Pop the next decoded :class:`_Request`, or :data:`None` - when the queue has been :meth:`close`\\ d. - - Awaitable on the worker loop only. - """ - item = await self._queue.get() - if isinstance(item, _EndOfStream): - return None - return _Request.from_protobuf( - item, work_ctx=self._work_ctx, serializer=self._serializer - ) - - def close(self) -> None: - """Signal end of input by pushing the close sentinel. - Cross-loop safe.""" - self._worker_loop.call_soon_threadsafe(self._queue.put_nowait, _EOS) - - -class _ResponseQueue: - """Cross-loop queue carrying :class:`_Response` frames from the - worker loop's :func:`_step` driver back to the main (gRPC) - loop's :meth:`DispatchSession.__aiter__`. - - Producers on the worker loop push frames via :meth:`put` and - signal end-of-stream via :meth:`close`. The consumer on the - main loop pulls them via :meth:`get`, which returns :data:`None` - after a clean termination (the routine exhausted or returned) - and **raises** the worker task's underlying exception when the - worker died — the queue holds a reference to the - worker-completion :class:`concurrent.futures.Future` so the - sentinel-and-failure check co-locates with the close sentinel - that triggers it. The exception propagates out of - :meth:`DispatchSession.__aiter__` for the dispatch handler's - terminal-exception clause to ship. - """ - - def __init__( - self, - main_loop: asyncio.AbstractEventLoop, - worker_done: concurrent.futures.Future, - ) -> None: - # Unbounded by necessity: both response-frame pushes (the - # data path) and ``_EOS`` pushes (close + ``_on_done``) - # share this queue via ``put_nowait``, so a hard cap would - # need to leave headroom for one or two sentinel slots. The - # actual invariant — bounded by producer/consumer - # alternation in :func:`_run` and - # :meth:`DispatchSession._iterate` to ≤1 response in flight - # — is enforced structurally there: the worker pushes one - # response, then awaits the next request before pushing - # again. A future change that decouples that cadence - # (prefetch, batching) needs to add explicit backpressure - # here rather than relying on this queue to provide it. - self._queue: asyncio.Queue[_Response | _EndOfStream] = asyncio.Queue() - self._main_loop = main_loop - self._worker_done = worker_done - - def put(self, response: _Response) -> None: - """Push a :class:`_Response` onto the queue. - - Cross-loop safe — schedules the put on the main loop via - :func:`asyncio.AbstractEventLoop.call_soon_threadsafe`. - """ - self._main_loop.call_soon_threadsafe(self._queue.put_nowait, response) - - async def get(self) -> _Response | None: - """Pop the next response, or :data:`None` after a clean - :meth:`close`. - - **Raises** the worker task's exception when the close - sentinel arrives and ``worker_done`` carries one — - surfacing worker failures (pre-stream, routine-time, or - cancellation) up to :meth:`DispatchSession.__aiter__` so they - propagate to the dispatch handler's terminal-exception - clause. - - Awaitable on the main loop only. - """ - result = await self._queue.get() - if isinstance(result, _EndOfStream): - # The worker-completion future is the synchronization - # primitive: when the worker dies with an exception, - # ``worker_done`` is set before the close sentinel is - # observable here, so reading the exception (if any) - # surfaces worker failures alongside the EOS sentinel. - # A clean routine end may close before the worker task - # finishes, in which case ``worker_done`` is still - # pending — return ``None`` either way. - if self._worker_done.done(): - exc = self._worker_done.exception() - if exc is not None: - raise exc - return None - return result - - def close(self) -> None: - """Signal end of responses by pushing the close sentinel. - Cross-loop safe.""" - self._main_loop.call_soon_threadsafe(self._queue.put_nowait, _EOS) - - -async def _step( - routine: Coroutine | AsyncGenerator, - streaming: bool, - request: _Request, - work_ctx: Context, - *, - serializer: Serializer, -) -> _Response: - """Drive *routine* through one *request* and return the - corresponding :class:`_Response`. - - Decodes the caller's wire context, merges state into - ``work_ctx``, then steps the routine (``await routine`` for - coroutines; ``asend|athrow`` for async-generators). Returns a - result-bearing :class:`_Response` carrying the post-step - snapshot of ``work_ctx``. - - Most exceptions propagate raw — :class:`StopAsyncIteration`, - routine-raised exceptions, snapshot encode failures — because - the dispatch handler ships the next failure it catches as the - routine's terminal frame on the wire. A - :class:`BaseExceptionGroup` from the per-step caller-context - decode is rewrapped with a "mid-stream" label so it's - distinguishable from the initial-frame variant in tracebacks; - the rewrap preserves the umbrella class so the constructor's - auto-downgrade still routes Exception-only peers along the - routine-failure path and leaves non-Exception peers (e.g. - ``KeyboardInterrupt``) to tear the dispatch down rather than - ship as a typed response. - """ - try: - incoming = Context.from_protobuf( - request.caller_wire_context, serializer=serializer - ) - except BaseExceptionGroup as eg: - raise BaseExceptionGroup( - "mid-stream request context decode failed", - list(eg.exceptions), - ) from eg - if incoming.has_state(): - work_ctx.update(incoming) - if streaming: - gen = cast(AsyncGenerator, routine) - match request.action: - case "next": - value = await gen.asend(None) - case "send": - value = await gen.asend(request.payload) - case "throw": - value = await gen.athrow(request.payload) - case _: # pragma: no cover - assert_never(request.action) - else: - value = await cast(Coroutine, routine) - return _Response(result=value, context=work_ctx.to_protobuf(serializer=serializer)) - - class Rejected(Exception): - """Raised by :meth:`DispatchSession.__aenter__` when the dispatch - parse phase fails — :class:`wool.Context` decode, - :class:`wool.Task` rebuild, or routine-type validation. + """Raised by `DispatchSession.__aenter__` when the dispatch + parse phase fails — Wool chain-manifest decode, + `wool.Task` rebuild, or routine-type validation. The dispatch handler catches this and replies with a Nack - whose ``exception`` field carries :attr:`original` serialized + whose ``exception`` field carries `original` serialized via the session's ``serializer`` attribute (always ``wool.__serializer__``, cloudpickle). Same path as a routine-time failure's ``Response.exception``. @@ -389,40 +88,43 @@ class DispatchSession: worker-side lifetime end-to-end: - **Parse phase** (``__aenter__``) reads the first - :class:`protocol.Request` off *request_iterator* and parses - it: decodes the caller's wool.Context snapshot, rebuilds the - wool.Task under ``attached(context, guarded=False)``, and - validates the routine type. Failures are wrapped in - :class:`Rejected` so the dispatch handler can surface them - via Nack-with-exception. A first-request read failure (empty - iterator, gRPC error) propagates raw — no parsed payload - exists to serialize for the caller. + `protocol.Request` off *request_iterator* and parses + it: decodes the caller's Wool chain manifest, rebuilds + the wool.Task, and validates the routine type. Failures are + wrapped in `Rejected` so the dispatch handler can + surface them via Nack-with-exception. A first-request read + failure (empty iterator, gRPC error) propagates raw — no + parsed payload exists to serialize for the caller. The worker-loop driver is **not** scheduled here; that - happens lazily on the first ``__aiter__`` call so the - dispatch handler can run pre-iteration decisions - (backpressure) against the parsed task and context without - contending with the worker for ``Context._guard()``. + happens lazily on the first ``__aiter__`` call. The dispatch + handler runs pre-iteration decisions (backpressure) against + the parsed task and the decoded + `ChainManifest.vars + ` + mapping off `decoded` before the worker task is + scheduled. - **Iteration** (``__aiter__``) schedules the worker driver on first call and drives the request/response loop on the main - loop. Sets up cross-loop :class:`_RequestQueue` / - :class:`_ResponseQueue` and submits a worker-loop task that - enters :func:`routine_scope` for the parsed task and drives the - routine through :func:`_step`. The - :class:`concurrent.futures.Future` held by the response + loop. Sets up cross-loop `_RequestQueue` / + `_ResponseQueue` and submits a worker-loop task that + enters `routine_scope` for the parsed task and drives the + routine through `_drive_step`. The + `concurrent.futures.Future` held by the response queue surfaces pre-stream worker failures so they propagate - out of :meth:`_ResponseQueue.get` rather than hang + out of `_ResponseQueue.get` rather than hang iteration. Forwards each subsequent - :class:`protocol.Request` from *request_iterator* through - the request queue and yields one :class:`_Response` per + `protocol.Request` from *request_iterator* through + the request queue and yields one `ResponseFrame` per response. The coroutine path synthesizes a single ``"next"`` request. Pre-stream worker failures raise out of - :meth:`_ResponseQueue.get` and propagate raw — the dispatch + `_ResponseQueue.get` and propagate raw — the dispatch handler's terminal-exception clause builds the wire response - with a snapshot of ``self.context`` and the dumped exception. + from `_final_wire_chain_manifest` (encoded by the worker task + inside its own Chain) and the dumped exception. - - **Teardown** (``__aexit__``) calls :meth:`drain` (close + - **Teardown** (``__aexit__``) calls `drain` (close request queue + await ``worker_done``) before unwinding the exit stack. Worker exceptions surfaced via ``worker_done`` are silently swallowed — pre-stream and @@ -437,16 +139,16 @@ class DispatchSession: - **Cancellation** (``cancel``) sets a flag, cancels the worker driver task on the worker loop, and pushes ``_EOS`` onto the response queue. The worker-task cancellation is - what propagates :class:`asyncio.CancelledError` into a - routine mid-``_step``; without it, a compute-bound or + what propagates `asyncio.CancelledError` into a + routine mid-``_drive_step``; without it, a compute-bound or sleeping routine would run to natural completion after the caller has gone away. Used by the service's docket-cancel path on shutdown and by the dispatch handler as the on-exit cleanup hook. Idempotent. - Public attributes ``.task``, ``.context``, ``.serializer`` are + Public attributes ``.task``, ``.decoded``, ``.serializer`` are populated on enter for use by the dispatch handler (e.g., - backpressure). + backpressure, which dry-run-mounts ``.decoded``). :param request_iterator: The bidirectional request stream. The first frame is read @@ -459,7 +161,7 @@ class DispatchSession: """ task: Task - context: Context + decoded: ChainManifest serializer: Serializer def __init__( @@ -476,10 +178,24 @@ def __init__( self._response_queue: _ResponseQueue | None = None self._worker_done: concurrent.futures.Future | None = None self._worker_task: asyncio.Task | None = None - self._iterator: AsyncGenerator[_Response, None] | None = None + self._iterator: AsyncGenerator[ResponseFrame, None] | None = None + # The initial `RequestFrame` decoded in __aenter__ — its + # chain manifest (when non-None) carries the caller's chain + # manifest state plus every ``wool.Token`` captured from the + # task args/kwargs payload. The worker driver's initialization + # mounts it inside its own ``contextvars.Context`` before the + # routine is constructed. + self._initial_frame: RequestFrame | None = None + # The worker driver encodes its final chain manifest inside the + # work Chain (Chain.to_manifest reads the backing variables) + # and publishes it here for the dispatch handler's terminal- + # exception path. ``_final_encode_error`` carries a strict-mode + # encode failure instead. + self._final_wire_chain_manifest: protocol.ChainManifest | None = None + self._final_encode_error: BaseException | None = None # All dispatch serializes through cloudpickle; this one - # serializer covers the payload, the context snapshots, and - # any :class:`Rejected` dumped exception (including pre-parse + # serializer covers the payload, the chain-manifest frames, and + # any `Rejected` dumped exception (including pre-parse # failures such as StopAsyncIteration or a malformed frame). self.serializer: Serializer = wool.__serializer__ @@ -487,32 +203,33 @@ def __init__( def streaming(self) -> bool: """Whether the parsed task is an async-generator routine. - Set by :meth:`__aenter__` after the first request is parsed + Set by `__aenter__` after the first request is parsed and the callable is validated. Read-only — exposed so the dispatch handler can decide whether to unwrap PEP 525's synthesized ``RuntimeError("async generator raised StopAsyncIteration")`` back to its original - :class:`StopAsyncIteration` for coroutine routines, + `StopAsyncIteration` for coroutine routines, without reaching across the privacy boundary. """ return self._streaming async def _safe_aclose_stack(self) -> None: - """Defensively close :attr:`_stack` and swallow routine + """Defensively close `_stack` and swallow routine teardown failures. - Used by :meth:`__aenter__`'s error arms so an aclose - failure (e.g. resource teardown raising) cannot replace + Used by `__aenter__`'s error arms so an aclose + failure (e.g., resource teardown raising) cannot replace the original parse error en route to the dispatch - handler's Nack channel. ``KeyboardInterrupt`` and - ``SystemExit`` propagate — process-level signals must - not be silently dropped during cleanup. + handler's Nack channel. Only `Exception` subclasses are + swallowed; `BaseException` signals + (``KeyboardInterrupt``, ``SystemExit``, + ``CancelledError``) propagate — process-level and + cancellation signals must not be silently dropped during + cleanup. """ try: await _complete_teardown(self._stack.aclose()) - except (KeyboardInterrupt, SystemExit): - raise - except Exception: + except Exception: # pragma: no cover — no failing callback on the stack yet pass async def __aenter__(self) -> DispatchSession: @@ -545,34 +262,26 @@ async def __aenter__(self) -> DispatchSession: "first request must carry a Task in its `payload` " f"oneof; observed {request.WhichOneof('payload')!r}" ) - try: - self.context = Context.from_protobuf( - request.context, serializer=self.serializer - ) - except BaseExceptionGroup as eg: - # Rewrap with the umbrella class so the - # constructor's auto-downgrade decides the - # propagation path: all-Exception peers (today's - # only case via :class:`ContextDecodeWarning`) - # produce an :class:`ExceptionGroup`, which the - # outer ``except Exception`` arm wraps as - # :class:`Rejected` and ships via Nack-with- - # exception. A non-Exception peer (e.g. a - # ``CancelledError`` or ``KeyboardInterrupt``) - # would keep the result as a true - # :class:`BaseExceptionGroup`, falling through to - # the ``except BaseException`` arm below where it - # propagates raw — Nack is the wrong channel for - # cancellation/interrupt signals; they should - # tear the dispatch task down rather than be - # encoded as a typed parse rejection. - raise BaseExceptionGroup( - "request context decode failed", - list(eg.exceptions), - ) from eg - - with attached(self.context, guarded=False): - self.task = Task.from_protobuf(request.task) + # Decode the initial frame. A strict-mode decode failure is + # captured as the frame's ``chain_manifest`` value rather + # than raised mid-decode — surface it up-front so the outer + # ``except Exception`` arm wraps it in ``Rejected`` for Nack + # transport (preserving the caller's ``except + # ChainSerializationError`` semantics). + decoded = Frame.from_protobuf(request, serializer=self.serializer) + assert isinstance(decoded, TaskRequestFrame) + self._initial_frame = decoded + manifest = self._initial_frame.chain_manifest + if isinstance(manifest, ChainSerializationError): + raise manifest + assert isinstance(self._initial_frame.payload, Task) + self.task = self._initial_frame.payload + # Backpressure inspection on the dispatch handler reads + # ``session.decoded`` — surface the manifest there. An + # empty manifest (no chain-manifest state) is absent on the + # frame; fall back to a stateless one so callers see a + # consistent attribute shape. + self.decoded = manifest or ChainManifest.empty() if not ( iscoroutinefunction(self.task.callable) @@ -590,7 +299,7 @@ async def __aenter__(self) -> DispatchSession: await self._safe_aclose_stack() raise # Register drain on the exit stack so it runs as part of the - # LIFO unwind. If drain raises (e.g. a ``CancelledError`` + # LIFO unwind. If drain raises (e.g., a ``CancelledError`` # reaching it during graceful shutdown), the stack still # unwinds the remaining callbacks. self._stack.push_async_callback(self.drain) @@ -599,51 +308,218 @@ async def __aenter__(self) -> DispatchSession: def _schedule_worker(self) -> None: """Set up the cross-loop request/response queues and schedule the worker driver. Called lazily from - ``__aiter__`` on the first iteration so that backpressure - and other pre-iteration decisions run before any worker - task acquires the routine's :class:`Context` guard — - otherwise a main-loop ``attached(self.context)`` would race - the worker's ``_context_scope`` ``_guard()`` and spuriously - raise on every dispatch with a backpressure hook. - - Short-circuits when :meth:`cancel` was called before the - first :meth:`__aiter__`: the queues stay ``None`` and - :meth:`_iterate` returns an empty stream. + ``__aiter__`` on the first iteration — after the dispatch + handler has already run its pre-iteration decisions + (backpressure) against the decoded + `ChainManifest.vars + ` + mapping off `decoded`. The worker driver's + initialization frame mounts `decoded` for real inside + its own `contextvars.Context`. + + Short-circuits when `cancel` was called before the + first `__aiter__`: the queues stay ``None`` and + `_iterate` returns an empty stream. + + **Closure capture chain.** The worker driver + (``_start`` → ``_run`` → ``_on_done``) is structured as + three nested closures rather than a dataclass-shaped + ``_WorkerDriver`` because the per-step state + (``request_queue``, ``response_queue``, ``worker_done``, + ``serializer``, ``streaming``, ``worker_loop``, ``work_task``) + is captured fresh per dispatch from the session attributes, + and the closure form keeps the capture site adjacent to its + consumers. Captures, by layer: + + * Top-level locals in `_schedule_worker` (this method): + ``main_loop``, ``worker_done``, ``request_queue``, + ``response_queue``, ``work_task``, ``serializer``, + ``streaming``, ``worker_loop``. Re-bound from + ``self.task`` / ``self.serializer`` / ``self._streaming`` / + ``self._worker_loop`` so ``_run``'s loop body can read + them without going through ``self`` (cheaper in the hot + path). + * ``_start`` closes over the above; constructs the worker + coroutine via ``_run()`` and schedules it on the worker + loop. ``_on_done`` (nested inside ``_start``) closes over + ``worker_done`` and ``response_queue`` so the task's + completion callback can settle the future and wake any + pending main-loop consumer. + * ``_run`` (the driver coroutine) additionally allocates a + local ``chain_registry: WeakValueDictionary[UUID, + contextvars.Context]`` per dispatch — the wool chain id to + cached-context map every frame resolves through — and + ``loop`` (rebound from ``worker_loop`` so the closure body + avoids a ``asyncio.get_running_loop()`` per iteration). + Mutates ``self._final_wire_chain_manifest`` / + ``self._final_encode_error`` on the routine-failure path + (the only writes back to ``self``). + + A flatter dataclass-shaped ``_WorkerDriver`` is a possible + follow-up; the current closure form is the minimum that + keeps the captures readable per-layer and threads the + chain registry's lifetime to the dispatch's lifetime + naturally (it goes out of scope with ``_run``). """ if self._cancelled: return main_loop = asyncio.get_running_loop() worker_done: concurrent.futures.Future = concurrent.futures.Future() - request_queue = _RequestQueue( - self.context, self._worker_loop, serializer=self.serializer - ) + request_queue = _RequestQueue(self._worker_loop, serializer=self.serializer) response_queue = _ResponseQueue(main_loop, worker_done) self._request_queue = request_queue self._response_queue = response_queue work_task = self.task - work_ctx = self.context serializer = self.serializer streaming = self._streaming + worker_loop = self._worker_loop def _start(): async def _run(): + # The worker loop is the running loop here — capture + # it explicitly so `_create_step_task` can + # construct step tasks bound to it without a fresh + # `asyncio.get_running_loop` per iteration. + loop = worker_loop + # Per-dispatch chain registry: each wool chain id seen + # on a request frame maps to the cached + # ``contextvars.Context`` that scopes that chain's + # bindings on this worker. Frames sharing a chain id + # reuse one context, so a routine's cross-yield var + # state — a ``set`` before a yield and its ``reset`` + # after — lands in the same context (stdlib async-gen + # parity). Frames carrying *distinct* chain ids within + # one stream get distinct contexts, so they do not + # pollute each other (stdlib drives each step in the + # caller's own context). Held weakly: within a dispatch + # the live ``ctx`` local keeps the entry alive; the + # registry is dropped wholesale when ``_run`` returns. + chain_registry: WeakValueDictionary[UUID, contextvars.Context] = ( + WeakValueDictionary() + ) + # Working context for unarmed frames. An unarmed caller + # propagates no chain id, so there is nothing to key on + # yet; successive unarmed frames share this one context + # (so a routine's contextvar mutations carry across + # yields) and it is left *unarmed* — a stateless + # dispatch never installs a wool Chain, so a plain + # ``asyncio.to_thread`` offload inside the routine + # copies a bare context and never trips ChainContention. + # If the routine's own ``set`` arms it, the post-step + # hook below indexes it in ``chain_registry`` under the + # minted chain id, so the back-propagated next frame + # (now carrying that id) reuses it. + dispatch_ctx: contextvars.Context | None = None + try: async with routine_scope(work_task) as routine: - while (request := await request_queue.get()) is not None: + while (raw_request := await request_queue.get()) is not None: + # Decode per-frame inside the loop so a + # single malformed wire envelope ships a + # typed terminal on the response stream + # rather than propagating raw out of the + # worker driver and surfacing as an opaque + # task death. The dispatch is no longer + # serviceable after a decode failure + # (the wire framing is broken), so we + # ship the decode error as an + # ExceptionResponseFrame and break. + try: + decoded = Frame.from_protobuf( + raw_request, serializer=serializer + ) + except Exception as decode_err: + response_queue.put( + ExceptionResponseFrame.for_send( + decode_err, serializer=serializer + ) + ) + break + # Narrow to the mid-stream request subset + # (initial TaskRequestFrame is consumed by + # the parse phase in ``__aenter__``; only + # Next/Send/Throw flow through this loop) + # so `_drive_step` gets the precise + # union without an extra cast. + assert isinstance( + decoded, + (NextRequestFrame, SendRequestFrame, ThrowRequestFrame), + ) + request = decoded + manifest = request.chain_manifest + if manifest is None or isinstance( + manifest, ChainSerializationError + ): + # Unarmed frame, or a deferred decode + # failure with no chain id to route to: + # drive it in the dispatch's single + # working context, created lazily and + # left *unarmed*. Successive unarmed + # frames share it so the routine's + # contextvar mutations carry across + # yields; a stateless routine never + # installs a wool Chain, so the worker + # context stays unarmed. For a decode + # failure, ``request.mount`` below raises + # it (or chains it onto a throw payload) + # — nothing is actually mounted. + if dispatch_ctx is None: + dispatch_ctx = contextvars.copy_context() + ctx = dispatch_ctx + else: + # Armed frame: key on the propagated + # chain id. A miss allocates a fresh + # context, so frames from distinct chain + # ids within one stream stay isolated. A + # chain the routine itself armed from the + # working context was indexed here after + # that step (see the post-step hook + # below), so the back-propagated frame + # carrying that id reuses the same + # context. The ``copy_context`` snapshot + # inherits the driver task's bindings + # (``routine_scope`` has entered + # ``task.runtime_context`` and set + # ``wool.__proxy__``). + cached = chain_registry.get(manifest.id) + if cached is None: + cached = contextvars.copy_context() + chain_registry[manifest.id] = cached + ctx = cached + + # Single mount entry point. Frame.mount + # routes through the unified + # `Chain.from_manifest` pipeline inside + # ``ctx.run(...)`` and handles the + # exception-frame decode-error-chaining + # (ThrowRequestFrame's + # ``_chain_exceptions`` walk) so + # the worker driver no longer needs the + # hand-rolled apply/walk block. + request.mount(ctx) + + # Drive the step via the top-level + # `_drive_step` helper. The helper + # builds the per-step coroutine (asend / + # athrow / coroutine-itself) and runs it + # inside the cached context via a Task + # constructed directly (bypassing the + # loop's task factory; see + # `_create_step_task`). try: - response = await _step( + value = await _drive_step( routine, streaming, request, - work_ctx, - serializer=serializer, + ctx, + loop=loop, ) except StopAsyncIteration: # Streaming SAI = clean end-of-stream; # break the driver loop. Coroutine SAI # propagates so the asyncgen-transport - # path (:meth:`_iterate`) ships it; + # path (`_iterate`) ships it; # ``WorkerService.dispatch`` unwraps # PEP 525's synthesized RuntimeError # back to the original SAI for the @@ -652,7 +528,88 @@ async def _run(): if not streaming: raise break - response_queue.put(response) + except BaseException: + # Routine raised mid-step. Pin the + # chain's manifest for the dispatch + # handler's terminal-exception clause + # *inside* the step's ``ctx`` so the + # encoded snapshot reflects the chain + # the routine actually ran on (not the + # most-recent ``last_ctx`` — which under + # cross-chain interleaving might point + # at a sibling chain). Encode errors + # land on ``_final_encode_error`` so + # the handler chains them onto the + # routine's failure via ``__cause__``. + try: + self._final_wire_chain_manifest = ctx.run( + lambda: ( + wool.__chain__.get() + .to_manifest() + .to_protobuf(serializer=serializer) + ) + ) + except LookupError: + self._final_wire_chain_manifest = None + except BaseException as encode_err: + self._final_encode_error = encode_err + self._final_wire_chain_manifest = None + raise + + # Encode the post-step chain manifest from + # within the cached contextvars.Context so the response + # reflects the routine's mutations to the + # chain's bindings. Pin the captured state + # to ``self._final_wire_chain_manifest`` as the + # step commits: the dispatch handler's + # terminal-exception path reads this field + # without re-encoding, and pinning here + # associates it with the step's chain *by + # construction* (inside the same + # ``ctx.run`` that produced it), eliminating + # the cross-chain-interleaving provenance + # ambiguity a finally-block encode of + # ``last_ctx`` would have. A strict-mode + # encode failure pins the error and + # re-raises so the dispatch handler's + # terminal-exception path ships the + # encode error on the wire. + try: + captured = ctx.run( + lambda: ( + wool.__chain__.get() + .to_manifest() + .to_protobuf(serializer=serializer) + ) + ) + except LookupError: + captured = None + except BaseException as encode_err: + self._final_encode_error = encode_err + self._final_wire_chain_manifest = None + raise + self._final_wire_chain_manifest = captured + # If the routine armed the dispatch's unarmed + # working context — its own first ``set`` + # mints a chain on this worker — index that + # context under the new chain id. The next + # frame carries that id once the ``set`` + # back-propagates and arms the caller, so it + # resolves to this same context: that is what + # lets a ``set`` before a yield and its + # ``reset`` after share one context. A + # non-None ``captured`` means the context is + # armed. + if ctx is dispatch_ctx and captured is not None: + armed = ctx.run(wool.__chain__.get) + chain_registry.setdefault(armed.id, ctx) + response_queue.put( + ResultResponseFrame.for_send( + value, + serializer=serializer, + wire_chain_manifest=captured, + ) + ) if not streaming: break response_queue.close() @@ -660,29 +617,31 @@ async def _run(): # Stop the producer side immediately on any # ``_run`` exit, including mid-frame routine # exceptions that unwind past the ``async with - # routine_scope`` block. Pre-fix, the producer's - # ``_iterate`` kept queueing frames until its - # own ``finally`` (or external teardown via - # :meth:`drain`) closed the queue, leaving a - # window where requests accumulated against a - # worker that had already failed. + # routine_scope`` block. Otherwise the producer's + # ``_iterate`` keeps queueing frames until its own + # ``finally`` (or external teardown via `drain`) + # closes the queue, leaving a window where requests + # accumulate against a worker that has already + # failed. request_queue.close() + coro = _run() try: - task = self._worker_loop.create_task( - _run(), - context=work_ctx, # pyright: ignore[reportArgumentType] - ) + task = self._worker_loop.create_task(coro) except BaseException as e: # Late-loop-closure or task-factory failure: # ``call_soon_threadsafe`` succeeded earlier (loop # was open at scheduling time) but ``create_task`` # raises here because the loop has since closed # (or the factory itself rejected the coroutine). - # Settle ``worker_done`` so :meth:`drain` does not - # await an unresolved future, and close the + # ``create_task`` never took ownership of ``coro``, + # so close it explicitly — an orphaned coroutine + # leaks a "coroutine was never awaited" RuntimeWarning + # at GC. Settle ``worker_done`` so `drain` does + # not await an unresolved future, and close the # response queue so any pending - # :meth:`_ResponseQueue.get` returns immediately. + # `_ResponseQueue.get` returns immediately. + coro.close() worker_done.set_exception(e) response_queue.close() return @@ -732,13 +691,111 @@ def _on_done(t: asyncio.Task): self._worker_loop.call_soon_threadsafe(_start) self._worker_done = worker_done + def terminal_response( + self, + exception: BaseException, + *, + serializer: Serializer, + ) -> ResponseFrame: + """Build the terminal `ExceptionResponseFrame` for a + routine-failure dispatch. + + Owns the encode-error vs. lazy-wire-frame decision and + the PEP 525 ``StopAsyncIteration`` unwrap, keeping + `_final_wire_chain_manifest` / `_final_encode_error` access + encapsulated here. + + Callers MUST have awaited `drain` before invoking this + so the worker task's in-context encode publish has settled. + The dispatch handler then yields + ``session.terminal_response(e, serializer=s).to_protobuf()`` + as the wire-side terminal frame. + + Two distinct shapes ride out: + + * **Lazy-wire-frame** (the worker stayed armed or stayed + unarmed cleanly): ship the routine exception alongside the + worker's final chain manifest. + * **Strict-mode encode failure**: the worker's in-context + `ChainManifest.to_protobuf` raised — typically a + `wool.ChainSerializationError` aggregating per-var + warnings. Chain the encode error as ``__cause__`` on a + *copy* of the routine exception (the live instance may be + a module-level singleton — for example, an interpreter- + cached `StopAsyncIteration` — or already propagating + elsewhere, so mutating its ``__cause__`` / + ``__suppress_context__`` would alter globally observable + state) and ship the result with no chain-manifest payload. + + Coroutine routines that raised `StopAsyncIteration` + get the PEP 525 ``RuntimeError("async generator raised + StopAsyncIteration")`` wrapping unwrapped so the caller's + ``await routine()`` surfaces the original SAI raw — matching + stdlib coroutine semantics. Streaming routines keep the + ``RuntimeError`` shape; that already matches stdlib + ``async for x in agen()`` semantics. + + :param exception: + The routine-time exception captured by the dispatch + handler. + :param serializer: + The serializer to use for the response frame — typically + ``session.serializer``. + + :returns: + An `ExceptionResponseFrame` ready to encode via + `Frame.to_protobuf`. + """ + e = exception + # PEP 525 SAI unwrap (coroutine-only): the streaming + # transport in `_iterate` synthesizes the + # ``RuntimeError("async generator raised + # StopAsyncIteration")`` wrapping when a coroutine raises + # `StopAsyncIteration`. Unwrap so the caller sees the + # original SAI. + if ( + not self._streaming + and isinstance(e, RuntimeError) + and isinstance(e.__cause__, StopAsyncIteration) + ): + e = e.__cause__ + if self._final_encode_error is not None: + # Strict-mode encode failure: copy ``e`` before mutating + # so the live (possibly globally-shared) instance is not + # touched. Chain the encode error as ``__cause__`` for + # diagnostic visibility while keeping the caller's + # primary ``except`` clause matching the original type. + e = copy.copy(e) + e.__cause__ = self._final_encode_error + e.__suppress_context__ = True + return ExceptionResponseFrame.for_send( + e, + serializer=serializer, + # The encode error means there is no trustworthy + # final chain manifest to ship; explicit + # ``wire_chain_manifest=None`` suppresses the field + # entirely (and avoids an unrelated auto-capture + # from the main-loop scope). + wire_chain_manifest=None, + ) + # Lazy-wire-frame: when the worker stayed unarmed the + # captured chain manifest is ``None`` and + # `Frame.to_protobuf` omits the optional ``context`` + # field; the caller's apply-back skips the absent field. + return ExceptionResponseFrame.for_send( + e, + serializer=serializer, + wire_chain_manifest=self._final_wire_chain_manifest, + ) + async def drain(self) -> None: """Close the request queue and await the worker driver to complete. Idempotent — safe to call multiple times. - After this returns, ``self.context`` is no longer being - mutated by the worker, so the dispatch handler can safely - snapshot it for the terminal-exception response. + After this returns, the worker task has published its final + `_final_wire_chain_manifest` (or `_final_encode_error`), + so the dispatch handler can safely read it for the + terminal-exception response. Worker exceptions raised during the drain are swallowed but logged at ``WARNING``: pre-stream and routine-time @@ -802,28 +859,28 @@ async def drain(self) -> None: async def __aexit__(self, exc_type, exc_val, exc_tb): # Shield the stack unwind from caller cancellation so the # registered drain callback runs to completion — see - # :func:`_complete_teardown`. The registered managers never + # `_complete_teardown`. The registered managers never # suppress, so discarding the suppression return value # (always falsy here) is behaviour-preserving. await _complete_teardown(self._stack.aclose()) - def __aiter__(self) -> AsyncIterator[_Response]: + def __aiter__(self) -> AsyncIterator[ResponseFrame]: if self._iterator is None: self._schedule_worker() self._iterator = self._iterate() return self._iterator - async def _iterate(self) -> AsyncGenerator[_Response, None]: + async def _iterate(self) -> AsyncGenerator[ResponseFrame, None]: """Drive the request/response loop on the main loop. - Forwards each :class:`protocol.Request` to the request - queue and yields one :class:`_Response` per response + Forwards each `protocol.Request` to the request + queue and yields one `ResponseFrame` per response received. Coroutine path synthesizes a single ``"next"`` request. Pre-stream worker failures raise out of - :meth:`_ResponseQueue.get` and propagate to the dispatch + `_ResponseQueue.get` and propagate to the dispatch handler's terminal-exception clause. - Raises :class:`asyncio.CancelledError` when :meth:`cancel` + Raises `asyncio.CancelledError` when `cancel` has been invoked — mirroring stdlib's ``await task`` semantics where ``task.cancel()`` from any source (caller, routine self-raise, operator preempt) surfaces as @@ -849,7 +906,7 @@ async def _iterate(self) -> AsyncGenerator[_Response, None]: # from "cancel arrived after natural completion". Matches # stdlib ``task.cancel()``: cancel-before-completion # surfaces as ``CancelledError``; cancel-after-completion - # is a no-op. The pre-fix unconditional check would ship a + # is a no-op. An unconditional check would ship a # spurious trailing exception frame on a routine the # caller has already observed completing — invisible to # user code through wool's public API, but counted by any @@ -864,29 +921,27 @@ async def _iterate(self) -> AsyncGenerator[_Response, None]: break try: request_queue.put(protobuf_request) - except RuntimeError: - # Mirror :meth:`drain`'s tolerance: when - # the worker loop has been torn down + except _WorkerLoopClosed: + # The worker loop has been torn down # mid-stream (graceful shutdown landing - # between two main-loop pumps), - # ``call_soon_threadsafe`` raises - # ``RuntimeError("Event loop is - # closed")``. The dispatch is no longer - # serviceable; break cleanly so the - # stream terminates without - # misattributing the loop teardown as a - # routine failure. + # between two main-loop pumps). + # `_RequestQueue.put` raises the + # typed signal specifically for this case + # so a broad ``except RuntimeError`` here + # would not silently swallow an unrelated + # protocol violation. Break cleanly — the + # dispatch is no longer serviceable. break response = await response_queue.get() if response is None: # ``_EOS`` arrived. Two producers can push - # it: :meth:`cancel` closing the response - # queue explicitly, or :meth:`_on_done` + # it: `cancel` closing the response + # queue explicitly, or `_on_done` # closing it after the worker task # finalizes. The cancel-induced path is # the one the trailing check guards # against — the queue's pre-pushed _EOS - # may have raced :meth:`_on_done` settling + # may have raced `_on_done` settling # ``worker_done`` with the actual # ``CancelledError``. if self._cancelled: @@ -916,7 +971,7 @@ async def _iterate(self) -> AsyncGenerator[_Response, None]: # sentinel. ``close()`` is idempotent (it just # pushes ``_EOS``) so the eventual # second close from drain is a no-op. Mirrors - # :meth:`drain`'s tolerance for a closed worker + # `drain`'s tolerance for a closed worker # loop: during graceful shutdown the loop pool # may have torn the worker loop down already, # in which case ``call_soon_threadsafe`` raises @@ -927,15 +982,32 @@ async def _iterate(self) -> AsyncGenerator[_Response, None]: except RuntimeError: pass else: - # Mirror the streaming branch's ``RuntimeError`` guard: - # a worker loop torn down between ``__aiter__``'s - # ``_schedule_worker`` and this first put raises - # ``RuntimeError("Event loop is closed")`` out of - # ``call_soon_threadsafe``. Exit cleanly rather than + # Read the caller's prime ``NextRequestFrame`` off the + # request iterator and forward it to the worker. The + # frame carries the caller's auto-captured chain manifest; + # synthesising a manifest-less ``Request(next=Void())`` + # would bypass the per-frame chain-manifest propagation + # the worker driver expects (mid-stream frames carry + # the manifest; boundary frames don't). Streaming uses + # ``async for`` over the iterator at the top of this + # branch — the coroutine branch reads exactly one frame. + # + # Fall back to a manifest-less synthetic ``Next`` if the + # caller closed the write side before sending the prime + # frame (``StopAsyncIteration``): the routine still runs, + # just without the caller's chain manifest. + try: + prime = await anext(aiter(self._request_iterator)) + except StopAsyncIteration: + prime = protocol.Request(next=protocol.Void()) + # Mirror the streaming branch: a worker loop torn down + # between ``__aiter__``'s ``_schedule_worker`` and this + # first put raises `_WorkerLoopClosed` out of + # `_RequestQueue.put`. Exit cleanly rather than # ship a transport-teardown failure as a routine fault. try: - request_queue.put(protocol.Request(next=protocol.Void())) - except RuntimeError: + request_queue.put(prime) + except _WorkerLoopClosed: return response = await response_queue.get() if response is None: @@ -950,28 +1022,28 @@ async def _iterate(self) -> AsyncGenerator[_Response, None]: async def cancel(self) -> None: """Signal cancellation. Idempotent. Cross-task safe. - Sets a flag observed by :meth:`_schedule_worker` (so a - cancellation arriving before the first :meth:`__aiter__` - short-circuits the worker schedule) and by :meth:`_iterate` - (so iteration surfaces :class:`asyncio.CancelledError` at + Sets a flag observed by `_schedule_worker` (so a + cancellation arriving before the first `__aiter__` + short-circuits the worker schedule) and by `_iterate` + (so iteration surfaces `asyncio.CancelledError` at the next yield boundary — mirroring stdlib's ``await task`` semantics where ``task.cancel()`` from any source produces the same observable), cancels the worker driver task on the - worker loop so a routine mid-``_step`` (e.g., + worker loop so a routine mid-``_drive_step`` (e.g., ``await asyncio.sleep(...)``) receives a - :class:`asyncio.CancelledError` rather than running to + `asyncio.CancelledError` rather than running to natural completion, and pushes ``_EOS`` onto the response - queue so any suspended :meth:`_ResponseQueue.get` returns + queue so any suspended `_ResponseQueue.get` returns ``None`` and unblocks the iterator. Worker-task cancellation is scheduled via ``loop.call_soon_threadsafe`` to remain cross-loop safe, and tolerates a closed worker loop (the dispatch is no longer serviceable; the existing ``RuntimeError`` swallow - on ``call_soon_threadsafe`` matches :meth:`drain`'s + on ``call_soon_threadsafe`` matches `drain`'s tolerance). - Unlike a direct ``aclose()`` on :attr:`_iterator`, this is + Unlike a direct ``aclose()`` on `_iterator`, this is safe to call from a task other than the one driving the iterator — no ``RuntimeError("asynchronous generator is already running")`` is possible because no aclose is @@ -979,7 +1051,7 @@ async def cancel(self) -> None: Suspension caveat. The three signals — ``_cancelled`` flag, worker-task cancel, ``_EOS`` push — together unblock - :meth:`_iterate`'s ``_ResponseQueue.get`` suspensions and any + `_iterate`'s ``_ResponseQueue.get`` suspensions and any inter-step ``_cancelled`` observation. They do NOT interrupt a request-iterator read in flight: a streaming dispatch idling on ``async for protobuf_request in @@ -987,7 +1059,7 @@ async def cancel(self) -> None: last yield and the caller's next ``asend``/``anext``) only observes ``_cancelled`` after a new frame arrives or after the gRPC layer tears the stream down. The operator-preempt - path (:meth:`WorkerService._preempt`) relies on the broader + path (`WorkerService._preempt`) relies on the broader gRPC server shutdown to cancel the stream and close that gap; ``cancel()`` alone does not match stdlib ``task.cancel()`` for that specific suspension. @@ -1002,3 +1074,303 @@ async def cancel(self) -> None: pass if self._response_queue is not None: self._response_queue.close() + + +class _EndOfStream: + """Marker type for the end-of-stream sentinel pushed onto + `_RequestQueue` and `_ResponseQueue` to wake a + suspended ``get`` after `close`. Identity is unique by + construction (one instance, `_EOS`); the dedicated type + parameterizes both queues precisely without falling back to + ``object`` or a string ``Literal``. + """ + + +_EOS: Final[_EndOfStream] = _EndOfStream() +"""Singleton sentinel marking end of a queue-based dispatch stream.""" + + +class _WorkerLoopClosed(Exception): + """Internal control-flow signal: the worker loop was torn down + during dispatch. + + Raised by `_RequestQueue.put` when + ``call_soon_threadsafe`` rejects because the worker loop is + closed. The streaming and coroutine branches of + `DispatchSession._iterate` catch this specifically to break + cleanly without misclassifying the graceful teardown as a routine + failure. + + Extends `Exception` (not `RuntimeError`) so broad + ``except RuntimeError:`` patterns elsewhere can't accidentally + swallow the signal. Matches stdlib's control-flow-signal + convention — `StopIteration` and `StopAsyncIteration` + both extend `Exception` directly. + """ + + +class _RequestQueue: + """Cross-loop queue carrying gRPC request envelopes from the + main (gRPC) loop to the worker loop's `_run` driver. + + Producers on the main loop push `protocol.Request` + envelopes via `put`. The consumer on the worker loop pulls + them via `get`, which decodes each envelope into a + `~wool.runtime.worker.frame.RequestFrame` via + `RequestFrame.from_protobuf` before returning. Decoding on + the worker side keeps payload deserialization (which may + reconstitute pickled `wool.ContextVar` instances) inside + the worker-loop task that runs the routine under the work context. + + Closure: `close` pushes a sentinel so `get` returns + `None` once the producer side is done. + """ + + def __init__( + self, + worker_loop: asyncio.AbstractEventLoop, + *, + serializer: Serializer, + ) -> None: + self._queue: asyncio.Queue[protocol.Request | _EndOfStream] = asyncio.Queue() + self._worker_loop = worker_loop + self._serializer = serializer + + def put(self, request: protocol.Request) -> None: + """Push a `protocol.Request` onto the queue. + + Cross-loop safe — schedules the put on the worker loop via + `asyncio.AbstractEventLoop.call_soon_threadsafe`. + + Raises `_WorkerLoopClosed` if the worker loop has + already been torn down (``call_soon_threadsafe`` rejects + with ``RuntimeError("Event loop is closed")``). Callers + catch that typed signal specifically to break cleanly + without misclassifying graceful teardown as a routine + failure. Other `RuntimeError` instances propagate + unchanged — protocol violations remain visible rather than + being swallowed by a broad ``except RuntimeError`` clause. + """ + try: + self._worker_loop.call_soon_threadsafe(self._queue.put_nowait, request) + except RuntimeError as e: + if self._worker_loop.is_closed(): + raise _WorkerLoopClosed() from e + # A non-closed-loop ``RuntimeError`` from + # ``call_soon_threadsafe`` is not reproducible without mocking + # the loop; surface it rather than swallow it. + raise # pragma: no cover + + async def get(self) -> protocol.Request | None: + """Pop the next request from the queue. + + Decoding is deferred to the consumer (``_run``) so a single + malformed wire envelope can be surfaced as a typed terminal + on the response stream rather than killing the worker driver + task mid-stream. + + Awaitable on the worker loop only. + """ + item = await self._queue.get() + if isinstance(item, _EndOfStream): + return None + return item + + def close(self) -> None: + """Signal end of input by pushing the close sentinel.""" + self._worker_loop.call_soon_threadsafe(self._queue.put_nowait, _EOS) + + +class _ResponseQueue: + """Cross-loop queue carrying `ResponseFrame` instances from + the worker loop's `_run` driver back to the main (gRPC) + loop's `DispatchSession.__aiter__`. + + Producers on the worker loop push frames via `put` and + signal end-of-stream via `close`. The consumer on the + main loop pulls them via `get`, which returns `None` + after a clean termination (the routine exhausted or returned) + and **raises** the worker task's underlying exception when the + worker died — the queue holds a reference to the + worker-completion `concurrent.futures.Future` so the + sentinel-and-failure check co-locates with the close sentinel + that triggers it. The exception propagates out of + `DispatchSession.__aiter__` for the dispatch handler's + terminal-exception clause to ship. + """ + + def __init__( + self, + main_loop: asyncio.AbstractEventLoop, + worker_done: concurrent.futures.Future, + ) -> None: + # Unbounded by necessity: both response-frame pushes (the + # data path) and ``_EOS`` pushes (close + ``_on_done``) + # share this queue via ``put_nowait``, so a hard cap would + # need to leave headroom for one or two sentinel slots. The + # actual invariant — bounded by producer/consumer + # alternation in `_run` and + # `DispatchSession._iterate` to ≤1 response in flight + # — is enforced structurally there: the worker pushes one + # response, then awaits the next request before pushing + # again. A future change that decouples that cadence + # (prefetch, batching) needs to add explicit backpressure + # here rather than relying on this queue to provide it. + self._queue: asyncio.Queue[ResponseFrame | _EndOfStream] = asyncio.Queue() + self._main_loop = main_loop + self._worker_done = worker_done + + def put(self, response: ResponseFrame) -> None: + """Push a `ResponseFrame` onto the queue. + + Cross-loop safe — schedules the put on the main loop via + `asyncio.AbstractEventLoop.call_soon_threadsafe`. + """ + self._main_loop.call_soon_threadsafe(self._queue.put_nowait, response) + + async def get(self) -> ResponseFrame | None: + """Pop the next response, or `None` after a clean + `close`. + + **Raises** the worker task's exception when the close + sentinel arrives and ``worker_done`` carries one — + surfacing worker failures (pre-stream, routine-time, or + cancellation) up to `DispatchSession.__aiter__` so they + propagate to the dispatch handler's terminal-exception + clause. + + Awaitable on the main loop only. + """ + result = await self._queue.get() + if isinstance(result, _EndOfStream): + # The worker-completion future is the synchronization + # primitive: when the worker dies with an exception, + # ``worker_done`` is set before the close sentinel is + # observable here, so reading the exception (if any) + # surfaces worker failures alongside the EOS sentinel. + # A clean routine end may close before the worker task + # finishes, in which case ``worker_done`` is still + # pending — return ``None`` either way. + if self._worker_done.done(): + exc = self._worker_done.exception() + if exc is not None: + raise exc + return None + return result + + def close(self) -> None: + """Signal end of responses by pushing the close sentinel. + Cross-loop safe.""" + self._main_loop.call_soon_threadsafe(self._queue.put_nowait, _EOS) + + +def _create_step_task( + coro: Coroutine[Any, Any, _T], + *, + loop: asyncio.AbstractEventLoop, + context: contextvars.Context, +) -> asyncio.Task[_T]: + """Create an `asyncio.Task` that runs *coro* in *context*, + bypassing Wool's fork-on-task factory. + + The per-step driver constructs each step's task directly via the + `asyncio.Task` constructor rather than ``loop.create_task``. + ``loop.create_task`` routes through whatever task factory the loop + has installed — Wool's factory included — and Wool's factory forks + the child onto a fresh chain. For an *internal* driver task whose + job is to drive a chain that already exists on the registry, + forking would mint a new chain id and break the registry lookup. + Going through the `asyncio.Task` constructor skips the + factory and runs *coro* in the supplied *context* unchanged. + + Localised here so the bypass site is identifiable in a grep, and + so a future change that needs uvloop's native task class can swap + the construction in one place without touching every caller. + """ + return asyncio.Task(coro, loop=loop, context=context) + + +async def _drive_step( + routine: Any, + streaming: bool, + request: NextRequestFrame | SendRequestFrame | ThrowRequestFrame, + work_ctx: contextvars.Context, + *, + loop: asyncio.AbstractEventLoop, +) -> Any: + """Drive one step of the worker's per-routine loop. + + Builds the step coroutine from the request kind (``asend`` / + ``athrow`` for an async-generator routine, the routine itself + for a coroutine routine), runs it inside *work_ctx* as a freshly + constructed `asyncio.Task`, and returns the step's + yielded or returned value. + + The top-level per-step build-and-execute, kept a single-purpose + function readable on its own. + + The session loop in `DispatchSession._schedule_worker` + wraps the call to capture the post-step chain manifest, pin + `DispatchSession._final_wire_chain_manifest`, and translate the + routine's exception types into the streaming-vs-coroutine end- + of-stream signalling. + + :param routine: + The active routine — either the coroutine itself (non- + streaming) or an `AsyncGenerator` (streaming). + :param streaming: + ``True`` when *routine* is an async generator and each + request drives one ``asend`` / ``athrow``. + :param request: + The decoded request frame. Determines the step's flavor + (``Next``/``Send``/``Throw`` for streaming). + :param work_ctx: + The cached `contextvars.Context` for the request's + chain. The step task is constructed against this context so + backing-variable writes ride the chain's bindings. + :param loop: + The worker loop the step task is bound to. + + :returns: + The value the step coroutine yielded or returned. + + :raises BaseException: + Whatever the routine raises propagates — the caller's + wrapper translates `StopAsyncIteration` (streaming + end-of-stream) and pins the terminal chain manifest for the + routine-failure path. + """ + step_coro: Coroutine[Any, Any, Any] + if streaming: + gen = cast(AsyncGenerator, routine) + if isinstance(request, NextRequestFrame): + step_coro = gen.asend(None) + elif isinstance(request, SendRequestFrame): + step_coro = gen.asend(request.payload) + elif isinstance(request, ThrowRequestFrame): + step_coro = gen.athrow(request.payload) + else: # pragma: no cover + assert_never(request) + else: + step_coro = cast(Coroutine[Any, Any, Any], routine) + + step_task = _create_step_task(step_coro, loop=loop, context=work_ctx) + try: + return await step_task + finally: + # Defensive: cancel the step task if the await was preempted so a + # routine mid-step doesn't run to natural completion after the + # caller has gone away. Under real cancellation + # ``asyncio.Task.cancel`` on the driver propagates through + # ``_fut_waiter`` to the step task, so the step is already + # ``done()`` here — this guard only fires for a step that ignores + # cancellation, a state real routines do not reach, hence the + # no-cover pragma. + if not step_task.done(): # pragma: no cover + step_task.cancel() + try: + await step_task + except (KeyboardInterrupt, SystemExit): + raise + except BaseException: + pass From fccb6fe448564e711ee375ee4cc79e46852157b8 Mon Sep 17 00:00:00 2001 From: Conrad Date: Sat, 27 Jun 2026 17:51:27 -0400 Subject: [PATCH 5/7] refactor!: Align the runtime layers and reshape the public API Read dispatch-time fields off RuntimeContext in the routine wrapper and Task, switch discovery to the protobuf wire namespace, and update the loadbalancer error handling to the ChainSerializationError shape. Reshape the public wool API: remove Context, current_context, copy_context, create_task, and ContextAlreadyBound; alias Token to contextvars.Token; and add to_thread and install_task_factory. --- wool/src/wool/__init__.py | 48 ++++++++----- wool/src/wool/runtime/discovery/lan.py | 4 +- wool/src/wool/runtime/discovery/local.py | 4 +- .../wool/runtime/loadbalancer/roundrobin.py | 3 +- wool/src/wool/runtime/resourcepool.py | 26 +++++-- wool/src/wool/runtime/routine/task.py | 52 ++++++++++++-- wool/src/wool/runtime/routine/wrapper.py | 70 +++++++++++-------- wool/src/wool/runtime/typing.py | 1 + 8 files changed, 139 insertions(+), 69 deletions(-) diff --git a/wool/src/wool/__init__.py b/wool/src/wool/__init__.py index a5a32425..98a1e740 100644 --- a/wool/src/wool/__init__.py +++ b/wool/src/wool/__init__.py @@ -1,22 +1,24 @@ import contextvars +from contextvars import Token from importlib.metadata import PackageNotFoundError from importlib.metadata import version +from typing import TYPE_CHECKING from typing import Final from tblib import pickling_support -from wool.exception import WoolError -from wool.exception import WoolWarning -from wool.runtime.context import Context -from wool.runtime.context import ContextAlreadyBound -from wool.runtime.context import ContextDecodeWarning -from wool.runtime.context import ContextVar -from wool.runtime.context import ContextVarCollision -from wool.runtime.context import RuntimeContext -from wool.runtime.context import Token -from wool.runtime.context import copy_context -from wool.runtime.context import create_task -from wool.runtime.context import current_context +from wool.exceptions import WoolError +from wool.exceptions import WoolWarning +from wool.runtime.context.exceptions import ChainContention +from wool.runtime.context.exceptions import ChainSerializationError +from wool.runtime.context.exceptions import ContextVarCollision +from wool.runtime.context.exceptions import SerializationError +from wool.runtime.context.exceptions import SerializationWarning +from wool.runtime.context.exceptions import TaskFactoryDisplaced +from wool.runtime.context.factory import install_task_factory +from wool.runtime.context.runtime import RuntimeContext +from wool.runtime.context.threading import to_thread +from wool.runtime.context.var import ContextVar from wool.runtime.discovery.base import Discovery from wool.runtime.discovery.base import DiscoveryEvent from wool.runtime.discovery.base import DiscoveryEventType @@ -40,6 +42,7 @@ from wool.runtime.serializer import CloudpickleSerializer from wool.runtime.serializer import Serializer from wool.runtime.typing import Factory +from wool.runtime.typing import UndefinedType from wool.runtime.worker.auth import WorkerCredentials from wool.runtime.worker.base import BoundWorkerFactory from wool.runtime.worker.base import Worker @@ -59,6 +62,9 @@ from wool.runtime.worker.service import BackpressureLike from wool.runtime.worker.service import WorkerService +if TYPE_CHECKING: + from wool.runtime.context.chain import Chain + pickling_support.install() try: @@ -68,6 +74,10 @@ __serializer__: Final[Serializer] = CloudpickleSerializer() +__chain__: Final[contextvars.ContextVar["Chain"]] = contextvars.ContextVar( + "__wool_chain__" +) + __proxy__: Final[contextvars.ContextVar[WorkerProxy | None]] = contextvars.ContextVar( "__proxy__", default=None ) @@ -89,9 +99,8 @@ "BackpressureContext", "BackpressureLike", "BoundWorkerFactory", - "Context", - "ContextAlreadyBound", - "ContextDecodeWarning", + "ChainContention", + "ChainSerializationError", "ContextVar", "ContextVarCollision", "Discovery", @@ -114,11 +123,15 @@ "RoundRobinLoadBalancer", "RpcError", "RuntimeContext", + "SerializationError", + "SerializationWarning", "Serializer", "Task", "TaskException", + "TaskFactoryDisplaced", "Token", "TransientRpcError", + "UndefinedType", "UnexpectedResponse", "WoolError", "WoolWarning", @@ -131,11 +144,10 @@ "WorkerPool", "WorkerProxy", "WorkerService", - "copy_context", - "create_task", - "current_context", "current_task", + "install_task_factory", "routine", + "to_thread", ] for symbol in __all__: diff --git a/wool/src/wool/runtime/discovery/lan.py b/wool/src/wool/runtime/discovery/lan.py index 55197a74..311e33ae 100644 --- a/wool/src/wool/runtime/discovery/lan.py +++ b/wool/src/wool/runtime/discovery/lan.py @@ -24,8 +24,8 @@ from zeroconf.asyncio import AsyncServiceBrowser from zeroconf.asyncio import AsyncZeroconf -from wool.exception import WoolError -from wool.exception import WoolWarning +from wool.exceptions import WoolError +from wool.exceptions import WoolWarning from wool.runtime.discovery.base import Discovery from wool.runtime.discovery.base import DiscoveryEvent from wool.runtime.discovery.base import DiscoveryEventType diff --git a/wool/src/wool/runtime/discovery/local.py b/wool/src/wool/runtime/discovery/local.py index 0f2e147d..5f488863 100644 --- a/wool/src/wool/runtime/discovery/local.py +++ b/wool/src/wool/runtime/discovery/local.py @@ -22,7 +22,7 @@ from watchdog.events import FileSystemEventHandler from watchdog.observers import Observer -from wool.protocol import WorkerMetadata as WorkerMetadataProtobuf +from wool import protocol as wire from wool.runtime.discovery.base import Discovery from wool.runtime.discovery.base import DiscoveryEvent from wool.runtime.discovery.base import DiscoveryEventType @@ -701,7 +701,7 @@ def _deserialize_metadata(self, ref: str): assert memory_block.buf is not None size = struct.unpack_from("I", memory_block.buf, 0)[0] serialized = struct.unpack_from(f"{size}s", memory_block.buf, 4)[0] - protobuf = WorkerMetadataProtobuf.FromString(serialized) + protobuf = wire.WorkerMetadata.FromString(serialized) return WorkerMetadata.from_protobuf(protobuf) def _diff( diff --git a/wool/src/wool/runtime/loadbalancer/roundrobin.py b/wool/src/wool/runtime/loadbalancer/roundrobin.py index 893bc75c..d321b13d 100644 --- a/wool/src/wool/runtime/loadbalancer/roundrobin.py +++ b/wool/src/wool/runtime/loadbalancer/roundrobin.py @@ -37,8 +37,7 @@ class RoundRobinLoadBalancer(LoadBalancerLike): :class:`~wool.runtime.worker.connection.TransientRpcError`) as worker-health concerns. Other exceptions raised by :meth:`WorkerConnection.dispatch` — e.g. a strict-mode - :class:`BaseExceptionGroup` of - :class:`wool.ContextDecodeWarning` peers from a caller-side + :class:`wool.ChainSerializationError` from a caller-side encode failure, or a programming-error :class:`ValueError` — propagate to the caller unwrapped, so a fault that has nothing to do with worker health does not evict diff --git a/wool/src/wool/runtime/resourcepool.py b/wool/src/wool/runtime/resourcepool.py index 6517087c..f82f4209 100644 --- a/wool/src/wool/runtime/resourcepool.py +++ b/wool/src/wool/runtime/resourcepool.py @@ -366,13 +366,25 @@ async def _cleanup(self, key: Any) -> None: except RuntimeError: pass finally: - # Call finalizer - if self._finalizer: - try: - await self._await(self._finalizer, entry.obj) - except Exception: - pass - del self._cache[key] + # Evict from the cache *unconditionally*, before and + # regardless of how the finalizer exits. A finalized + # resource must never remain cached: if the finalizer + # raises — including ``CancelledError`` when cleanup runs + # under a cancelled teardown, which is a ``BaseException`` + # and so escapes ``except Exception`` — the entry must + # still be removed, or a later ``acquire`` hands back a + # torn-down resource (e.g. a closed event loop). The + # inner ``try`` lets the finalizer run for its side + # effects while the outer ``finally`` guarantees eviction + # and lets any cancellation propagate. + try: + if self._finalizer: + try: + await self._await(self._finalizer, entry.obj) + except Exception: + pass + finally: + del self._cache[key] async def _await(self, func: Callable, *args) -> Any: """ diff --git a/wool/src/wool/runtime/routine/task.py b/wool/src/wool/runtime/routine/task.py index b458ad17..638a21ac 100644 --- a/wool/src/wool/runtime/routine/task.py +++ b/wool/src/wool/runtime/routine/task.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import contextvars import logging import traceback from collections.abc import Callable @@ -8,6 +9,7 @@ from contextlib import contextmanager from contextvars import ContextVar from dataclasses import dataclass +from dataclasses import field from inspect import isasyncgen from inspect import isasyncgenfunction from inspect import iscoroutinefunction @@ -30,7 +32,7 @@ import wool from wool import protocol -from wool.runtime.context import RuntimeContext +from wool.runtime.context.runtime import RuntimeContext Args = Tuple Kwargs = Dict @@ -76,7 +78,6 @@ def do_dispatch(flag: bool | None = None, /) -> bool | ContextManager[None]: return _do_dispatch_context_manager(flag) -# public @runtime_checkable class WorkerProxyLike(Protocol): """Protocol defining the interface required by Task for proxy objects. @@ -124,8 +125,8 @@ class Task(Generic[W]): Descriptive label identifying the call site, formatted as ``module.qualname:lineno`` by the ``@routine`` wrapper. :param runtime_context: - Snapshot of the active :class:`RuntimeContext` at construction - time, captured by :meth:`__post_init__` if not supplied. Ships + The active :class:`RuntimeContext` captured at construction + time by :meth:`__post_init__` if not supplied. Ships with the dispatch frame so the worker side can restore wire defaults (notably ``dispatch_timeout``) for the routine's execution. @@ -141,10 +142,18 @@ class Task(Generic[W]): exception: TaskException | None = None tag: str | None = None runtime_context: RuntimeContext | None = None + # Declare the single-entry guard slot at class scope, matching + # :class:`RuntimeContext._dispatch_timeout_token`'s precedent. + # ``init=False`` keeps the constructor surface unchanged; ``repr=False`` + # / ``compare=False`` keep the field out of repr() and equality so + # internal lifecycle state doesn't leak into either. + _task_token: contextvars.Token[Task | None] | None = field( + default=None, init=False, repr=False, compare=False + ) def __post_init__(self): """Validate the proxy, capture the calling task's id, and seed - a :class:`RuntimeContext` snapshot if one was not supplied. + a :class:`RuntimeContext` if one was not supplied. The runtime-context seed lets a Task built outside an active :class:`RuntimeContext` scope still ship the wire defaults @@ -154,7 +163,12 @@ def __post_init__(self): raise TypeError( f"proxy must conform to WorkerProxyLike, got {type(self.proxy).__name__}" ) - if caller := _current_task.get(): + # Auto-capture the calling task's id only when ``caller`` was + # not supplied explicitly — match the ``runtime_context`` seed + # pattern below. Unconditionally overwriting an explicit + # constructor argument violates the principle of least + # surprise for a public dataclass field. + if self.caller is None and (caller := _current_task.get()): self.caller = caller.id if self.runtime_context is None: self.runtime_context = RuntimeContext.get_current() @@ -165,8 +179,20 @@ def __enter__(self) -> Task: Bind this task to the current context. On exit, re-binds the calling task and records any propagating exception on :attr:`exception` for wire transport. + + :raises RuntimeError: + If the instance is already inside a ``with`` block + (``self._task_token is not None``). ``Task`` is + block-scoped and single-use as a context manager; + re-entering would leak the outer token. """ logging.debug(f"Entering {self.__class__.__name__} with ID {self.id}") + if self._task_token is not None: + raise RuntimeError( + f"{self.__class__.__name__} is already active in a `with` " + "block; instances are block-scoped and single-use as context " + "managers" + ) self._task_token = _current_task.set(self) return self @@ -203,7 +229,19 @@ def __exit__( for y in x.split("\n") ], ) - _current_task.reset(self._task_token) + # Guard against __exit__ invoked without a preceding + # successful __enter__. ``_task_token`` is ``None`` until + # __enter__ runs; without this guard, ``_current_task.reset(None)`` + # raises ``TypeError: instance of Token expected`` at runtime + # and pyright (rightly) flags the type as + # ``Token | None`` → ``Token``. Mirrors __enter__'s + # re-entry guard symmetrically: __enter__ raises on double- + # enter, __exit__ no-ops on double-exit. + token = self._task_token + if token is None: + return False + _current_task.reset(token) + self._task_token = None return False @classmethod diff --git a/wool/src/wool/runtime/routine/wrapper.py b/wool/src/wool/runtime/routine/wrapper.py index 185dcc96..006bb3d0 100644 --- a/wool/src/wool/runtime/routine/wrapper.py +++ b/wool/src/wool/runtime/routine/wrapper.py @@ -14,7 +14,7 @@ from uuid import uuid4 import wool -from wool.runtime.context import dispatch_timeout +from wool.runtime.context.runtime import dispatch_timeout from wool.runtime.routine.task import Task from wool.runtime.routine.task import do_dispatch @@ -112,41 +112,49 @@ async def ping() -> bool: ... consider using shared memory or passing references instead of the data itself. - **Context propagation and decode-failure semantics:** + **Chain propagation and decode-failure semantics:** - Routines run inside their own :class:`wool.Context` on the worker, - which receives the caller's :class:`wool.ContextVar` snapshot on - dispatch and ships post-run mutations back on the response. Wool - treats this propagation as **ancillary state** — separate from the - routine's primary signal (its return value or raised exception): + Routines run on the worker under the caller's Wool chain: the + caller's :class:`wool.ContextVar` chain manifest is decoded from the + dispatch frame and the routine runs under it, with post-run + mutations shipped back on the response. Wool treats this + propagation as **ancillary state** — separate from the routine's + primary signal (its return value or raised exception): - - **Primary signal preservation.** A failure to decode the wire - :class:`wool.Context` (e.g., cross-version pickle skew on a - single var value) never preempts the routine's outcome. The + - **Primary signal preservation.** A failure to decode the chain + manifest (e.g., cross-version pickle skew on a single + variable value) never preempts the routine's outcome. The result is still returned; a routine exception is still raised. - **Visibility via warning.** Every ancillary decode failure emits - a :class:`wool.ContextDecodeWarning` on the side that observed - the failure. On the caller, an exception-frame decode failure - is bundled alongside the routine's raised exception in a - :class:`BaseExceptionGroup` so both signals reach user code — - ``except*`` splits the worker exception from the per-entry - decode failures. + a :class:`wool.SerializationWarning` on the side that observed + the failure. Under the default warning filter the routine + exception's type is preserved — ``except RoutineError:`` still + matches as written — and the warnings remain visible in the + Python warnings stream. - **Strict mode (opt-in).** Promote the warning to an exception to treat ancillary failures as fatal:: import warnings import wool - warnings.filterwarnings("error", category=wool.ContextDecodeWarning) - - In strict mode :meth:`wool.Context.from_protobuf` aggregates - the per-entry exceptions into a :class:`BaseExceptionGroup` - that the caller observes in place of the primary signal: - - * Result frames lose the routine's return value — the group - raises instead. - * Exception frames preserve the routine exception as a peer - of the decode-failure group, so ``except*`` recovers both. + warnings.filterwarnings("error", category=wool.SerializationWarning) + + In strict mode the chain-manifest decode aggregates the per-entry + warnings into a single :class:`wool.ChainSerializationError` (a + :class:`wool.WoolError` subclass, catchable via + :class:`wool.SerializationError`, with the warnings on + :attr:`~wool.ChainSerializationError.warnings`): + + * **Result frames** lose the routine's return value — the + :class:`~wool.ChainSerializationError` raises in place of the + primary signal so the caller observes every bad entry, not + just the first. + * **Exception frames** still preserve the routine exception as + the primary signal; the :class:`~wool.ChainSerializationError` + rides on it as ``__cause__`` via ``raise routine_exc from + decode_err``. The routine exception's class is preserved so + ``except RoutineError:`` keeps matching — no migration to + ``except*`` required. Callers that want both result preservation *and* failure visibility should instead use ``warnings.catch_warnings(record=True)`` @@ -157,12 +165,12 @@ async def ping() -> bool: ... ``multiprocessing`` propagates to spawned worker subprocesses by default:: - PYTHONWARNINGS = "error::wool.ContextDecodeWarning" + PYTHONWARNINGS = "error::wool.SerializationWarning" - When the worker promotes the warning, wool ships it back - through the routine-exception channel so the caller catches - the same ``wool.ContextDecodeWarning`` class — no - :class:`grpc.aio.AioRpcError` to special-case. + When the worker promotes the warning, wool ships the resulting + :class:`wool.ChainSerializationError` back through the + routine-exception channel so the caller catches the same error + class — no :class:`grpc.aio.AioRpcError` to special-case. Example usage with coroutines: diff --git a/wool/src/wool/runtime/typing.py b/wool/src/wool/runtime/typing.py index bb65958d..ffae2d4c 100644 --- a/wool/src/wool/runtime/typing.py +++ b/wool/src/wool/runtime/typing.py @@ -16,6 +16,7 @@ PassthroughWrapper = Callable[[F], F] +# public @final class UndefinedType(Enum): Undefined = "Undefined" From f4182dc3a10ab47dba4e6d72592057f5c6bd2865 Mon Sep 17 00:00:00 2001 From: Conrad Date: Sat, 27 Jun 2026 17:51:27 -0400 Subject: [PATCH 6/7] test: Rebuild the suite around the stdlib-aligned chain model Decompose the context tests into per-class mirror modules (chain, manifest, runtime, var, token, factory, guard, exceptions), add the stdlib_parity suite pinning Wool against the analogous contextvars and asyncio surfaces under both loops, add Frame round-trip coverage, and rewrite the worker and integration tests for the per-frame architecture. The unit suite holds the 98% coverage gate. --- wool/tests/conftest.py | 16 + wool/tests/helpers.py | 57 +- wool/tests/integration/_collision_fixtures.py | 2 +- wool/tests/integration/conftest.py | 104 +- wool/tests/integration/routines.py | 308 +- .../test_context_var_propagation.py | 2123 ++++++------ .../integration/test_pool_composition.py | 109 +- wool/tests/integration/test_unified_driver.py | 18 +- wool/tests/protocol/test_wire.py | 21 +- wool/tests/runtime/context/conftest.py | 7 +- wool/tests/runtime/context/test_base.py | 2965 ----------------- wool/tests/runtime/context/test_chain.py | 423 +++ wool/tests/runtime/context/test_exceptions.py | 457 +++ wool/tests/runtime/context/test_factory.py | 1233 +++++++ wool/tests/runtime/context/test_guard.py | 386 +++ wool/tests/runtime/context/test_manifest.py | 1170 +++++++ wool/tests/runtime/context/test_registry.py | 71 - wool/tests/runtime/context/test_runtime.py | 205 ++ wool/tests/runtime/context/test_token.py | 439 +-- wool/tests/runtime/context/test_var.py | 1213 ++++--- wool/tests/runtime/discovery/test_base.py | 6 +- wool/tests/runtime/discovery/test_local.py | 23 + .../runtime/loadbalancer/test_roundrobin.py | 164 +- wool/tests/runtime/routine/test_task.py | 292 +- wool/tests/runtime/test_resourcepool.py | 220 +- wool/tests/runtime/worker/conftest.py | 17 +- wool/tests/runtime/worker/test_auth.py | 87 +- wool/tests/runtime/worker/test_connection.py | 886 +++-- wool/tests/runtime/worker/test_frame.py | 801 +++++ wool/tests/runtime/worker/test_service.py | 2516 +++++--------- wool/tests/runtime/worker/test_session.py | 993 +++--- wool/tests/stdlib_parity/conftest.py | 78 + .../stdlib_parity/test_async_gen_aclose.py | 77 - .../stdlib_parity/test_context_parity.py | 403 +++ .../stdlib_parity/test_executor_offload.py | 332 ++ .../stdlib_parity/test_loop_callbacks.py | 411 +++ .../tests/stdlib_parity/test_task_creation.py | 979 ++++++ .../{test_exception.py => test_exceptions.py} | 14 +- wool/tests/test_public.py | 67 +- 39 files changed, 11732 insertions(+), 7961 deletions(-) delete mode 100644 wool/tests/runtime/context/test_base.py create mode 100644 wool/tests/runtime/context/test_chain.py create mode 100644 wool/tests/runtime/context/test_exceptions.py create mode 100644 wool/tests/runtime/context/test_factory.py create mode 100644 wool/tests/runtime/context/test_guard.py create mode 100644 wool/tests/runtime/context/test_manifest.py delete mode 100644 wool/tests/runtime/context/test_registry.py create mode 100644 wool/tests/runtime/context/test_runtime.py create mode 100644 wool/tests/runtime/worker/test_frame.py create mode 100644 wool/tests/stdlib_parity/conftest.py delete mode 100644 wool/tests/stdlib_parity/test_async_gen_aclose.py create mode 100644 wool/tests/stdlib_parity/test_context_parity.py create mode 100644 wool/tests/stdlib_parity/test_executor_offload.py create mode 100644 wool/tests/stdlib_parity/test_loop_callbacks.py create mode 100644 wool/tests/stdlib_parity/test_task_creation.py rename wool/tests/{test_exception.py => test_exceptions.py} (82%) diff --git a/wool/tests/conftest.py b/wool/tests/conftest.py index 05016959..70a5988f 100644 --- a/wool/tests/conftest.py +++ b/wool/tests/conftest.py @@ -1,3 +1,4 @@ +import contextvars import logging from typing import Final @@ -13,6 +14,21 @@ _ignore_unknown = True +@pytest.hookimpl(tryfirst=True) +def pytest_pyfunc_call(pyfuncitem): + """Run each sync test in a fresh ``contextvars.copy_context`` so a + ``wool.ContextVar`` set in one test cannot leak its armed chain into + the next. Async tests self-isolate — their task copies the context — + and are owned by pytest-asyncio's ``pytest_pyfunc_call``, so defer. + """ + if pyfuncitem.get_closest_marker("asyncio") is not None: + return None + argnames = pyfuncitem._fixtureinfo.argnames + testargs = {name: pyfuncitem.funcargs[name] for name in argnames} + contextvars.copy_context().run(pyfuncitem.obj, **testargs) + return True + + @pytest.hookimpl(tryfirst=True) def pytest_configure(config): if config.getoption("--wait-for-debugger"): diff --git a/wool/tests/helpers.py b/wool/tests/helpers.py index 17b796d9..a6012fc9 100644 --- a/wool/tests/helpers.py +++ b/wool/tests/helpers.py @@ -1,33 +1,36 @@ +import uuid +from collections.abc import Generator from contextlib import contextmanager -from typing import Iterator -from uuid import UUID -from wool import protocol -from wool.runtime.context import Context -from wool.runtime.context import attached +import wool + + +def _unique(stem: str) -> str: + """Return a process-unique variable name to avoid registry collisions.""" + return f"{stem}_{uuid.uuid4().hex}" @contextmanager -def scoped_context(id: UUID | None = None) -> Iterator[Context]: - """Test helper — install a wool.Context for the duration of the block. - - Mints a fresh chain id by default. Pass *id* to construct a - Context with a specific chain id, used by tests that exercise - chain-id-dependent semantics (e.g. ContextVar.reset's same-id - fallback). The id-bearing path goes through the public - ``Context.from_protobuf`` rather than the private - ``_reconstitute`` builder, since wool deliberately does not - expose an in-process id-only constructor (that would invite - duplicate-id Contexts and undercut the single-task-per-Context - invariant). On exit the prior scope's Context is restored. - - Attaches without claiming the single-task guard so tests can - invoke ``Context.run`` / ``attached(ctx)`` on the yielded - Context themselves. +def scoped_context() -> Generator[None]: + """Test helper — bracket a block of Wool chain mutations. + + Per-test isolation lives in the ``pytest_pyfunc_call`` hook in + ``tests/conftest.py``, which runs each sync test in a fresh + :func:`contextvars.copy_context` (async tests self-isolate via their + task's context copy). With ``__chain__`` typed + :class:`~wool.runtime.context.chain.Chain` there is no settable + "unarmed" value to install in place, so this manager no longer + disarms; it is retained as a no-op scope around chain mutations. + """ + yield + + +def context_is_unarmed() -> bool: + """Test helper — return whether the current context carries no Wool state. + + A module-level, picklable function so it can be dispatched to a + :class:`~concurrent.futures.ProcessPoolExecutor` worker, where it + proves a bare ``run_in_executor`` offload carries no Wool chain + into a worker process. """ - if id is None: - ctx = Context() - else: - ctx = Context.from_protobuf(protocol.Context(id=id.hex)) - with attached(ctx, guarded=False): - yield ctx + return wool.__chain__.get(None) is None diff --git a/wool/tests/integration/_collision_fixtures.py b/wool/tests/integration/_collision_fixtures.py index c98efe7b..b3c2d0e1 100644 --- a/wool/tests/integration/_collision_fixtures.py +++ b/wool/tests/integration/_collision_fixtures.py @@ -42,7 +42,7 @@ # Strong-ref pin dict, populated on the worker side so constructed # ContextVar instances outlive their routine's return and remain in -# :attr:`wool.ContextVar._registry` (a WeakValueDictionary) when the +# the process-wide var_registry (a WeakValueDictionary) when the # sibling routine runs. _PINS: dict[str, wool.ContextVar] = {} diff --git a/wool/tests/integration/conftest.py b/wool/tests/integration/conftest.py index b67f8fc7..fd96ef86 100644 --- a/wool/tests/integration/conftest.py +++ b/wool/tests/integration/conftest.py @@ -29,8 +29,7 @@ from cryptography.x509.oid import NameOID from hypothesis import strategies as st -import wool -from wool.runtime.context import dispatch_timeout +from wool.runtime.context.runtime import dispatch_timeout from wool.runtime.discovery.local import LocalDiscovery from wool.runtime.loadbalancer.roundrobin import RoundRobinLoadBalancer from wool.runtime.worker.auth import WorkerCredentials @@ -128,7 +127,7 @@ class StrictWarnings(Enum): """Documents whether warnings are promoted to errors during dispatch. ``OFF`` leaves the ambient warning filter untouched. ``ALL_DECODABLE`` - promotes :class:`wool.ContextDecodeWarning` to an error for the + promotes :class:`wool.SerializationWarning` to an error for the duration of the dispatch and ships caller vars whose values can be cleanly decoded on the worker — under strict mode the dispatch is expected to complete without raising, proving the happy-path @@ -183,7 +182,7 @@ class Scenario: # and serves as a documentation annotation on test IDs — it does # not drive ``build_pool_from_scenario`` behavior today. Tests # that vary along this axis set it explicitly so the pytest ID - # reflects the behavior under exercise. + # reflects the warning regime under exercise. strict_warnings: StrictWarnings | None = None def __or__(self, other: Scenario) -> Scenario: @@ -245,14 +244,15 @@ def default_scenario( ctx_var_2: ContextVarPattern = ContextVarPattern.NONE, ctx_var_3: ContextVarPattern = ContextVarPattern.NONE, quorum: QuorumMode = QuorumMode.DEFAULT, + timeout: TimeoutKind = TimeoutKind.NONE, strict_warnings: StrictWarnings | None = None, ) -> Scenario: """Build a fully-populated :class:`Scenario` with sensible defaults. Used by happy-path integration tests that want to vary only one or two dimensions while leaving the rest at their canonical values. The - optional documentation field (``strict_warnings``) defaults to ``None`` - so it remains absent from the pytest ID unless explicitly set. + optional documentation field ``strict_warnings`` defaults to + ``None`` so it remains absent from the pytest ID unless explicitly set. """ return Scenario( shape=shape, @@ -261,7 +261,7 @@ def default_scenario( lb=LbFactory.CLASS_REF, credential=CredentialType.INSECURE, options=WorkerOptionsKind.DEFAULT, - timeout=TimeoutKind.NONE, + timeout=timeout, binding=binding, lazy=lazy, backpressure=backpressure, @@ -297,11 +297,17 @@ def subscribe(self, filter=None): @asynccontextmanager -async def build_pool_from_scenario(scenario, credentials_map): +async def build_pool_from_scenario(scenario, credentials_map, *, backpressure=None): """Build and enter a WorkerPool from a complete Scenario. Resolves each dimension to its concrete runtime value and yields the entered pool context. + + :param backpressure: + Optional admission-control hook that overrides the hook the + ``BackpressureMode`` dimension would otherwise resolve. Tests + that need a bespoke (e.g. context-var-aware) hook pass it here + rather than building a :class:`WorkerPool` by hand. """ missing = [ f.name @@ -403,13 +409,16 @@ async def _lan_async_cm(): lazy = scenario.lazy is LazyMode.LAZY - match scenario.backpressure: - case BackpressureMode.SYNC: - bp_hook = _sync_accept_hook - case BackpressureMode.ASYNC: - bp_hook = _async_accept_hook - case _: - bp_hook = None + if backpressure is not None: + bp_hook = backpressure + else: + match scenario.backpressure: + case BackpressureMode.SYNC: + bp_hook = _sync_accept_hook + case BackpressureMode.ASYNC: + bp_hook = _async_accept_hook + case _: + bp_hook = None match scenario.quorum: case QuorumMode.ABOVE_DEFAULT: @@ -717,15 +726,15 @@ def _assert_caller_vars(patterns, initial_values, *, shape=None): what the caller observes via back-propagation. The inner worker's mutations do not reach the outer worker's copied context because the nested dispatch crosses event loop boundaries; only the outer - worker's own writes are captured in its snapshot. + worker's own writes are captured in its context. For NESTED_ASYNC_GEN shapes, the async generator's final context - snapshot is sent with the last yield, not after exhaustion. + is sent with the last yield, not after exhaustion. Post-teardown mutations (UPSTREAM_RESET) are not visible to the caller because there is no subsequent yield to carry them. For DOWNSTREAM_OVERWRITE and DOWNSTREAM_RESET the outer worker's ``_pre_nested_setup`` runs before the first inner yield, so those - values ARE captured in per-yield snapshots. + values ARE captured in per-yield contexts. """ is_nested_gen = shape is RoutineShape.NESTED_ASYNC_GEN for var_name, pattern in patterns.items(): @@ -745,7 +754,7 @@ def _assert_caller_vars(patterns, initial_values, *, shape=None): case ContextVarPattern.DOWNSTREAM_OVERWRITE: # Under stdlib-mirror semantics, wool routines run in # the caller's stdlib Context — outer sets, inner - # overwrites in the same Context, and back-prop + # overwrites in the same Chain, and back-prop # carries the final state (inner's overwrite) to the # caller. Matches `await coro()` semantics. assert var.get() == f"inner-overwrite-{var_name}", ( @@ -764,7 +773,7 @@ def _assert_caller_vars(patterns, initial_values, *, shape=None): if is_nested_gen: # Async gen: inner sets "inner-set-" before # yielding; that value is captured in a - # per-yield snapshot and back-propagated to the + # per-yield context and back-propagated to the # caller. _post_nested_teardown runs after the # gen exhausts — no subsequent yield carries # its mutation — so the caller's final visible @@ -776,7 +785,7 @@ def _assert_caller_vars(patterns, initial_values, *, shape=None): else: # Coroutine: inner sets, then _post_nested_teardown # overwrites with outer-reset before the response - # snapshot ships. + # context ships. assert var.get() == f"outer-reset-{var_name}", ( f"UPSTREAM_RESET: expected outer-reset value, got {var.get()!r}" ) @@ -784,6 +793,10 @@ def _assert_caller_vars(patterns, initial_values, *, shape=None): # After iteration, caller should see the last # step value back-propagated. pass # validated inline during iteration + case ContextVarPattern.MID_STREAM_FORWARD: + # Forward propagation is asserted worker-side per + # frame — a mismatch raises a dispatch exception. + pass # validated worker-side during iteration def _cleanup_caller_vars(tokens): @@ -833,6 +846,30 @@ async def invoke_routine(scenario): gen = routine(obj, 3) else: gen = routine(3) + forward = { + k: v + for k, v in patterns.items() + if v is ContextVarPattern.MID_STREAM_FORWARD + } + # MID_STREAM_FORWARD drives the generator one step at a + # time, mutating the var to a per-step value before + # each ``__anext__`` so the worker frame observes the + # forward-propagated value. + if forward: + step = 0 + while True: + for var_name in forward: + _CALLER_VARS[var_name].set(f"step-{step}") + try: + item = await gen.__anext__() + except StopAsyncIteration: + break + collected.append(item) + step += 1 + assert collected == [0, 1, 2] + if patterns: + _assert_caller_vars(patterns, initial_values, shape=shape) + return collected step = 0 async for item in gen: collected.append(item) @@ -939,7 +976,7 @@ async def invoke_routine(scenario): # ``TENANT_ID`` to a per-step value and nested-dispatches # ``get_tenant_id`` to read it back. Locks in the second # consequence of the #176 fix: ``_current_task`` and - # ``wool.Context`` remain active across the generator's + # ``wool.Chain`` remain active across the generator's # lifespan, so nested dispatches from inside a streaming # routine carry the streaming task as caller. collected = [ @@ -1034,6 +1071,13 @@ def _select_routine(shape, binding): ContextVarPattern.DOWNSTREAM_RESET, ContextVarPattern.UPSTREAM_RESET, ) +# MID_STREAM_FORWARD requires the caller to mutate the var before each +# ``__anext__`` and the worker to assert the forward-propagated value +# per frame. Only the plain ``async for`` shape drives the generator +# one step at a time through ``invoke_routine``; the asend/athrow/ +# aclose scripts and the single-yield shape do not, so the pattern is +# constrained to ASYNC_GEN_ANEXT. +_MID_STREAM_FORWARD_SHAPES = (RoutineShape.ASYNC_GEN_ANEXT,) def _is_grpc_internal(exc: BaseException) -> bool: @@ -1063,7 +1107,8 @@ def _pairwise_filter(row): always uses module-level routines) - D11/D12/D13 (ctx_var_1/2/3): DOWNSTREAM_OVERWRITE, DOWNSTREAM_RESET, UPSTREAM_RESET only valid with NESTED_* shapes; - PER_YIELD only valid with ASYNC_GEN_* shapes + PER_YIELD and MID_STREAM_FORWARD only valid with ASYNC_GEN_* + shapes - D14 (quorum) ABOVE_DEFAULT (quorum=2) requires PoolMode.EPHEMERAL — every other pool mode in the builder produces only one worker, so quorum=2 would block forever. @@ -1112,7 +1157,7 @@ def _pairwise_filter(row): and binding is not RoutineBinding.MODULE_FUNCTION ): return False - # Context var pattern constraints (indices 10, 11, 12) + # Chain var pattern constraints (indices 10, 11, 12) shape = row[0] for idx in (10, 11, 12): if len(row) > idx: @@ -1121,6 +1166,11 @@ def _pairwise_filter(row): return False if pattern is ContextVarPattern.PER_YIELD and shape not in _ASYNC_GEN_SHAPES: return False + if ( + pattern is ContextVarPattern.MID_STREAM_FORWARD + and shape not in _MID_STREAM_FORWARD_SHAPES + ): + return False # NESTED_ASYNC_GEN_READBACK's routine self-mutates # ``TENANT_ID`` per iteration; combining with framework- # driven patterns over the same caller vars would conflict. @@ -1273,6 +1323,8 @@ def _draw_ctx_var_pattern(draw): valid = [p for p in valid if p not in _NESTED_ONLY_PATTERNS] if shape not in _ASYNC_GEN_SHAPES: valid = [p for p in valid if p is not ContextVarPattern.PER_YIELD] + if shape not in _MID_STREAM_FORWARD_SHAPES: + valid = [p for p in valid if p is not ContextVarPattern.MID_STREAM_FORWARD] return draw(st.sampled_from(valid)) ctx_var_1 = _draw_ctx_var_pattern(draw) @@ -1398,10 +1450,10 @@ async def _clear_channel_pool(): # asyncio.Task whose ``contextvars.Context`` is a copy, so # wool.ContextVar mutations stay scoped to that copy and don't leak # to the next test. Sync integration helpers run in the pytest main -# Context — if they ever mutate routine-level vars, add an explicit +# Chain — if they ever mutate routine-level vars, add an explicit # per-test teardown at that site rather than reviving a global # autouse cleanup. (The previous sync ``_clear_proxy_context`` -# autouse fixture mutated the pytest main Context, which async test +# autouse fixture mutated the pytest main Chain, which async test # tasks never observe; it was load-bearing in appearance only.) diff --git a/wool/tests/integration/routines.py b/wool/tests/integration/routines.py index 781501cf..2c9b12c4 100644 --- a/wool/tests/integration/routines.py +++ b/wool/tests/integration/routines.py @@ -34,6 +34,7 @@ class must live in a module that is importable on the worker. DOWNSTREAM_RESET = auto() UPSTREAM_RESET = auto() PER_YIELD = auto() + MID_STREAM_FORWARD = auto() # Module-level wool.ContextVars used by the propagation integration tests. @@ -41,7 +42,7 @@ class must live in a module that is importable on the worker. # worker when it unpickles a routine defined here. The import causes # the worker's own wool.ContextVar instances to self-register in the # process-wide registry under their ``":"`` keys; -# caller-side snapshots with matching keys resolve to the same logical +# caller-side contexts with matching keys resolve to the same logical # var on the worker. TENANT_ID: wool.ContextVar[str] = wool.ContextVar("tenant_id", default="unknown") REGION: wool.ContextVar[str] = wool.ContextVar("region", default="global") @@ -58,7 +59,7 @@ class must live in a module that is importable on the worker. _RESET_TOKENS: wool.ContextVar[dict] = wool.ContextVar("_reset_tokens", default={}) -def _resolve_var(name: str) -> "wool.ContextVar": +def _resolve_var(name: str) -> wool.ContextVar: """Look up a wool.ContextVar by logical name at call time. Resolves against the module's live globals each call. This @@ -87,6 +88,18 @@ def _execute_patterns(patterns, *, step=None): case ContextVarPattern.PER_YIELD: if step is not None: var.set(f"step-{step}") + case ContextVarPattern.MID_STREAM_FORWARD: + # Forward pattern: the caller mutates the var to a + # per-step value before each ``__anext__``; the worker + # asserts the forward-propagated value reached this + # frame. A mismatch raises, surfacing as a dispatch + # exception on the caller's ``__anext__``. + if step is not None: + observed = var.get() + assert observed == f"step-{step}", ( + f"MID_STREAM_FORWARD {var_name}: worker expected " + f"forward-propagated 'step-{step}', got {observed!r}" + ) case ContextVarPattern.DOWNSTREAM_OVERWRITE: var.set(f"inner-overwrite-{var_name}") case ContextVarPattern.DOWNSTREAM_RESET: @@ -130,14 +143,15 @@ def _pre_nested_setup(patterns): def _post_nested_teardown(patterns): """Clean up after a nested dispatch returns (outer worker side). - For UPSTREAM_RESET the outer worker resets the var that the inner - worker set, using a token captured before the nested call. + For UPSTREAM_RESET the outer worker overwrites the var that the + inner worker set with a sentinel value, signalling that the outer + worker has completed its post-nested teardown step. """ for var_name, pattern in patterns.items(): var = _resolve_var(var_name) if pattern is ContextVarPattern.UPSTREAM_RESET: # The inner worker set this var; the outer worker now - # resets it back to whatever it was before. + # overwrites it with a sentinel to mark teardown. var.set(f"outer-reset-{var_name}") @@ -149,6 +163,18 @@ def _post_nested_teardown(patterns): } ) +# Patterns executed once per generator step (with a ``step`` index) +# rather than once before the first yield. ``PER_YIELD`` mutates the +# var per step (back-propagation direction); ``MID_STREAM_FORWARD`` +# asserts the caller's per-step mutation reached the worker frame +# (forward-propagation direction). +_PER_STEP_PATTERNS: frozenset[ContextVarPattern] = frozenset( + { + ContextVarPattern.PER_YIELD, + ContextVarPattern.MID_STREAM_FORWARD, + } +) + def context_pattern_aware(fn): """Decorator that reads TEST_PATTERNS and executes context-var mutations. @@ -209,18 +235,14 @@ async def wrapper(*args, **kwargs): inner_patterns["_inner"] = True TEST_PATTERNS.set(inner_patterns) - # Execute simple patterns (except PER_YIELD). + # Execute simple patterns (except the per-step ones). non_yield_simple = { - k: v - for k, v in non_nested.items() - if v is not ContextVarPattern.PER_YIELD + k: v for k, v in non_nested.items() if v not in _PER_STEP_PATTERNS } if non_yield_simple: _execute_patterns(non_yield_simple) - per_yield = { - k: v for k, v in non_nested.items() if v is ContextVarPattern.PER_YIELD - } + per_yield = {k: v for k, v in non_nested.items() if v in _PER_STEP_PATTERNS} step = 0 try: value = await gen.__anext__() @@ -631,7 +653,7 @@ async def streaming_nested_get_tenant_id(count: int): Mutates TENANT_ID to a per-iteration value, dispatches ``get_tenant_id`` (nested), and yields the observed value. - Verifies that the ``_current_task`` and ``wool.Context`` set by + Verifies that the ``_current_task`` and chain context set by the worker for the streaming routine remain active across the generator's lifespan — without that, the nested dispatch cannot find the caller's task and the propagation chain breaks. @@ -662,7 +684,7 @@ async def read_multi_vars() -> tuple[str, str]: """Coroutine that reads TENANT_ID and REGION simultaneously. Used to verify that multiple wool.ContextVars in the registry are - all snapshotted and restored correctly through a single dispatch. + all propagated and restored correctly through a single dispatch. """ return TENANT_ID.get(), REGION.get() @@ -701,14 +723,37 @@ async def mutate_on_each_yield(count: int): @wool.routine -async def return_current_context_id_hex() -> str: - """Coroutine that returns ``wool.current_context().id.hex`` from the worker. +async def return_current_chain_id_hex() -> str: + """Coroutine that returns the worker-side context ``chain_id`` hex. + + Used to verify that a dispatch boundary correctly arms the worker + on the caller's chain. The worker installs the caller's decoded + context via ``install_context``, so its ``chain_id`` equals the + caller's (or the child's, when dispatched from an asyncio child + task that has forked the chain). + """ + + context = wool.__chain__.get(None) + assert context is not None + return context.id.hex + - Used to verify that the caller's context id propagates through a - dispatch boundary — the worker should observe the same id as the - caller captured pre-dispatch. +@wool.routine +async def stream_chain_id_hex(count: int): + """Async generator that yields the worker-side context ``chain_id`` hex. + + The streaming counterpart of :func:`return_current_chain_id_hex`. A + ``sleep(0)`` between yields forces the generator to suspend, so each + read happens after a genuine resume across an ``__anext__`` + boundary. Used to verify that two interleaved async-generator + dispatches both observe the caller's shared chain id. """ - return wool.current_context().id.hex + + for _ in range(count): + context = wool.__chain__.get(None) + assert context is not None + yield context.id.hex + await asyncio.sleep(0) @wool.routine @@ -752,88 +797,12 @@ async def touch_argument(argument): return argument -@wool.routine -async def accept_token_and_reset(token: wool.Token) -> str: - """Coroutine that calls ``var.reset(token)`` on the worker and returns var.get(). - - The caller sets ``TENANT_ID`` before dispatch and passes the - resulting Token. On the worker, ``reset`` restores the pre-set - value. The returned value is the post-reset read, which should - equal the default (or prior value) captured at ``set`` time. - """ - TENANT_ID.reset(token) - return TENANT_ID.get() - - -@wool.routine -async def accept_token_and_double_reset(token: wool.Token) -> str: - """Coroutine that calls ``reset(token)`` twice, the second call raising. - - Returns the caught exception's repr so the caller can verify a - RuntimeError ("Token has already been used") was raised inside the - routine. - """ - TENANT_ID.reset(token) - try: - TENANT_ID.reset(token) - except RuntimeError as exc: - return repr(exc) - return "no-error" - - -@wool.routine -async def accept_token_and_reset_on_yield(token: wool.Token): - """Async generator that resets *token* between two yields. - - Yields "before" first, then consumes the token via - ``TENANT_ID.reset(token)``, then yields "after". The caller can - observe per-yield back-propagation of the consumed-token state - across the reset boundary. - """ - yield "before" - TENANT_ID.reset(token) - yield "after" - - -@wool.routine -async def read_value_and_attempt_reset(token: wool.Token) -> tuple[str, str]: - """Coroutine that reads ``TENANT_ID`` and attempts to reset with *token*. - - Returns a ``(value, reset_outcome)`` pair. ``value`` is the - worker-side observation of ``TENANT_ID.get()`` before any reset - attempt; ``reset_outcome`` is the repr of the RuntimeError raised - when ``TENANT_ID.reset(token)`` is invoked against an - already-consumed Token, or the literal ``"no-error"`` string if - no exception was raised. Used to verify that a single dispatch - carrying both a value binding and the corresponding - consumed-token id under the same wire entry round-trips both - pieces of state to the worker. - """ - value = TENANT_ID.get() - try: - TENANT_ID.reset(token) - except RuntimeError as exc: - return value, repr(exc) - return value, "no-error" - - -@wool.routine -async def mint_tenant_token(value: str) -> wool.Token: - """Coroutine that mints a Token via ``TENANT_ID.set`` and returns it. - - Returns the Token straight to the caller so cross-process - lifecycle scenarios can begin from a worker-minted Token - instead of a caller-minted one. - """ - return TENANT_ID.set(value) - - @wool.routine async def mutate_then_raise_tenant_id(value: str) -> str: """Coroutine that sets ``TENANT_ID`` to *value* then raises ValueError. Used by exception-path back-propagation tests — the worker's mutation - should reach the caller via the exception's snapshot path. + should reach the caller via the exception's context path. """ TENANT_ID.set(value) raise ValueError("mutate_then_raise_tenant_id") @@ -846,7 +815,7 @@ async def yield_then_mutate_and_raise(sentinel: str): Yields ``"ready"`` first so the caller can iterate once, then sets ``TENANT_ID`` to *sentinel* before raising ``ValueError``. Used to verify that mid-stream mutations are back-propagated through the - exception snapshot. + exception context. """ yield "ready" TENANT_ID.set(sentinel) @@ -969,7 +938,7 @@ async def declare_and_read_unregistered_key( Exercises the stub-promotion path: the wire frame creates a stub in the registry with the caller's value applied to the active - Context; the in-routine ``ContextVar(name, namespace=...)`` call + context; the in-routine ``ContextVar(name, namespace=...)`` call finds the stub and promotes it in place, preserving the wire-applied value on the new authoritative declaration. """ @@ -981,12 +950,12 @@ async def declare_and_read_unregistered_key( async def read_dispatch_timeout() -> float | None: """Return the worker-side value of the ambient ``dispatch_timeout``. - Verifies that a caller-side :class:`wool.RuntimeContext` snapshot + Verifies that a caller-side :class:`wool.RuntimeContext` rides through the dispatch wire frame and is restored on the worker before the routine body executes — independent of the :class:`wool.ContextVar` propagation path. """ - from wool.runtime.context import dispatch_timeout + from wool.runtime.context.runtime import dispatch_timeout return dispatch_timeout.get() @@ -1002,3 +971,140 @@ async def mutate_then_nested_get_tenant_id(mid_value: str) -> str: """ TENANT_ID.set(mid_value) return await get_tenant_id() + + +def _decode_bomb_rebuild(): + """Rebuild callable referenced by :class:`DecodeBomb`'s ``__reduce__``. + + Raises unconditionally — when a pickled :class:`DecodeBomb` is + unpickled (the worker-side ``decode_context`` step), this fires + and the per-entry decode fails. The caller never unpickles its own + outgoing context, so the failure is worker-side only. + """ + raise RuntimeError("decode bomb detonated on unpickle") + + +class DecodeBomb: + """A value that pickles cleanly but raises when unpickled. + + Models a version-skew payload: ``encode_context`` on the caller + pickles it without error (``__reduce__`` just stores the rebuild + tuple), but the worker's ``decode_context`` calls + :func:`_decode_bomb_rebuild`, which raises. Used to drive the + worker-side context-decode-failure → Nack path. + """ + + def __reduce__(self): + return (_decode_bomb_rebuild, ()) + + +@wool.routine +async def set_tenant_then_crash_worker(value: str) -> str: + """Set ``TENANT_ID`` then hard-crash the worker process. + + Sets the var (arming the chain), then calls ``os._exit`` so the + worker subprocess dies mid-dispatch without a graceful response. + The caller should observe an ``RpcError`` / ``UnexpectedResponse`` + and its own context state must stay intact — no half-merged + back-propagation from the crashed worker. + """ + import os + + TENANT_ID.set(value) + os._exit(70) + + +@wool.routine +async def set_tenant_then_sleep(value: str, sentinel_path: str, duration: float = 30.0): + """Set ``TENANT_ID`` then sleep — for armed-context cancellation. + + Writes ``"started"`` to *sentinel_path* immediately before + suspending on :func:`asyncio.sleep` so the caller can poll for + suspension. On :class:`asyncio.CancelledError` the marker is + overwritten with ``"cancelled"``. The var mutation is the partial + state whose fate under cancellation the caller pins. + """ + TENANT_ID.set(value) + try: + with open(sentinel_path, "w") as f: + f.write("started") + await asyncio.sleep(duration) + except asyncio.CancelledError: + with open(sentinel_path, "w") as f: + f.write("cancelled") + raise + + +@wool.routine +async def set_and_reset_tenant_across_yield(value: str): + """Async generator that sets ``TENANT_ID``, yields, resets, yields. + + First sets ``TENANT_ID`` to *value* and yields ``"set"``; then + resets the var via the token from that set and yields ``"reset"``. + The caller observes the per-yield back-propagation across the + set→reset boundary: its own ``TENANT_ID`` tracks the value after + the first yield and reverts after the second. + """ + token = TENANT_ID.set(value) + yield "set" + TENANT_ID.reset(token) + yield "reset" + + +@wool.routine +async def read_unbound_default_less_var(namespace: str, name: str) -> str: + """Declare a default-less :class:`wool.ContextVar` and ``get()`` it. + + The var has no constructor default and is unbound, so ``get()`` + with no argument raises :class:`LookupError`. The exception must + surface to the caller through the exception back-propagation path. + """ + var = wool.ContextVar(name, namespace=namespace) + return var.get() + + +@wool.routine +async def count_wool_context_vars() -> int: + """Return the count of wool-owned ``contextvars.ContextVar``s. + + Enumerates ``contextvars.copy_context()`` and counts entries whose + name carries the ``__wool`` prefix — the context variable plus one + backing variable per bound :class:`wool.ContextVar`. + """ + import contextvars as _cv + + return sum(1 for var in _cv.copy_context() if var.name.startswith("__wool")) + + +@wool.routine +async def reenter_armed_chain_off_owner_thread(value: str) -> str: + """Arm the routine's chain, then read the var off the owner thread. + + Sets ``TENANT_ID`` (arming the routine's chain on the worker's + loop thread), then hands a ``wool.ContextVar`` access to a worker + thread via :func:`asyncio.to_thread`. ``asyncio.to_thread`` copies + the surrounding ``contextvars`` context — chain UUID and owner + included — into the executor thread, so the off-thread ``get()`` + re-enters an armed chain from a thread other than its owner and + raises :class:`wool.ChainContention`. The exception surfaces + to the caller through the exception back-propagation path. + """ + TENANT_ID.set(value) + return await asyncio.to_thread(TENANT_ID.get) + + +@wool.routine +async def read_var_off_thread_via_wool_to_thread(value: str) -> str: + """Arm the routine's chain, then read the var off-thread via wool.to_thread. + + Sets ``TENANT_ID`` (arming the routine's chain on the worker's loop + thread), then offloads the var read to a worker thread via + :func:`wool.to_thread`. Unlike :func:`asyncio.to_thread` — which + copies the armed chain verbatim and trips + :class:`wool.ChainContention` off the owner thread — + ``wool.to_thread`` forks the chain onto a fresh, detached chain + owned by the worker thread, so the off-thread ``get()`` re-arms + cleanly and observes the forked copy of the value. + """ + TENANT_ID.set(value) + return await wool.to_thread(TENANT_ID.get) diff --git a/wool/tests/integration/test_context_var_propagation.py b/wool/tests/integration/test_context_var_propagation.py index 6aff58bc..91e68afd 100644 --- a/wool/tests/integration/test_context_var_propagation.py +++ b/wool/tests/integration/test_context_var_propagation.py @@ -1,7 +1,7 @@ """Integration tests for wool.ContextVar cross-worker propagation. These tests drive the full dispatch wire path — caller sets a wool -ContextVar, Task.to_protobuf serializes it, gRPC carries it to a real +ContextVar, encode_context serializes it, gRPC carries it to a real worker subprocess, the worker unpickles the callable (importing routines.py and populating its wool.ContextVar registry), and the routine observes the propagated value. They complement the in-process @@ -16,15 +16,15 @@ import contextvars import warnings +import grpc import pytest import wool -from wool.runtime.context import attached -from wool.runtime.context import dispatch_timeout +from wool.runtime.worker.connection import RpcError +from wool.runtime.worker.connection import UnexpectedResponse from . import _collision_fixtures from . import routines -from .conftest import ContextVarPattern from .conftest import PoolMode from .conftest import RoutineShape from .conftest import build_pool_from_scenario @@ -34,7 +34,7 @@ @pytest.mark.integration class TestContextVarPropagation: @pytest.mark.asyncio - async def test_coroutine_dispatch_propagates_wool_context_var( + async def test_coroutine_dispatch_should_propagate_wool_context_var( self, credentials_map, retry_grpc_internal ): """Test wool.ContextVar values reach a remote coroutine routine. @@ -64,7 +64,7 @@ async def body(): await retry_grpc_internal(body) @pytest.mark.asyncio - async def test_async_generator_dispatch_propagates_wool_context_var( + async def test_async_generator_dispatch_should_propagate_wool_context_var( self, credentials_map, retry_grpc_internal ): """Test propagation across async-generator suspension boundaries. @@ -104,7 +104,7 @@ async def body(): await retry_grpc_internal(body) @pytest.mark.asyncio - async def test_nested_dispatch_propagates_wool_context_var( + async def test_nested_dispatch_should_propagate_wool_context_var( self, credentials_map, retry_grpc_internal ): """Test propagation through a nested routine dispatch chain. @@ -136,7 +136,7 @@ async def body(): await retry_grpc_internal(body) @pytest.mark.asyncio - async def test_coroutine_mutation_is_visible_in_return_value( + async def test_coroutine_mutation_should_be_visible_in_return_value( self, credentials_map, retry_grpc_internal ): """Test a coroutine routine's mutation is readable via its own return value. @@ -167,7 +167,7 @@ async def body(): await retry_grpc_internal(body) @pytest.mark.asyncio - async def test_coroutine_mutation_back_propagates_to_caller( + async def test_coroutine_mutation_should_back_propagate_to_caller( self, credentials_map, retry_grpc_internal ): """Test a coroutine routine's mutation reaches the caller after dispatch. @@ -181,7 +181,7 @@ async def test_coroutine_mutation_back_propagates_to_caller( Then: The caller's value should equal the worker-side mutation — back-propagation applies the routine's change to the - caller's Context + caller's Chain """ # Arrange, act, & assert @@ -199,7 +199,44 @@ async def body(): await retry_grpc_internal(body) @pytest.mark.asyncio - async def test_concurrent_dispatches_observe_isolated_values( + async def test_coroutine_mutation_should_back_propagate_to_unarmed_caller( + self, credentials_map, retry_grpc_internal + ): + """Test a coroutine routine arms a previously-unarmed caller. + + Given: + An unarmed caller (no prior ``wool.ContextVar.set``) and a + DEFAULT pool running a coroutine that performs the first + ``var.set`` on the worker. + When: + The caller dispatches the routine and reads its own var + value after the routine returns. + Then: + The caller should observe the worker-side value — the + worker mints a fresh chain on the first ``var.set``, + ships it back on the result frame's chain manifest, and + the caller's apply-back arms the previously-unarmed + chain with the worker's bindings. This is the + stdlib-parity contract: ``await routine_that_sets(x)`` + makes ``x`` visible to the caller afterward. + """ + + # Arrange, act, & assert + async def body(): + scenario = default_scenario() + async with build_pool_from_scenario(scenario, credentials_map): + # Pre-dispatch baseline: caller is unarmed — the var + # resolves through to its constructor default + # ("unknown") since no chain has set it. + assert routines.TENANT_ID.get() == "unknown" + await routines.mutate_and_read_tenant_id() + caller_value_after = routines.TENANT_ID.get() + assert caller_value_after == "mutated_on_worker" + + await retry_grpc_internal(body) + + @pytest.mark.asyncio + async def test_concurrent_dispatches_should_observe_isolated_values( self, credentials_map, retry_grpc_internal ): """Test concurrent dispatches with different values stay isolated. @@ -241,7 +278,7 @@ async def dispatch_with(value: str) -> str: await retry_grpc_internal(body) @pytest.mark.asyncio - async def test_default_only_value_is_not_propagated( + async def test_default_only_value_should_not_be_propagated( self, credentials_map, retry_grpc_internal ): """Test defaults are not shipped through the propagation path. @@ -255,7 +292,7 @@ async def test_default_only_value_is_not_propagated( Then: The routine should see the worker-side class-level default ("unknown"), proving that default-only values are not - snapshotted into the protobuf payload + captured into the protobuf payload """ # Arrange, act, & assert @@ -268,10 +305,10 @@ async def body(): await retry_grpc_internal(body) @pytest.mark.asyncio - async def test_multiple_wool_context_vars_round_trip( + async def test_multiple_wool_context_vars_should_round_trip( self, credentials_map, retry_grpc_internal ): - """Test multiple registered vars are all snapshotted and restored. + """Test multiple registered vars are all propagated and restored. Given: Two module-level wool.ContextVars both set on the caller @@ -300,7 +337,7 @@ async def body(): await retry_grpc_internal(body) @pytest.mark.asyncio - async def test_async_generator_mutation_is_visible_in_yields( + async def test_async_generator_mutation_should_be_visible_in_yields( self, credentials_map, retry_grpc_internal ): """Test an async generator routine's mutations appear in yielded values. @@ -340,7 +377,7 @@ async def body(): await retry_grpc_internal(body) @pytest.mark.asyncio - async def test_async_generator_mutation_back_propagates_to_caller( + async def test_async_generator_mutation_should_back_propagate_to_caller( self, credentials_map, retry_grpc_internal ): """Test an async generator's mutation reaches the caller after exhaustion. @@ -355,7 +392,7 @@ async def test_async_generator_mutation_back_propagates_to_caller( Then: The caller's value should equal the routine's final mutation — back-propagation applies the final yield's - change to the caller's Context + change to the caller's Chain """ # Arrange, act, & assert @@ -377,7 +414,49 @@ async def body(): await retry_grpc_internal(body) @pytest.mark.asyncio - async def test_back_propagation_updates_caller_per_yield( + async def test_async_gen_per_yield_should_back_propagate_to_unarmed_caller( + self, credentials_map, retry_grpc_internal + ): + """Test an async-generator routine arms a previously-unarmed caller per yield. + + Given: + An unarmed caller (no prior ``wool.ContextVar.set``) and a + DEFAULT pool running an async generator that performs the + first ``var.set`` on the worker on every iteration. + When: + The caller iterates the generator and reads its own var + value after each yield. + Then: + * The pre-dispatch read resolves through to the var's + constructor default — the caller is unarmed. + * Each per-yield snapshot equals the worker's most-recent + ``var.set`` — back-propagation arms the previously- + unarmed caller via the response apply-back, then + updates the binding on every subsequent yield's mount. + * After exhaustion the caller observes the final yield's + binding, the stdlib-parity contract for + ``async for x in agen()`` when the routine sets state. + """ + + # Arrange, act, & assert + async def body(): + scenario = default_scenario( + shape=RoutineShape.ASYNC_GEN_ANEXT, + ) + caller_snapshots: list[str] = [] + async with build_pool_from_scenario(scenario, credentials_map): + # Pre-dispatch baseline: caller is unarmed. + assert routines.TENANT_ID.get() == "unknown" + async for _ in routines.mutate_on_each_yield(3): + caller_snapshots.append(routines.TENANT_ID.get()) + caller_final = routines.TENANT_ID.get() + assert caller_snapshots == ["step-0", "step-1", "step-2"] + assert caller_final == "step-2" + + await retry_grpc_internal(body) + + @pytest.mark.asyncio + async def test_back_propagation_should_update_caller_per_yield( self, credentials_map, retry_grpc_internal ): """Test the caller observes back-propagated mutations after each yield. @@ -411,7 +490,7 @@ async def body(): await retry_grpc_internal(body) @pytest.mark.asyncio - async def test_async_generator_dispatch_matches_in_process_baseline( + async def test_async_generator_dispatch_should_match_in_process_baseline( self, credentials_map, retry_grpc_internal ): """Test remote async-generator per-yield mutations match the in-process baseline. @@ -423,7 +502,7 @@ async def test_async_generator_dispatch_matches_in_process_baseline( in-process (both using wool.ContextVar — this is a wool self-baseline, not a stdlib comparison; stdlib generators share the task context with their caller rather than - owning an isolated worker context) + owning an isolated worker chain) When: Both generators are iterated to completion Then: @@ -469,906 +548,436 @@ async def body(): @pytest.mark.integration -class TestStdlibEquivalence: +class TestWoolContextAcrossWorkers: @pytest.mark.asyncio - async def test_coroutine_mutation_matches_stdlib( + async def test_caller_chain_id_should_propagate_to_worker( self, credentials_map, retry_grpc_internal ): - """Test coroutine back-propagation diverges from stdlib copy-on-write. + """Test the worker observes the same chain id as the caller. Given: - A wool.ContextVar set on the caller and a DEFAULT pool - running a coroutine that mutates the var, alongside an - equivalent plain stdlib contextvars.ContextVar exercised - via contextvars.copy_context().run() + A caller that arms its context by setting a wool.ContextVar + and a routine that returns the worker-side context + chain id hex. When: - Both the wool dispatch and the stdlib run complete + The caller dispatches the routine. Then: - Both paths return the same worker-side mutation result, - but only wool back-propagates the mutation to the - caller — stdlib's copy_context().run() leaves the - caller-side var untouched, while wool's dispatch causes - the caller to observe the worker's set value + It should return the caller's own chain id hex — + confirming the worker is armed on the caller's chain via + install_context. """ - # Arrange - stdlib_var: contextvars.ContextVar[str] = contextvars.ContextVar("stdlib_tenant") - - def stdlib_mutate() -> str: - stdlib_var.set("mutated_on_worker") - return stdlib_var.get() - - # Act & assert async def body(): scenario = default_scenario() async with build_pool_from_scenario(scenario, credentials_map): - # — stdlib path — - stdlib_var.set("caller-value") - ctx = contextvars.copy_context() - stdlib_result = ctx.run(stdlib_mutate) - stdlib_caller_after = stdlib_var.get() - - # — wool path — - token = routines.TENANT_ID.set("caller-value") - try: - wool_result = await routines.mutate_and_read_tenant_id() - wool_caller_after = routines.TENANT_ID.get() - finally: - routines.TENANT_ID.reset(token) - - assert wool_result == stdlib_result - assert wool_caller_after == "mutated_on_worker" - # stdlib copy_context().run() does NOT propagate back. - # wool back-propagates, so we just confirm wool's result - # matches the worker-side mutation. - assert stdlib_caller_after == "caller-value" - - await retry_grpc_internal(body) - - -@pytest.mark.integration -class TestWoolContextAcrossWorkers: - @pytest.mark.asyncio - async def test_caller_context_id_propagates_to_worker( - self, credentials_map, retry_grpc_internal - ): - """Test worker observes the caller's wool.Context id. + # Arrange + routines.TENANT_ID.set("armed") + caller = wool.__chain__.get(None) + assert caller is not None - Given: - A caller that captures ``current_context().id`` before a - dispatch and a routine that returns - ``current_context().id.hex`` from inside the worker - When: - The caller dispatches the routine - Then: - It should return the same id hex as the caller's - captured id hex - """ + # Act + observed_hex = await routines.return_current_chain_id_hex() - # Arrange, act, & assert - async def body(): - scenario = default_scenario() - async with build_pool_from_scenario(scenario, credentials_map): - caller_id = wool.current_context().id - observed_hex = await routines.return_current_context_id_hex() - assert observed_hex == caller_id.hex + # Assert + assert observed_hex == caller.id.hex await retry_grpc_internal(body) @pytest.mark.asyncio - async def test_asyncio_child_task_forks_context_id( + async def test_asyncio_child_task_should_fork_chain_id( self, credentials_map, retry_grpc_internal ): - """Test asyncio child task dispatches fork a fresh context id. + """Test an asyncio child task runs on a freshly forked chain id. Given: - A caller that enters an asyncio child task via - ``create_task`` and captures ``current_context().id`` - inside the child before dispatch + An armed caller that enters an asyncio child task via + ``create_task`` and captures the parent and child chain + ids. When: - The child dispatches a routine that returns its own - observed context id + The child dispatches a routine. Then: - The routine should observe a different id from the - parent's id (stdlib fork parity) + It should observe the child's forked chain id on the + worker — the child's chain differs from the parent's + (stdlib copy-on-fork parity) and the worker arms on the + child's chain. """ - # Arrange, act, & assert async def body(): scenario = default_scenario(pool_mode=PoolMode.EPHEMERAL) async with build_pool_from_scenario(scenario, credentials_map): - parent_id = wool.current_context().id + # Arrange + routines.TENANT_ID.set("armed") + parent = wool.__chain__.get(None) + assert parent is not None + parent_id = parent.id async def _child(): - child_id = wool.current_context().id - observed_hex = await routines.return_current_context_id_hex() - return child_id, observed_hex + child = wool.__chain__.get(None) + assert child is not None + + # Act + observed_hex = await routines.return_current_chain_id_hex() + return child.id, observed_hex child_id, observed_hex = await asyncio.create_task(_child()) + # Assert assert child_id != parent_id assert observed_hex == child_id.hex await retry_grpc_internal(body) @pytest.mark.asyncio - async def test_seeded_context_dispatch_propagates_var_bindings( + async def test_seeded_context_dispatch_should_propagate_var_bindings( self, credentials_map, retry_grpc_internal ): - """Test a Context pre-populated with var bindings ships those - bindings to the worker when the dispatch runs under it. + """Test var bindings seeded in a copied context ship to the worker. Given: - A freshly-constructed ``wool.Context`` (distinct from the - implicit current Context) populated with a TENANT_ID - binding via ``Context.run``, and a routine that returns - the var's observed value + A caller that seeds a TENANT_ID binding then copies the + live context with ``contextvars.copy_context``, and a + routine that returns the var's observed value. When: - The caller invokes the dispatch inside a - ``with attached(seed):`` block + The caller invokes the dispatch inside the copied + context's ``run``. Then: - The routine should return the seed value — the wire - snapshot picks up the seeded binding from the active - Context regardless of whether that Context was the - implicit current one or an explicitly constructed peer + It should return the seed value — the chain manifest picks + up the seeded binding from the copied context's run. """ - # Arrange, act, & assert async def body(): scenario = default_scenario() async with build_pool_from_scenario(scenario, credentials_map): - seed = wool.Context() - seed.run(lambda: routines.TENANT_ID.set("seed-value")) + # Arrange + token = routines.TENANT_ID.set("seed-value") + try: + forked = contextvars.copy_context() + finally: + routines.TENANT_ID.reset(token) - with attached(seed): - result = await routines.get_tenant_id() + # Act + # forked.run() activates the copied context only long + # enough for ensure_future to schedule the dispatch, so + # the routine's chain manifest is encoded from the seeded + # context; the await then completes outside run(). + result = await forked.run( + lambda: asyncio.ensure_future(routines.get_tenant_id()) + ) + + # Assert assert result == "seed-value" await retry_grpc_internal(body) @pytest.mark.integration -class TestTokenAcrossWorkers: +class TestExceptionPathBackPropagation: @pytest.mark.asyncio - async def test_pickled_token_resets_on_worker( + async def test_coroutine_exception_should_back_propagate_worker_mutation( self, credentials_map, retry_grpc_internal ): - """Test worker can reset via a caller-minted pickled Token. + """Test exception payload carries worker-side var mutations. Given: - A caller that sets TENANT_ID and captures the resulting - Token and a routine that accepts the Token and calls - ``var.reset(token)`` on the worker + A routine that sets TENANT_ID to a sentinel value then + raises ValueError When: - The caller dispatches passing the pickled Token + The caller dispatches and catches the exception Then: - The reset should succeed on the worker and the routine's - post-reset read should equal the var's pre-set default + The caller's TENANT_ID should reflect the worker-side + sentinel value (back-propagated via exception context + path) """ # Arrange, act, & assert async def body(): scenario = default_scenario() async with build_pool_from_scenario(scenario, credentials_map): - token = routines.TENANT_ID.set("caller-value") + token = routines.TENANT_ID.set("caller-original") try: - worker_read = await routines.accept_token_and_reset(token) + with pytest.raises(ValueError, match="mutate_then_raise_tenant_id"): + await routines.mutate_then_raise_tenant_id("exc-path-sentinel") + observed = routines.TENANT_ID.get() finally: - # The worker's reset may have consumed the local - # token via back-propagation; only reset locally - # if the token wasn't already used. - if not token.used: - routines.TENANT_ID.reset(token) - # Post-reset read on the worker restores pre-set Undefined, - # which surfaces as the var's constructor default. - assert worker_read == "unknown" + routines.TENANT_ID.reset(token) + assert observed == "exc-path-sentinel" await retry_grpc_internal(body) @pytest.mark.asyncio - async def test_token_reused_on_worker_raises_runtime_error( + async def test_async_gen_exception_should_back_propagate_worker_mutation( self, credentials_map, retry_grpc_internal ): - """Test second reset with same Token raises RuntimeError. + """Test async-gen exception carries mid-stream mutations. Given: - A caller that sets TENANT_ID and dispatches a routine that - calls ``var.reset(token)`` once then attempts a second - reset with the same Token + An async-generator routine that yields once, then sets + TENANT_ID and raises on the next iteration When: - The second reset fires on the worker + The caller iterates and catches the exception Then: - The routine should catch RuntimeError ("Token has already - been used") and return its repr to the caller + The caller's TENANT_ID should reflect the last mid-stream + mutation performed on the worker before the raise """ # Arrange, act, & assert async def body(): - scenario = default_scenario() + scenario = default_scenario( + shape=RoutineShape.ASYNC_GEN_ANEXT, + pool_mode=PoolMode.EPHEMERAL, + ) async with build_pool_from_scenario(scenario, credentials_map): - token = routines.TENANT_ID.set("caller-value") + token = routines.TENANT_ID.set("caller-original") try: - observed = await routines.accept_token_and_double_reset(token) + gen = routines.yield_then_mutate_and_raise("mid-stream-sentinel") + try: + first = await gen.__anext__() + assert first == "ready" + match = "yield_then_mutate_and_raise" + with pytest.raises(ValueError, match=match): + await gen.__anext__() + finally: + await gen.aclose() + observed = routines.TENANT_ID.get() finally: - if not token.used: - routines.TENANT_ID.reset(token) - assert "Token has already been used" in observed + routines.TENANT_ID.reset(token) + assert observed == "mid-stream-sentinel" await retry_grpc_internal(body) + +@pytest.mark.integration +class TestAsyncioForkOnWorker: @pytest.mark.asyncio - async def test_caller_reset_after_worker_consumption_raises( + async def test_worker_child_mutation_should_not_leak_to_parent( self, credentials_map, retry_grpc_internal ): - """Test a caller reset of a worker-consumed Token raises. + """Test child-task mutation stays out of parent on the worker. Given: - A caller that sets TENANT_ID to "X", dispatches a routine - that consumes the Token via var.reset(token) on the - worker, and then sets TENANT_ID to "Y" after the dispatch - returns + A routine that sets TENANT_ID to ``"parent"``, spawns a + child asyncio task that sets TENANT_ID to ``"child"`` and + returns its read, and finally reads TENANT_ID from the + parent after the child completes When: - The caller invokes var.reset(token) a second time locally - — the worker already consumed the Token, and the - caller has a later set that must not be silently - reverted + The caller dispatches the routine Then: - The second reset should raise RuntimeError (Token is - logically single-use across processes, not just - in-process) and the caller's post-set value "Y" must - remain intact + It should return the original value for the parent's + post-child read (stdlib copy-on-fork parity) """ # Arrange, act, & assert async def body(): scenario = default_scenario() async with build_pool_from_scenario(scenario, credentials_map): - token = routines.TENANT_ID.set("X") - # Worker consumes the Token via var.reset(token). - await routines.accept_token_and_reset(token) - # Caller installs a fresh value AFTER the worker - # consumed the Token. A correct implementation must - # reject the caller's second reset and preserve "Y". - y_token = routines.TENANT_ID.set("Y") - try: - with pytest.raises(RuntimeError, match="already been used"): - routines.TENANT_ID.reset(token) - assert routines.TENANT_ID.get() == "Y" - finally: - routines.TENANT_ID.reset(y_token) + child_value, parent_value = await routines.spawn_and_mutate_tenant_id() + assert child_value == "child" + assert parent_value == "parent" await retry_grpc_internal(body) @pytest.mark.asyncio - async def test_caller_reset_after_async_gen_consumed_token_raises( + async def test_worker_child_should_inherit_parent_value( self, credentials_map, retry_grpc_internal ): - """Test consumed-token state back-propagates from an async gen. + """Test child asyncio task inherits parent's pre-fork var value. Given: - A caller that sets TENANT_ID to "X", iterates an async- - generator routine that consumes the Token on one of its - yields, and then sets TENANT_ID to "Y" after exhaustion + A routine that sets TENANT_ID then spawns a child asyncio + task that reads TENANT_ID without mutating When: - The caller invokes var.reset(token) locally — the - generator already consumed the Token on the worker + The caller dispatches the routine Then: - The reset should raise RuntimeError (per-yield back- - propagation carries the consumed-token set to the - caller just like coroutine back-propagation) and the - caller's post-set value "Y" must remain intact + It should return the parent's pre-fork value for the + child's read """ # Arrange, act, & assert async def body(): - scenario = default_scenario(shape=RoutineShape.ASYNC_GEN_ANEXT) + scenario = default_scenario() async with build_pool_from_scenario(scenario, credentials_map): - token = routines.TENANT_ID.set("X") - async for _ in routines.accept_token_and_reset_on_yield(token): - pass - y_token = routines.TENANT_ID.set("Y") - try: - with pytest.raises(RuntimeError, match="already been used"): - routines.TENANT_ID.reset(token) - assert routines.TENANT_ID.get() == "Y" - finally: - routines.TENANT_ID.reset(y_token) + child_value = await routines.parent_sets_child_reads() + assert child_value == "parent-set" await retry_grpc_internal(body) @pytest.mark.asyncio - async def test_worker_reset_of_caller_consumed_token_raises( + async def test_worker_sibling_children_should_be_isolated( self, credentials_map, retry_grpc_internal ): - """Test forward-propagated consumed tokens reject a worker reset. + """Test sibling asyncio children are mutually isolated. Given: - A caller that sets TENANT_ID, consumes the resulting - Token locally via var.reset(token), and then dispatches - a routine that tries to reset the same (already-consumed) - Token on the worker + A routine that spawns two children via ``asyncio.gather``, + each mutating TENANT_ID to different values, and a parent + read afterward When: - The dispatch runs — forward-propagation carries the - caller's consumed-token set to the worker's scoped - Context before the routine body executes + The caller dispatches the routine Then: - The worker's var.reset(token) call should raise - RuntimeError and the exception should surface to the - caller's await + Each child should observe its own value, and the parent's + var should remain unchanged (neither child leaks into the + other nor into the parent) """ # Arrange, act, & assert async def body(): scenario = default_scenario() async with build_pool_from_scenario(scenario, credentials_map): - token = routines.TENANT_ID.set("X") - routines.TENANT_ID.reset(token) # caller consumes first - with pytest.raises(RuntimeError, match="already been used"): - await routines.accept_token_and_reset(token) + ( + a_value, + b_value, + parent_value, + ) = await routines.two_children_mutate_tenant_id() + assert a_value == "alpha" + assert b_value == "beta" + # Parent never set TENANT_ID, and children's mutations are + # in their own forked contexts. The default surfaces here. + assert parent_value == "unknown" await retry_grpc_internal(body) + +@pytest.mark.integration +class TestStubPromotionAcrossWorkers: @pytest.mark.asyncio - async def test_worker_minted_token_is_reusable_on_caller_then_rejects_reuse( + async def test_fresh_worker_should_promote_stub_without_collision( self, credentials_map, retry_grpc_internal ): - """Test a worker-minted Token round-trips and stays single-use. + """Test fresh worker unpickles stub then imports module. Given: - A routine that mints a Token via TENANT_ID.set(...) on - the worker and returns it to the caller, followed by - the caller consuming that Token locally via - TENANT_ID.reset(token) + A routine that imports and reads TENANT_ID, a caller that + sets the var, and a fresh EPHEMERAL worker that has not yet + imported the defining module When: - The caller invokes TENANT_ID.reset(token) a second time + The caller dispatches the routine Then: - The second reset should raise RuntimeError — the - Token is logically single-use regardless of which side - minted it, and the identity round-trip back from the - worker must preserve that contract + The worker should unpickle the var (stub), import the + module (promote the stub), and the routine should read + the propagated value without a collision """ # Arrange, act, & assert async def body(): - scenario = default_scenario() + scenario = default_scenario(pool_mode=PoolMode.EPHEMERAL) async with build_pool_from_scenario(scenario, credentials_map): - token = await routines.mint_tenant_token("W") - routines.TENANT_ID.reset(token) # consume once locally - with pytest.raises(RuntimeError, match="already been used"): + token = routines.TENANT_ID.set("stub-promotion-value") + try: + result = await routines.get_tenant_id() + finally: routines.TENANT_ID.reset(token) + assert result == "stub-promotion-value" await retry_grpc_internal(body) @pytest.mark.asyncio - async def test_worker_reset_of_worker_minted_then_caller_consumed_token_raises( + async def test_sibling_routine_should_raise_context_var_collision( self, credentials_map, retry_grpc_internal ): - """Test forward-prop rejects worker reset of a worker-minted consumed Token. + """Test colliding sibling var raises on the caller. Given: - A worker that mints a Token and returns it to the - caller; the caller consumes the Token locally via - TENANT_ID.reset(token); and a second routine dispatch - that passes the same (now-consumed) Token to a worker - that attempts TENANT_ID.reset(token) + Two sibling routines that each construct a + ``wool.ContextVar`` with the same ``namespace:name`` key + in their function bodies, on a DEFAULT pool so both + dispatches land on the same worker (process-wide registry + isolation would otherwise mask the collision) When: - The second dispatch runs — forward-propagation carries - the caller's consumed-token state into the second - worker's scoped Context before the routine body executes + The caller dispatches the first sibling (registering the + key on the worker) then dispatches the second sibling Then: - The second worker's var.reset(token) should raise - RuntimeError and the exception should surface to the - caller's await, mirroring the caller-minted-Token case - for a Token that originated on the worker instead of on - the caller + The second dispatch should raise wool.ContextVarCollision + on the caller via the worker's exception context path """ # Arrange, act, & assert async def body(): - scenario = default_scenario() + # DEFAULT pool is size=1 so both dispatches land on the + # same worker process; the second construction under the + # already-registered key triggers the collision. + scenario = default_scenario(pool_mode=PoolMode.DEFAULT) async with build_pool_from_scenario(scenario, credentials_map): - token = await routines.mint_tenant_token("W") - routines.TENANT_ID.reset(token) - with pytest.raises(RuntimeError, match="already been used"): - await routines.accept_token_and_reset(token) + # First dispatch registers the key on the worker. + first = await _collision_fixtures.sibling_a() + assert first == "sibling-a" + + # Second dispatch constructs a new var with the same + # key → ContextVarCollision propagates back. + with pytest.raises( + wool.ContextVarCollision, + match=_collision_fixtures.COLLIDING_KEY, + ): + await _collision_fixtures.sibling_b() await retry_grpc_internal(body) @pytest.mark.integration -class TestSelfDispatchTokenReset: - """Token reset under self-dispatch (DEFAULT pool, single in-process worker). - - Self-dispatch serializes the dispatch payload through cloudpickle - exactly like cross-process dispatch, so a routine receives a copy - of its arguments and a :class:`wool.Token` reset behaves - identically regardless of pool mode. These tests cover - caller-minted and worker-minted token resets through a - single-worker DEFAULT pool. - """ - +class TestForwardPropagationMidStream: + # NOTE: the plain ``async for`` mid-stream forward-propagation case + # (formerly ``test_mid_stream_mutation_reaches_next_anext``) is now + # the ``ContextVarPattern.MID_STREAM_FORWARD`` dimension member — + # the covering array in ``test_integration.py`` crosses it with + # ``pool_mode``/``binding``/``credential`` automatically. Only the + # ``asend``/``athrow`` shapes and the concurrency cases remain + # here: the per-step forward driver in ``invoke_routine`` covers + # ``ASYNC_GEN_ANEXT`` only, so those shapes are genuinely distinct. @pytest.mark.asyncio - async def test_reset_with_caller_minted_token_in_self_dispatch( + async def test_mid_stream_mutation_should_reach_asend_frame( self, credentials_map, retry_grpc_internal ): - """Test a self-dispatched routine resets a caller-minted Token without raising. + """Test caller mutation before asend reaches the worker frame. Given: - A DEFAULT pool (single in-process worker so the dispatch - target matches the caller process) and a caller that sets - ``TENANT_ID`` and captures the resulting ``wool.Token``. + An async-generator routine using ``asend`` that echoes + ``TENANT_ID.get()`` each iteration When: - The self-dispatched routine calls ``TENANT_ID.reset(token)`` - on the worker. + The caller calls ``gen.asend(x)`` with TENANT_ID set to a + distinct value before each send Then: - The reset should not raise :class:`ValueError`, and the - token should read as used afterward. + Each echoed value should reflect the caller's var value at + the moment of the corresponding ``asend`` frame """ # Arrange, act, & assert async def body(): - scenario = default_scenario(pool_mode=PoolMode.DEFAULT) + scenario = default_scenario(shape=RoutineShape.ASYNC_GEN_ASEND) async with build_pool_from_scenario(scenario, credentials_map): - token = routines.TENANT_ID.set("caller-value") + gen = routines.echo_tenant_id_on_send(3) try: - worker_read = await routines.accept_token_and_reset(token) + first = await gen.__anext__() + assert first == "ready" + collected: list[str] = [] + values = ["fs2-a", "fs2-b", "fs2-c"] + tokens: list = [] + try: + for v in values: + tokens.append(routines.TENANT_ID.set(v)) + collected.append(await gen.asend(None)) + finally: + for tok in reversed(tokens): + routines.TENANT_ID.reset(tok) finally: - if not token.used: - routines.TENANT_ID.reset(token) - assert worker_read == "unknown" - assert token.used is True + await gen.aclose() + assert collected == values await retry_grpc_internal(body) @pytest.mark.asyncio - async def test_reset_with_caller_minted_token_matches_cross_process( + async def test_mid_stream_mutation_should_reach_athrow_frame( self, credentials_map, retry_grpc_internal ): - """Test caller-minted Token reset matches self-dispatch and cross-process. + """Test caller mutation before athrow reaches the handler frame. Given: - The same caller-minted-Token-reset scenario run once under - a DEFAULT pool (self-dispatch) and once under an EPHEMERAL - pool (cross-process workers). + An async-generator routine whose ``athrow`` handler reads + ``TENANT_ID`` before yielding that value and returning When: - The self-dispatched routine calls ``TENANT_ID.reset(token)`` - on the worker in each pool mode. + The caller mutates TENANT_ID then calls ``gen.athrow`` Then: - Both runs should produce identical observable results — - proving self-dispatch is behaviorally identical to - cross-process dispatch. - """ - - # Arrange, act, & assert - async def body(): - observations = [] - for pool_mode in (PoolMode.DEFAULT, PoolMode.EPHEMERAL): - scenario = default_scenario(pool_mode=pool_mode) - async with build_pool_from_scenario(scenario, credentials_map): - token = routines.TENANT_ID.set("caller-value") - try: - worker_read = await routines.accept_token_and_reset(token) - finally: - if not token.used: - routines.TENANT_ID.reset(token) - observations.append((worker_read, token.used)) - assert observations[0] == observations[1] - assert observations[0] == ("unknown", True) - - await retry_grpc_internal(body) - - @pytest.mark.asyncio - async def test_reset_with_worker_minted_token_in_self_dispatch( - self, credentials_map, retry_grpc_internal - ): - """Test a worker-minted Token round-trips and stays single-use in self-dispatch. - - Given: - A DEFAULT pool (self-dispatch) and a routine that mints a - ``wool.Token`` on the worker via ``TENANT_ID.set`` and - returns it to the caller. - When: - The caller calls ``TENANT_ID.reset(token)`` locally, then - calls it a second time. - Then: - The first local reset should succeed and the second should - raise :class:`RuntimeError` — the worker-minted token is - single-use across the logical chain even when minted under - a self-dispatch. - """ - - # Arrange, act, & assert - async def body(): - scenario = default_scenario(pool_mode=PoolMode.DEFAULT) - async with build_pool_from_scenario(scenario, credentials_map): - token = await routines.mint_tenant_token("worker-value") - routines.TENANT_ID.reset(token) - with pytest.raises(RuntimeError, match="already been used"): - routines.TENANT_ID.reset(token) - - await retry_grpc_internal(body) - - @pytest.mark.asyncio - async def test_reset_with_caller_minted_token_on_async_gen_yield_in_self_dispatch( - self, credentials_map, retry_grpc_internal - ): - """Test a self-dispatched async generator resets a caller-minted Token. - - Given: - A DEFAULT pool (self-dispatch) and an async-generator - routine that resets a caller-minted ``wool.Token`` between - two yields. - When: - The caller iterates the generator to completion, then - calls ``TENANT_ID.reset(token)`` locally. - Then: - The worker-side reset inside the generator frame should not - raise :class:`ValueError`, and the caller's later reset - should raise :class:`RuntimeError` because the token was - already consumed on the worker. - """ - - # Arrange, act, & assert - async def body(): - scenario = default_scenario( - shape=RoutineShape.ASYNC_GEN_ANEXT, - pool_mode=PoolMode.DEFAULT, - ) - async with build_pool_from_scenario(scenario, credentials_map): - token = routines.TENANT_ID.set("caller-value") - collected = [ - item - async for item in routines.accept_token_and_reset_on_yield(token) - ] - assert collected == ["before", "after"] - with pytest.raises(RuntimeError, match="already been used"): - routines.TENANT_ID.reset(token) - - await retry_grpc_internal(body) - - -@pytest.mark.integration -class TestExceptionPathBackPropagation: - @pytest.mark.asyncio - async def test_coroutine_exception_back_propagates_worker_mutation( - self, credentials_map, retry_grpc_internal - ): - """Test exception payload carries worker-side var mutations. - - Given: - A routine that sets TENANT_ID to a sentinel value then - raises ValueError - When: - The caller dispatches and catches the exception - Then: - The caller's TENANT_ID should reflect the worker-side - sentinel value (back-propagated via exception snapshot - path) - """ - - # Arrange, act, & assert - async def body(): - scenario = default_scenario() - async with build_pool_from_scenario(scenario, credentials_map): - token = routines.TENANT_ID.set("caller-original") - try: - with pytest.raises(ValueError, match="mutate_then_raise_tenant_id"): - await routines.mutate_then_raise_tenant_id("exc-path-sentinel") - observed = routines.TENANT_ID.get() - finally: - routines.TENANT_ID.reset(token) - assert observed == "exc-path-sentinel" - - await retry_grpc_internal(body) - - @pytest.mark.asyncio - async def test_async_gen_exception_back_propagates_worker_mutation( - self, credentials_map, retry_grpc_internal - ): - """Test async-gen exception carries mid-stream mutations. - - Given: - An async-generator routine that yields once, then sets - TENANT_ID and raises on the next iteration - When: - The caller iterates and catches the exception - Then: - The caller's TENANT_ID should reflect the last mid-stream - mutation performed on the worker before the raise - """ - - # Arrange, act, & assert - async def body(): - scenario = default_scenario( - shape=RoutineShape.ASYNC_GEN_ANEXT, - pool_mode=PoolMode.EPHEMERAL, - ) - async with build_pool_from_scenario(scenario, credentials_map): - token = routines.TENANT_ID.set("caller-original") - try: - gen = routines.yield_then_mutate_and_raise("mid-stream-sentinel") - try: - first = await gen.__anext__() - assert first == "ready" - match = "yield_then_mutate_and_raise" - with pytest.raises(ValueError, match=match): - await gen.__anext__() - finally: - await gen.aclose() - observed = routines.TENANT_ID.get() - finally: - routines.TENANT_ID.reset(token) - assert observed == "mid-stream-sentinel" - - await retry_grpc_internal(body) - - -@pytest.mark.integration -class TestAsyncioForkOnWorker: - @pytest.mark.asyncio - async def test_worker_child_mutation_does_not_leak_to_parent( - self, credentials_map, retry_grpc_internal - ): - """Test child-task mutation stays out of parent on the worker. - - Given: - A routine that sets TENANT_ID to ``"parent"``, spawns a - child asyncio task that sets TENANT_ID to ``"child"`` and - returns its read, and finally reads TENANT_ID from the - parent after the child completes - When: - The caller dispatches the routine - Then: - It should return the original value for the parent's - post-child read (stdlib copy-on-fork parity) - """ - - # Arrange, act, & assert - async def body(): - scenario = default_scenario() - async with build_pool_from_scenario(scenario, credentials_map): - child_value, parent_value = await routines.spawn_and_mutate_tenant_id() - assert child_value == "child" - assert parent_value == "parent" - - await retry_grpc_internal(body) - - @pytest.mark.asyncio - async def test_worker_child_inherits_parent_value( - self, credentials_map, retry_grpc_internal - ): - """Test child asyncio task inherits parent's pre-fork var value. - - Given: - A routine that sets TENANT_ID then spawns a child asyncio - task that reads TENANT_ID without mutating - When: - The caller dispatches the routine - Then: - It should return the parent's pre-fork value for the - child's read - """ - - # Arrange, act, & assert - async def body(): - scenario = default_scenario() - async with build_pool_from_scenario(scenario, credentials_map): - child_value = await routines.parent_sets_child_reads() - assert child_value == "parent-set" - - await retry_grpc_internal(body) - - @pytest.mark.asyncio - async def test_worker_sibling_children_are_isolated( - self, credentials_map, retry_grpc_internal - ): - """Test sibling asyncio children are mutually isolated. - - Given: - A routine that spawns two children via ``asyncio.gather``, - each mutating TENANT_ID to different values, and a parent - read afterward - When: - The caller dispatches the routine - Then: - Each child should observe its own value, and the parent's - var should remain unchanged (neither child leaks into the - other nor into the parent) - """ - - # Arrange, act, & assert - async def body(): - scenario = default_scenario() - async with build_pool_from_scenario(scenario, credentials_map): - ( - a_value, - b_value, - parent_value, - ) = await routines.two_children_mutate_tenant_id() - assert a_value == "alpha" - assert b_value == "beta" - # Parent never set TENANT_ID, and children's mutations are - # in their own forked contexts. The default surfaces here. - assert parent_value == "unknown" - - await retry_grpc_internal(body) - - -@pytest.mark.integration -class TestStubPromotionAcrossWorkers: - @pytest.mark.asyncio - async def test_fresh_worker_promotes_stub_without_collision( - self, credentials_map, retry_grpc_internal - ): - """Test fresh worker unpickles stub then imports module. - - Given: - A routine that imports and reads TENANT_ID, a caller that - sets the var, and a fresh EPHEMERAL worker that has not yet - imported the defining module - When: - The caller dispatches the routine - Then: - The worker should unpickle the var (stub), import the - module (promote the stub), and the routine should read - the propagated value without a collision - """ - - # Arrange, act, & assert - async def body(): - scenario = default_scenario(pool_mode=PoolMode.EPHEMERAL) - async with build_pool_from_scenario(scenario, credentials_map): - token = routines.TENANT_ID.set("stub-promotion-value") - try: - result = await routines.get_tenant_id() - finally: - routines.TENANT_ID.reset(token) - assert result == "stub-promotion-value" - - await retry_grpc_internal(body) - - @pytest.mark.asyncio - async def test_sibling_routine_raises_context_var_collision( - self, credentials_map, retry_grpc_internal - ): - """Test colliding sibling var raises on the caller. - - Given: - Two sibling routines that each construct a - ``wool.ContextVar`` with the same ``namespace:name`` key - in their function bodies, on a DEFAULT pool so both - dispatches land on the same worker (process-wide registry - isolation would otherwise mask the collision) - When: - The caller dispatches the first sibling (registering the - key on the worker) then dispatches the second sibling - Then: - The second dispatch should raise wool.ContextVarCollision - on the caller via the worker's exception snapshot path - """ - - # Arrange, act, & assert - async def body(): - # DEFAULT pool is size=1 so both dispatches land on the - # same worker process; the second construction under the - # already-registered key triggers the collision. - scenario = default_scenario(pool_mode=PoolMode.DEFAULT) - async with build_pool_from_scenario(scenario, credentials_map): - # First dispatch registers the key on the worker. - first = await _collision_fixtures.sibling_a() - assert first == "sibling-a" - - # Second dispatch constructs a new var with the same - # key → ContextVarCollision propagates back. - with pytest.raises( - wool.ContextVarCollision, - match=_collision_fixtures.COLLIDING_KEY, - ): - await _collision_fixtures.sibling_b() - - await retry_grpc_internal(body) - - -@pytest.mark.integration -class TestForwardPropagationMidStream: - @pytest.mark.asyncio - async def test_mid_stream_mutation_reaches_next_anext( - self, credentials_map, retry_grpc_internal - ): - """Test caller mutation between __anext__ calls reaches worker. - - Given: - An async-generator routine that yields ``TENANT_ID.get()`` - on each iteration and a caller that mutates the var - between ``__anext__`` calls - When: - The caller drives the generator manually, setting the var - to a distinct value before each ``__anext__`` - Then: - Each yielded value should reflect the caller's most recent - value at the moment of the call - """ - - # Arrange, act, & assert - async def body(): - scenario = default_scenario(shape=RoutineShape.ASYNC_GEN_ANEXT) - async with build_pool_from_scenario(scenario, credentials_map): - gen = routines.stream_tenant_id_echo(3) - try: - collected: list[str] = [] - values = ["fs1-first", "fs1-second", "fs1-third"] - tokens: list = [] - try: - for v in values: - tokens.append(routines.TENANT_ID.set(v)) - collected.append(await gen.__anext__()) - finally: - for tok in reversed(tokens): - if not tok.used: - routines.TENANT_ID.reset(tok) - finally: - await gen.aclose() - assert collected == values - - await retry_grpc_internal(body) - - @pytest.mark.asyncio - async def test_mid_stream_mutation_reaches_asend_frame( - self, credentials_map, retry_grpc_internal - ): - """Test caller mutation before asend reaches the worker frame. - - Given: - An async-generator routine using ``asend`` that echoes - ``TENANT_ID.get()`` each iteration - When: - The caller calls ``gen.asend(x)`` with TENANT_ID set to a - distinct value before each send - Then: - Each echoed value should reflect the caller's var value at - the moment of the corresponding ``asend`` frame - """ - - # Arrange, act, & assert - async def body(): - scenario = default_scenario(shape=RoutineShape.ASYNC_GEN_ASEND) - async with build_pool_from_scenario(scenario, credentials_map): - gen = routines.echo_tenant_id_on_send(3) - try: - first = await gen.__anext__() - assert first == "ready" - collected: list[str] = [] - values = ["fs2-a", "fs2-b", "fs2-c"] - tokens: list = [] - try: - for v in values: - tokens.append(routines.TENANT_ID.set(v)) - collected.append(await gen.asend(None)) - finally: - for tok in reversed(tokens): - if not tok.used: - routines.TENANT_ID.reset(tok) - finally: - await gen.aclose() - assert collected == values - - await retry_grpc_internal(body) - - @pytest.mark.asyncio - async def test_mid_stream_mutation_reaches_athrow_frame( - self, credentials_map, retry_grpc_internal - ): - """Test caller mutation before athrow reaches the handler frame. - - Given: - An async-generator routine whose ``athrow`` handler reads - ``TENANT_ID`` before yielding that value and returning - When: - The caller mutates TENANT_ID then calls ``gen.athrow`` - Then: - The yielded value should reflect the caller's most recent - var value at the moment of the ``athrow`` call + The yielded value should reflect the caller's most recent + var value at the moment of the ``athrow`` call """ # Arrange, act, & assert @@ -1391,11 +1000,10 @@ async def body(): await retry_grpc_internal(body) @pytest.mark.asyncio - async def test_concurrent_mid_stream_mutations_remain_serialized( + async def test_concurrent_mid_stream_mutations_should_remain_serialized( self, credentials_map, retry_grpc_internal ): - """Test parallel async-generator dispatches with mid-stream - mutations remain isolated under concurrent load. + """Test concurrent async-gen dispatches keep mid-stream mutations isolated. Given: An EPHEMERAL pool sized to host multiple concurrent @@ -1434,8 +1042,7 @@ async def dispatch_with_prefix(prefix: str) -> list[str]: collected.append(await gen.__anext__()) finally: for tok in reversed(tokens): - if not tok.used: - routines.TENANT_ID.reset(tok) + routines.TENANT_ID.reset(tok) finally: await gen.aclose() return collected @@ -1454,12 +1061,10 @@ async def dispatch_with_prefix(prefix: str) -> list[str]: await retry_grpc_internal(body) @pytest.mark.asyncio - async def test_concurrent_asend_against_single_generator_raises( + async def test_concurrent_asend_against_single_generator_should_raise( self, credentials_map, retry_grpc_internal ): - """Test two concurrent ``asend`` calls against the same wool - async-generator behave like Python native: one succeeds, the - other raises RuntimeError. + """Test concurrent asend on one wool async-gen: one succeeds, one raises. Given: An async-generator routine driven past its initial @@ -1520,7 +1125,7 @@ async def body(): @pytest.mark.integration class TestUnregisteredKeyBehavior: @pytest.mark.asyncio - async def test_worker_silently_drops_unknown_key( + async def test_worker_should_silently_drop_unknown_key( self, credentials_map, retry_grpc_internal ): """Test var unknown on worker is dropped, dispatch succeeds. @@ -1554,7 +1159,7 @@ async def body(): await retry_grpc_internal(body) @pytest.mark.asyncio - async def test_worker_stubs_unknown_key_visible_after_late_declaration( + async def test_worker_should_expose_stubbed_unknown_key_when_late_declared( self, credentials_map, retry_grpc_internal ): """Test wire stub becomes visible once worker declares the var. @@ -1594,28 +1199,44 @@ async def body(): @pytest.mark.integration -class TestCallerSideTaskFactoryFork: +class TestCallerSideChildTaskDispatch: + """Caller-side child asyncio task forks the chain and dispatches. + + Consolidates the formerly-separate ``TestCallerSideTaskFactoryFork`` + and ``TestForkedChildTaskDispatchAcrossWorkers`` classes — both + exercised the same "a child task created with ``create_task`` forks + the parent chain and dispatches a routine" scenario, differing only + in ``pool_mode``. The inheritance case is parametrized over + ``pool_mode`` rather than duplicated across two class bodies. + """ + @pytest.mark.asyncio - async def test_caller_child_task_inherits_var_through_dispatch( - self, credentials_map, retry_grpc_internal + @pytest.mark.parametrize( + "pool_mode", + [PoolMode.DEFAULT, PoolMode.EPHEMERAL], + ids=lambda m: m.name, + ) + async def test_caller_child_task_should_inherit_var_through_dispatch( + self, pool_mode, credentials_map, retry_grpc_internal ): """Test caller child asyncio task inherits var and dispatches correctly. Given: A caller that sets TENANT_ID and spawns an asyncio child task via ``create_task`` which dispatches a routine that - reads the var from the worker + reads the var from the worker, against a DEFAULT + (self-dispatch) and an EPHEMERAL (cross-process) pool When: The child task dispatches the routine Then: The routine should return the caller's propagated value, - proving the child task inherited the parent's context and - the dispatch propagated it to the worker + proving the child task's forked chain carries the parent's + variable bindings through the dispatch to the worker """ # Arrange, act, & assert async def body(): - scenario = default_scenario(pool_mode=PoolMode.EPHEMERAL) + scenario = default_scenario(pool_mode=pool_mode) async with build_pool_from_scenario(scenario, credentials_map): token = routines.TENANT_ID.set("parent-caller-value") try: @@ -1631,7 +1252,7 @@ async def _child(): await retry_grpc_internal(body) @pytest.mark.asyncio - async def test_caller_child_dispatch_mutation_does_not_leak_to_parent( + async def test_caller_child_dispatch_mutation_should_not_leak_to_parent( self, credentials_map, retry_grpc_internal ): """Test caller child task's back-propagated mutation stays isolated. @@ -1674,11 +1295,48 @@ async def _child(): await retry_grpc_internal(body) + @pytest.mark.asyncio + async def test_concurrent_child_dispatches_should_be_isolated( + self, credentials_map, retry_grpc_internal + ): + """Test two concurrent child-task dispatches stay isolated. + + Given: + An armed caller and two child tasks created via + asyncio.create_task, each setting TENANT_ID to a distinct + value before dispatching a routine that reads it. + When: + Both tasks are gathered concurrently. + Then: + Each routine should observe its own task's value — the + task factory forks each child onto a fresh chain so + mutations do not cross between siblings. + """ + + # Arrange, act, & assert + async def body(): + scenario = default_scenario(pool_mode=PoolMode.EPHEMERAL) + async with build_pool_from_scenario(scenario, credentials_map): + routines.TENANT_ID.set("caller-value") + + async def _dispatch(value: str) -> str: + routines.TENANT_ID.set(value) + await asyncio.sleep(0) + return await routines.get_tenant_id() + + first = asyncio.create_task(_dispatch("tenant-a")) + second = asyncio.create_task(_dispatch("tenant-b")) + results = await asyncio.gather(first, second) + + assert results == ["tenant-a", "tenant-b"] + + await retry_grpc_internal(body) + @pytest.mark.integration class TestSequentialDispatchIsolation: @pytest.mark.asyncio - async def test_sequential_dispatches_do_not_bleed_context( + async def test_sequential_dispatches_should_not_bleed_context( self, credentials_map, retry_grpc_internal ): """Test sequential dispatches do not leak var state between calls. @@ -1693,7 +1351,7 @@ async def test_sequential_dispatches_do_not_bleed_context( Then: The second dispatch should observe the caller's freshly set value, not the residual mutation from the first dispatch, - proving each dispatch snapshots the caller's current context + proving each dispatch captures the caller's current context independently """ @@ -1724,7 +1382,7 @@ async def body(): @pytest.mark.integration class TestSelfDispatchStreamingVarMutation: @pytest.mark.asyncio - async def test_self_dispatch_streaming_var_mutation_between_yields( + async def test_self_dispatch_streaming_should_apply_var_mutation_between_yields( self, credentials_map, retry_grpc_internal ): """Test self-dispatch streaming applies caller var mutations per yield. @@ -1761,8 +1419,7 @@ async def body(): collected.append(await gen.__anext__()) finally: for tok in reversed(tokens): - if not tok.used: - routines.TENANT_ID.reset(tok) + routines.TENANT_ID.reset(tok) finally: await gen.aclose() assert collected == values @@ -1773,7 +1430,7 @@ async def body(): @pytest.mark.integration class TestDurablePoolContextPropagation: @pytest.mark.asyncio - async def test_durable_pool_propagates_wool_context_var( + async def test_durable_pool_should_propagate_wool_context_var( self, credentials_map, retry_grpc_internal ): """Test wool.ContextVar propagation works through a DURABLE pool. @@ -1806,645 +1463,777 @@ async def body(): @pytest.mark.integration -class TestMergedWireShapeEndToEnd: - """End-to-end coverage for the merged wire shape introduced - when ``protocol.Context.vars`` became ``repeated ContextVar`` - with optional ``value`` and ``consumed_tokens`` fields under - the same entry. Each test exercises a caller setup whose - Context carries both a current value AND a consumed-token id - for the same var — a corner the prior shape (``map`` - plus ``repeated ConsumedToken``) could not express in a single - entry — and verifies the dispatch path round-trips both pieces - of state to the worker. - """ - +class TestMergedWireShapeEndToEnd: + """End-to-end coverage for the merged wire shape introduced + when ``protocol.ChainManifest.vars`` became ``repeated ContextVar`` + with optional ``value`` and ``consumed_tokens`` fields under + the same entry. Each test exercises a caller setup whose + Chain carries both a current value AND a consumed-token id + for the same var — a corner the prior shape (``map`` + plus ``repeated ConsumedToken``) could not express in a single + entry — and verifies the dispatch path round-trips both pieces + of state to the worker. + """ + + @pytest.mark.asyncio + async def test_merged_entry_should_ride_forward_across_async_gen_frames( + self, credentials_map, retry_grpc_internal + ): + """Test per-frame propagation preserves the merged entry across anext frames. + + Given: + A caller that ran ``token = TENANT_ID.set("X")``, then + ``TENANT_ID.reset(token)``, then ``TENANT_ID.set("Y")`` + — same setup as the single-dispatch case but the routine + is an async generator that yields ``TENANT_ID.get()`` on + each iteration + When: + The caller iterates ``stream_tenant_id(count=3)`` to + completion + Then: + Every yield equals "Y" — the value rides the merged + entry on each per-frame request — and a subsequent + local ``TENANT_ID.reset(token)`` still raises + RuntimeError, confirming the streaming back-propagation + preserved the caller's consumed-token state rather than + silently clobbering it + """ + + # Arrange, act, & assert + async def body(): + scenario = default_scenario(pool_mode=PoolMode.EPHEMERAL) + async with build_pool_from_scenario(scenario, credentials_map): + token = routines.TENANT_ID.set("X") + routines.TENANT_ID.reset(token) + y_token = routines.TENANT_ID.set("Y") + try: + yielded: list[str] = [] + async for value in routines.stream_tenant_id(3): + yielded.append(value) + assert yielded == ["Y", "Y", "Y"] + with pytest.raises(RuntimeError, match="already been used"): + routines.TENANT_ID.reset(token) + finally: + routines.TENANT_ID.reset(y_token) + + await retry_grpc_internal(body) + + +@pytest.mark.integration +class TestSerializationWarningAcrossWorkers: + @pytest.mark.asyncio + async def test_unpicklable_var_value_should_emit_warning_and_complete_dispatch( + self, credentials_map, retry_grpc_internal + ): + """Test unpicklable wool.ContextVar value is dropped and dispatch survives. + + Given: + A caller that sets ``TENANT_ID`` to a known value and a + second wool.ContextVar (REGION) to an unpicklable lambda, + and a routine that reads ``TENANT_ID`` + When: + The caller dispatches the routine under default warning + filters + Then: + ``wool.SerializationWarning`` should be emitted on the + caller side for the unpicklable var; the dispatch should + still complete; and the routine should observe the + propagated ``TENANT_ID`` value — primary signal preserved, + ancillary failure surfaced as a warning + """ + + # Arrange, act, & assert + async def body(): + scenario = default_scenario() + async with build_pool_from_scenario(scenario, credentials_map): + tenant_token = routines.TENANT_ID.set("survives-encode") + # Local lambdas are not picklable across processes + # (cloudpickle handles many cases, but a closure over + # a local non-importable scope is rejected by the + # default pickle protocol via cloudpickle.dumps when + # paired with a file-local symbol that has no qualname + # path the worker can resolve). Use an open file + # handle as a robust unpicklable sentinel. + import socket + + unpicklable = socket.socket() + try: + region_token = routines.REGION.set(unpicklable) # pyright: ignore[reportArgumentType] + try: + with warnings.catch_warnings(record=True) as captured: + warnings.simplefilter( + "always", category=wool.SerializationWarning + ) + result = await routines.read_tenant_id_only() + finally: + routines.REGION.reset(region_token) + finally: + unpicklable.close() + routines.TENANT_ID.reset(tenant_token) + decode_warnings = [ + w for w in captured if issubclass(w.category, wool.SerializationWarning) + ] + assert decode_warnings, ( + f"Expected at least one SerializationWarning, got {captured!r}" + ) + assert any("region" in str(w.message) for w in decode_warnings), ( + f"Expected the warning to name the offending var; " + f"got {[str(w.message) for w in decode_warnings]!r}" + ) + assert result == "survives-encode" + + await retry_grpc_internal(body) + + @pytest.mark.asyncio + async def test_unpicklable_var_value_should_raise_error_when_strict_mode( + self, credentials_map, retry_grpc_internal + ): + """Test caller-side strict mode raises ChainSerializationError. + + Given: + A caller that sets a wool.ContextVar to an unpicklable + value, with ``warnings.simplefilter("error", + category=wool.SerializationWarning)`` active for the + duration of the dispatch attempt. + When: + The caller dispatches a routine — encode_context + discovers the unencodable var. + Then: + It should raise a ``wool.ChainSerializationError`` aggregating + ``wool.SerializationWarning`` instances on + ``.warnings``, with the offending var named in the + warning. The dispatch must NOT leave the caller — strict + mode promotes the warning before the wire frame is + constructed, and the load balancer's worker-health + contract treats only ``RpcError`` as a health concern, + so the error propagates unwrapped to the caller rather + than triggering worker eviction and a + ``NoWorkersAvailable`` fallback. + """ + + # Arrange, act, & assert + async def body(): + scenario = default_scenario() + async with build_pool_from_scenario(scenario, credentials_map): + import socket + + unpicklable = socket.socket() + try: + region_token = routines.REGION.set(unpicklable) # pyright: ignore[reportArgumentType] + try: + with warnings.catch_warnings(): + warnings.simplefilter( + "error", category=wool.SerializationWarning + ) + with pytest.raises(wool.ChainSerializationError) as exc_info: + await routines.read_tenant_id_only() + finally: + routines.REGION.reset(region_token) + finally: + unpicklable.close() + warnings_list = exc_info.value.warnings + assert all( + isinstance(w, wool.SerializationWarning) for w in warnings_list + ), f"Expected only SerializationWarning items, got {warnings_list!r}" + assert any("region" in str(w) for w in warnings_list), ( + f"Expected the offending var to be named in a warning; " + f"got {[str(w) for w in warnings_list]!r}" + ) + + await retry_grpc_internal(body) + + +@pytest.mark.integration +class TestNestedDispatchMidChainMutation: @pytest.mark.asyncio - async def test_single_dispatch_carries_value_and_consumed_token( + async def test_outer_worker_mid_routine_mutation_should_reach_nested_inner_worker( self, credentials_map, retry_grpc_internal ): - """Test one dispatch propagates a current value and a - consumed-token id under the same merged wire entry. + """Test outer routine mutation propagates to a nested dispatch. Given: - A caller that ran ``token = TENANT_ID.set("X")``, then - ``TENANT_ID.reset(token)``, then ``TENANT_ID.set("Y")`` - — the var carries a current value "Y" alongside a - locally-consumed token under the same identity, with a - strong reference held to the token + A caller that sets TENANT_ID to "alpha", an EPHEMERAL pool + sized to permit two distinct workers, and an outer routine + that mutates TENANT_ID to "beta" before dispatching + ``get_tenant_id`` to a nested worker When: - The caller dispatches ``read_value_and_attempt_reset`` - passing the consumed token as the argument + The caller dispatches the outer routine Then: - The routine should observe ``TENANT_ID.get() == "Y"`` - on the worker AND ``TENANT_ID.reset(token)`` should - raise ``RuntimeError("Token has already been used")`` - — confirming the merged entry round-trips both the - value and the consumed-token id to the worker, where - the wire-promoted Token correctly reports as already - used + The outer routine should return "beta" — the inner worker + observed the outer's mid-routine mutation, not the + caller's pre-dispatch value — and the caller should + observe "beta" after the dispatch returns, completing the + bidirectional propagation chain across two worker hops """ # Arrange, act, & assert async def body(): - scenario = default_scenario() + scenario = default_scenario( + shape=RoutineShape.NESTED_COROUTINE, + pool_mode=PoolMode.EPHEMERAL, + ) async with build_pool_from_scenario(scenario, credentials_map): - token = routines.TENANT_ID.set("X") - routines.TENANT_ID.reset(token) - y_token = routines.TENANT_ID.set("Y") + token = routines.TENANT_ID.set("alpha") try: - value, reset_outcome = await routines.read_value_and_attempt_reset( - token + inner_observed = await routines.mutate_then_nested_get_tenant_id( + "beta" ) + caller_after = routines.TENANT_ID.get() finally: - routines.TENANT_ID.reset(y_token) - assert value == "Y" - assert "Token has already been used" in reset_outcome + routines.TENANT_ID.reset(token) + assert inner_observed == "beta" + assert caller_after == "beta" await retry_grpc_internal(body) + +@pytest.mark.integration +class TestWorkerSideContextDecodeFailure: @pytest.mark.asyncio - async def test_consumed_token_carries_across_two_sequential_dispatches( + async def test_worker_side_decode_failure_should_drop_var_and_complete_dispatch( self, credentials_map, retry_grpc_internal ): - """Test two sequential dispatches forward the same merged - entry to the worker on each frame. + """Test a worker-side chain-manifest decode failure degrades gracefully. Given: - A caller that ran ``token = TENANT_ID.set("X")``, then - ``TENANT_ID.reset(token)``, then ``TENANT_ID.set("Y")``, - with a strong reference held to the consumed token + A caller that sets ``REGION`` to a value that pickles + cleanly on the caller but raises when unpickled on the + worker (a version-skew shape), plus a worker-known + ``TENANT_ID``, dispatched through an EPHEMERAL pool under + the default warning filter When: - The caller dispatches ``get_tenant_id`` first (which - takes no arguments), then dispatches - ``accept_token_and_reset`` passing the consumed token + The caller dispatches a routine that reads ``TENANT_ID`` — + ``decode_context`` on the worker cannot decode the + ``REGION`` entry in the dispatch's initial frame Then: - The first dispatch returns "Y" — the value rode forward - in the merged entry — and the second dispatch raises - RuntimeError ("Token has already been used") on the - worker, confirming the consumed-token id rode forward - in the same merged entry on both dispatches + The dispatch should complete, the worker should drop the + offending ``REGION`` entry and emit a + ``SerializationWarning``, and the routine should still + observe its own ``TENANT_ID`` — the version-skew shape + degrades gracefully under default filters rather than + failing the whole dispatch. """ # Arrange, act, & assert async def body(): - scenario = default_scenario() + scenario = default_scenario(pool_mode=PoolMode.EPHEMERAL) async with build_pool_from_scenario(scenario, credentials_map): - token = routines.TENANT_ID.set("X") - routines.TENANT_ID.reset(token) - y_token = routines.TENANT_ID.set("Y") + tenant_token = routines.TENANT_ID.set("decode-fail-tenant") + # REGION carries a value that pickles fine caller-side + # but detonates when the worker unpickles it. + region_token = routines.REGION.set(routines.DecodeBomb()) # pyright: ignore[reportArgumentType] try: - first_value = await routines.get_tenant_id() - assert first_value == "Y" - with pytest.raises(RuntimeError, match="already been used"): - await routines.accept_token_and_reset(token) + with warnings.catch_warnings(record=True): + warnings.simplefilter( + "always", category=wool.SerializationWarning + ) + observed = await routines.read_tenant_id_only() finally: - routines.TENANT_ID.reset(y_token) + routines.REGION.reset(region_token) + routines.TENANT_ID.reset(tenant_token) + # The undecodable REGION entry was dropped on the worker; + # the decodable TENANT_ID still reached the routine. + assert observed == "decode-fail-tenant" await retry_grpc_internal(body) + +@pytest.mark.integration +class TestWorkerCrashMidDispatch: @pytest.mark.asyncio - async def test_merged_entry_rides_forward_across_async_gen_frames( + async def test_worker_crash_mid_dispatch_should_leave_caller_context_intact( self, credentials_map, retry_grpc_internal ): - """Test per-frame forward propagation preserves the merged - entry's value across every ``__anext__`` boundary of an - async-generator routine. + """Test a worker crash mid-dispatch leaves the caller's context intact. Given: - A caller that ran ``token = TENANT_ID.set("X")``, then - ``TENANT_ID.reset(token)``, then ``TENANT_ID.set("Y")`` - — same setup as the single-dispatch case but the routine - is an async generator that yields ``TENANT_ID.get()`` on - each iteration + An armed caller that set ``TENANT_ID`` and an EPHEMERAL + pool whose worker hard-exits its process mid-routine after + mutating its own copy of ``TENANT_ID`` When: - The caller iterates ``stream_tenant_id(count=3)`` to - completion + The caller dispatches the crashing routine Then: - Every yield equals "Y" — the value rides the merged - entry on each per-frame request — and a subsequent - local ``TENANT_ID.reset(token)`` still raises - RuntimeError, confirming the streaming back-propagation - preserved the caller's consumed-token state rather than - silently clobbering it + The caller should observe a dispatch error (a broken + stream surfaces as a gRPC / wool RpcError or an + UnexpectedResponse), and its own ``TENANT_ID`` must still + equal the value it set — no half-merged back-propagation + rode back from the dead worker. """ # Arrange, act, & assert async def body(): scenario = default_scenario(pool_mode=PoolMode.EPHEMERAL) async with build_pool_from_scenario(scenario, credentials_map): - token = routines.TENANT_ID.set("X") - routines.TENANT_ID.reset(token) - y_token = routines.TENANT_ID.set("Y") + token = routines.TENANT_ID.set("caller-pre-crash") try: - yielded: list[str] = [] - async for value in routines.stream_tenant_id(3): - yielded.append(value) - assert yielded == ["Y", "Y", "Y"] - with pytest.raises(RuntimeError, match="already been used"): - routines.TENANT_ID.reset(token) + with pytest.raises((grpc.RpcError, RpcError, UnexpectedResponse)): + await routines.set_tenant_then_crash_worker("worker-mutation") + # The crashed worker's partial mutation must not + # have merged into the caller's context. + caller_value = routines.TENANT_ID.get() finally: - routines.TENANT_ID.reset(y_token) + routines.TENANT_ID.reset(token) + assert caller_value == "caller-pre-crash" await retry_grpc_internal(body) @pytest.mark.integration -class TestExplicitWoolContextBindingAcrossWorkers: +class TestCancellationWithArmedContext: @pytest.mark.asyncio - async def test_explicit_wool_context_binding_propagates_var_to_worker( - self, credentials_map, retry_grpc_internal + async def test_cancel_armed_dispatch_should_leave_caller_context_intact( + self, credentials_map, retry_grpc_internal, tmp_path ): - """Test wool.create_task with an explicit wool.Context binds and dispatches. + """Test cancelling a dispatch that armed a context preserves caller state. Given: - A pre-populated wool.Context seeded with a TENANT_ID - value via Context.run, a child task created with - wool.create_task(coro, context=ctx), and a routine that - reads TENANT_ID on the worker + An armed caller and an EPHEMERAL pool running a routine + that sets ``TENANT_ID`` to a worker-side value, then + sleeps — the routine has a live, mutated ``wool.ContextVar`` + when the cancellation arrives When: - The child task awaits the dispatched routine + The caller cancels the dispatch task after the routine has + suspended on its sleep Then: - The routine should observe the explicitly bound - wool.Context's TENANT_ID value, proving the wool task - factory routes the explicit Context across the wire - independently of the caller's implicit current Context + The caller's ``await`` should raise + ``asyncio.CancelledError``, the worker-side routine should + run its ``except`` arm (sentinel ``"cancelled"``), and the + caller's own ``TENANT_ID`` must equal the value it set — + the partial worker mutation is cleanly dropped under + cancellation, not half-merged back. """ # Arrange, act, & assert async def body(): - scenario = default_scenario() + scenario = default_scenario(pool_mode=PoolMode.EPHEMERAL) + sentinel = tmp_path / "armed_cancel_sentinel.txt" async with build_pool_from_scenario(scenario, credentials_map): - outer_token = routines.TENANT_ID.set("outer-caller-value") + token = routines.TENANT_ID.set("caller-armed-value") try: - bound_ctx = wool.Context() - bound_ctx.run(lambda: routines.TENANT_ID.set("explicit-bound-value")) - - async def _child(): - return await routines.get_tenant_id() - - task = wool.create_task(_child(), context=bound_ctx) - result = await task + task = asyncio.create_task( + routines.set_tenant_then_sleep( + "worker-armed-mutation", str(sentinel), 30.0 + ) + ) + # Wait deterministically for the routine to arm its + # context and suspend on the sleep. + for _ in range(150): + if sentinel.exists() and sentinel.read_text() == "started": + break + await asyncio.sleep(0.1) + else: + raise AssertionError( + "routine never wrote ``started`` — worker " + "startup or dispatch handshake hung" + ) + task.cancel() + with pytest.raises(asyncio.CancelledError): + await task + # Poll for the worker's cancel arm to run. + for _ in range(150): + if sentinel.read_text() == "cancelled": + break + await asyncio.sleep(0.1) + caller_value = routines.TENANT_ID.get() finally: - routines.TENANT_ID.reset(outer_token) - assert result == "explicit-bound-value" + routines.TENANT_ID.reset(token) + assert sentinel.read_text() == "cancelled" + # The cancelled dispatch's partial worker mutation must not + # have corrupted the caller's own armed value. + assert caller_value == "caller-armed-value" await retry_grpc_internal(body) + +@pytest.mark.integration +class TestMidStreamContextDecodeFailure: @pytest.mark.asyncio - async def test_concurrent_dispatch_under_same_wool_context_raises( + async def test_mid_stream_decode_failure_should_drop_var_and_continue_stream( self, credentials_map, retry_grpc_internal ): - """Test two concurrent tasks bound to the same wool.Context fail. + """Test a malformed context on a mid-stream frame degrades gracefully. Given: - A single wool.Context and two child tasks each created - via wool.create_task(coro, context=same_ctx) that await - a remote routine + An async-generator routine echoing ``TENANT_ID`` per + ``asend``, and a caller that — after the first frame — + additionally sets ``REGION`` to a value that pickles + cleanly but fails to decode on the worker When: - Both tasks are gathered concurrently + The caller drives the generator with ``asend``, so the + malformed ``REGION`` rides the mid-stream request frame + and ``_step``'s per-step ``decode_context`` cannot decode + it Then: - One task should complete successfully and the other - should raise RuntimeError because at most one task may - run inside a given wool.Context at a time — the wool - task factory's _context_scope first-task-wins guard fires - before the second task acquires _guard + The worker should drop the undecodable ``REGION`` entry + and continue the stream — the echoed value still tracks + the decodable ``TENANT_ID``, proving the mid-stream decode + path degrades gracefully under default filters. """ # Arrange, act, & assert async def body(): - scenario = default_scenario(pool_mode=PoolMode.EPHEMERAL) + scenario = default_scenario(shape=RoutineShape.ASYNC_GEN_ASEND) async with build_pool_from_scenario(scenario, credentials_map): - shared_ctx = wool.Context() - shared_ctx.run(lambda: routines.TENANT_ID.set("shared-context-value")) - - async def _slow_dispatch(): - # An ``asyncio.sleep(0)`` lets the scheduler park - # the first task before the second is created, so - # both create_task calls execute while the first - # is still mid-dispatch and the second hits the - # first-task-wins guard. - await asyncio.sleep(0) - return await routines.get_tenant_id() - - first_coro = _slow_dispatch() - second_coro = _slow_dispatch() - first = wool.create_task(first_coro, context=shared_ctx) - with warnings.catch_warnings(record=True) as captured: - warnings.simplefilter("always", category=RuntimeWarning) + gen = routines.echo_tenant_id_on_send(2) + try: + first = await gen.__anext__() + assert first == "ready" + tenant_token = routines.TENANT_ID.set("mid-stream-tenant") + # REGION now carries a worker-undecodable value; + # it rides the asend frame's context. + region_token = routines.REGION.set(routines.DecodeBomb()) # pyright: ignore[reportArgumentType] try: - second = wool.create_task(second_coro, context=shared_ctx) - except RuntimeError as exc: - # First-task-wins guard fired synchronously inside - # the factory before the second task was even - # scheduled. Close the unawaited coroutine and - # await the first to get its successful result. - second_coro.close() - successes = [await first] - failures = [exc] - else: - outcomes = await asyncio.gather( - first, second, return_exceptions=True - ) - successes = [ - o for o in outcomes if not isinstance(o, BaseException) - ] - failures = [o for o in outcomes if isinstance(o, BaseException)] - # Force collection in this frame so any "coroutine - # was never awaited" warning surfaces inside the - # catch_warnings scope rather than at teardown. - import gc - - gc.collect() - leaked = [ - w - for w in captured - if issubclass(w.category, RuntimeWarning) - and "was never awaited" in str(w.message) - ] - assert leaked == [], ( - "Guard-rejected coroutine must be closed by _context_scope, " - f"not leaked at GC; saw: {[str(w.message) for w in leaked]}" - ) - - assert len(failures) == 1 - assert isinstance(failures[0], RuntimeError) - assert "first-task-wins" in str(failures[0]) - assert successes == ["shared-context-value"] + with warnings.catch_warnings(record=True): + warnings.simplefilter( + "always", category=wool.SerializationWarning + ) + echoed = await gen.asend(None) + finally: + routines.REGION.reset(region_token) + routines.TENANT_ID.reset(tenant_token) + finally: + await gen.aclose() + # The malformed REGION was dropped; the decodable + # TENANT_ID still reached the mid-stream worker frame. + assert echoed == "mid-stream-tenant" await retry_grpc_internal(body) @pytest.mark.integration -class TestRuntimeContextDispatchTimeoutAcrossWorkers: +class TestConcurrentDispatchesShareParentVar: @pytest.mark.asyncio - async def test_caller_runtime_context_dispatch_timeout_visible_on_worker( + async def test_concurrent_fan_out_should_read_parent_set_var_without_contamination( self, credentials_map, retry_grpc_internal ): - """Test caller-side dispatch_timeout overrides ride the wire. + """Test concurrent fan-out dispatches read a parent-set var cleanly. Given: - A caller that wraps a dispatch in - ``with wool.RuntimeContext(dispatch_timeout=X):`` and a - routine that returns the worker-side value of - ``dispatch_timeout.get()`` + A parent that sets ``TENANT_ID`` once (the request-scoped + tenant-id pattern) and an EPHEMERAL pool, then fans out + two concurrent child-task dispatches that each read the + var and mutate their own copy When: - The caller dispatches the routine inside the override block + Both dispatches are gathered Then: - The routine should observe the caller's override value, - proving the RuntimeContext snapshot rode through the - Task.runtime_context wire field and was restored on the - worker before the routine body executed + Each dispatch should observe the parent-set value, and + after the gather the parent's own ``TENANT_ID`` must be + unchanged — neither child's back-propagated mutation + cross-contaminates the other or the parent. """ # Arrange, act, & assert async def body(): - scenario = default_scenario() + scenario = default_scenario(pool_mode=PoolMode.EPHEMERAL) async with build_pool_from_scenario(scenario, credentials_map): - with wool.RuntimeContext(dispatch_timeout=12.5): - observed = await routines.read_dispatch_timeout() - assert observed == 12.5 + token = routines.TENANT_ID.set("request-scoped-tenant") + try: + + async def _fan_out(suffix: str) -> tuple[str, str]: + # Each child reads the parent-set value, then + # mutates its own forked copy. + before = await routines.get_tenant_id() + worker_after = await routines.mutate_and_read_tenant_id() + return before, worker_after + + results = await asyncio.gather( + asyncio.create_task(_fan_out("a")), + asyncio.create_task(_fan_out("b")), + ) + parent_after = routines.TENANT_ID.get() + finally: + routines.TENANT_ID.reset(token) + # Both children saw the parent's request-scoped value. + assert results[0][0] == "request-scoped-tenant" + assert results[1][0] == "request-scoped-tenant" + # Each child's worker mutation is its own. + assert results[0][1] == "mutated_on_worker" + assert results[1][1] == "mutated_on_worker" + # The parent's var is untouched by either child's + # back-propagation — the child tasks fork the chain. + assert parent_after == "request-scoped-tenant" await retry_grpc_internal(body) + +@pytest.mark.integration +class TestAsyncGenSetAndResetAcrossYield: @pytest.mark.asyncio - async def test_caller_dispatch_timeout_var_propagates_without_runtime_context( + async def test_async_gen_set_then_reset_should_back_propagate_per_yield( self, credentials_map, retry_grpc_internal ): - """Test the ambient dispatch_timeout var alone propagates. + """Test an async-gen set+reset across a yield back-propagates per frame. Given: - A caller that sets the module-level ``dispatch_timeout`` - stdlib ContextVar (no explicit RuntimeContext block) and a - routine that returns the worker-side value + An async-generator routine that sets ``TENANT_ID`` and + yields, then resets the var via the set's own Token and + yields again When: - The caller dispatches the routine + The caller iterates the generator, reading its own + ``TENANT_ID`` after each yield Then: - The routine should observe the caller's set value because - ``RuntimeContext.get_current`` captures the ambient - ``dispatch_timeout`` at Task construction time and the - captured snapshot rides the wire + After the first yield the caller should observe the + worker's set value, and after the second yield it should + observe the var reverted to its default — per-yield + back-propagation carries both the set and the reset. """ # Arrange, act, & assert async def body(): - scenario = default_scenario() + scenario = default_scenario(shape=RoutineShape.ASYNC_GEN_ANEXT) async with build_pool_from_scenario(scenario, credentials_map): - token = dispatch_timeout.set(7.25) + caller_reads: list[str] = [] + gen = routines.set_and_reset_tenant_across_yield("worker-set-value") try: - observed = await routines.read_dispatch_timeout() + async for marker in gen: + caller_reads.append(routines.TENANT_ID.get()) + assert marker in ("set", "reset") finally: - dispatch_timeout.reset(token) - assert observed == 7.25 + await gen.aclose() + # After the first yield the set is visible; after the + # second the reset reverted the var to its default. + assert caller_reads == ["worker-set-value", "unknown"] await retry_grpc_internal(body) @pytest.mark.integration -class TestContextDecodeWarningAcrossWorkers: +class TestCopyContextWidthAcrossWorkers: @pytest.mark.asyncio - async def test_unpicklable_var_value_emits_decode_warning_and_dispatch_completes( + async def test_copy_context_should_enumerate_one_plus_n_after_dispatch( self, credentials_map, retry_grpc_internal ): - """Test unpicklable wool.ContextVar value is dropped and dispatch survives. + """Test copy_context enumerates 1+N wool variables after a dispatch. Given: - A caller that sets ``TENANT_ID`` to a known value and a - second wool.ContextVar (REGION) to an unpicklable lambda, - and a routine that reads ``TENANT_ID`` + A caller with an unarmed context and a routine that + counts wool-owned ``contextvars.ContextVar``s visible in a + worker-side ``contextvars.copy_context()`` When: - The caller dispatches the routine under default warning - filters + The caller arms its context by binding two + ``wool.ContextVar``s and dispatches the counting routine Then: - ``wool.ContextDecodeWarning`` should be emitted on the - caller side for the unpicklable var; the dispatch should - still complete; and the routine should observe the - propagated ``TENANT_ID`` value — primary signal preserved, - ancillary failure surfaced as a warning + The unarmed caller's own copy_context should enumerate + zero wool variables, and the worker — running with two + bound variables forward-propagated — should report + ``1 + 2`` (the context variable plus one backing + variable per bound var). """ # Arrange, act, & assert async def body(): scenario = default_scenario() async with build_pool_from_scenario(scenario, credentials_map): - tenant_token = routines.TENANT_ID.set("survives-encode") - # Local lambdas are not picklable across processes - # (cloudpickle handles many cases, but a closure over - # a local non-importable scope is rejected by the - # default pickle protocol via cloudpickle.dumps when - # paired with a file-local symbol that has no qualname - # path the worker can resolve). Use an open file - # handle as a robust unpicklable sentinel. - import socket + # Unarmed: no wool-owned variables in a copy. + unarmed = [ + var + for var in contextvars.copy_context() + if var.name.startswith("__wool") + ] - unpicklable = socket.socket() + tenant_token = routines.TENANT_ID.set("width-tenant") + region_token = routines.REGION.set("width-region") try: - region_token = routines.REGION.set(unpicklable) # pyright: ignore[reportArgumentType] - try: - with warnings.catch_warnings(record=True) as captured: - warnings.simplefilter( - "always", category=wool.ContextDecodeWarning - ) - result = await routines.read_tenant_id_only() - finally: - routines.REGION.reset(region_token) + worker_count = await routines.count_wool_context_vars() finally: - unpicklable.close() + routines.REGION.reset(region_token) routines.TENANT_ID.reset(tenant_token) - decode_warnings = [ - w for w in captured if issubclass(w.category, wool.ContextDecodeWarning) - ] - assert decode_warnings, ( - f"Expected at least one ContextDecodeWarning, got {captured!r}" - ) - assert any("region" in str(w.message) for w in decode_warnings), ( - f"Expected the warning to name the offending var; " - f"got {[str(w.message) for w in decode_warnings]!r}" - ) - assert result == "survives-encode" + assert unarmed == [] + # 1 context variable + 2 backing variables for the two + # bound wool.ContextVars propagated to the worker. + assert worker_count == 3 await retry_grpc_internal(body) + +@pytest.mark.integration +class TestRoutineLookupErrorBackPropagation: @pytest.mark.asyncio - async def test_unpicklable_var_value_under_strict_mode_raises_group( + async def test_get_on_unbound_default_less_var_should_surface_lookup_error( self, credentials_map, retry_grpc_internal ): - """Test caller-side strict mode aggregates encode failures into a group. + """Test a routine LookupError on an unbound var surfaces to the caller. Given: - A caller that sets a wool.ContextVar to an unpicklable - value, with ``warnings.simplefilter("error", - category=wool.ContextDecodeWarning)`` active for the - duration of the dispatch attempt. + A routine that declares a default-less ``wool.ContextVar`` + and calls ``get()`` on it while it is unbound When: - The caller dispatches a routine — Task.to_protobuf - invokes Context.to_protobuf which discovers the - unencodable var. + The caller dispatches the routine Then: - ``Context.to_protobuf`` should raise a - ``BaseExceptionGroup`` whose peers are - ``wool.ContextDecodeWarning`` instances naming the - offending var, and the dispatch must NOT leave the - caller — strict mode promotes the warning before the - wire frame is constructed, and the load balancer's - worker-health contract treats only ``RpcError`` as a - health concern, so the group propagates unwrapped to the - caller rather than triggering worker eviction and a - ``NoWorkersAvailable`` fallback. + The ``LookupError`` raised inside the routine should + surface to the caller's ``await`` through the exception + back-propagation path. """ # Arrange, act, & assert async def body(): scenario = default_scenario() async with build_pool_from_scenario(scenario, credentials_map): - import socket - - unpicklable = socket.socket() - try: - region_token = routines.REGION.set(unpicklable) # pyright: ignore[reportArgumentType] - try: - with warnings.catch_warnings(): - warnings.simplefilter( - "error", category=wool.ContextDecodeWarning - ) - with pytest.raises(BaseExceptionGroup) as exc_info: - await routines.read_tenant_id_only() - finally: - routines.REGION.reset(region_token) - finally: - unpicklable.close() - peers = exc_info.value.exceptions - assert all(isinstance(p, wool.ContextDecodeWarning) for p in peers), ( - f"Expected only ContextDecodeWarning peers, got {peers!r}" - ) - assert any("region" in str(p) for p in peers), ( - f"Expected the offending var to be named in a peer; " - f"got {[str(p) for p in peers]!r}" - ) + with pytest.raises(LookupError): + await routines.read_unbound_default_less_var( + "synthetic_unbound_ns", "never_bound_key" + ) await retry_grpc_internal(body) @pytest.mark.integration -class TestWoolCopyContextWithDispatch: +class TestMultiWorkerFanOutWithContext: @pytest.mark.asyncio - async def test_wool_copy_context_seeded_value_propagates_under_attach( + async def test_same_value_should_fan_out_to_distinct_workers_independently( self, credentials_map, retry_grpc_internal ): - """Test wool.copy_context.run sets a value the dispatch ships. + """Test one armed caller fans the same value to distinct workers. Given: - A caller that calls ``wool.copy_context()`` to fork the - current wool.Context, runs a setter inside the forked - Context to seed a TENANT_ID value, and dispatches the - routine while ``attached`` to the forked Context + An armed caller that set ``TENANT_ID`` once and an + EPHEMERAL pool with two worker processes When: - The dispatched routine reads TENANT_ID + The caller fans out two concurrent dispatches that each + mutate their worker-side copy of the var Then: - The routine should return the seeded value, not the - outer caller's value, proving the forked copy is the - active source of truth and ships its bindings to the - worker independently of the implicit current Context + Both dispatches should observe the caller's value as + their starting point and each should report its own + worker-side mutation — each worker mounts the caller's + context independently with no shared mutable state. """ # Arrange, act, & assert async def body(): - scenario = default_scenario() + scenario = default_scenario(pool_mode=PoolMode.EPHEMERAL) async with build_pool_from_scenario(scenario, credentials_map): - outer_token = routines.TENANT_ID.set("outer-original") + token = routines.TENANT_ID.set("fan-out-seed") try: - forked = wool.copy_context() - forked.run(lambda: routines.TENANT_ID.set("forked-seed")) - - with attached(forked): - result = await routines.get_tenant_id() + seen = await asyncio.gather( + routines.get_tenant_id(), + routines.get_tenant_id(), + ) + mutated = await asyncio.gather( + routines.mutate_and_read_tenant_id(), + routines.mutate_and_read_tenant_id(), + ) finally: - routines.TENANT_ID.reset(outer_token) - assert result == "forked-seed" + routines.TENANT_ID.reset(token) + # Both workers mounted the caller's seed independently. + assert seen == ["fan-out-seed", "fan-out-seed"] + assert mutated == ["mutated_on_worker", "mutated_on_worker"] await retry_grpc_internal(body) + +@pytest.mark.integration +class TestChainContentionAcrossDispatch: @pytest.mark.asyncio - async def test_wool_copy_context_has_distinct_id_from_source( + async def test_off_owner_thread_var_access_should_raise_chain_contention( self, credentials_map, retry_grpc_internal ): - """Test wool.copy_context produces a fresh logical chain id. + """Test a worker re-entering its armed chain off-thread raises. Given: - A caller that captures ``wool.current_context().id`` then - calls ``wool.copy_context()`` and dispatches a routine - that returns the worker-side ``current_context().id.hex`` - inside the forked Context + An EPHEMERAL pool running a routine that sets a + ``wool.ContextVar`` — arming its chain on the worker loop + thread — then reads the same var from a worker thread via + ``asyncio.to_thread``, which copies the armed chain into + the executor thread When: - The dispatch runs under ``with attached(forked):`` + The caller dispatches the routine Then: - The worker-observed id should equal the forked Context's - id and differ from the outer caller's captured id — - ``copy_context`` mints a fresh chain id rather than - reusing the source's + The off-owner-thread ``get()`` should raise + ``wool.ChainContention`` and that exception should + surface to the caller's ``await`` through the exception + back-propagation path. """ # Arrange, act, & assert async def body(): - scenario = default_scenario() + scenario = default_scenario(pool_mode=PoolMode.EPHEMERAL) async with build_pool_from_scenario(scenario, credentials_map): - outer_id = wool.current_context().id - forked = wool.copy_context() - - with attached(forked): - observed_hex = await routines.return_current_context_id_hex() - assert observed_hex == forked.id.hex - assert forked.id != outer_id + with pytest.raises(wool.ChainContention): + await routines.reenter_armed_chain_off_owner_thread( + "armed-on-loop-thread" + ) await retry_grpc_internal(body) - -@pytest.mark.integration -class TestNestedDispatchMidChainMutation: @pytest.mark.asyncio - async def test_outer_worker_mid_routine_mutation_reaches_nested_inner_worker( + async def test_wool_to_thread_should_fork_armed_chain_off_owner_thread( self, credentials_map, retry_grpc_internal ): - """Test outer routine mutation propagates to a nested dispatch. + """Test wool.to_thread forks a worker's armed chain off-thread cleanly. Given: - A caller that sets TENANT_ID to "alpha", an EPHEMERAL pool - sized to permit two distinct workers, and an outer routine - that mutates TENANT_ID to "beta" before dispatching - ``get_tenant_id`` to a nested worker + An EPHEMERAL pool running a routine that sets a + ``wool.ContextVar`` — arming its chain on the worker loop + thread — then reads the same var from a worker thread via + ``wool.to_thread``, the supported context-propagating + offload that forks the chain onto a fresh, detached chain + owned by the worker thread. When: - The caller dispatches the outer routine + The caller dispatches the routine. Then: - The outer routine should return "beta" — the inner worker - observed the outer's mid-routine mutation, not the - caller's pre-dispatch value — and the caller should - observe "beta" after the dispatch returns, completing the - bidirectional propagation chain across two worker hops + The off-thread ``get()`` should return the value the + routine set — the fork copies the caller's bindings under a + new chain UUID owned by the worker thread, so the read + re-arms cleanly rather than tripping + ``wool.ChainContention``. """ # Arrange, act, & assert async def body(): - scenario = default_scenario( - shape=RoutineShape.NESTED_COROUTINE, - pool_mode=PoolMode.EPHEMERAL, - ) + scenario = default_scenario(pool_mode=PoolMode.EPHEMERAL) async with build_pool_from_scenario(scenario, credentials_map): - token = routines.TENANT_ID.set("alpha") - try: - inner_observed = await routines.mutate_then_nested_get_tenant_id( - "beta" - ) - caller_after = routines.TENANT_ID.get() - finally: - routines.TENANT_ID.reset(token) - assert inner_observed == "beta" - assert caller_after == "beta" + result = await routines.read_var_off_thread_via_wool_to_thread( + "forked-off-thread" + ) + assert result == "forked-off-thread" await retry_grpc_internal(body) - -def _tenant_aware_backpressure_hook(ctx): - """Module-level (picklable) sync hook that rejects when TENANT_ID == "reject-me". - - Reads ``routines.TENANT_ID`` to verify the caller's wire-shipped - wool.ContextVar snapshot has been applied to the worker's - handler context before the hook runs (per the dispatch - handler's documented ordering). - """ - return routines.TENANT_ID.get() == "reject-me" - - -@pytest.mark.integration -class TestBackpressureReadsCallerShippedContextVar: @pytest.mark.asyncio - async def test_backpressure_hook_observes_caller_tenant_var( - self, retry_grpc_internal + async def test_interleaved_async_gen_dispatches_should_share_caller_chain( + self, credentials_map, retry_grpc_internal ): - """Test backpressure hook reads the caller's wool.ContextVar value. + """Test interleaving two async-generator dispatches on one chain. Given: - A single-worker pool whose backpressure hook returns True - (reject) when ``routines.TENANT_ID.get() == "reject-me"``, - and accepts otherwise. + An armed caller whose context carries a single chain id and + an EPHEMERAL pool, with two concurrently-open async-generator + dispatches of a routine that yields the worker-side chain id. When: - Two coroutine dispatches run sequentially under different - caller-side TENANT_ID values: first "reject-me", then - "ok". + The caller advances the two generators in strict alternation + with ``anext`` from its own single task, never advancing + both at once. Then: - The first dispatch should raise NoWorkersAvailable - (RESOURCE_EXHAUSTED from the hook), and the second should - succeed — proving the hook observes the caller's wire- - shipped TENANT_ID, not a stale or default value. + It should drive both to exhaustion without raising + wool.ChainContention, and every yielded value should + equal the caller's chain id — serialized interleaving never + runs the shared chain from two runners at once. """ - from functools import partial - from wool.runtime.loadbalancer.base import NoWorkersAvailable - from wool.runtime.loadbalancer.roundrobin import RoundRobinLoadBalancer - from wool.runtime.worker.local import LocalWorker - from wool.runtime.worker.pool import WorkerPool - - # Arrange, act, & assert async def body(): - pool = WorkerPool( - size=1, - loadbalancer=RoundRobinLoadBalancer, - worker=partial( - LocalWorker, backpressure=_tenant_aware_backpressure_hook - ), + scenario = default_scenario( + shape=RoutineShape.ASYNC_GEN_ANEXT, + pool_mode=PoolMode.EPHEMERAL, ) - - async with pool: - reject_token = routines.TENANT_ID.set("reject-me") + async with build_pool_from_scenario(scenario, credentials_map): + # Arrange + routines.TENANT_ID.set("armed") + caller = wool.__chain__.get(None) + assert caller is not None + a = routines.stream_chain_id_hex(3) + b = routines.stream_chain_id_hex(3) + + # Act + collected: list[str] = [] try: - with pytest.raises(NoWorkersAvailable): - await routines.add(1, 2) + for _ in range(3): + collected.append(await anext(a)) + collected.append(await anext(b)) finally: - routines.TENANT_ID.reset(reject_token) + await a.aclose() + await b.aclose() - accept_token = routines.TENANT_ID.set("ok") - try: - result = await routines.add(1, 2) - finally: - routines.TENANT_ID.reset(accept_token) - assert result == 3 + # Assert + assert collected == [caller.id.hex] * 6 await retry_grpc_internal(body) diff --git a/wool/tests/integration/test_pool_composition.py b/wool/tests/integration/test_pool_composition.py index b2099549..05a58df1 100644 --- a/wool/tests/integration/test_pool_composition.py +++ b/wool/tests/integration/test_pool_composition.py @@ -5,6 +5,7 @@ import pytest +import wool from wool.runtime.discovery.local import LocalDiscovery from wool.runtime.loadbalancer.base import NoWorkersAvailable from wool.runtime.loadbalancer.roundrobin import RoundRobinLoadBalancer @@ -33,7 +34,7 @@ @pytest.mark.integration class TestPoolComposition: @pytest.mark.asyncio - async def test_build_pool_from_scenario_with_default_mode( + async def test_build_pool_from_scenario_should_return_result_when_default_mode( self, credentials_map, retry_grpc_internal ): """Test building a pool with DEFAULT mode. @@ -76,7 +77,7 @@ async def body(): await retry_grpc_internal(body) @pytest.mark.asyncio - async def test_build_pool_from_scenario_with_ephemeral_mode( + async def test_build_pool_from_scenario_should_return_result_when_ephemeral_mode( self, credentials_map, retry_grpc_internal ): """Test building a pool with EPHEMERAL mode and size=2. @@ -119,7 +120,7 @@ async def body(): await retry_grpc_internal(body) @pytest.mark.asyncio - async def test_build_pool_from_scenario_with_durable_mode( + async def test_build_pool_from_scenario_should_return_result_when_durable_mode( self, credentials_map, retry_grpc_internal ): """Test building a pool with DURABLE mode. @@ -162,7 +163,7 @@ async def body(): await retry_grpc_internal(body) @pytest.mark.asyncio - async def test_build_pool_from_scenario_with_hybrid_mode( + async def test_build_pool_from_scenario_should_return_result_when_hybrid_mode( self, credentials_map, retry_grpc_internal ): """Test building a pool with HYBRID mode and LOCAL_CALLABLE discovery. @@ -205,7 +206,7 @@ async def body(): await retry_grpc_internal(body) @pytest.mark.asyncio - async def test_build_pool_from_scenario_with_durable_joined_local( + async def test_build_pool_from_scenario_should_return_result_when_durable_joined( self, credentials_map, retry_grpc_internal ): """Test building a pool with DURABLE_JOINED mode and LOCAL_CALLABLE. @@ -250,7 +251,7 @@ async def body(): await retry_grpc_internal(body) @pytest.mark.asyncio - async def test_build_pool_from_scenario_with_restrictive_opts( + async def test_build_pool_from_scenario_should_return_result_when_restrictive_opts( self, credentials_map, retry_grpc_internal ): """Test building a pool with restrictive message size options. @@ -293,7 +294,7 @@ async def body(): await retry_grpc_internal(body) @pytest.mark.asyncio - async def test_build_pool_from_scenario_with_keepalive_opts( + async def test_build_pool_from_scenario_should_return_result_when_keepalive_opts( self, credentials_map, retry_grpc_internal ): """Test building a pool with keepalive worker options. @@ -337,18 +338,24 @@ async def body(): await retry_grpc_internal(body) @pytest.mark.asyncio - async def test_build_pool_from_scenario_with_dispatch_timeout( + async def test_build_pool_from_scenario_should_propagate_dispatch_timeout_to_worker( self, credentials_map, retry_grpc_internal ): - """Test building a pool with dispatch_timeout set in the context. + """Test the VIA_DISPATCH_TIMEOUT_VAR dimension propagates to the worker. Given: - A complete scenario using VIA_DISPATCH_TIMEOUT_VAR timeout. + A complete scenario using VIA_DISPATCH_TIMEOUT_VAR timeout, + which makes ``build_pool_from_scenario`` set the ambient + ``dispatch_timeout`` var before the dispatch. When: - A pool is built with dispatch_timeout set in the ambient - context and a coroutine is dispatched. + A pool is built and a coroutine is dispatched, then a + routine that returns the worker-side ``dispatch_timeout`` + value is dispatched. Then: - It should return the correct result with the timeout active. + The first dispatch should return its result and the + second should report the builder's ``dispatch_timeout`` + value — proving the dimension's ambient var rides the wire + and is restored on the worker. """ async def body(): @@ -373,14 +380,72 @@ async def body(): # Act async with build_pool_from_scenario(scenario, credentials_map): result = await invoke_routine(scenario) + worker_timeout = await routines.read_dispatch_timeout() # Assert assert result == 3 + # ``build_pool_from_scenario`` sets dispatch_timeout=30.0 + # for the VIA_DISPATCH_TIMEOUT_VAR dimension; the worker + # must observe that value through the wire-shipped + # RuntimeContext. + assert worker_timeout == 30.0 await retry_grpc_internal(body) @pytest.mark.asyncio - async def test_build_pool_from_scenario_with_shared_discovery( + async def test_runtime_context_manager_should_propagate_dispatch_timeout( + self, credentials_map, retry_grpc_internal + ): + """Test a caller-side wool.RuntimeContext rides the dispatch wire. + + Given: + A default-timeout scenario pool and a caller that wraps a + dispatch in ``with wool.RuntimeContext(dispatch_timeout=X)``. + When: + A routine reading the worker-side ``dispatch_timeout`` is + dispatched inside the block. + Then: + The worker should observe ``X`` — ``RuntimeContext.__enter__`` + sets the ambient timeout, which rides the wire — and the + ambient ``dispatch_timeout`` should be restored to its prior + value once the block exits (``__exit__``). + """ + from wool.runtime.context.runtime import dispatch_timeout + + async def body(): + # Arrange + scenario = Scenario( + shape=RoutineShape.COROUTINE, + pool_mode=PoolMode.DEFAULT, + discovery=DiscoveryFactory.NONE, + lb=LbFactory.CLASS_REF, + credential=CredentialType.INSECURE, + options=WorkerOptionsKind.DEFAULT, + timeout=TimeoutKind.NONE, + binding=RoutineBinding.MODULE_FUNCTION, + lazy=LazyMode.LAZY, + backpressure=BackpressureMode.NONE, + ctx_var_1=ContextVarPattern.NONE, + ctx_var_2=ContextVarPattern.NONE, + ctx_var_3=ContextVarPattern.NONE, + quorum=QuorumMode.DEFAULT, + ) + before = dispatch_timeout.get() + + # Act + async with build_pool_from_scenario(scenario, credentials_map): + with wool.RuntimeContext(dispatch_timeout=7.5): + worker_timeout = await routines.read_dispatch_timeout() + restored = dispatch_timeout.get() + + # Assert + assert worker_timeout == 7.5 + assert restored == before + + await retry_grpc_internal(body) + + @pytest.mark.asyncio + async def test_build_pool_from_scenario_should_return_result_when_shared_discovery( self, credentials_map, retry_grpc_internal ): """Test two pools sharing the same discovery subscriber. @@ -426,7 +491,7 @@ async def body(): await retry_grpc_internal(body) @pytest.mark.asyncio - async def test_build_pool_from_scenario_with_sync_backpressure( + async def test_build_pool_from_scenario_should_return_result_when_sync_backpressure( self, credentials_map, retry_grpc_internal ): """Test building a pool with a sync backpressure accept hook. @@ -470,7 +535,7 @@ async def body(): await retry_grpc_internal(body) @pytest.mark.asyncio - async def test_build_pool_from_scenario_with_async_backpressure( + async def test_build_pool_from_scenario_should_return_result_when_async_backpressure( self, credentials_map, retry_grpc_internal ): """Test building a pool with an async backpressure accept hook. @@ -527,7 +592,9 @@ async def _async_reject_hook(ctx): @pytest.mark.integration class TestBackpressureRejection: @pytest.mark.asyncio - async def test_sync_backpressure_rejection(self, retry_grpc_internal): + async def test_sync_backpressure_should_raise_no_workers_available( + self, retry_grpc_internal + ): """Test sync backpressure hook rejects task end-to-end. Given: @@ -556,7 +623,9 @@ async def body(): await retry_grpc_internal(body) @pytest.mark.asyncio - async def test_async_backpressure_rejection(self, retry_grpc_internal): + async def test_async_backpressure_should_raise_no_workers_available( + self, retry_grpc_internal + ): """Test async backpressure hook rejects task end-to-end. Given: @@ -585,7 +654,9 @@ async def body(): await retry_grpc_internal(body) @pytest.mark.asyncio - async def test_backpressure_fallback_to_accepting_worker(self, retry_grpc_internal): + async def test_backpressure_should_fall_through_to_accepting_worker( + self, retry_grpc_internal + ): """Test load balancer falls through to an accepting worker. Given: diff --git a/wool/tests/integration/test_unified_driver.py b/wool/tests/integration/test_unified_driver.py index fc0ea5ba..63288920 100644 --- a/wool/tests/integration/test_unified_driver.py +++ b/wool/tests/integration/test_unified_driver.py @@ -38,7 +38,7 @@ class TestUnifiedDriverShape: async def test_coroutine_dispatch_completes_across_process_boundary( self, credentials_map, retry_grpc_internal ): - """Test a single coroutine dispatch round-trips across the gRPC + worker-process boundary. + """Test a single coroutine dispatch round-trips across the worker boundary. Given: A coroutine routine with no caller-side wool.ContextVar @@ -69,7 +69,7 @@ async def body(): # Assert assert result == 5 # The routine never set TENANT_ID/REGION on the worker, so - # the caller's snapshot must equal its pre-dispatch value + # the caller's context must equal its pre-dispatch value # — the back-propagation path is silent when the worker # made no mutations. assert routines.TENANT_ID.get() == tenant_before @@ -81,7 +81,7 @@ async def body(): async def test_async_gen_dispatch_exhausts_after_single_yield( self, credentials_map, retry_grpc_internal ): - """Test an async-generator that yields exactly once terminates with StopAsyncIteration. + """Test an async-generator yielding once terminates with StopAsyncIteration. Given: An async-generator routine that yields a single value and @@ -385,13 +385,13 @@ async def test_dispatch_completes_under_strict_mode_with_decodable_vars( A coroutine routine with caller-side wool.ContextVars set to fully-decodable values and a DEFAULT pool, with ``warnings.simplefilter("error", - category=wool.ContextDecodeWarning)`` active for the + category=wool.SerializationWarning)`` active for the dispatch. When: The caller dispatches the routine. Then: The dispatch should complete without raising and without - emitting a single ``wool.ContextDecodeWarning``, since + emitting a single ``wool.SerializationWarning``, since every shipped var is decodable on the worker. """ @@ -410,7 +410,7 @@ async def body(): try: with warnings.catch_warnings(record=True) as captured: warnings.simplefilter( - "error", category=wool.ContextDecodeWarning + "error", category=wool.SerializationWarning ) result = await routines.get_tenant_id() finally: @@ -420,10 +420,10 @@ async def body(): # Assert assert result == "strict-tenant" decode_warnings = [ - w for w in captured if issubclass(w.category, wool.ContextDecodeWarning) + w for w in captured if issubclass(w.category, wool.SerializationWarning) ] assert decode_warnings == [], ( - f"Expected no ContextDecodeWarning under strict mode " + f"Expected no SerializationWarning under strict mode " f"with decodable vars, got {decode_warnings!r}" ) @@ -444,7 +444,7 @@ async def test_nested_async_gen_dispatch_keeps_context_live_across_yields( Then: Each yielded value should equal the per-step value the outer worker set, proving ``_current_task`` and - ``wool.Context`` stay live across the streaming routine's + ``wool.Chain`` stay live across the streaming routine's lifespan and every nested dispatch finds the streaming task as the caller. """ diff --git a/wool/tests/protocol/test_wire.py b/wool/tests/protocol/test_wire.py index da5c8b4a..0784119e 100644 --- a/wool/tests/protocol/test_wire.py +++ b/wool/tests/protocol/test_wire.py @@ -4,7 +4,7 @@ EXPECTED_MESSAGE_EXPORTS = [ "Ack", - "Context", + "ChainManifest", "Message", "Nack", "Request", @@ -91,27 +91,29 @@ def test_task_fields(self): assert task.kwargs == b"kwargs-bytes" assert task.timeout == 30 - def test_context_fields(self): - """Test Context message field round-trip. + def test_chain_manifest_fields(self): + """Test ChainManifest message field round-trip. Given: An id hex string and a list of ContextVar entries. When: - A Context message is constructed. + A ChainManifest message is constructed. Then: Both fields round-trip correctly and each ContextVar - entry preserves its namespace, name, value, and - consumed_tokens. + entry preserves its namespace, name, and value. The + previous ``consumed_tokens`` repeated field is removed + from the wire schema (cross-process token transport is + deferred to a separate Wool Token wrapper ride; see + issue #231). """ # Arrange, act, & assert - ctx = protocol.Context( + ctx = protocol.ChainManifest( id="abc", vars=[ protocol.ContextVar( namespace="ns", name="key", value=b"value", - consumed_tokens=["abc123"], ) ], ) @@ -120,7 +122,6 @@ def test_context_fields(self): entry = ctx.vars[0] assert (entry.namespace, entry.name) == ("ns", "key") assert entry.value == b"value" - assert list(entry.consumed_tokens) == ["abc123"] def test_runtime_context_fields_with_dispatch_timeout(self): """Test RuntimeContext exposes dispatch_timeout when supplied. @@ -400,6 +401,6 @@ def test_protobuf_import_error_class_exists(self): Then: It should be a subclass of ImportError. """ - from wool.protocol.exception import ProtobufImportError + from wool.protocol.exceptions import ProtobufImportError assert issubclass(ProtobufImportError, ImportError) diff --git a/wool/tests/runtime/context/conftest.py b/wool/tests/runtime/context/conftest.py index 1888c2d7..5bc9ac76 100644 --- a/wool/tests/runtime/context/conftest.py +++ b/wool/tests/runtime/context/conftest.py @@ -5,6 +5,11 @@ @pytest.fixture(autouse=True) def isolated_context(): - """Install a fresh wool.Context for the duration of each test.""" + """Run each test under a fresh, unarmed Wool context. + + Resets the wool-owned context ``contextvars.ContextVar`` so a + :meth:`wool.ContextVar.set` in one test does not leak its armed + context into the next. + """ with scoped_context(): yield diff --git a/wool/tests/runtime/context/test_base.py b/wool/tests/runtime/context/test_base.py deleted file mode 100644 index 4e16afa0..00000000 --- a/wool/tests/runtime/context/test_base.py +++ /dev/null @@ -1,2965 +0,0 @@ -import asyncio -import gc -import uuid -from types import SimpleNamespace -from typing import Any -from typing import cast - -import cloudpickle -import pytest - -import wool -from tests.helpers import scoped_context -from wool.runtime.context import Context -from wool.runtime.context import ContextVar -from wool.runtime.context import RuntimeContext -from wool.runtime.context import Token -from wool.runtime.context import attached -from wool.runtime.context import copy_context -from wool.runtime.context import current_context -from wool.runtime.context import dispatch_timeout -from wool.runtime.serializer import Serializer - -dumps = wool.__serializer__.dumps -loads = cloudpickle.loads - - -class TestRuntimeContext: - def test___init___with_default_sentinel(self): - """Test RuntimeContext entered with the default sentinel skips - installing the stdlib dispatch_timeout value. - - Given: - A RuntimeContext constructed with no dispatch_timeout - argument and a stdlib dispatch_timeout value already - installed in the current scope - When: - The RuntimeContext is entered as a context manager and the - stdlib var is read inside the block - Then: - It should observe the prior value unchanged — the - "no-override" usage leaves the live var alone - """ - # Arrange - prior_token = dispatch_timeout.set(7.5) - try: - # Act - with RuntimeContext(): - observed = dispatch_timeout.get() - - # Assert - assert observed == 7.5 - finally: - dispatch_timeout.reset(prior_token) - - def test___enter___with_explicit_value(self): - """Test RuntimeContext entered with an explicit value overrides - and restores the stdlib dispatch_timeout. - - Given: - A RuntimeContext constructed with ``dispatch_timeout=2.5`` - and a different prior value installed in the current scope - When: - The RuntimeContext is entered as a context manager, the - stdlib var is read inside the block, then the block exits - and the var is read again - Then: - The in-block read should return ``2.5`` and the post-exit - read should return the prior value — __enter__/__exit__ - performs a scoped override - """ - # Arrange - prior_token = dispatch_timeout.set(1.0) - try: - inside: list[float | None] = [] - - # Act - with RuntimeContext(dispatch_timeout=2.5): - inside.append(dispatch_timeout.get()) - after = dispatch_timeout.get() - - # Assert - assert inside == [2.5] - assert after == 1.0 - finally: - dispatch_timeout.reset(prior_token) - - def test___enter___with_explicit_none(self): - """Test RuntimeContext entered with explicit ``None`` overrides - the stdlib dispatch_timeout to ``None`` (distinct from the - Undefined skip path). - - Given: - A RuntimeContext constructed with ``dispatch_timeout=None`` - and a non-None prior value installed in the current scope - When: - The RuntimeContext is entered as a context manager and the - stdlib var is read inside the block - Then: - It should observe ``None`` inside — explicit None is a - real override that __enter__ applies, separate from the - default Undefined sentinel that skips - """ - # Arrange - prior_token = dispatch_timeout.set(4.0) - try: - inside: list[float | None] = [] - - # Act - with RuntimeContext(dispatch_timeout=None): - inside.append(dispatch_timeout.get()) - - # Assert - assert inside == [None] - finally: - dispatch_timeout.reset(prior_token) - - def test_get_current_with_live_value(self): - """Test RuntimeContext.get_current snapshots the live stdlib - dispatch_timeout value. - - Given: - A stdlib ``dispatch_timeout`` set to ``4.0`` in the current - scope - When: - ``RuntimeContext.get_current`` is called - Then: - The returned RuntimeContext's ``to_protobuf`` should emit - ``dispatch_timeout=4.0`` — get_current captures the live - value rather than carrying the Undefined sentinel forward - """ - # Arrange - prior_token = dispatch_timeout.set(4.0) - try: - # Act - captured = RuntimeContext.get_current() - wire = captured.to_protobuf() - - # Assert - assert wire.HasField("dispatch_timeout") is True - assert wire.dispatch_timeout == 4.0 - finally: - dispatch_timeout.reset(prior_token) - - def test_from_protobuf_with_dispatch_timeout_set(self): - """Test RuntimeContext.from_protobuf decodes the dispatch_timeout - wire value. - - Given: - A ``protocol.RuntimeContext`` message with - ``dispatch_timeout=12.5`` - When: - ``RuntimeContext.from_protobuf`` decodes it and the result - is entered as a context manager - Then: - The stdlib var inside the block should observe ``12.5`` - """ - # Arrange - from wool import protocol - - wire = protocol.RuntimeContext(dispatch_timeout=12.5) - - # Act - decoded = RuntimeContext.from_protobuf(wire) - observed: list[float | None] = [] - with decoded: - observed.append(dispatch_timeout.get()) - - # Assert - assert observed == [12.5] - - def test_from_protobuf_without_dispatch_timeout(self): - """Test RuntimeContext.from_protobuf treats an absent wire - dispatch_timeout as an explicit ``None`` override. - - Given: - A ``protocol.RuntimeContext`` message with no - ``dispatch_timeout`` field set - When: - ``RuntimeContext.from_protobuf`` decodes it and the result - is re-emitted via ``to_protobuf`` - Then: - The re-emitted wire message should have - ``HasField('dispatch_timeout')`` False — the absent wire - field decodes to explicit ``None``, which on re-emission - skips the field rather than substituting the live scope - value - """ - # Arrange - from wool import protocol - - wire = protocol.RuntimeContext() - - # Act - decoded = RuntimeContext.from_protobuf(wire) - re_emitted = decoded.to_protobuf() - - # Assert - assert re_emitted.HasField("dispatch_timeout") is False - - def test_to_protobuf_with_explicit_value(self): - """Test RuntimeContext.to_protobuf emits an explicitly-set - dispatch_timeout value. - - Given: - A RuntimeContext constructed with ``dispatch_timeout=3.0`` - When: - ``to_protobuf`` is called - Then: - The returned message should have - ``HasField('dispatch_timeout')`` True and - ``dispatch_timeout == 3.0`` - """ - # Arrange - ctx = RuntimeContext(dispatch_timeout=3.0) - - # Act - wire = ctx.to_protobuf() - - # Assert - assert wire.HasField("dispatch_timeout") is True - assert wire.dispatch_timeout == 3.0 - - def test_to_protobuf_with_explicit_none(self): - """Test RuntimeContext.to_protobuf skips emission when the value - is explicit ``None``. - - Given: - A RuntimeContext constructed with ``dispatch_timeout=None`` - When: - ``to_protobuf`` is called - Then: - The returned message should have - ``HasField('dispatch_timeout')`` False — explicit None - means "no timeout", which the wire shape encodes as - absence so the receiver inherits its own scope's default - """ - # Arrange - ctx = RuntimeContext(dispatch_timeout=None) - - # Act - wire = ctx.to_protobuf() - - # Assert - assert wire.HasField("dispatch_timeout") is False - - def test_to_protobuf_with_default_sentinel_substitutes_live_value(self): - """Test RuntimeContext.to_protobuf substitutes the live scope - value when constructed with the default Undefined sentinel. - - Given: - A RuntimeContext constructed with no dispatch_timeout - argument (Undefined sentinel) and a stdlib - ``dispatch_timeout`` value of ``9.5`` set in the current - scope - When: - ``to_protobuf`` is called - Then: - The returned message should emit ``dispatch_timeout=9.5`` - — the encode-time live-value substitution lets a bare - ``RuntimeContext()`` ride the wire with the encoder's - effective timeout - """ - # Arrange - prior_token = dispatch_timeout.set(9.5) - try: - ctx = RuntimeContext() - - # Act - wire = ctx.to_protobuf() - - # Assert - assert wire.HasField("dispatch_timeout") is True - assert wire.dispatch_timeout == 9.5 - finally: - dispatch_timeout.reset(prior_token) - - def test_to_protobuf_with_default_sentinel_and_no_live_value(self): - """Test RuntimeContext.to_protobuf skips emission when the - sentinel resolves to a live ``None``. - - Given: - A RuntimeContext constructed with no dispatch_timeout - argument and the stdlib ``dispatch_timeout`` at its - default (None) in the current scope - When: - ``to_protobuf`` is called - Then: - The returned message should have - ``HasField('dispatch_timeout')`` False — Undefined - substituted to None, which then skips emission, mirroring - the explicit-None branch - """ - # Arrange - ctx = RuntimeContext() - - # Act - wire = ctx.to_protobuf() - - # Assert - assert wire.HasField("dispatch_timeout") is False - - def test___exit___when_called_after_unentered_context(self): - """Test RuntimeContext.__exit__ is a no-op when no token was - captured. - - Given: - A RuntimeContext constructed with the default Undefined - sentinel and never entered (so no token was captured) - When: - ``__exit__`` is invoked directly on the instance - Then: - It should return without error and without mutating the - stdlib var — the no-token branch is reached - """ - # Arrange - prior_token = dispatch_timeout.set(5.0) - try: - ctx = RuntimeContext() - - # Act - ctx.__exit__(None, None, None) - - # Assert - assert dispatch_timeout.get() == 5.0 - finally: - dispatch_timeout.reset(prior_token) - - -class TestContext: - def test___new___with_direct_instantiation(self): - """Test Context() constructs an empty Context with a fresh id. - - Given: - The Context class - When: - It is instantiated directly - Then: - The result should have a fresh id and no captured vars - """ - # Act - ctx = Context() - - # Assert - assert ctx.id is not None - assert len(ctx) == 0 - - def test___bool___is_false_for_fresh_context(self): - """Test bool(Context()) is False when no state has been captured. - - Given: - A freshly constructed empty Context - When: - bool() is invoked on it - Then: - The result should be False so callers can use - ``if not ctx:`` as a fast-path gate - """ - # Act - ctx = Context() - - # Assert - assert bool(ctx) is False - - def test___bool___when_var_is_set(self): - """Test bool(ctx) is True once a ContextVar has been set in it. - - Given: - A Context with a var bound via ContextVar.set - When: - bool() is invoked on it - Then: - The result should be True - """ - # Arrange - var = ContextVar("bool_var_set", default="initial") - var.set("value") - ctx = current_context() - - # Act - result = bool(ctx) - - # Assert - assert result is True - - def test_has_state_is_true_when_wire_carried_consumed_tokens(self): - """Test ctx.has_state() is True when _data is empty but the - wire Context carried a non-empty consumed_tokens list. - - Given: - A wire protocol.Context with no var bindings but a - non-empty consumed_tokens list (modeling a back-prop - response that only carries used-token state) - When: - Context.from_protobuf reconstructs it and has_state() is - invoked on the result - Then: - The result should be True — the ids must be applied on - merge so the corresponding live tokens flip their _used - flags. ``bool()`` follows ``__len__`` (var bindings only) - so it remains False; ``has_state()`` is the wire-truthy - predicate - """ - # Arrange - from wool import protocol - - pb = protocol.Context(id=uuid.uuid4().hex) - pb.vars.add( - namespace="", - name="", - consumed_tokens=[uuid.uuid4().hex], - ) - - # Act - reconstructed = Context.from_protobuf(pb) - - # Assert - assert len(reconstructed) == 0 - assert bool(reconstructed) is False - assert reconstructed.has_state() is True - - def test_has_state_is_true_when_live_consumed_token_is_present(self): - """Test ctx.has_state() is True when _data is empty but a - live consumed Token populated by ``ContextVar.reset`` is - still reachable. - - Given: - A Context that ran a set+reset cycle on a ContextVar - and a strong reference to the consumed Token kept alive - outside the Context — the var binding cleared back to - MISSING but the Token sits in ``_used_tokens`` for - outbound wire emission - When: - has_state() is invoked on the Context - Then: - The result should be True — the Context still carries - wire-shippable state via ``Context.to_protobuf``'s - ``consumed_tokens`` field, so ``has_state()`` aligns - with what the wire emission produces. ``bool()`` follows - ``__len__`` (var bindings only) and remains False - """ - # Arrange - var = ContextVar(f"used_token_id_truthy_{uuid.uuid4().hex}") - ctx = Context() - captured: list[Token[Any]] = [] - - def consume(): - t = var.set("x") - var.reset(t) - captured.append(t) - - ctx.run(consume) - - # Act - result = ctx.has_state() - - # Assert - assert len(ctx) == 0 - assert bool(ctx) is False - assert result is True - - def test_consumed_tokens_are_bounded_by_live_token_count(self): - """Test the consumed-token tracking set does not grow - unboundedly across many set/reset cycles whose Tokens are - not retained. - - Given: - A Context that performs 1000 set/reset cycles on the - same ContextVar, with no strong reference kept to any - Token after its reset - When: - ``gc.collect`` runs and the Context's emitted - ``consumed_tokens`` is inspected - Then: - The emitted list has length 0 — the auto-pruning - ``weakref.WeakSet`` reclaimed all Tokens whose role - (double-reset detection) has nothing left to bind to, - so the chain forwards no stale UUIDs onward - """ - # Arrange - var = ContextVar(f"bounded_consumed_tokens_{uuid.uuid4().hex}") - ctx = Context() - - def churn(): - for _ in range(1000): - t = var.set("x") - var.reset(t) - - ctx.run(churn) - gc.collect() - - # Act - emitted = ctx.to_protobuf() - - # Assert - emitted_token_ids = [ - tid for entry in emitted.vars for tid in entry.consumed_tokens - ] - assert emitted_token_ids == [] - - def test___bool___when_consumed_token_is_collected(self): - """Test bool(ctx) is False after the only reference to a - locally-consumed Token is dropped. - - Given: - A Context that ran a set+reset cycle whose Token was - never propagated outward and whose only strong reference - (the function-local ``t``) has gone out of scope, so the - ``weakref.WeakSet`` tracking entry has been pruned by GC - When: - bool() is invoked on the Context - Then: - The result should be False — there is no live Token - anywhere whose double-reset detection could be triggered, - so the consumed-state record has nothing left to forward; - the new ownership-handoff design prefers reclamation over - indefinite UUID retention - """ - # Arrange - var = ContextVar(f"used_token_id_collected_{uuid.uuid4().hex}") - ctx = Context() - - def consume(): - t = var.set("x") - var.reset(t) - - ctx.run(consume) - gc.collect() - - # Act - result = bool(ctx) - - # Assert - assert len(ctx) == 0 - assert result is False - - def test_update_with_live_token_when_wire_carries_id(self): - """Test Context.update flips a live unused Token's used flag - when the incoming wire Context lists its id. - - Given: - A live Token whose used flag is False, and a temp - Context reconstructed from a wire message whose - consumed_tokens list contains that Token's id - When: - current_context().update(temp) is called - Then: - Token.used on the live Token flips to True — the merge - resolves the incoming id through the process-wide token - registry and applies cross-process consumption to the - local instance - """ - # Arrange - from wool import protocol - - var = ContextVar("update_flip_used", default="d") - token = var.set("x") - assert token.used is False - - pb = protocol.Context(id=uuid.uuid4().hex) - pb.vars.add( - namespace=var.namespace, - name=var.name, - consumed_tokens=[token.id.hex], - ) - incoming = Context.from_protobuf(pb) - - # Act - current_context().update(incoming) - - # Assert - assert token.used is True - - def test_update_with_incoming_id_and_no_live_token(self): - """Test Context.update ignores incoming ids with no live Token. - - Given: - A current Context and a temp Context reconstructed from - a wire message whose consumed_tokens list contains a - UUID that matches no live Token in the process - When: - current_context().update(temp) is called - Then: - The merge completes without raising — unregistered - incoming ids are silently dropped because no peer Token - object exists to flip - """ - # Arrange - from wool import protocol - - pb = protocol.Context(id=uuid.uuid4().hex) - pb.vars.add( - namespace="", - name="", - consumed_tokens=[uuid.uuid4().hex], - ) - incoming = Context.from_protobuf(pb) - - # Act & assert - current_context().update(incoming) - - def test_run_seeds_vars_and_scopes_mutations(self): - """Test Context.run seeds vars in a fresh stdlib Context and scopes mutations. - - Given: - A Context captured with an initial var value - When: - Context.run() runs a function that mutates the var - Then: - Mutations should be visible inside the run and captured on exit - """ - # Arrange - var = ContextVar("run_seed", default="initial") - var.set("seeded") - ctx = current_context() - - def body(): - assert var.get() == "seeded" - var.set("mutated") - return var.get() - - # Act - result = ctx.run(body) - - # Assert - assert result == "mutated" - assert ctx[var] == "mutated" - - def test_run_binds_context_for_sync_callers(self): - """Test Context.run makes self.id the active Context id inside fn. - - Given: - A Context constructed directly (sync caller, no asyncio task) - When: - Context.run invokes a function that reads current_context().id - Then: - The reported context id equals the Context's own id, not the - process-default id - """ - # Arrange - ctx = Context() - - # Act - observed = ctx.run(lambda: current_context().id) - - # Assert - assert observed == ctx.id - - def test_run_snapshot_with_unset_or_reset_vars(self): - """Test Context.run's snapshot excludes vars that look unset at exit. - - Given: - One var never set inside the run, one var that is set and - then reset inside the run - When: - Context.run() returns and captures the post-run snapshot - Then: - Neither var should appear in the Context's captured vars - """ - # Arrange - untouched = ContextVar("untouched", default="x") - set_then_reset = ContextVar("set_then_reset") - ctx = Context() - - def body(): - token = set_then_reset.set("temp") - set_then_reset.reset(token) - - # Act - ctx.run(body) - - # Assert - assert untouched not in ctx - assert set_then_reset not in ctx - - def test_run_with_re_entry_on_same_task(self): - """Test Context.run rejects a nested run() call on the same task. - - Given: - A Context already executing a synchronous run() body - When: - A second run() is attempted from inside the body, on - the same thread/task - Then: - The single-task guard rejects the inner call — - re-entry from the owning execution scope is treated the - same as concurrent entry from another scope. Cross-thread - concurrency is exercised separately by - ``test_attached_with_concurrent_entry``. - """ - # Arrange - ctx = Context() - - def outer(): - with pytest.raises(RuntimeError): - ctx.run(lambda: None) - - # Act & assert - ctx.run(outer) - - @pytest.mark.asyncio - async def test_attached_keeps_context_across_await(self): - """Test attached(ctx) keeps the supplied Context installed - across coroutine suspension points. - - Given: - A Context populated with a ContextVar binding and an - async function that suspends via ``await asyncio.sleep(0)`` - before reading the var - When: - The async function is awaited inside a ``with attached(ctx)`` - block - Then: - The post-suspension read should observe the Context's - value — the attach scope spans the entire coroutine - body, not just the synchronous frame that constructed - the coroutine - """ - # Arrange - var = ContextVar("attached_keeps_context", default="default") - ctx = Context() - ctx.run(lambda: var.set("inside")) - observed: list[str] = [] - - async def read_after_suspend(): - await asyncio.sleep(0) - observed.append(var.get()) - - # Act - with attached(ctx): - await read_after_suspend() - - # Assert - assert observed == ["inside"] - - @pytest.mark.asyncio - async def test_attached_with_concurrent_entry(self): - """Test attached(ctx) raises when another task is already - running inside the same supplied Context. - - Given: - A Context with one task suspended inside an - ``attached(ctx)`` block (holding the single-task guard - across an await) - When: - A concurrent task attempts ``attached(ctx)`` on the same - Context - Then: - It should raise RuntimeError — the single-task - invariant holds across the await window, not just the - synchronous portion - """ - # Arrange - ctx = Context() - first_entered = asyncio.Event() - release_first = asyncio.Event() - outcomes: list[str] = [] - - async def first(): - with attached(ctx): - first_entered.set() - await release_first.wait() - - async def second(): - try: - with attached(ctx): - pass - except RuntimeError: - outcomes.append("second-rejected") - - first_task = asyncio.create_task(first()) - await first_entered.wait() - - # Act - await second() - release_first.set() - await first_task - - # Assert - assert outcomes == ["second-rejected"] - - def test_attached_installs_and_restores_binding(self): - """Test attached(ctx) makes ctx the current Context and restores - the prior binding on exit. - - Given: - A Context that is not the active scope's Context - When: - ``with attached(ctx):`` is entered, current_context() - is read inside, and the block exits - Then: - current_context() inside the block is ctx; after exit, - current_context() resolves to the prior binding - """ - # Arrange - prior = current_context() - target = Context() - - # Act & assert - assert current_context() is prior - with attached(target): - assert current_context() is target - assert current_context() is prior - - def test_attached_unguarded_is_reentrant_against_running_context(self): - """Test attached(ctx, guarded=False) does not acquire the - single-task guard and so may be used while ctx is already - running a routine. - - Given: - A Context already executing a routine via ``Context.run`` - (holding the single-task guard) - When: - ``with attached(ctx, guarded=False):`` is entered against - the same ctx from inside the running routine - Then: - It does not raise — guarded=False opts out of the - single-task claim so deserialization scopes nested inside - a running routine remain valid - """ - # Arrange - ctx = Context() - observed: list[Context] = [] - - def inside_run(): - with attached(ctx, guarded=False): - observed.append(current_context()) - - # Act - ctx.run(inside_run) - - # Assert - assert observed == [ctx] - - def test_copy_with_populated_source(self): - """Test Context.copy returns a sibling Context with the same - var bindings but a fresh logical-chain id. - - Given: - A Context populated with one or more var bindings via - ``Context.run`` - When: - ``source.copy()`` is invoked on the populated Context, - and a second ``source.copy()`` is invoked alongside - Then: - Each copy holds the same var bindings as the source, - each copy's id differs from the source's id, and the - two copies' ids differ from each other — mirrors - ``contextvars.Context.copy`` semantics with wool's - chain-id contract that copies are new chains in the - tree, not aliases of the source - """ - # Arrange - var = ContextVar(f"copy_source_{uuid.uuid4().hex}") - source = Context() - source.run(lambda: var.set("seed-value")) - - # Act - sibling_a = source.copy() - sibling_b = source.copy() - - # Assert - assert sibling_a.id != source.id - assert sibling_b.id != source.id - assert sibling_a.id != sibling_b.id - assert sibling_a[var] == "seed-value" - assert sibling_b[var] == "seed-value" - - def test_iter_yields_captured_vars(self): - """Test Context iterates over captured ContextVar instances. - - Given: - A Context with multiple captured vars - When: - It is iterated - Then: - The iterator should yield each captured var - """ - # Arrange - a = ContextVar("iter_a", default=0) - b = ContextVar("iter_b", default=0) - a.set(1) - b.set(2) - - # Act - ctx = current_context() - - # Assert - assert set(iter(ctx)) == {a, b} - - def test_getitem_with_captured_value(self): - """Test Context[var] returns the captured value. - - Given: - A Context with a captured var - When: - The Context is indexed by the var - Then: - The captured value should be returned - """ - # Arrange - var = ContextVar("get_item", default=0) - var.set(1) - ctx = current_context() - - # Act & assert - assert ctx[var] == 1 - - def test_contains_reports_membership(self): - """Test `var in ctx` reports whether the var was captured. - - Given: - A Context with one captured var and one uncaptured var - When: - Membership is tested for each - Then: - The captured var should be present and the uncaptured absent - """ - # Arrange - in_ctx = ContextVar("present_var", default=0) - out_ctx = ContextVar("absent_var", default=0) - in_ctx.set(1) - - # Act - ctx = current_context() - - # Assert - assert in_ctx in ctx - assert out_ctx not in ctx - - def test_len_with_captured_vars(self): - """Test len(ctx) returns the number of captured vars. - - Given: - A Context with two captured vars - When: - len() is called on the Context - Then: - It should return 2 - """ - # Arrange - a = ContextVar("len_a", default=0) - b = ContextVar("len_b", default=0) - a.set(1) - b.set(2) - - # Act - ctx = current_context() - - # Assert - assert len(ctx) == 2 - - def test_keys_values_items_expose_captured_pairs(self): - """Test Context keys/values/items expose the captured mapping. - - Given: - A Context with two captured vars - When: - keys(), values(), items() are called - Then: - Each accessor should return the expected captured pairs - """ - # Arrange - a = ContextVar("kvitems_a", default="") - b = ContextVar("kvitems_b", default="") - a.set("x") - b.set("y") - - # Act - ctx = current_context() - - # Assert - assert set(ctx.keys()) == {a, b} - assert set(ctx.values()) == {"x", "y"} - assert dict(ctx.items()) == {a: "x", b: "y"} - - def test_get_with_set_value_or_default(self): - """Test Context.get returns the set value or the supplied default. - - Given: - A Context with one var set and another var that was never - set in this Context - When: - get(var) and get(var, default) are called - Then: - The set var returns its value; the unset var returns the - supplied default (or None if no default is given) - """ - # Arrange - set_var = ContextVar("get_set", default="class-default") - unset_var = ContextVar("get_unset", default="class-default") - set_var.set("value") - ctx = current_context() - - # Act & assert - assert ctx.get(set_var) == "value" - assert ctx.get(set_var, "fallback") == "value" - assert ctx.get(unset_var) is None - assert ctx.get(unset_var, "fallback") == "fallback" - - def test_repr_includes_id_and_var_count(self): - """Test Context repr mentions id and number of vars. - - Given: - A Context with one captured var - When: - repr() is called on it - Then: - The repr should contain "id=" and "vars=1" - """ - # Arrange - var = ContextVar("repr_var", default=0) - var.set(1) - ctx = current_context() - - # Act - text = repr(ctx) - - # Assert - assert "id=" in text - assert "vars=1" in text - - def test___reduce_ex___under_pickle_copy_and_deepcopy(self): - """Test wool.Context refuses pickle, copy.copy, and copy.deepcopy. - - Given: - A live wool.Context - When: - pickle.dumps, cloudpickle.dumps, wool.__serializer__.dumps, - copy.copy, and copy.deepcopy are each invoked on it - Then: - All five raise TypeError. Wool's own pickler rejects too — - Context has no __wool_reduce__ — so a snapshot disconnected - from live state cannot leak through any pickling path. - Callers must use Context.copy() explicitly for in-process - duplication; cross-process propagation rides - Context.to_protobuf and Context.from_protobuf instead. - """ - # Arrange - import copy as _copy - import pickle - - var = ContextVar("ctx_unpicklable", default="zero") - var.set("one") - ctx = current_context() - - # Act & assert - with pytest.raises(TypeError, match="wool.Context"): - pickle.dumps(ctx) - with pytest.raises(TypeError, match="wool.Context"): - cloudpickle.dumps(ctx) - with pytest.raises(TypeError, match="wool.Context"): - wool.__serializer__.dumps(ctx) - with pytest.raises(TypeError, match="wool.Context"): - _copy.copy(ctx) - with pytest.raises(TypeError, match="wool.Context"): - _copy.deepcopy(ctx) - - def test_to_protobuf_with_unpicklable_value(self): - """Test Context.to_protobuf emits a ContextDecodeWarning for an - unpicklable var and skips that entry. - - Given: - A ContextVar set to an unpicklable value (a local generator - function object) alongside a ContextVar set to a - serializable value - When: - current_context().to_protobuf() is called to snapshot the - current vars - Then: - A ContextDecodeWarning is emitted naming the offending var - key, the returned wire context contains only the - serializable var, and the snapshot does not preempt the - primary signal — mirroring from_protobuf's per-entry - decode resilience. - """ - # Arrange - from wool import ContextDecodeWarning - - good = ContextVar("ctx001_serializable", default="default") - bad = ContextVar("ctx001_unpicklable") - - def _local_gen(): - yield 1 - - good.set("kept") - bad.set(_local_gen()) - - # Act - with pytest.warns(ContextDecodeWarning, match=bad.name): - wire_ctx = current_context().to_protobuf() - - # Assert - emitted_keys = {(e.namespace, e.name) for e in wire_ctx.vars} - assert (good.namespace, good.name) in emitted_keys, ( - "Serializable var should ride the wire alongside the skipped entry" - ) - assert (bad.namespace, bad.name) not in emitted_keys, ( - "Unpicklable entry should be skipped, not partially encoded" - ) - - def test_to_protobuf_with_multiple_unpicklable_values_strict_mode(self): - """Test Context.to_protobuf aggregates strict-mode encode - failures into a single BaseExceptionGroup naming every - offending var. - - Given: - Two ContextVars each set to a value whose ``__reduce__`` - raises, with a strict warnings filter promoting - ContextDecodeWarning to an exception - When: - current_context().to_protobuf() is called to snapshot the - current vars - Then: - A BaseExceptionGroup is raised whose peers are - ContextDecodeWarning instances — one per offending var, - each naming the corresponding key — so callers learn - about every bad var on a single dispatch attempt. - """ - import warnings as _warnings - - from wool import ContextDecodeWarning - - # Arrange - first = ContextVar("ctx001_strict_a") - second = ContextVar("ctx001_strict_b") - - class _Unpicklable: - def __reduce__(self): - raise TypeError("synthetic unpicklable") - - first.set(_Unpicklable()) - second.set(_Unpicklable()) - - # Act & assert - with _warnings.catch_warnings(): - _warnings.simplefilter("error", category=ContextDecodeWarning) - with pytest.raises(BaseExceptionGroup) as exc_info: - current_context().to_protobuf() - - peers = exc_info.value.exceptions - assert all(isinstance(p, ContextDecodeWarning) for p in peers), ( - "Every peer should be a ContextDecodeWarning" - ) - peer_messages = [str(p) for p in peers] - assert any(first.name in msg for msg in peer_messages), ( - "First offending var should be named in a peer" - ) - assert any(second.name in msg for msg in peer_messages), ( - "Second offending var should be named in a peer" - ) - assert len(peers) == 2, "Group should hold one peer per offending var, no more" - - def test_to_protobuf_omits_entry_when_value_serialization_fails_with_consumed_tokens( - self, - ): - """Test Context.to_protobuf omits the wire entry entirely when a - var's value fails to serialize, even when the var carries - consumed tokens. - - Given: - A Context where a var was set, reset (producing a consumed - token held by a strong reference), and re-set to an - unserializable value — the var carries both a consumed - token in :attr:`_used_tokens` and a current binding in - :attr:`_data` whose value cannot be serialized. - When: - ``ctx.to_protobuf()`` is called. - Then: - A :class:`ContextDecodeWarning` is emitted, and no wire - entry is produced for the offending var — neither the - value nor the consumed-token list rides the wire. A - half-encoded entry (consumed tokens but no value) would - propagate a phantom reset to the receiver. - """ - # Arrange - from wool import ContextDecodeWarning - - var = ContextVar(f"halfencoded_{uuid.uuid4().hex}") - ctx = Context() - captured: list[Token[Any]] = [] - - def _local_gen(): - yield 1 - - def cycle() -> None: - captured.append(var.set("serializable")) - var.reset(captured[0]) - var.set(_local_gen()) - - ctx.run(cycle) - - # Act - with pytest.warns(ContextDecodeWarning, match=var.name): - emitted = ctx.to_protobuf() - - # Assert - emitted_keys = {(e.namespace, e.name) for e in emitted.vars} - assert (var.namespace, var.name) not in emitted_keys, ( - "A var whose value failed to serialize must be suppressed " - "entirely; emitting consumed tokens without a value would " - "propagate a phantom reset to the receiver" - ) - - def test_update_does_not_spuriously_reset_on_sender_serialization_failure(self): - """Test Context.update preserves the receiver's binding when - the sender's wire frame omits a var due to a serialization - failure on the sender. - - Given: - A receiver Context populated from a clean initial frame - binding ``var`` to ``"serializable"``; a sender Context - that subsequently ran a reset+re-set cycle re-setting the - var to an unserializable value, then produced a second - wire frame. - When: - ``receiver.update(Context.from_protobuf(second_frame))`` - is called. - Then: - The receiver still observes ``var == "serializable"`` — - the sender's serialization failure does not propagate a - phantom reset that would clobber the receiver's prior - binding. - """ - # Arrange - from wool import ContextDecodeWarning - - var = ContextVar(f"no_phantom_reset_{uuid.uuid4().hex}") - sender = Context() - captured: list[Token[Any]] = [] - - def _local_gen(): - yield 1 - - def setup() -> None: - captured.append(var.set("serializable")) - - sender.run(setup) - receiver = Context.from_protobuf(sender.to_protobuf()) - assert receiver[var] == "serializable", ( - "Sanity: initial frame should carry the set value" - ) - - def cycle() -> None: - var.reset(captured[0]) - var.set(_local_gen()) - - sender.run(cycle) - - with pytest.warns(ContextDecodeWarning, match=var.name): - second_frame = sender.to_protobuf() - - # Act - receiver.update(Context.from_protobuf(second_frame)) - - # Assert - assert var in receiver, ( - "A serialization failure on the sender must not propagate " - "a phantom reset to the receiver" - ) - assert receiver[var] == "serializable", ( - "The receiver's prior binding must survive the sender's failed re-set" - ) - - def test_update_with_empty_context_is_noop(self): - """Test update applied with an empty peer leaves state unchanged. - - Given: - A Context with a var set and an empty peer Context - When: - current.update(empty) is called - Then: - The current context is unchanged and no exception is - raised. - """ - # Arrange - from wool import protocol - - var = ContextVar("ctx002_seed", default="default") - var.set("before") - before = current_context() - empty = Context.from_protobuf(protocol.Context()) - - # Act - before.update(empty) - - # Assert - after = current_context() - assert after[var] == before[var] - assert set(after.keys()) == set(before.keys()) - - def test_update_propagates_caller_side_reset(self): - """Test Context.update unsets a var on the receiver when the - sender's wire frame consumed the corresponding token without - re-setting the var. - - Given: - A receiver Context with var bound to a value (modeling a - worker that received the caller's initial frame) and a - second wire frame from the caller showing the var was - reset between dispatches — the sender's vars map no - longer carries the var, but its consumed_tokens list - does - When: - receiver.update(Context.from_protobuf(reset_frame)) is - called - Then: - The receiver no longer sees the var — the reset signal - propagates through the merge so mid-stream resets reach - the worker, not just the initial-dispatch state - """ - # Arrange - var = ContextVar(f"reset_propagation_{uuid.uuid4().hex}") - sender = Context() - - def set_var() -> Token[str]: - return var.set("caller-value") - - token = sender.run(set_var) - receiver = Context.from_protobuf(sender.to_protobuf()) - assert receiver[var] == "caller-value", ( - "Sanity: initial frame should carry the set value" - ) - - def reset_var() -> None: - var.reset(token) - - sender.run(reset_var) - reset_frame = sender.to_protobuf() - - # Act - receiver.update(Context.from_protobuf(reset_frame)) - - # Assert - assert var not in receiver, ( - "Reset should propagate via update so the worker observes " - "the caller's post-reset state" - ) - - def test_update_preserves_re_set_after_reset(self): - """Test Context.update keeps a re-set value when the sender's - wire frame carries both the consumed token and a fresh value - for the same var. - - Given: - A sender Context that ran var.set('A'), var.reset(token), - then var.set('B') in sequence — its data carries the new - 'B' value and its used-token set carries the consumed - token. A receiver Context populated from the sender's - initial frame. - When: - receiver.update(Context.from_protobuf(sender.to_protobuf())) - is called with the post-re-set frame. - Then: - ``receiver[var] == 'B'`` — the consumed token's var key - matches an entry in the sender's vars map, so the merge - keeps the re-set value rather than treating reset - propagation as a blanket pop. - """ - # Arrange - var = ContextVar(f"reset_then_reset_{uuid.uuid4().hex}") - sender = Context() - - def set_a() -> Token[str]: - return var.set("A") - - token = sender.run(set_a) - receiver = Context.from_protobuf(sender.to_protobuf()) - assert receiver[var] == "A" - - def reset_then_set_b() -> None: - var.reset(token) - var.set("B") - - sender.run(reset_then_set_b) - - # Act - receiver.update(Context.from_protobuf(sender.to_protobuf())) - - # Assert - assert receiver[var] == "B", ( - "Re-set after reset should win over the consumed-token's " - "pop signal — the var key is in both the consumed list " - "and the vars map, so the vars value is authoritative" - ) - - def test_update_propagates_external_token_reset_through_transit_hop(self): - """Test Context.update unsets a var on the receiver when the - wire frame's consumed token arrives without a live token - instance — the transit-hop case. - - Given: - A receiver Context with var bound to ``"caller-value"``; - a wire ``protocol.Context`` whose ``vars`` list carries a - single ``protocol.ContextVar`` entry under the var's - ``(namespace, name)`` identity, with no current value but - with a consumed-token id whose UUID was never registered - in this process's token registry — modeling a frame - relayed through a hop that never reconstituted the live - Token instance. - When: - receiver.update(Context.from_protobuf(pb)) is called. - Then: - ``var not in receiver`` — the external-token entry's - ``(namespace, name)`` identity threads through - ``_external_used_tokens`` and propagates the reset signal - to the receiver even when the live Token instance is - absent. - """ - # Arrange - from wool import protocol - - var = ContextVar(f"transit_reset_{uuid.uuid4().hex}") - with scoped_context() as receiver: - var.set("caller-value") - assert receiver[var] == "caller-value" - - pb = protocol.Context() - pb.vars.add( - namespace=var.namespace, - name=var.name, - consumed_tokens=[uuid.uuid4().hex], - ) - - # Act - receiver.update(Context.from_protobuf(pb)) - - # Assert - assert var not in receiver, ( - "External-token consumed entry with var_key should " - "propagate the reset to the receiver even without a " - "live Token instance" - ) - - def test_to_protobuf_emits_consumed_token_under_var_entry(self): - """Test Context.to_protobuf emits each consumed token inside - the wire entry for the var that minted it. - - Given: - A Context that ran a set+reset cycle on a ContextVar with - a strong reference held to the resulting Token so the - ``weakref.WeakSet`` does not prune it before emission. - When: - ``ctx.to_protobuf()`` is called. - Then: - The wire frame contains one :class:`protocol.ContextVar` - entry whose ``namespace`` and ``name`` match the owning - var and whose ``consumed_tokens`` list carries the - token's id hex — the wire shape colocates the consumed - token with its var so transit hops can propagate the - reset signal. - """ - # Arrange - var = ContextVar(f"emit_var_key_{uuid.uuid4().hex}") - ctx = Context() - captured: list[Token[str]] = [] - - def consume() -> None: - t = var.set("x") - var.reset(t) - captured.append(t) - - ctx.run(consume) - - # Act - emitted = ctx.to_protobuf() - - # Assert - entries = list(emitted.vars) - assert len(entries) == 1 - assert (entries[0].namespace, entries[0].name) == (var.namespace, var.name) - assert list(entries[0].consumed_tokens) == [captured[0].id.hex] - - def test_to_protobuf_emits_value_and_consumed_tokens_in_one_entry(self): - """Test Context.to_protobuf collates a current value and a - consumed-token id under the same wire entry when both belong - to the same var. - - Given: - A Context that ran ``var.set("A")`` capturing the token, - then ``var.reset(token)``, then ``var.set("B")`` — the var - now carries a current value alongside a consumed-token - for the same identity, with a strong reference held to - the captured token. - When: - ``ctx.to_protobuf()`` is called. - Then: - Exactly one ``protocol.ContextVar`` entry is emitted; its - ``(namespace, name)`` matches the owning var, its - ``value`` field is set (``HasField('value') is True``), - and its ``consumed_tokens`` list contains the captured - token's id hex — confirming the merged wire shape - colocates set state and consumed-token state for one - var into a single entry. - """ - # Arrange - var = ContextVar(f"value_and_token_{uuid.uuid4().hex}") - ctx = Context() - captured: list[Token[str]] = [] - - def churn() -> None: - t = var.set("A") - var.reset(t) - captured.append(t) - var.set("B") - - ctx.run(churn) - - # Act - emitted = ctx.to_protobuf() - - # Assert - entries = list(emitted.vars) - assert len(entries) == 1 - entry = entries[0] - assert (entry.namespace, entry.name) == (var.namespace, var.name) - assert entry.HasField("value") - assert cloudpickle.loads(entry.value) == "B" - assert list(entry.consumed_tokens) == [captured[0].id.hex] - - def test_to_protobuf_folds_repeated_consumed_tokens_for_one_var(self): - """Test Context.to_protobuf groups multiple consumed tokens - for the same var into a single wire entry's - ``consumed_tokens`` list. - - Given: - A Context that ran two set+reset cycles against the same - var — producing two distinct consumed Tokens under the - same ``(namespace, name)`` — with strong references held - to both Tokens so the ``weakref.WeakSet`` does not prune - them before emission. - When: - ``ctx.to_protobuf()`` is called. - Then: - Exactly one ``protocol.ContextVar`` entry is emitted for - the var, and its ``consumed_tokens`` list contains both - token id hexes — the encode loop deduplicates by - ``(namespace, name)`` rather than emitting one entry per - consumed token. - """ - # Arrange - var = ContextVar(f"two_tokens_one_var_{uuid.uuid4().hex}") - ctx = Context() - captured: list[Token[str]] = [] - - def churn() -> None: - t1 = var.set("first") - var.reset(t1) - captured.append(t1) - t2 = var.set("second") - var.reset(t2) - captured.append(t2) - - ctx.run(churn) - - # Act - emitted = ctx.to_protobuf() - - # Assert - entries = list(emitted.vars) - assert len(entries) == 1 - entry = entries[0] - assert (entry.namespace, entry.name) == (var.namespace, var.name) - assert set(entry.consumed_tokens) == {captured[0].id.hex, captured[1].id.hex} - - def test_from_protobuf_decodes_value_and_consumed_tokens_in_one_entry(self): - """Test Context.from_protobuf decodes a wire entry that - carries both an optional value and a consumed-token id under - the same var identity. - - Given: - A wire ``protocol.Context`` whose single - ``protocol.ContextVar`` entry carries a registered var's - ``(namespace, name)``, a serialized value, and a - consumed-token id hex. - When: - ``Context.from_protobuf(pb)`` is called and the result is - re-emitted via ``to_protobuf``. - Then: - The reconstructed Context binds the var to the - deserialized value, and the re-emitted wire frame carries - the consumed-token id under the same merged entry — - decode handles colocated value and consumed_tokens - symmetrically with encode. - """ - # Arrange - from wool import protocol - - var = ContextVar(f"merged_decode_{uuid.uuid4().hex}", default="initial") - token_id = uuid.uuid4() - pb = protocol.Context( - id=uuid.uuid4().hex, - vars=[ - protocol.ContextVar( - namespace=var.namespace, - name=var.name, - value=cloudpickle.dumps("decoded"), - consumed_tokens=[token_id.hex], - ), - ], - ) - - # Act - reconstructed = Context.from_protobuf(pb) - re_emitted = reconstructed.to_protobuf() - - # Assert - assert reconstructed[var] == "decoded" - re_entries = list(re_emitted.vars) - assert len(re_entries) == 1 - re_entry = re_entries[0] - assert (re_entry.namespace, re_entry.name) == (var.namespace, var.name) - assert re_entry.HasField("value") - assert cloudpickle.loads(re_entry.value) == "decoded" - assert list(re_entry.consumed_tokens) == [token_id.hex] - - def test_from_protobuf_with_unknown_keys_alongside_known_ones(self): - """Test Context.from_protobuf stubs unknown keys and applies their - values, while still deserializing known keys as normal. - - Given: - A wire-form protocol.Context carrying a mix of keys — one - registered on this process, one not - When: - Context.from_protobuf is invoked with the payload - Then: - The registered key is deserialized as before, and the - unregistered key results in a stub entry so the receiver - can observe the propagated value when the var is later - declared (rolling-deploy / lazy-import scenario). - """ - # Arrange - from wool import protocol - - known_var = ContextVar("ctx003_known", default="initial") - unknown_ns = f"ctx003_unknown_{uuid.uuid4().hex}" - pb = protocol.Context( - vars=[ - protocol.ContextVar( - namespace=unknown_ns, - name="missing", - value=dumps("propagated"), - ), - protocol.ContextVar( - namespace=known_var.namespace, - name=known_var.name, - value=dumps("applied"), - ), - ] - ) - - # Act - reconstructed = Context.from_protobuf(pb) - current_context().update(reconstructed) - - # Assert - assert reconstructed[known_var] == "applied" - late_declared: ContextVar[str] = ContextVar("missing", namespace=unknown_ns) - assert late_declared.get() == "propagated" - - def test_from_protobuf_with_unregistered_key_then_later_var_declaration( - self, - ): - """Test wire ingress of an unregistered var matches the pickle-path - stub-promotion semantics. - - Given: - A wire protocol.Context carrying a var key that is not yet - registered on this process, and the corresponding - wool.ContextVar declaration arrives later (lazy-import on - the receiver) - When: - Context.from_protobuf reconstructs the payload, - current_context().update merges it, and the user then - declares the ContextVar under the same key - Then: - ContextVar.get should return the wire-propagated value — - the wire-ingress path creates and pins a stub the same way - the pickled-ContextVar-instance path does, so lazy-import - receivers converge after one dispatch rather than needing - a second one that carries a ContextVar instance in-args. - """ - # Arrange - from wool import protocol - - unique_ns = f"wire_stub_{uuid.uuid4().hex}" - pb = protocol.Context(id=uuid.uuid4().hex) - pb.vars.add( - namespace=unique_ns, - name="tenant_id", - value=dumps("acme-corp"), - ) - - incoming = Context.from_protobuf(pb) - current_context().update(incoming) - - # Act - var: ContextVar[str] = ContextVar("tenant_id", namespace=unique_ns) - - # Assert - assert var.get() == "acme-corp" - - def test_from_protobuf_with_corrupt_value(self): - """Test Context.from_protobuf emits a ContextDecodeWarning - naming a corrupt key and skips that entry instead of - aborting the whole decode. - - Given: - A registered wool.ContextVar and a wire-form - protocol.Context containing that key mapped to bytes that - are not a valid pickle stream - When: - Context.from_protobuf is invoked with the payload - Then: - It emits a ContextDecodeWarning naming the offending - key and returns a Context with the corrupt entry - skipped — surviving entries decode normally - """ - # Arrange - from wool import ContextDecodeWarning - from wool import protocol - - var = ContextVar("ctx003_corrupt", default="initial") - pb = protocol.Context( - vars=[ - protocol.ContextVar( - namespace=var.namespace, - name=var.name, - value=b"\x00not a valid pickle stream\x00", - ) - ] - ) - - # Act - with pytest.warns(ContextDecodeWarning, match=var.name): - ctx = Context.from_protobuf(pb) - - # Assert - assert var not in ctx, "Corrupt entry should be skipped, not partially decoded" - - def test_from_protobuf_with_malformed_id_emits_warning(self): - """Test Context.from_protobuf emits a ContextDecodeWarning when - the wire context's chain id cannot be parsed, falling back to - a freshly-minted chain id. - - Given: - A wire-form protocol.Context whose ``id`` field is a - non-empty string that does not parse as a valid UUID - When: - Context.from_protobuf is invoked under default warning - filters - Then: - A ContextDecodeWarning is emitted naming the malformed - id, and the returned Context carries a freshly-minted - chain id rather than propagating the bad value - """ - # Arrange - from wool import ContextDecodeWarning - from wool import protocol - - pb = protocol.Context(id="not-a-uuid") - - # Act - with pytest.warns(ContextDecodeWarning, match="not-a-uuid"): - ctx = Context.from_protobuf(pb) - - # Assert - assert isinstance(ctx.id, uuid.UUID), ( - "Fallback chain id should be a freshly-minted UUID" - ) - assert ctx.id.hex != "not-a-uuid", "Malformed id should not be propagated" - - def test_from_protobuf_with_malformed_id_strict_mode(self): - """Test Context.from_protobuf aggregates a strict-mode - catastrophic id-parse failure into a BaseExceptionGroup. - - Given: - A wire-form protocol.Context with a malformed ``id`` - field, with a strict warnings filter promoting - ContextDecodeWarning to an exception - When: - Context.from_protobuf is invoked - Then: - A BaseExceptionGroup is raised with a single - ContextDecodeWarning peer naming the malformed id — - catastrophic decode failures ride the same group - channel as per-var failures - """ - import warnings as _warnings - - from wool import ContextDecodeWarning - from wool import protocol - - # Arrange - pb = protocol.Context(id="not-a-uuid") - - # Act & assert - with _warnings.catch_warnings(): - _warnings.simplefilter("error", category=ContextDecodeWarning) - with pytest.raises(BaseExceptionGroup) as exc_info: - Context.from_protobuf(pb) - - peers = exc_info.value.exceptions - assert len(peers) == 1, "Group should hold one peer for the id failure" - assert isinstance(peers[0], ContextDecodeWarning) - assert "not-a-uuid" in str(peers[0]) - - def test_from_protobuf_aggregates_id_and_var_failures_strict_mode(self): - """Test Context.from_protobuf folds catastrophic id-parse - failures and per-var decode failures into a single - BaseExceptionGroup under strict mode. - - Given: - A wire-form protocol.Context with both a malformed - ``id`` field and a registered var bound to bytes that - are not a valid pickle stream, with strict mode active - When: - Context.from_protobuf is invoked - Then: - A BaseExceptionGroup is raised whose peers include one - ContextDecodeWarning naming the malformed id and one - naming the corrupt var key — confirming the unified - decode-failure channel covers both axes simultaneously - """ - import warnings as _warnings - - from wool import ContextDecodeWarning - from wool import protocol - - # Arrange - var = ContextVar("ctx003_strict_combined", default="initial") - pb = protocol.Context( - id="not-a-uuid", - vars=[ - protocol.ContextVar( - namespace=var.namespace, - name=var.name, - value=b"\x00not a valid pickle stream\x00", - ) - ], - ) - - # Act & assert - with _warnings.catch_warnings(): - _warnings.simplefilter("error", category=ContextDecodeWarning) - with pytest.raises(BaseExceptionGroup) as exc_info: - Context.from_protobuf(pb) - - peers = exc_info.value.exceptions - assert all(isinstance(p, ContextDecodeWarning) for p in peers), ( - "Every peer should be a ContextDecodeWarning" - ) - peer_messages = [str(p) for p in peers] - assert any("not-a-uuid" in msg for msg in peer_messages), ( - "Malformed id should be named in a peer" - ) - assert any(var.name in msg for msg in peer_messages), ( - "Corrupt var should be named in a peer" - ) - assert len(peers) == 2, "Group should hold one peer per failure axis (id + var)" - - def test_from_protobuf_with_multiple_corrupt_values_strict_mode(self): - """Test Context.from_protobuf aggregates strict-mode decode - failures into a single BaseExceptionGroup naming every - offending var. - - Given: - A wire-form protocol.Context carrying two registered keys - mapped to bytes that are not a valid pickle stream, with a - strict warnings filter promoting ContextDecodeWarning to - an exception - When: - Context.from_protobuf is invoked with the payload - Then: - A BaseExceptionGroup is raised whose peers are - ContextDecodeWarning instances — one per offending var, - each naming the corresponding key — so callers learn - about every bad var on a single decode attempt. - """ - import warnings as _warnings - - from wool import ContextDecodeWarning - from wool import protocol - - # Arrange - first = ContextVar("ctx003_strict_a", default="a-default") - second = ContextVar("ctx003_strict_b", default="b-default") - pb = protocol.Context( - vars=[ - protocol.ContextVar( - namespace=first.namespace, - name=first.name, - value=b"\x00not a valid pickle stream\x00", - ), - protocol.ContextVar( - namespace=second.namespace, - name=second.name, - value=b"\x01also not valid\x01", - ), - ] - ) - - # Act & assert - with _warnings.catch_warnings(): - _warnings.simplefilter("error", category=ContextDecodeWarning) - with pytest.raises(BaseExceptionGroup) as exc_info: - Context.from_protobuf(pb) - - peers = exc_info.value.exceptions - assert all(isinstance(p, ContextDecodeWarning) for p in peers), ( - "Every peer should be a ContextDecodeWarning" - ) - peer_messages = [str(p) for p in peers] - assert any(first.name in msg for msg in peer_messages), ( - "First offending var should be named in a peer" - ) - assert any(second.name in msg for msg in peer_messages), ( - "Second offending var should be named in a peer" - ) - assert len(peers) == 2, "Group should hold one peer per offending var, no more" - - def test_from_protobuf_with_malformed_consumed_token_ids( - self, - ): - """Test Context.from_protobuf tolerates a single malformed - consumed-token id without aborting the whole frame decode. - - Given: - A wire-form protocol.Context carrying a valid var - binding, a valid consumed-token hex id, and a malformed - consumed-token hex id (not a UUID) - When: - Context.from_protobuf is invoked with the payload - Then: - The valid var binding is applied, the valid consumed-token - id lands in the reconstructed Context's incoming buffer, - the malformed id is skipped with a ContextDecodeWarning - naming it, and no ValueError propagates — matching the - per-var log-and-skip policy already in place for var - values. - """ - # Arrange - from wool import ContextDecodeWarning - from wool import protocol - - var = ContextVar("partial_decode_with_invalid_token_id", default="d") - valid_id = uuid.uuid4() - pb = protocol.Context( - id=uuid.uuid4().hex, - vars=[ - protocol.ContextVar( - namespace=var.namespace, - name=var.name, - value=dumps("applied"), - consumed_tokens=[valid_id.hex, "not-a-uuid"], - ), - ], - ) - - # Act - with pytest.warns(ContextDecodeWarning, match="not-a-uuid"): - reconstructed = Context.from_protobuf(pb) - - # Assert - assert reconstructed[var] == "applied" - current_context().update(reconstructed) - emitted = current_context().to_protobuf() - emitted_token_ids = { - tid for entry in emitted.vars for tid in entry.consumed_tokens - } - assert valid_id.hex in emitted_token_ids - - def test_from_protobuf_strict_mode_aggregates_malformed_token_id(self): - """Test Context.from_protobuf folds a malformed consumed-token - id into the strict-mode BaseExceptionGroup with the owning - var's identity in the warning message. - - Given: - A wire ``protocol.Context`` whose single - ``protocol.ContextVar`` entry carries a registered var's - ``(namespace, name)``, a valid pickled value, and a - malformed consumed-token hex ("not-a-uuid"); a strict - warnings filter promotes ContextDecodeWarning to an - exception. - When: - Context.from_protobuf is invoked with the payload. - Then: - A BaseExceptionGroup is raised whose peers include - exactly one ContextDecodeWarning whose message names - both the malformed hex and the var's namespace and name — - confirming malformed-token-id failures route through the - same strict-mode aggregation channel as value-decode and - id-parse failures, and that the error message uses the - var-key context the merged wire shape exposes. - """ - import warnings as _warnings - - from wool import ContextDecodeWarning - from wool import protocol - - # Arrange - var = ContextVar(f"strict_token_id_{uuid.uuid4().hex}", default="d") - pb = protocol.Context( - id=uuid.uuid4().hex, - vars=[ - protocol.ContextVar( - namespace=var.namespace, - name=var.name, - value=cloudpickle.dumps("applied"), - consumed_tokens=["not-a-uuid"], - ), - ], - ) - - # Act & assert - with _warnings.catch_warnings(): - _warnings.simplefilter("error", category=ContextDecodeWarning) - with pytest.raises(BaseExceptionGroup) as exc_info: - Context.from_protobuf(pb) - - peers = exc_info.value.exceptions - assert all(isinstance(p, ContextDecodeWarning) for p in peers) - token_id_peers = [p for p in peers if "not-a-uuid" in str(p)] - assert len(token_id_peers) == 1 - message = str(token_id_peers[0]) - assert var.namespace in message - assert var.name in message - - def test_from_protobuf_with_stub_pinning(self): - """Test Context.from_protobuf attaches resolved stubs to the - Context it constructs and returns, not to the caller's - currently-active Context. - - Given: - A wire-form protocol.Context carrying an unregistered - namespaced key, decoded from inside an outer - scoped_context block. The returned Context is then - dropped while the outer Context is still in scope, so - only the pin anchor's keep-alive can preserve the stub - once the outer block subsequently exits - When: - The returned Context is dropped, the outer scope exits, - and gc.collect runs - Then: - The stub should NOT be discoverable in the process-wide - var registry — the pin attribution lives on the - returned Context (now gone), so the stub is reclaimed. - If the pin had attached to the outer Context the stub - would have outlived the returned Context (an attribution - inversion). - """ - # Arrange - from wool import protocol - from wool.runtime.context import var_registry - - key = ("pin_attribution", uuid.uuid4().hex) - pb = protocol.Context(id=uuid.uuid4().hex) - pb.vars.add( - namespace=key[0], - name=key[1], - value=cloudpickle.dumps("propagated"), - ) - - # Act - with scoped_context(): - incoming = Context.from_protobuf(pb) - del incoming - gc.collect() - # Window: returned Context dropped, outer scope still - # active. With the pin on incoming this releases the - # stub immediately; with the pin on outer the stub - # would survive until outer also dies. - in_registry_after_incoming_dies = key in var_registry - - # Assert - assert in_registry_after_incoming_dies is False - - @pytest.mark.asyncio - async def test_from_protobuf_in_caller_task(self): - """Test Context.from_protobuf does not lazy-stamp a Context on - the calling task's scope as a side effect of decoding a wire - frame. - - Given: - An asyncio task with no wool.Context bound to it — - a fresh ``loop.create_task`` child created without - wool's task factory installed, so no Context is - auto-bound to the task identity. - When: - Context.from_protobuf is called from inside that task. - Then: - The wool registry slot for the task remains ``None`` — - decoding a wire frame must not materialize a Context on - the decoding task's scope. - """ - # Arrange - from wool import protocol - from wool.runtime.context.registry import context_registry - - loop = asyncio.get_running_loop() - # Bypass the wool task factory so no Context is auto-bound - # to the child task's identity. - loop.set_task_factory(None) - - pb = protocol.Context(id=uuid.uuid4().hex) - pb.vars.add( - namespace="no_side_effect", - name=uuid.uuid4().hex, - value=cloudpickle.dumps("v"), - ) - observed: list[Context | None] = [] - - async def body(): - Context.from_protobuf(pb) - observed.append(context_registry.get()) - - # Act - await loop.create_task(body()) - - # Assert - assert observed == [None] - - def test_update_merges_consumed_tokens(self): - """Test Context.update folds the source's consumed-token ids - into the destination's outbound wire emission. - - Given: - A secondary :class:`Context` reconstructed from a wire - payload whose ``consumed_tokens`` lists a token id, and - a primary :class:`Context` that has not yet observed - that id - When: - ``primary.update(secondary)`` is called - Then: - ``primary.to_protobuf().consumed_tokens`` contains the - merged id — observable via the wire shape rather than - via a registry-aliased Token whose ``used`` flag is - trivially True regardless of whether ``update`` did - anything - """ - # Arrange - from wool import protocol - - token_id = uuid.uuid4() - wire = protocol.Context(id=uuid.uuid4().hex) - wire.vars.add(namespace="", name="", consumed_tokens=[token_id.hex]) - - secondary = Context.from_protobuf(wire) - primary = Context() - - # Act - primary.update(secondary) - - # Assert - emitted = primary.to_protobuf() - emitted_token_ids = { - tid for entry in emitted.vars for tid in entry.consumed_tokens - } - assert token_id.hex in emitted_token_ids - - def test_to_protobuf_with_locally_reset_token(self): - """Test Context.to_protobuf emits consumed-token ids scoped to - this Context's logical chain, excluding tokens reset under a - different chain. - - Given: - Two Contexts A and B with distinct ids, a Token minted - and reset under A, and a Token minted and reset under B - When: - A.to_protobuf() is called - Then: - The resulting ``consumed_tokens`` list contains A's - token id but not B's — the per-lineage scoping of - wire emission holds regardless of whether it is derived - from a global scan or from per-Context bookkeeping - """ - # Arrange - var_a = ContextVar("pin_scope_a", default="d") - var_b = ContextVar("pin_scope_b", default="d") - - a_tokens: list[Token] = [] - b_tokens: list[Token] = [] - - def consume_in_a() -> None: - a_tokens.append(var_a.set("ax")) - var_a.reset(a_tokens[-1]) - - def consume_in_b() -> None: - b_tokens.append(var_b.set("bx")) - var_b.reset(b_tokens[-1]) - - ctx_a = Context() - ctx_b = Context() - ctx_a.run(consume_in_a) - ctx_b.run(consume_in_b) - - # Act - a_pb = ctx_a.to_protobuf() - - # Assert - a_hex = {tid for entry in a_pb.vars for tid in entry.consumed_tokens} - assert a_tokens[0].id.hex in a_hex - assert b_tokens[0].id.hex not in a_hex - - def test_update_with_flipped_token_ids_then_to_protobuf( - self, - ): - """Test Context.update makes a merged used-token id visible in a - subsequent Context.to_protobuf on the same Context. - - Given: - A live Token minted under the current Context, and a - wire protocol.Context whose ``consumed_tokens`` lists - that Token's id (modeling a back-prop frame from a peer) - When: - The wire Context is merged via current_context().update - and current_context().to_protobuf() is called - Then: - The resulting ``consumed_tokens`` list contains the - merged token id — forwarding the used-state onward - through this Context's wire emissions - """ - # Arrange - from wool import protocol - - var = ContextVar("pin_forward", default="d") - token = var.set("x") - - pb = protocol.Context(id=uuid.uuid4().hex) - pb.vars.add( - namespace=var.namespace, - name=var.name, - consumed_tokens=[token.id.hex], - ) - - incoming = Context.from_protobuf(pb) - current_context().update(incoming) - - # Act - emitted = current_context().to_protobuf() - - # Assert - emitted_token_ids = { - tid for entry in emitted.vars for tid in entry.consumed_tokens - } - assert token.id.hex in emitted_token_ids - - def test_to_protobuf_roundtrips_consumed_tokens(self): - """Test Context.to_protobuf followed by Context.from_protobuf - preserves consumed-token state end-to-end. - - Given: - A Context that has consumed a Token via ContextVar.reset - and a strong reference to the consumed Token kept alive - outside the Context so the ``weakref.WeakSet`` tracking - entry survives until ``to_protobuf`` runs - When: - The Context is serialized via to_protobuf and a new - Context is reconstructed via from_protobuf - Then: - The reconstructed Context emits the same consumed-token - id when serialized again — the round-trip is observable - via the wire shape rather than via a registry-aliased - Token whose ``used`` flag is trivially True regardless - of whether ``from_protobuf`` populated the destination - """ - # Arrange - var = ContextVar("proto_consumed", default="d") - origin = Context() - captured: list[Token[Any]] = [] - - def consume(): - t = var.set("x") - var.reset(t) - captured.append(t) - - origin.run(consume) - - # Act - roundtripped = Context.from_protobuf(origin.to_protobuf()) - - # Assert - emitted = roundtripped.to_protobuf() - emitted_token_ids = { - tid for entry in emitted.vars for tid in entry.consumed_tokens - } - assert captured[0].id.hex in emitted_token_ids - - def test_to_protobuf_forwards_wire_consumed_tokens_without_local_mutation( - self, - ): - """Test from_protobuf → to_protobuf round-trip preserves - wire-supplied consumed-token ids when no local mutation - occurred between the two operations. - - Given: - A Context built via Context.from_protobuf from a wire - payload whose consumed_tokens field carries a token id, - and no local ContextVar.reset or update call against the - resulting Context after construction - When: - to_protobuf is called on the resulting Context - Then: - The emitted wire payload's consumed_tokens contains the - original id — a participant in a dispatch chain must - forward upstream-told consumed-token state, not silently - drop it on the second hop - """ - # Arrange - from wool import protocol - - token_id = uuid.uuid4() - wire_in = protocol.Context(id=uuid.uuid4().hex) - wire_in.vars.add(namespace="", name="", consumed_tokens=[token_id.hex]) - - # Act - ctx = Context.from_protobuf(wire_in) - wire_out = ctx.to_protobuf() - - # Assert - emitted_token_ids = { - tid for entry in wire_out.vars for tid in entry.consumed_tokens - } - assert token_id.hex in emitted_token_ids, ( - "Context built from a wire payload must forward " - "upstream consumed-token ids on subsequent to_protobuf " - "emission; otherwise nested dispatches lose chain state" - ) - - def test_to_protobuf_with_custom_dumps(self): - """Test Context.to_protobuf returns a Context with custom-serialized vars. - - Given: - A ContextVar with a value set and a custom dumps function - When: - current_context().to_protobuf(serializer=custom) is called - Then: - It should return a protocol.Context whose vars map was - produced by the custom serializer and whose id is a 32-char - UUID hex string. - """ - # Arrange - var = ContextVar("tpb_custom_dumps", namespace="tpb") - var.set(42) - - calls: list[object] = [] - - def custom_dumps(value: object) -> bytes: - calls.append(value) - return b"tpb:" + str(value).encode() - - # Act - pb = current_context().to_protobuf( - serializer=cast(Serializer, SimpleNamespace(dumps=custom_dumps)) - ) - - # Assert - emitted = {(e.namespace, e.name): e.value for e in pb.vars} - assert (var.namespace, var.name) in emitted - assert emitted[(var.namespace, var.name)] == b"tpb:42" - assert isinstance(pb.id, str) - assert len(pb.id) == 32 - assert calls == [42] - - def test_from_protobuf_with_custom_loads(self): - """Test Context.from_protobuf deserializes values via a custom loads callable. - - Given: - A ContextVar registered in the process and a wire-form - protocol.Context with a custom-encoded value - When: - Context.from_protobuf is called with a custom loads function - Then: - It should deserialize each value through the custom callable. - """ - # Arrange - from wool import protocol - - var = ContextVar("fpb_custom_loads", namespace="fpb") - pb = protocol.Context( - vars=[ - protocol.ContextVar( - namespace=var.namespace, - name=var.name, - value=b"custom-payload", - ) - ] - ) - - calls: list[bytes] = [] - - def custom_loads(data: bytes) -> object: - calls.append(data) - return "decoded-" + data.decode() - - # Act - reconstructed = Context.from_protobuf( - pb, serializer=cast(Serializer, SimpleNamespace(loads=custom_loads)) - ) - - # Assert - assert reconstructed[var] == "decoded-custom-payload" - assert calls == [b"custom-payload"] - - -def test_copy_context_with_set_vars(): - """Test copy_context() snapshots vars and assigns a fresh chain id. - - Given: - A ContextVar with an explicit value set in the live Context - When: - copy_context() is called - Then: - The snapshot contains the var's value but its id differs - from the live Context's id — the copy is an independent - logical chain - """ - # Arrange - var = ContextVar("copy_ctx", default=0) - var.set(1) - live_id = current_context().id - - # Act - snapshot = copy_context() - - # Assert - assert snapshot[var] == 1 - assert snapshot.id != live_id - - -def test_copy_context_chain_id_uniqueness(): - """Test successive copy_context() calls each get a distinct id. - - Given: - No mutation to the live Context between calls - When: - copy_context() is called twice - Then: - The two snapshots have distinct ids - """ - # Act - a = copy_context() - b = copy_context() - - # Assert - assert a.id != b.id - - -@pytest.mark.asyncio -async def test_create_task_with_explicit_wool_context_skips_fork(): - """Test passing a :class:`wool.Context` to :func:`wool.create_task` - pre-binds the given Context and bypasses the copy-on-fork path - the wool task factory would otherwise take. - - Given: - An event loop with wool's task factory installed and a - parent scope holding a ContextVar binding that the factory's - fork path would normally propagate to child tasks. - When: - :func:`wool.create_task` is called with a fresh (empty) - target wool.Context distinct from the parent's. - Then: - The child sees the target Context directly — same identity, - and crucially without the parent's var binding — proving - the fork path was skipped rather than taken. This is the - canonical wool task-binding idiom: stdlib's ``context=`` - kwarg is intercepted by wool's task factory and routed - through wool's per-task registry, with - :func:`wool.create_task` providing the typing shim around - the duck-typed wool.Context payload. - """ - from wool.runtime.context import create_task - from wool.runtime.context import install_task_factory - - # Arrange - loop = asyncio.get_running_loop() - install_task_factory(loop) - - sentinel_var = ContextVar("bound_task_fork_sentinel", default="default") - sentinel_var.set("parent-value") - - target = Context() - observed_ctx: list[Context] = [] - observed_value: list[str] = [] - - async def body(): - observed_ctx.append(current_context()) - observed_value.append(sentinel_var.get()) - - # Act - task = create_task(body(), context=target) - await task - - # Assert - assert observed_ctx == [target] - # A forked child would have inherited "parent-value" via - # parent.copy(); the bound child sees the fresh target's empty - # state and falls through to the var's default. - assert observed_value == ["default"] - - -@pytest.mark.asyncio -async def test_create_task_inside_parent_context_scope(): - """Test the child task spawned via asyncio.create_task inherits a fork - of the parent's current_context (not an empty Context). - - Given: - An event loop with wool's task factory installed, a parent - task that holds a non-empty Context carrying a ContextVar - binding - When: - The parent calls asyncio.create_task to spawn a child - Then: - The child's current_context should be a distinct Context - (fresh chain id) but carrying the same var binding the - parent set — locking in the fork-on-spawn contract that - depends on ``asyncio.current_task()`` resolving to the - parent inside the task-factory callback - """ - from wool.runtime.context import install_task_factory - - loop = asyncio.get_running_loop() - install_task_factory(loop) - - var = ContextVar("fork_invariant_probe", default="d") - var.set("parent-value") - parent_ctx = current_context() - - captured: list[tuple[uuid.UUID, str]] = [] - - async def child() -> None: - child_ctx = current_context() - captured.append((child_ctx.id, var.get())) - - await asyncio.create_task(child()) - - assert len(captured) == 1 - child_id, child_value = captured[0] - assert child_id != parent_ctx.id # fork mints a fresh chain id - assert child_value == "parent-value" # but carries parent's var state - - -@pytest.mark.asyncio -async def test_install_task_factory_idempotent(): - """Test install_task_factory is a no-op when already installed. - - Given: - install_task_factory has been called on the running loop - When: - install_task_factory is called again - Then: - It should return without error (idempotent) - """ - from wool.runtime.context import install_task_factory - - # Arrange - install_task_factory() - - # Act & assert — no error - install_task_factory() - - -@pytest.mark.asyncio -async def test_install_task_factory_with_existing_factory(): - """Test install_task_factory wraps an existing factory. - - Given: - A custom task factory already set on the loop - When: - install_task_factory is called - Then: - It should wrap the existing factory, creating tasks via the - original while also seeding wool Context on the child - """ - from wool.runtime.context import install_task_factory - - # Arrange - loop = asyncio.get_running_loop() - calls = [] - - def custom_factory(loop, coro, **kwargs): - calls.append("custom") - return asyncio.Task(coro, loop=loop, **kwargs) - - loop.set_task_factory(custom_factory) - - # Act - install_task_factory() - - var = ContextVar("compose_test", namespace="test_compose") - var.set("parent_value") - - async def child(): - return var.get("missing") - - result = await asyncio.create_task(child()) - - # Assert - assert len(calls) > 0 # custom factory was called - assert result == "parent_value" # wool context inherited - - # Cleanup - loop.set_task_factory(None) - - -@pytest.mark.asyncio -async def test_install_task_factory_idempotent_over_composed(): - """Test install_task_factory is a no-op when a wool-composed factory is installed. - - Given: - A user factory was set on the loop and install_task_factory - composed around it - When: - install_task_factory is called again - Then: - The second call recognizes the _wool_wrapped marker and - returns without replacing the composed factory - """ - from wool.runtime.context import install_task_factory - - # Arrange - loop = asyncio.get_running_loop() - - def custom_factory(loop, coro, **kwargs): - return asyncio.Task(coro, loop=loop, **kwargs) - - loop.set_task_factory(custom_factory) - install_task_factory() - composed = loop.get_task_factory() - - # Act - install_task_factory() - - # Assert - assert loop.get_task_factory() is composed - - # Cleanup - loop.set_task_factory(None) - - -@pytest.mark.asyncio -async def test_install_task_factory_on_fresh_loop(caplog): - """Test install_task_factory logs a fresh-install message on an empty loop. - - Given: - A running event loop with no task factory set - When: - install_task_factory runs once - Then: - A debug record naming the "installed" path is emitted - """ - from wool.runtime.context import install_task_factory - - # Arrange - loop = asyncio.get_running_loop() - loop.set_task_factory(None) - - # Act - with caplog.at_level("DEBUG", logger="wool.runtime.context"): - install_task_factory() - - # Assert - assert any("wool task factory installed" in r.message for r in caplog.records) - - -@pytest.mark.asyncio -async def test_install_task_factory_when_recalled(caplog): - """Test a second install on the same loop logs the already-installed path. - - Given: - A running event loop with wool's factory already installed - When: - install_task_factory runs a second time - Then: - A debug record naming the "already installed" path is emitted - """ - from wool.runtime.context import install_task_factory - - # Arrange - loop = asyncio.get_running_loop() - loop.set_task_factory(None) - install_task_factory() - - # Act - with caplog.at_level("DEBUG", logger="wool.runtime.context"): - install_task_factory() - - # Assert - assert any("already installed" in r.message for r in caplog.records) - - -@pytest.mark.asyncio -async def test_install_task_factory_with_user_factory_present(caplog): - """Test install over a non-wool user factory logs the compose path. - - Given: - A running event loop with a non-wool user task factory in place - When: - install_task_factory runs - Then: - A debug record naming the "composed with existing factory" - path is emitted - """ - from wool.runtime.context import install_task_factory - - # Arrange - loop = asyncio.get_running_loop() - loop.set_task_factory(None) - - def custom_factory(loop, coro, **kwargs): - return asyncio.Task(coro, loop=loop, **kwargs) - - loop.set_task_factory(custom_factory) - - # Act - with caplog.at_level("DEBUG", logger="wool.runtime.context"): - install_task_factory() - - # Assert - assert any("composed with existing factory" in r.message for r in caplog.records) - - -@pytest.mark.asyncio -async def test_install_task_factory_when_recalled_over_composed( - caplog, -): - """Test a second install over a composed factory logs the already-composed path. - - Given: - A running event loop with wool's factory already composed - over a user factory - When: - install_task_factory runs a second time - Then: - A debug record naming the "composed task factory already - installed" path is emitted - """ - from wool.runtime.context import install_task_factory - - # Arrange - loop = asyncio.get_running_loop() - loop.set_task_factory(None) - - def custom_factory(loop, coro, **kwargs): - return asyncio.Task(coro, loop=loop, **kwargs) - - loop.set_task_factory(custom_factory) - install_task_factory() - - # Act - with caplog.at_level("DEBUG", logger="wool.runtime.context"): - install_task_factory() - - # Assert - assert any( - "composed task factory already installed" in r.message for r in caplog.records - ) - - -def test_update_with_external_uuid_resolved_to_reloaded_token(): - """Test Context.update flips the used flag on a live Token - re-registered via cloudpickle reload after the wire snapshot - was decoded. - - Given: - A wire :class:`protocol.Context` whose ``consumed_tokens`` - carries a UUID for which no live :class:`Token` exists at - decode time — the originating Token was cloudpickled and - then released. After ``Context.from_protobuf`` records the - UUID, the pickle is reloaded so a fresh live Token is in - the registry under the same UUID. - When: - ``current_context().update(secondary)`` runs against that - wire-derived Context. - Then: - The reloaded Token's ``used`` flag flips to True — the - merge resolves the incoming UUID through the live token - registry rather than parking it as a bare id. - """ - # Arrange - from wool import protocol - - var = ContextVar(f"update_external_uuid_{uuid.uuid4().hex}") - token = var.set("x") - token_id = token.id - pickled = dumps(token) - del token - gc.collect() - - pb = protocol.Context(id=uuid.uuid4().hex) - pb.vars.add( - namespace=var.namespace, - name=var.name, - consumed_tokens=[token_id.hex], - ) - secondary = Context.from_protobuf(pb) - - restored = loads(pickled) - assert restored.used is False - - # Act - current_context().update(secondary) - - # Assert - assert restored.used is True - - -@pytest.mark.asyncio -async def test_create_task_with_context_already_bound_to_another_running_task(): - """Test :func:`wool.create_task` rejects a second concurrent task - pinned to a :class:`wool.Context` already running another task. - - Given: - A running event loop with wool's task factory installed and a - :class:`wool.Context` bound to a first task whose coroutine is - suspended (the bound-task slot points at a still-running task). - When: - :func:`wool.create_task` schedules a second coroutine targeting - the same Context while the first task is still alive, and the - second task is awaited. - Then: - It should raise :class:`RuntimeError` naming the "first-task- - wins for the routine's lifetime" invariant. - """ - # Arrange - from wool.runtime.context import create_task as wool_create_task - from wool.runtime.context import install_task_factory - - install_task_factory() - ctx = Context() - started = asyncio.Event() - release = asyncio.Event() - - async def first(): - started.set() - await release.wait() - - async def second(): - return "should-not-run" - - first_task = wool_create_task(first(), context=ctx) - await started.wait() - second_task = wool_create_task(second(), context=ctx) - - # Act & assert - try: - with pytest.raises(RuntimeError, match="bound to another live task"): - await second_task - finally: - release.set() - await first_task - - -@pytest.mark.asyncio -async def test_create_task_with_copy_context_inherits_and_forks(): - """Test passing a :func:`contextvars.copy_context` to - ``create_task`` does not break wool's fork-on-task semantics — - the child still inherits the parent's wool.Context state under - a fresh chain id via the wool task factory's copy-at-creation - path. - - Given: - A running event loop with wool's task factory installed and - a parent scope whose live wool.Context carries an explicit - ContextVar binding. - When: - A child coroutine is scheduled with - ``context=contextvars.copy_context()`` — exercising the - stdlib ``context=`` forwarding path of the wool factory. - Then: - The child's :func:`current_context` carries the parent's - ContextVar binding (inherited by wool's fork-on-task) under - a fresh chain id, and child mutations do not leak back to - the parent — the stdlib ``context=`` argument is forwarded - verbatim to asyncio without disturbing wool's parallel - registry. - """ - # Arrange - import contextvars - - from wool.runtime.context import install_task_factory - - install_task_factory() - - var = ContextVar("copy_context_inherit_probe", default="default") - var.set("parent-value") - parent_ctx = current_context() - - captured: list[tuple[uuid.UUID, str]] = [] - - async def child() -> None: - captured.append((current_context().id, var.get())) - # Mutate to verify isolation back toward the parent. - var.set("child-mutated") - - # Act - await asyncio.get_running_loop().create_task( - child(), context=contextvars.copy_context() - ) - - # Assert - assert len(captured) == 1 - child_id, child_value = captured[0] - assert child_value == "parent-value" - assert child_id != parent_ctx.id - # Parent observes its own original binding — the child's mutation - # rode the forked wool.Context, independent of stdlib Context. - assert var.get() == "parent-value" - assert current_context() is parent_ctx - - -@pytest.mark.asyncio -async def test_current_context_self_installs_task_factory(): - """Test :func:`current_context` self-installs the wool task - factory on the running loop the first time it is called, so - user code that touches wool without an explicit - :func:`install_task_factory` call still gets fork-on-task - semantics for tasks created afterward. - - Given: - A running event loop on which the wool task factory has not - been installed (the factory slot is empty). - When: - :func:`current_context` is called inside a task on that loop. - Then: - The loop's task factory is set to wool's wrapped factory, - and a child task subsequently created with - ``context=contextvars.copy_context()`` observes the - copy-on-fork contract — the child's wool.Context inherits - the parent's ContextVar bindings but carries a fresh chain - id, and child mutations do not leak back to the parent. - """ - # Arrange - import contextvars - - loop = asyncio.get_running_loop() - # Reset to a clean slate: any factory the test harness installed - # for prior tests is removed so the auto-install path is exercised. - loop.set_task_factory(None) - - # Act — first wool API touch on this loop should install the factory. - parent_ctx = current_context() - - # Assert — factory is now wool's wrapped factory. - factory = loop.get_task_factory() - assert factory is not None - assert getattr(factory, "__wool_wrapped__", False) is True - - # And the wrapper is doing its job: a child task with - # copy_context inherits state but forks the chain. - var = ContextVar("auto_install_probe", default="default") - var.set("parent-value") - - captured: list[tuple[uuid.UUID, str]] = [] - - async def child() -> None: - captured.append((current_context().id, var.get())) - var.set("child-mutated") - - await loop.create_task(child(), context=contextvars.copy_context()) - - assert len(captured) == 1 - child_id, child_value = captured[0] - assert child_value == "parent-value" - assert child_id != parent_ctx.id - assert var.get() == "parent-value" diff --git a/wool/tests/runtime/context/test_chain.py b/wool/tests/runtime/context/test_chain.py new file mode 100644 index 00000000..c9297dd5 --- /dev/null +++ b/wool/tests/runtime/context/test_chain.py @@ -0,0 +1,423 @@ +"""Unit tests for Chain — the live, immutable chain-state model.""" + +import asyncio +import contextvars +import threading +from uuid import uuid4 + +import pytest +from hypothesis import given +from hypothesis import settings +from hypothesis import strategies as st + +import wool +from tests.helpers import _unique +from tests.helpers import scoped_context +from wool.runtime.context.chain import Chain +from wool.runtime.context.var import ContextVar + + +def _count_wool_vars_in_a_fresh_context(work) -> int: + """Run *work* in a brand-new Chain and count Wool-owned variables. + + A fresh `contextvars.Context` carries no backing variables + leaked from earlier work on the running thread, so the count is + exactly the Wool-owned variables *work* itself binds. + """ + holder: list[int] = [] + + def _runner() -> None: + work() + copied = contextvars.copy_context() + holder.append(len([v for v in copied if v.name.startswith("__wool")])) + + contextvars.Context().run(_runner) + return holder[0] + + +class TestChain: + def test___init___should_default_collections_empty_when_required_fields_only(self): + """Test Chain construction with only the required fields. + + Given: + A fresh chain id and an owner thread id. + When: + A Chain is constructed with only chain_id and owner. + Then: + It should expose empty data, resets, and stubs + collections by default. + """ + # Arrange + chain_id = uuid4() + owner = threading.get_ident() + + # Act + context = Chain(id=chain_id, thread=owner) + + # Assert + assert context.id == chain_id + assert context.thread == owner + assert context.vars == frozenset() + assert context.resets == frozenset() + assert context.stubs == frozenset() + + def test___init___should_expose_supplied_collections_when_all_fields(self): + """Test Chain construction with every field supplied. + + Given: + Explicit data, resets, and stubs collections. + When: + A Chain is constructed with all fields. + Then: + It should expose each supplied collection verbatim. + """ + # Arrange + var = ContextVar(_unique("snap_init")) + bound = frozenset({var}) + resets = frozenset({("ns", "name")}) + + # Act + context = Chain( + id=uuid4(), + thread=threading.get_ident(), + vars=bound, + resets=resets, + stubs=frozenset({var}), + ) + + # Assert + assert context.vars == bound + assert context.resets == resets + assert context.stubs == frozenset({var}) + + def test_mount_should_install_field_replacing_copy(self): + """Test mount installs a field-replacing copy of the chain. + + Given: + A directly-constructed chain with a known chain id and one + variable binding. + When: + mount is called on it. + Then: + It should install a new chain carrying the same chain id and + bindings while leaving the original instance untouched — the + mount is a field-replacing copy, not an in-place mutation. + """ + # Arrange + var = ContextVar(_unique("snap_evolve")) + chain_id = uuid4() + original = Chain(id=chain_id, vars=frozenset({var})) + + # Act + with scoped_context(): + installed = original.mount() + + # Assert — the installed copy preserves the chain id and bindings + assert installed.id == chain_id + assert installed.vars == frozenset({var}) + # Assert — the mount is a copy, not an in-place mutation + assert installed is not original + + def test___post_init___should_coerce_iterables_to_frozensets(self): + """Test Chain coerces non-frozenset iterables on construction. + + Given: + Plain sets, lists, and tuples supplied for the data, + resets, and stubs fields — collections that + satisfy the typing intent of the field but are not + frozensets at the call site. + When: + A Chain is constructed with those iterables. + Then: + All three fields should expose frozenset views after + construction — the post-init coerces non-frozenset + iterables so the dataclass invariant (hashable + immutable + container shape) holds regardless of what the caller + passed. + """ + # Arrange + var = ContextVar(_unique("post_init_coerce")) + + # Act — supply a plain set, list, and tuple respectively. + context = Chain( + id=uuid4(), + thread=threading.get_ident(), + vars={var}, # set, not frozenset + resets=[("ns", "name")], # list, not frozenset + stubs=(var,), # tuple, not frozenset + ) + + # Assert + assert isinstance(context.vars, frozenset) + assert isinstance(context.resets, frozenset) + assert isinstance(context.stubs, frozenset) + assert context.vars == frozenset({var}) + assert context.resets == frozenset({("ns", "name")}) + assert context.stubs == frozenset({var}) + + def test_equality_should_be_identity_based(self): + """Test Chain equality is identity-based. + + Given: + Two chains constructed with identical field values. + When: + They are compared for equality. + Then: + They should be unequal — Chain is declared eq=False so + distinct instances never compare equal. + """ + # Arrange + chain_id = uuid4() + owner = threading.get_ident() + + # Act + first = Chain(id=chain_id, thread=owner) + second = Chain(id=chain_id, thread=owner) + + # Assert + assert first != second + assert first == first + + @given( + ops=st.lists( + st.tuples( + st.sampled_from(["set", "reset"]), + st.integers(min_value=0, max_value=2), + st.integers(), + ), + max_size=30, + ) + ) + @settings(max_examples=50, deadline=None) + def test_bookkeeping_should_match_oracle_when_arbitrary_set_reset_sequences( + self, ops + ): + """Test chain bookkeeping tracks an oracle over set/reset sequences. + + Given: + Three fresh wool.ContextVars in an unarmed scoped context, + an arbitrary sequence of set/reset operations over them, + and a naive per-var oracle replaying each operation (set + binds and clears any pending reset; reset spends the most + recent unspent token, restoring its prior state — a prior + of unbound leaves the key reset-pending). + When: + Each operation is applied through the public set/reset API. + Then: + After every operation the live chain's vars index and + resets keys should match the oracle, every var's get() + should observe the oracle's value or default, and vars and + resets should stay disjoint. + """ + unset = object() + + def _check() -> None: + targets = [ContextVar(_unique(f"oracle_{i}")) for i in range(3)] + value_of: dict[int, object] = {i: unset for i in range(3)} + pending: set[tuple[str, str]] = set() + tokens: dict[int, list] = {i: [] for i in range(3)} + + for op, slot, value in ops: + var = targets[slot] + key = (var.namespace, var.name) + if op == "set": + tokens[slot].append((var.set(value), value_of[slot])) + value_of[slot] = value + pending.discard(key) + else: + if not tokens[slot]: + continue # no unspent token — nothing to reset + token, prior = tokens[slot].pop() + var.reset(token) + value_of[slot] = prior + if prior is unset: + pending.add(key) + else: + pending.discard(key) + + chain = wool.__chain__.get(None) + chain_vars = ( + {(v.namespace, v.name) for v in chain.vars} + if chain is not None + else set() + ) + chain_resets = set(chain.resets) if chain is not None else set() + expected_bound = { + (targets[i].namespace, targets[i].name) + for i in range(3) + if value_of[i] is not unset + } + assert chain_vars == expected_bound + assert chain_resets == pending + assert not (chain_vars & chain_resets) + for i, target in enumerate(targets): + expected = "" if value_of[i] is unset else value_of[i] + assert target.get("") == expected + + contextvars.Context().run(_check) + + def test_to_manifest_should_snapshot_bound_values_inline(self): + """Test to_manifest captures each bound variable's live value. + + Given: + An armed chain with two distinct variable bindings. + When: + to_manifest is called on the active chain. + Then: + The returned manifest should map each bound variable to the + value read from its backing in the calling context. + """ + # Arrange + var_a = ContextVar(_unique("snap_a")) + var_b = ContextVar(_unique("snap_b")) + + with scoped_context(): + var_a.set(1) + var_b.set(2) + + # Act + manifest = wool.__chain__.get().to_manifest() + + # Assert + assert manifest.vars == {var_a: 1, var_b: 2} + + def test_to_manifest_should_skip_variable_when_backing_undefined(self): + """Test to_manifest omits a variable whose backing is Undefined. + + Given: + A chain whose vars index names a variable whose backing + resolves to the Undefined sentinel in the active context. + When: + to_manifest is called inside that context. + Then: + The returned manifest should carry no entry for that variable. + """ + # Arrange + var = ContextVar(_unique("snap_desync")) + + with scoped_context(): + token = var.set("v") + var.reset(token) + context = Chain( + id=uuid4(), + thread=threading.get_ident(), + vars=frozenset({var}), + ) + + # Act + manifest = context.to_manifest() + + # Assert + assert var not in manifest.vars + + def test_to_manifest_should_carry_id_resets_and_stubs(self): + """Test to_manifest carries the chain id and reset signals through. + + Given: + A chain with a known id and a reset-pending variable key. + When: + to_manifest is called. + Then: + The returned manifest should preserve the chain id and the + reset signal verbatim, with no inline value for the reset key. + """ + # Arrange + var = ContextVar(_unique("snap_reset")) + chain_id = uuid4() + context = Chain( + id=chain_id, + thread=threading.get_ident(), + resets=frozenset({(var.namespace, var.name)}), + ) + + # Act + manifest = context.to_manifest() + + # Assert + assert manifest.id == chain_id + assert manifest.resets == frozenset({(var.namespace, var.name)}) + assert manifest.vars == {} + + @pytest.mark.asyncio + async def test_child_task_should_fork_fresh_chain_copying_bindings(self): + """Test a child task forks a fresh chain that copies bindings and drops resets. + + Given: + An armed chain carrying one variable binding and one + reset-pending key (a variable set then reset to no value). + When: + A child task is created under Wool's task factory. + Then: + The child's chain should carry a different chain id, inherit + the bound variable, and start with empty resets — copy-on- + fork mints a fresh chain id, copies the bindings, and drops + the parent's reset signals. + """ + # Arrange + bound_var = ContextVar(_unique("fork_bound")) + reset_var = ContextVar(_unique("fork_reset")) + + with scoped_context(): + bound_var.set("bound") + token = reset_var.set("transient") + reset_var.reset(token) + parent = wool.__chain__.get(None) + assert parent is not None + assert (reset_var.namespace, reset_var.name) in parent.resets + + async def child() -> Chain: + forked = wool.__chain__.get(None) + assert forked is not None + return forked + + # Act + forked = await asyncio.create_task(child()) + + # Assert + assert forked.id != parent.id + assert bound_var in forked.vars + assert forked.resets == frozenset() + + +def test_copy_context_should_carry_no_wool_variables_when_unarmed(): + """Test a copy_context of an unarmed context carries no Wool variables. + + Given: + A fresh, unarmed Wool context where no wool.ContextVar has + been set. + When: + contextvars.copy_context enumerates its variables. + Then: + No Wool-owned contextvars.ContextVar should appear — an + unarmed context is indistinguishable from a plain + contextvars.Context. + """ + # Arrange, act, & assert + assert _count_wool_vars_in_a_fresh_context(lambda: None) == 0 + + +@pytest.mark.parametrize("n", [1, 2, 3]) +def test_copy_context_should_carry_one_plus_n_wool_variables_when_armed(n): + """Test a copy_context of an armed context carries 1 + N Wool variables. + + Given: + A context armed with N bound wool.ContextVars. + When: + contextvars.copy_context enumerates its variables. + Then: + Exactly 1 + N Wool-owned contextvars.ContextVars should appear + — the one context variable plus one backing variable per + bound wool.ContextVar — the explicit context-audit contract. + """ + # Arrange + bound = [ContextVar(_unique("width")) for _ in range(n)] + + def _arm() -> None: + for i, var in enumerate(bound): + var.set(i) + + # Act + count = _count_wool_vars_in_a_fresh_context(_arm) + + # Assert + assert count == 1 + n diff --git a/wool/tests/runtime/context/test_exceptions.py b/wool/tests/runtime/context/test_exceptions.py new file mode 100644 index 00000000..b96e72af --- /dev/null +++ b/wool/tests/runtime/context/test_exceptions.py @@ -0,0 +1,457 @@ +import pickle +import warnings +from uuid import uuid4 + +import pytest +from hypothesis import given +from hypothesis import settings +from hypothesis import strategies as st + +import wool +from wool.runtime.context.exceptions import ChainContention +from wool.runtime.context.exceptions import ChainSerializationError +from wool.runtime.context.exceptions import ContextVarCollision +from wool.runtime.context.exceptions import SerializationError +from wool.runtime.context.exceptions import SerializationWarning +from wool.runtime.context.exceptions import TaskFactoryDisplaced + + +class TestSerializationError: + def test___init___should_default_structured_fields_to_none_when_message_only(self): + """Test SerializationError construction with only a message. + + Given: + The SerializationError class. + When: + An instance is constructed with only a message argument. + Then: + It should expose None for cause and value_repr, and carry + the message on args. + """ + # Act + error = SerializationError("encode failed") + + # Assert + assert error.cause is None + assert error.value_repr is None + assert error.args == ("encode failed",) + + def test___init___should_expose_structured_fields_when_provided(self): + """Test SerializationError construction with structured fields. + + Given: + An underlying cause exception and a value repr preview. + When: + An instance is constructed with the cause and value_repr + keyword fields alongside a message. + Then: + It should expose both structured fields unchanged and + carry the message on args. + """ + # Arrange + cause = TypeError("cannot pickle lock") + + # Act + error = SerializationError( + "encode failed", + cause=cause, + value_repr="", + ) + + # Assert + assert error.cause is cause + assert error.value_repr == "" + assert "encode failed" in error.args + + def test_catchability_should_catch_as_wool_error(self): + """Test SerializationError is catchable as a WoolError. + + Given: + A SerializationError instance. + When: + It is raised inside a try block with an except clause for + wool.WoolError. + Then: + It should be caught by that clause and be the raised + instance. + """ + # Arrange + error = SerializationError("boom") + + # Act + try: + raise error + except wool.WoolError as caught_error: + caught = caught_error + + # Assert + assert caught is error + + +class TestChainSerializationError: + def test_warnings_should_keep_warnings_and_summarize_message(self): + """Test ChainSerializationError keeps warnings and summarizes them. + + Given: + Two SerializationWarning instances interleaved with a + plain string argument. + When: + A ChainSerializationError is constructed from them. + Then: + It should keep exactly the two warnings on the warnings + attribute and carry a synthesized summary message on args + instead of the raw arguments tuple. + """ + # Arrange + warning_a = SerializationWarning("first failure") + warning_b = SerializationWarning("second failure") + + # Act + error = ChainSerializationError(warning_a, "plain", warning_b) + + # Assert + assert error.warnings == (warning_a, warning_b) + assert error.args == (str(error),) + assert "plain" not in str(error) + + def test_warnings_should_return_empty_tuple_when_no_warning_args(self): + """Test ChainSerializationError.warnings with no warning args. + + Given: + A ChainSerializationError constructed without any + SerializationWarning arguments. + When: + Its warnings property is read. + Then: + It should return an empty tuple. + """ + # Act + error = ChainSerializationError("no warnings here") + + # Assert + assert error.warnings == () + + def test_catchability_should_catch_as_serialization_error(self): + """Test ChainSerializationError is catchable as SerializationError. + + Given: + A ChainSerializationError instance. + When: + It is raised inside a try block with an except clause for + wool.SerializationError. + Then: + It should be caught by that clause and be the raised + instance. + """ + # Arrange + error = ChainSerializationError(SerializationWarning("promoted")) + + # Act + try: + raise error + except wool.SerializationError as caught_error: + caught = caught_error + + # Assert + assert caught is error + + def test_roots_should_subclass_wool_error_not_runtime_error(self): + """Test ChainSerializationError subclasses WoolError, not RuntimeError. + + Given: + The ChainSerializationError aggregator class. + When: + Its position in the exception hierarchy is checked. + Then: + It should subclass wool.WoolError (catchable via + wool.SerializationError) and must not subclass RuntimeError, + so a documented ``except RuntimeError`` would not catch it. + """ + # Arrange, act, & assert + assert issubclass(ChainSerializationError, wool.WoolError) + assert issubclass(ChainSerializationError, wool.SerializationError) + assert not issubclass(ChainSerializationError, RuntimeError) + + +class TestSerializationWarning: + def test___init___should_default_structured_fields_to_none_when_message_only(self): + """Test SerializationWarning construction with only a message. + + Given: + The SerializationWarning class. + When: + An instance is constructed with only a message argument. + Then: + It should expose None for every structured field — cause, + var_key, direction, original_type — and carry the + message on args. + """ + # Act + warning = SerializationWarning("something failed to serialize") + + # Assert + assert warning.cause is None + assert warning.var_key is None + assert warning.direction is None + assert warning.original_type is None + assert warning.args == ("something failed to serialize",) + + def test___init___should_expose_structured_fields_when_provided(self): + """Test SerializationWarning construction with structured fields. + + Given: + A cause exception, a (namespace, name) variable key, and a + direction literal. + When: + An instance is constructed with all keyword fields — + cause, var_key, direction, and original_type. + Then: + It should expose each structured field unchanged as a + public attribute. + """ + # Arrange + cause = TypeError("cannot pickle lock") + var_key = ("test_ns", "test_var") + + # Act + warning = SerializationWarning( + "value failed to encode", + cause=cause, + var_key=var_key, + direction="encode", + original_type=ValueError, + ) + + # Assert + assert warning.cause is cause + assert warning.var_key == var_key + assert warning.direction == "encode" + assert warning.original_type is ValueError + + def test_serialization_warning_should_subclass_wool_warning(self): + """Test SerializationWarning is a WoolWarning subclass. + + Given: + The SerializationWarning class. + When: + Its subclass relationship to WoolWarning is checked. + Then: + It should be a subclass of WoolWarning — so callers can + promote every Wool warning category to an error with a + single filter on the umbrella. + """ + # Arrange, act, & assert + assert issubclass(SerializationWarning, wool.WoolWarning) + + def test_serialization_warning_should_be_re_exported_from_wool(self): + """Test SerializationWarning is re-exported on the wool package. + + Given: + The wool package and the SerializationWarning class. + When: + wool.SerializationWarning is accessed. + Then: + It should be the same class as the one defined in + wool.runtime.context.exceptions. + """ + # Arrange, act, & assert + assert wool.SerializationWarning is SerializationWarning + + def test_promotion_should_raise_warning_via_umbrella_filter(self): + """Test the WoolWarning umbrella filter promotes the warning. + + Given: + A warnings filter promoting wool.WoolWarning to an error. + When: + A SerializationWarning is emitted via warnings.warn. + Then: + It should raise the emitted SerializationWarning instance + as an exception — the umbrella's single-filter strict-mode + recipe works through the base class, not just the leaf. + """ + # Arrange + emitted = SerializationWarning("bad var", direction="encode") + + # Act & assert + with warnings.catch_warnings(): + warnings.filterwarnings("error", category=wool.WoolWarning) + with pytest.raises(SerializationWarning) as exc_info: + warnings.warn(emitted, stacklevel=2) + assert exc_info.value is emitted + + @given( + message=st.text(), + var_key=st.one_of(st.none(), st.tuples(st.text(), st.text())), + direction=st.sampled_from(["encode", "decode", None]), + ) + @settings(max_examples=100, deadline=None) + def test_serializer_roundtrip_should_preserve_structured_fields( + self, message, var_key, direction + ): + """Test the serializer round-trips a warning's structured fields. + + Given: + Any combination of message text, optional (namespace, + name) variable key, and optional direction literal. + When: + The warning is serialized and deserialized through + wool.__serializer__. + Then: + The restored instance should preserve args, var_key, and + direction. + """ + # Arrange + warning = SerializationWarning( + message, + var_key=var_key, + direction=direction, + ) + + # Act + restored = wool.__serializer__.loads(wool.__serializer__.dumps(warning)) + + # Assert + assert restored.args == warning.args + assert restored.var_key == var_key + assert restored.direction == direction + + +class TestChainContention: + def test_chain_contention_should_subclass_wool_error(self): + """Test ChainContention is a WoolError subclass. + + Given: + The wool.ChainContention exception class. + When: + Its subclass relationship to WoolError is checked. + Then: + It should be a subclass of WoolError — catchable under the + single Wool-domain umbrella. + """ + # Arrange, act, & assert + assert issubclass(ChainContention, wool.WoolError) + + def test_unknown_kind_should_raise_value_error(self): + """Test ChainContention rejects unknown ``kind`` values. + + Given: + A ``kind`` value that is not one of the recognised + literals (``"thread"``, ``"task"``, ``"create_task"``). + When: + ChainContention is constructed with that kind. + Then: + It should raise ValueError naming the unknown kind and + listing the recognised set — this guards dynamic call + sites (notably ``__reduce__``-driven cross-process + reconstruction where a forward-compat receiver might + decode a ``kind`` value it does not know) so the + exception fails loud at construction rather than later + with a KeyError during message formatting. + """ + # Arrange, act, & assert + with pytest.raises(ValueError, match="unknown ChainContention kind"): + ChainContention(chain_id=uuid4(), kind="bogus") # type: ignore[arg-type] + + def test_chain_contention_should_round_trip_through_pickle(self): + """Test ChainContention survives a pickle round-trip. + + Given: + A ChainContention raised with structured kwargs. + When: + The exception is pickled and unpickled. + Then: + The restored instance should carry the same chain id, kind, + and identity fields, and reconstruct an equivalent message. + """ + # Arrange + chain_id = uuid4() + exc = ChainContention( + chain_id=chain_id, + kind="thread", + owning_thread=12345, + current_thread=67890, + ) + + # Act + restored = pickle.loads(pickle.dumps(exc)) + + # Assert + assert isinstance(restored, ChainContention) + assert restored.chain_id == chain_id + assert restored.kind == "thread" + assert restored.owning_thread == 12345 + assert restored.current_thread == 67890 + assert str(restored) == str(exc) + + def test_chain_contention_should_interpolate_chain_id_when_create_task_kind(self): + """Test the create_task kind interpolates the chain id. + + Given: + A ChainContention with kind="create_task" raised by the task + factory when an armed context is re-passed to create_task. + When: + The exception's string form is inspected. + Then: + The message should mention create_task and interpolate the + chain id, and the exception's structured fields should + reflect the create_task kind. + """ + # Arrange + chain_id = uuid4() + + # Act + exc = ChainContention(chain_id=chain_id, kind="create_task") + + # Assert + assert exc.kind == "create_task" + assert exc.chain_id == chain_id + assert "create_task" in str(exc) + assert str(chain_id) in str(exc) + + +class TestContextVarCollision: + def test_context_var_collision_should_subclass_wool_error(self): + """Test ContextVarCollision is a WoolError subclass. + + Given: + The wool.ContextVarCollision exception class. + When: + Its subclass relationship to WoolError is checked. + Then: + It should be a subclass of WoolError — catchable under the + single Wool-domain umbrella. + """ + # Arrange, act, & assert + assert issubclass(ContextVarCollision, wool.WoolError) + + +class TestTaskFactoryDisplaced: + def test_task_factory_displaced_should_subclass_wool_error(self): + """Test TaskFactoryDisplaced is a WoolError subclass. + + Given: + The TaskFactoryDisplaced exception class. + When: + Its subclass relationship to WoolError is checked. + Then: + It should be a subclass of WoolError — displacement is an + unconditional structural failure, not a tunable warning; + this lets callers catch it under the single Wool-domain + umbrella. + """ + # Arrange, act, & assert + assert issubclass(TaskFactoryDisplaced, wool.WoolError) + + def test_task_factory_displaced_should_be_re_exported_from_wool(self): + """Test TaskFactoryDisplaced is re-exported on the wool package. + + Given: + The wool package and the TaskFactoryDisplaced class. + When: + wool.TaskFactoryDisplaced is accessed. + Then: + It should be the same class as the one defined in + wool.runtime.context.exceptions. + """ + # Arrange, act, & assert + assert wool.TaskFactoryDisplaced is TaskFactoryDisplaced diff --git a/wool/tests/runtime/context/test_factory.py b/wool/tests/runtime/context/test_factory.py new file mode 100644 index 00000000..9ec99c2b --- /dev/null +++ b/wool/tests/runtime/context/test_factory.py @@ -0,0 +1,1233 @@ +import asyncio +import contextvars +import gc +import logging +import uuid + +import pytest +import pytest_asyncio +from hypothesis import HealthCheck +from hypothesis import given +from hypothesis import settings +from hypothesis import strategies as st + +import wool +from tests.helpers import scoped_context +from wool.runtime.context.exceptions import ChainContention +from wool.runtime.context.exceptions import TaskFactoryDisplaced +from wool.runtime.context.factory import context_is_armed +from wool.runtime.context.factory import install_task_factory +from wool.runtime.context.threading import to_thread +from wool.runtime.context.var import ContextVar + + +def _unique(stem: str) -> str: + """Return a process-unique variable name to avoid registry collisions.""" + return f"{stem}_{uuid.uuid4().hex}" + + +@pytest_asyncio.fixture(autouse=True) +async def _reset_task_factory(): + """Restore the running loop's task factory after each test. + + Tests here install Wool's task factory on the event loop; clearing + it on teardown keeps a failed assertion mid-test from leaving a + factory installed for any later test that shares the loop. + """ + yield + asyncio.get_running_loop().set_task_factory(None) + + +@pytest.mark.asyncio +async def test_install_task_factory_should_install_factory_when_none_exists(): + """Test install_task_factory installs a factory on the running loop. + + Given: + A running loop with no task factory installed. + When: + install_task_factory is called with no loop argument. + Then: + It should install a non-None task factory on the running loop. + """ + # Arrange + loop = asyncio.get_running_loop() + loop.set_task_factory(None) + + # Act + install_task_factory() + + # Assert + assert loop.get_task_factory() is not None + + +@pytest.mark.asyncio +async def test_install_task_factory_should_not_double_wrap_when_already_installed(): + """Test install_task_factory does not double-wrap an installed factory. + + Given: + A running loop where install_task_factory has already run. + When: + install_task_factory is called a second time. + Then: + It should leave the task factory object unchanged. + """ + # Arrange + loop = asyncio.get_running_loop() + loop.set_task_factory(None) + install_task_factory() + first = loop.get_task_factory() + + # Act + install_task_factory() + + # Assert + assert loop.get_task_factory() is first + + +def test_install_task_factory_should_raise_runtime_error_when_outside_running_loop(): + """Test install_task_factory(loop=None) outside a loop raises clearly. + + Given: + No running event loop in the calling scope. + When: + install_task_factory is called with no loop argument. + Then: + It should raise a RuntimeError whose message names + install_task_factory and directs the caller to either run + inside a running loop or pass ``loop=`` explicitly. + """ + # Arrange, act & assert + with pytest.raises(RuntimeError, match="install_task_factory"): + install_task_factory() + + +@pytest.mark.asyncio +async def test_first_set_should_self_install_the_task_factory(): + """Test a bare ContextVar.set self-installs the task factory. + + Given: + A running loop with no task factory installed and no + explicit install_task_factory call. + When: + A wool.ContextVar is set for the first time, arming the + chain, then a child task is created with + asyncio.create_task. + Then: + The child should fork onto a chain id distinct from the + parent's — the first set self-installed the factory, so + copy-on-fork works without an explicit install_task_factory + call. + """ + # Arrange + loop = asyncio.get_running_loop() + loop.set_task_factory(None) + var = ContextVar(_unique("self_install")) + + async def child() -> uuid.UUID: + context = wool.__chain__.get(None) + assert context is not None + return context.id + + # Act + var.set("x") # No explicit install_task_factory() call. + parent_context = wool.__chain__.get(None) + assert parent_context is not None + child_chain = await asyncio.create_task(child()) + + # Assert + assert child_chain != parent_context.id + + +@pytest.mark.asyncio +@given( + parent_value=st.one_of(st.integers(), st.text(), st.lists(st.integers())), + child_value=st.one_of(st.integers(), st.text(), st.lists(st.integers())), +) +@settings( + max_examples=25, + deadline=None, + suppress_health_check=[HealthCheck.function_scoped_fixture], +) +async def test_install_task_factory_should_inherit_then_isolate_state_on_fork( + parent_value, child_value +): + """Test a child task forks the parent chain, inheriting then isolating state. + + Given: + An armed parent chain with Wool's task factory installed + and a variable set to an arbitrary parent value. + When: + A child task created with asyncio.create_task reads the + inherited value, reads its own chain id, then mutates the + variable to an arbitrary child value. + Then: + The child should observe the parent's value, run on a chain + id distinct from the parent's, and its mutation should not + leak back — copy-on-fork inherits, forks, and isolates. + """ + # Arrange + var = ContextVar(_unique("copy_on_fork")) + + async def child() -> tuple[object, uuid.UUID]: + inherited = var.get() + context = wool.__chain__.get(None) + assert context is not None + var.set(child_value) + return inherited, context.id + + install_task_factory() + var.set(parent_value) + parent_context = wool.__chain__.get(None) + assert parent_context is not None + + # Act + inherited, child_chain = await asyncio.create_task(child()) + + # Assert + assert inherited == parent_value + assert child_chain != parent_context.id + assert var.get() == parent_value + + +@pytest.mark.asyncio +async def test_install_task_factory_should_stay_unarmed_when_parent_unarmed(): + """Test a child task of an unarmed chain stays unarmed. + + Given: + An unarmed chain with Wool's task factory installed. + When: + A child task reads current_context. + Then: + It should observe None — an unarmed fork is a dormant + no-op. + """ + + async def child() -> object: + return wool.__chain__.get(None) + + install_task_factory() + + # Act + observed = await asyncio.create_task(child()) + + # Assert + assert observed is None + + +@pytest.mark.asyncio +async def test_install_task_factory_should_raise_when_child_resets_parent_token(): + """Test a child task cannot reset a Token minted in the parent chain. + + Given: + An armed parent chain whose set produced a Token. + When: + A child task attempts to reset that parent Token. + Then: + Stdlib's :meth:`contextvars.ContextVar.reset` raises + :class:`ValueError` naming the different + :class:`contextvars.Context` — the parent and the + copy-on-fork child run in distinct contexts. + """ + # Arrange + var = ContextVar(_unique("child_token")) + install_task_factory() + token = var.set("x") + + async def child() -> None: + var.reset(token) + + # Act & assert + with pytest.raises(ValueError, match="different Context"): + await asyncio.create_task(child()) + + +@pytest.mark.asyncio +async def test_install_task_factory_should_wrap_an_existing_user_factory(): + """Test install_task_factory wraps an existing user factory. + + Given: + A running loop with a user-supplied task factory that + increments a counter. + When: + install_task_factory is called and a child task is created + from an armed parent. + Then: + It should fork the child onto a fresh chain UUID (copy-on- + fork still works) and increment the user factory's counter + (the original factory still participated). + """ + # Arrange + counter = [0] + + def user_factory( + loop: asyncio.AbstractEventLoop, + coro, + **kwargs, + ) -> asyncio.Task: + counter[0] += 1 + return asyncio.Task(coro, loop=loop, **kwargs) + + loop = asyncio.get_running_loop() + loop.set_task_factory(user_factory) + + var = ContextVar(_unique("compose")) + + async def child() -> uuid.UUID: + context = wool.__chain__.get(None) + assert context is not None + return context.id + + # Act + install_task_factory() + var.set("x") + parent_context = wool.__chain__.get(None) + assert parent_context is not None + child_chain = await asyncio.create_task(child()) + + # Assert + assert child_chain != parent_context.id + assert counter[0] == 1 + + +@pytest.mark.asyncio +async def test_install_task_factory_should_drop_fork_when_user_factory_installed_after(): + """Test that installing a user factory after Wool's drops copy-on-fork. + + Given: + Wool's task factory installed on the running loop and a + ContextVar armed in the parent chain. + When: + A plain pass-through user factory replaces Wool's factory + and a child task is created. + Then: + It should observe the parent's chain UUID — confirming that + copy-on-fork is gone once Wool's factory is no longer last. + """ + # Arrange + var = ContextVar(_unique("install_order")) + + def user_factory( + loop: asyncio.AbstractEventLoop, + coro, + **kwargs, + ) -> asyncio.Task: + return asyncio.Task(coro, loop=loop, **kwargs) + + async def child() -> uuid.UUID | None: + context = wool.__chain__.get(None) + return context.id if context is not None else None + + install_task_factory() + var.set("x") + parent_context = wool.__chain__.get(None) + assert parent_context is not None + + # Act + loop = asyncio.get_running_loop() + loop.set_task_factory(user_factory) + child_chain = await asyncio.create_task(child()) + + # Assert + assert child_chain == parent_context.id + + +@pytest.mark.asyncio +async def test_install_task_factory_should_skip_when_wool_buried_under_third_party(): + """Test the idempotency check sees Wool buried under a third-party. + + Given: + A loop where Wool's factory was installed first and a + third-party factory was installed *over* it (a known + ordering hazard). + When: + install_task_factory is called again. + Then: + It should detect the buried Wool layer via the + ``__wool_inner__`` chain walk and skip the install rather + than wrap into a ``wool → third-party → wool`` composition. + """ + # Arrange — install Wool first, then a third-party factory + # over it that explicitly preserves the inner chain attribute. + loop = asyncio.get_running_loop() + loop.set_task_factory(None) + install_task_factory() + wool_factory = loop.get_task_factory() + + def third_party( + loop: asyncio.AbstractEventLoop, + coro, + **kwargs, + ) -> asyncio.Task: + return wool_factory(loop, coro, **kwargs) # pyright: ignore[reportCallIssue] + + # Expose the inner chain so the walk can find Wool buried below. + third_party.__wool_inner__ = wool_factory # type: ignore[attr-defined] + loop.set_task_factory(third_party) + + # Act — calling install_task_factory again should be a no-op. + install_task_factory() + + # Assert — the third-party factory still sits on the loop, + # unchanged. A naive outer-only check would have re-wrapped. + assert loop.get_task_factory() is third_party + + +@pytest.mark.asyncio +async def test_install_task_factory_should_close_coroutine_when_inner_raises(recwarn): + """Test ``inner`` raising closes the wrapper and user coroutine. + + Given: + Wool's factory installed and an inner factory that raises + unconditionally; an armed parent so the child coroutine + is wrapped in ``_forked_scope``. + When: + asyncio.create_task is called and the inner factory raises. + Then: + The exception should propagate, no "coroutine was never + awaited" RuntimeWarning should be emitted at GC, and the + user coroutine should be closed. + """ + + # Arrange + def failing_inner( + loop: asyncio.AbstractEventLoop, + coro, + **kwargs, + ) -> asyncio.Task: + raise RuntimeError("inner refused the kwargs") + + loop = asyncio.get_running_loop() + loop.set_task_factory(failing_inner) + install_task_factory() + + var = ContextVar(_unique("inner_raises")) + var.set("armed") + + async def user_coro() -> None: # pragma: no cover — never awaited + return None + + coro = user_coro() + + # Act + try: + with pytest.raises(RuntimeError, match="inner refused"): + loop.create_task(coro) + + # Assert — collect to force any pending RuntimeWarning. + del coro + gc.collect() + leaks = [ + w + for w in recwarn.list + if issubclass(w.category, RuntimeWarning) + and "never awaited" in str(w.message) + ] + assert not leaks, f"unexpected coroutine-never-awaited warnings: {leaks}" + finally: + # Restore so teardown asyncgen shutdown can create tasks. + loop.set_task_factory(None) + + +@pytest.mark.asyncio +async def test_install_task_factory_should_close_coroutine_when_inner_raises_unarmed( + recwarn, +): + """Test ``inner`` raising in the unarmed branch closes the user coroutine. + + Given: + Wool's factory installed over an inner factory that + raises unconditionally, and an unarmed parent so the + child coroutine is NOT wrapped in ``_forked_scope`` — + the unarmed branch of the factory body, separate from + the already-covered armed branch. + When: + asyncio.create_task is called from the unarmed parent + and the inner factory raises. + Then: + The exception should propagate, no "coroutine was never + awaited" RuntimeWarning should be emitted at GC, and the + user coroutine should be closed by the unarmed-branch + cleanup arm. + """ + + # Arrange + def failing_inner( + loop: asyncio.AbstractEventLoop, + coro, + **kwargs, + ) -> asyncio.Task: + raise RuntimeError("inner refused the kwargs") + + loop = asyncio.get_running_loop() + loop.set_task_factory(failing_inner) + install_task_factory() + + async def user_coro() -> None: # pragma: no cover — never awaited + return None + + coro = user_coro() + + # Act — no var.set, so the parent (and therefore the child) + # is unarmed. The factory takes the no-wrap branch. + try: + with pytest.raises(RuntimeError, match="inner refused"): + loop.create_task(coro) + + # Assert — collect to force any pending RuntimeWarning. + del coro + gc.collect() + leaks = [ + w + for w in recwarn.list + if issubclass(w.category, RuntimeWarning) + and "never awaited" in str(w.message) + ] + assert not leaks, f"unexpected coroutine-never-awaited warnings: {leaks}" + finally: + # Restore so teardown asyncgen shutdown can create tasks. + loop.set_task_factory(None) + + +@pytest.mark.asyncio +async def test_install_task_factory_should_close_coroutine_when_cancelled_before_step( + recwarn, +): + """Test an armed task cancelled before its first step leaks no warning. + + Given: + An armed parent chain with Wool's task factory installed, + and a child coroutine wrapped by the factory's forked + scope. + When: + The child task is cancelled before the loop ever steps it, + then awaited to completion. + Then: + No "coroutine was never awaited" RuntimeWarning should + leak — the factory's release callback closes the un-stepped + wrapped coroutine. + """ + # Arrange + var = ContextVar(_unique("cancel_before_step")) + stepped = [False] + + async def child() -> None: + stepped[0] = True + + install_task_factory() + var.set("x") + + # Act — cancel before yielding control, so the task is never + # stepped, then await it so its done-callback runs. + child_task = asyncio.create_task(child()) + child_task.cancel() + with pytest.raises(asyncio.CancelledError): + await child_task + # Force a GC cycle so any un-awaited coroutine would surface. + gc.collect() + + # Assert + assert stepped[0] is False + assert not [ + w + for w in recwarn.list + if issubclass(w.category, RuntimeWarning) and "never awaited" in str(w.message) + ] + + +@pytest.mark.asyncio +async def test_install_task_factory_should_raise_when_context_shared_across_live_tasks(): + """Test the factory rejects one contextvars.Context shared by two tasks. + + Given: + An armed chain with Wool's task factory installed and a + live task created with an explicit contextvars.Context. + When: + A second task is created with that same context object + while the first is still running. + Then: + It should raise wool.ChainContention — two tasks + cannot interleave on one context's chain context. + """ + # Arrange + var = ContextVar(_unique("shared_ctx")) + release = asyncio.Event() + + async def body() -> None: + # Block on an event the test controls so the first task is + # provably still live when the second is created. + await release.wait() + + install_task_factory() + var.set("x") + loop = asyncio.get_running_loop() + shared = contextvars.copy_context() + first = loop.create_task(body(), context=shared) + + # Act & assert + try: + with pytest.raises(ChainContention): + loop.create_task(body(), context=shared) + finally: + release.set() + await first + + +@pytest.mark.asyncio +async def test_install_task_factory_should_not_raise_when_unarmed_context_shared(): + """Test the factory allows one unarmed context shared by two live tasks. + + Given: + Wool's task factory installed and a live task created with + an explicit, unarmed contextvars.Context. + When: + A second task is created with that same context object while + the first is still running. + Then: + It should not raise — an unarmed context carries no chain to + corrupt, so sharing it across concurrently-live tasks is + permitted exactly as stdlib asyncio permits it. + """ + # Arrange + release = asyncio.Event() + + async def body() -> None: + await release.wait() + + install_task_factory() + loop = asyncio.get_running_loop() + shared = contextvars.copy_context() + first = loop.create_task(body(), context=shared) + + # Act + try: + second = loop.create_task(body(), context=shared) + finally: + release.set() + + # Assert + await asyncio.gather(first, second) + + +@pytest.mark.asyncio +async def test_install_task_factory_should_raise_when_shared_context_armed_late(): + """Test arming a shared unarmed context trips the owning-task guard. + + Given: + Wool's task factory installed and one unarmed + contextvars.Context shared by two concurrently-live tasks, + one of which arms it with a wool.ContextVar.set. + When: + The other task touches a wool.ContextVar on that now-armed + chain. + Then: + It should raise wool.ChainContention — the owning-task + guard catches the second task entering a chain another live + task owns, the case the creation-time rejection cannot see + because the context was unarmed when both tasks were created. + """ + # Arrange + var = ContextVar(_unique("late_arm")) + armed = asyncio.Event() + release = asyncio.Event() + + async def arming_task() -> None: + var.set("owner") + armed.set() + await release.wait() + + async def entering_task() -> str: + await armed.wait() + try: + return var.get("fallback") + finally: + release.set() + + install_task_factory() + loop = asyncio.get_running_loop() + shared = contextvars.copy_context() + owner = loop.create_task(arming_task(), context=shared) + intruder = loop.create_task(entering_task(), context=shared) + + # Act & assert + with pytest.raises(ChainContention): + await intruder + await owner + + +@pytest.mark.asyncio +async def test_install_task_factory_should_not_raise_when_context_reused_after_task(): + """Test a contextvars.Context is reusable once its task has finished. + + Given: + An armed chain with Wool's task factory installed and a + task created with an explicit contextvars.Context that has + run to completion. + When: + A new task is created with that same context object. + Then: + It should not raise — the context is free once no live + task holds it. + """ + # Arrange + var = ContextVar(_unique("reuse_ctx")) + + async def body() -> None: + await asyncio.sleep(0) + + install_task_factory() + var.set("x") + loop = asyncio.get_running_loop() + ctx = contextvars.copy_context() + + # Act & assert + await loop.create_task(body(), context=ctx) + await loop.create_task(body(), context=ctx) + + +@pytest.mark.asyncio +async def test_install_task_factory_should_not_raise_when_tasks_use_default_contexts(): + """Test concurrent tasks with default per-task contexts are not rejected. + + Given: + An armed chain with Wool's task factory installed. + When: + Two child tasks are created concurrently without an + explicit context argument. + Then: + It should not raise — the guard fires only on a shared + explicit context, and asyncio copies the context per task. + """ + # Arrange + var = ContextVar(_unique("default_ctx")) + + async def body() -> str: + await asyncio.sleep(0) + return var.get() + + install_task_factory() + var.set("x") + + # Act + results = await asyncio.gather( + asyncio.create_task(body()), + asyncio.create_task(body()), + ) + + # Assert + assert results == ["x", "x"] + + +@pytest.mark.asyncio +async def test_install_task_factory_should_raise_when_displaced_by_later_factory(): + """Test a displaced Wool factory raises TaskFactoryDisplaced. + + Given: + A running loop where Wool's task factory was self-installed + by a first wool.ContextVar.set, then displaced by a + third-party task factory installed after it. + When: + Wool's self-install path runs again — triggered by a first + wool.ContextVar.set in a fresh unarmed context. + Then: + It should raise TaskFactoryDisplaced — copy-on-fork is + silently lost for tasks created since the displacement, so + the displacement is surfaced rather than passing unnoticed. + """ + # Arrange + loop = asyncio.get_running_loop() + + def user_factory( + loop: asyncio.AbstractEventLoop, + coro, + **kwargs, + ) -> asyncio.Task: + return asyncio.Task(coro, loop=loop, **kwargs) + + # A first set self-installs Wool's factory and records the loop. + with scoped_context(): + ContextVar(_unique("displace_seed")).set("x") + # A third-party factory installed afterwards displaces Wool's. + loop.set_task_factory(user_factory) + + # Act & assert — a later first-set re-enters the self-install + # path, which finds the loop recorded but the factory no longer + # Wool-wrapped. + with scoped_context(): + with pytest.raises(TaskFactoryDisplaced, match="displaced"): + ContextVar(_unique("displace_trigger")).set("y") + + +@pytest.mark.asyncio +async def test_install_task_factory_should_monitor_displacement_when_direct_install(): + """Test a direct install_task_factory call registers the loop for displacement. + + Given: + A running loop where Wool's task factory was installed by a + direct install_task_factory() call — not a self-install from + wool.ContextVar.set — then displaced by a third-party factory + installed after it. + When: + Wool's self-install path runs again, triggered by a first + wool.ContextVar.set in a fresh unarmed context. + Then: + It should raise TaskFactoryDisplaced — a direct install + also records the loop, so direct-install (e.g. worker) + loops are displacement-monitored like self-installed ones. + """ + # Arrange + loop = asyncio.get_running_loop() + + def user_factory( + loop: asyncio.AbstractEventLoop, + coro, + **kwargs, + ) -> asyncio.Task: + return asyncio.Task(coro, loop=loop, **kwargs) + + # A direct install records the loop for displacement monitoring. + install_task_factory(loop) + # A third-party factory installed afterwards displaces Wool's. + loop.set_task_factory(user_factory) + + # Act & assert — a first-set re-enters the self-install path, + # which finds the loop recorded but the factory not Wool-wrapped. + with scoped_context(): + with pytest.raises(TaskFactoryDisplaced, match="displaced"): + ContextVar(_unique("direct_displace_trigger")).set("x") + + +def test_install_task_factory_should_not_warn_when_finalized_on_idle_loop(caplog): + """Test the finalizer treats a non-running loop as teardown. + + Given: + Wool's task factory installed on an event loop that is never + started. + When: + The factory is replaced and dropped so it is garbage-collected + while the loop is not running. + Then: + It should not warn that the factory was displaced — a finalizer + firing on a non-running loop is normal teardown, not the + displacement case that warns and poisons the loop. + """ + # Arrange + loop = asyncio.new_event_loop() + try: + install_task_factory(loop) + + # Act — drop Wool's factory so it is collected and its finalizer + # fires while the loop is not running. + with caplog.at_level(logging.WARNING, logger="wool.runtime.context.factory"): + loop.set_task_factory(None) + gc.collect() + + # Assert — no displacement warning was emitted for the teardown. + assert not [r for r in caplog.records if "displaced" in r.getMessage().lower()] + finally: + loop.close() + + +@pytest.mark.asyncio +async def test_install_task_factory_should_raise_when_displacer_keeps_factory_alive(): + """Test displacement is caught at the next set when the finalizer cannot fire. + + Given: + Wool's task factory installed, then displaced by a third-party + factory that keeps the Wool factory object alive — so its + finalizer never fires and the loop is not pre-flagged as + displaced. + When: + A wool.ContextVar value is next set. + Then: + It should raise TaskFactoryDisplaced — the set-time factory + inspection catches the displacement the finalizer missed. + """ + # Arrange + loop = asyncio.get_running_loop() + install_task_factory(loop) + stashed = loop.get_task_factory() # strong ref so the finalizer cannot fire + + def user_factory(loop, coro, **kwargs) -> asyncio.Task: + return asyncio.Task(coro, loop=loop, **kwargs) + + loop.set_task_factory(user_factory) + + # Act & assert + with scoped_context(): + with pytest.raises(TaskFactoryDisplaced, match="displaced"): + ContextVar(_unique("kept_alive_trigger")).set("x") + assert stashed is not None # keep the factory alive through the assert + + +@pytest.mark.asyncio +async def test_install_task_factory_should_raise_when_completion_sees_displacement(): + """Test a completing Wool task detects displacement as a backstop. + + Given: + Wool's task factory installed with one Wool task in flight, + then displaced by a third-party factory whose installation + keeps the Wool factory object alive (so its finalizer never + fires and the loop is not pre-flagged as displaced). + When: + The in-flight task completes. + Then: + A subsequent wool.ContextVar set should raise + TaskFactoryDisplaced — the completing task is the backstop that + notices the loop's factory is no longer Wool's. + """ + # Arrange + loop = asyncio.get_running_loop() + install_task_factory(loop) + stashed = loop.get_task_factory() # keep Wool's factory alive + + def user_factory(loop, coro, **kwargs) -> asyncio.Task: + return asyncio.Task(coro, loop=loop, **kwargs) + + release = asyncio.Event() + + async def in_flight() -> None: + await release.wait() + + # A Wool-created task in flight; its done-callback is the backstop. + # Left unarmed so the task body touches no wool.ContextVar and can + # complete cleanly after the displacement, letting the done-callback + # observe it rather than tripping the set-time check mid-task. + task = loop.create_task(in_flight()) + + # Act — displace Wool while the task is in flight, then let it finish + # so its done-callback runs the displacement backstop. + loop.set_task_factory(user_factory) + release.set() + await task + await asyncio.sleep(0) + + # Assert + with scoped_context(): + with pytest.raises(TaskFactoryDisplaced, match="displaced"): + ContextVar(_unique("backstop_trigger")).set("y") + assert stashed is not None # keep the factory alive through the assert + + +@pytest.mark.asyncio +async def test_install_task_factory_should_detect_buried_wool_via_cached_layer(): + """Test the idempotency walk short-circuits on an already-cached layer. + + Given: + Wool's factory buried beneath nested third-party wrappers, + after Wool has already been detected once on the loop (so the + inner wrapper is already memoized). + When: + install_task_factory is called again with a further wrapper + stacked on top. + Then: + Wool should still be detected through the wrappers and the call + should be idempotent — the loop's factory is left unchanged. + """ + # Arrange — install Wool, bury it under a third party, and prime the + # detection cache by detecting that buried layer once. + loop = asyncio.get_running_loop() + loop.set_task_factory(None) + install_task_factory() + wool_factory = loop.get_task_factory() + + def third_party(loop, coro, **kwargs) -> asyncio.Task: + return wool_factory(loop, coro, **kwargs) # pyright: ignore[reportCallIssue] + + third_party.__wool_inner__ = wool_factory # type: ignore[attr-defined] + loop.set_task_factory(third_party) + install_task_factory() # memoizes the third-party layer + + def outer(loop, coro, **kwargs) -> asyncio.Task: + return third_party(loop, coro, **kwargs) + + outer.__wool_inner__ = third_party # type: ignore[attr-defined] + loop.set_task_factory(outer) + + # Act — the walk now hits the cached third-party layer mid-chain. + install_task_factory() + + # Assert — Wool detected via the cache; the install was skipped. + assert loop.get_task_factory() is outer + + +@pytest.mark.asyncio +async def test_install_task_factory_should_clear_reservation_when_armed_inner_raises( + recwarn, +): + """Test a failed armed task creation releases its context reservation. + + Given: + Wool's factory composed over an inner factory that raises, and + an armed Wool context passed explicitly to task creation. + When: + A task is created with that context and the inner factory + raises. + Then: + The error should propagate, no "coroutine was never awaited" + warning should be emitted, and the same context should remain + reusable for a later task — its pending reservation was cleared + rather than pinned. + """ + + # Arrange + def failing_inner(loop, coro, **kwargs) -> asyncio.Task: + raise RuntimeError("inner refused the kwargs") + + loop = asyncio.get_running_loop() + loop.set_task_factory(failing_inner) + install_task_factory() + + var = ContextVar(_unique("armed_reservation")) + var.set("armed") + armed_ctx = contextvars.copy_context() + + async def user_coro() -> None: # pragma: no cover — never awaited + return None + + first, second = user_coro(), user_coro() + + # Act & assert + try: + with pytest.raises(RuntimeError, match="inner refused"): + loop.create_task(first, context=armed_ctx) + # The reservation must be released: a second creation with the + # same context reaches the (failing) inner factory again rather + # than tripping the contention guard on a pinned slot. + with pytest.raises(RuntimeError, match="inner refused"): + loop.create_task(second, context=armed_ctx) + + del first, second + gc.collect() + leaks = [ + w + for w in recwarn.list + if issubclass(w.category, RuntimeWarning) + and "never awaited" in str(w.message) + ] + assert not leaks, f"unexpected coroutine-never-awaited warnings: {leaks}" + finally: + loop.set_task_factory(None) + + +@pytest.mark.asyncio +async def test_to_thread_should_return_result_when_positional_args(): + """Test wool.to_thread runs the callable and returns its result. + + Given: + A blocking callable returning a value. + When: + wool.to_thread offloads it. + Then: + It should return the callable's result. + """ + + # Arrange + def work(a: int, b: int) -> int: + return a + b + + # Act + result = await to_thread(work, 2, 3) + + # Assert + assert result == 5 + + +@pytest.mark.asyncio +async def test_to_thread_should_forward_keyword_args(): + """Test wool.to_thread forwards keyword arguments to the callable. + + Given: + A callable accepting a keyword argument. + When: + wool.to_thread is called with that keyword argument. + Then: + It should forward the keyword argument correctly. + """ + + # Arrange + def work(*, label: str) -> str: + return label.upper() + + # Act + result = await to_thread(work, label="hi") + + # Assert + assert result == "HI" + + +@pytest.mark.asyncio +async def test_to_thread_should_carry_caller_value_into_thread_when_armed_context(): + """Test wool.to_thread carries the caller's ContextVar value into the thread. + + Given: + An armed chain with a ContextVar set. + When: + wool.to_thread offloads a function that reads the variable. + Then: + It should observe the caller's value. + """ + # Arrange + var = ContextVar(_unique("to_thread_value")) + + def read() -> str: + return var.get() + + var.set("carried") + + # Act + observed = await to_thread(read) + + # Assert + assert observed == "carried" + + +@pytest.mark.asyncio +async def test_to_thread_should_not_trip_contention_guard_when_armed_context(): + """Test wool.to_thread does not trip the chain-contention guard. + + Given: + An armed chain. + When: + wool.to_thread offloads a function that touches a + ContextVar. + Then: + It should not raise wool.ChainContention. + """ + # Arrange + var = ContextVar(_unique("to_thread_guard")) + + def touch() -> str: + return var.get() + + var.set("ok") + + # Act + result = await to_thread(touch) + + # Assert + assert result == "ok" + + +@pytest.mark.asyncio +async def test_to_thread_should_not_arm_when_unarmed_context(): + """Test wool.to_thread on an unarmed chain offloads without arming. + + Given: + An unarmed chain. + When: + wool.to_thread offloads a function that reads + current_context. + Then: + It should observe None. + """ + + def read() -> object: + return wool.__chain__.get(None) + + # Act + observed = await to_thread(read) + + # Assert + assert observed is None + + +@pytest.mark.asyncio +async def test_to_thread_should_run_on_fresh_chain_when_armed_parent_chain(): + """Test wool.to_thread runs the offloaded function on a fresh chain. + + Given: + An armed chain whose chain id is known. + When: + wool.to_thread offloads a function reading + wool.__chain__.get(None). + Then: + It should differ from the caller's chain id. + """ + # Arrange + var = ContextVar(_unique("to_thread_chain")) + + def read_chain() -> uuid.UUID: + context = wool.__chain__.get(None) + assert context is not None + return context.id + + var.set("x") + caller_context = wool.__chain__.get(None) + assert caller_context is not None + + # Act + offloaded_chain = await to_thread(read_chain) + + # Assert + assert offloaded_chain != caller_context.id + + +@pytest.mark.asyncio +async def test_to_thread_should_not_propagate_mutation_from_worker_thread(): + """Test mutations made inside wool.to_thread do not propagate back. + + Given: + An armed chain with a ContextVar set. + When: + wool.to_thread offloads a function that resets the + variable. + Then: + It should leave the caller observing its own value — the + offloaded chain is detached, with no merge-back. + """ + # Arrange + var = ContextVar(_unique("to_thread_no_merge")) + + def mutate() -> None: + var.set("thread-value") + + var.set("caller-value") + + # Act + await to_thread(mutate) + + # Assert + assert var.get() == "caller-value" + + +def test_context_is_armed_should_return_true_when_armed_context(): + """Test context_is_armed returns True for a context carrying a context. + + Given: + A contextvars.Context in which a wool.ContextVar has been set. + When: + context_is_armed is called on it. + Then: + It should return True — the context carries a non-None Wool + context. + """ + # Arrange + var = ContextVar(_unique("armed_probe")) + + def _arm() -> None: + var.set("x") + + armed_context = contextvars.copy_context() + armed_context.run(_arm) + + # Act & assert + assert context_is_armed(armed_context) is True + + +def test_context_is_armed_should_return_false_when_unarmed_context(): + """Test context_is_armed returns False for a context with no context. + + Given: + A fresh contextvars.Context in which no wool.ContextVar has + been set. + When: + context_is_armed is called on it. + Then: + It should return False — an unarmed context is + indistinguishable from a plain contextvars.Context. + """ + # Arrange + unarmed_context = contextvars.Context() + + # Act & assert + assert context_is_armed(unarmed_context) is False diff --git a/wool/tests/runtime/context/test_guard.py b/wool/tests/runtime/context/test_guard.py new file mode 100644 index 00000000..a107241f --- /dev/null +++ b/wool/tests/runtime/context/test_guard.py @@ -0,0 +1,386 @@ +import asyncio +import contextvars +import threading +import uuid + +import pytest + +import wool +from tests.helpers import scoped_context +from wool.runtime.context.exceptions import ChainContention +from wool.runtime.context.factory import install_task_factory +from wool.runtime.context.var import ContextVar + + +def _unique(stem: str) -> str: + """Return a process-unique variable name to avoid registry collisions.""" + return f"{stem}_{uuid.uuid4().hex}" + + +def test_get_should_not_raise_when_owner_task_is_pending_off_loop(): + """Test reading an armed var off any loop on the owning thread is a no-op. + + Given: + An armed context whose chain is owned by a still-pending asyncio + task, with its contextvars.Context captured while that task is + live, and the driving loop subsequently stopped. + When: + A wool.ContextVar is read inside that captured context from + synchronous code on the owning thread, with no running event + loop. + Then: + It should return the armed value without raising — with no + running loop there is no concurrent task to arbitrate against. + """ + # Arrange — arm the chain inside a task that stays pending so its + # owner reference is live (not done) when the off-loop read runs. + # A dedicated loop is driven to the arming point and then left idle + # (not closed) so the off-loop read runs on the owning thread with + # no loop running in its frame. + var = ContextVar(_unique("pending_off_loop")) + captured: dict[str, object] = {} + loop = asyncio.new_event_loop() + + async def _arm_and_block(armed: asyncio.Event) -> None: + var.set("armed") + captured["context"] = contextvars.copy_context() + armed.set() + await asyncio.Future() # never completes — keep the owner live + + async def _drive() -> asyncio.Task[None]: + armed = asyncio.Event() + task = loop.create_task(_arm_and_block(armed)) + await armed.wait() + return task + + owning_task = loop.run_until_complete(_drive()) + context = captured["context"] + assert isinstance(context, contextvars.Context) + + # Act — the loop has stopped (run_until_complete returned) but the + # owner task is still pending; read synchronously inside the + # captured armed context with no running loop in this frame. + try: + result = context.run(var.get) + finally: + # Drain the pending task so its coroutine is closed and no + # "coroutine was never awaited" RuntimeWarning leaks. + owning_task.cancel() + with pytest.raises(asyncio.CancelledError): + loop.run_until_complete(owning_task) + loop.close() + + # Assert + assert result == "armed" + + +def test_get_should_not_raise_when_owner_task_is_done(): + """Test reading an armed var whose owner task has finished is a no-op. + + Given: + An armed context whose owner task has already run to completion, + with its contextvars.Context captured. + When: + A wool.ContextVar is read inside that captured context from + synchronous code on the owning thread. + Then: + It should return the armed value without raising — a finished + owner is no longer a live concurrent runner. + """ + # Arrange — drive an arming coroutine to completion so its owner + # task is done, then capture its armed context. + var = ContextVar(_unique("done_owner")) + captured: dict[str, object] = {} + loop = asyncio.new_event_loop() + + async def _arm() -> None: + var.set("armed") + captured["context"] = contextvars.copy_context() + + try: + loop.run_until_complete(_arm()) + finally: + loop.close() + context = captured["context"] + assert isinstance(context, contextvars.Context) + + # Act — the owner task is now done; read synchronously inside the + # captured armed context on the owning thread. + result = context.run(var.get) + + # Assert + assert result == "armed" + + +class TestChainContention: + @pytest.mark.asyncio + async def test_get_should_raise_contention_when_plain_to_thread_from_armed_context( + self, + ): + """Test plain asyncio.to_thread touching a ContextVar trips the guard. + + Given: + An armed context whose chain is owned by the loop thread. + When: + asyncio.to_thread offloads a function that reads a + wool.ContextVar from a different OS thread. + Then: + It should raise wool.ChainContention directing the + caller to wool.to_thread. + """ + # Arrange + var = ContextVar(_unique("plain_to_thread")) + + def touch() -> str: + return var.get() + + # Act & assert + with scoped_context(): + var.set("armed") + with pytest.raises(ChainContention, match="wool.to_thread"): + await asyncio.to_thread(touch) + + @pytest.mark.asyncio + async def test_set_should_raise_contention_when_plain_to_thread_from_armed_context( + self, + ): + """Test plain asyncio.to_thread setting a ContextVar trips the guard. + + Given: + An armed context owned by the loop thread. + When: + asyncio.to_thread offloads a function that sets a + wool.ContextVar from a worker thread. + Then: + It should raise wool.ChainContention directing the + caller to wool.to_thread. + """ + # Arrange + var = ContextVar(_unique("plain_to_thread_set")) + + def mutate() -> None: + var.set("from-thread") + + # Act & assert + with scoped_context(): + var.set("armed") + with pytest.raises(ChainContention, match="wool.to_thread"): + await asyncio.to_thread(mutate) + + @pytest.mark.asyncio + async def test_get_should_carry_structured_fields_when_cross_thread_contention(self): + """Test a cross-thread ChainContention carries structured identity. + + Given: + An armed context owned by the loop thread. + When: + asyncio.to_thread offloads a function that reads a + wool.ContextVar from a worker thread, tripping the guard. + Then: + The exception should expose kind="thread", the chain id, the + owning thread ident, and the offending worker-thread ident; + the message should interpolate the chain id and both thread + identities for diagnostics. + """ + # Arrange + var = ContextVar(_unique("thread_fields")) + observed: dict[str, int] = {} + + def touch() -> str: + observed["worker_thread"] = threading.get_ident() + return var.get() + + # Act + with scoped_context(): + var.set("armed") + owning_thread = threading.get_ident() + chain = wool.__chain__.get(None) + with pytest.raises(ChainContention) as excinfo: + await asyncio.to_thread(touch) + + # Assert + exc = excinfo.value + assert chain is not None + assert exc.kind == "thread" + assert exc.chain_id == chain.id + assert exc.owning_thread == owning_thread + assert exc.current_thread == observed["worker_thread"] + message = str(exc) + assert str(chain.id) in message + assert str(owning_thread) in message + assert str(observed["worker_thread"]) in message + + @pytest.mark.asyncio + async def test_get_should_carry_structured_fields_when_cross_task_contention(self): + """Test a cross-task ChainContention carries structured identity. + + Given: + One unarmed contextvars.Context handed to two tasks; the + first arms it with a wool.ContextVar.set and the second + then touches a wool.ContextVar — the cross-task path the + task factory's copy-on-fork cannot catch. + When: + The second task touches the armed chain it does not own. + Then: + It should raise wool.ChainContention exposing kind="task", + the chain id, the owning (first) task, and the current + (second) task; the message should interpolate the chain id + and both task identities for diagnostics. + """ + # Arrange — share one unarmed context between two tasks. Sharing + # an unarmed context is permitted exactly as stdlib allows; the + # first task arms it, then the second task fails loud. + install_task_factory() + var = ContextVar(_unique("task_fields")) + shared = contextvars.copy_context() + loop = asyncio.get_running_loop() + armed = asyncio.Event() + first_can_finish = asyncio.Event() + identities: dict[str, object] = {} + + async def first() -> None: + var.set("armed-by-first") + identities["owning_task"] = asyncio.current_task() + identities["chain"] = wool.__chain__.get(None) + armed.set() + await first_can_finish.wait() + + async def second() -> BaseException | None: + await armed.wait() + identities["current_task"] = asyncio.current_task() + try: + var.get("fallback") + except ChainContention as exc: + return exc + finally: + first_can_finish.set() + return None + + # Act — both tasks are handed the SAME context object. + first_task = loop.create_task(first(), context=shared) + second_task = loop.create_task(second(), context=shared) + observed, _ = await asyncio.gather(second_task, first_task) + + # Assert + assert isinstance(observed, ChainContention) + chain = identities["chain"] + assert chain is not None + assert observed.kind == "task" + assert observed.chain_id == chain.id + assert observed.owning_task is identities["owning_task"] + assert observed.current_task is identities["current_task"] + message = str(observed) + assert str(chain.id) in message + + @pytest.mark.asyncio + async def test_get_should_not_raise_when_plain_to_thread_from_unarmed_context(self): + """Test plain asyncio.to_thread touching a ContextVar in an unarmed context. + + Given: + An unarmed context — no chain, no guard. + When: + asyncio.to_thread offloads a function that reads a + wool.ContextVar with a default. + Then: + It should not raise — an unarmed context behaves as plain + contextvars. + """ + # Arrange + var = ContextVar(_unique("plain_unarmed"), default="d") + + def touch() -> str: + return var.get() + + # Act + with scoped_context(): + result = await asyncio.to_thread(touch) + + # Assert + assert result == "d" + + @pytest.mark.asyncio + async def test_get_should_not_raise_when_wool_to_thread_from_armed_context(self): + """Test wool.to_thread offloading is the supported alternative. + + Given: + An armed context. + When: + wool.to_thread offloads a function that reads a + wool.ContextVar. + Then: + It should not raise — wool.to_thread forks a fresh, + detached chain owned by the worker thread. + """ + # Arrange + var = ContextVar(_unique("wool_to_thread_ok")) + + def touch() -> str: + return var.get() + + # Act + with scoped_context(): + var.set("armed") + result = await wool.to_thread(touch) + + # Assert + assert result == "armed" + + @pytest.mark.asyncio + async def test_get_should_not_raise_when_callback_on_owning_thread(self): + """Test event-loop callbacks on the owning thread never trip the guard. + + Given: + An armed context owned by the loop thread. + When: + A loop.call_soon callback reads a wool.ContextVar. + Then: + It should not raise — cooperative work on the owning thread + shares the chain but never runs in parallel. + """ + # Arrange + var = ContextVar(_unique("callback_owner")) + loop = asyncio.get_running_loop() + observed: asyncio.Future[str] = loop.create_future() + + def callback() -> None: + observed.set_result(var.get()) + + # Act + with scoped_context(): + var.set("armed") + loop.call_soon(callback) + result = await observed + + # Assert + assert result == "armed" + + @pytest.mark.asyncio + async def test_get_should_not_raise_when_timer_callback_on_owning_thread(self): + """Test event-loop timers on the owning thread never trip the guard. + + Given: + An armed context owned by the loop thread. + When: + A loop.call_later timer callback reads a wool.ContextVar. + Then: + It should not raise — the guard exemption covers timers as + well as immediate callbacks; a cooperatively-scheduled timer + on the owning thread shares the chain but never runs in + parallel with it. + """ + # Arrange + var = ContextVar(_unique("timer_owner")) + loop = asyncio.get_running_loop() + observed: asyncio.Future[str] = loop.create_future() + + def callback() -> None: + observed.set_result(var.get()) + + # Act + with scoped_context(): + var.set("armed") + loop.call_later(0, callback) + result = await observed + + # Assert + assert result == "armed" diff --git a/wool/tests/runtime/context/test_manifest.py b/wool/tests/runtime/context/test_manifest.py new file mode 100644 index 00000000..047fba80 --- /dev/null +++ b/wool/tests/runtime/context/test_manifest.py @@ -0,0 +1,1170 @@ +"""Unit tests for the wire-abstraction types — ChainManifest (the decoded +chain snapshot) and ContextVarManifest (the per-variable identity layer).""" + +import asyncio +import contextvars +import threading +import uuid +import warnings +from uuid import UUID +from uuid import uuid4 + +import pytest +from hypothesis import given +from hypothesis import settings +from hypothesis import strategies as st + +import wool +from tests.helpers import _unique +from tests.helpers import scoped_context +from wool import protocol +from wool.runtime.context.chain import Chain +from wool.runtime.context.exceptions import TaskFactoryDisplaced +from wool.runtime.context.manifest import ChainManifest +from wool.runtime.context.manifest import ContextVarManifest +from wool.runtime.context.manifest import resolve_stub +from wool.runtime.context.var import ContextVar +from wool.runtime.typing import Undefined + + +def _decode_manifest(wire: protocol.ChainManifest) -> ChainManifest: + """Decode a wire `protocol.ChainManifest` into a + `ChainManifest` using the default serializer. + + Test-helper shortcut over the public + `ChainManifest.from_protobuf` entry point — keeps the + decode/mount tests below readable. Raises ``ChainSerializationError`` + under strict mode, just as the production decode does. + """ + return ChainManifest.from_protobuf(wire, serializer=wool.__serializer__) + + +def _mount_manifest(manifest: ChainManifest) -> None: + """Inline the receive-time mount pattern for tests. + + Mirrors `wool.runtime.worker.frame.Frame.mount`: routes through + `wool.runtime.context.chain.Chain.from_manifest` with the + current Chain as the merge target (None when unarmed, picking up the + fresh-install branch automatically). Used by the tests that exercise + the decode-then-mount round trip via the public API. + """ + if not (manifest.vars or manifest.resets): + return + Chain.from_manifest( + manifest, + owned=True, + merge_with=wool.__chain__.get(None), + ) + + +def _adopt_chain(chain_id: UUID) -> None: + """Arm or re-arm the current context with *chain_id*. + + Mirrors the worker-side receive-time invariant: a worker that arms + onto a caller's chain and then receives further chain-manifest frames + on the same chain. Tests that exercise the armed-receiver merge path + call this before `_mount_manifest` so the manifest's chain id + matches the receiver's. + """ + current = wool.__chain__.get(None) + if current is None: + wool.__chain__.set(Chain(id=chain_id, thread=threading.get_ident())) + else: + wool.__chain__.set(current._evolve(id=chain_id)) + + +class TestChainManifest: + def test_from_protobuf_should_round_trip_value(self): + """Test from_protobuf recovers an encoded variable binding. + + Given: + A chain manifest produced by encoding an armed chain + with one variable binding. + When: + from_protobuf is called on that chain manifest. + Then: + The decoded chain manifest should index the same variable and the + decoded values map should carry the same value. + """ + # Arrange + var = ContextVar(_unique("decode_value")) + + with scoped_context(): + var.set("hello") + context = wool.__chain__.get(None) + assert context is not None + wire = context.to_manifest().to_protobuf() + + # Act + decoded = _decode_manifest(wire) + + # Assert + assert var in decoded.vars + assert decoded.vars[var] == "hello" + assert decoded.id == context.id + + def test_from_protobuf_should_record_reset_var_from_no_value_entry(self): + """Test from_protobuf reads a no-value wire entry into resets. + + Given: + A chain manifest entry that carries no value field. + When: + from_protobuf is called. + Then: + The decoded chain manifest's resets set should name that + variable's key. + """ + # Arrange + var = ContextVar(_unique("decode_reset_var")) + wire = protocol.ChainManifest(id=uuid4().hex) + wire.vars.add(namespace=var.namespace, name=var.name) + + # Act + decoded = _decode_manifest(wire) + + # Assert + assert (var.namespace, var.name) in decoded.resets + + def test_from_protobuf_should_produce_manifest_not_live_context(self): + """Test from_protobuf produces a `ChainManifest` and not a + live `Chain`. + + Given: + A chain manifest with a variable binding. + When: + `ChainManifest.from_protobuf` is called. + Then: + The returned object is a `ChainManifest` (unmounted + wire state, no owner stamping) — owner stamping is the + `Chain.mount` site and a freshly decoded manifest + carries no owner. + """ + # Arrange + var = ContextVar(_unique("decode_owner")) + with scoped_context(): + var.set(1) + wire = wool.__chain__.get().to_manifest().to_protobuf() + + # Act + decoded = _decode_manifest(wire) + + # Assert + assert isinstance(decoded, ChainManifest) + + def test_from_protobuf_should_raise_when_malformed_chain_id(self): + """Test from_protobuf raises on a malformed chain id. + + Given: + A chain manifest whose id is not a valid UUID hex string. + When: + from_protobuf is called. + Then: + It should raise ChainSerializationError unconditionally — + chain-id parse failure is a structural protocol error + distinct from per-var data errors, so it is fatal regardless + of the strict-mode warning filter. A silently-replaced + ``uuid4()`` would route follow-up frames to a fresh cached + Chain (silent state loss), so the decode fails loud instead. + The aggregated warning carries the decode direction, no + var_key (the failure is structural, not per-variable), and + the underlying ValueError as cause — the same ValueError + chained on the error's __cause__. + """ + # Arrange + wire = protocol.ChainManifest(id="not-a-uuid") + + # Act & assert + with pytest.raises(wool.ChainSerializationError) as exc_info: + _decode_manifest(wire) + warning = exc_info.value.warnings[0] + assert warning.direction == "decode" + assert warning.var_key is None + assert isinstance(warning.cause, ValueError) + assert exc_info.value.__cause__ is warning.cause + + def test_from_protobuf_should_register_stub_when_undeclared_variable(self): + """Test from_protobuf pins a stub for an undeclared wire variable. + + Given: + A chain manifest referencing a variable key that was never + declared as a ContextVar in this process. + When: + from_protobuf is called. + Then: + The decoded chain manifest should pin a stub variable for that key. + """ + # Arrange + key = ("undeclared_ns", _unique("decode_stub")) + wire = protocol.ChainManifest(id=uuid4().hex) + wire.vars.add(namespace=key[0], name=key[1], value=wool.__serializer__.dumps(1)) + + # Act + decoded = _decode_manifest(wire) + + # Assert + pinned_keys = {(var.namespace, var.name) for var in decoded.stubs} + assert key in pinned_keys + + def test_from_protobuf_should_warn_and_skip_unserializable_value(self): + """Test from_protobuf warns and skips a value it cannot deserialize. + + Given: + A chain manifest entry whose value bytes are not valid pickle + data. + When: + from_protobuf is called. + Then: + It should emit a SerializationWarning carrying the failed + variable's key, the decode direction, and the underlying + cause, and omit that variable from the decoded data. + """ + # Arrange + var = ContextVar(_unique("decode_bad_value")) + wire = protocol.ChainManifest(id=uuid4().hex) + wire.vars.add(namespace=var.namespace, name=var.name, value=b"not-pickle") + + # Act + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + decoded = _decode_manifest(wire) + + # Assert + assert var not in decoded.vars + emitted = [ + w for w in caught if issubclass(w.category, wool.SerializationWarning) + ] + assert emitted + warning = emitted[0].message + assert isinstance(warning, wool.SerializationWarning) + assert warning.var_key == (var.namespace, var.name) + assert warning.direction == "decode" + assert warning.cause is not None + + def test_from_protobuf_should_aggregate_failures_when_strict_mode(self): + """Test from_protobuf raises ChainSerializationError under strict mode. + + Given: + A chain manifest with two unparseable values (per-var data + errors), under strict SerializationWarning filtering. + Chain-id is intentionally well-formed here — a chain-id parse + failure is fatal *independently* of the strict-mode + aggregator, so a malformed chain id never feeds into the + per-var aggregator. + When: + from_protobuf is called. + Then: + It should raise a ChainSerializationError aggregating both + failures on .warnings, each carrying its variable's key, the + decode direction, and the underlying cause. + """ + # Arrange + var_a = ContextVar(_unique("decode_strict_a")) + var_b = ContextVar(_unique("decode_strict_b")) + wire = protocol.ChainManifest(id=uuid4().hex) + wire.vars.add(namespace=var_a.namespace, name=var_a.name, value=b"bad") + wire.vars.add(namespace=var_b.namespace, name=var_b.name, value=b"bad") + + # Act & assert + with warnings.catch_warnings(): + warnings.filterwarnings("error", category=wool.SerializationWarning) + with pytest.raises(wool.ChainSerializationError) as exc_info: + _decode_manifest(wire) + assert len(exc_info.value.warnings) == 2 + assert all( + isinstance(e, wool.SerializationWarning) for e in exc_info.value.warnings + ) + assert {w.var_key for w in exc_info.value.warnings} == { + (var_a.namespace, var_a.name), + (var_b.namespace, var_b.name), + } + assert all(w.direction == "decode" for w in exc_info.value.warnings) + assert all(w.cause is not None for w in exc_info.value.warnings) + + def test_from_protobuf_should_populate_vars(self): + """Test ChainManifest.from_protobuf carries decoded values + keyed by their variable. + + Given: + A chain manifest carrying one value-bearing variable entry. + When: + ChainManifest.from_protobuf decodes it. + Then: + The resulting manifest should expose the decoded value in + ``vars``, keyed by the variable. + """ + # Arrange + var = ContextVar(_unique("manifest_carry")) + with scoped_context(): + var.set("v") + wire = wool.__chain__.get().to_manifest().to_protobuf() + + # Act + decoded = _decode_manifest(wire) + + # Assert + assert var in decoded.vars + assert decoded.vars[var] == "v" + + def test_from_protobuf_should_warn_when_duplicate_var_keys(self): + """Test ChainManifest.from_protobuf warns on duplicate var keys. + + Given: + A chain manifest whose ``vars`` list carries two entries + with the same ``(namespace, name)`` key — a malformed + sender that duplicated an entry rather than overwriting. + When: + ChainManifest.from_protobuf decodes it. + Then: + A SerializationWarning is emitted naming the duplicate + key — carrying the duplicated var_key, the decode + direction, and no cause — and the first occurrence wins + in the decoded manifest (the second is dropped). The + decode otherwise succeeds — duplicate keys are + recoverable; the decode does not raise outside strict + mode. + """ + # Arrange + var = ContextVar(_unique("manifest_dup")) + wire = protocol.ChainManifest(id=uuid4().hex) + wire.vars.add( + namespace=var.namespace, + name=var.name, + value=wool.__serializer__.dumps("first"), + ) + wire.vars.add( + namespace=var.namespace, + name=var.name, + value=wool.__serializer__.dumps("second-ignored"), + ) + + # Act + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + decoded = _decode_manifest(wire) + + # Assert — duplicate-key warning was emitted naming the key, + # and the first occurrence won in the decoded manifest. + dup_warnings = [ + w for w in caught if issubclass(w.category, wool.SerializationWarning) + ] + assert dup_warnings, "expected a SerializationWarning on duplicate key" + assert any("Duplicate" in str(w.message) for w in dup_warnings) + duplicate = next( + w.message for w in dup_warnings if "Duplicate" in str(w.message) + ) + assert isinstance(duplicate, wool.SerializationWarning) + assert duplicate.var_key == (var.namespace, var.name) + assert duplicate.direction == "decode" + assert duplicate.cause is None + assert decoded.vars[var] == "first" + + @given( + values=st.lists( + st.text() | st.integers() | st.lists(st.integers()), + min_size=0, + max_size=5, + ), + reset_count=st.integers(min_value=0, max_value=3), + ) + @settings(max_examples=50) + def test_from_protobuf_should_round_trip_arbitrary_context( + self, values, reset_count + ): + """Test to_protobuf/from_protobuf round-trips data and reset signals. + + Given: + An armed chain binding zero or more variables to arbitrary + serializable values, plus zero or more reset-and-not-re-set + signals. + When: + The chain is encoded to the wire and decoded back. + Then: + The decoded chain manifest should carry every variable binding, every + resets key, and the chain id. + """ + # Arrange + bound_vars = [ContextVar(_unique("rt_context")) for _ in values] + reset_targets = [ContextVar(_unique("rt_reset")) for _ in range(reset_count)] + resets = frozenset((v.namespace, v.name) for v in reset_targets) + chain_id = uuid4() + + with scoped_context(): + # Bind the value-bearing vars so to_protobuf reads their + # live backing-variable values. + for var, value in zip(bound_vars, values): + var.set(value) + live = wool.__chain__.get(None) + bound = live.vars if live is not None else frozenset() + context = Chain( + id=chain_id, + thread=threading.get_ident(), + vars=bound, + resets=resets, + ) + wire = context.to_manifest().to_protobuf() + + # Act + decoded = _decode_manifest(wire) + + # Assert + expected = dict(zip(bound_vars, values)) + assert {var: decoded.vars[var] for var in bound_vars} == expected + assert decoded.resets == resets + assert decoded.id == chain_id + + +class TestChainManifestToProtobuf: + def test_to_protobuf_should_return_empty_manifest_when_none(self): + """Test to_protobuf returns an empty chain manifest for None. + + Given: + A None chain (an unarmed chain). + When: + to_protobuf is called. + Then: + It should return an empty protocol.ChainManifest with no vars. + """ + # Act + wire = protocol.ChainManifest() + + # Assert + assert isinstance(wire, protocol.ChainManifest) + assert len(wire.vars) == 0 + assert wire.id == "" + + def test_to_protobuf_should_carry_chain_id(self): + """Test to_protobuf writes the chain id to the chain manifest. + + Given: + A chain with a known chain id. + When: + to_protobuf is called. + Then: + The chain manifest id should equal the chain id's hex form. + """ + # Arrange + chain_id = uuid4() + context = Chain(id=chain_id, thread=threading.get_ident()) + + # Act + wire = context.to_manifest().to_protobuf() + + # Assert + assert wire.id == chain_id.hex + + def test_to_protobuf_should_emit_one_entry_per_variable(self): + """Test to_protobuf emits one wire entry per bound variable. + + Given: + An armed chain with two distinct variable bindings. + When: + to_protobuf is called on the active chain. + Then: + The chain manifest should carry one entry per variable, each + with a populated value field. + """ + # Arrange + var_a = ContextVar(_unique("encode_a")) + var_b = ContextVar(_unique("encode_b")) + + with scoped_context(): + var_a.set(1) + var_b.set(2) + + # Act + wire = wool.__chain__.get().to_manifest().to_protobuf() + + # Assert + keys = {(entry.namespace, entry.name) for entry in wire.vars} + assert keys == { + (var_a.namespace, var_a.name), + (var_b.namespace, var_b.name), + } + assert all(entry.HasField("value") for entry in wire.vars) + + def test_to_protobuf_should_emit_no_entries_when_empty_data(self): + """Test to_protobuf produces an empty chain manifest when data is empty. + + Given: + A chain with no variable bindings in its data map. + When: + to_protobuf is called. + Then: + The chain manifest should carry no entries. + """ + # Arrange + context = Chain(id=uuid4(), thread=threading.get_ident()) + + # Act + wire = context.to_manifest().to_protobuf() + + # Assert + assert len(wire.vars) == 0 + + def test_to_protobuf_should_emit_reset_var_entry(self): + """Test to_protobuf emits a no-value entry for a reset variable. + + Given: + A chain whose resets set names a variable that has no + current binding in data (reset to no value, not re-set). + When: + to_protobuf is called. + Then: + The chain manifest should carry an entry for that variable with + no value field set. + """ + # Arrange + var = ContextVar(_unique("encode_reset_var")) + context = Chain( + id=uuid4(), + thread=threading.get_ident(), + resets=frozenset({(var.namespace, var.name)}), + ) + + # Act + wire = context.to_manifest().to_protobuf() + + # Assert + entry = next( + e for e in wire.vars if (e.namespace, e.name) == (var.namespace, var.name) + ) + assert not entry.HasField("value") + + def test_to_protobuf_should_warn_and_skip_unserializable_value(self): + """Test to_protobuf warns and skips a variable it cannot serialize. + + Given: + An armed chain carrying an unpicklable value alongside a + normal binding. + When: + to_protobuf is called. + Then: + It should emit a SerializationWarning carrying the failed + variable's key, the encode direction, and the underlying + cause, and emit only the serializable variable on the wire. + """ + # Arrange + good = ContextVar(_unique("encode_good")) + bad = ContextVar(_unique("encode_bad")) + + with scoped_context(): + good.set("ok") + bad.set(threading.Lock()) + + # Act + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + wire = wool.__chain__.get().to_manifest().to_protobuf() + + # Assert + keys = {(entry.namespace, entry.name) for entry in wire.vars} + assert keys == {(good.namespace, good.name)} + emitted = [ + w for w in caught if issubclass(w.category, wool.SerializationWarning) + ] + assert emitted + warning = emitted[0].message + assert isinstance(warning, wool.SerializationWarning) + assert warning.var_key == (bad.namespace, bad.name) + assert warning.direction == "encode" + assert warning.cause is not None + + def test_to_protobuf_should_aggregate_failures_when_strict_mode(self): + """Test to_protobuf raises ChainSerializationError under strict mode. + + Given: + An armed chain with two unserializable values and strict + warning filtering for SerializationWarning. + When: + to_protobuf is called. + Then: + It should raise a ChainSerializationError aggregating both + per-variable failures on .warnings, each carrying its + variable's key, the encode direction, and the underlying + cause. + """ + # Arrange + bad_a = ContextVar(_unique("encode_strict_a")) + bad_b = ContextVar(_unique("encode_strict_b")) + + with scoped_context(): + bad_a.set(threading.Lock()) + bad_b.set(threading.Lock()) + context = wool.__chain__.get(None) + assert context is not None + + # Act & assert + with warnings.catch_warnings(): + warnings.filterwarnings("error", category=wool.SerializationWarning) + with pytest.raises(wool.ChainSerializationError) as exc_info: + context.to_manifest().to_protobuf() + assert len(exc_info.value.warnings) == 2 + assert all( + isinstance(e, wool.SerializationWarning) for e in exc_info.value.warnings + ) + assert {w.var_key for w in exc_info.value.warnings} == { + (bad_a.namespace, bad_a.name), + (bad_b.namespace, bad_b.name), + } + assert all(w.direction == "encode" for w in exc_info.value.warnings) + assert all(w.cause is not None for w in exc_info.value.warnings) + + def test_to_protobuf_should_skip_variable_when_backing_undefined(self): + """Test to_protobuf skips a data entry whose backing is Undefined. + + Given: + A Chain whose data index names a variable, but the + variable's backing contextvars.ContextVar resolves to the + Undefined sentinel in the active context — a data-membership + desync. + When: + to_protobuf is called inside that context. + Then: + The chain manifest should carry no entry for that variable — + encode skips it defensively rather than ship a phantom value. + """ + # Arrange + var = ContextVar(_unique("encode_desync")) + + with scoped_context(): + # Drive the backing to the Undefined sentinel through the + # public set/reset cycle — a first-set reset rewinds the + # backing to its unset default — then index the variable in + # data, producing the desynced state. + token = var.set("v") + var.reset(token) + context = Chain( + id=uuid4(), + thread=threading.get_ident(), + vars=frozenset({var}), + ) + + # Act + wire = context.to_manifest().to_protobuf() + + # Assert + keys = {(entry.namespace, entry.name) for entry in wire.vars} + assert (var.namespace, var.name) not in keys + + +class TestChainMount: + def test_mount_should_restamp_owner_and_owning_task(self): + """Test Chain.mount re-stamps the context with the calling owner. + + Given: + A ChainManifest decoded from a wire chain manifest — owner + stamping happens at mount time, not decode time. + When: + mount applies it onto the current contextvars.Context. + Then: + The installed chain should be stamped with the calling + thread as owner and a None _owning_task should be replaced by + whatever the calling task carries (None here, off-loop). + """ + # Arrange + var = ContextVar(_unique("mount_owner")) + with scoped_context(): + var.set("v") + decoded = _decode_manifest(wool.__chain__.get().to_manifest().to_protobuf()) + + # Act & assert + with scoped_context(): + _mount_manifest(decoded) + mounted = wool.__chain__.get(None) + assert mounted is not None + assert mounted.thread == threading.get_ident() + + def test_mount_should_apply_manifest_values_to_backing_variables(self): + """Test Chain.from_manifest writes decoded values into the + backing vars. + + Given: + A ChainManifest decoded from a wire chain manifest that bound + a variable to a value. + When: + mount applies it in a fresh contextvars.Context. + Then: + get() on the variable should return the decoded value — + mount drains the manifest into the backing variables. + """ + # Arrange + var = ContextVar(_unique("mount_value")) + with scoped_context(): + var.set("mounted-value") + decoded = _decode_manifest(wool.__chain__.get().to_manifest().to_protobuf()) + + # Act & assert + with scoped_context(): + _mount_manifest(decoded) + assert var.get() == "mounted-value" + + def test_mount_should_apply_manifest_via_unified_install_pipeline(self): + """Test the unified install pipeline transfers values to backings. + + Given: + A manifest decoded from a wire chain manifest bound to a value. + When: + The receive-time pipeline routes via + `Chain.from_manifest` from a fresh `contextvars.Context`. + Then: + ``var.get()`` returns the decoded value — the install + pipeline drained the manifest into the backing variable + and installed a fresh Chain with the matching ``data`` + index. + """ + # Arrange + var = ContextVar(_unique("mount_install")) + with scoped_context(): + var.set("v") + decoded = _decode_manifest(wool.__chain__.get().to_manifest().to_protobuf()) + + # Act + with scoped_context(): + _mount_manifest(decoded) + + # Assert: the installed live chain resolves the var + # through its backing. + assert var.get() == "v" + + def test_mount_should_fold_data_into_active_context_when_armed(self): + """Test mount folds incoming data into the active chain. + + Given: + An armed chain carrying one variable and a decoded chain + manifest carrying a different variable, both on the same chain + id (as production receive sites always are — the worker adopts + the caller's chain on first arm). + When: + mount is called with the decoded chain manifest. + Then: + The active chain should index both variables, both values + should be observable, and the chain id is unchanged. + """ + # Arrange + existing = ContextVar(_unique("merge_existing")) + incoming_var = ContextVar(_unique("merge_incoming")) + + with scoped_context(): + incoming_var.set("b") + decoded = _decode_manifest(wool.__chain__.get().to_manifest().to_protobuf()) + + with scoped_context(): + existing.set("a") + _adopt_chain(decoded.id) + armed = wool.__chain__.get(None) + assert armed is not None + original_chain = armed.id + + # Act + _mount_manifest(decoded) + merged = wool.__chain__.get(None) + + # Assert + assert merged is not None + assert existing in merged.vars + assert incoming_var in merged.vars + assert existing.get() == "a" + assert incoming_var.get() == "b" + assert merged.id == original_chain + + def test_mount_should_let_incoming_win_when_overlap(self): + """Test mount lets the incoming chain manifest win overlapping keys. + + Given: + An armed chain and a decoded chain manifest that both bind the + same variable to different values. + When: + mount is called. + Then: + The merged chain should carry the incoming value. + """ + # Arrange + var = ContextVar(_unique("merge_overlap")) + + with scoped_context(): + var.set("remote") + decoded = _decode_manifest(wool.__chain__.get().to_manifest().to_protobuf()) + + # Act + with scoped_context(): + var.set("local") + _adopt_chain(decoded.id) + _mount_manifest(decoded) + + # Assert + assert var.get() == "remote" + + def test_mount_should_arm_unarmed_context(self): + """Test mount arms an unarmed chain. + + Given: + An unarmed chain and a decoded chain manifest with state. + When: + mount is called. + Then: + The chain should become armed and observe the merged value. + """ + # Arrange + var = ContextVar(_unique("merge_arm")) + + def _build_decoded(): + var.set("armed") + return _decode_manifest(wool.__chain__.get().to_manifest().to_protobuf()) + + decoded = contextvars.Context().run(_build_decoded) + + # Act + assert wool.__chain__.get(None) is None + _mount_manifest(decoded) + + # Assert + assert wool.__chain__.get(None) is not None + assert var.get() == "armed" + + def test_mount_should_propagate_reset_signal(self): + """Test mount removes a variable that the incoming chain manifest reset. + + Given: + An armed chain binding a variable, and a decoded chain manifest + whose resets set names that variable. + When: + mount is called. + Then: + The variable should be removed from the active chain. + """ + # Arrange + var = ContextVar(_unique("merge_reset")) + sender = Chain( + id=uuid4(), + thread=threading.get_ident(), + resets=frozenset({(var.namespace, var.name)}), + ) + decoded = _decode_manifest(sender.to_manifest().to_protobuf()) + + # Act + with scoped_context(): + var.set("present") + _adopt_chain(decoded.id) + _mount_manifest(decoded) + + # Assert + with pytest.raises(LookupError): + var.get() + + def test_mount_should_preserve_re_set_after_reset(self): + """Test mount keeps a variable the incoming chain manifest re-set. + + Given: + A receiver chain where a variable was reset to no value, and + a decoded chain manifest that re-binds that same variable in data. + When: + mount is called. + Then: + The variable should remain bound to the incoming value and be + absent from the merged resets. + """ + # Arrange + var = ContextVar(_unique("merge_reset_rebind")) + + with scoped_context(): + var.set("new") + decoded = _decode_manifest(wool.__chain__.get().to_manifest().to_protobuf()) + + with scoped_context(): + token = var.set("old") + var.reset(token) + _adopt_chain(decoded.id) + + # Act + _mount_manifest(decoded) + + # Assert + assert var.get() == "new" + merged = wool.__chain__.get(None) + assert merged is not None + assert (var.namespace, var.name) not in merged.resets + + def test_mount_should_preserve_untouched_receiver_resets(self): + """Test mount keeps receiver resets the incoming chain manifest ignored. + + Given: + A receiver chain with one variable reset to no value, and a + decoded chain manifest that touches a different variable only. + When: + mount is called. + Then: + The receiver's reset for the untouched variable should survive + the merge — the merge is one-way, the incoming chain manifest wins + only for the keys it touched. + """ + # Arrange + receiver_reset_var = ContextVar(_unique("merge_keep_reset")) + incoming_var = ContextVar(_unique("merge_incoming_only")) + + with scoped_context(): + incoming_var.set("incoming") + decoded = _decode_manifest(wool.__chain__.get().to_manifest().to_protobuf()) + + # Act & assert + with scoped_context(): + token = receiver_reset_var.set("present") + receiver_reset_var.reset(token) + _adopt_chain(decoded.id) + _mount_manifest(decoded) + merged = wool.__chain__.get(None) + assert merged is not None + assert ( + receiver_reset_var.namespace, + receiver_reset_var.name, + ) in merged.resets + + @pytest.mark.asyncio + async def test_mount_should_self_install_task_factory_when_arming(self): + """Test a mount that arms an unarmed chain self-installs the task factory. + + Given: + An async unarmed chain with no Wool task factory installed, + and a decoded chain manifest with state. + When: + mount arms the chain, then a child task is created. + Then: + The child task should fork onto a chain id distinct from the + merged chain — the arming merge self-installed the task + factory so copy-on-fork engages. + """ + # Arrange + var = ContextVar(_unique("merge_self_install")) + loop = asyncio.get_running_loop() + loop.set_task_factory(None) + with scoped_context(): + var.set("armed") + decoded = _decode_manifest(wool.__chain__.get().to_manifest().to_protobuf()) + + # Act + with scoped_context(): + _mount_manifest(decoded) + merged = wool.__chain__.get(None) + assert merged is not None + + async def child() -> uuid.UUID: + context = wool.__chain__.get(None) + assert context is not None + return context.id + + try: + child_chain = await asyncio.create_task(child()) + finally: + loop.set_task_factory(None) + + # Assert + assert child_chain != merged.id + + @pytest.mark.asyncio + async def test_mount_should_surface_displacement_when_unowned(self): + """Test an unowned mount surfaces a displaced task factory. + + Given: + An armed chain on a loop whose Wool task factory was + displaced by a third-party factory installed after it. + When: + The chain is mounted with owned=False (the worker-side + per-step driver path). + Then: + It should raise TaskFactoryDisplaced — the unowned mount + ensures the factory unconditionally, so displacement + surfaces at the mount rather than only on the next + wool.ContextVar set. + """ + # Arrange + var = ContextVar(_unique("unowned_displaced")) + loop = asyncio.get_running_loop() + loop.set_task_factory(None) + with scoped_context(): + var.set("armed") + chain = wool.__chain__.get() + + def third_party_factory(loop, coro, **kwargs): + return asyncio.Task(coro, loop=loop, **kwargs) + + loop.set_task_factory(third_party_factory) + + # Act & assert + try: + with pytest.raises(TaskFactoryDisplaced, match="displaced"): + chain.mount(owned=False) + finally: + loop.set_task_factory(None) + + +class TestChainManifestVars: + def test_vars_should_expose_decoded_values_when_pre_mount_manifest(self): + """Test ChainManifest.vars surfaces decoded values. + + Given: + A ChainManifest decoded from a wire chain manifest that carried + a value for one variable. + When: + ``manifest.vars`` is read. + Then: + The mapping should expose the variable's decoded value + keyed by the wool.ContextVar singleton — the same instance + a caller obtains by constructing the variable with the + matching (namespace, name). + """ + # Arrange + var = ContextVar(_unique("vars_pre_mount")) + with scoped_context(): + var.set("shipped") + decoded = _decode_manifest(wool.__chain__.get().to_manifest().to_protobuf()) + + # Act + snapshot = decoded.vars + + # Assert + assert snapshot[var] == "shipped" + assert dict(snapshot) == {var: "shipped"} + + def test_vars_should_reflect_manifest_state(self): + """Test ChainManifest.vars exposes the decoded values. + + Given: + A ChainManifest decoded from a wire chain manifest — its + vars is populated. + When: + ``manifest.vars`` is read before and after mount. + Then: + Pre-mount, the variable's decoded value is exposed via the + mapping. Mount does not mutate the manifest in place — the + mapping remains populated after mount, by design (the + manifest is a record of the wire decode, not the drain + transient). The drained values live on the backing + variables, reachable via wool.ContextVar.get inside the + mounted chain. + """ + # Arrange + var = ContextVar(_unique("vars_drain")) + with scoped_context(): + var.set("shipped") + decoded = _decode_manifest(wool.__chain__.get().to_manifest().to_protobuf()) + assert decoded.vars[var] == "shipped" + + # Act + with scoped_context(): + _mount_manifest(decoded) + + # Assert + assert decoded.vars == {var: "shipped"} + assert var.get() == "shipped" + + +def test_stub_backing_variable_value_should_survive_promotion_to_real_variable(): + """Test a wire value decoded into a stub survives stub-to-real promotion. + + Given: + A wire chain manifest carrying a value — chosen to differ from the + constructor default — for an undeclared variable key, decoded + and mounted so the value lands in the stub's backing variable. + When: + The real ContextVar is declared for that key, with a different + constructor default, and get() is called. + Then: + get() should return the wire value, not the constructor + default — the backing variable survives the stub-to-real + promotion. + """ + # Arrange + namespace = uuid.uuid4().hex + name = "stub_promote_value" + wire = protocol.ChainManifest(id=uuid4().hex) + wire.vars.add( + namespace=namespace, + name=name, + value=wool.__serializer__.dumps("wire-value"), + ) + decoded = _decode_manifest(wire) + + # Act & assert + with scoped_context(): + _mount_manifest(decoded) + # Declare the real variable with a default that differs from + # the wire value, promoting the stub in place. + real = ContextVar(name, namespace=namespace, default="ctor-default") + assert real.get() == "wire-value" + + +class TestResolveStub: + def test_resolve_stub_should_create_data_only_placeholder_when_key_unregistered( + self, + ): + """Test resolving an unregistered key yields a bare data placeholder. + + Given: + A (namespace, name) key no variable has been declared or + resolved for. + When: + resolve_stub is called with that key. + Then: + It should return a placeholder flagged as a stub and carrying + no constructor default — exposing the identity surface + (namespace, name) but none of the live get/set/reset + behavior of a declared variable. + """ + # Arrange + key = ("pkg", _unique("bare")) + + # Act + manifest = resolve_stub(key) + + # Assert + assert type(manifest) is ContextVarManifest + assert not isinstance(manifest, ContextVar) + assert manifest._stub is True + assert manifest._default is Undefined + assert (manifest.namespace, manifest.name) == key + assert not hasattr(manifest, "set") + assert not hasattr(manifest, "get") + + def test_resolve_stub_should_return_same_instance_when_key_already_resolved(self): + """Test repeated resolution of one key converges on a single placeholder. + + Given: + A key already resolved once into a registered placeholder. + When: + resolve_stub is called again with the same key. + Then: + It should return the very same instance — the registry holds + one placeholder per key. + """ + # Arrange + key = ("pkg", _unique("singleton")) + first = resolve_stub(key) + + # Act + second = resolve_stub(key) + + # Assert + assert first is second + + def test_resolve_stub_should_seed_default_into_a_default_less_placeholder(self): + """Test a later resolution folds a default into a default-less placeholder. + + Given: + A key first resolved with no default — a default-less + placeholder. + When: + resolve_stub is called again for the key with a default. + Then: + The same placeholder should adopt the default rather than + discard it — whichever ingress carries the default wins. + """ + # Arrange + key = ("pkg", _unique("fold")) + manifest = resolve_stub(key) + assert manifest._default is Undefined + + # Act + again = resolve_stub(key, default="seeded") + + # Assert + assert again is manifest + assert manifest._default == "seeded" diff --git a/wool/tests/runtime/context/test_registry.py b/wool/tests/runtime/context/test_registry.py deleted file mode 100644 index 492e47c6..00000000 --- a/wool/tests/runtime/context/test_registry.py +++ /dev/null @@ -1,71 +0,0 @@ -import asyncio - -import pytest - -from wool.runtime.context import Context -from wool.runtime.context import ContextVar -from wool.runtime.context import current_context - - -def test_current_context_with_set_vars(): - """Test current_context() returns the live Context for the scope. - - Given: - A ContextVar with an explicit value set - When: - current_context() is called twice from the same scope - Then: - The returned Context contains the var with its value and - both calls yield the same Context instance (idempotent - within a scope) - """ - # Arrange - var = ContextVar("cur_ctx", default=0) - var.set(1) - - # Act - ctx = current_context() - - # Assert - assert ctx[var] == 1 - assert current_context() is ctx - - -@pytest.mark.asyncio -async def test_context_registry_get_when_scope_has_no_bound_ctx(): - """Test :meth:`_ContextRegistry.get` returns ``None`` when the - current scope has no bound :class:`Context`. - - Given: - A scope (an asyncio task) running with the asyncio default - task factory — no wool task factory is installed, so no - :class:`Context` is auto-bound to the task identity. - When: - :meth:`context_registry.get` is called twice from inside - that scope. - Then: - Both reads observe ``None`` — the non-creating accessor - does not materialize a Context, while - :func:`current_context` does (covered separately by - :func:`test_current_context_with_set_vars`). - """ - # Arrange - from wool.runtime.context.registry import context_registry - - loop = asyncio.get_running_loop() - # Ensure no wool factory is installed on this loop so the child - # task does not auto-bind a Context to its identity. - loop.set_task_factory(None) - - observed: list[Context | None] = [] - - async def body(): - observed.append(context_registry.get()) - observed.append(context_registry.get()) - - # Act — run under the asyncio default task factory so no wool - # binding leaks in. - await loop.create_task(body()) - - # Assert - assert observed == [None, None] diff --git a/wool/tests/runtime/context/test_runtime.py b/wool/tests/runtime/context/test_runtime.py new file mode 100644 index 00000000..b2fe9f35 --- /dev/null +++ b/wool/tests/runtime/context/test_runtime.py @@ -0,0 +1,205 @@ +"""Unit tests for RuntimeContext — block-scoped runtime option overrides.""" + +import pytest + +from wool import protocol +from wool.runtime.context.runtime import RuntimeContext +from wool.runtime.context.runtime import dispatch_timeout + + +class TestRuntimeContext: + def test___enter___should_install_dispatch_timeout(self): + """Test RuntimeContext.__enter__ installs an explicit dispatch timeout. + + Given: + A RuntimeContext constructed with an explicit + dispatch_timeout. + When: + It is entered as a context manager. + Then: + The ambient dispatch_timeout variable should report the + supplied value for the duration of the block. + """ + # Arrange + context = RuntimeContext(dispatch_timeout=7.5) + + # Act & assert + with context: + assert dispatch_timeout.get() == 7.5 + + def test___exit___should_restore_prior_dispatch_timeout(self): + """Test RuntimeContext.__exit__ restores the prior dispatch timeout. + + Given: + A RuntimeContext entered over an explicit dispatch_timeout. + When: + The context-manager block exits. + Then: + The ambient dispatch_timeout variable should revert to the + value it held before the block. + """ + # Arrange + token = dispatch_timeout.set(2.0) + + # Act + with RuntimeContext(dispatch_timeout=9.0): + pass + + # Assert + assert dispatch_timeout.get() == 2.0 + dispatch_timeout.reset(token) + + def test___enter___should_leave_dispatch_timeout_when_no_override(self): + """Test a bare RuntimeContext does not touch the dispatch timeout. + + Given: + A bare RuntimeContext() constructed without an explicit + dispatch_timeout — the default sentinel. + When: + It is entered and exited as a context manager over a scope + with a pre-set dispatch_timeout. + Then: + The ambient dispatch_timeout should be unchanged throughout + — the no-override path skips setting the variable. + """ + # Arrange + token = dispatch_timeout.set(4.0) + + # Act & assert + with RuntimeContext(): + assert dispatch_timeout.get() == 4.0 + assert dispatch_timeout.get() == 4.0 + dispatch_timeout.reset(token) + + def test_get_current_should_capture_live_dispatch_timeout(self): + """Test RuntimeContext.get_current snapshots the live dispatch timeout. + + Given: + A scope with the ambient dispatch_timeout set to a value. + When: + RuntimeContext.get_current is called. + Then: + The captured RuntimeContext should re-install that value + when entered. + """ + # Arrange + token = dispatch_timeout.set(3.25) + + # Act + captured = RuntimeContext.get_current() + + # Assert + dispatch_timeout.reset(token) + with captured: + assert dispatch_timeout.get() == 3.25 + + def test_from_protobuf_should_reconstruct_dispatch_timeout(self): + """Test RuntimeContext.from_protobuf reconstructs an explicit timeout. + + Given: + A protocol.RuntimeContext message carrying a + dispatch_timeout field. + When: + RuntimeContext.from_protobuf is called on it. + Then: + Entering the reconstructed RuntimeContext should install + the message's timeout value. + """ + # Arrange + message = protocol.RuntimeContext(dispatch_timeout=6.0) + + # Act + context = RuntimeContext.from_protobuf(message) + + # Assert + with context: + assert dispatch_timeout.get() == 6.0 + + def test_to_protobuf_should_omit_field_when_explicit_none(self): + """Test RuntimeContext.to_protobuf omits the field for an explicit None. + + Given: + A RuntimeContext constructed with an explicit + dispatch_timeout of None. + When: + to_protobuf is called. + Then: + The message should not carry the dispatch_timeout field — + an explicit None lets the receiver inherit its own scope's + default. + """ + # Arrange + context = RuntimeContext(dispatch_timeout=None) + + # Act + message = context.to_protobuf() + + # Assert + assert message.HasField("dispatch_timeout") is False + + def test_to_protobuf_should_capture_live_value_when_default_sentinel(self): + """Test RuntimeContext.to_protobuf captures the live timeout when unset. + + Given: + A bare RuntimeContext() — the default sentinel — encoded + inside a scope whose ambient dispatch_timeout is set. + When: + to_protobuf is called. + Then: + The message should carry the scope's live dispatch_timeout + value — a bare RuntimeContext propagates the encoder's + effective timeout to the receiver. + """ + # Arrange + token = dispatch_timeout.set(11.5) + + # Act + message = RuntimeContext().to_protobuf() + + # Assert + dispatch_timeout.reset(token) + assert message.HasField("dispatch_timeout") is True + assert message.dispatch_timeout == 11.5 + + def test_to_protobuf_should_round_trip_explicit_timeout(self): + """Test a RuntimeContext round-trips an explicit timeout through protobuf. + + Given: + A RuntimeContext with an explicit dispatch_timeout. + When: + It is encoded with to_protobuf and decoded with + from_protobuf. + Then: + Entering the decoded RuntimeContext should install the + original timeout value. + """ + # Arrange + original = RuntimeContext(dispatch_timeout=8.0) + + # Act + restored = RuntimeContext.from_protobuf(original.to_protobuf()) + + # Assert + with restored: + assert dispatch_timeout.get() == 8.0 + + def test_double_enter_should_raise(self): + """Test re-entering an already-active RuntimeContext raises. + + Given: + A RuntimeContext currently inside an active ``with`` + block. + When: + The same instance is entered as a context manager a + second time. + Then: + It should raise RuntimeError — RuntimeContext instances + are block-scoped and single-use. + """ + # Arrange + rc = RuntimeContext(dispatch_timeout=2.0) + + # Act & assert + with rc: + with pytest.raises(RuntimeError, match="already active"): + rc.__enter__() diff --git a/wool/tests/runtime/context/test_token.py b/wool/tests/runtime/context/test_token.py index b52abced..0eee1a4f 100644 --- a/wool/tests/runtime/context/test_token.py +++ b/wool/tests/runtime/context/test_token.py @@ -1,437 +1,150 @@ -import gc import uuid -import weakref +from contextvars import Token import cloudpickle import pytest import wool from tests.helpers import scoped_context -from wool.runtime.context import Context -from wool.runtime.context import ContextVar -from wool.runtime.context import Token -from wool.runtime.context import attached +from wool.runtime.context.var import ContextVar +# ``dumps`` is the wool serializer — ContextVar serialises only through +# it. ``loads`` is plain ``cloudpickle.loads``: wool-encoded bytes are +# cloudpickle-loadable, so a separate wool deserializer is not needed +# here. dumps = wool.__serializer__.dumps loads = cloudpickle.loads -class TestToken: - def test_pickle_roundtrip_with_var_reference(self): - """Test Token pickle roundtrip carries its owning ContextVar by key. - - Given: - A ContextVar and a Token produced by set() - When: - The Token is pickled and unpickled - Then: - The restored token should reference the same ContextVar instance - """ - # Arrange - var = ContextVar("tokened") - token = var.set("x") - - # Act - restored = loads(dumps(token)) - - # Assert - assert restored.var is var - assert restored.old_value is Token.MISSING - - def test_pickle_roundtrip_in_same_process(self): - """Test same-process pickle of a Token returns the same instance. - - Given: - A live Token minted via ContextVar.set — strongly - referenced so its entry in the process-wide token - registry stays alive - When: - The Token is pickled and unpickled in the same process - Then: - The restored token should be the same Python object as - the original — the registry lookup in Token._reconstitute - resolves the id back to the live instance so mutations - to ``_used`` stay visible across all references - """ - # Arrange - var = ContextVar("pickle_identity") - token = var.set("x") - - # Act - restored = loads(dumps(token)) - - # Assert - assert restored is token - - def test_pickle_roundtrip_with_used_flag(self): - """Test Token.__wool_reduce__ serializes the _used flag. - - Given: - A Token minted and then consumed via ContextVar.reset, - pickled after consumption - When: - The pickled bytes are loaded in a context where the - original Token is no longer reachable (registry miss - forces _reconstitute to build a fresh stub) - Then: - The restored stub should have used=True — the flag - rides the pickle tuple so a cross-process copy cannot - silently attempt reset - """ - # Arrange - var = ContextVar("pickle_used_flag") - token = var.set("x") - var.reset(token) - pickled = dumps(token) - original_ref = weakref.ref(token) - del token - gc.collect() - assert original_ref() is None, ( - "Original Token must be collected before the load for the " - "registry-miss path to fire" - ) - - # Act - restored = loads(pickled) +def _unique(stem: str) -> str: + """Return a process-unique variable name to avoid registry collisions.""" + return f"{stem}_{uuid.uuid4().hex}" - # Assert - assert restored.used is True - def test_pickle_roundtrip_in_active_receiver_context(self): - """Test pickling a Token does not mutate the receiver's - :class:`Context` via an embedded var value. +class TestToken: + def test___repr___should_reference_backing_contextvar(self): + """Test Token repr references the backing ContextVar. Given: - A wool.ContextVar bound to value ``"A"`` in one Context, - with a Token pickled while that binding is active. The - receiver later enters a different Context where the same - var has been bound to ``"B"``. + A Token produced by set() on a wool.ContextVar. When: - The pickled Token bytes are loaded under the receiver's - Context + repr() is called on it. Then: - The receiver's Context should still report ``"B"`` for - the var — a Token is a reset receipt, not a value- - bearing wire payload, so its pickle round-trip must not - transitively propagate the originating Context's binding - through the owning ContextVar's reduce path + It should include the backing stdlib ContextVar's + qualified name — wool encodes ``(namespace, name)`` into + the backing's name so the repr names the originating + wool variable in a stdlib-shaped form. """ # Arrange - var = ContextVar(f"token_no_value_leak_{uuid.uuid4().hex}") - token = var.set("A") - pickled = dumps(token) - var.reset(token) - observed: list[str] = [] + var = ContextVar(_unique("repr_token_var")) # Act with scoped_context(): - var.set("B") - loads(pickled) - observed.append(var.get()) - - # Assert - assert observed == ["B"] - - def test_set_reset_loop_token_lifecycle(self): - """Test a tight set/reset loop does not accumulate per-iteration state. - - Given: - A ContextVar in a fresh Context scope and a weakref to a - sampled iteration's Token - When: - set(value) followed by reset(token) runs many times - with no lingering strong reference to any iteration's - Token, followed by gc.collect() - Then: - The sampled Token's weakref should resolve to None — - each iteration's Token becomes unreachable once the - loop's local binding is overwritten, so no lingering - per-iteration state keeps it alive - """ - # Arrange - var = ContextVar("loop_no_leak") - sampled_ref: weakref.ref[Token] | None = None - - # Act - for i in range(200): token = var.set("x") - if i == 100: - sampled_ref = weakref.ref(token) - var.reset(token) - del token - gc.collect() - - # Assert - assert sampled_ref is not None - assert sampled_ref() is None - - def test_repr_includes_var_key(self): - """Test Token repr includes the owning var's key. - - Given: - A Token produced by set() on a ContextVar - When: - repr() is called on it - Then: - The repr should include the var's full key - """ - # Arrange - var = ContextVar("repr_token_var") - token = var.set("x") - - # Act - text = repr(token) + text = repr(token) # Assert - assert repr((var.namespace, var.name)) in text + assert var.namespace in text + assert var.name in text - def test_missing_pickle_roundtrip_singleton_identity(self): - """Test pickling and unpickling Token.MISSING returns the same - singleton instance. + def test_old_value_should_report_missing_when_var_previously_unset(self): + """Test Token.old_value is MISSING when the var was previously unset. Given: - Token.MISSING — a singleton-by-construction sentinel - whose ``__new__`` caches the lone instance and whose - ``__reduce__`` rebuilds via the same constructor + A previously-unset ContextVar. When: - Token.MISSING is pickled with cloudpickle and the bytes - are loaded back + set() mints a Token and its old_value is read. Then: - The reloaded value is identical to the original — - ``loaded is Token.MISSING`` — so callers comparing - ``token.old_value is Token.MISSING`` after a wire round- - trip still hit the identity check. + It should be Token.MISSING. """ # Arrange - original = Token.MISSING + var = ContextVar(_unique("old_value_missing")) # Act - pickled = dumps(original) - loaded = loads(pickled) + with scoped_context(): + token = var.set("x") # Assert - assert loaded is original - assert loaded is Token.MISSING - - def test_used_is_false_before_reset_and_true_after(self): - """Test Token.used flips from False to True when the owning var is reset. - - Given: - A ContextVar and a Token produced by var.set() - When: - var.reset(token) is called - Then: - Token.used should be False before the reset and True - after — single-process sanity for the lifecycle flag - """ - # Arrange - var = ContextVar("used_flag", default="d") - token = var.set("x") - - # Act & assert - assert token.used is False - var.reset(token) - assert token.used is True + assert token.old_value is Token.MISSING - def test_out_of_order_pickle_loads_sync_used_state(self): - """Test loading an older Token snapshot before a newer one - leaves the registered instance synced to the most-progressed - ``_used`` value witnessed across the round-trips. + def test_old_value_should_report_prior_value_when_nested_set(self): + """Test Token.old_value reports the prior value for a nested set. Given: - A user pickles a Token before reset and again after - reset, drops their reference to the original so the - weak registry entry is collected, and then loads the - two pickled snapshots in chronological order - (``data_pre`` first, ``data_post`` second) + A ContextVar already set to a value. When: - The second load lands a registry hit on the stub - registered by the first load and observes the wire - payload's ``used=True`` + A second set mints a Token whose old_value is read. Then: - Both reloaded references report ``used=True`` — the - registry instance is monotonically advanced via - :meth:`Token._sync_state` so a subsequent - :meth:`ContextVar.reset` sees the consumed state and - cannot silently double-reset against an older snapshot + It should equal the first value. """ # Arrange - var = ContextVar(f"sync_used_{uuid.uuid4().hex}", default="d") - token = var.set("x") - data_pre = dumps(token) - var.reset(token) - data_post = dumps(token) - - # Drop the original strong reference so the WeakValueDictionary - # entry can be collected; force a GC to ensure the entry is - # gone before the first load fires. - del token - gc.collect() + var = ContextVar(_unique("old_value_prior")) # Act - loaded_pre = loads(data_pre) - loaded_post = loads(data_post) + with scoped_context(): + var.set("first") + token = var.set("second") # Assert - assert loaded_post is loaded_pre, ( - "Same-id pickle round-trips should converge on a single registry instance" - ) - assert loaded_pre.used is True, ( - "Older snapshot followed by newer snapshot must leave the " - "registry instance reflecting the consumed state, not the " - "pre-reset snapshot" - ) + assert token.old_value == "first" - def test_vanilla_pickle_copy_and_deepcopy_are_rejected(self): - """Test wool.Token rejects pickle, cloudpickle, copy.copy, - and copy.deepcopy. + def test_token_should_reject_pickle_cloudpickle_and_copy(self): + """Test wool.Token rejects pickle, cloudpickle, copy.copy, and deepcopy. Given: - A live wool.Token + A live wool.Token (stdlib contextvars.Token under the alias). When: pickle.dumps, cloudpickle.dumps, copy.copy, and - copy.deepcopy are each invoked on it + copy.deepcopy are each invoked on it. Then: - All four raise TypeError. Tokens are bound to the - process-wide registry and a live wool.Context, neither of - which is reconstructible outside Wool's dispatch path — - wool.__serializer__ remains the only valid serialization - channel. + All four raise TypeError — stdlib :class:`contextvars.Token` + is not picklable, and a chain-bound Token has no meaningful + copy semantics. Cross-process Token transport is deferred; + see issue #231. """ # Arrange import copy as _copy import pickle - match = "wool.Token cannot be pickled" - var = ContextVar("token_pickle_rejection") - token = var.set("x") + var = ContextVar(_unique("token_reject_channels")) # Act & assert - with pytest.raises(TypeError, match=match): - pickle.dumps(token) - with pytest.raises(TypeError, match=match): - cloudpickle.dumps(token) - with pytest.raises(TypeError, match=match): - _copy.copy(token) - with pytest.raises(TypeError, match=match): - _copy.deepcopy(token) - - def test_wool_serializer_encodes_token_that_vanilla_cloudpickle_rejects(self): - """Test wool.__serializer__.dumps encodes a wool.Token where - bare cloudpickle.dumps raises. - - Given: - A live wool.Token whose __reduce_ex__ guard rejects - vanilla pickle / cloudpickle. - When: - The same Token is fed to wool.__serializer__.dumps and to - cloudpickle.dumps. - Then: - wool.__serializer__.dumps should produce bytes that - round-trip through cloudpickle.loads back to the same - Token, while cloudpickle.dumps raises TypeError — - documenting the worker dispatch handler's contract that - payload serialization MUST go through wool.__serializer__ - (whose internal _WoolPickler honors __wool_reduce__) and - NOT bare cloudpickle (which trips the __reduce_ex__ guard). - """ - # Arrange - var = ContextVar("wool_serializer_token_round_trip") - token = var.set("x") - - # Act - encoded = wool.__serializer__.dumps(token) - restored = cloudpickle.loads(encoded) - - # Assert - assert restored is token, "Same-process round-trip preserves identity" - with pytest.raises(TypeError, match="wool.Token cannot be pickled"): - cloudpickle.dumps(token) + with scoped_context(): + token = var.set("x") + for channel in ( + pickle.dumps, + cloudpickle.dumps, + _copy.copy, + _copy.deepcopy, + ): + with pytest.raises(TypeError, match="cannot pickle"): + channel(token) -def test_token_reconstitute_with_uuid_in_active_external_set(): - """Test cross-process pickle reload of a Token whose UUID arrived - earlier on the wire as a consumed-token id converges to a - consumed Token instance. +def test_token_should_pin_context_var_in_weak_registry(): + """Test a live Token keeps its backing variable reachable. Given: - A :class:`Token` minted under a sender's :class:`Context`, - cloudpickled, the original released so the process-wide - registry no longer carries it, and a wire - :class:`protocol.Context` whose ``consumed_tokens`` lists - that token's id — reconstituted to a Context that is then - attached as the active scope. + A Token minted by ContextVar.set, with no other strong + reference to the wool variable. When: - The pickle is loaded under that active Context. + A garbage collection runs while the token is held. Then: - ``restored.used`` is True — the reconstitute path observes - the matching id in the active Context's external set and - adopts the consumed state, so a subsequent - :meth:`ContextVar.reset` raises rather than silently - succeeding. + The token's ``var`` attribute — the backing stdlib + :class:`contextvars.ContextVar` whose qualified name encodes + the wool ``(namespace, name)`` identity — should still + resolve. """ - # Arrange - from wool import protocol - - var = ContextVar(f"reconstitute_promote_{uuid.uuid4().hex}") - token = var.set("x") - token_id = token.id - pickled = dumps(token) - del token - gc.collect() + import gc - pb = protocol.Context(id=uuid.uuid4().hex) - pb.vars.add( - namespace=var.namespace, - name=var.name, - consumed_tokens=[token_id.hex], - ) - secondary = Context.from_protobuf(pb) - - # Act - with attached(secondary, guarded=False): - restored = loads(pickled) - - # Assert - assert restored.used is True - - -def test_token_does_not_pin_originating_context(): - """Test a live Token does not keep its originating wool.Context - alive once the user drops their own references to the Context. - - Given: - A wool.Context held via a weakref and a Token minted under - that Context (via ContextVar.set inside Context.run) retained - in a strong-reffed local — modeling a user who stashes a - Token on a long-lived service object. - When: - The original strong reference to the Context is dropped and - ``gc.collect()`` runs. - Then: - The weakref to the Context resolves to ``None`` — the Token - does not pin the Context, so other vars and values bound in - that Context are eligible for collection. A long-lived Token - no longer leaks its originating Context's state. - """ # Arrange - var = ContextVar(f"no_pin_{uuid.uuid4().hex}") - ctx = Context() - ctx_ref = weakref.ref(ctx) - - def mint() -> Token: - return var.set("x") - - token = ctx.run(mint) + namespace = uuid.uuid4().hex + with scoped_context(): + token = ContextVar("pinned_var", namespace=namespace).set("x") # Act - del ctx gc.collect() - # Assert - assert ctx_ref() is None, ( - "Originating Context should be collectible once the user " - "drops their reference, even with the Token still alive" - ) - # Sanity: the token is still alive and reports unused. - assert token.used is False + # Assert — the backing carries the qualified key in its name. + assert namespace in token.var.name + assert "pinned_var" in token.var.name diff --git a/wool/tests/runtime/context/test_var.py b/wool/tests/runtime/context/test_var.py index 0ad5c8d2..36c11838 100644 --- a/wool/tests/runtime/context/test_var.py +++ b/wool/tests/runtime/context/test_var.py @@ -1,96 +1,135 @@ +import asyncio import contextvars +import copy import gc +import pickle +import threading import uuid +from contextvars import Token import cloudpickle import pytest +from hypothesis import HealthCheck from hypothesis import given +from hypothesis import settings from hypothesis import strategies as st import wool +from tests.helpers import _unique from tests.helpers import scoped_context -from wool.runtime.context import Context -from wool.runtime.context import ContextVar -from wool.runtime.context import ContextVarCollision -from wool.runtime.context import Token -from wool.runtime.context import current_context +from wool import protocol + +# ``Chain`` is imported so the task-adoption test can seed an ownerless +# armed chain directly through the public ``wool.__chain__`` context var +# (rather than ``Chain.mount``, which restamps the owner), verifying the +# adoption path on first `ContextVar.set`. +from wool.runtime.context.chain import Chain +from wool.runtime.context.exceptions import ContextVarCollision +from wool.runtime.context.manifest import ChainManifest +from wool.runtime.context.manifest import resolve_stub +from wool.runtime.context.var import ContextVar dumps = wool.__serializer__.dumps loads = cloudpickle.loads +# Hypothesis sentinel for the default-ladder test — distinguishes "no +# argument supplied" from a supplied ``None``, which is itself a valid +# default value. +_NOTHING = object() + class TestContextVar: - def test___init___with_name_only(self): + def test___new___should_raise_lookup_error_on_get_when_name_only(self): """Test ContextVar initialization with a name and no default. Given: - A name string + A name string. When: - ContextVar is instantiated with the name + ContextVar is instantiated with the name. Then: - It should expose the name and raise LookupError on get() with no value set + It should expose the name and raise LookupError on get() + with no value set. """ # Act - var = ContextVar("init_nameonly") + var = ContextVar(_unique("init_nameonly")) # Assert - assert var.name == "init_nameonly" + assert var.name.startswith("init_nameonly") with pytest.raises(LookupError): var.get() - def test___init___with_default(self): + def test___new___should_return_default_on_get_when_default_given(self): """Test ContextVar initialization with a default value. Given: - A name string and a default value + A name string and a default value. When: - ContextVar is instantiated with both + ContextVar is instantiated with both. Then: - get() should return the default when no value is set + get() should return the default when no value is set. """ # Arrange - var = ContextVar("init_withdefault", default=42) + var = ContextVar(_unique("init_withdefault"), default=42) # Act & assert assert var.get() == 42 - def test___init___infers_namespace_from_caller(self): + def test___new___should_infer_namespace_from_caller_package(self): """Test ContextVar infers namespace from the caller's top-level package. Given: - A ContextVar constructed from this test module + A ContextVar constructed from this test module. When: - No explicit namespace is provided + No explicit namespace is provided. Then: - The namespace should be the top-level package of ``__name__`` + The namespace should be the top-level package of __name__. """ # Arrange expected_ns = __name__.partition(".")[0] # Act - var = ContextVar("inferred") + var = ContextVar(_unique("inferred")) # Assert assert var.namespace == expected_ns - assert var.name == "inferred" - def test___init___accepts_explicit_namespace(self): + def test___new___should_use_explicit_namespace_when_provided(self): """Test ContextVar uses an explicit namespace when provided. Given: - A ContextVar constructed with namespace='myapp' + A ContextVar constructed with an explicit namespace. When: - No implicit inference is needed + No implicit inference is needed. Then: - The key should combine the explicit namespace with the name + The key should combine the explicit namespace with the name. """ # Act - var = ContextVar("explicit", namespace="myapp") + var = ContextVar("explicit", namespace=_unique("ns")) # Assert - assert var.namespace == "myapp" + assert var.namespace.startswith("ns") assert var.name == "explicit" + def test___new___should_raise_collision_when_key_already_registered(self): + """Test re-declaring an already-registered key raises ContextVarCollision. + + Given: + A ContextVar already registered under a key. + When: + ContextVar is invoked again with the identical key. + Then: + It should raise ContextVarCollision — keys must be unique + within a namespace. + """ + # Arrange + namespace = _unique("dup_same") + first = ContextVar("v", namespace=namespace, default=1) + + # Act & assert + with pytest.raises(ContextVarCollision): + ContextVar("v", namespace=namespace, default=1) + assert first.name == "v" + @given( name=st.text( alphabet=st.characters(min_codepoint=ord("a"), max_codepoint=ord("z")), @@ -100,540 +139,861 @@ def test___init___accepts_explicit_namespace(self): default_a=st.one_of(st.integers(), st.text(), st.booleans()), default_b=st.one_of(st.integers(), st.text(), st.booleans()), ) - def test___init___with_duplicate_key(self, name, default_a, default_b): + @settings(suppress_health_check=[HealthCheck.function_scoped_fixture]) + def test___new___should_raise_collision_when_duplicate_key_regardless_of_default( + self, name, default_a, default_b + ): """Test duplicate-key construction raises regardless of default match. Given: - Any non-empty name and pair of default values + Any non-empty name and pair of default values. When: - A ContextVar is constructed, then a second with the identical key + A ContextVar is constructed, then a second with the + identical key. Then: - The second construction raises ContextVarCollision whether or - not the two defaults are equal + The second construction raises ContextVarCollision whether + or not the two defaults are equal — distinct declarations + of the same key collide. """ # A unique namespace per example guarantees the registry slot - # is free without needing teardown between Hypothesis examples - # (which share a single test-function invocation). The - # collision still fires within the example because both - # ContextVar calls use the same namespace + name. + # is free without needing teardown between Hypothesis examples. namespace = f"test_duplicate_{uuid.uuid4().hex}" # The registry holds vars weakly, so we hold ``first`` for the - # duration of the second construction; without that pin the - # weakref would drop before the second call and no collision - # would fire. + # duration of the second construction. first = ContextVar(name, namespace=namespace, default=default_a) with pytest.raises(ContextVarCollision): ContextVar(name, namespace=namespace, default=default_b) - # Hold ``first`` to the end so the collision actually had - # something to collide with. assert first.namespace == namespace assert first.name == name - def test___init___with_same_name_in_different_namespace(self): + def test___new___should_register_both_when_same_name_in_different_namespace(self): """Test duplicate names across different namespaces are allowed. Given: - A ContextVar 'shared' in namespace 'lib_a' + A ContextVar in one namespace. When: - A second ContextVar 'shared' is constructed in namespace 'lib_b' + A second ContextVar with the same name is constructed in a + different namespace. Then: - Both should register without collision + Both should register without collision. """ + # Arrange + ns_a = _unique("lib_a") + ns_b = _unique("lib_b") + # Act - a = ContextVar("shared", namespace="lib_a") - b = ContextVar("shared", namespace="lib_b") + a = ContextVar("shared", namespace=ns_a) + b = ContextVar("shared", namespace=ns_b) # Assert assert a is not b - assert (a.namespace, a.name) == ("lib_a", "shared") - assert (b.namespace, b.name) == ("lib_b", "shared") + assert (a.namespace, a.name) == (ns_a, "shared") + assert (b.namespace, b.name) == (ns_b, "shared") - def test_get_with_explicit_default_fallback(self): - """Test ContextVar.get returns the supplied fallback when unset. + def test___new___should_seed_supplied_default_when_promoting_stub(self): + """Test ContextVar.__new__ promotes a stub and seeds the supplied default. Given: - A ContextVar with no class-level default and no value set + A wire-decoded context that creates a default-less stub for + a key, then a ContextVar construction for the same key with + an explicit default. When: - get() is called with a fallback argument + ContextVar is constructed with default="promoted" for the + stub's key. Then: - It should return the fallback argument + get() should return "promoted" — the stub is promoted in + place and the supplied default is adopted. """ # Arrange - var = ContextVar("no_default") + var_namespace = uuid.uuid4().hex + pb = protocol.ChainManifest(id=uuid.uuid4().hex) + pb.vars.add( + namespace=var_namespace, + name="stub_promote", + value=dumps("wire-value"), + ) + # Decoding seeds a default-less stub for the key. + _decoded = ChainManifest.from_protobuf(pb, serializer=wool.__serializer__) # noqa: F841 # Act - value = var.get("fallback") + var = ContextVar("stub_promote", namespace=var_namespace, default="promoted") # Assert - assert value == "fallback" + assert var.get() == "promoted" - def test_set_token_when_reset_from_first_set(self): - """Test the first set's Token restores the var to its default on reset. + def test_contextvar_should_declare_empty_slots(self): + """Test ContextVar declares no instance slots of its own. Given: - A ContextVar with a constructor default and a single set - that captures a Token + The ContextVar type, whose declared instances must reuse the + exact storage a pre-declaration placeholder already holds. + When: + Its own __slots__ are inspected. + Then: + They should be empty — the identity and data fields live on + the base, and an empty subclass layout is what lets a + placeholder be upgraded to a full variable in place when the + variable is declared. + """ + # Act & assert + assert ContextVar.__slots__ == () + + def test_declaration_should_keep_placeholder_object_and_value_in_place(self): + """Test declaring a wire-seeded variable keeps its object and value. + + Given: + A variable whose value was applied to its placeholder backing + before the receiver declared it — modelling a propagated + value that arrived ahead of the declaration. When: - reset(token) is called + The receiver declares the variable for that key. Then: - var.get() should return the constructor default + Declaration should hand back the same object the placeholder + occupied, now usable, and reading it should yield the value + applied before declaration — the upgrade happens in place, so + nothing the placeholder held is lost. """ # Arrange - var = ContextVar("restore_default", default="initial") - token = var.set("x") + namespace = uuid.uuid4().hex + name = "inplace_promote" + placeholder = resolve_stub((namespace, name)) + placeholder_id = id(placeholder) + placeholder._backing.set("pre-declaration") # Act - var.reset(token) + var = ContextVar(name, namespace=namespace) # Assert - assert var.get() == "initial" + assert id(var) == placeholder_id + assert var.get() == "pre-declaration" - def test_set_token_restores_outer_value_when_reset_from_nested_set(self): - """Test a nested set's Token restores the outer set's value on reset. + @pytest.mark.parametrize( + "non_str_name", + [42, None, ("ns", "name"), b"bytes-name"], + ids=["int", "none", "tuple", "bytes"], + ) + def test___new___should_raise_type_error_when_name_not_str(self, non_str_name): + """Test ContextVar construction rejects a non-str name. Given: - A ContextVar set to "outer" and then set again to "inner" - capturing the inner Token + A name argument that is not a str — an int, None, a tuple, + or bytes. When: - reset(inner_token) is called + ContextVar is constructed with that argument. Then: - var.get() should return "outer" — Tokens stack, each - restoring only the value replaced by its own set + It should raise TypeError — the documented contract is that + a context variable name must be a str. + """ + # Act & assert + with pytest.raises(TypeError, match="name must be a str"): + ContextVar(non_str_name) + + def test_repr_should_include_namespace_and_name(self): + """Test ContextVar repr includes the namespace and name. + + Given: + A ContextVar with a name. + When: + repr() is called on it. + Then: + It should include the namespace and the name. """ # Arrange - var = ContextVar("restore_outer", default="d") - var.set("outer") - inner_token = var.set("inner") + var = ContextVar(_unique("repr_cv")) # Act - var.reset(inner_token) + text = repr(var) # Assert - assert var.get() == "outer" + assert f"namespace={var.namespace!r}" in text + assert f"name={var.name!r}" in text - def test_reset_with_used_token(self): - """Test ContextVar.reset raises on a token already consumed. + @given( + ctor_default=st.one_of(st.none(), st.just(_NOTHING), st.integers(), st.text()), + get_arg=st.one_of(st.none(), st.just(_NOTHING), st.integers(), st.text()), + value_set=st.booleans(), + armed=st.booleans(), + set_value=st.one_of(st.integers(), st.text()), + ) + @settings(suppress_health_check=[HealthCheck.function_scoped_fixture]) + def test_get_should_resolve_the_default_ladder( + self, ctor_default, get_arg, value_set, armed, set_value + ): + """Test ContextVar.get walks the bound-value, get-arg, ctor-default ladder. Given: - A ContextVar and a Token that has already been used + Any combination of a constructor default (present or + absent), a get fallback argument (present or absent), + whether the variable itself was set, and whether the + context was independently armed by a sibling variable. When: - reset() is called with the same token again + get() is called — with or without the fallback argument. Then: - It should raise RuntimeError + It should return the highest-priority rung available: a set + value first, then the supplied fallback, then the + constructor default, raising LookupError only when none of + the three is present — and an unset variable on an armed + context still falls through the ladder rather than raising. """ # Arrange - var = ContextVar("once", default=0) - token = var.set(1) - var.reset(token) + ctor_kwargs = {} if ctor_default is _NOTHING else {"default": ctor_default} + var = ContextVar(_unique("ladder"), **ctor_kwargs) # Act & assert - with pytest.raises(RuntimeError): - var.reset(token) + with scoped_context(): + if armed: + # Arm the context via a sibling so ``var`` itself stays + # unset on an armed chain — the armed-unset rung. + ContextVar(_unique("ladder_sibling")).set("arms-it") + if value_set: + var.set(set_value) + args = () if get_arg is _NOTHING else (get_arg,) + if value_set: + assert var.get(*args) == set_value + elif get_arg is not _NOTHING: + assert var.get(*args) == get_arg + elif ctor_default is not _NOTHING: + assert var.get(*args) == ctor_default + else: + with pytest.raises(LookupError): + var.get(*args) + + def test_get_should_return_none_when_default_is_none(self): + """Test ContextVar.get treats a None default as a real default, not unset. - def test_reset_with_token_for_different_var(self): - """Test ContextVar.reset rejects tokens minted by a different var. + Given: + One ContextVar with no constructor default and one with an + explicit default=None, both unset. + When: + get(None) is called on the first and get() on the second. + Then: + Both return None rather than raising LookupError — a + supplied or constructor None is a real default, distinct + from "no default supplied", mirroring + contextvars.ContextVar.get. + """ + # Arrange + no_ctor_default = ContextVar(_unique("supplied_none")) + ctor_none_default = ContextVar(_unique("ctor_none"), default=None) + + # Act & assert + assert no_ctor_default.get(None) is None + assert ctor_none_default.get() is None + + def test_get_should_raise_type_error_when_more_than_one_positional_argument(self): + """Test ContextVar.get rejects more than one positional argument. Given: - Two distinct ContextVar instances, each with a set value + An armed ContextVar. When: - reset() is called on one with the other's token + get is called with two positional arguments. Then: - It should raise ValueError + It should raise TypeError — get accepts at most one + fallback argument, mirroring contextvars.ContextVar.get. """ # Arrange - a = ContextVar("reset_a", default=0) - b = ContextVar("reset_b", default=0) - token = a.set(1) + var = ContextVar(_unique("get_too_many_args")) # Act & assert - with pytest.raises(ValueError): - b.reset(token) + with scoped_context(): + with pytest.raises(TypeError, match="at most 1 argument"): + var.get("a", "b") - def test_reset_restores_old_value_in_different_context_scope(self): - """Test reset restores old_value when invoked in a different Context scope. + def test_get_should_return_set_value(self): + """Test ContextVar.get returns the value installed by set. Given: - A ContextVar set twice in the original Context, then a - different wool.Context scope entered via Context.run + A ContextVar with a value set in the active context. When: - reset(token) is invoked inside the scoped Context + get() is called. Then: - The var should revert to the value captured in the token + It should return the set value. """ # Arrange - var = ContextVar("reset_fallback", default="initial") - var.set("first") - token = var.set("second") - seeded = contextvars.copy_context() + var = ContextVar(_unique("get_set")) - def body(): - var.set("outer-most") - var.reset(token) - return var.get() + # Act + with scoped_context(): + var.set("value") + + # Assert + assert var.get() == "value" + + def test_get_should_not_arm_an_unarmed_context(self): + """Test ContextVar.get on an unarmed context leaves it unarmed. + + Given: + A fresh unarmed context and a ContextVar with a default. + When: + get() is called. + Then: + current_context should still be None — a read never arms a + context. + """ + # Arrange + var = ContextVar(_unique("get_no_arm"), default="d") # Act - result = seeded.run(body) + with scoped_context(): + var.get() - # Assert - assert result == "first" + # Assert + assert wool.__chain__.get(None) is None - def test_reset_with_token_from_different_in_process_context(self): - """Test reset rejects a token minted in a different in-process Context. + def test_set_should_arm_an_unarmed_context(self): + """Test the first set arms the context with a context. Given: - A ContextVar and a Token minted by var.set(...) inside - ctx_a.run(...) — the token still holds its in-memory - Context reference (distinct from the cross-process - scenario covered elsewhere, which tests the UUID fallback - after pickling) + A fresh, unarmed context where current_context is None. When: - var.reset(token) is called inside ctx_b.run(...) where - ctx_b is a different wool.Context in the same process + ContextVar.set is called for the first time. Then: - ValueError is raised with a message naming the different - wool.Context — the in-process identity check fires, not - the UUID fallback + current_context should return a Chain. """ # Arrange - var = ContextVar("inprocess_reset", default="d") - ctx_a = Context() - ctx_b = Context() - captured: list[Token] = [] + var = ContextVar(_unique("set_arms")) + with scoped_context(): + assert wool.__chain__.get(None) is None - def inside_a(): - captured.append(var.set("a_value")) + # Act + var.set("x") - ctx_a.run(inside_a) - token = captured[0] + # Assert + assert wool.__chain__.get(None) is not None - def inside_b(): - var.reset(token) + def test_set_should_return_token(self): + """Test ContextVar.set returns a Token. - # Act & assert - with pytest.raises(ValueError, match="different wool.Context"): - ctx_b.run(inside_b) + Given: + A ContextVar in an armed context. + When: + set is called. + Then: + It should return a Token instance. + """ + # Arrange + var = ContextVar(_unique("set_token")) + + # Act + with scoped_context(): + token = var.set("x") + + # Assert + assert isinstance(token, Token) - def test_reset_with_reconstituted_token_when_chain_id_differs(self): - """Test reset raises ValueError for a reconstituted Token whose - originating chain id does not match the current Context. + def test_set_should_return_token_whose_var_is_not_the_wool_var(self): + """Test the returned token does not reference the wool ContextVar. Given: - A ContextVar, a Token minted in Context A and pickled, - the original Token then released and garbage-collected - so the process-wide token registry's weak entry drops, - and a fresh scoped_context(B) block whose id differs - from A + A ContextVar in an armed context. When: - The pickled bytes are loaded (producing a reconstituted - Token with _context=None and _context_id=A) and - var.reset(restored) is called under Context B + set is called and the returned token's ``var`` is inspected. Then: - ValueError is raised with a message naming the different - wool.Context — the UUID fallback check on the - reconstituted branch fires because object-identity - comparison isn't available + ``token.var`` should not be the wool ContextVar — the + deliberate, disclosed divergence from stdlib, where + ``var.set(x).var is var`` holds. The supported reset path is + ``var.reset(token)``. """ # Arrange - var = ContextVar("reconstituted_chain_mismatch", default="d") - pickled: bytes + var = ContextVar(_unique("set_token_var")) - def mint_and_release() -> bytes: - with scoped_context(): - token = var.set("x") - return wool.__serializer__.dumps(token) + # Act + with scoped_context(): + token = var.set("x") - pickled = mint_and_release() - gc.collect() + # Assert + assert token.var is not var + assert token.var is var._backing - # Act & assert + def test_set_should_rebuild_the_context_on_each_call(self): + """Test each set rebuilds a new immutable context. + + Given: + A wool.ContextVar set once on an armed context. + When: + set is called a second time. + Then: + It should install a context distinct by identity from the + first — set rebuilds the immutable context rather than + mutating it in place. + """ + # Arrange + var = ContextVar(_unique("set_cow")) + + # Act with scoped_context(): - restored = cloudpickle.loads(pickled) - with pytest.raises(ValueError, match="different wool.Context"): - var.reset(restored) + var.set("first") + first_context = wool.__chain__.get(None) + var.set("second") + second_context = wool.__chain__.get(None) - def test_reset_restores_unset_in_different_context_scope(self): - """Test reset restores unset state when old_value was MISSING. + # Assert + assert first_context is not None + assert second_context is not None + assert second_context is not first_context + + def test_set_should_clear_a_prior_reset_signal(self): + """Test re-setting a variable clears its reset-to-no-value signal. Given: - A previously-unset ContextVar whose set() Token captured - MISSING as the old_value, then a different wool.Context - scope entered via Context.run + A ContextVar reset to no value, leaving its key in the + active context's resets set. When: - reset(token) is invoked in the scoped Context + The variable is set again. Then: - The var should revert to unset; get(fallback) returns the - supplied fallback + Its key should be absent from the context's resets — + a re-set variable is no longer in a reset state. """ # Arrange - var = ContextVar("reset_unset_fallback") - token = var.set("briefly") - seeded = contextvars.copy_context() + var = ContextVar(_unique("set_clears_reset")) - def body(): - var.set("nested") + # Act & assert + with scoped_context(): + token = var.set("first") var.reset(token) - return var.get("") + after_reset = wool.__chain__.get(None) + assert after_reset is not None + assert (var.namespace, var.name) in after_reset.resets + var.set("second") + after_set = wool.__chain__.get(None) + assert after_set is not None + assert (var.namespace, var.name) not in after_set.resets + + @given(payload=st.text() | st.integers() | st.lists(st.integers())) + @settings(suppress_health_check=[HealthCheck.function_scoped_fixture]) + def test_set_should_round_trip_with_arbitrary_value(self, payload): + """Test set followed by get round-trips an arbitrary value. + + Given: + Any text, integer, or list-of-integers value. + When: + The value is set on a ContextVar and read back via get. + Then: + get should return a value equal to the one set. + """ + # Arrange + var = ContextVar(_unique("rt_set_get")) + + # Act & assert + with scoped_context(): + var.set(payload) + assert var.get() == payload + + @pytest.mark.asyncio + async def test_set_should_adopt_an_armed_chain_with_no_live_owning_task(self): + """Test ContextVar.set adopts ownership of an ownerless armed chain. + + Given: + An armed chain with no owner task — the shape a wire- + decoded context has before its worker mounts it — entered + by an asyncio task created with no Wool task factory + installed (so the chain is not forked). + When: + The running task calls set on a wool.ContextVar. + Then: + The active context's owning task (``Chain.task``) should + become the running task — a new runner adopts an ownerless + armed chain so a later concurrent task still fails loud via + the owning-task guard. + """ + # Arrange — clear any task factory so the child task does not + # fork the chain onto a fresh owner. + loop = asyncio.get_running_loop() + loop.set_task_factory(None) + var = ContextVar(_unique("set_adoption")) + + with scoped_context(): + # An armed chain with task=None — the shape of a context + # decoded off the wire before a worker mounts it. + wool.__chain__.set( + Chain( + id=uuid.uuid4(), + thread=threading.get_ident(), + task=None, + ) + ) + armed = wool.__chain__.get(None) + assert armed is not None + assert armed.task is None + + async def adopt() -> object: + # Act — the running task sets the var on the ownerless + # armed chain. + var.set("adopted") + context = wool.__chain__.get(None) + assert context is not None + ref = context.task + return None if ref is None else ref() + + try: + task = asyncio.ensure_future(adopt()) + owning_task = await task + finally: + loop.set_task_factory(None) + + # Assert — the running task adopted ownership. + assert owning_task is task + + def test_set_should_propagate_value_into_seeded_contextvars_context(self): + """Test a set value is visible inside a stdlib Context copied after the set. + + Given: + A ContextVar set to a value, then a stdlib contextvars.Context + copied from the live context. + When: + The var is read inside that copied Context via Context.run. + Then: + It should observe the set value — Wool chain state rides in a + stdlib contextvars.ContextVar and copies with stdlib semantics. + """ + # Arrange + var = ContextVar(_unique("propagate_seeded")) # Act - result = seeded.run(body) + with scoped_context(): + var.set("seeded") + seeded = contextvars.copy_context() + observed = seeded.run(var.get) # Assert - assert result == "" + assert observed == "seeded" - def test_pickle_roundtrip_when_var_registered(self): - """Test cloudpickle roundtrips a ContextVar to the same registered instance. + def test_reset_should_restore_default_when_first_set_token(self): + """Test the first set's Token restores the var to its default on reset. Given: - A ContextVar registered in the process-wide registry + A ContextVar with a constructor default and a single set + that captures a Token. When: - It is pickled and unpickled via dumps / loads + reset(token) is called. Then: - The unpickled instance should be the same object as the original + var.get() should return the constructor default. """ # Arrange - var = ContextVar("shipped") + var = ContextVar(_unique("restore_default"), default="initial") # Act - restored = loads(dumps(var)) + with scoped_context(): + token = var.set("x") + var.reset(token) - # Assert - assert restored is var + # Assert + assert var.get() == "initial" - def test_pickle_roundtrip_when_key_unregistered_on_receiver(self): - """Test loads() creates a stub when the var key is not registered. + def test_reset_should_restore_outer_value_when_nested_set_token(self): + """Test a nested set's Token restores the outer set's value on reset. Given: - Pickle bytes of a ContextVar whose registry slot has been - reclaimed via GC (the var instance was constructed inside - a nested Context.run block, pickled while bound to a - value, and released on block exit so the - WeakValueDictionary entry can drop) + A ContextVar set to "outer" and then set again to "inner" + capturing the inner Token. When: - loads(pickled) is called inside a fresh Context scope + reset(inner_token) is called. Then: - The restored var's key and default match the original, - but the receiver's Context contains no entry for the var - and get(fallback) returns the fallback — pickle preserves - identity only and never writes to the receiver's Context + var.get() should return "outer" — Tokens stack, each + restoring only the value replaced by its own set. """ - # Arrange — pickle inside a nested Context.run so the - # ephemeral var's strong ref in that Context's data dict - # drops with the scope, letting the WeakValueDictionary slot - # clear on gc.collect(). - captured: list[tuple[bytes, tuple[str, str]]] = [] + # Arrange + var = ContextVar(_unique("restore_outer"), default="d") - def pickle_ephemeral(): - ephemeral = ContextVar( - "stub_unknown", - namespace="test_stub_create", - default="d", - ) - ephemeral.set("sender_value") - captured.append((dumps(ephemeral), (ephemeral.namespace, ephemeral.name))) + # Act + with scoped_context(): + var.set("outer") + inner_token = var.set("inner") + var.reset(inner_token) - Context().run(pickle_ephemeral) - gc.collect() + # Assert + assert var.get() == "outer" - pickled, original_key = captured[0] + def test_reset_should_raise_runtime_error_when_token_already_used(self): + """Test ContextVar.reset raises on a token already consumed. - observed: list[tuple[ContextVar, Context]] = [] + Given: + A ContextVar and a Token that has already been used. + When: + reset() is called with the same token again. + Then: + It should raise RuntimeError — the reset routes through the + backing variable's native single-use bit. + """ + # Arrange + var = ContextVar(_unique("once"), default=0) - def in_fresh(): - restored = loads(pickled) - observed.append((restored, current_context())) + # Act & assert + with scoped_context(): + token = var.set(1) + var.reset(token) + with pytest.raises(RuntimeError, match="has already been used"): + var.reset(token) - # Act - Context().run(in_fresh) + def test_reset_should_raise_value_error_when_token_for_different_var(self): + """Test ContextVar.reset rejects tokens minted by a different var. - # Assert - assert len(observed) == 1 - restored, receiver_ctx = observed[0] - assert (restored.namespace, restored.name) == original_key - assert restored.get("fallback") == "fallback" - assert restored not in receiver_ctx + Given: + Two distinct ContextVar instances, each with a set value. + When: + reset() is called on one with the other's token. + Then: + Stdlib's `contextvars.ContextVar.reset` raises + `ValueError` naming the different ContextVar. + """ + # Arrange + a = ContextVar(_unique("reset_a"), default=0) + b = ContextVar(_unique("reset_b"), default=0) - def test_pickle_roundtrip_does_not_write_to_receiver_context(self): - """Test pickling a ContextVar does not propagate its value via reduce. + # Act & assert + with scoped_context(): + token = a.set(1) + with pytest.raises( + ValueError, match="was created by a different ContextVar" + ): + b.reset(token) + + @pytest.mark.parametrize( + "not_a_token", + ["a-string", 42, None], + ids=["str", "int", "none"], + ) + def test_reset_should_raise_type_error_when_argument_not_token(self, not_a_token): + """Test ContextVar.reset rejects an argument that is not a Token. Given: - A ContextVar set to a specific value in the current - context + An armed ContextVar and an argument that is not a + wool.Token — a string, an int, or None. When: - The var is pickled and unpickled inside a fresh - wool.Context + reset() is called with that argument. Then: - The receiver's Context contains no binding for the var - and var.get() falls back through the receiver Context's - empty data to the constructor default — pickle is a key- - only payload; state propagation rides on the wire-context - path, not the reduce tuple + Stdlib's `contextvars.ContextVar.reset` raises + `TypeError` naming Token — the type check guards + reset before any token-state inspection. """ # Arrange - var = ContextVar("pickle_with_value", default="default_value") - var.set("sender_value") - pickled = dumps(var) + var = ContextVar(_unique("reset_not_token")) - # Act - observed_get: list[object] = [] - observed_ctx: list[Context] = [] + # Act & assert + with scoped_context(): + var.set("x") + with pytest.raises(TypeError, match="instance of Token"): + var.reset(not_a_token) - def in_fresh(): - restored = loads(pickled) - observed_get.append(restored.get()) - observed_ctx.append(current_context()) + def test_reset_should_raise_runtime_error_when_used_token_for_different_var(self): + """Test ContextVar.reset rejects an already-used token even on a different var. - Context().run(in_fresh) + Given: + A Token that has already been consumed by a reset on its own + ContextVar, passed to reset on a different ContextVar — so + the token is BOTH already-used AND created by a different + var. + When: + reset() is called with that token on the other var. + Then: + Stdlib's `contextvars.ContextVar.reset` raises + `RuntimeError` for the used token — the used-token + check runs first, before the different-ContextVar check. + """ + # Arrange + a = ContextVar(_unique("reset_used_a"), default=0) + b = ContextVar(_unique("reset_used_b"), default=0) - # Assert - assert observed_get == ["default_value"] - assert var not in observed_ctx[0] + # Act & assert + with scoped_context(): + token = a.set(1) + a.reset(token) + with pytest.raises(RuntimeError, match="has already been used"): + b.reset(token) - def test_pickle_roundtrip_when_var_unset(self): - """Test pickling a never-set ContextVar leaves the receiver clean. + def test_reset_should_restore_unset_state_when_old_value_missing(self): + """Test reset restores unset state when old_value was MISSING. Given: - A ContextVar that has never been set in the current - context + A previously-unset ContextVar whose set() Token captured + MISSING as the old_value. When: - The var is pickled and unpickled inside a fresh - wool.Context + reset(token) is invoked. Then: - The receiver's var.get() returns the class-level default - and the receiver's Context does not contain the var + The var should revert to unset; get(fallback) returns the + supplied fallback. """ # Arrange - var = ContextVar("pickle_no_value", default="default_value") - pickled = dumps(var) + var = ContextVar(_unique("reset_unset")) # Act - observed_get: list[object] = [] - observed_ctx: list[Context] = [] + with scoped_context(): + token = var.set("briefly") + var.reset(token) - def in_fresh(): - restored = loads(pickled) - observed_get.append(restored.get()) - observed_ctx.append(current_context()) + # Assert + assert var.get("") == "" - Context().run(in_fresh) + def test_reset_should_raise_value_error_when_token_reset_in_copied_context(self): + """Test resetting a token in a copy_context copy is rejected. - # Assert - assert observed_get == ["default_value"] - assert var not in observed_ctx[0] + Given: + A ContextVar set in one context, and a copy of that context + taken via contextvars.copy_context. + When: + The token is reset inside the copy, then reset again in the + original context. + Then: + The reset in the copy should raise ValueError — the reset + delegates to the backing variable's native reset, and stdlib + rejects a token reset in a different contextvars.Context + than the one it was minted in — while the reset in the + original context still succeeds. + """ + # Arrange + var = ContextVar(_unique("reset_copy_rejects")) - def test_later_declaration_after_preloaded_var(self): - """Test a later ContextVar declaration promotes a pre-existing stub. + # Act & assert + with scoped_context(): + token = var.set("x") + copy_ctx = contextvars.copy_context() + with pytest.raises(ValueError, match="different Context"): + copy_ctx.run(var.reset, token) + var.reset(token) + assert var.get("") == "" + + def test_reset_should_not_raise_when_chain_unarmed_after_native_reset(self, mocker): + """Test reset no-ops when the active context reads unarmed post-native-reset. Given: - A stub produced by unpickling an unknown-key pickle — - the registry holds a stub after loads() + A ContextVar armed by a set that captures a Token, with the + active Wool context patched to report the chain unarmed. When: - ContextVar(name, namespace=ns, default=) is constructed - with the same (namespace, name) + reset(token) is called — the native backing reset succeeds, + then the Wool-bookkeeping branch observes no armed chain. Then: - The declaration returns the same instance as the restored - stub (no ContextVarCollision raised), and the declaration's - default value is observed via var.get() in a fresh Context - distinct from the stub's wire-value receiver Context + reset should return without raising — the defensive no-op + guarding the post-native-reset path that is otherwise + unreachable, since stdlib's native reset would itself have + raised before an armed chain could vanish. """ - # Arrange — pickle inside a nested Context.run, let gc clear - # the slot, then unpickle to create a stub. - captured: list[bytes] = [] + # Arrange + var = ContextVar(_unique("reset_unarmed_noop"), default="d") - def pickle_ephemeral(): - ephemeral = ContextVar( - "stub_promoted", - namespace="test_stub_promote", - default="d", - ) - ephemeral.set("wire_value") - captured.append(dumps(ephemeral)) + # Act & assert + unset = object() + with scoped_context(): + token = var.set("x") + mocker.patch.object(wool, "__chain__").get.return_value = None + var.reset(token) + assert var.get(unset) is unset - Context().run(pickle_ephemeral) - gc.collect() + def test_pickle_should_roundtrip_to_same_instance_when_var_registered(self): + """Test the wool serializer roundtrips a ContextVar to the same instance. - pickled = captured[0] + Given: + A ContextVar registered in the process-wide registry. + When: + It is dumped via the wool serializer and loaded back. + Then: + The unpickled instance should be the same object as the + original. + """ + # Arrange + var = ContextVar(_unique("shipped")) - # Hold a strong reference to the receiver Context so it - # outlives the pickle load, ensuring the later - # ContextVar(...) declaration runs against the registry - # state the load produced rather than landing on state - # already reclaimed by GC. - receiver = Context() - stub_capture: list[ContextVar] = [] + # Act + restored = loads(dumps(var)) - def in_receiver(): - stub_capture.append(loads(pickled)) + # Assert + assert restored is var - receiver.run(in_receiver) - restored = stub_capture[0] + def test_pickle_should_not_write_to_receiver_context_on_roundtrip(self): + """Test the wool serializer roundtrip of a ContextVar carries identity only. - # Act — authoritative declaration with a new default. - promoted = ContextVar( - "stub_promoted", - namespace="test_stub_promote", - default="auth_default", - ) + Given: + A ContextVar set to a specific value in the current context. + When: + The var is dumped and loaded inside a fresh, unarmed + contextvars.Context. + Then: + The receiver's context stays unarmed and var.get falls back + to the constructor default — the pickle path is a key-only + payload, never a value context, so the value never reaches + the receiver's backing variable. + """ + # Arrange + var = ContextVar(_unique("pickle_with_value"), default="default_value") - # Assert — same instance (stub was promoted, not replaced). - assert promoted is restored + # Act + with scoped_context(): + var.set("sender_value") + pickled = dumps(var) - # And the promoted default is observed in a fresh Context - # (distinct from the receiver Context that holds wire_value). - observed_default: list[object] = [] + observed: list[object] = [] - def in_fresh(): - observed_default.append(promoted.get()) + def _receive() -> None: + restored = loads(pickled) + observed.append(restored.get()) + observed.append(wool.__chain__.get(None)) - Context().run(in_fresh) + # A fresh contextvars.Context carries neither the wool context + # nor the variable's backing contextvars.ContextVar value. + contextvars.Context().run(_receive) - assert observed_default == ["auth_default"] + # Assert + assert observed[0] == "default_value" + assert observed[1] is None - def test_repr_includes_namespace_and_name(self): - """Test ContextVar repr includes the namespace and name. + def test_pickle_should_seed_default_into_stub_when_reloaded_after_wire_decode(self): + """Test loading a cloudpickled ContextVar with a default seeds an existing stub. Given: - A ContextVar with a name + A receiver process where a wire protocol.ChainManifest carrying a + value for a ContextVar key is decoded first — creating a + default-less stub — and a cloudpickle dump of a ContextVar + for the same key with a non-Undefined default is loaded + second, after the originating ContextVar has been released. When: - repr() is called on it + cloudpickle.loads runs over the pickle. Then: - The repr should include the namespace and the name + It should return a ContextVar for that key whose get falls + back to the pickled default — the second ingress folded the + default into the stub rather than discarding it. """ # Arrange - var = ContextVar("repr_cv") + var_namespace = uuid.uuid4().hex + var = ContextVar( + "wire_then_pickle", namespace=var_namespace, default="from-pickle" + ) + pickled_var = dumps(var) + del var + gc.collect() + + pb = protocol.ChainManifest(id=uuid.uuid4().hex) + pb.vars.add( + namespace=var_namespace, + name="wire_then_pickle", + value=dumps("from-wire"), + ) + # Bound to keep the decoded manifest alive so its stubs anchor + # holds the registered stub through the loads() call below. + _decoded = ChainManifest.from_protobuf(pb, serializer=wool.__serializer__) # noqa: F841 # Act - text = repr(var) + restored_var = loads(pickled_var) # Assert - assert f"namespace={var.namespace!r}" in text - assert f"name={var.name!r}" in text + assert restored_var.get() == "from-pickle" - def test_vanilla_pickle_copy_and_deepcopy_are_rejected(self): - """Test wool.ContextVar rejects pickle, cloudpickle, copy.copy, - and copy.deepcopy. + def test_pickle_should_reject_vanilla_pickle_copy_and_deepcopy(self): + """Test wool.ContextVar rejects pickle, cloudpickle, copy.copy, and deepcopy. Given: - A wool.ContextVar + A wool.ContextVar. When: pickle.dumps, cloudpickle.dumps, copy.copy, and - copy.deepcopy are each invoked on it + copy.deepcopy are each invoked on it. Then: - All four raise TypeError. ContextVar identity lives in - the process-wide var_registry, and a vanilla restore - outside Wool's dispatch path bypasses the stub-promotion - and collision-detection that ContextVar._reconstitute - relies on — wool.__serializer__ remains the only valid - serialization channel. + All four raise TypeError — wool.__serializer__ remains the + only valid serialization channel. """ # Arrange - import copy as _copy - import pickle - match = "wool.ContextVar cannot be pickled" - var = ContextVar("var_pickle_rejection") + var = ContextVar(_unique("var_pickle_rejection")) # Act & assert with pytest.raises(TypeError, match=match): @@ -641,149 +1001,6 @@ def test_vanilla_pickle_copy_and_deepcopy_are_rejected(self): with pytest.raises(TypeError, match=match): cloudpickle.dumps(var) with pytest.raises(TypeError, match=match): - _copy.copy(var) + copy.copy(var) with pytest.raises(TypeError, match=match): - _copy.deepcopy(var) - - -@pytest.mark.asyncio -async def test_contextvar_pickle_reload_after_wire_decode_seeds_default(): - """Test loading a cloudpickled :class:`ContextVar` with a default - seeds the receiver's stub when an earlier wire decode created it - without one. - - Given: - A receiver process where a wire :class:`protocol.Context` - carrying a value for a :class:`ContextVar` key is decoded - first — creating a default-less stub via - :meth:`Context.from_protobuf` — and a cloudpickle dump of a - :class:`ContextVar` for the same key with a non-Undefined - default is loaded second, after the originating - :class:`ContextVar` has been released so the registry no - longer holds it. - When: - :func:`cloudpickle.loads` runs over the pickle. - Then: - It should return a :class:`ContextVar` for that key whose - :meth:`get` falls back to the pickled default — the second - ingress folded the default into the stub rather than silently - discarding it. - """ - # Arrange - from wool import protocol - - var_namespace = uuid.uuid4().hex - - var = ContextVar("wire_then_pickle", namespace=var_namespace, default="from-pickle") - pickled_var = dumps(var) - del var - gc.collect() - - pb = protocol.Context(id=uuid.uuid4().hex) - pb.vars.add( - namespace=var_namespace, - name="wire_then_pickle", - value=dumps("from-wire"), - ) - # Bound to keep the wire-side Context alive so its stub anchor - # holds the registered stub through the loads() call below. - _secondary = Context.from_protobuf(pb) # noqa: F841 - - # Act - restored_var = loads(pickled_var) - - # Assert - assert restored_var.get() == "from-pickle" - - -def test_no_wool_contextvar_constructions_outside_runtime_context(): - """Test wool source has no ``wool.ContextVar(...)`` constructor - calls outside the ``wool.runtime.context`` subpackage. - - Given: - The wool source tree under ``src/wool/`` - When: - Each module is parsed and walked for ``Call`` nodes whose - target resolves to ``wool.ContextVar`` (via any import alias - — ``wool.ContextVar``, ``from wool import ContextVar``, - ``from wool.runtime.context import ContextVar``, or the - deeper ``wool.runtime.context.var.ContextVar``) - Then: - No constructor call is found outside ``wool/runtime/context/`` - — ``_infer_namespace`` only skips frames within that - submodule, so a wool-internal construction beyond it would - silently misattribute the var into ``wool``'s namespace - rather than the user's package. Adding such a construction - without first broadening the inference (or passing - ``namespace=`` explicitly) would be a latent footgun this - test catches at test time - """ - import ast - import importlib.util - from pathlib import Path - - # Arrange - spec = importlib.util.find_spec("wool") - assert spec is not None and spec.origin is not None - src_root = Path(spec.origin).resolve().parent - - def _resolves_to_contextvar(name: str, imports: dict[str, str]) -> bool: - target = imports.get(name) - return target in { - "wool.ContextVar", - "wool.runtime.context.ContextVar", - "wool.runtime.context.var.ContextVar", - } - - def _check_call( - call: ast.Call, imports: dict[str, str], offending: list[int] - ) -> None: - func = call.func - if isinstance(func, ast.Attribute) and isinstance(func.value, ast.Name): - # ``wool.ContextVar(...)`` style — bare alias resolves to wool. - if imports.get(func.value.id) == "wool" and func.attr == "ContextVar": - offending.append(call.lineno) - elif isinstance(func, ast.Name) and _resolves_to_contextvar(func.id, imports): - offending.append(call.lineno) - - def _scan(file: Path) -> list[int]: - tree = ast.parse(file.read_text(), filename=str(file)) - # Flow-sensitive: walk the top-level body in source order, mutating - # the import map as Imports are encountered so that calls earlier - # in the file resolve against earlier imports — a later import that - # rebinds the same alias (e.g. ``from contextvars import ContextVar`` - # after a wool import of the same name) does not retroactively mask - # an earlier wool construction. - imports: dict[str, str] = {} - offending: list[int] = [] - for stmt in tree.body: - if isinstance(stmt, ast.Import): - for alias in stmt.names: - imports[alias.asname or alias.name.partition(".")[0]] = alias.name - continue - if isinstance(stmt, ast.ImportFrom) and stmt.module: - for alias in stmt.names: - imports[alias.asname or alias.name] = f"{stmt.module}.{alias.name}" - continue - for node in ast.walk(stmt): - if isinstance(node, ast.Call): - _check_call(node, imports, offending) - return offending - - # Act - offenders: list[tuple[Path, int]] = [] - runtime_context = src_root / "runtime" / "context" - for py_file in src_root.rglob("*.py"): - if py_file.is_relative_to(runtime_context): - continue - for lineno in _scan(py_file): - offenders.append((py_file.relative_to(src_root), lineno)) - - # Assert - assert not offenders, ( - "wool.ContextVar constructed outside wool.runtime.context — " - "_infer_namespace's skip predicate (`wool.runtime.context.*`) " - "would misattribute the namespace. Either pass `namespace=` " - "explicitly or broaden the skip predicate before adding such " - "a construction:\n" + "\n".join(f" {f}:{ln}" for f, ln in offenders) - ) + copy.deepcopy(var) diff --git a/wool/tests/runtime/discovery/test_base.py b/wool/tests/runtime/discovery/test_base.py index ee1befb9..59a80b1e 100644 --- a/wool/tests/runtime/discovery/test_base.py +++ b/wool/tests/runtime/discovery/test_base.py @@ -7,7 +7,7 @@ from hypothesis import given from hypothesis import strategies as st -from wool.protocol import WorkerMetadata as WorkerMetadataProtobuf +from wool import protocol as wire from wool.runtime.discovery.base import Discovery from wool.runtime.discovery.base import DiscoveryEvent from wool.runtime.discovery.base import DiscoveryEventType @@ -42,7 +42,7 @@ def metadata_message(): Creates a protobuf WorkerMetadata message with typical field values for use in tests that need to deserialize protobuf messages. """ - return WorkerMetadataProtobuf( + return wire.WorkerMetadata( uid="12345678-1234-5678-1234-567812345678", address="localhost:50051", pid=12345, @@ -169,7 +169,7 @@ def test_from_protobuf_invalid_uuid(self): It should raise ValueError """ # Arrange - protobuf = WorkerMetadataProtobuf( + protobuf = wire.WorkerMetadata( uid="invalid-uuid", address="localhost:50051", pid=12345, diff --git a/wool/tests/runtime/discovery/test_local.py b/wool/tests/runtime/discovery/test_local.py index 1d809ae2..812f6312 100644 --- a/wool/tests/runtime/discovery/test_local.py +++ b/wool/tests/runtime/discovery/test_local.py @@ -1004,6 +1004,29 @@ def contending_lock(fh, flags): assert attempt == 2 +class TestWorkerReference: + """Tests for the internal _WorkerReference value object.""" + + def test_is_hashable_by_its_uuid(self): + """Test a _WorkerReference hashes by its UUID. + + Given: + A _WorkerReference wrapping a UUID. + When: + It is hashed. + Then: + Its hash should equal the UUID's hash — references are + usable as dict keys / set members keyed by worker identity. + """ + # Arrange + from wool.runtime.discovery.local import _WorkerReference + + uid = uuid.uuid4() + + # Act & assert + assert hash(_WorkerReference(uid)) == hash(uid) + + class TestLocalDiscoverySubscriber: """Tests for LocalDiscovery.Subscriber class. diff --git a/wool/tests/runtime/loadbalancer/test_roundrobin.py b/wool/tests/runtime/loadbalancer/test_roundrobin.py index 4acc636d..3c8dcf29 100644 --- a/wool/tests/runtime/loadbalancer/test_roundrobin.py +++ b/wool/tests/runtime/loadbalancer/test_roundrobin.py @@ -1,4 +1,5 @@ import asyncio +from unittest.mock import MagicMock from uuid import uuid4 import pytest @@ -20,6 +21,17 @@ from wool.runtime.worker.metadata import WorkerMetadata +def make_task(callable): + """Build a `wool.Task` wrapping *callable* with a throwaway worker proxy.""" + return Task( + id=uuid4(), + callable=callable, + args=(), + kwargs={}, + proxy=MagicMock(spec=WorkerProxyLike, id="mock-proxy"), + ) + + @st.composite def dispatch_side_effects( draw, min_size: int, max_size: int, include_transient: bool = True @@ -28,10 +40,10 @@ def dispatch_side_effects( Generates a list of failure side effects drawn from the load-balancer's worker-health exception contract: only - :class:`RpcError` and its transient subclass - :class:`TransientRpcError` (the LB's documented catch surface). - Exceptions outside this contract — e.g. raw :class:`Exception`, - :class:`BaseExceptionGroup`, or local :class:`asyncio.TimeoutError` + `RpcError` and its transient subclass + `TransientRpcError` (the LB's documented catch surface). + Exceptions outside this contract — e.g., raw `Exception`, + `BaseExceptionGroup`, or local `asyncio.TimeoutError` — propagate past the LB to the caller and are exercised separately. @@ -79,11 +91,32 @@ def test_isinstance_satisfies_protocol(self): # Act & assert assert isinstance(RoundRobinLoadBalancer(), LoadBalancerLike) + def test_pickle_round_trip_yields_a_fresh_balancer(self): + """Test a RoundRobinLoadBalancer round-trips through pickle. + + Given: + A RoundRobinLoadBalancer instance. + When: + It is pickled and unpickled. + Then: + The result should be a fresh RoundRobinLoadBalancer — the + balancer reduces to its bare class, dropping the unpicklable + per-process rotation state and lock. + """ + # Arrange + import pickle + + balancer = RoundRobinLoadBalancer() + + # Act + restored = pickle.loads(pickle.dumps(balancer)) + + # Assert + assert isinstance(restored, RoundRobinLoadBalancer) + assert restored is not balancer + @pytest.mark.asyncio - async def test_dispatch_with_empty_context( - self, - mocker: MockerFixture, - ): + async def test_dispatch_with_empty_context(self): """Test dispatch raises NoWorkersAvailable with empty context. @@ -101,15 +134,7 @@ async def test_dispatch_with_empty_context( async def routine(): return "Hello world!" - mock_proxy = mocker.MagicMock(spec=WorkerProxyLike, id="mock-proxy") - - task = Task( - id=uuid4(), - callable=routine, - args=(), - kwargs={}, - proxy=mock_proxy, - ) + task = make_task(routine) # Act & assert with pytest.raises(NoWorkersAvailable): @@ -163,15 +188,7 @@ async def test_dispatch_with_healthy_workers( async def routine(): return "Hello world!" - mock_proxy = mocker.MagicMock(spec=WorkerProxyLike, id="mock-proxy") - - task = Task( - id=uuid4(), - callable=routine, - args=(), - kwargs={}, - proxy=mock_proxy, - ) + task = make_task(routine) # Track dispatch attempts to verify round-robin behavior tasks_dispatched = [] @@ -294,15 +311,7 @@ async def test_dispatch_with_all_workers_failing( async def routine(): return "Hello world!" - mock_proxy = mocker.MagicMock(spec=WorkerProxyLike, id="mock-proxy") - - task = Task( - id=uuid4(), - callable=routine, - args=(), - kwargs={}, - proxy=mock_proxy, - ) + task = make_task(routine) # Track dispatch attempts to verify round-robin behavior tasks_dispatched = [] @@ -410,18 +419,7 @@ async def dispatch_fn(task, *, timeout=None, m=metadata): async def routine(): return "Hello world!" - mock_proxy = mocker.MagicMock(spec=WorkerProxyLike, id="mock-proxy") - - tasks = [ - Task( - id=uuid4(), - callable=routine, - args=(), - kwargs={}, - proxy=mock_proxy, - ) - for _ in range(4) - ] + tasks = [make_task(routine) for _ in range(4)] # Act results = await asyncio.gather( @@ -478,22 +476,8 @@ async def test_dispatch_with_lock_release_on_success( async def routine(): return "Hello world!" - mock_proxy = mocker.MagicMock(spec=WorkerProxyLike, id="mock-proxy") - - task1 = Task( - id=uuid4(), - callable=routine, - args=(), - kwargs={}, - proxy=mock_proxy, - ) - task2 = Task( - id=uuid4(), - callable=routine, - args=(), - kwargs={}, - proxy=mock_proxy, - ) + task1 = make_task(routine) + task2 = make_task(routine) # Act await lb.dispatch(task1, context=ctx) @@ -547,15 +531,7 @@ async def dispatch_with_transient_then_success(task, *, timeout=None): async def routine(): return "Hello world!" - mock_proxy = mocker.MagicMock(spec=WorkerProxyLike, id="mock-proxy") - - task = Task( - id=uuid4(), - callable=routine, - args=(), - kwargs={}, - proxy=mock_proxy, - ) + task = make_task(routine) # Act & assert with pytest.raises(NoWorkersAvailable): @@ -614,15 +590,7 @@ async def test_dispatch_with_worker_removal( async def routine(): return "Hello world!" - mock_proxy = mocker.MagicMock(spec=WorkerProxyLike, id="mock-proxy") - - task = Task( - id=uuid4(), - callable=routine, - args=(), - kwargs={}, - proxy=mock_proxy, - ) + task = make_task(routine) # Act result = await lb.dispatch(task, context=ctx) @@ -640,17 +608,17 @@ async def test_dispatch_propagates_non_rpc_error_without_worker_removal( Given: A load balancer with one worker whose dispatch raises a - non-:class:`RpcError` exception (modelling e.g. a strict- - mode :class:`BaseExceptionGroup` of - :class:`wool.ContextDecodeWarning` peers from - :meth:`Context.to_protobuf`, or a programming-error - :class:`ValueError`). + non-`RpcError` exception (modelling e.g., a strict- + mode `BaseExceptionGroup` of + `wool.SerializationWarning` peers from + `ChainManifest.to_protobuf`, or a programming-error + `ValueError`). When: ``await lb.dispatch(...)`` is called. Then: The exception should propagate unwrapped to the caller and the worker should remain in the context — the LB's - worker-health contract treats only :class:`RpcError` + worker-health contract treats only `RpcError` instances as worker-health concerns, so a fault that has nothing to do with worker health does not evict the pool. """ @@ -675,14 +643,7 @@ async def test_dispatch_propagates_non_rpc_error_without_worker_removal( async def routine(): return "Hello world!" - mock_proxy = mocker.MagicMock(spec=WorkerProxyLike, id="mock-proxy") - task = Task( - id=uuid4(), - callable=routine, - args=(), - kwargs={}, - proxy=mock_proxy, - ) + task = make_task(routine) # Act & assert with pytest.raises(BaseExceptionGroup) as exc_info: @@ -741,18 +702,7 @@ async def dispatch_fn(task, *, timeout=None, m=metadata): async def routine(): return "Hello world!" - mock_proxy = mocker.MagicMock(spec=WorkerProxyLike, id="mock-proxy") - - tasks = [ - Task( - id=uuid4(), - callable=routine, - args=(), - kwargs={}, - proxy=mock_proxy, - ) - for _ in range(8) - ] + tasks = [make_task(routine) for _ in range(8)] # Act dispatch_futures = [ diff --git a/wool/tests/runtime/routine/test_task.py b/wool/tests/runtime/routine/test_task.py index cc18f8dd..b7278c8d 100644 --- a/wool/tests/runtime/routine/test_task.py +++ b/wool/tests/runtime/routine/test_task.py @@ -16,8 +16,8 @@ import wool from wool import protocol -from wool.runtime.context import RuntimeContext -from wool.runtime.context import dispatch_timeout +from wool.runtime.context.runtime import RuntimeContext +from wool.runtime.context.runtime import dispatch_timeout from wool.runtime.routine.task import Task from wool.runtime.routine.task import TaskException from wool.runtime.routine.task import current_task @@ -85,7 +85,7 @@ def __reduce_ex__(self, *_): class TestWorkerProxyLike: """Tests for :py:class:`WorkerProxyLike` protocol conformance.""" - def test_positive_conformance(self, sample_task): + def test_positive_conformance_should_instantiate_task(self, sample_task): """Test that a conforming proxy is accepted by Task. Given: @@ -119,7 +119,7 @@ async def dispatch(self, task, *, timeout=None): assert hasattr(task.proxy, "id") assert callable(task.proxy.dispatch) - def test_negative_conformance(self, sample_async_callable): + def test_negative_conformance_should_raise(self, sample_async_callable): """Test that a non-conforming proxy is rejected by Task. Given: @@ -152,7 +152,7 @@ def id(self): ) -def test_do_dispatch_with_default_context(): +def test_do_dispatch_should_return_true_when_no_active_context(): """Test do_dispatch returns True with no active context. Given: @@ -166,7 +166,7 @@ def test_do_dispatch_with_default_context(): assert do_dispatch() is True -def test_do_dispatch_with_false_flag(): +def test_do_dispatch_should_return_false_when_inside_false_context(): """Test do_dispatch returns False inside a False context. Given: @@ -183,7 +183,7 @@ def test_do_dispatch_with_false_flag(): assert do_dispatch() -def test_do_dispatch_with_nested_contexts(): +def test_do_dispatch_should_return_innermost_value_when_nested(): """Test do_dispatch tracks the innermost nested context. Given: @@ -205,7 +205,9 @@ def test_do_dispatch_with_nested_contexts(): @pytest.mark.asyncio -async def test_current_task_inside_task_context(sample_task, mock_worker_proxy_cache): +async def test_current_task_should_return_current_task_when_inside_task_context( + sample_task, mock_worker_proxy_cache +): """Test current_task returns the active Task during dispatch. Given: @@ -230,7 +232,7 @@ async def test_callable(): assert result == task -def test_current_task_outside_task_context(): +def test_current_task_should_return_none_when_outside_task_context(): """Test current_task returns None outside any task context. Given: @@ -248,7 +250,7 @@ def test_current_task_outside_task_context(): @pytest.mark.asyncio -async def test_current_task_with_nested_task_contexts(sample_task): +async def test_current_task_should_set_caller_to_outer_task_when_nested(sample_task): """Test nested task contexts set caller to the outer task. Given: @@ -277,7 +279,9 @@ async def test_current_task_with_nested_task_contexts(sample_task): ) @given(depth=st.integers(min_value=2, max_value=5)) @pytest.mark.asyncio -async def test_current_task_with_variable_nesting_depth(depth, sample_task): +async def test_current_task_should_track_caller_when_variable_nesting_depth( + depth, sample_task +): """Test nested task context tracking at variable depth. Given: @@ -318,7 +322,9 @@ class TestTask: """Tests for :py:class:`Task`.""" @pytest.mark.asyncio - async def test___post_init___inside_task_context(self, sample_task): + async def test___post_init___should_set_caller_when_inside_task_context( + self, sample_task + ): """Test post-init sets caller inside a task context. Given: @@ -339,7 +345,9 @@ async def test___post_init___inside_task_context(self, sample_task): # Assert assert inner_task.caller == outer_task.id - def test___post_init___outside_task_context(self, sample_task): + def test___post_init___should_leave_caller_none_when_outside_task_context( + self, sample_task + ): """Test post-init leaves caller as None without context. Given: @@ -355,7 +363,9 @@ def test___post_init___outside_task_context(self, sample_task): # Assert assert task.caller is None - def test___post_init___without_explicit_context(self, sample_async_callable): + def test___post_init___should_capture_caller_runtime_context( + self, sample_async_callable + ): """Test post-init auto-captures the caller's RuntimeContext. Given: @@ -388,7 +398,9 @@ def test___post_init___without_explicit_context(self, sample_async_callable): assert dispatch_timeout.get() == 1.25 @pytest.mark.asyncio - async def test___enter___with_coroutine_callable(self, sample_task): + async def test___enter___should_bind_current_task_when_coroutine_callable( + self, sample_task + ): """Test __enter__ returns a callable for coroutine tasks. Given: @@ -408,7 +420,9 @@ async def test___enter___with_coroutine_callable(self, sample_task): assert current_task() is task @pytest.mark.asyncio - async def test___enter___with_async_generator(self, sample_task): + async def test___enter___should_bind_current_task_when_async_generator( + self, sample_task + ): """Test __enter__ binds ``_current_task`` for an async-gen task. Given: @@ -432,7 +446,7 @@ async def test_generator(): assert current_task() is task @pytest.mark.asyncio - async def test___exit___without_exception(self, sample_task): + async def test___exit___should_exit_cleanly_when_no_exception(self, sample_task): """Test __exit__ completes cleanly without exceptions. Given: @@ -449,8 +463,55 @@ async def test___exit___without_exception(self, sample_task): with task: pass + def test___enter___should_raise_when_already_active(self, sample_task): + """Test re-entering an already-active Task raises. + + Given: + A :py:class:`Task` currently inside an active ``with`` + block. + When: + The same instance is entered as a context manager a + second time. + Then: + It should raise RuntimeError — Task instances are + block-scoped and single-use as context managers. + """ + # Arrange + task = sample_task() + + # Act & assert + with task: + with pytest.raises(RuntimeError, match="already active"): + with task: + pass + + def test___exit___should_return_false_when_never_entered(self, sample_task): + """Test __exit__ on a never-entered Task is a no-op. + + Given: + A :py:class:`Task` whose ``__enter__`` was never invoked. + When: + ``__exit__`` is invoked directly with no exception + propagating. + Then: + It should return False (exceptions propagate) and leave + state untouched — symmetric with ``__enter__``'s double- + entry guard so misuse is well-defined rather than + crashing on the underlying ``Token.reset(None)``. + """ + # Arrange + task = sample_task() + + # Act + # No ``with`` idiom can invoke ``__exit__`` without a matching + # ``__enter__``, so the misuse guard is exercised by a direct call. + result = task.__exit__(None, None, None) + + # Assert + assert result is False + @pytest.mark.asyncio - async def test___exit___with_value_error(self, sample_task): + async def test___exit___should_capture_exception_when_value_error(self, sample_task): """Test __exit__ captures ValueError as TaskException. Given: @@ -479,7 +540,9 @@ async def run_with_exception(): assert any("test error" in line for line in task.exception.traceback) @pytest.mark.asyncio - async def test___exit___with_runtime_error(self, sample_task): + async def test___exit___should_capture_exception_when_runtime_error( + self, sample_task + ): """Test __exit__ captures RuntimeError as TaskException. Given: @@ -516,7 +579,7 @@ async def run_with_exception(): tag=st.one_of(st.none(), st.text(min_size=1, max_size=100)), ) @pytest.mark.asyncio - async def test_to_protobuf_with_picklable_proxy( + async def test_to_protobuf_should_preserve_all_attributes_when_round_tripped( self, task_id, timeout, @@ -568,7 +631,7 @@ async def test_callable(): assert deserialized_task.timeout == original_task.timeout assert deserialized_task.tag == original_task.tag - def test_to_protobuf_round_trip_produces_distinct_objects(self): + def test_to_protobuf_should_produce_distinct_objects_when_round_tripped(self): """Test the protobuf round-trip returns independent copies. Given: @@ -607,7 +670,7 @@ async def test_callable(): assert restored.kwargs is not original.kwargs assert restored.proxy is not original.proxy - def test_to_protobuf_round_trip_copies_mutable_args(self): + def test_to_protobuf_should_copy_mutable_args_when_round_tripped(self): """Test the protobuf round-trip yields independent mutable args. Given: @@ -641,7 +704,7 @@ async def test_callable(): assert restored.args[0] == [1, 2, 3, 4] assert original_list == [1, 2, 3] - def test_to_protobuf_with_uncloudpicklable_arg_fails(self, picklable_proxy): + def test_to_protobuf_should_raise_when_arg_uncloudpicklable(self, picklable_proxy): """Test to_protobuf raises for a non-cloudpicklable positional arg. Given: @@ -669,7 +732,7 @@ async def test_callable(): with pytest.raises((TypeError, pickle.PicklingError)): task.to_protobuf() - def test_to_protobuf_with_uncloudpicklable_kwarg_fails(self, picklable_proxy): + def test_to_protobuf_should_raise_when_kwarg_uncloudpicklable(self, picklable_proxy): """Test to_protobuf raises for a non-cloudpicklable keyword arg. Given: @@ -697,7 +760,7 @@ async def test_callable(): with pytest.raises((TypeError, pickle.PicklingError)): task.to_protobuf() - def test_to_protobuf_omits_serializer_field( + def test_to_protobuf_should_omit_serializer_field( self, sample_async_callable, picklable_proxy ): """Test to_protobuf produces a message without a serializer field. @@ -732,7 +795,7 @@ def test_to_protobuf_omits_serializer_field( @settings(max_examples=50, deadline=None) @given(payload=_arbitrary_payloads()) - def test_to_protobuf_round_trip_copies_arbitrary_payloads(self, payload): + def test_to_protobuf_should_copy_arbitrary_payloads(self, payload): """Test the protobuf round-trip copies arbitrary nested payloads. Given: @@ -777,7 +840,7 @@ async def test_callable(): suppress_health_check=[HealthCheck.function_scoped_fixture], ) @given(value_count=st.integers(min_value=0, max_value=10)) - async def test_dispatch_with_async_generator( + async def test_dispatch_should_yield_all_values_when_async_generator( self, value_count, mock_worker_proxy_cache, @@ -818,7 +881,7 @@ async def test_generator(): assert results == list(range(value_count)) @pytest.mark.asyncio - async def test_from_protobuf_all_fields( + async def test_from_protobuf_should_deserialize_all_fields( self, sample_async_callable, picklable_proxy ): """Test from_protobuf deserializes all fields correctly. @@ -865,7 +928,7 @@ async def test_from_protobuf_all_fields( assert task.tag == "test_tag" @pytest.mark.asyncio - async def test_from_protobuf_empty_optionals( + async def test_from_protobuf_should_default_optionals_when_empty( self, sample_async_callable, picklable_proxy ): """Test from_protobuf handles empty optional fields. @@ -905,7 +968,7 @@ async def test_from_protobuf_empty_optionals( assert task.tag is None @pytest.mark.asyncio - async def test_from_protobuf_with_runtime_context( + async def test_from_protobuf_should_read_runtime_context( self, sample_async_callable, picklable_proxy ): """Test from_protobuf reads the RuntimeContext submessage. @@ -944,7 +1007,7 @@ async def test_from_protobuf_with_runtime_context( assert dispatch_timeout.get() == 12.5 @pytest.mark.asyncio - async def test_dispatch_with_dispatch_timeout_on_coroutine( + async def test_dispatch_should_apply_dispatch_timeout_when_coroutine( self, mock_worker_proxy_cache ): """Test coroutine dispatch applies context dispatch_timeout. @@ -982,7 +1045,7 @@ async def capture_timeout(): assert captured == [7.5] @pytest.mark.asyncio - async def test_dispatch_with_dispatch_timeout_on_async_generator( + async def test_dispatch_should_apply_dispatch_timeout_each_iteration( self, mock_worker_proxy_cache ): """Test async-gen dispatch applies context dispatch_timeout each iteration. @@ -1023,7 +1086,9 @@ async def capture_timeout_stream(): # Assert assert captured == [3.0, 3.0] - def test_to_protobuf_all_fields(self, sample_async_callable, picklable_proxy): + def test_to_protobuf_should_serialize_all_fields( + self, sample_async_callable, picklable_proxy + ): """Test to_protobuf serializes all fields correctly. Given: @@ -1066,7 +1131,9 @@ def test_to_protobuf_all_fields(self, sample_async_callable, picklable_proxy): assert task_msg.timeout == 30 assert task_msg.tag == "test_tag" - def test_to_protobuf_none_optionals(self, sample_async_callable, picklable_proxy): + def test_to_protobuf_should_serialize_defaults_when_optionals_none( + self, sample_async_callable, picklable_proxy + ): """Test to_protobuf serializes None optionals as defaults. Given: @@ -1098,7 +1165,7 @@ def test_to_protobuf_none_optionals(self, sample_async_callable, picklable_proxy assert task_msg.timeout == 0 assert task_msg.tag == "" - def test_to_protobuf_with_version_field( + def test_to_protobuf_should_include_version( self, sample_async_callable, picklable_proxy ): """Test to_protobuf includes the protocol version. @@ -1133,7 +1200,7 @@ def test_to_protobuf_with_version_field( @given( version=st.from_regex(r"\d{1,3}\.\d{1,3}(rc\d{1,3}|\.\d{1,3})", fullmatch=True), ) - def test_from_protobuf_with_version_roundtrip(self, version): + def test_from_protobuf_should_preserve_version_when_round_tripped(self, version): """Test protobuf round-trip preserves the version field. Given: @@ -1172,7 +1239,7 @@ async def test_callable(): assert parsed.version == version @pytest.mark.asyncio - async def test_dispatch_successful_execution( + async def test_dispatch_should_return_result( self, sample_task, mock_worker_proxy_cache, @@ -1205,7 +1272,7 @@ async def test_callable(x, y=0): assert result == 8 @pytest.mark.asyncio - async def test_dispatch_without_proxy_pool_raises_error(self, sample_task): + async def test_dispatch_should_raise_when_no_proxy_pool(self, sample_task): """Test dispatch raises RuntimeError without a proxy pool. Given: @@ -1228,7 +1295,7 @@ async def test_dispatch_without_proxy_pool_raises_error(self, sample_task): async with routine_scope(task) as routine: await cast(Coroutine, routine) - def test_to_protobuf_with_unpicklable_callable_fails(self, picklable_proxy): + def test_to_protobuf_should_raise_when_callable_unpicklable(self, picklable_proxy): """Test to_protobuf fails with an unpicklable callable. Given: @@ -1260,7 +1327,7 @@ async def unpicklable_callable(): task.to_protobuf() @pytest.mark.asyncio - async def test_dispatch_with_async_generator_callable( + async def test_dispatch_should_yield_all_values_when_async_generator_callable( self, sample_task, mock_worker_proxy_cache, @@ -1294,7 +1361,7 @@ async def test_generator(): assert results == ["value_0", "value_1", "value_2"] @pytest.mark.asyncio - async def test_dispatch_with_coroutine_callable( + async def test_dispatch_should_return_result_when_coroutine_callable( self, sample_task: Callable[..., Task], mock_worker_proxy_cache, @@ -1324,7 +1391,7 @@ async def test_coroutine(): assert result == "coroutine_result" @pytest.mark.asyncio - async def test_routine_scope_with_invalid_callable( + async def test_routine_scope_should_raise_when_invalid_callable( self, sample_task, mock_worker_proxy_cache, @@ -1355,7 +1422,7 @@ def not_async(): pass @pytest.mark.asyncio - async def test_dispatch_async_generator_without_proxy_pool_raises_error( + async def test_dispatch_should_raise_when_async_generator_no_proxy_pool( self, sample_task, ): @@ -1387,7 +1454,7 @@ async def test_generator(): pass @pytest.mark.asyncio - async def test_dispatch_async_generator_raises_during_iteration( + async def test_dispatch_should_propagate_exception_when_generator_raises( self, sample_task, mock_worker_proxy_cache, @@ -1421,7 +1488,7 @@ async def failing_generator(): assert results == ["first"] @pytest.mark.asyncio - async def test_dispatch_async_generator_early_termination( + async def test_dispatch_should_stop_when_early_break( self, sample_task, mock_worker_proxy_cache, @@ -1455,7 +1522,7 @@ async def test_generator(): assert results == ["value_0", "value_1"] @pytest.mark.asyncio - async def test_dispatch_async_generator_multiple_values( + async def test_dispatch_should_yield_multiple_values_in_order( self, sample_task, mock_worker_proxy_cache, @@ -1489,7 +1556,7 @@ async def multi_value_generator(): assert results == [0, 10, 20, 30, 40] @pytest.mark.asyncio - async def test_dispatch_async_generator_empty( + async def test_dispatch_should_yield_nothing_when_generator_empty( self, sample_task, mock_worker_proxy_cache, @@ -1522,7 +1589,7 @@ async def empty_generator(): # Assert assert results == [] - def test_to_protobuf_with_guarded_proxy(self, sample_async_callable): + def test_to_protobuf_should_serialize_guarded_proxy(self, sample_async_callable): """Test to_protobuf serializes a guarded proxy via wool.__serializer__. Given: @@ -1551,7 +1618,7 @@ def test_to_protobuf_with_guarded_proxy(self, sample_async_callable): assert task_msg.proxy assert task_msg.proxy_id == str(proxy.id) - def test_from_protobuf_with_guarded_proxy(self, sample_async_callable): + def test_from_protobuf_should_restore_guarded_proxy(self, sample_async_callable): """Test the to_protobuf / from_protobuf round-trip with a guarded proxy. Given: @@ -1581,7 +1648,7 @@ def test_from_protobuf_with_guarded_proxy(self, sample_async_callable): assert isinstance(restored.proxy, _PicklableProxy) assert restored.proxy.id == proxy.id - def test_from_protobuf_with_cloudpickle_fields( + def test_from_protobuf_should_deserialize_cloudpickle_fields( self, sample_async_callable, picklable_proxy ): """Test from_protobuf deserializes cloudpickle-encoded payload fields. @@ -1619,7 +1686,7 @@ def test_from_protobuf_with_cloudpickle_fields( assert task.args == args assert task.kwargs == kwargs - def test_from_protobuf_with_invalid_payload_fails(self, picklable_proxy): + def test_from_protobuf_should_raise_when_invalid_payload(self, picklable_proxy): """Test from_protobuf raises for a non-cloudpickle payload. Given: @@ -1654,7 +1721,7 @@ class TestRoutineScope: """Tests for :func:`wool.runtime.routine.task.routine_scope`.""" @pytest.mark.asyncio - async def test_routine_scope_with_null_runtime_context_asserts( + async def test_routine_scope_should_raise_when_null_runtime_context( self, sample_task, mock_worker_proxy_cache ): """Test :func:`routine_scope` asserts a non-None runtime context. @@ -1688,7 +1755,7 @@ async def trivial_routine(): pass @pytest.mark.asyncio - async def test_routine_scope_without_proxy_pool(self, sample_task): + async def test_routine_scope_should_raise_when_no_proxy_pool(self, sample_task): """Test routine_scope raises when wool.__proxy_pool__ is unset. Given: @@ -1711,7 +1778,7 @@ async def test_routine_scope_without_proxy_pool(self, sample_task): pass @pytest.mark.asyncio - async def test_routine_scope_with_coroutine_callable( + async def test_routine_scope_should_yield_coroutine_when_coroutine_callable( self, sample_task, mock_worker_proxy_cache ): """Test routine_scope yields an awaitable coroutine for coroutine tasks. @@ -1742,7 +1809,7 @@ async def coro_callable(): assert result == "coro_result" @pytest.mark.asyncio - async def test_routine_scope_with_async_generator_callable( + async def test_routine_scope_should_yield_async_generator_when_generator_callable( self, sample_task, mock_worker_proxy_cache ): """Test routine_scope yields an iterable async generator for stream tasks. @@ -1776,7 +1843,7 @@ async def gen_callable(): assert results == [1, 2, 3] @pytest.mark.asyncio - async def test_routine_scope_establishes_task_scope( + async def test_routine_scope_should_bind_task_scope( self, sample_task, mock_worker_proxy_cache ): """Test routine_scope binds current_task and disables dispatch routing. @@ -1813,7 +1880,7 @@ async def record_scope(): assert do_dispatch() is outer_dispatch_before @pytest.mark.asyncio - async def test_routine_scope_resets_proxy_token_on_exit( + async def test_routine_scope_should_restore_proxy_on_exit( self, sample_task, mock_worker_proxy_cache ): """Test routine_scope restores wool.__proxy__ on exit. @@ -1848,7 +1915,7 @@ async def record_proxy(): assert wool.__proxy__.get() is None @pytest.mark.asyncio - async def test_routine_scope_applies_runtime_context( + async def test_routine_scope_should_apply_runtime_context( self, sample_task, mock_worker_proxy_cache ): """Test routine_scope applies the Task's RuntimeContext. @@ -1882,7 +1949,7 @@ async def record_timeout(): assert observed["dispatch_timeout"] == 4.5 @pytest.mark.asyncio - async def test_routine_scope_aclose_unconsumed_async_gen( + async def test_routine_scope_should_aclose_when_async_gen_unconsumed( self, sample_task, mock_worker_proxy_cache ): """Test routine_scope acloses an async generator never iterated. @@ -1913,7 +1980,7 @@ async def gen_callable(): assert captured_routine.ag_frame is None @pytest.mark.asyncio - async def test_routine_scope_swallows_generator_exit_during_aclose( + async def test_routine_scope_should_exit_cleanly_when_routine_reraises_ge( self, sample_task, mock_worker_proxy_cache ): """Test routine_scope exits cleanly when the routine reacts to GeneratorExit. @@ -1948,7 +2015,7 @@ async def naughty_gen(): assert routine.ag_frame is None @pytest.mark.asyncio - async def test_routine_scope_propagates_internal_cancelled_during_aclose( + async def test_routine_scope_should_propagate_cancellation_when_routine_internal( self, sample_task, mock_worker_proxy_cache ): """Test routine_scope propagates routine-internal CancelledError on aclose. @@ -1958,7 +2025,7 @@ async def test_routine_scope_propagates_internal_cancelled_during_aclose( :class:`asyncio.CancelledError` during its cleanup, the exception propagates from :func:`routine_scope`'s exit handler unchanged. Paired stdlib parity test - :meth:`test_stdlib_aclose_propagates_internal_cancelled` + :meth:`test_stdlib_aclose_should_propagate_internal_cancelled` pins the stdlib behavior so a future stdlib change signals that wool's parity needs revisiting. @@ -1995,7 +2062,7 @@ async def naughty_gen(): await it.__anext__() @pytest.mark.asyncio - async def test_routine_scope_propagates_external_cancellation_during_aclose( + async def test_routine_scope_should_propagate_cancellation_when_externally_cancelled( self, sample_task, mock_worker_proxy_cache ): """Test routine_scope re-raises CancelledError when externally cancelled. @@ -2045,7 +2112,7 @@ async def body(): await wrapped @pytest.mark.asyncio - async def test_routine_scope_propagates_runtime_error_when_routine_yields_during_ge( + async def test_routine_scope_should_propagate_runtime_error_when_yields_during_ge( self, sample_task, mock_worker_proxy_cache ): """Test routine_scope propagates the synthesized RuntimeError when @@ -2057,7 +2124,7 @@ async def test_routine_scope_propagates_runtime_error_when_routine_yields_during ``RuntimeError("async generator ignored GeneratorExit")`` from ``aclose``. Wool propagates this unchanged. Paired stdlib parity test - :meth:`test_stdlib_aclose_raises_runtime_error_when_yielding_during_ge` + :meth:`test_stdlib_aclose_should_raise_runtime_error_when_yields_during_ge` pins the stdlib behavior. Given: @@ -2088,7 +2155,7 @@ async def yielding_gen(): await it.__anext__() @pytest.mark.asyncio - async def test_routine_scope_with_coroutine_does_not_aclose( + async def test_routine_scope_should_not_aclose_when_coroutine( self, sample_task, mock_worker_proxy_cache ): """Test routine_scope does not invoke aclose teardown for coroutines. @@ -2128,7 +2195,7 @@ async def coro_callable(): assert events == ["enter", "finally"] @pytest.mark.asyncio - async def test_routine_scope_propagates_caller_body_exception( + async def test_routine_scope_should_propagate_exception_when_caller_body_raises( self, sample_task, mock_worker_proxy_cache ): """Test routine_scope propagates exceptions raised in the caller body. @@ -2167,7 +2234,7 @@ async def coro_callable(): assert aexit_mock.await_count >= 1 @pytest.mark.asyncio - async def test_routine_scope_propagates_routine_exception_transparently( + async def test_routine_scope_should_propagate_exception_when_routine_raises( self, sample_task, mock_worker_proxy_cache ): """Test routine_scope propagates routine-raised exceptions unchanged. @@ -2199,10 +2266,83 @@ async def gen_callable(): assert results == [1] +class TestAsyncGenAcloseParity: + """Stdlib parity pins for ``async-generator.aclose`` semantics. + + These tests assert observations about CPython's own + ``await agen.aclose()`` behavior. They are intentionally NOT tests + of any wool code — they pin stdlib semantics so that a future + change in CPython's async-generator close protocol fails here + first, signaling that the paired :class:`TestRoutineScope` + regression tests (and :func:`routine_scope`'s contract) may need + to be revisited. + """ + + @pytest.mark.asyncio + async def test_stdlib_aclose_should_propagate_internal_cancelled(self): + """Test ``aclose`` propagates internal CancelledError. + + Given: + A direct ``asyncio`` async generator that raises + :class:`asyncio.CancelledError` during aclose unwind + while the awaiting task's ``cancelling()`` count is 0. + When: + ``await agen.aclose()`` is invoked after one iteration. + Then: + It should raise :class:`asyncio.CancelledError`. + """ + + # Arrange + async def naughty_gen(): + try: + yield 1 + yield 2 + except GeneratorExit: + raise asyncio.CancelledError() + + agen = naughty_gen() + await agen.__anext__() + + # Act & assert + with pytest.raises(asyncio.CancelledError): + await agen.aclose() + + @pytest.mark.asyncio + async def test_stdlib_aclose_should_raise_runtime_error_when_yields_during_ge(self): + """Test ``aclose`` raises RuntimeError when the routine yields + during ``GeneratorExit``. + + Given: + A direct ``asyncio`` async generator that catches + :class:`GeneratorExit` and yields a value (a PEP 525 + protocol violation). + When: + ``await agen.aclose()`` is invoked after one iteration. + Then: + It should raise + ``RuntimeError("async generator ignored GeneratorExit")``. + """ + + # Arrange + async def yielding_gen(): + try: + yield 1 + yield 2 + except GeneratorExit: + yield "rude" + + agen = yielding_gen() + await agen.__anext__() + + # Act & assert + with pytest.raises(RuntimeError, match="ignored GeneratorExit"): + await agen.aclose() + + class TestRuntimeContext: """Tests for :py:class:`RuntimeContext`.""" - def test___enter___and___exit___apply_inner_and_restore_outer_dispatch_timeout(self): + def test___enter___should_apply_inner_and_restore_outer_dispatch_timeout(self): """Test RuntimeContext sets and restores dispatch_timeout. Given: @@ -2224,7 +2364,7 @@ def test___enter___and___exit___apply_inner_and_restore_outer_dispatch_timeout(s finally: dispatch_timeout.reset(outer_token) - def test___enter___when_dispatch_timeout_unset(self): + def test___enter___should_not_touch_dispatch_timeout_when_unset(self): """Test an empty RuntimeContext does not touch dispatch_timeout. Given: @@ -2245,7 +2385,7 @@ def test___enter___when_dispatch_timeout_unset(self): finally: dispatch_timeout.reset(outer_token) - def test_get_current_with_dispatch_timeout_set(self): + def test_get_current_should_snapshot_dispatch_timeout(self): """Test get_current snapshots the current dispatch_timeout. Given: @@ -2268,7 +2408,7 @@ def test_get_current_with_dispatch_timeout_set(self): with captured: assert dispatch_timeout.get() == 4.0 - def test_to_protobuf_with_dispatch_timeout_set(self): + def test_to_protobuf_should_serialize_dispatch_timeout(self): """Test to_protobuf serializes dispatch_timeout on the wire. Given: @@ -2286,7 +2426,7 @@ def test_to_protobuf_with_dispatch_timeout_set(self): assert pb.HasField("dispatch_timeout") assert pb.dispatch_timeout == 6.0 - def test_to_protobuf_when_dispatch_timeout_unset(self): + def test_to_protobuf_should_fall_back_to_current_var_when_unset(self): """Test to_protobuf falls back to the current var when unset. Given: @@ -2310,7 +2450,7 @@ def test_to_protobuf_when_dispatch_timeout_unset(self): assert pb.HasField("dispatch_timeout") assert pb.dispatch_timeout == 9.25 - def test_to_protobuf_without_value(self): + def test_to_protobuf_should_omit_dispatch_timeout_when_none(self): """Test to_protobuf omits dispatch_timeout when it is None. Given: @@ -2326,7 +2466,7 @@ def test_to_protobuf_without_value(self): # Assert assert not pb.HasField("dispatch_timeout") - def test_from_protobuf_roundtrip(self): + def test_from_protobuf_should_reconstruct_runtime_context(self): """Test from_protobuf reconstructs a usable RuntimeContext. Given: @@ -2351,7 +2491,7 @@ def test_from_protobuf_roundtrip(self): class TestTaskException: """Tests for :py:class:`TaskException`.""" - def test___init___with_type_and_traceback(self): + def test___init___should_store_type_and_traceback(self): """Test TaskException stores type and traceback correctly. Given: diff --git a/wool/tests/runtime/test_resourcepool.py b/wool/tests/runtime/test_resourcepool.py index 0b533480..5f18f509 100644 --- a/wool/tests/runtime/test_resourcepool.py +++ b/wool/tests/runtime/test_resourcepool.py @@ -263,7 +263,7 @@ async def setup(): @pytest.mark.asyncio @given(setup=setup()) - async def test_get_returns_resource_instance(self, setup): + async def test_get_should_return_resource_instance(self, setup): """Test that get returns a Resource instance. Given: @@ -283,8 +283,8 @@ async def test_get_returns_resource_instance(self, setup): assert isinstance(resource_acquisition, Resource) @pytest.mark.asyncio - @pytest.mark.dependency("TestResourcePool::test_get_returns_resource_instance") - async def test_release_decrements_reference_counts(self): + @pytest.mark.dependency("TestResourcePool::test_get_should_return_resource_instance") + async def test_release_should_decrement_reference_counts(self): """Test releasing resources decrements reference counts properly. Given: @@ -323,8 +323,12 @@ async def test_release_decrements_reference_counts(self): assert pool.stats.referenced_entries == 0 @pytest.mark.asyncio - @pytest.mark.dependency("TestResourcePool::test_release_decrements_reference_counts") - async def test_release_nonexistent_key_raises_error(self, counting_factory): + @pytest.mark.dependency( + "TestResourcePool::test_release_should_decrement_reference_counts" + ) + async def test_release_should_not_affect_existing_resources_when_key_nonexistent( + self, counting_factory + ): """Test releasing nonexistent key raises KeyError. Given: @@ -355,8 +359,10 @@ async def test_release_nonexistent_key_raises_error(self, counting_factory): assert pool.stats.referenced_entries == 0 @pytest.mark.asyncio - @pytest.mark.dependency("TestResourcePool::test_release_decrements_reference_counts") - async def test_release_zero_reference_count_raises_error(self): + @pytest.mark.dependency( + "TestResourcePool::test_release_should_decrement_reference_counts" + ) + async def test_release_should_raise_value_error_when_zero_reference_count(self): """Test releasing key with zero ref count raises ValueError. Given: @@ -392,8 +398,148 @@ async def test_release_zero_reference_count_raises_error(self): await ttl_pool.release(unique_key) @pytest.mark.asyncio - @pytest.mark.dependency("TestResourcePool::test_get_returns_resource_instance") - async def test_clear_finalizes_all_resources(self): + async def test_finalizer_should_still_evict_entry_when_raising_base_exception(self): + """Test a cancelled finalizer still evicts the cache entry. + + Given: + A ``ttl=0`` pool whose finalizer raises + ``CancelledError`` — a ``BaseException``, not an + ``Exception`` — on its first call, modelling cleanup that + runs under a cancelled teardown + When: + A resource is acquired and released, driving immediate + cleanup whose finalizer raises + Then: + The ``CancelledError`` propagates, but the torn-down entry + is still evicted, so the next acquire is a cache miss that + builds a fresh resource via the factory rather than handing + back the finalized one + """ + + # Arrange + finalizer_calls = {"count": 0} + + async def finalizer(obj): + finalizer_calls["count"] += 1 + if finalizer_calls["count"] == 1: + # First cleanup runs under cancellation. + raise asyncio.CancelledError() + + factory = Mock( + side_effect=[ + SimpleNamespace(name="first"), + SimpleNamespace(name="second"), + ] + ) + pool = ResourcePool(factory=factory, finalizer=finalizer, ttl=0) + + # Act + # Acquire then release: rc -> 0 drives immediate cleanup, whose + # finalizer raises CancelledError out of the release. + with pytest.raises(asyncio.CancelledError): + async with pool.get("key"): + pass + + # Assert + # The finalized resource must not survive in the cache. + assert pool.stats.total_entries == 0 + # The next acquire is therefore a miss that builds a fresh + # resource, never the torn-down one. + async with pool.get("key") as resource: + assert resource.name == "second" + assert factory.call_count == 2 + + def test_acquire_should_cancel_cross_loop_cleanup_task_threadsafe(self): + """Test acquire cancels a foreign-loop cleanup task cross-loop. + + Given: + A pool whose cached entry has a pending TTL cleanup task + scheduled on one event loop, which has since closed. + When: + The same key is re-acquired from a different event loop. + Then: + The cleanup should be cancelled on its own (now-closed) loop + via call_soon_threadsafe rather than awaited cross-loop, the + resulting RuntimeError should be swallowed, the cached object + should be returned, and the pending cleanup should be + cleared. + """ + # Arrange + pool = ResourcePool(factory=Mock(return_value="obj"), ttl=60) + + # Acquire and release on a dedicated loop so a TTL cleanup task is + # scheduled and bound to that loop, then close it. The pending task is + # held by reference so it survives until the cross-loop cancel runs. + foreign_loop = asyncio.new_event_loop() + + async def schedule_cleanup_on_foreign_loop(): + async with pool.get("key"): + pass + + foreign_loop.run_until_complete(schedule_cleanup_on_foreign_loop()) + pending_cleanup = pool.pending_cleanup["key"] + foreign_loop.close() + + acquiring_loop = asyncio.new_event_loop() + try: + # Act + acquired = acquiring_loop.run_until_complete(pool.acquire("key")) + + # Assert + assert acquired == "obj" + assert pool.pending_cleanup == {} + finally: + acquiring_loop.close() + del pending_cleanup + + def test_clear_should_cancel_cross_loop_cleanup_task_threadsafe(self): + """Test clear cancels a foreign-loop cleanup task cross-loop. + + Given: + A pool whose cached entry has a pending TTL cleanup task + scheduled on one event loop, which has since closed. + When: + The key is cleared from a different event loop. + Then: + The cleanup should be cancelled on its own (now-closed) loop + via call_soon_threadsafe, the resulting RuntimeError should be + swallowed, the finalizer should still run, and the entry + should be evicted. + """ + # Arrange + finalizer = AsyncMock() + pool = ResourcePool( + factory=Mock(return_value="obj"), finalizer=finalizer, ttl=60 + ) + + # Acquire and release on a dedicated loop so a TTL cleanup task is + # scheduled and bound to that loop, then close it. The pending task is + # held by reference so it survives until the cross-loop cancel runs. + foreign_loop = asyncio.new_event_loop() + + async def schedule_cleanup_on_foreign_loop(): + async with pool.get("key"): + pass + + foreign_loop.run_until_complete(schedule_cleanup_on_foreign_loop()) + pending_cleanup = pool.pending_cleanup["key"] + foreign_loop.close() + + clearing_loop = asyncio.new_event_loop() + try: + # Act + clearing_loop.run_until_complete(pool.clear("key")) + + # Assert + finalizer.assert_awaited_once_with("obj") + assert pool.stats.total_entries == 0 + finally: + clearing_loop.close() + del pending_cleanup + + @pytest.mark.asyncio + @pytest.mark.dependency("TestResourcePool::test_get_should_return_resource_instance") + async def test_clear_should_finalize_all_resources(self): """Test clearing the pool calls finalizer on all resources. Given: @@ -431,8 +577,10 @@ async def test_clear_finalizes_all_resources(self): assert mock_finalizer.call_count == 3 @pytest.mark.asyncio - @pytest.mark.dependency("TestResourcePool::test_get_returns_resource_instance") - async def test_clear_key_removes_specific_resource(self, mock_finalizer): + @pytest.mark.dependency("TestResourcePool::test_get_should_return_resource_instance") + async def test_clear_should_remove_specific_resource_when_key_given( + self, mock_finalizer + ): """Test clearing a specific key from the pool. Given: @@ -477,7 +625,7 @@ async def test_clear_key_removes_specific_resource(self, mock_finalizer): mock_finalizer.assert_called_once_with(mock_resource1) @pytest.mark.asyncio - async def test_clear_nonexistent_key_raises_error(self): + async def test_clear_should_raise_key_error_when_key_nonexistent(self): """Test clearing a non-existent key raises KeyError. Given: @@ -513,8 +661,8 @@ async def test_clear_nonexistent_key_raises_error(self): mock_finalizer.assert_not_called() @pytest.mark.asyncio - @pytest.mark.dependency("TestResourcePool::test_get_returns_resource_instance") - async def test_ttl_cleanup_schedules_resource_removal(self): + @pytest.mark.dependency("TestResourcePool::test_get_should_return_resource_instance") + async def test_ttl_cleanup_should_schedule_resource_removal(self): """Test TTL-based cleanup schedules and executes properly. Given: @@ -573,8 +721,8 @@ async def mock_sleep(_delay): mock_finalizer.assert_called_once_with(mock_resource) @pytest.mark.asyncio - @pytest.mark.dependency("TestResourcePool::test_get_returns_resource_instance") - async def test_ttl_cleanup_cancelled_on_reacquire(self): + @pytest.mark.dependency("TestResourcePool::test_get_should_return_resource_instance") + async def test_ttl_cleanup_should_be_cancelled_when_reacquired(self): """Test TTL cleanup is cancelled when resource is reacquired. Given: @@ -619,8 +767,8 @@ async def test_ttl_cleanup_cancelled_on_reacquire(self): assert pool.stats.referenced_entries == 0 @pytest.mark.asyncio - @pytest.mark.dependency("TestResourcePool::test_get_returns_resource_instance") - async def test_stats_returns_accurate_counts(self): + @pytest.mark.dependency("TestResourcePool::test_get_should_return_resource_instance") + async def test_stats_should_return_accurate_counts(self): """Test stats method returns accurate cache statistics. Given: @@ -656,8 +804,8 @@ async def test_stats_returns_accurate_counts(self): assert stats.pending_cleanup == 0 # None scheduled yet @pytest.mark.asyncio - @pytest.mark.dependency("TestResourcePool::test_get_returns_resource_instance") - async def test_async_context_manager_clears_resources(self): + @pytest.mark.dependency("TestResourcePool::test_get_should_return_resource_instance") + async def test_async_context_manager_should_clear_resources(self): """Test ResourcePool as async context manager clears all on exit. Given: @@ -685,7 +833,7 @@ async def test_async_context_manager_clears_resources(self): @pytest.mark.asyncio @pytest.mark.parametrize("ttl", [0, 0.1, 1, 1.1, 10, 10.1]) - async def test_ttl_specific_behavior_with_mocked_sleep(self, ttl): + async def test_ttl_should_schedule_cleanup_based_on_value(self, ttl): """Test specific TTL values with controlled sleep mocking. Given: @@ -754,7 +902,7 @@ def mock_finalizer_func(obj): ) @pytest.mark.asyncio - async def test_finalizer_exception_handling_catches_errors(self): + async def test_finalizer_should_catch_exception_and_remove_resource(self): """Test finalizer exceptions are caught and logged. Given: @@ -786,7 +934,7 @@ async def failing_finalizer(_): assert pool.stats.total_entries == 0 @pytest.mark.asyncio - async def test_clear_with_nonexistent_key_raises_error(self): + async def test_clear_should_raise_key_error_when_key_absent(self): """Test clear with non-existent key raises appropriate error. Given: @@ -810,7 +958,9 @@ async def test_clear_with_nonexistent_key_raises_error(self): await pool.clear("nonexistent-key") @pytest.mark.asyncio - async def test_concurrent_acquire_release_same_key(self, counting_factory): + async def test_concurrent_should_maintain_consistency_when_acquire_release_same_key( + self, counting_factory + ): """Test concurrent operations on same key maintain consistency. Given: @@ -842,7 +992,7 @@ async def acquire_release_worker(): assert pool.stats.total_entries <= 1 # 0 or 1 depending on TTL timing @pytest.mark.asyncio - async def test_resource_pool_with_zero_ttl_immediate_cleanup( + async def test_resource_pool_should_cleanup_immediately_when_zero_ttl( self, resource_pool_immediate_cleanup ): """Test TTL=0 performs immediate cleanup as expected. @@ -874,7 +1024,7 @@ async def test_resource_pool_with_zero_ttl_immediate_cleanup( pool._finalizer.assert_called_once_with(mock_resource) @pytest.mark.asyncio - async def test_get_with_none_key_handles_gracefully(self): + async def test_get_should_handle_none_key(self): """Test resource pool handles None key appropriately. Given: @@ -903,7 +1053,7 @@ class TestResource: """Test suite for the Resource class.""" @pytest.mark.asyncio - async def test_context_manager_auto_releases(self): + async def test_context_manager_should_auto_release(self): """Test Resource as async context manager. Given: @@ -932,7 +1082,7 @@ async def test_context_manager_auto_releases(self): assert pool.stats.total_entries == 0 @pytest.mark.asyncio - async def test_resource_has_no_manual_release_method(self): + async def test_resource_should_have_no_manual_release_method(self): """Test Resource has no manual release method. Given: @@ -956,7 +1106,7 @@ async def test_resource_has_no_manual_release_method(self): assert not hasattr(resource_acquisition, "release") @pytest.mark.asyncio - async def test_resource_lifecycle_with_ttl(self): + async def test_resource_should_stay_cached_when_ttl_set(self): """Test Resource lifecycle with TTL keeps resource in cache. Given: @@ -985,7 +1135,7 @@ async def test_resource_lifecycle_with_ttl(self): assert pool.stats.referenced_entries == 0 @pytest.mark.asyncio - async def test_context_manager_only_usage_handles_lifecycle(self): + async def test_context_manager_should_handle_lifecycle(self): """Test using Resource only as context manager. Given: @@ -1012,7 +1162,7 @@ async def test_context_manager_only_usage_handles_lifecycle(self): assert pool.stats.total_entries == 0 @pytest.mark.asyncio - async def test_acquire_twice(self): + async def test_acquire_should_raise_runtime_error_when_acquired_twice(self): """Test that re-acquiring the same Resource instance raises error. Given: @@ -1039,7 +1189,7 @@ async def test_acquire_twice(self): pass @pytest.mark.asyncio - async def test_resource_context_acquire_exception(self): + async def test_resource_context_should_propagate_acquire_exception(self): """Test Resource context manager handles acquire exceptions properly. Given: @@ -1066,7 +1216,7 @@ async def test_resource_context_acquire_exception(self): assert resource._acquired is False @pytest.mark.asyncio - async def test_resource_context_release_not_acquired(self): + async def test_resource_context_should_raise_runtime_error_when_not_acquired(self): """Test Resource release when not acquired raises RuntimeError. Given: @@ -1089,7 +1239,9 @@ async def test_resource_context_release_not_acquired(self): await resource.__aexit__(None, None, None) @pytest.mark.asyncio - async def test_resource_context_release_already_released(self): + async def test_resource_context_should_raise_runtime_error_when_already_released( + self, + ): """Test Resource release when already released raises RuntimeError. Given: diff --git a/wool/tests/runtime/worker/conftest.py b/wool/tests/runtime/worker/conftest.py index 8d7318ec..0820722e 100644 --- a/wool/tests/runtime/worker/conftest.py +++ b/wool/tests/runtime/worker/conftest.py @@ -1,5 +1,6 @@ import asyncio import datetime +import multiprocessing.shared_memory import threading import uuid from types import MappingProxyType @@ -18,7 +19,7 @@ import wool.runtime.worker.pool as wp from tests.helpers import scoped_context -from wool.runtime.context import install_task_factory +from wool.runtime.context.factory import install_task_factory from wool.runtime.discovery.base import DiscoveryEvent from wool.runtime.worker.auth import WorkerCredentials from wool.runtime.worker.metadata import WorkerMetadata @@ -37,12 +38,12 @@ def __reduce__(self): @pytest.fixture(autouse=True) def _isolate_wool_context(): - """Install a fresh wool.Context for the duration of the test. + """Install a fresh, unarmed Wool context for the duration of the test. - Each test runs under its own scoped Context so var values set in - one test do not leak into subsequent tests via the per-task data - map. The process-wide var_registry is not reset; tests SHOULD - use unique key namespaces (e.g. via uuid suffix) to avoid + Each test runs under its own unarmed context so var values set + in one test do not leak into subsequent tests via the chain + context. The process-wide var_registry is not reset; tests + SHOULD use unique key namespaces (e.g. via uuid suffix) to avoid cross-test collisions on shared keys. """ with scoped_context(): @@ -358,7 +359,9 @@ def mock_shared_memory(mocker: MockerFixture): mock_memory.buf = bytearray(1024) mock_memory.close = mocker.MagicMock() mock_memory.unlink = mocker.MagicMock() - mocker.patch("multiprocessing.shared_memory.SharedMemory", return_value=mock_memory) + mocker.patch.object( + multiprocessing.shared_memory, "SharedMemory", return_value=mock_memory + ) return mock_memory diff --git a/wool/tests/runtime/worker/test_auth.py b/wool/tests/runtime/worker/test_auth.py index 3c208dfe..88bcdcc0 100644 --- a/wool/tests/runtime/worker/test_auth.py +++ b/wool/tests/runtime/worker/test_auth.py @@ -15,6 +15,7 @@ from hypothesis import settings from hypothesis import strategies as st +from wool.runtime.worker.auth import CredentialContext from wool.runtime.worker.auth import WorkerCredentials @@ -118,7 +119,7 @@ def temp_cert_files(test_certificates, tmp_path): class TestWorkerCredentials: """Test suite for WorkerCredentials credential management.""" - def test___init___with_mtls(self, test_certificates): + def test___init___should_set_all_fields_when_mtls(self, test_certificates): """Test basic instantiation with mTLS. Given: @@ -142,7 +143,7 @@ def test___init___with_mtls(self, test_certificates): assert creds.worker_cert == cert_pem assert creds.mutual is True - def test___init___with_one_way_tls(self, test_certificates): + def test___init___should_set_mutual_false_when_one_way_tls(self, test_certificates): """Test instantiation with one-way TLS. Given: @@ -163,7 +164,7 @@ def test___init___with_one_way_tls(self, test_certificates): # Assert assert creds.mutual is False - def test___init___frozen_dataclass_immutability(self, test_certificates): + def test___init___should_raise_when_field_mutated(self, test_certificates): """Test immutability via frozen dataclass. Given: @@ -184,7 +185,7 @@ def test___init___frozen_dataclass_immutability(self, test_certificates): with pytest.raises((FrozenInstanceError, AttributeError)): creds.mutual = False - def test_from_files_with_mtls(self, temp_cert_files): + def test_from_files_should_load_bytes_when_mtls(self, temp_cert_files): """Test from_files classmethod with mTLS. Given: @@ -209,7 +210,7 @@ def test_from_files_with_mtls(self, temp_cert_files): assert len(creds.worker_cert) > 0 assert creds.mutual is True - def test_from_files_with_one_way_tls(self, temp_cert_files): + def test_from_files_should_set_mutual_false_when_one_way_tls(self, temp_cert_files): """Test from_files classmethod with one-way TLS. Given: @@ -230,7 +231,7 @@ def test_from_files_with_one_way_tls(self, temp_cert_files): # Assert assert creds.mutual is False - def test_from_files_default_mutual_parameter(self, temp_cert_files): + def test_from_files_should_default_mutual_to_true(self, temp_cert_files): """Test default mutual=True parameter. Given: @@ -251,7 +252,7 @@ def test_from_files_default_mutual_parameter(self, temp_cert_files): # Assert assert creds.mutual is True - def test_from_files_missing_ca_cert(self, temp_cert_files): + def test_from_files_should_raise_when_ca_cert_missing(self, temp_cert_files): """Test missing CA file error handling. Given: @@ -270,7 +271,7 @@ def test_from_files_missing_ca_cert(self, temp_cert_files): ca_path="/nonexistent/ca.pem", key_path=key_path, cert_path=cert_path ) - def test_from_files_missing_key(self, temp_cert_files): + def test_from_files_should_raise_when_key_missing(self, temp_cert_files): """Test missing key file error handling. Given: @@ -289,7 +290,7 @@ def test_from_files_missing_key(self, temp_cert_files): ca_path=ca_path, key_path="/nonexistent/key.pem", cert_path=cert_path ) - def test_from_files_missing_cert(self, temp_cert_files): + def test_from_files_should_raise_when_cert_missing(self, temp_cert_files): """Test missing cert file error handling. Given: @@ -308,7 +309,9 @@ def test_from_files_missing_cert(self, temp_cert_files): ca_path=ca_path, key_path=key_path, cert_path="/nonexistent/cert.pem" ) - def test_from_files_permission_error(self, temp_cert_files, tmp_path): + def test_from_files_should_raise_when_permission_denied( + self, temp_cert_files, tmp_path + ): """Test permission error handling. Given: @@ -334,7 +337,9 @@ def test_from_files_permission_error(self, temp_cert_files, tmp_path): # Restore permissions for cleanup restricted_file.chmod(0o644) - def test_server_credentials_with_mtls(self, test_certificates): + def test_server_credentials_should_return_server_credentials_when_mtls( + self, test_certificates + ): """Test server credentials property for mTLS. Given: @@ -356,7 +361,9 @@ def test_server_credentials_with_mtls(self, test_certificates): # Assert assert isinstance(server_creds, grpc.ServerCredentials) - def test_server_credentials_with_one_way_tls(self, test_certificates): + def test_server_credentials_should_return_server_credentials_when_one_way_tls( + self, test_certificates + ): """Test server credentials property for one-way TLS. Given: @@ -378,7 +385,9 @@ def test_server_credentials_with_one_way_tls(self, test_certificates): # Assert assert isinstance(server_creds, grpc.ServerCredentials) - def test_client_credentials_with_mtls(self, test_certificates): + def test_client_credentials_should_return_channel_credentials_when_mtls( + self, test_certificates + ): """Test client credentials property for mTLS. Given: @@ -400,7 +409,9 @@ def test_client_credentials_with_mtls(self, test_certificates): # Assert assert isinstance(client_creds, grpc.ChannelCredentials) - def test_client_credentials_with_one_way_tls(self, test_certificates): + def test_client_credentials_should_return_channel_credentials_when_one_way_tls( + self, test_certificates + ): """Test client credentials property for one-way TLS. Given: @@ -422,7 +433,7 @@ def test_client_credentials_with_one_way_tls(self, test_certificates): # Assert assert isinstance(client_creds, grpc.ChannelCredentials) - def test_server_credentials_and_client_credentials_bidirectional( + def test_server_and_client_credentials_should_both_build_valid_credentials( self, test_certificates ): """Test bidirectional credential generation. @@ -450,7 +461,9 @@ def test_server_credentials_and_client_credentials_bidirectional( assert isinstance(server_creds, grpc.ServerCredentials) assert isinstance(client_creds, grpc.ChannelCredentials) - def test_server_credentials_idempotent_access(self, test_certificates): + def test_server_credentials_should_return_credentials_on_repeated_access( + self, test_certificates + ): """Test server credentials method idempotency. Given: @@ -475,7 +488,9 @@ def test_server_credentials_idempotent_access(self, test_certificates): assert isinstance(server_creds_1, grpc.ServerCredentials) assert isinstance(server_creds_2, grpc.ServerCredentials) - def test_client_credentials_idempotent_access(self, test_certificates): + def test_client_credentials_should_return_credentials_on_repeated_access( + self, test_certificates + ): """Test client credentials method idempotency. Given: @@ -502,7 +517,7 @@ def test_client_credentials_idempotent_access(self, test_certificates): @given(mutual=st.booleans()) @settings(suppress_health_check=[HealthCheck.function_scoped_fixture]) - def test_server_credentials_and_client_credentials_type_consistency( + def test_server_credentials_and_client_credentials_should_return_consistent_types( self, mutual, test_certificates ): """Test credential method idempotency across mutual flag values. @@ -539,7 +554,9 @@ def test_server_credentials_and_client_credentials_type_consistency( @given(mutual=st.booleans()) @settings(suppress_health_check=[HealthCheck.function_scoped_fixture]) - def test_pickle_roundtrip(self, mutual, test_certificates): + def test_pickle_roundtrip_should_produce_equal_instance( + self, mutual, test_certificates + ): """Test WorkerCredentials survives pickle roundtrip. Given: @@ -565,7 +582,9 @@ def test_pickle_roundtrip(self, mutual, test_certificates): assert isinstance(restored.server_credentials(), grpc.ServerCredentials) assert isinstance(restored.client_credentials(), grpc.ChannelCredentials) - def test___enter___not_supported(self, test_certificates): + def test___enter___should_raise_when_used_as_context_manager( + self, test_certificates + ): """Test WorkerCredentials does not support context manager protocol. Given: @@ -586,7 +605,7 @@ def test___enter___not_supported(self, test_certificates): with creds: pass - def test_current_not_supported(self): + def test_current_should_raise_attribute_error(self): """Test WorkerCredentials does not expose current() classmethod. Given: @@ -599,3 +618,29 @@ def test_current_not_supported(self): # Act & assert with pytest.raises(AttributeError): WorkerCredentials.current() + + +class TestCredentialContext: + """Test suite for the internal CredentialContext manager.""" + + def test___exit___should_raise_when_invoked_without_enter(self): + """Test exit guards against running without a matching enter. + + Given: + A CredentialContext that was never entered, so it holds no + reset token. + When: + Its exit is invoked directly. + Then: + It should raise RuntimeError — the credential reset is + guarded against running without a token from a matching + enter. + """ + # Arrange + context = CredentialContext(credentials=None) # type: ignore[arg-type] + + # Act & assert + # No ``with`` idiom can invoke ``__exit__`` without a matching + # ``__enter__``, so the misuse guard is exercised by a direct call. + with pytest.raises(RuntimeError, match="without matching __enter__"): + context.__exit__(None, None, None) diff --git a/wool/tests/runtime/worker/test_connection.py b/wool/tests/runtime/worker/test_connection.py index 546e8ff4..1934eec5 100644 --- a/wool/tests/runtime/worker/test_connection.py +++ b/wool/tests/runtime/worker/test_connection.py @@ -1,4 +1,5 @@ import asyncio +import logging from typing import Callable from typing import Coroutine from uuid import uuid4 @@ -15,8 +16,8 @@ import wool from wool import protocol -from wool.runtime.context import ContextDecodeWarning -from wool.runtime.context import ContextVar +from wool.runtime.context.exceptions import SerializationWarning +from wool.runtime.context.var import ContextVar from wool.runtime.routine.task import Task from wool.runtime.routine.task import WorkerProxyLike from wool.runtime.worker.base import ChannelOptions @@ -38,52 +39,6 @@ class on deserialization in tests that round-trip user-defined """ -class _StrictRejectingException(Exception): - """Module-level exception that rejects every best-effort write - the dispatch handler's strict-mode side channels attempt. - - Defined at module scope so cloudpickle can resolve the class on - deserialization when this exception ships across the wire on - :class:`protocol.Response`'s ``exception`` field. - - ``add_note`` raises :class:`AttributeError` so the PEP 678 note - path inside :meth:`WorkerConnection._read_next`'s exception arm - hits the ``except (AttributeError, TypeError)`` swallow. - - Arbitrary attribute writes — including - ``__wool_context_warnings__`` — raise :class:`AttributeError` so - the programmatic side-channel write hits the ``except - AttributeError`` swallow. The ``args``/``__cause__``/ - ``__context__``/``__traceback__``/``__suppress_context__``/ - ``__notes__`` slots are explicitly allowed so the standard - ``BaseException`` machinery (and cloudpickle's restore via - ``__setstate__``) keeps working. - """ - - _ALLOWED = frozenset( - { - "args", - "__cause__", - "__context__", - "__traceback__", - "__suppress_context__", - "__notes__", - } - ) - - def __setattr__(self, name, value): - if name in self._ALLOWED: - object.__setattr__(self, name, value) - else: - raise AttributeError( - f"{type(self).__name__!r} object does not accept " - f"arbitrary attribute writes: {name!r}" - ) - - def add_note(self, _note): - raise AttributeError(f"{type(self).__name__!r} object does not accept add_note") - - @pytest.fixture def sample_task(mocker: MockerFixture): """Provides a mock :class:`Task` for testing. @@ -165,7 +120,7 @@ def create_call(stream_iterator, cancel_raises=False): class TestWorkerConnection: @pytest.mark.asyncio - async def test_dispatch_task_that_returns( + async def test_dispatch_should_yield_return_value_when_task_returns( self, mocker: MockerFixture, sample_task, async_stream, mock_grpc_call ): """Test task dispatch with successful acknowledgment and result. @@ -209,7 +164,7 @@ async def test_dispatch_task_that_returns( assert results[0] == "test_result" @pytest.mark.asyncio - async def test_dispatch_task_that_raises( + async def test_dispatch_should_raise_exception_when_task_raises( self, mocker: MockerFixture, sample_task, async_stream, mock_grpc_call ): """Test task dispatch when task raises an exception. @@ -252,7 +207,7 @@ async def test_dispatch_task_that_raises( mock_stub.dispatch.assert_called_once() @pytest.mark.asyncio - async def test_dispatch_no_ack( + async def test_dispatch_should_raise_unexpected_response_when_not_acked( self, mocker: MockerFixture, sample_task, async_stream, mock_grpc_call ): """Test task dispatch when acknowledgment is not received. @@ -297,7 +252,7 @@ async def test_dispatch_no_ack( [False, True], ids=["cancel_succeeds", "cancel_raises"], ) - async def test_dispatch_unexpected_response( + async def test_dispatch_should_raise_unexpected_response_when_unexpected_after_ack( self, mocker: MockerFixture, sample_task, @@ -354,7 +309,7 @@ async def test_dispatch_unexpected_response( grpc.StatusCode.UNAVAILABLE, ], ) - async def test_dispatch_transient_rpc_error( + async def test_dispatch_should_raise_transient_rpc_error_when_stub_raises_transient( self, mocker: MockerFixture, sample_task, status_code: grpc.StatusCode ): """Test task dispatch when stub raises a transient RPC error. @@ -402,7 +357,7 @@ def details(self): grpc.StatusCode.UNIMPLEMENTED, ], ) - async def test_dispatch_nontransient_rpc_error( + async def test_dispatch_should_raise_rpc_error_when_stub_raises_nontransient( self, mocker: MockerFixture, sample_task, status_code: grpc.StatusCode ): """Test task dispatch when stub raises a non-transient RPC error. @@ -442,7 +397,9 @@ def details(self): @pytest.mark.asyncio @settings(suppress_health_check=[HealthCheck.function_scoped_fixture]) @given(timeout=st.floats(max_value=0.0, allow_nan=False, allow_infinity=False)) - async def test_dispatch_invalid_timeout(self, mocker: MockerFixture, timeout: float): + async def test_dispatch_should_raise_when_timeout_not_positive( + self, mocker: MockerFixture, timeout: float + ): """Test task dispatch with invalid dispatch timeout value. Given: @@ -476,7 +433,7 @@ async def test_callable(): pass @pytest.mark.asyncio - async def test_dispatch_exceeds_timeout( + async def test_dispatch_should_raise_deadline_when_timeout_exceeded( self, mocker: MockerFixture, sample_task, async_stream, mock_grpc_call ): """Test task dispatch when dispatch timeout is exceeded. @@ -514,7 +471,7 @@ async def test_dispatch_exceeds_timeout( mock_call.cancel.assert_called() @pytest.mark.asyncio - async def test_dispatch_exceeds_limit( + async def test_dispatch_should_raise_deadline_when_concurrency_limit_reached( self, mocker: MockerFixture, sample_task, mock_grpc_call, async_stream ): """Test task dispatch when concurrency limit is reached. @@ -571,7 +528,7 @@ async def consume_slot(): pass @pytest.mark.asyncio - async def test_dispatch_releases_semaphore_when_handshake_fails( + async def test_dispatch_should_release_semaphore_when_handshake_fails( self, mocker: MockerFixture, sample_task, mock_grpc_call, async_stream ): """Test :meth:`WorkerConnection.dispatch` releases the @@ -629,7 +586,7 @@ def details(self): ) @pytest.mark.asyncio - async def test_dispatch_releases_semaphore_when_stream_acloses_before_iteration( + async def test_dispatch_should_release_semaphore_when_stream_acloses_early( self, mocker: MockerFixture, sample_task, mock_grpc_call, async_stream ): """Test that closing a primed stream before iterating any @@ -677,7 +634,7 @@ async def test_dispatch_releases_semaphore_when_stream_acloses_before_iteration( ) @pytest.mark.asyncio - async def test_dispatch_propagates_task_encode_failure_unwrapped( + async def test_dispatch_should_propagate_unwrapped_when_task_encode_fails( self, mocker: MockerFixture, sample_task ): """Test that a caller-side task encode failure propagates @@ -725,7 +682,7 @@ class EncodeError(Exception): [False, True], ids=["cancel_succeeds", "cancel_raises"], ) - async def test_dispatch_cancelled_during_dispatch( + async def test_dispatch_should_cancel_call_and_raise_when_cancelled_during_dispatch( self, mocker: MockerFixture, sample_task, @@ -784,7 +741,7 @@ async def run_dispatch(): [False, True], ids=["cancel_succeeds", "cancel_raises"], ) - async def test_dispatch_cancelled_during_execution( + async def test_dispatch_should_cancel_call_and_raise_when_cancelled_during_execution( self, mocker: MockerFixture, sample_task, @@ -843,7 +800,7 @@ async def run_dispatch(): assert mock_call.cancel.called @pytest.mark.asyncio - async def test_dispatch_cancelled_during_teardown_releases_channel_ref( + async def test_dispatch_should_release_channel_ref_when_cancelled_during_teardown( self, mocker: MockerFixture, sample_task, mock_grpc_call, async_stream ): """Test external cancellation during teardown releases the @@ -906,7 +863,7 @@ async def consume(): assert pool.stats.referenced_entries == 0 @pytest.mark.asyncio - async def test_dispatch_response_exception_with_cancelled_error_releases_channel_ref( + async def test_dispatch_should_release_channel_ref_when_worker_cancels( self, mocker: MockerFixture, sample_task, mock_grpc_call, async_stream ): """Test a worker-side CancelledError releases the pooled @@ -972,7 +929,112 @@ async def consume(): assert pool.stats.referenced_entries == 0 @pytest.mark.asyncio - async def test_close_idempotent(self, mocker: MockerFixture): + async def test_dispatch_should_reraise_signal_when_teardown_raises_process_signal( + self, mocker: MockerFixture, sample_task, mock_grpc_call, async_stream + ): + """Test a process signal raised during teardown reaches the caller. + + Given: + A dispatched task run to completion where a resource-release + step performed during teardown raises a process-exit signal + (``SystemExit``). + When: + The caller awaits the stream's completion. + Then: + The ``SystemExit`` should propagate to the caller and the + connection's pooled channel resources should still be + released — the signal is captured off the teardown task and + re-raised without leaking the pooled reference. + """ + # Arrange + from wool.runtime.worker import connection as connection_module + + responses = ( + protocol.Response(ack=protocol.Ack()), + protocol.Response(result=protocol.Message(dump=cloudpickle.dumps("done"))), + ) + mock_call = mock_grpc_call(async_stream(responses)) + # A teardown-side release step (the gRPC call cancel) raises a + # process-exit signal; the swallow guard only catches Exception, + # so it escapes the teardown task. + mock_call.cancel.side_effect = SystemExit("teardown signal") + mock_stub = mocker.MagicMock() + mock_stub.dispatch = mocker.MagicMock(return_value=mock_call) + mocker.patch.object(protocol, "WorkerStub", return_value=mock_stub) + connection = WorkerConnection( + "localhost:50051", options=ChannelOptions(max_concurrent_streams=10) + ) + pool = connection_module._channel_pool + + # Act & assert + with pytest.raises(SystemExit, match="teardown signal"): + async for _ in await connection.dispatch(sample_task): + pass + + # The pooled reference is still released despite the signal. + assert pool.stats.referenced_entries == 0 + + @pytest.mark.asyncio + async def test_dispatch_should_detach_teardown_when_release_exceeds_timeout( + self, mocker: MockerFixture, sample_task, mock_grpc_call, async_stream, caplog + ): + """Test a wedged teardown unblocks the caller after the timeout. + + Given: + A dispatched task whose teardown cannot complete promptly + because a resource-release step is blocked. + When: + The caller awaits the stream's completion. + Then: + The caller should unblock within the teardown timeout rather + than hang, the blocked release should be left to a detached + task, and a timeout warning should be logged. + """ + # Arrange + from wool.runtime.worker import connection as connection_module + + # Contend the pool lock on a fresh instance, and shorten the + # teardown budget so the wedge resolves quickly. + mocker.patch.object(connection_module._channel_pool, "_lock", asyncio.Lock()) + mocker.patch.object(connection_module, "_TEARDOWN_TIMEOUT", 0.1) + responses = ( + protocol.Response(ack=protocol.Ack()), + protocol.Response(result=protocol.Message(dump=cloudpickle.dumps("done"))), + ) + mock_call = mock_grpc_call(async_stream(responses)) + mock_stub = mocker.MagicMock() + mock_stub.dispatch = mocker.MagicMock(return_value=mock_call) + mocker.patch.object(protocol, "WorkerStub", return_value=mock_stub) + connection = WorkerConnection( + "localhost:50051", options=ChannelOptions(max_concurrent_streams=10) + ) + stream = await connection.dispatch(sample_task) + pool = connection_module._channel_pool + + # Act — hold the pool lock so the release wedges; the caller must + # unblock at the timeout instead of hanging. + await pool._lock.acquire() + try: + + async def consume(): + async for _ in stream: + pass + + with caplog.at_level( + logging.WARNING, logger="wool.runtime.worker.connection" + ): + await asyncio.wait_for(consume(), timeout=5) + finally: + pool._lock.release() + # Let the detached teardown task finish now the lock is free. + for _ in range(5): + await asyncio.sleep(0) + + # Assert — the caller unblocked and the timeout was reported. + assert any("teardown exceeded" in r.getMessage() for r in caplog.records) + + @pytest.mark.asyncio + async def test_close_should_be_idempotent(self, mocker: MockerFixture): """Test closing a connection is idempotent. Given: @@ -992,7 +1054,7 @@ async def test_close_idempotent(self, mocker: MockerFixture): await connection.close() @pytest.mark.asyncio - async def test_close_called_twice_after_uds_self_dispatch( + async def test_close_should_not_raise_when_called_twice_after_uds_self_dispatch( self, mocker: MockerFixture, sample_task, @@ -1051,7 +1113,7 @@ async def test_close_called_twice_after_uds_self_dispatch( await connection.close() @pytest.mark.asyncio - async def test_dispatch_task_that_yields_multiple_results( + async def test_dispatch_should_yield_all_results_in_order_when_task_yields_multiple( self, mocker: MockerFixture, sample_task, async_stream, mock_grpc_call ): """Test task dispatch with multiple streaming results. @@ -1100,7 +1162,54 @@ async def test_dispatch_task_that_yields_multiple_results( assert results == ["result_1", "result_2", "result_3"] @pytest.mark.asyncio - async def test_dispatch_stream_early_close( + async def test_dispatch_should_swallow_cancel_error_when_stream_closes( + self, mocker: MockerFixture, sample_task, async_stream, mock_grpc_call + ): + """Test the dispatch stream's close swallows a cancel-time error. + + Given: + A dispatch in flight whose underlying gRPC call's + ``cancel()`` is patched to raise an Exception. + When: + The result iterator is broken out of, triggering the + ``_DispatchStream.close()`` path on unwind. + Then: + The break and the surrounding teardown should complete + without surfacing the cancel-time exception — close is + a cleanup site, so a raise from ``call.cancel()`` is + swallowed defensively (cleanup-during-cleanup). + """ + # Arrange + responses = ( + protocol.Response(ack=protocol.Ack()), + protocol.Response(result=protocol.Message(dump=cloudpickle.dumps("first"))), + protocol.Response(result=protocol.Message(dump=cloudpickle.dumps("second"))), + ) + mock_call = mock_grpc_call(async_stream(responses)) + # The ``cancel`` swallow is on the inner ``_DispatchStream`` + # path; ``_DispatchStream.close()`` invokes ``self._call.cancel()`` + # whose raise is what we want surfaced into the except arm. + mock_call.cancel = mocker.MagicMock(side_effect=RuntimeError("cancel boom")) + + mock_stub = mocker.MagicMock() + mock_stub.dispatch = mocker.MagicMock(return_value=mock_call) + mocker.patch.object(protocol, "WorkerStub", return_value=mock_stub) + + connection = WorkerConnection( + "localhost:50051", options=ChannelOptions(max_concurrent_streams=10) + ) + + # Act — break early to drive the close() path. + results = [] + async for result in await connection.dispatch(sample_task): + results.append(result) + break + + # Assert — early break completed without surfacing the cancel raise. + assert results == ["first"] + + @pytest.mark.asyncio + async def test_dispatch_should_stop_stream_when_iterator_breaks_early( self, mocker: MockerFixture, sample_task, async_stream, mock_grpc_call ): """Test iterator closed via break before completion. @@ -1146,7 +1255,7 @@ async def test_dispatch_stream_early_close( assert results == ["result_1", "result_2"] @pytest.mark.asyncio - async def test_dispatch_with_version_in_ack( + async def test_dispatch_should_accept_ack_when_version_present( self, mocker: MockerFixture, sample_task, async_stream, mock_grpc_call ): """Test dispatch accepts Ack with version field. @@ -1184,7 +1293,7 @@ async def test_dispatch_with_version_in_ack( assert results == ["test_result"] @pytest.mark.asyncio - async def test_dispatch_nack_with_exception_reraises_original_class( + async def test_dispatch_should_reraise_original_class_when_nack_carries_exception( self, mocker: MockerFixture, sample_task, async_stream, mock_grpc_call ): """Test dispatch re-raises the worker's exception class on Nack. @@ -1226,7 +1335,7 @@ async def test_dispatch_nack_with_exception_reraises_original_class( mock_call.cancel.assert_called() @pytest.mark.asyncio - async def test_dispatch_nack_with_exception_preserves_subclass_identity( + async def test_dispatch_should_preserve_subclass_identity_when_nack_has_exception( self, mocker: MockerFixture, sample_task, async_stream, mock_grpc_call ): """Test dispatch preserves the exact subclass of the worker exception. @@ -1270,7 +1379,7 @@ async def test_dispatch_nack_with_exception_preserves_subclass_identity( assert type(excinfo.value) is MyAppError @pytest.mark.asyncio - async def test_dispatch_nack_with_exception_suppresses_implicit_chaining( + async def test_dispatch_should_suppress_implicit_chaining_when_reraising_nack( self, mocker: MockerFixture, sample_task, async_stream, mock_grpc_call ): """Test dispatch suppresses chaining when re-raising the worker exception. @@ -1320,7 +1429,7 @@ async def test_dispatch_nack_with_exception_suppresses_implicit_chaining( assert "cloudpickle" not in ctx_module @pytest.mark.asyncio - async def test_dispatch_nack_with_unpicklable_exception_falls_back_to_rpc_error( + async def test_dispatch_should_raise_rpc_error_when_nack_dump_unpicklable( self, mocker: MockerFixture, sample_task, async_stream, mock_grpc_call ): """Test dispatch falls back to RpcError when the Nack dump is malformed. @@ -1358,7 +1467,7 @@ async def test_dispatch_nack_with_unpicklable_exception_falls_back_to_rpc_error( pass @pytest.mark.asyncio - async def test_dispatch_nack_with_non_exception_payload_falls_back_to_rpc_error( + async def test_dispatch_should_raise_rpc_error_when_nack_payload_not_exception( self, mocker: MockerFixture, sample_task, async_stream, mock_grpc_call ): """Test dispatch falls back to RpcError when the Nack dump is not an exception. @@ -1398,7 +1507,7 @@ async def test_dispatch_nack_with_non_exception_payload_falls_back_to_rpc_error( pass @pytest.mark.asyncio - async def test_dispatch_nack_with_base_exception_falls_back_to_rpc_error( + async def test_dispatch_should_raise_rpc_error_when_nack_payload_base_exception( self, mocker: MockerFixture, sample_task, async_stream, mock_grpc_call ): """Test dispatch degrades non-Exception BaseException Nacks to RpcError. @@ -1454,14 +1563,14 @@ async def test_dispatch_nack_with_base_exception_falls_back_to_rpc_error( OSError, ArithmeticError, # PR #205 explicitly names ``ImportError`` and - # ``ContextDecodeWarning`` as parse-phase rejection + # ``SerializationWarning`` as parse-phase rejection # classes that ride the Nack-with-exception channel - # (unloadable callable / strict-mode context decode). + # (unloadable callable / strict-mode chain-manifest decode). # Both must round-trip class+message intact so the # caller's ``except ImportError`` / ``except - # ContextDecodeWarning`` keeps matching. + # SerializationWarning`` keeps matching. ImportError, - ContextDecodeWarning, + SerializationWarning, ) ), message=st.text( @@ -1469,7 +1578,7 @@ async def test_dispatch_nack_with_base_exception_falls_back_to_rpc_error( max_size=64, ), ) - async def test_dispatch_nack_with_arbitrary_exception_roundtrips( + async def test_dispatch_should_roundtrip_exception_when_nack_carries_arbitrary( self, mocker: MockerFixture, sample_task, @@ -1484,7 +1593,7 @@ async def test_dispatch_nack_with_arbitrary_exception_roundtrips( A Hypothesis-generated typed exception instance drawn from a representative sampling of parse-phase rejection classes (Exception subclasses including ImportError and - ContextDecodeWarning) paired with arbitrary printable + SerializationWarning) paired with arbitrary printable text messages, dumped via cloudpickle.dumps and shipped as a Nack.exception. When: @@ -1530,7 +1639,7 @@ async def test_dispatch_nack_with_arbitrary_exception_roundtrips( await connection.close() @pytest.mark.asyncio - async def test_dispatch_with_secure_channel( + async def test_dispatch_should_use_secure_channel_when_credentials_provided( self, mocker: MockerFixture, sample_task, async_stream, mock_grpc_call ): """Test dispatch uses secure channel when credentials are provided. @@ -1545,7 +1654,9 @@ async def test_dispatch_with_secure_channel( """ # Arrange mock_channel = mocker.AsyncMock() - mock_secure = mocker.patch("grpc.aio.secure_channel", return_value=mock_channel) + mock_secure = mocker.patch.object( + grpc.aio, "secure_channel", return_value=mock_channel + ) responses = ( protocol.Response(ack=protocol.Ack()), @@ -1574,7 +1685,7 @@ async def test_dispatch_with_secure_channel( await connection.close() @pytest.mark.asyncio - async def test_dispatch_athrow_propagates_to_worker( + async def test_dispatch_should_propagate_athrow_to_worker_and_return_recovery( self, mocker: MockerFixture, sample_task, async_stream, mock_grpc_call ): """Test throwing an exception into the dispatch generator. @@ -1617,7 +1728,7 @@ async def test_dispatch_athrow_propagates_to_worker( assert mock_call.write.call_count == 3 @pytest.mark.asyncio - async def test_stream_usable_after_dispatch_returns( + async def test_stream_should_be_consumable_when_dispatch_returns( self, mocker: MockerFixture, sample_task, async_stream, mock_grpc_call ): """Test dispatch returns a usable stream after its scope exits. @@ -1655,7 +1766,7 @@ async def test_stream_usable_after_dispatch_returns( assert results == ["value"] @pytest.mark.asyncio - async def test_stream_consumption_releases_pool_ref( + async def test_stream_should_release_pool_ref_when_fully_consumed( self, mocker: MockerFixture, sample_task, async_stream, mock_grpc_call ): """Test consuming the full stream releases the channel. @@ -1671,7 +1782,7 @@ async def test_stream_consumption_releases_pool_ref( """ # Arrange mock_channel = mocker.AsyncMock() - mocker.patch("grpc.aio.insecure_channel", return_value=mock_channel) + mocker.patch.object(grpc.aio, "insecure_channel", return_value=mock_channel) responses = ( protocol.Response(ack=protocol.Ack()), @@ -1696,7 +1807,7 @@ async def test_stream_consumption_releases_pool_ref( mock_channel.close.assert_called_once() @pytest.mark.asyncio - async def test_error_mid_stream_releases_pool_ref( + async def test_error_should_release_pool_ref_when_raised_mid_stream( self, mocker: MockerFixture, sample_task, async_stream, mock_grpc_call ): """Test an error mid-stream releases the channel. @@ -1714,7 +1825,7 @@ async def test_error_mid_stream_releases_pool_ref( """ # Arrange mock_channel = mocker.AsyncMock() - mocker.patch("grpc.aio.insecure_channel", return_value=mock_channel) + mocker.patch.object(grpc.aio, "insecure_channel", return_value=mock_channel) responses = ( protocol.Response(ack=protocol.Ack()), @@ -1742,7 +1853,7 @@ async def test_error_mid_stream_releases_pool_ref( mock_channel.close.assert_called_once() @pytest.mark.asyncio - async def test_close_invokes_channel_finalizer( + async def test_close_should_invoke_channel_finalizer( self, mocker: MockerFixture, sample_task, async_stream, mock_grpc_call ): """Test that close() tears down the pooled gRPC channel. @@ -1758,7 +1869,7 @@ async def test_close_invokes_channel_finalizer( """ # Arrange mock_channel = mocker.AsyncMock() - mocker.patch("grpc.aio.insecure_channel", return_value=mock_channel) + mocker.patch.object(grpc.aio, "insecure_channel", return_value=mock_channel) responses = ( protocol.Response(ack=protocol.Ack()), @@ -1784,7 +1895,7 @@ async def test_close_invokes_channel_finalizer( mock_channel.close.assert_called_once() @pytest.mark.asyncio - async def test_two_dispatches_share_one_channel( + async def test_two_dispatches_should_share_one_channel_when_target_matches( self, mocker: MockerFixture, sample_task, async_stream, mock_grpc_call ): """Test two dispatches to the same target share one channel. @@ -1828,7 +1939,7 @@ def make_call(): assert protocol.WorkerStub.call_count == 1 @pytest.mark.asyncio - async def test_dispatch_with_default_options( + async def test_dispatch_should_use_default_message_sizes_when_no_options( self, mocker: MockerFixture, sample_task, async_stream, mock_grpc_call ): """Test dispatch creates gRPC channel with default ChannelOptions. @@ -1877,7 +1988,7 @@ async def test_dispatch_with_default_options( ) in call_options @pytest.mark.asyncio - async def test_dispatch_with_custom_options( + async def test_dispatch_should_use_custom_message_sizes_when_options_provided( self, mocker: MockerFixture, sample_task, async_stream, mock_grpc_call ): """Test dispatch creates gRPC channel with custom ChannelOptions. @@ -1928,7 +2039,7 @@ async def test_dispatch_with_custom_options( ) in call_options @pytest.mark.asyncio - async def test_dispatch_with_custom_keepalive_options( + async def test_dispatch_should_use_custom_keepalive_options_when_provided( self, mocker: MockerFixture, sample_task, async_stream, mock_grpc_call ): """Test dispatch creates gRPC channel with custom keepalive options. @@ -1976,7 +2087,7 @@ async def test_dispatch_with_custom_keepalive_options( assert ("grpc.keepalive_permit_without_calls", 0) in call_options @pytest.mark.asyncio - async def test_dispatch_with_default_keepalive_options( + async def test_dispatch_should_use_default_keepalive_options_when_no_options( self, mocker: MockerFixture, sample_task, async_stream, mock_grpc_call ): """Test dispatch includes default keepalive options in channel. @@ -2019,7 +2130,7 @@ async def test_dispatch_with_default_keepalive_options( assert ("grpc.keepalive_permit_without_calls", 1) in call_options @pytest.mark.asyncio - async def test_dispatch_with_custom_transport_options( + async def test_dispatch_should_use_custom_transport_options_when_provided( self, mocker: MockerFixture, sample_task, async_stream, mock_grpc_call ): """Test dispatch creates gRPC channel with custom transport options. @@ -2068,7 +2179,7 @@ async def test_dispatch_with_custom_transport_options( assert ("grpc.default_compression_algorithm", 2) in call_options @pytest.mark.asyncio - async def test_dispatch_with_self_dispatch( + async def test_dispatch_should_serialize_payload_with_cloudpickle_when_self_dispatch( self, mocker: MockerFixture, sample_task, @@ -2120,7 +2231,7 @@ async def test_dispatch_with_self_dispatch( assert asyncio.iscoroutinefunction(restored_callable) @pytest.mark.asyncio - async def test_dispatch_with_self_dispatch_over_uds( + async def test_dispatch_should_use_uds_channel_when_uds_address_set( self, mocker: MockerFixture, sample_task, @@ -2182,7 +2293,7 @@ async def test_dispatch_with_self_dispatch_over_uds( secure_spy.assert_not_called() @pytest.mark.asyncio - async def test_dispatch_with_address_mismatch( + async def test_dispatch_should_use_cross_process_path_when_address_mismatch( self, mocker: MockerFixture, sample_task, async_stream, mock_grpc_call ): """Test dispatch uses cloudpickle when addresses do not match. @@ -2227,7 +2338,7 @@ async def test_dispatch_with_address_mismatch( assert results == ["grpc_result"] @pytest.mark.asyncio - async def test_dispatch_self_dispatch_anext_sends_vars( + async def test_dispatch_should_cloudpickle_vars_when_self_dispatch_anext( self, mocker: MockerFixture, sample_task, @@ -2286,7 +2397,7 @@ async def test_dispatch_self_dispatch_anext_sends_vars( assert len(emitted[(var.namespace, var.name)]) > 16 @pytest.mark.asyncio - async def test_dispatch_self_dispatch_with_response_vars( + async def test_dispatch_should_apply_back_propagated_vars_when_response_carries_vars( self, mocker: MockerFixture, sample_task, @@ -2316,18 +2427,21 @@ async def test_dispatch_self_dispatch_with_response_vars( var = ContextVar("conn_e_var", namespace="conn_e") var.set("original") + caller_chain = wool.__chain__.get().id + responses = ( protocol.Response(ack=protocol.Ack()), protocol.Response( result=protocol.Message(dump=cloudpickle.dumps("result")), - context=protocol.Context( + context=protocol.ChainManifest( + id=caller_chain.hex, vars=[ protocol.ContextVar( namespace=var.namespace, name=var.name, value=cloudpickle.dumps("back_propagated"), ) - ] + ], ), ), ) @@ -2351,7 +2465,181 @@ async def test_dispatch_self_dispatch_with_response_vars( assert var.get() == "back_propagated" @pytest.mark.asyncio - async def test_dispatch_with_corrupt_response_context_and_worker_exception( + async def test_dispatch_should_back_propagate_routine_set_when_caller_unarmed( + self, mocker: MockerFixture, async_stream, mock_grpc_call + ): + """Test a routine's set on a wool.ContextVar reaches an unarmed caller. + + Given: + An unarmed caller (no wool.ContextVar bindings, no chain + installed) dispatching a coroutine routine, and a response + whose context carries a new value for a wool.ContextVar + and a fresh chain id — the shape of a routine that + performed the first ``var.set`` on the worker. + When: + The dispatch result is consumed. + Then: + * Before the dispatch the caller has no value for the var + (``LookupError`` on bare ``get``) and no active chain. + * After the dispatch the caller observes the routine-set + value, and the caller's chain is armed with the worker's + chain id — back-propagation arms the previously-unarmed + caller via the apply-back leg of the wire. + """ + # Arrange + + var = ContextVar("conn_set_back_var", namespace="conn_set_back") + + async def routine() -> None: + return None + + wool_task = Task( + id=uuid4(), + callable=routine, + args=(), + kwargs={}, + proxy=PicklableMock(spec=WorkerProxyLike, id="test-proxy-id"), + ) + + worker_chain_id = uuid4() + responses = ( + protocol.Response(ack=protocol.Ack()), + protocol.Response( + result=protocol.Message(dump=cloudpickle.dumps("done")), + context=protocol.ChainManifest( + id=worker_chain_id.hex, + vars=[ + protocol.ContextVar( + namespace=var.namespace, + name=var.name, + value=cloudpickle.dumps("routine-set-value"), + ) + ], + ), + ), + ) + mock_call = mock_grpc_call(async_stream(responses)) + mock_stub = mocker.MagicMock() + mock_stub.dispatch = mocker.MagicMock(return_value=mock_call) + mocker.patch.object(protocol, "WorkerStub", return_value=mock_stub) + + connection = WorkerConnection( + "localhost:50051", options=ChannelOptions(max_concurrent_streams=10) + ) + + # Pre-dispatch baseline: caller is unarmed, var has no value. + assert wool.__chain__.get(None) is None + with pytest.raises(LookupError): + var.get() + + # Act + results = [] + async for result in await connection.dispatch(wool_task): + results.append(result) + + # Assert — back-propagation applied: value visible, caller armed. + assert results == ["done"] + assert var.get() == "routine-set-value" + armed = wool.__chain__.get(None) + assert armed is not None + assert armed.id == worker_chain_id + + @pytest.mark.asyncio + async def test_async_gen_should_back_propagate_per_yield_when_caller_unarmed( + self, mocker: MockerFixture, async_stream, mock_grpc_call + ): + """Test an async-gen routine's per-yield mutations reach an unarmed caller. + + Given: + An unarmed caller dispatching an async-generator routine, + and a response stream where each yield carries an updated + chain manifest — the shape of an async generator that does + ``var.set("step-N")`` on every iteration. All response + frames carry the same chain id (the worker's cached + chain), so the caller's apply-back stays on one chain + across iterations. + When: + The caller iterates the result stream to exhaustion, + reading the var after every yield. + Then: + * Before the dispatch the caller has no value for the var + and no active chain. + * Each per-yield snapshot equals the worker's most-recent + ``var.set`` — the response-frame mount on every + ``__anext__`` applies the latest binding. + * After exhaustion the caller is armed with the worker's + chain id and the var holds the final yield's value. + """ + # Arrange + + var = ContextVar("conn_set_back_agen_var", namespace="conn_set_back_agen") + + async def streaming_routine(): + for _ in range(3): + yield None + + wool_task = Task( + id=uuid4(), + callable=streaming_routine, + args=(), + kwargs={}, + proxy=PicklableMock(spec=WorkerProxyLike, id="test-proxy-id"), + ) + + worker_chain_id = uuid4() + + def _yield_response(value: str) -> protocol.Response: + """Build a yield-shaped response carrying the per-step var set.""" + return protocol.Response( + result=protocol.Message(dump=cloudpickle.dumps(value)), + context=protocol.ChainManifest( + id=worker_chain_id.hex, + vars=[ + protocol.ContextVar( + namespace=var.namespace, + name=var.name, + value=cloudpickle.dumps(value), + ) + ], + ), + ) + + responses = ( + protocol.Response(ack=protocol.Ack()), + _yield_response("step-0"), + _yield_response("step-1"), + _yield_response("step-2"), + ) + mock_call = mock_grpc_call(async_stream(responses)) + mock_stub = mocker.MagicMock() + mock_stub.dispatch = mocker.MagicMock(return_value=mock_call) + mocker.patch.object(protocol, "WorkerStub", return_value=mock_stub) + + connection = WorkerConnection( + "localhost:50051", options=ChannelOptions(max_concurrent_streams=10) + ) + + # Pre-dispatch baseline: caller is unarmed, var has no value. + assert wool.__chain__.get(None) is None + with pytest.raises(LookupError): + var.get() + + # Act — iterate the stream and snapshot var.get() after each yield. + snapshots: list[str] = [] + async for _ in await connection.dispatch(wool_task): + snapshots.append(var.get()) + + # Assert — every per-yield snapshot matches the worker's + # most-recent mutation; the caller is left armed on the + # worker's chain with the final binding. + assert snapshots == ["step-0", "step-1", "step-2"] + assert var.get() == "step-2" + armed = wool.__chain__.get(None) + assert armed is not None + assert armed.id == worker_chain_id + + @pytest.mark.asyncio + async def test_dispatch_should_raise_and_warn_when_corrupt_context_with_worker_exc( self, mocker: MockerFixture, sample_task, @@ -2372,7 +2660,7 @@ async def test_dispatch_with_corrupt_response_context_and_worker_exception( The caller iterates the dispatch stream Then: The caller raises the worker's routine exception, and a - ContextDecodeWarning naming the corrupt var key is also + SerializationWarning naming the corrupt var key is also emitted — the corrupt var is skipped via the per-entry resilience contract; surviving context state still propagates and the worker's signal still surfaces @@ -2389,7 +2677,7 @@ async def test_dispatch_with_corrupt_response_context_and_worker_exception( exception=protocol.Message( dump=cloudpickle.dumps(ValueError("worker-side failure")) ), - context=protocol.Context( + context=protocol.ChainManifest( vars=[ protocol.ContextVar( namespace=var.namespace, @@ -2410,13 +2698,13 @@ async def test_dispatch_with_corrupt_response_context_and_worker_exception( ) # Act & assert - with pytest.warns(ContextDecodeWarning, match=var.name): + with pytest.warns(SerializationWarning, match=var.name): with pytest.raises(ValueError, match="worker-side failure"): async for _ in await connection.dispatch(sample_task): pass @pytest.mark.asyncio - async def test_dispatch_with_corrupt_response_context_and_result_frame( + async def test_dispatch_should_yield_result_and_warn_when_result_context_corrupt( self, mocker: MockerFixture, sample_task, @@ -2424,7 +2712,7 @@ async def test_dispatch_with_corrupt_response_context_and_result_frame( mock_grpc_call, ): """Test the caller-side response decoder delivers the routine's - return value and emits a ContextDecodeWarning when a result + return value and emits a SerializationWarning when a result frame's accompanying context payload fails to deserialize. Given: @@ -2435,7 +2723,7 @@ async def test_dispatch_with_corrupt_response_context_and_result_frame( The caller iterates the dispatch stream Then: The caller observes the routine's return value normally - and a ContextDecodeWarning is emitted — context + and a SerializationWarning is emitted — context propagation is ancillary state and a decode failure here never preempts the primary signal. Callers that prefer strict semantics can promote the warning to an exception @@ -2451,7 +2739,7 @@ async def test_dispatch_with_corrupt_response_context_and_result_frame( protocol.Response(ack=protocol.Ack()), protocol.Response( result=protocol.Message(dump=cloudpickle.dumps("worker_result")), - context=protocol.Context( + context=protocol.ChainManifest( vars=[ protocol.ContextVar( namespace=var.namespace, @@ -2473,7 +2761,7 @@ async def test_dispatch_with_corrupt_response_context_and_result_frame( # Act results: list[object] = [] - with pytest.warns(ContextDecodeWarning, match="Failed to deserialize"): + with pytest.warns(SerializationWarning, match="Failed to deserialize"): async for value in await connection.dispatch(sample_task): results.append(value) @@ -2481,7 +2769,7 @@ async def test_dispatch_with_corrupt_response_context_and_result_frame( assert results == ["worker_result"] @pytest.mark.asyncio - async def test_dispatch_with_corrupt_response_context_and_result_frame_strict( + async def test_dispatch_should_raise_serialization_error_when_corrupt_context_strict( self, mocker: MockerFixture, sample_task, @@ -2489,26 +2777,26 @@ async def test_dispatch_with_corrupt_response_context_and_result_frame_strict( mock_grpc_call, ): """Test that a caller can opt into strict semantics by promoting - ContextDecodeWarning to an error. + SerializationWarning to an error. Given: The same response shape as the lenient-mode test (result + corrupt context var) When: The caller has installed - ``warnings.filterwarnings("error", category=ContextDecodeWarning)`` + ``warnings.filterwarnings("error", category=SerializationWarning)`` for the duration of the dispatch Then: - Iterating the dispatch raises a :class:`BaseExceptionGroup` - whose sole peer is the promoted - :class:`ContextDecodeWarning` — wool emits decode failures - uniformly through the group shape so caller code stays - symmetric across single- and multi-peer cases (e.g. - decode failure alongside a worker exception). The opt-in - strict mode lets callers treat ancillary failures as - fatal without changing wool's wire-protocol defaults. + Iterating the dispatch raises a typed + :class:`wool.ChainSerializationError` aggregating the promoted + warnings on ``.warnings``. On a result frame the decode + error IS the primary — the routine's value is dropped + because a result cannot be trusted alongside a context + that failed to apply (strict-mode "fail loud" contract). """ # Arrange + import wool + target = "localhost:50051" var = ContextVar( "strict_corrupt_context_var", @@ -2518,7 +2806,7 @@ async def test_dispatch_with_corrupt_response_context_and_result_frame_strict( protocol.Response(ack=protocol.Ack()), protocol.Response( result=protocol.Message(dump=cloudpickle.dumps("worker_result")), - context=protocol.Context( + context=protocol.ChainManifest( vars=[ protocol.ContextVar( namespace=var.namespace, @@ -2542,15 +2830,15 @@ async def test_dispatch_with_corrupt_response_context_and_result_frame_strict( import warnings as _warnings with _warnings.catch_warnings(): - _warnings.simplefilter("error", category=ContextDecodeWarning) - with pytest.raises(BaseExceptionGroup) as exc_info: + _warnings.simplefilter("error", category=SerializationWarning) + with pytest.raises(wool.ChainSerializationError) as exc_info: async for _ in await connection.dispatch(sample_task): pass - assert len(exc_info.value.exceptions) == 1 - assert isinstance(exc_info.value.exceptions[0], ContextDecodeWarning) + assert len(exc_info.value.warnings) == 1 + assert isinstance(exc_info.value.warnings[0], SerializationWarning) @pytest.mark.asyncio - async def test_dispatch_without_serializer_uses_cloudpickle_for_vars( + async def test_dispatch_should_serialize_vars_with_cloudpickle_when_no_serializer( self, mocker: MockerFixture, sample_task, @@ -2563,10 +2851,14 @@ async def test_dispatch_without_serializer_uses_cloudpickle_for_vars( A WorkerConnection whose target does not match the current worker's address and a ContextVar with a value set When: - The dispatch stream writes requests + The dispatch stream writes requests — the initial + :class:`TaskRequestFrame` ships pure dispatch metadata + with no chain manifest, and the first + :class:`NextRequestFrame` (sent by ``__anext__`` to pull + the result) auto-captures the active chain Then: - The vars in each request should be serialized via - cloudpickle + The first mid-stream :class:`NextRequestFrame` should + carry the var with its value serialized via cloudpickle. """ # Arrange var = ContextVar("conn_f_var", namespace="conn_f") @@ -2592,31 +2884,38 @@ async def test_dispatch_without_serializer_uses_cloudpickle_for_vars( async for result in await connection.dispatch(sample_task): results.append(result) - # Assert — the initial request vars should be cloudpickle bytes + # Assert — the first NextRequest (write[1]) carries the var; + # the initial TaskRequest (write[0]) is pure dispatch metadata. assert results == ["result"] - initial_request = mock_call.write.call_args_list[0][0][0] - emitted = {(e.namespace, e.name): e.value for e in initial_request.context.vars} + next_request = mock_call.write.call_args_list[1][0][0] + emitted = {(e.namespace, e.name): e.value for e in next_request.context.vars} assert (var.namespace, var.name) in emitted assert len(emitted[(var.namespace, var.name)]) > 16 @pytest.mark.asyncio - async def test_dispatch_self_dispatch_initial_request_roundtrips_vars( + async def test_dispatch_should_roundtrip_vars_when_self_dispatch_mid_stream_request( self, mocker: MockerFixture, sample_task, async_stream, mock_grpc_call, ): - """Test self-dispatch cloudpickle-encodes vars on the initial request. + """Test self-dispatch cloudpickle-encodes vars on the first mid-stream request. Given: A WorkerConnection whose target matches the current worker's address and a ContextVar with a value set When: - dispatch() sends the initial task request + dispatch() sends the initial task request (pure dispatch + metadata, no chain manifest) and the first + :class:`NextRequestFrame` (which auto-captures the active + chain) Then: - It should cloudpickle-encode the initial request's context - vars so they round-trip back to the original value. + The first mid-stream :class:`NextRequestFrame` should + cloudpickle-encode the var so it round-trips back to the + original value. Under the per-frame architecture the + initial Task frame is intentionally manifest-free — + mid-stream frames ship the per-step manifest. """ # Arrange target = "localhost:50051" @@ -2649,16 +2948,16 @@ async def test_dispatch_self_dispatch_initial_request_roundtrips_vars( async for result in await connection.dispatch(sample_task): results.append(result) - # Assert — the initial request (first write) carries the var - # cloudpickle-encoded; decoding it yields the original value. + # Assert — the first NextRequest (write[1]) carries the var; + # decoding it yields the original value. assert results == ["result"] - initial_request = mock_call.write.call_args_list[0][0][0] - emitted = {(e.namespace, e.name): e.value for e in initial_request.context.vars} + next_request = mock_call.write.call_args_list[1][0][0] + emitted = {(e.namespace, e.name): e.value for e in next_request.context.vars} assert (var.namespace, var.name) in emitted assert cloudpickle.loads(emitted[(var.namespace, var.name)]) == "roundtrip_value" @pytest.mark.asyncio - async def test_dispatch_cross_process_initial_request_encodes_callable( + async def test_dispatch_should_encode_callable_when_cross_process_initial_request( self, mocker: MockerFixture, sample_task, @@ -2715,7 +3014,7 @@ async def test_dispatch_cross_process_initial_request_encodes_callable( st.dictionaries(st.text(), st.integers()), ) ) - async def test_dispatch_self_dispatch_initial_request_var_roundtrip_property( + async def test_dispatch_should_roundtrip_arbitrary_var_when_self_dispatch_mid_stream( self, value, mocker: MockerFixture, @@ -2723,14 +3022,17 @@ async def test_dispatch_self_dispatch_initial_request_var_roundtrip_property( async_stream, mock_grpc_call, ): - """Test self-dispatch round-trips arbitrary ContextVar values. + """Test self-dispatch round-trips arbitrary ContextVar values + on the first mid-stream request. Given: A self-dispatch WorkerConnection and any picklable ContextVar value drawn from text, ints, tuples, or dicts When: - dispatch() writes the initial request and the var is - decoded from the emitted protocol.Context + dispatch() writes the initial task request (no wire + context) and the first NextRequest (which auto-captures + the chain), and the var is decoded from the emitted + ``next_request.context`` Then: It should decode to a value equal to the original. """ @@ -2772,13 +3074,13 @@ async def test_dispatch_self_dispatch_initial_request_var_roundtrip_property( # Assert assert results == ["result"] - initial_request = mock_call.write.call_args_list[0][0][0] - emitted = {(e.namespace, e.name): e.value for e in initial_request.context.vars} + next_request = mock_call.write.call_args_list[1][0][0] + emitted = {(e.namespace, e.name): e.value for e in next_request.context.vars} assert (var.namespace, var.name) in emitted assert cloudpickle.loads(emitted[(var.namespace, var.name)]) == value @pytest.mark.asyncio - async def test_dispatch_response_exception_with_non_exception_payload_falls_back_to_unexpected_response( + async def test_dispatch_should_raise_unexpected_response_when_exc_payload_non_exc( self, mocker: MockerFixture, sample_task, async_stream, mock_grpc_call ): """Test dispatch wraps a non-Exception ``Response.exception`` @@ -2831,7 +3133,55 @@ async def test_dispatch_response_exception_with_non_exception_payload_falls_back assert not isinstance(exc_info.value, RpcError) @pytest.mark.asyncio - async def test_dispatch_response_exception_with_cancelled_error_propagates_as_cancelled_error( + async def test_dispatch_should_preserve_base_exception_payload_on_context( + self, mocker: MockerFixture, sample_task, async_stream, mock_grpc_call + ): + """Test dispatch preserves a non-Exception BaseException payload + on the wrapper's ``__context__``. + + Given: + A :class:`protocol.Response` whose ``exception`` field + carries a cloudpickle dump of a non-Exception + :class:`BaseException` (a ``KeyboardInterrupt``). + When: + ``dispatch(task)`` is awaited and the result iterator is + consumed. + Then: + It should raise :class:`UnexpectedResponse` (not + :class:`RpcError`) with the original ``KeyboardInterrupt`` + preserved on ``__context__`` — a process-level signal is not + smuggled across the wire as a raisable, but it is not lost. + """ + # Arrange + responses = ( + protocol.Response(ack=protocol.Ack()), + protocol.Response( + exception=protocol.Message( + dump=cloudpickle.dumps(KeyboardInterrupt("boom")), + ) + ), + ) + mock_call = mock_grpc_call(async_stream(responses)) + + mock_stub = mocker.MagicMock() + mock_stub.dispatch = mocker.MagicMock(return_value=mock_call) + mocker.patch.object(protocol, "WorkerStub", return_value=mock_stub) + + connection = WorkerConnection( + "localhost:50051", options=ChannelOptions(max_concurrent_streams=10) + ) + + # Act & assert + with pytest.raises( + UnexpectedResponse, match="non-Exception payload" + ) as exc_info: + async for _ in await connection.dispatch(sample_task): + pass + assert isinstance(exc_info.value.__context__, KeyboardInterrupt) + assert not isinstance(exc_info.value, RpcError) + + @pytest.mark.asyncio + async def test_dispatch_should_propagate_cancelled_error_raw_when_worker_cancels( self, mocker: MockerFixture, sample_task, async_stream, mock_grpc_call ): """Test dispatch propagates a worker-side @@ -2905,21 +3255,21 @@ async def body(): assert not isinstance(exc_info.value, RpcError) @pytest.mark.asyncio - async def test_dispatch_response_exception_with_cancelled_error_increments_caller_cancelling( + async def test_dispatch_should_not_increment_cancelling_when_worker_cancels( self, mocker: MockerFixture, sample_task, async_stream, mock_grpc_call ): - """Test dispatch increments current_task().cancelling() on - a worker-side CancelledError. - - Mirrors stdlib's local-cancel state shape: a caller catching - :class:`asyncio.CancelledError` from a wool routine must - observe ``current_task().cancelling() > 0`` — the same shape - it would see for a local cancel — so idiomatic - ``if cancelling() > 0: raise`` re-raise gates and - ``current_task().uncancel()`` absorbers behave identically - regardless of whether the cancel originated locally or on - the worker. ``uncancel()`` must also decrement the count - back to zero per asyncio's documented contract. + """Test dispatch does NOT increment ``current_task().cancelling()`` + on a worker-side CancelledError. + + Matches stdlib ``await task`` semantics: when the awaitee + raises :class:`asyncio.CancelledError`, the awaiter's + ``cancelling()`` count is **not** bumped. A caller that + catches ``CancelledError`` and continues to ``await`` + something else (a recovery path) is therefore not + re-interrupted at the next checkpoint — the wool-naive + caller does not need to call + ``current_task().uncancel()`` to absorb a phantom bump + (F9). Given: A :class:`protocol.Response` whose ``exception`` field @@ -2929,9 +3279,9 @@ async def test_dispatch_response_exception_with_cancelled_error_increments_calle ``dispatch(task)`` is awaited and the result iterator is consumed, and the resulting ``CancelledError`` is caught Then: - It should observe ``current_task().cancelling() > 0`` - synchronously with the catch, and ``uncancel()`` should - decrement the count back to ``0``. + ``current_task().cancelling()`` should remain ``0`` — + the worker-shipped CancelledError propagates as-is and + no synchronous bump of the awaiter's state happens. """ # Arrange cancellation = asyncio.CancelledError("worker self-raised cancel") @@ -2951,18 +3301,8 @@ async def test_dispatch_response_exception_with_cancelled_error_increments_calle "localhost:50051", options=ChannelOptions(max_concurrent_streams=10) ) - observed: dict[str, int | None] = { - "cancelling": None, - "post_uncancel": None, - } + observed: dict[str, int | None] = {"cancelling": None} - # Run the cancellable consumption in an inner task so the - # observations happen on a task we control. On Python 3.11 - # ``Task.uncancel()`` only decrements the counter without - # clearing ``_must_cancel`` — the inner task therefore - # finalises as cancelled even though ``body()`` returned a - # value. The outer ``await wrapped`` consumes that - # cancellation, isolating the test runner's task. async def body(): try: async for _ in await connection.dispatch(sample_task): @@ -2971,26 +3311,18 @@ async def body(): current = asyncio.current_task() assert current is not None observed["cancelling"] = current.cancelling() - observed["post_uncancel"] = current.uncancel() - - wrapped = asyncio.ensure_future(body()) - # Act - try: - await wrapped - except asyncio.CancelledError: - # On Python 3.11 the scheduled cancel still fires after - # body() returns; on 3.12+ uncancel() suppresses it. - # Either outcome is fine — we assert on the observations - # captured inside the except arm. - pass + await asyncio.ensure_future(body()) # Assert - assert observed["cancelling"] is not None and observed["cancelling"] > 0 - assert observed["post_uncancel"] == 0 + assert observed["cancelling"] == 0, ( + "worker-shipped CancelledError must not bump the " + "awaiter's cancelling() count — stdlib ``await task`` " + "semantics keep it at 0" + ) @pytest.mark.asyncio - async def test_dispatch_response_exception_with_cancelled_error_propagates_to_task_cancelled_state( + async def test_dispatch_should_leave_task_cancelled_when_worker_cancels( self, mocker: MockerFixture, sample_task, async_stream, mock_grpc_call ): """Test re-raising a worker-side CancelledError ends the @@ -3044,104 +3376,31 @@ async def body(): assert wrapped.cancelled() @pytest.mark.asyncio - async def test_dispatch_response_exception_with_decode_failures_swallows_note_write_rejection( + async def test_dispatch_should_propagate_raw_when_result_payload_malformed( self, mocker: MockerFixture, sample_task, async_stream, mock_grpc_call ): - """Test the strict-mode note/attribute write attempts are - swallowed for slotted exception classes that reject them. - - Given: - A :class:`Response` whose ``exception`` payload is a - ``_StrictRejectingException`` (rejects ``add_note`` with - :class:`AttributeError` and arbitrary attribute writes - including ``__wool_context_warnings__`` with - :class:`AttributeError`) AND whose ``context`` decode - raises a :class:`BaseExceptionGroup` of - :class:`ContextDecodeWarning` peers (strict-mode - promotion) - When: - ``dispatch(task)`` is awaited and the result iterator is - consumed - Then: - It should raise the ``_StrictRejectingException`` - unchanged — the failed note/attribute writes are - swallowed under ``except (AttributeError, TypeError)`` - and ``except AttributeError`` respectively, so the - routine's primary signal still ships. - """ - from wool.runtime.context import Context - - # Arrange — patch Context.from_protobuf to raise the - # strict-mode decode group on each call. The exception arm - # in _read_next gets ``decode_failures`` populated and then - # tries to attach them via add_note and a sidecar attribute. - peer = ContextDecodeWarning("var-1 unencodable") - - def encode_with_strict_failure(cls, *args, **kwargs): - raise BaseExceptionGroup("strict-mode encode group", [peer]) - - mocker.patch.object( - Context, - "from_protobuf", - classmethod(encode_with_strict_failure), - ) - - responses = ( - protocol.Response(ack=protocol.Ack()), - protocol.Response( - exception=protocol.Message( - dump=cloudpickle.dumps(_StrictRejectingException("primary signal")), - ) - ), - ) - mock_call = mock_grpc_call(async_stream(responses)) - - mock_stub = mocker.MagicMock() - mock_stub.dispatch = mocker.MagicMock(return_value=mock_call) - mocker.patch.object(protocol, "WorkerStub", return_value=mock_stub) - - connection = WorkerConnection( - "localhost:50051", options=ChannelOptions(max_concurrent_streams=10) - ) - - # Act & assert — the routine's primary signal type is - # preserved; no stray AttributeError leaks from the - # swallowed note/attribute writes. - with pytest.raises(_StrictRejectingException) as exc_info: - async for _ in await connection.dispatch(sample_task): - pass - - assert "primary signal" in str(exc_info.value) - # The sidecar attribute was never set because the class - # rejects arbitrary writes — the swallow is the assertion. - assert not hasattr(exc_info.value, "__wool_context_warnings__") - - @pytest.mark.asyncio - async def test_dispatch_response_result_with_malformed_payload_falls_back_to_unexpected_response( - self, mocker: MockerFixture, sample_task, async_stream, mock_grpc_call - ): - """Test dispatch wraps a malformed ``Response.result`` payload - as :class:`UnexpectedResponse` so the load balancer does not - evict the worker for what is typically caller-side version - skew on a shared result class. + """Test dispatch lets a malformed ``Response.result`` payload + deserialization error propagate with its original type. Given: A :class:`protocol.Response` whose ``result`` field carries bytes that cannot be deserialized - (b"not a valid pickle stream") + (b"not a valid pickle stream"). When: ``dispatch(task)`` is awaited and the result iterator is - consumed + consumed. Then: - It should raise :class:`UnexpectedResponse` whose - message names the malformed result payload, chained - from the underlying deserialization error via - ``__cause__``; :class:`UnexpectedResponse` is not an - :class:`RpcError` subclass so the load-balancer - classification treats it as a caller-fault and does - not evict the worker. + It should raise the underlying serializer error + (:class:`pickle.UnpicklingError` for cloudpickle's + default deserializer) raw — no wrapper class, no + indirection. The original exception type carries the + diagnostic detail. Since it isn't an :class:`RpcError` + subclass the load-balancer classification treats it as + a caller-fault and does not evict the worker. """ # Arrange + import pickle + responses = ( protocol.Response(ack=protocol.Ack()), protocol.Response( @@ -3159,46 +3418,41 @@ async def test_dispatch_response_result_with_malformed_payload_falls_back_to_une ) # Act & assert - with pytest.raises( - UnexpectedResponse, match="malformed result payload" - ) as exc_info: + with pytest.raises(pickle.UnpicklingError) as exc_info: async for _ in await connection.dispatch(sample_task): pass - # Belt-and-suspenders: the worker-eviction contract is - # carried by ``RpcError``; a malformed-result degradation - # must not surface as an ``RpcError`` subclass. + # The serializer error propagates raw — not wrapped in + # UnexpectedResponse, not an RpcError subclass. + assert not isinstance(exc_info.value, UnexpectedResponse) assert not isinstance(exc_info.value, RpcError) - # The underlying deserialization error is preserved on - # ``__cause__`` for diagnostic chains. - assert exc_info.value.__cause__ is not None @pytest.mark.asyncio - async def test_dispatch_response_exception_with_malformed_payload_falls_back_to_unexpected_response( + async def test_dispatch_should_propagate_raw_when_exception_payload_malformed( self, mocker: MockerFixture, sample_task, async_stream, mock_grpc_call ): - """Test dispatch wraps a malformed ``Response.exception`` - payload as :class:`UnexpectedResponse` so the load balancer - does not evict the worker for what is typically caller-side - version skew on a shared exception class. + """Test dispatch lets a malformed ``Response.exception`` + payload deserialization error propagate with its original + type. Given: A :class:`protocol.Response` whose ``exception`` field carries bytes that cannot be deserialized - (b"not a valid pickle stream") + (b"not a valid pickle stream"). When: ``dispatch(task)`` is awaited and the result iterator is - consumed + consumed. Then: - It should raise :class:`UnexpectedResponse` whose - message names the malformed exception payload, with the - underlying deserialization error preserved on - ``__cause__`` and ``__suppress_context__`` set so the - implicit context chain is suppressed. - :class:`UnexpectedResponse` is not an :class:`RpcError` - subclass so the load-balancer classification treats it - as a caller-fault and does not evict the worker. + It should raise the underlying serializer error + (:class:`pickle.UnpicklingError` for cloudpickle's + default deserializer) raw — no wrapper class, no + indirection. The original exception type carries the + diagnostic detail. Since it isn't an :class:`RpcError` + subclass the load-balancer classification treats it as + a caller-fault and does not evict the worker. """ # Arrange + import pickle + responses = ( protocol.Response(ack=protocol.Ack()), protocol.Response( @@ -3216,23 +3470,17 @@ async def test_dispatch_response_exception_with_malformed_payload_falls_back_to_ ) # Act & assert - with pytest.raises( - UnexpectedResponse, match="malformed exception payload" - ) as exc_info: + with pytest.raises(pickle.UnpicklingError) as exc_info: async for _ in await connection.dispatch(sample_task): pass - # Belt-and-suspenders: a routine-time decode mismatch must - # not surface as an ``RpcError`` subclass. + # The serializer error propagates raw — not wrapped in + # UnexpectedResponse, not an RpcError subclass. + assert not isinstance(exc_info.value, UnexpectedResponse) assert not isinstance(exc_info.value, RpcError) - # Manual ``__cause__`` chaining preserves the original - # pickle/import failure for diagnostics; the implicit - # context chain is suppressed via ``__suppress_context__``. - assert exc_info.value.__cause__ is not None - assert exc_info.value.__suppress_context__ is True @pytest.mark.asyncio -async def test_clear_channel_pool_tears_down_cached_channels( +async def test_clear_channel_pool_should_close_cached_channels( mocker: MockerFixture, sample_task, async_stream, mock_grpc_call ): """Test :func:`clear_channel_pool` closes every cached gRPC @@ -3278,7 +3526,7 @@ async def test_clear_channel_pool_tears_down_cached_channels( @pytest.mark.asyncio -async def test_clear_channel_pool_with_empty_pool_returns_without_raising(): +async def test_clear_channel_pool_should_not_raise_when_pool_empty(): """Test :func:`clear_channel_pool` is a no-op when the pool is empty. diff --git a/wool/tests/runtime/worker/test_frame.py b/wool/tests/runtime/worker/test_frame.py new file mode 100644 index 00000000..cc658de1 --- /dev/null +++ b/wool/tests/runtime/worker/test_frame.py @@ -0,0 +1,801 @@ +"""Unit tests for the Frame subtype hierarchy in :mod:`wool.runtime.worker.frame`.""" + +import threading +import warnings +from uuid import uuid4 + +import cloudpickle +import pytest +from hypothesis import given +from hypothesis import settings +from hypothesis import strategies as st + +import wool +from tests.helpers import scoped_context +from wool import protocol +from wool.runtime.context.exceptions import SerializationWarning +from wool.runtime.worker.frame import AckResponseFrame +from wool.runtime.worker.frame import ExceptionResponseFrame +from wool.runtime.worker.frame import Frame +from wool.runtime.worker.frame import NackResponseFrame +from wool.runtime.worker.frame import NextRequestFrame +from wool.runtime.worker.frame import RequestFrame +from wool.runtime.worker.frame import ResponseFrame +from wool.runtime.worker.frame import ResultResponseFrame +from wool.runtime.worker.frame import SendRequestFrame +from wool.runtime.worker.frame import TaskRequestFrame +from wool.runtime.worker.frame import ThrowRequestFrame +from wool.runtime.worker.frame import _safely_serialize_exception + + +def _duplicate_key_wire(key: tuple[str, str]) -> protocol.ChainManifest: + """Build a wire ChainManifest carrying *key* twice. + + Duplicate ``(namespace, name)`` keys are undefined behaviour Wool's own + encoder never emits; this fabricates the malformed wire a buggy or + hostile peer could send so the strict-mode decode path raises. + """ + return protocol.ChainManifest( + id=uuid4().hex, + vars=[ + protocol.ContextVar(namespace=key[0], name=key[1]), + protocol.ContextVar(namespace=key[0], name=key[1]), + ], + ) + + +def _duplicate_key_response(key: tuple[str, str]) -> protocol.Response: + """Build a wire Response whose chain manifest carries *key* twice.""" + response = protocol.Response( + result=protocol.Message(dump=wool.__serializer__.dumps("value")) + ) + response.context.CopyFrom(_duplicate_key_wire(key)) + return response + + +class TestSafelySerializeException: + """Tests for the type-preserving exception serialization fallback.""" + + def test_safely_serialize_exception_should_ship_bare_type_when_unpicklable(self): + """Test the fallback ships the bare type when only attachments fail. + + Given: + A StopAsyncIteration whose __cause__ drags an un-picklable + object, so the primary serialization and the + attachment-carrying fallback both fail. + When: + It is serialized via the exception serialization helper. + Then: + The wire payload should deserialize to a StopAsyncIteration + with its original args — the bare reconstructed type is + shipped rather than demoted to a generic RuntimeError, so the + caller's except clause still matches. + """ + + # Arrange + class _Unpicklable(Exception): + def __init__(self): + super().__init__() + self._lock = threading.Lock() + + exc = StopAsyncIteration("from coroutine") + exc.__cause__ = _Unpicklable() + + # Act + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + dump = _safely_serialize_exception(wool.__serializer__, exc) + restored = wool.__serializer__.loads(dump) + + # Assert + assert isinstance(restored, StopAsyncIteration) + assert restored.args == ("from coroutine",) + + def test_safely_serialize_exception_should_demote_to_runtime_error(self): + """Test the fallback demotes to RuntimeError when args cannot pickle. + + Given: + An exception whose own args carry an un-picklable object, so + no reconstruction of its type can be serialized. + When: + It is serialized via the exception serialization helper. + Then: + The wire payload should deserialize to a RuntimeError naming + the original class — the always-picklable last resort when + the bare type itself cannot survive. + """ + # Arrange + exc = ValueError(threading.Lock()) + + # Act + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + dump = _safely_serialize_exception(wool.__serializer__, exc) + restored = wool.__serializer__.loads(dump) + + # Assert + assert type(restored) is RuntimeError + assert "ValueError" in str(restored) + + def test_safely_serialize_exception_should_warn_fidelity_loss(self): + """Test the fallback reports fidelity loss for an unreconstructible cause. + + Given: + An un-picklable exception (so the primary serialization + fails) whose __cause__ is an exception that raises when its + type is re-constructed. + When: + It is serialized via the exception serialization helper. + Then: + The wire payload should still deserialize to the original + top-level type with the un-walkable cause level truncated + rather than taking the whole exception down, and a + SerializationWarning should report the fidelity loss. The + warning should name the original exception class and + carry the primary serialization failure as cause, with no + var_key or direction (the failure is exception-fidelity + loss, not a chain-manifest hop). + """ + + # Arrange + class _OnceConstructible(Exception): + _made = False + + def __init__(self, *args): + if type(self)._made: + raise RuntimeError("cannot rebuild") + type(self)._made = True + super().__init__(*args) + + bad_cause = _OnceConstructible("c") + bad_cause.__cause__ = ValueError("deeper") + bad_cause.__cause__.__cause__ = RuntimeError("deepest") + exc = ValueError("top") + exc._lock = threading.Lock() # un-picklable __dict__ → primary fails + exc.__cause__ = bad_cause + + # Act + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + dump = _safely_serialize_exception(wool.__serializer__, exc) + restored = wool.__serializer__.loads(dump) + + # Assert + assert type(restored) is ValueError + assert restored.args == ("top",) + assert restored.__cause__ is None # chain truncated at the bad level + fidelity = [ + w.message for w in caught if isinstance(w.message, SerializationWarning) + ] + assert fidelity + assert fidelity[0].original_type is ValueError + assert fidelity[0].cause is not None + assert fidelity[0].var_key is None + assert fidelity[0].direction is None + + def test_strict_mode_should_still_ship_exception_when_warnings_are_errors(self): + """Test fidelity loss stays non-fatal even when warnings are errors. + + Given: + An un-picklable exception (so the fallback runs) serialized + while SerializationWarning is promoted to an error. + When: + It is serialized via the exception serialization helper. + Then: + The promoted warning should be swallowed and the bare type + still shipped — fidelity loss must never deprive the caller + of the primary exception signal. + """ + # Arrange + exc = ValueError("strict") + exc._lock = threading.Lock() # un-picklable __dict__ → primary fails + + # Act + with warnings.catch_warnings(): + warnings.simplefilter("error", SerializationWarning) + dump = _safely_serialize_exception(wool.__serializer__, exc) + restored = wool.__serializer__.loads(dump) + + # Assert + assert type(restored) is ValueError + assert restored.args == ("strict",) + + def test_safely_serialize_exception_should_cap_cause_chain(self): + """Test the fallback bounds an over-deep __cause__ chain. + + Given: + An un-picklable exception whose __cause__ chain is deeper + than the walk's depth limit. + When: + It is serialized via the exception serialization helper. + Then: + The wire payload should deserialize to the original + top-level type with the reconstructed cause chain capped + at the walk bound, and a SerializationWarning should + report the fidelity loss. + """ + # Arrange + top = ValueError("L0") + top._lock = threading.Lock() # un-picklable __dict__ → primary fails + node = top + for i in range(70): # exceeds the 64-level walk bound + nxt = ValueError(f"L{i + 1}") + node.__cause__ = nxt + node = nxt + + # Act + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + dump = _safely_serialize_exception(wool.__serializer__, top) + restored = wool.__serializer__.loads(dump) + + # Assert + assert type(restored) is ValueError + assert restored.args == ("L0",) + depth = 0 + link = restored.__cause__ + while link is not None: + depth += 1 + link = link.__cause__ + assert depth == 64 # capped at the walk bound, tail dropped + fidelity = [ + w.message for w in caught if isinstance(w.message, SerializationWarning) + ] + assert fidelity + + @given( + exc_type=st.sampled_from( + [ValueError, RuntimeError, KeyError, TypeError, StopAsyncIteration] + ), + message=st.text(), + ) + @settings(max_examples=100, deadline=None) + def test_exception_payload_should_round_trip_through_the_send_path( + self, exc_type, message + ): + """Test an exception round-trips through the exception-frame wire path. + + Given: + Any standard exception type constructed with arbitrary + message text. + When: + It is shipped via ``ExceptionResponseFrame.for_send`` and + decoded back with :meth:`Frame.from_protobuf`. + Then: + The restored payload should preserve the exception type and + args — the type-preserving serializer survives the wire + round trip through the public frame path. + """ + # Arrange + exc = exc_type(message) + + # Act + wire = ExceptionResponseFrame.for_send( + exc, wire_chain_manifest=None + ).to_protobuf() + restored = Frame.from_protobuf(wire) + + # Assert + assert isinstance(restored, ExceptionResponseFrame) + assert type(restored.payload) is exc_type + assert restored.payload.args == exc.args + + +class TestFrameHierarchy: + """Tests for the Frame / RequestFrame / ResponseFrame / leaf layout.""" + + def test_request_frame_should_declare_request_wire_type(self): + """Test RequestFrame declares the Request wire envelope. + + Given: + The RequestFrame intermediate class. + When: + ``_wire_type_name`` is inspected. + Then: + It should resolve to the ``Request`` protobuf message class. + """ + # Act + wire_cls = getattr(protocol, RequestFrame._wire_type_name) + + # Assert + assert wire_cls is protocol.Request + + def test_response_frame_should_declare_response_wire_type(self): + """Test ResponseFrame declares the Response wire envelope. + + Given: + The ResponseFrame intermediate class. + When: + ``_wire_type_name`` is inspected. + Then: + It should resolve to the ``Response`` protobuf message class. + """ + # Act + wire_cls = getattr(protocol, ResponseFrame._wire_type_name) + + # Assert + assert wire_cls is protocol.Response + + @pytest.mark.parametrize( + ("leaf", "parent", "field"), + [ + (TaskRequestFrame, RequestFrame, "task"), + (NextRequestFrame, RequestFrame, "next"), + (SendRequestFrame, RequestFrame, "send"), + (ThrowRequestFrame, RequestFrame, "throw"), + (AckResponseFrame, ResponseFrame, "ack"), + (NackResponseFrame, ResponseFrame, "nack"), + (ResultResponseFrame, ResponseFrame, "result"), + (ExceptionResponseFrame, ResponseFrame, "exception"), + ], + ) + def test_leaf_should_declare_payload_field_and_inherit_intermediate( + self, leaf, parent, field + ): + """Test each leaf declares its payload-oneof field and inherits its envelope. + + Given: + A concrete frame leaf class. + When: + Its ``_payload_field`` and parentage are inspected. + Then: + ``_payload_field`` should match the protobuf oneof field + name, and ``parent`` should be one of the two abstract + intermediates (``RequestFrame`` / ``ResponseFrame``). + """ + # Assert + assert leaf._payload_field == field + assert issubclass(leaf, parent) + assert issubclass(leaf, Frame) + + def test_duplicate_payload_field_should_raise(self): + """Test a second leaf claiming an in-use payload field raises. + + Given: + A concrete leaf already registered under a payload-oneof field + (``TaskRequestFrame`` under ``"task"``). + When: + Another concrete subclass claims the same ``field``. + Then: + __init_subclass__ should raise TypeError naming the conflict, + so schema drift fails loudly instead of silently shadowing the + original leaf. + """ + # Act & assert + with pytest.raises(TypeError, match="already registered"): + + class _DuplicateTask(RequestFrame, field="task"): + pass + + def test_fieldless_subclass_should_not_register(self): + """Test a subclass that declares no field is not a dispatch target. + + Given: + A Frame subclass defined without the ``field=`` class keyword, + as the abstract intermediates are. + When: + The class body is evaluated. + Then: + __init_subclass__ should neither raise nor register it in the + leaf dispatch table. + """ + # Act + + class _Intermediate(Frame): + pass + + # Assert + assert _Intermediate not in Frame._frame_by_field.values() + + def test_abstract_payload_hooks_should_raise_not_implemented(self): + """Test the base Frame payload hooks raise when not overridden. + + Given: + The abstract Frame base, whose ``_decode_payload`` / + ``_encode_payload`` are not overridden. + When: + Each default hook is invoked directly. + Then: + Both should raise NotImplementedError naming the hook — a + mistyped or abstract frame fails loudly rather than + silently mis-encoding. + """ + # Act & assert + with pytest.raises(NotImplementedError, match="_decode_payload"): + Frame._decode_payload(object(), serializer=wool.__serializer__) + with pytest.raises(NotImplementedError, match="_encode_payload"): + Frame(payload=None)._encode_payload(serializer=wool.__serializer__) + + +class TestFrameFromProtobuf: + """Tests for :meth:`Frame.from_protobuf` leaf dispatch.""" + + def test_from_protobuf_should_return_next_request_frame(self): + """Test from_protobuf returns a NextRequestFrame for a ``next`` request. + + Given: + A wire :class:`protocol.Request` carrying the ``next`` payload. + When: + :meth:`Frame.from_protobuf` decodes it. + Then: + The returned frame should be a :class:`NextRequestFrame` instance. + """ + # Arrange + wire = protocol.Request(next=protocol.Void()) + + # Act + frame = Frame.from_protobuf(wire) + + # Assert + assert isinstance(frame, NextRequestFrame) + assert frame.payload is None + + def test_from_protobuf_should_return_result_response_frame(self): + """Test from_protobuf returns a ResultResponseFrame for a ``result`` response. + + Given: + A wire :class:`protocol.Response` carrying a ``result`` payload. + When: + :meth:`Frame.from_protobuf` decodes it. + Then: + The returned frame should be a :class:`ResultResponseFrame` + carrying the deserialised payload. + """ + # Arrange + wire = protocol.Response( + result=protocol.Message(dump=cloudpickle.dumps("value")) + ) + + # Act + frame = Frame.from_protobuf(wire) + + # Assert + assert isinstance(frame, ResultResponseFrame) + assert frame.payload == "value" + + def test_empty_request_should_raise_value_error(self): + """Test from_protobuf rejects a wire envelope with no payload. + + Given: + A wire :class:`protocol.Request` with no payload oneof set. + When: + :meth:`Frame.from_protobuf` decodes it. + Then: + It should raise ValueError — an envelope must carry exactly + one payload variant. + """ + # Act & assert + with pytest.raises(ValueError, match="no payload"): + Frame.from_protobuf(protocol.Request()) + + def test_from_protobuf_should_return_ack_response_frame(self): + """Test from_protobuf returns an AckResponseFrame for an ``ack`` response. + + Given: + A wire :class:`protocol.Response` carrying the ``ack`` payload. + When: + :meth:`Frame.from_protobuf` decodes it. + Then: + The returned frame should be an :class:`AckResponseFrame` + with no payload — Ack is a pure boundary signal. + """ + # Arrange + wire = protocol.Response(ack=protocol.Ack()) + + # Act + frame = Frame.from_protobuf(wire) + + # Assert + assert isinstance(frame, AckResponseFrame) + assert frame.payload is None + + def test_from_protobuf_should_return_nack_response_frame(self): + """Test from_protobuf returns a NackResponseFrame carrying the exception. + + Given: + A wire :class:`protocol.Response` carrying a ``nack`` payload + whose serialized message is a rejection exception. + When: + :meth:`Frame.from_protobuf` decodes it. + Then: + The returned frame should be a :class:`NackResponseFrame` + carrying the deserialized exception. + """ + # Arrange + wire = protocol.Response( + nack=protocol.Nack( + exception=protocol.Message( + dump=wool.__serializer__.dumps(ValueError("rejected")) + ) + ) + ) + + # Act + frame = Frame.from_protobuf(wire) + + # Assert + assert isinstance(frame, NackResponseFrame) + assert isinstance(frame.payload, ValueError) + assert frame.payload.args == ("rejected",) + + def test_from_protobuf_should_leave_chain_manifest_none_when_no_wire_context(self): + """Test from_protobuf leaves chain_manifest None when no wire context. + + Given: + A wire :class:`protocol.Response` carrying a result payload + but no ``context`` field. + When: + :meth:`Frame.from_protobuf` decodes it. + Then: + The frame's ``chain_manifest`` should be ``None`` — there is + no chain state to mount. + """ + # Arrange + wire = protocol.Response(result=protocol.Message(dump=cloudpickle.dumps("v"))) + + # Act + frame = Frame.from_protobuf(wire) + + # Assert + assert frame.chain_manifest is None + + def test_from_protobuf_should_capture_error_when_strict_decode_fails(self): + """Test from_protobuf captures a strict-mode decode failure on the frame. + + Given: + A wire :class:`protocol.Response` whose context carries a + duplicated ``(namespace, name)`` key, decoded under strict + mode (SerializationWarning promoted to error). + When: + :meth:`Frame.from_protobuf` decodes it. + Then: + It should capture the :class:`ChainSerializationError` as the + frame's ``chain_manifest`` value rather than raising — the + failure is deferred to mount, and the aggregated warning + carries the duplicated var_key. + """ + # Arrange + from wool.runtime.context.exceptions import ChainSerializationError + + key = ("dup_ns", "dup_var") + + # Act + with warnings.catch_warnings(): + warnings.simplefilter("error", SerializationWarning) + frame = Frame.from_protobuf(_duplicate_key_response(key)) + + # Assert + assert isinstance(frame.chain_manifest, ChainSerializationError) + assert frame.chain_manifest.warnings[0].var_key == key + + @given( + payload=st.one_of(st.text(), st.integers(), st.binary(), st.lists(st.integers())) + ) + @settings(max_examples=100, deadline=None) + def test_result_payload_should_round_trip_through_the_send_path(self, payload): + """Test a result payload round-trips through the result-frame wire path. + + Given: + Any picklable result payload. + When: + It is shipped via ``ResultResponseFrame.for_send`` and + decoded back with :meth:`Frame.from_protobuf`. + Then: + The restored frame should carry an equal payload. + """ + # Arrange & act + wire = ResultResponseFrame.for_send( + payload, wire_chain_manifest=None + ).to_protobuf() + restored = Frame.from_protobuf(wire) + + # Assert + assert isinstance(restored, ResultResponseFrame) + assert restored.payload == payload + + +class TestFrameToProtobuf: + """Tests for :meth:`Frame.to_protobuf` wire-envelope emission.""" + + def test_to_protobuf_should_copy_captured_wire_chain_manifest_into_envelope(self): + """Test to_protobuf copies a captured chain manifest into the envelope. + + Given: + A response frame built via ``for_send`` with an explicit + non-None chain manifest. + When: + :meth:`Frame.to_protobuf` encodes it. + Then: + The emitted wire envelope should carry the ``context`` field + — the armed chain rides out on the frame. + """ + # Arrange + wire_chain_manifest = protocol.ChainManifest(id=uuid4().hex) + frame = ResultResponseFrame.for_send( + "value", wire_chain_manifest=wire_chain_manifest + ) + + # Act + wire = frame.to_protobuf() + + # Assert + assert wire.HasField("context") + assert wire.context.id == wire_chain_manifest.id + + def test_for_send_should_default_serializer_when_omitted(self): + """Test for_send falls back to the package serializer when none is given. + + Given: + A boundary frame (AckResponseFrame) that does not override + for_send. + When: + ``for_send`` is called with no serializer argument. + Then: + The frame should encode and decode intact — the omitted + serializer falls back to the package default, so the send + path still round-trips. + """ + # Act + wire = AckResponseFrame.for_send().to_protobuf() + restored = Frame.from_protobuf(wire) + + # Assert + assert isinstance(restored, AckResponseFrame) + + +class TestChainsDecodeErrorOntoPayload: + """Tests for the mixin's chain-walk into the payload exception.""" + + def _decode_failure(self): + """Build a synthetic strict-mode decode failure — the union's error arm. + + ``Frame.chain_manifest`` is typed + ``ChainManifest | ChainSerializationError | None``, so a failed + decode is represented by the error itself; the chaining tests pass + this directly as ``chain_manifest=``. + """ + from wool.runtime.context.exceptions import ChainSerializationError + from wool.runtime.context.exceptions import SerializationWarning + + return ChainSerializationError(SerializationWarning("synthetic decode failure")) + + def test_exception_frame_should_chain_decode_error_onto_payload(self): + """Test ExceptionResponseFrame.mount chains decode_error onto the payload. + + Given: + An :class:`ExceptionResponseFrame` whose chain manifest + carries a strict-mode :class:`ChainSerializationError`. + When: + ``frame.mount()`` runs. + Then: + The decode error should land at the end of the payload + exception's ``__context__`` chain rather than propagating + out — the routine exception remains the primary signal. + """ + # Arrange + payload = RuntimeError("worker failed") + frame = ExceptionResponseFrame( + payload=payload, chain_manifest=self._decode_failure() + ) + + # Act + with scoped_context(): + frame.mount() + + # Assert + chained = payload.__context__ + from wool.runtime.context.exceptions import ChainSerializationError as _CSE + + assert isinstance(chained, _CSE) + + def test_nack_frame_should_chain_decode_error_onto_payload(self): + """Test NackResponseFrame.mount chains decode_error onto payload.__context__. + + Given: + A :class:`NackResponseFrame` whose context manifest carries + a strict-mode :class:`ChainSerializationError`. + When: + ``frame.mount()`` runs. + Then: + The decode error should land at the end of the rejection + exception's ``__context__`` chain. + """ + # Arrange + rejection = RuntimeError("rejected pre-routine") + frame = NackResponseFrame( + payload=rejection, chain_manifest=self._decode_failure() + ) + + # Act + with scoped_context(): + frame.mount() + + # Assert + chained = rejection.__context__ + from wool.runtime.context.exceptions import ChainSerializationError as _CSE + + assert isinstance(chained, _CSE) + + def test_throw_frame_should_chain_decode_error_onto_payload(self): + """Test ThrowRequestFrame.mount chains decode_error onto payload.__context__. + + Given: + A :class:`ThrowRequestFrame` whose context manifest carries + a strict-mode :class:`ChainSerializationError`. + When: + ``frame.mount()`` runs. + Then: + The decode error should land at the end of the thrown + exception's ``__context__`` chain so the worker-side + routine sees both signals on the eventual ``athrow``. + """ + # Arrange + thrown = ValueError("caller-side throw payload") + frame = ThrowRequestFrame(payload=thrown, chain_manifest=self._decode_failure()) + + # Act + with scoped_context(): + frame.mount() + + # Assert + chained = thrown.__context__ + from wool.runtime.context.exceptions import ChainSerializationError as _CSE + + assert isinstance(chained, _CSE) + + def test_non_exception_leaf_should_raise_decode_error(self): + """Test ResultResponseFrame.mount raises decode_error (no mixin). + + Given: + A :class:`ResultResponseFrame` (no chain-walk mixin) whose + context manifest carries a strict-mode + :class:`ChainSerializationError`. + When: + ``frame.mount()`` runs. + Then: + The decode error should raise raw — non-exception-bearing + leaves don't chain on payload. + """ + # Arrange + from wool.runtime.context.exceptions import ChainSerializationError + + frame = ResultResponseFrame( + payload="value", chain_manifest=self._decode_failure() + ) + + # Act & assert + with scoped_context(): + with pytest.raises(ChainSerializationError): + frame.mount() + + def test_mount_should_append_decode_error_past_existing_context_chain(self): + """Test mount appends the decode error past a pre-existing __context__. + + Given: + An :class:`ExceptionResponseFrame` whose payload exception + already has a ``__context__`` chain, plus a strict-mode + decode error. + When: + ``frame.mount()`` runs. + Then: + The decode error should land at the *end* of the existing + chain — the pre-existing context is preserved, and the + decode error chains beyond it. + """ + # Arrange + from wool.runtime.context.exceptions import ChainSerializationError as _CSE + + inner = ValueError("pre-existing context") + payload = RuntimeError("worker failed") + payload.__context__ = inner + frame = ExceptionResponseFrame( + payload=payload, chain_manifest=self._decode_failure() + ) + + # Act + with scoped_context(): + frame.mount() + + # Assert + assert payload.__context__ is inner # preserved, not overwritten + assert isinstance(inner.__context__, _CSE) # decode error chained past it diff --git a/wool/tests/runtime/worker/test_service.py b/wool/tests/runtime/worker/test_service.py index 9a2bf787..f9c56295 100644 --- a/wool/tests/runtime/worker/test_service.py +++ b/wool/tests/runtime/worker/test_service.py @@ -1,5 +1,4 @@ import asyncio -import pickle import threading from contextlib import asynccontextmanager from uuid import uuid4 @@ -19,7 +18,6 @@ from wool import protocol from wool.protocol import WorkerStub from wool.protocol import add_WorkerServicer_to_server -from wool.runtime.context import install_task_factory from wool.runtime.routine.task import Task from wool.runtime.routine.task import WorkerProxyLike from wool.runtime.worker.interceptor import VersionInterceptor @@ -29,6 +27,17 @@ from .conftest import PicklableMock +def make_task(callable, *, proxy_id="test-proxy-id"): + """Build a `wool.Task` wrapping *callable* with a throwaway worker proxy.""" + return Task( + id=uuid4(), + callable=callable, + args=(), + kwargs={}, + proxy=PicklableMock(spec=WorkerProxyLike, id=proxy_id), + ) + + @pytest.fixture(scope="function") def grpc_interceptors(): return [VersionInterceptor()] @@ -59,7 +68,7 @@ def grpc_stub_cls(): # # Assumes serial test execution within this module: each test that # uses the event takes responsibility for setting it to a fresh -# :class:`threading.Event` in its arrange phase and resetting it to +# `threading.Event` in its arrange phase and resetting it to # ``None`` in its finally clause. If the test suite is ever run with # parallel collection within this module, these globals must be # re-keyed (e.g., a dict keyed by ``task.id``) to avoid races. @@ -79,35 +88,37 @@ async def _controllable_task(): return "task_completed" -# Cross-loop side-channel for A1 regression: routine on the worker -# loop signals via this threading.Event when it observes -# CancelledError; the test asserts on it from the main loop. -_a1_cancellation_observed: threading.Event | None = None +# Cross-loop side-channel for the mid-stream cancellation regression: +# routine on the worker loop signals via this threading.Event when it +# observes CancelledError; the test asserts on it from the main loop. +_midstream_cancellation_observed: threading.Event | None = None -# Side-channel for the A1 regression test to confirm the worker -# routine is actually running before the test cancels the stream. +# Side-channel for the mid-stream cancellation regression test +# to confirm the worker routine is actually running before the +# test cancels the stream. # The dispatch ``ack`` only confirms the handler reached its # ``yield ack`` — the worker task is scheduled lazily on the # handler's first ``async for`` iteration. Cancelling before the -# routine starts races :meth:`DispatchSession._schedule_worker`, +# routine starts races `DispatchSession._schedule_worker`, # which short-circuits on ``_cancelled`` and never dispatches the # routine, leaving nothing for the cancellation to interrupt. The # routine sets this event as its first statement; the test waits # for it before cancelling. -_a1_routine_started: threading.Event | None = None +_midstream_routine_started: threading.Event | None = None # Cross-loop side-channel for the stop+cancel regression tests -# (``test_stop_and_cancel`` and ``test_stop_and_cancel_streaming_routine``). -# The routine on the worker loop signals via this :class:`threading.Event` -# when it observes :class:`asyncio.CancelledError`; the test asserts on it -# from the main loop. Separate from ``_a1_cancellation_observed`` so the +# (``test_stop_should_cancel_active_coroutine_routine`` and +# ``test_stop_should_cancel_active_streaming_routine``). +# The routine on the worker loop signals via this `threading.Event` +# when it observes `asyncio.CancelledError`; the test asserts on it +# from the main loop. Separate from ``_midstream_cancellation_observed`` so the # tests do not interfere when running concurrently or in arbitrary order. _stop_cancellation_observed: threading.Event | None = None # Side-channel used by the stop+cancel regression tests to confirm # the routine has actually started running before the test sends # ``stop``. Without this barrier the test races -# :meth:`DispatchSession.__aiter__`'s lazy worker scheduling: on +# `DispatchSession.__aiter__`'s lazy worker scheduling: on # slower Python versions/runtimes, ``stop`` can land before the # worker task is created, so ``session.cancel()`` has no # ``_worker_task`` to cancel and the routine never observes @@ -117,9 +128,9 @@ async def _controllable_task(): _stop_routine_started: threading.Event | None = None # Cross-loop side-channel for the issue #202 worker-loop drain test -# (``test_stop_with_orphaned_cleanup_chain``). The second-generation -# cleanup task runs on the worker loop and sets this -# :class:`threading.Event`; the test asserts on it from the main loop. +# (``test_stop_should_drain_every_generation_of_orphaned_cleanup_chain``). +# The second-generation cleanup task runs on the worker loop and sets this +# `threading.Event`; the test asserts on it from the main loop. # The probe routine schedules a two-generation cleanup chain whose # second generation is observed only when worker-loop teardown drains # every generation, not just the first. @@ -142,7 +153,7 @@ async def _drain_probe_routine(): async def _drain_probe_first_gen(): - """First-generation orphan scheduled by :func:`_drain_probe_routine`. + """First-generation orphan scheduled by `_drain_probe_routine`. Awaits indefinitely until worker-loop teardown cancels it, then schedules the second generation from its own ``finally`` clause. @@ -154,9 +165,9 @@ async def _drain_probe_first_gen(): async def _drain_probe_second_gen(): - """Second-generation orphan scheduled by :func:`_drain_probe_first_gen`. + """Second-generation orphan scheduled by `_drain_probe_first_gen`. - Sets :data:`_drain_cleanup_observed` from its ``finally`` clause. + Sets `_drain_cleanup_observed` from its ``finally`` clause. The event is set only when worker-loop teardown drains every generation, not just the first. """ @@ -168,9 +179,9 @@ async def _drain_probe_second_gen(): class _AttributeRejectingRoutineError(Exception): - """Module-level exception class for the A4 regression test. + """Module-level exception class for the attribute-rejection regression test. - Overrides ``__setattr__`` to raise :class:`AttributeError` for + Overrides ``__setattr__`` to raise `AttributeError` for arbitrary attribute writes — modeling exception types whose storage layout (e.g., ``__slots__`` derived from a slotted parent, or C-extension types with custom attribute machinery) @@ -181,8 +192,7 @@ class _AttributeRejectingRoutineError(Exception): The override forwards the standard ``args``/``__cause__``/ ``__context__``/``__traceback__``/``__notes__`` slots so ``BaseException`` machinery and PEP 678 ``add_note`` continue - to work — only arbitrary attribute writes (like - ``__wool_context_warnings__``) raise. + to work — only arbitrary attribute writes raise. """ _ALLOWED = frozenset( @@ -206,23 +216,23 @@ def __setattr__(self, name, value): ) -async def _a1_long_routine(): - """Module-level routine for the A1 regression test. +async def _midstream_long_routine(): + """Module-level routine for the mid-stream cancellation regression test. Defined at module level so cloudpickle can serialize the - callable for dispatch. Signals :data:`_a1_routine_started` as + callable for dispatch. Signals `_midstream_routine_started` as its first statement so the test can wait for the routine to be running before cancelling. Sleeps long enough that the test - will have given up; signals :data:`_a1_cancellation_observed` - if interrupted by :class:`asyncio.CancelledError`. + will have given up; signals `_midstream_cancellation_observed` + if interrupted by `asyncio.CancelledError`. """ - if _a1_routine_started is not None: - _a1_routine_started.set() + if _midstream_routine_started is not None: + _midstream_routine_started.set() try: await asyncio.sleep(30) except asyncio.CancelledError: - if _a1_cancellation_observed is not None: - _a1_cancellation_observed.set() + if _midstream_cancellation_observed is not None: + _midstream_cancellation_observed.set() raise return "should_not_complete" @@ -230,12 +240,12 @@ async def _a1_long_routine(): async def _stop_long_coroutine(): """Module-level coroutine for the stop+cancel regression test. - Signals :data:`_stop_routine_started` so the test can wait for + Signals `_stop_routine_started` so the test can wait for the routine to actually start before sending ``stop`` (avoids - racing :meth:`DispatchSession.__aiter__`'s lazy worker + racing `DispatchSession.__aiter__`'s lazy worker scheduling), then sleeps long enough that the test will have - given up; signals :data:`_stop_cancellation_observed` if - interrupted by :class:`asyncio.CancelledError`. Defined at + given up; signals `_stop_cancellation_observed` if + interrupted by `asyncio.CancelledError`. Defined at module level so cloudpickle can serialize the callable for dispatch. """ @@ -254,19 +264,19 @@ async def _stop_streaming_routine(): """Module-level async generator for the stop+cancel streaming regression test. - Signals :data:`_stop_routine_started`, yields one value, then + Signals `_stop_routine_started`, yields one value, then sleeps long enough that the test will have given up; signals - :data:`_stop_cancellation_observed` if interrupted by - :class:`asyncio.CancelledError` or :class:`GeneratorExit`. + `_stop_cancellation_observed` if interrupted by + `asyncio.CancelledError` or `GeneratorExit`. Defined at module level so cloudpickle can serialize the callable for dispatch. Both exception types signal cancellation from the worker side: operator-preempt cancels the worker driver task; depending on where the routine is suspended (mid-await vs at a yield) the - teardown path either propagates :class:`asyncio.CancelledError` - through the await or injects :class:`GeneratorExit` via - :func:`routine_scope`'s ``aclose``. Either is a valid + teardown path either propagates `asyncio.CancelledError` + through the await or injects `GeneratorExit` via + `routine_scope`'s ``aclose``. Either is a valid observation of cancellation reaching the routine. """ if _stop_routine_started is not None: @@ -298,15 +308,7 @@ async def service_fixture(mocker: MockerFixture, grpc_aio_stub): service = WorkerService() _control_event = threading.Event() - mock_proxy = PicklableMock(spec=WorkerProxyLike, id="test-proxy-id") - - wool_task = Task( - id=uuid4(), - callable=_controllable_task, - args=(), - kwargs={}, - proxy=mock_proxy, - ) + wool_task = make_task(_controllable_task) request = protocol.Request(task=wool_task.to_protobuf()) @@ -334,13 +336,13 @@ async def service_fixture(mocker: MockerFixture, grpc_aio_stub): class TestWorkerService: - def test___init___with_defaults(self): - """Test :class:`WorkerService` initialization. + def test___init___should_expose_unset_lifecycle_events(self): + """Test `WorkerService` initialization. Given: No preconditions When: - :class:`WorkerService` is instantiated + `WorkerService` is instantiated Then: It should initialize successfully and expose its stopping and stopped events """ @@ -354,14 +356,14 @@ def test___init___with_defaults(self): assert not service.stopped.is_set() @pytest.mark.asyncio - async def test_dispatch_task_that_returns( + async def test_dispatch_should_return_result_when_task_returns( self, grpc_aio_stub, mocker: MockerFixture, mock_worker_proxy_cache ): - """Test :class:`WorkerService` dispatch successfully executes task that returns + """Test `WorkerService` dispatch successfully executes task that returns a value. Given: - A gRPC :class:`WorkerService` that is not stopping or stopped + A gRPC `WorkerService` that is not stopping or stopped When: Dispatch RPC is called with a task that returns a value Then: @@ -372,15 +374,7 @@ async def test_dispatch_task_that_returns( async def sample_task(): return "test_result" - mock_proxy = PicklableMock(spec=WorkerProxyLike, id="test-proxy-id") - - wool_task = Task( - id=uuid4(), - callable=sample_task, - args=(), - kwargs={}, - proxy=mock_proxy, - ) + wool_task = make_task(sample_task) request = protocol.Request(task=wool_task.to_protobuf()) @@ -398,14 +392,14 @@ async def sample_task(): assert cloudpickle.loads(reponse.result.dump) == "test_result" @pytest.mark.asyncio - async def test_dispatch_task_that_raises( + async def test_dispatch_should_return_exception_when_task_raises( self, grpc_aio_stub, mocker: MockerFixture, mock_worker_proxy_cache ): - """Test :class:`WorkerService` dispatch successfully executes task + """Test `WorkerService` dispatch successfully executes task that raises an exception. Given: - A gRPC :class:`WorkerService` that is not stopping or stopped + A gRPC `WorkerService` that is not stopping or stopped When: Dispatch RPC is called with a task that raises an exception Then: @@ -416,15 +410,7 @@ async def test_dispatch_task_that_raises( async def failing_task(): raise ValueError("test_exception") - mock_proxy = PicklableMock(spec=WorkerProxyLike, id="test-proxy-id") - - wool_task = Task( - id=uuid4(), - callable=failing_task, - args=(), - kwargs={}, - proxy=mock_proxy, - ) + wool_task = make_task(failing_task) request = protocol.Request(task=wool_task.to_protobuf()) @@ -439,23 +425,26 @@ async def failing_task(): ack, response = responses assert ack.HasField("ack") assert response.HasField("exception") - assert response.HasField("context") + # Under lazy-wire-frame semantics, an unarmed worker (this + # task never touched a wool.ContextVar) omits the optional + # context field — there's nothing to propagate. + assert not response.HasField("context") exception = cloudpickle.loads(response.exception.dump) assert isinstance(exception, ValueError) assert str(exception) == "test_exception" @pytest.mark.asyncio - async def test_dispatch_with_corrupt_context_under_strict_filter( + async def test_dispatch_should_nack_when_corrupt_context_under_strict_filter( self, grpc_aio_stub, mock_worker_proxy_cache ): - """Test :class:`WorkerService` dispatch under strict mode for - :class:`ContextDecodeWarning` when the caller's context + """Test `WorkerService` dispatch under strict mode for + `SerializationWarning` when the caller's context carries a corrupt var payload. Given: A dispatch Request whose ``context.vars`` map carries a corrupt byte payload, and the worker-side warning filter - promotes :class:`ContextDecodeWarning` to an exception + promotes `SerializationWarning` to an exception (modeling ``warnings.filterwarnings("error", category=...)`` set via ``PYTHONWARNINGS`` or programmatic config in the @@ -463,14 +452,14 @@ async def test_dispatch_with_corrupt_context_under_strict_filter( When: The dispatch RPC is invoked with that request Then: - It should reply with exactly one :class:`Nack` response + It should reply with exactly one `Nack` response (no preceding Ack) whose ``exception`` field decodes to - a :class:`BaseExceptionGroup` carrying the promoted - :class:`ContextDecodeWarning` as its sole peer — so - worker-side strict mode preserves the same uniform - group shape that caller-side strict mode produces, and - the leaf class identity remains addressable via - ``except*`` regardless of peer cardinality. + a typed `wool.ChainSerializationError` aggregating the + promoted warnings on ``.warnings``, each carrying the + corrupt variable's key and the decode direction across + the wire. The caller's existing + ``except wool.ChainSerializationError`` clause matches without + needing ``except*`` migration. """ import warnings as _warnings @@ -478,17 +467,10 @@ async def test_dispatch_with_corrupt_context_under_strict_filter( async def sample_task(): return "should_not_execute" - mock_proxy = PicklableMock(spec=WorkerProxyLike, id="test-proxy-id") - wool_task = Task( - id=uuid4(), - callable=sample_task, - args=(), - kwargs={}, - proxy=mock_proxy, - ) + wool_task = make_task(sample_task) namespace = f"strict_corrupt_{uuid4().hex}" var: wool.ContextVar[str] = wool.ContextVar("x", namespace=namespace) - context_pb = protocol.Context(id=uuid4().hex) + context_pb = protocol.ChainManifest(id=uuid4().hex) context_pb.vars.add( namespace=var.namespace, name=var.name, @@ -501,7 +483,7 @@ async def sample_task(): # Act with _warnings.catch_warnings(): - _warnings.simplefilter("error", category=wool.ContextDecodeWarning) + _warnings.simplefilter("error", category=wool.SerializationWarning) async with grpc_aio_stub() as stub: stream = stub.dispatch() await stream.write(request) @@ -514,29 +496,97 @@ async def sample_task(): assert nack.HasField("nack") assert nack.nack.HasField("exception") raised = cloudpickle.loads(nack.nack.exception.dump) - assert isinstance(raised, BaseExceptionGroup) - assert len(raised.exceptions) == 1 - peer = raised.exceptions[0] - assert isinstance(peer, wool.ContextDecodeWarning) - assert "Failed to deserialize" in str(peer) + assert isinstance(raised, wool.ChainSerializationError) + assert "failed to serialize across the wire" in str(raised) + assert "Failed to deserialize" in str(raised.warnings[0]) + # Warnings aggregated on .warnings, structured fields intact + # after the wire round trip. + assert len(raised.warnings) == 1 + assert isinstance(raised.warnings[0], wool.SerializationWarning) + assert raised.warnings[0].var_key == (var.namespace, var.name) + assert raised.warnings[0].direction == "decode" @pytest.mark.asyncio - async def test_dispatch_with_malformed_task_id( + async def test_dispatch_should_nack_when_multiple_corrupt_vars_under_strict_filter( self, grpc_aio_stub, mock_worker_proxy_cache ): - """Test :class:`WorkerService` dispatch when the wire-shipped + """Test pre-Ack strict-mode attaches multiple peers symmetrically. + + Given: + A dispatch Request whose ``context.vars`` carries two + corrupt var payloads, and the worker-side warning filter + promotes `SerializationWarning` to an exception. + When: + The dispatch RPC is invoked with that request. + Then: + It should reply with exactly one `Nack` whose + ``exception`` decodes to a typed + `wool.ChainSerializationError` aggregating every + promoted warning on ``.warnings``, so callers handle + multi-failure decoding with one ``except + ChainSerializationError`` clause. + """ + import warnings as _warnings + + # Arrange + async def sample_task(): + return "should_not_execute" + + wool_task = make_task(sample_task) + namespace = f"strict_multi_{uuid4().hex}" + var_a: wool.ContextVar[str] = wool.ContextVar("a", namespace=namespace) + var_b: wool.ContextVar[str] = wool.ContextVar("b", namespace=namespace) + context_pb = protocol.ChainManifest(id=uuid4().hex) + context_pb.vars.add( + namespace=var_a.namespace, + name=var_a.name, + value=b"\x00garbage-a\x00", + ) + context_pb.vars.add( + namespace=var_b.namespace, + name=var_b.name, + value=b"\x00garbage-b\x00", + ) + request = protocol.Request( + task=wool_task.to_protobuf(), + context=context_pb, + ) + + # Act + with _warnings.catch_warnings(): + _warnings.simplefilter("error", category=wool.SerializationWarning) + async with grpc_aio_stub() as stub: + stream = stub.dispatch() + await stream.write(request) + await stream.done_writing() + responses = [r async for r in stream] + + # Assert + assert len(responses) == 1 + nack = responses[0] + assert nack.HasField("nack") + raised = cloudpickle.loads(nack.nack.exception.dump) + assert isinstance(raised, wool.ChainSerializationError) + assert len(raised.warnings) == 2 + assert all(isinstance(w, wool.SerializationWarning) for w in raised.warnings) + + @pytest.mark.asyncio + async def test_dispatch_should_nack_with_value_error_when_task_id_malformed( + self, grpc_aio_stub, mock_worker_proxy_cache + ): + """Test `WorkerService` dispatch when the wire-shipped task id cannot be parsed as a UUID. Given: A dispatch Request whose ``task.id`` field is a non-hex / non-UUID string, so ``UUID(request.task.id)`` - raises :class:`ValueError` inside the parse phase + raises `ValueError` inside the parse phase When: The dispatch RPC is invoked with that request Then: - It should reply with exactly one :class:`Nack` response + It should reply with exactly one `Nack` response (no preceding Ack) whose ``exception`` field decodes to - the original :class:`ValueError`, surfacing the actual + the original `ValueError`, surfacing the actual parse-failure class to the caller rather than an opaque gRPC error. """ @@ -545,14 +595,7 @@ async def test_dispatch_with_malformed_task_id( async def sample_task(): return "should_not_execute" - mock_proxy = PicklableMock(spec=WorkerProxyLike, id="test-proxy-id") - wool_task = Task( - id=uuid4(), - callable=sample_task, - args=(), - kwargs={}, - proxy=mock_proxy, - ) + wool_task = make_task(sample_task) task_pb = wool_task.to_protobuf() task_pb.id = "not-a-valid-uuid" request = protocol.Request(task=task_pb) @@ -573,20 +616,20 @@ async def sample_task(): assert isinstance(raised, ValueError) @pytest.mark.asyncio - async def test_dispatch_with_corrupt_task_callable( + async def test_dispatch_should_nack_with_unpickling_error_when_task_callable_corrupt( self, grpc_aio_stub, mock_worker_proxy_cache ): - """Test :class:`WorkerService` dispatch when the wire-shipped + """Test `WorkerService` dispatch when the wire-shipped task callable bytes cannot be deserialized by cloudpickle. Given: A dispatch Request whose ``task.callable`` field carries - corrupt bytes, so :meth:`Task.from_protobuf` raises + corrupt bytes, so `Task.from_protobuf` raises during cloudpickle.loads inside the parse phase When: The dispatch RPC is invoked with that request Then: - It should reply with exactly one :class:`Nack` response + It should reply with exactly one `Nack` response (no preceding Ack) whose ``exception`` field decodes to the underlying cloudpickle / unpickling error, surfacing the actual parse-failure class to the caller. @@ -596,14 +639,7 @@ async def test_dispatch_with_corrupt_task_callable( async def sample_task(): return "should_not_execute" - mock_proxy = PicklableMock(spec=WorkerProxyLike, id="test-proxy-id") - wool_task = Task( - id=uuid4(), - callable=sample_task, - args=(), - kwargs={}, - proxy=mock_proxy, - ) + wool_task = make_task(sample_task) task_pb = wool_task.to_protobuf() task_pb.callable = b"\x00not a valid pickle stream\x00" request = protocol.Request(task=task_pb) @@ -624,13 +660,13 @@ async def sample_task(): assert isinstance(raised, Exception) @pytest.mark.asyncio - async def test_dispatch_with_corrupt_context_var_value( + async def test_dispatch_should_run_routine_and_warn_when_context_var_value_corrupt( self, grpc_aio_stub, mock_worker_proxy_cache ): - """Test :class:`WorkerService` dispatch runs the routine when a + """Test `WorkerService` dispatch runs the routine when a caller-shipped ``request.context.vars`` entry cannot be deserialized, falling back to a fresh empty context and - emitting a :class:`ContextDecodeWarning`. + emitting a `SerializationWarning`. Given: A dispatch Request whose ``context.vars`` map carries a @@ -641,7 +677,7 @@ async def test_dispatch_with_corrupt_context_var_value( The dispatch RPC is invoked with that request Then: The routine still runs and returns its value, a - :class:`ContextDecodeWarning` is emitted on the worker, + `SerializationWarning` is emitted on the worker, and the response is delivered normally — context propagation is ancillary state and a decode failure here does not preempt the primary signal @@ -651,17 +687,10 @@ async def test_dispatch_with_corrupt_context_var_value( async def sample_task(): return "routine_ran" - mock_proxy = PicklableMock(spec=WorkerProxyLike, id="test-proxy-id") - wool_task = Task( - id=uuid4(), - callable=sample_task, - args=(), - kwargs={}, - proxy=mock_proxy, - ) + wool_task = make_task(sample_task) namespace = f"corrupt_val_{uuid4().hex}" var: wool.ContextVar[str] = wool.ContextVar("x", namespace=namespace) - context_pb = protocol.Context(id=uuid4().hex) + context_pb = protocol.ChainManifest(id=uuid4().hex) context_pb.vars.add( namespace=var.namespace, name=var.name, @@ -673,63 +702,7 @@ async def sample_task(): ) # Act - with pytest.warns(wool.ContextDecodeWarning, match="Failed to deserialize"): - async with grpc_aio_stub() as stub: - stream = stub.dispatch() - await stream.write(request) - await stream.done_writing() - responses = [r async for r in stream] - - # Assert - result_responses = [r for r in responses if r.HasField("result")] - assert len(result_responses) == 1 - assert cloudpickle.loads(result_responses[0].result.dump) == "routine_ran" - - @pytest.mark.asyncio - async def test_dispatch_with_malformed_context_id( - self, grpc_aio_stub, mock_worker_proxy_cache - ): - """Test :class:`WorkerService` dispatch runs the routine when the - caller's ``request.context.id`` is not a valid hex UUID, - falling back to a fresh empty context and emitting a - :class:`ContextDecodeWarning`. - - Given: - A dispatch Request whose ``context.id`` field is a - non-hex string (e.g., ``"not-a-uuid"``) - When: - The dispatch RPC is invoked with that request - Then: - The routine still runs and returns its value, a - :class:`ContextDecodeWarning` is emitted, and the - response is delivered normally — malformed wire context - is treated as ancillary state lost, not a request - rejection - """ - - # Arrange - async def sample_task(): - return "routine_ran" - - mock_proxy = PicklableMock(spec=WorkerProxyLike, id="test-proxy-id") - wool_task = Task( - id=uuid4(), - callable=sample_task, - args=(), - kwargs={}, - proxy=mock_proxy, - ) - # carries_state requires a non-empty vars list for - # from_protobuf to even attempt parsing the id, so seed a - # ContextVar entry carrying a consumed-token id alongside - # the malformed Context id. - bad_ctx = protocol.Context(id="not-a-uuid") - bad_ctx.vars.add(namespace="", name="", consumed_tokens=[uuid4().hex]) - task_pb = wool_task.to_protobuf() - request = protocol.Request(task=task_pb, context=bad_ctx) - - # Act - with pytest.warns(wool.ContextDecodeWarning): + with pytest.warns(wool.SerializationWarning, match="Failed to deserialize"): async with grpc_aio_stub() as stub: stream = stub.dispatch() await stream.write(request) @@ -742,12 +715,12 @@ async def sample_task(): assert cloudpickle.loads(result_responses[0].result.dump) == "routine_ran" @pytest.mark.asyncio - async def test_dispatch_streaming_with_mid_stream_corrupt_context( + async def test_dispatch_should_warn_when_mid_stream_context_corrupt( self, grpc_aio_stub, mock_worker_proxy_cache ): - """Test :class:`WorkerService` dispatch continues an async-generator + """Test `WorkerService` dispatch continues an async-generator iteration when a mid-stream frame carries a corrupt context, - emitting a :class:`ContextDecodeWarning` instead of failing + emitting a `SerializationWarning` instead of failing the dispatch. Given: @@ -760,7 +733,7 @@ async def test_dispatch_streaming_with_mid_stream_corrupt_context( stream Then: The generator's second yield is delivered as a normal - result frame and a :class:`ContextDecodeWarning` is + result frame and a `SerializationWarning` is emitted on the worker — the corrupt mid-stream context is treated as ancillary state lost rather than a terminal failure @@ -771,20 +744,13 @@ async def streamer(): for i in range(5): yield f"value_{i}" - mock_proxy = PicklableMock(spec=WorkerProxyLike, id="test-proxy-id") - wool_task = Task( - id=uuid4(), - callable=streamer, - args=(), - kwargs={}, - proxy=mock_proxy, - ) + wool_task = make_task(streamer) first_request = protocol.Request(task=wool_task.to_protobuf()) good_next = protocol.Request( next=protocol.Void(), - context=protocol.Context(id=uuid4().hex), + context=protocol.ChainManifest(id=uuid4().hex), ) - bad_ctx = protocol.Context(id=uuid4().hex) + bad_ctx = protocol.ChainManifest(id=uuid4().hex) bad_ctx.vars.add( namespace="test", name="corrupt_key", @@ -810,7 +776,7 @@ async def drive(): await stream.done_writing() return second - with pytest.warns(wool.ContextDecodeWarning): + with pytest.warns(wool.SerializationWarning): second = await asyncio.wait_for(drive(), timeout=5.0) # Assert @@ -818,28 +784,28 @@ async def drive(): assert cloudpickle.loads(second.result.dump) == "value_1" @pytest.mark.asyncio - async def test_dispatch_streaming_with_unpicklable_worker_mutation( + async def test_dispatch_should_warn_when_streaming_worker_mutation_unpicklable( self, grpc_aio_stub, mock_worker_proxy_cache ): - """Test :class:`WorkerService` dispatch delivers the next yield - when a worker-side snapshot serialization fails between - iterations, emitting a :class:`ContextDecodeWarning` instead + """Test `WorkerService` dispatch delivers the next yield + when a worker-side context serialization fails between + iterations, emitting a `SerializationWarning` instead of failing the dispatch. Given: An async-generator routine that, between yields, sets a - :class:`wool.ContextVar` to a value whose ``__reduce__`` - raises — the wool back-prop snapshot - (``Context.to_protobuf``) on the next iteration cannot + `wool.ContextVar` to a value whose ``__reduce__`` + raises — the wool back-prop encode + (`ChainManifest.to_protobuf`) on the next iteration cannot serialize the var When: The caller drives the generator past the unpicklable assignment Then: The next yield is delivered as a normal result frame - with an empty wire context, and a - :class:`ContextDecodeWarning` is emitted on the worker — - the snapshot failure is ancillary state and does not + with an empty chain manifest, and a + `SerializationWarning` is emitted on the worker — + the chain-manifest failure is ancillary state and does not preempt the routine's primary signal """ # Arrange @@ -855,14 +821,7 @@ async def streamer(): var.set(_Unpicklable()) yield "second" - mock_proxy = PicklableMock(spec=WorkerProxyLike, id="test-proxy-id") - wool_task = Task( - id=uuid4(), - callable=streamer, - args=(), - kwargs={}, - proxy=mock_proxy, - ) + wool_task = make_task(streamer) first_request = protocol.Request(task=wool_task.to_protobuf()) next_request = protocol.Request(next=protocol.Void()) @@ -884,7 +843,7 @@ async def drive(): await stream.done_writing() return second - with pytest.warns(wool.ContextDecodeWarning): + with pytest.warns(wool.SerializationWarning): second = await asyncio.wait_for(drive(), timeout=5.0) # Assert @@ -892,27 +851,27 @@ async def drive(): assert cloudpickle.loads(second.result.dump) == "second" @pytest.mark.asyncio - async def test_dispatch_with_unpicklable_worker_mutation( + async def test_dispatch_should_warn_when_coroutine_worker_mutation_unpicklable( self, grpc_aio_stub, mock_worker_proxy_cache ): - """Test :class:`WorkerService` dispatch delivers the routine's + """Test `WorkerService` dispatch delivers the routine's return value on the coroutine path when a worker-side - snapshot serialization fails, emitting a - :class:`ContextDecodeWarning` instead of failing the + context serialization fails, emitting a + `SerializationWarning` instead of failing the dispatch. Given: - A coroutine routine that sets a :class:`wool.ContextVar` + A coroutine routine that sets a `wool.ContextVar` to a value whose ``__reduce__`` raises before - returning — the wool back-prop snapshot - (``Context.to_protobuf``) in the done-callback cannot + returning — the wool back-prop encode + (`ChainManifest.to_protobuf`) in the done-callback cannot serialize the post-run state When: The caller dispatches the routine Then: The routine's return value is delivered as a normal - result frame with an empty wire context, and a - :class:`ContextDecodeWarning` is emitted on the worker + result frame with an empty chain manifest, and a + `SerializationWarning` is emitted on the worker """ # Arrange namespace = f"unpicklable_coro_{uuid4().hex}" @@ -926,14 +885,7 @@ async def coroutine(): var.set(_Unpicklable()) return "ok" - mock_proxy = PicklableMock(spec=WorkerProxyLike, id="test-proxy-id") - wool_task = Task( - id=uuid4(), - callable=coroutine, - args=(), - kwargs={}, - proxy=mock_proxy, - ) + wool_task = make_task(coroutine) request = protocol.Request(task=wool_task.to_protobuf()) # Act @@ -944,7 +896,7 @@ async def drive(): await stream.done_writing() return [r async for r in stream] - with pytest.warns(wool.ContextDecodeWarning): + with pytest.warns(wool.SerializationWarning): responses = await asyncio.wait_for(drive(), timeout=5.0) # Assert @@ -953,35 +905,81 @@ async def drive(): assert cloudpickle.loads(result_responses[0].result.dump) == "ok" @pytest.mark.asyncio - async def test_dispatch_with_routine_raise_and_unpicklable_mutation( + async def test_dispatch_should_ship_serialization_error_when_result_unpicklable( + self, grpc_aio_stub, mock_worker_proxy_cache + ): + """Test `WorkerService` dispatch ships a typed + `wool.SerializationError` when the routine succeeds + but its return value cannot be encoded for the wire. + + Given: + A coroutine routine that succeeds but returns an + unpicklable value (a `threading.Lock`), with no + strict warning filter installed. + When: + The dispatch RPC is invoked and the response stream is + consumed. + Then: + The terminal exception frame should decode to a + `wool.SerializationError` — not the + `wool.ChainSerializationError` aggregator — + whose message names the result-payload encode failure, + with the underlying encode error attached as ``cause`` + and a ``value_repr`` preview of the offending value. + """ + + # Arrange + async def lock_returning_task(): + return threading.Lock() + + wool_task = make_task(lock_returning_task) + request = protocol.Request(task=wool_task.to_protobuf()) + + # Act + async with grpc_aio_stub() as stub: + stream = stub.dispatch() + await stream.write(request) + await stream.done_writing() + responses = [r async for r in stream] + + # Assert + assert responses[0].HasField("ack") + exc_responses = [r for r in responses if r.HasField("exception")] + assert len(exc_responses) == 1 + raised = cloudpickle.loads(exc_responses[0].exception.dump) + assert isinstance(raised, wool.SerializationError) + assert not isinstance(raised, wool.ChainSerializationError) + assert "Failed to encode result payload" in str(raised) + assert raised.cause is not None + assert isinstance(raised.value_repr, str) + + @pytest.mark.asyncio + async def test_dispatch_should_chain_encode_failure_as_cause_when_routine_raises_and_mutation_unpicklable( # noqa: E501 self, grpc_aio_stub, mock_worker_proxy_cache ): - """Test :class:`WorkerService` dispatch ships the routine - exception bare with the worker-side snapshot failure - attached as PEP 678 ``__notes__`` and a structured - ``__wool_context_warnings__`` attribute when both occur + """Test `WorkerService` dispatch ships the routine + exception bare with the worker-side chain failure + attached as PEP 678 ``__notes__`` when both occur in the same done-callback on the coroutine path under strict mode. Given: - A coroutine routine that sets a :class:`wool.ContextVar` + A coroutine routine that sets a `wool.ContextVar` to a value whose ``__reduce__`` raises and then itself raises an unrelated exception, with the worker-side - warnings filter promoting :class:`ContextDecodeWarning` + warnings filter promoting `SerializationWarning` to an exception — both the routine's failure and the - wool back-prop snapshot's failure occur in the same + wool back-prop chain's failure occur in the same done-callback When: The caller dispatches the routine Then: The dispatch ships the routine exception's type bare (so the caller's existing ``except RoutineError`` - keeps catching), with the snapshot encode failure - attached via PEP 678 ``__notes__`` (visible in - tracebacks) and a ``__wool_context_warnings__`` - attribute (programmatic access to the - :class:`ContextDecodeWarning` peers naming the - offending var) + keeps catching), with the strict-mode + `wool.ChainSerializationError` chained as + ``__cause__`` via ``raise routine_exc from encode_err`` + so the encode failure remains visible in the traceback. """ import warnings as _warnings @@ -997,14 +995,7 @@ async def coroutine(): var.set(_Unpicklable()) raise ValueError("routine failure") - mock_proxy = PicklableMock(spec=WorkerProxyLike, id="test-proxy-id") - wool_task = Task( - id=uuid4(), - callable=coroutine, - args=(), - kwargs={}, - proxy=mock_proxy, - ) + wool_task = make_task(coroutine) request = protocol.Request(task=wool_task.to_protobuf()) # Act @@ -1016,7 +1007,7 @@ async def drive(): return [r async for r in stream] with _warnings.catch_warnings(): - _warnings.simplefilter("error", category=wool.ContextDecodeWarning) + _warnings.simplefilter("error", category=wool.SerializationWarning) responses = await asyncio.wait_for(drive(), timeout=5.0) # Assert @@ -1032,63 +1023,65 @@ async def drive(): ) assert "routine failure" in str(raised) - # PEP 678 notes carry the warning(s) for traceback - # diagnostic. - assert hasattr(raised, "__notes__") - notes_text = "\n".join(raised.__notes__) - assert "synthetic unpicklable" in notes_text, ( - f"snapshot failure must appear in __notes__; observed: {raised.__notes__}" + # The strict-mode encode failure rides on ``__cause__`` + # via ``raise from`` chaining — visible in tracebacks and + # accessible programmatically as a typed + # ``wool.ChainSerializationError`` aggregating the warning(s). + cause = raised.__cause__ + assert isinstance(cause, wool.ChainSerializationError), ( + f"encode failure must appear on ``__cause__``; observed: " + f"{type(cause).__name__}" ) - - # __wool_context_warnings__ provides structured access. - warnings = raised.__wool_context_warnings__ - snapshot_warnings = [ - w - for w in warnings - if isinstance(w, wool.ContextDecodeWarning) - and "synthetic unpicklable" in str(w) - ] - assert len(snapshot_warnings) == 1, ( - "snapshot ContextDecodeWarning must appear in __wool_context_warnings__" + assert any("synthetic unpicklable" in str(w) for w in cause.warnings), ( + f"warning must name the offending var; got {cause.warnings!r}" ) @pytest.mark.asyncio - async def test_dispatch_with_attribute_rejecting_routine_exception_under_strict_mode( + async def test_dispatch_should_preserve_exception_type_when_exception_rejects_attribute_writes_under_strict_mode( # noqa: E501 self, grpc_aio_stub, mock_worker_proxy_cache ): - """Test :class:`WorkerService` dispatch ships a routine - exception whose class rejects arbitrary attribute writes - unchanged under strict mode. - - Regression test for A4. Pre-fix, - ``e.__wool_context_warnings__ = warnings`` raised - :class:`AttributeError` for exception classes that reject - arbitrary attribute writes (e.g., overridden - ``__setattr__``, or layouts that disable ``__dict__``) — + """Test `WorkerService` dispatch ships an attribute- + rejecting routine exception with its type preserved under + strict mode. + + Regression test for the primary-class-preservation invariant: + the routine exception's class must reach the caller bare even + when wool's structured side-channels can't ride on it. + Attaching wool's structured side-channel onto an exception that + rejects arbitrary writes once raised `AttributeError`, converting the routine's primary signal into a stray - :class:`AttributeError` shipped to the caller. Post-fix, - the assignment is best-effort: PEP 678 ``__notes__`` - carries the warnings (always available) and the structured - attribute is silently skipped when the exception class - does not support the write. + `AttributeError` shipped to the caller. The attachment is + now best-effort: it falls through cleanly when the class rejects + it, and the routine exception type still ships. + + Under the new "raise from" chaining design the encode error's + wire survival depends on tblib's ``pickling_support`` finding + a class-compatible reduce path. For an + `Exception` subclass with no custom ``__init__`` that + rejects ``__dict__`` writes, tblib's + ``unpickle_exception_with_attrs`` path fails to restore the + chain — the class itself round-trips fine but ``__cause__`` + does not survive. This test pins the primary-class + invariant; chain survival is a best-effort observable on + cooperating classes (the common case — see the routine-raise + siblings). Given: - A coroutine routine that sets a :class:`wool.ContextVar` + A coroutine routine that sets a `wool.ContextVar` to a value whose ``__reduce__`` raises (forcing the - wool snapshot encode failure path) and then raises an + wool chain encode failure path) and then raises an exception whose ``__setattr__`` rejects arbitrary attribute writes, with worker-side strict mode - promoting :class:`wool.ContextDecodeWarning` to an + promoting `wool.SerializationWarning` to an exception. When: The caller dispatches the routine. Then: - The wire ships the routine exception type unchanged - with PEP 678 ``__notes__`` carrying the warnings; - ``__wool_context_warnings__`` is not present (the - best-effort attribute set silently skipped). Pre-fix - the wire shipped an :class:`AttributeError` from the - failed attribute write instead. + The wire ships the routine exception type unchanged. + The wire frame omits the post-run ``context`` field + because encode failed. The caller's existing + ``except`` against the routine class keeps matching — + no migration required. """ import warnings as _warnings @@ -1104,14 +1097,7 @@ async def coroutine(): var.set(_Unpicklable()) raise _AttributeRejectingRoutineError("attribute-rejecting routine failure") - mock_proxy = PicklableMock(spec=WorkerProxyLike, id="test-proxy-id") - wool_task = Task( - id=uuid4(), - callable=coroutine, - args=(), - kwargs={}, - proxy=mock_proxy, - ) + wool_task = make_task(coroutine) request = protocol.Request(task=wool_task.to_protobuf()) # Act @@ -1123,7 +1109,7 @@ async def drive(): return [r async for r in stream] with _warnings.catch_warnings(): - _warnings.simplefilter("error", category=wool.ContextDecodeWarning) + _warnings.simplefilter("error", category=wool.SerializationWarning) responses = await asyncio.wait_for(drive(), timeout=5.0) # Assert @@ -1141,52 +1127,34 @@ async def drive(): ) assert "attribute-rejecting routine failure" in str(raised) - # __notes__ carries the warning (always available — it's - # part of BaseException's API regardless of __slots__). - notes_text = "\n".join(getattr(raised, "__notes__", [])) - assert "synthetic unpicklable" in notes_text, ( - f"snapshot failure must appear in __notes__; observed: " - f"{getattr(raised, '__notes__', None)}" - ) - - # __wool_context_warnings__ is not set on - # attribute-rejecting exception types — the best-effort - # attribute write was skipped. - assert not hasattr(raised, "__wool_context_warnings__"), ( - "attribute-rejecting exception classes cannot accept " - "arbitrary attribute writes; the best-effort set " - "should be skipped" - ) - @pytest.mark.asyncio - async def test_dispatch_streaming_with_routine_raise_and_unpicklable_mutation( + async def test_dispatch_should_chain_encode_failure_as_cause_when_streaming_routine_raises_and_mutation_unpicklable( # noqa: E501 self, grpc_aio_stub, mock_worker_proxy_cache ): - """Test :class:`WorkerService` dispatch ships the routine - exception bare with the worker-side snapshot failure - attached as PEP 678 ``__notes__`` and - ``__wool_context_warnings__`` when both occur in the + """Test `WorkerService` dispatch ships the routine + exception bare with the worker-side chain failure + attached as PEP 678 ``__notes__`` when both occur in the same iteration on the streaming path under strict mode. Symmetric with the coroutine path's contract. Given: An async-generator routine that yields once - successfully, then sets a :class:`wool.ContextVar` + successfully, then sets a `wool.ContextVar` to a value whose ``__reduce__`` raises and itself raises an unrelated exception on the next iteration, with the worker-side warnings filter promoting - :class:`ContextDecodeWarning` to an exception — both - the routine's failure and the back-prop snapshot's + `SerializationWarning` to an exception — both + the routine's failure and the back-prop chain's failure occur in the same iteration When: The caller drives the generator past the yielded value and into the failing iteration Then: The dispatch ships the routine exception type bare - with the snapshot encode failure attached via - PEP 678 ``__notes__`` and a structured - ``__wool_context_warnings__`` attribute, symmetric - with the coroutine path's contract + with the strict-mode `wool.ChainSerializationError` + chained as ``__cause__`` via ``raise routine_exc from + encode_err``, symmetric with the coroutine path's + contract. """ import warnings as _warnings @@ -1203,14 +1171,7 @@ async def streamer(): var.set(_Unpicklable()) raise ValueError("routine failure") - mock_proxy = PicklableMock(spec=WorkerProxyLike, id="test-proxy-id") - wool_task = Task( - id=uuid4(), - callable=streamer, - args=(), - kwargs={}, - proxy=mock_proxy, - ) + wool_task = make_task(streamer) first_request = protocol.Request(task=wool_task.to_protobuf()) next_request = protocol.Request(next=protocol.Void()) @@ -1232,7 +1193,7 @@ async def drive(): return [r async for r in stream] with _warnings.catch_warnings(): - _warnings.simplefilter("error", category=wool.ContextDecodeWarning) + _warnings.simplefilter("error", category=wool.SerializationWarning) responses = await asyncio.wait_for(drive(), timeout=5.0) # Assert @@ -1247,112 +1208,21 @@ async def drive(): ) assert "routine failure" in str(raised) - # PEP 678 notes carry the warning for traceback diagnostic. - assert hasattr(raised, "__notes__") - notes_text = "\n".join(raised.__notes__) - assert "synthetic unpicklable" in notes_text, ( - f"snapshot failure must appear in __notes__; observed: {raised.__notes__}" - ) - - # __wool_context_warnings__ provides structured access. - warnings = raised.__wool_context_warnings__ - snapshot_warnings = [ - w - for w in warnings - if isinstance(w, wool.ContextDecodeWarning) - and "synthetic unpicklable" in str(w) - ] - assert len(snapshot_warnings) == 1, ( - "snapshot ContextDecodeWarning must appear in __wool_context_warnings__" + # The strict-mode encode failure rides on ``__cause__`` via + # ``raise from`` chaining — typed + # ``wool.ChainSerializationError`` aggregating the warning. + cause = raised.__cause__ + assert isinstance(cause, wool.ChainSerializationError), ( + f"encode failure must appear on ``__cause__``; observed: " + f"{type(cause).__name__}" ) + assert any("synthetic unpicklable" in str(w) for w in cause.warnings) @pytest.mark.asyncio - async def test_dispatch_streaming_when_update_raises( + async def test_dispatch_should_ship_exception_when_streaming_pre_loop_setup_fails( self, grpc_aio_stub, mock_worker_proxy_cache, mocker: MockerFixture ): - """Test :class:`WorkerService` dispatch surfaces unhandled - iteration-body errors as a terminal error rather than hanging - the streaming dispatch. - - Given: - An async-generator dispatch where the second ``next`` - frame carries a state-bearing context, but - ``Context.update`` is patched to raise on invocation — - the unprotected merge that would otherwise strand the - worker task - When: - The caller sends the second request and consumes the - stream - Then: - The dispatch must terminate within the asyncio timeout - window with the synthetic error surfaced as an exception - Response — the iteration-body catch-all guarantees - that any exception escaping the precise handlers is - still pushed to the result queue - """ - # Arrange - from wool.runtime.context import Context - - async def streamer(): - for i in range(5): - yield f"value_{i}" - - mock_proxy = PicklableMock(spec=WorkerProxyLike, id="test-proxy-id") - wool_task = Task( - id=uuid4(), - callable=streamer, - args=(), - kwargs={}, - proxy=mock_proxy, - ) - first_request = protocol.Request(task=wool_task.to_protobuf()) - good_next = protocol.Request( - next=protocol.Void(), - context=protocol.Context(id=uuid4().hex), - ) - # State-bearing context (carries_state True) so the worker - # invokes update on receive. - bad_ctx = protocol.Context(id=uuid4().hex) - bad_ctx.vars.add(namespace="", name="", consumed_tokens=[uuid4().hex]) - bad_next = protocol.Request(next=protocol.Void(), context=bad_ctx) - - mocker.patch.object( - Context, - "update", - side_effect=RuntimeError("synthetic update failure"), - ) - - # Act - async def drive(): - async with grpc_aio_stub() as stub: - stream = stub.dispatch() - await stream.write(first_request) - ack = await anext(aiter(stream)) - assert ack.HasField("ack") - - await stream.write(good_next) - first = await anext(aiter(stream)) - assert first.HasField("result") - assert cloudpickle.loads(first.result.dump) == "value_0" - - await stream.write(bad_next) - await stream.done_writing() - return [r async for r in stream] - - responses = await asyncio.wait_for(drive(), timeout=5.0) - - # Assert - assert any(r.HasField("exception") for r in responses) - exc_response = next(r for r in responses if r.HasField("exception")) - raised = cloudpickle.loads(exc_response.exception.dump) - assert isinstance(raised, RuntimeError) - assert "synthetic update failure" in str(raised) - - @pytest.mark.asyncio - async def test_dispatch_streaming_surfaces_pre_loop_setup_failure( - self, grpc_aio_stub, mock_worker_proxy_cache, mocker: MockerFixture - ): - """Test :class:`WorkerService` streaming dispatch surfaces a + """Test `WorkerService` streaming dispatch surfaces a worker-task setup failure as a terminal exception frame rather than hanging the caller. @@ -1372,23 +1242,16 @@ async def test_dispatch_streaming_surfaces_pre_loop_setup_failure( silently swallowing it. """ # Arrange - from wool.runtime.context import RuntimeContext + from wool.runtime.context.runtime import RuntimeContext async def streamer(): yield "unreachable" - mock_proxy = PicklableMock(spec=WorkerProxyLike, id="test-proxy-id") - wool_task = Task( - id=uuid4(), - callable=streamer, - args=(), - kwargs={}, - proxy=mock_proxy, - ) + wool_task = make_task(streamer) first_request = protocol.Request(task=wool_task.to_protobuf()) next_request = protocol.Request( next=protocol.Void(), - context=protocol.Context(id=uuid4().hex), + context=protocol.ChainManifest(id=uuid4().hex), ) mocker.patch.object( @@ -1422,10 +1285,10 @@ async def drive(): assert "synthetic pre-loop failure" in str(raised) @pytest.mark.asyncio - async def test_dispatch_streaming_with_teardown_failure_after_completion( + async def test_dispatch_should_not_append_exception_when_streaming_teardown_fails_after_completion( # noqa: E501 self, grpc_aio_stub, mock_worker_proxy_cache ): - """Test :class:`WorkerService` streaming dispatch when the + """Test `WorkerService` streaming dispatch when the async generator's teardown raises after the primary signal has already reached the caller. @@ -1451,18 +1314,11 @@ async def streamer(): finally: raise RuntimeError("synthetic teardown failure") - mock_proxy = PicklableMock(spec=WorkerProxyLike, id="test-proxy-id") - wool_task = Task( - id=uuid4(), - callable=streamer, - args=(), - kwargs={}, - proxy=mock_proxy, - ) + wool_task = make_task(streamer) first_request = protocol.Request(task=wool_task.to_protobuf()) next_request = protocol.Request( next=protocol.Void(), - context=protocol.Context(id=uuid4().hex), + context=protocol.ChainManifest(id=uuid4().hex), ) # Act @@ -1490,13 +1346,13 @@ async def drive(): ) @pytest.mark.asyncio - async def test_dispatch_while_stopping( + async def test_dispatch_should_abort_unavailable_when_stopping( self, service_fixture, mock_worker_proxy_cache, mocker: MockerFixture ): - """Test :class:`WorkerService` dispatch aborts when stopping. + """Test `WorkerService` dispatch aborts when stopping. Given: - A :class:`WorkerService` with an active task, transitioning to stopping state + A `WorkerService` with an active task, transitioning to stopping state When: stop is called and another dispatch RPC is attempted Then: @@ -1515,15 +1371,7 @@ async def test_dispatch_while_stopping( async def another_task(): return "should_not_execute" - mock_proxy = PicklableMock(spec=WorkerProxyLike, id="test-proxy-id-2") - - wool_task = Task( - id=uuid4(), - callable=another_task, - args=(), - kwargs={}, - proxy=mock_proxy, - ) + wool_task = make_task(another_task, proxy_id="test-proxy-id-2") request = protocol.Request(task=wool_task.to_protobuf()) @@ -1542,27 +1390,27 @@ async def another_task(): await stop_task @pytest.mark.asyncio - async def test_dispatch_with_stop_arriving_between_entry_gate_and_tracking( + async def test_dispatch_should_abort_unavailable_when_stop_races_entry_gate_and_tracking( # noqa: E501 self, grpc_aio_stub, grpc_servicer, mock_worker_proxy_cache, mocker: MockerFixture, ): - """Test :class:`WorkerService.dispatch` aborts ``UNAVAILABLE`` when + """Test `WorkerService.dispatch` aborts ``UNAVAILABLE`` when the ``_stopping`` event is set after the entry-gate check but before the session is registered in the docket. Regression test for the ``_tracked`` check-to-register race window. ``WorkerService.dispatch`` checks ``_stopping`` on entry and again on docket registration; without the second check, a concurrent - :meth:`_stop` between the gate and registration would admit a - session that :meth:`_preempt` never sees, leaving it to be torn + `_stop` between the gate and registration would admit a + session that `_preempt` never sees, leaving it to be torn down indirectly by loop-pool teardown rather than the explicit cancel path. Given: - A :class:`WorkerService` whose ``_stopping.is_set`` returns + A `WorkerService` whose ``_stopping.is_set`` returns ``False`` on the entry-gate check and ``True`` on the ``_tracked`` check, simulating a stop arrival in the race window @@ -1587,14 +1435,7 @@ async def test_dispatch_with_stop_arriving_between_entry_gate_and_tracking( async def sample_task(): return "should_not_execute" - mock_proxy = PicklableMock(spec=WorkerProxyLike, id="test-proxy-id") - wool_task = Task( - id=uuid4(), - callable=sample_task, - args=(), - kwargs={}, - proxy=mock_proxy, - ) + wool_task = make_task(sample_task) request = protocol.Request(task=wool_task.to_protobuf()) # Act & assert @@ -1612,11 +1453,13 @@ async def sample_task(): assert exc_info.value.code() == StatusCode.UNAVAILABLE @pytest.mark.asyncio - async def test_dispatch_while_stopped(self, service_fixture, mocker: MockerFixture): - """Test :class:`WorkerService` dispatch aborts when stopped. + async def test_dispatch_should_abort_unavailable_when_stopped( + self, service_fixture, mocker: MockerFixture + ): + """Test `WorkerService` dispatch aborts when stopped. Given: - A :class:`WorkerService` that has been stopped + A `WorkerService` that has been stopped When: dispatch RPC is called with a task request Then: @@ -1636,15 +1479,7 @@ async def test_dispatch_while_stopped(self, service_fixture, mocker: MockerFixtu async def another_task(): return "should_not_execute" - mock_proxy = PicklableMock(spec=WorkerProxyLike, id="test-proxy-id-2") - - wool_task = Task( - id=uuid4(), - callable=another_task, - args=(), - kwargs={}, - proxy=mock_proxy, - ) + wool_task = make_task(another_task, proxy_id="test-proxy-id-2") request = protocol.Request(task=wool_task.to_protobuf()) @@ -1659,10 +1494,10 @@ async def another_task(): assert exc_info.value.code() == StatusCode.UNAVAILABLE @pytest.mark.asyncio - async def test_dispatch_with_sync_callable( + async def test_dispatch_should_nack_with_value_error_when_callable_synchronous( self, grpc_aio_stub, mock_worker_proxy_cache ): - """Test :class:`WorkerService` dispatch when the task's + """Test `WorkerService` dispatch when the task's callable is a plain synchronous function (not a coroutine function or async-generator function). @@ -1672,9 +1507,9 @@ async def test_dispatch_with_sync_callable( When: The dispatch RPC is invoked with that request Then: - It should reply with exactly one :class:`Nack` response + It should reply with exactly one `Nack` response (no preceding Ack) whose ``exception`` field decodes to - a :class:`ValueError` describing the routine-type + a `ValueError` describing the routine-type violation. """ @@ -1682,15 +1517,7 @@ async def test_dispatch_with_sync_callable( def sync_function(): return "sync_result" - mock_proxy = PicklableMock(spec=WorkerProxyLike, id="test-proxy-id") - - wool_task = Task( - id=uuid4(), - callable=sync_function, - args=(), - kwargs={}, - proxy=mock_proxy, - ) + wool_task = make_task(sync_function) request = protocol.Request(task=wool_task.to_protobuf()) @@ -1711,52 +1538,45 @@ def sync_function(): assert "coroutine function or async generator function" in str(raised) @pytest.mark.asyncio - async def test_stop_and_cancel( + async def test_stop_should_cancel_active_coroutine_routine( self, grpc_aio_stub, grpc_servicer, mocker: MockerFixture, mock_worker_proxy_cache, ): - """Test :class:`WorkerService` stop pre-empts an active + """Test `WorkerService` stop pre-empts an active coroutine routine. Verifies the operator-preempt contract on the routine side - via a side-channel :class:`threading.Event`. In production, + via a side-channel `threading.Event`. In production, the worker subprocess exits after stop and the gRPC connection drops — callers do not observe a terminal ``CancelledError`` wire frame, they observe an - :class:`RpcError` (transport-closed). Asserting on a wire + `RpcError` (transport-closed). Asserting on a wire frame after stop tests an in-process-only artifact (the gRPC server stays alive in the fixture); asserting on the routine's own observation of cancellation is the production-realistic check. Given: - A running :class:`WorkerService` with an active + A running `WorkerService` with an active coroutine awaiting a long sleep When: stop RPC is called with a timeout of 0 Then: - The routine should observe :class:`asyncio.CancelledError` + The routine should observe `asyncio.CancelledError` within the test's budget — operator-preempt cancels the worker driver task on its loop, propagating cancellation - into the routine's :func:`asyncio.sleep`. The service + into the routine's `asyncio.sleep`. The service should signal stopped state and call - :meth:`proxy_pool.clear`. + `proxy_pool.clear`. """ global _stop_cancellation_observed, _stop_routine_started _stop_cancellation_observed = threading.Event() _stop_routine_started = threading.Event() try: - mock_proxy = PicklableMock(spec=WorkerProxyLike, id="test-proxy-id") - wool_task = Task( - id=uuid4(), - callable=_stop_long_coroutine, - args=(), - kwargs={}, - proxy=mock_proxy, - ) + wool_task = make_task(_stop_long_coroutine) request = protocol.Request(task=wool_task.to_protobuf()) # Act @@ -1765,6 +1585,13 @@ async def test_stop_and_cancel( await stream.write(request) ack = await anext(aiter(stream)) assert ack.HasField("ack") + # Under the per-frame architecture, the coroutine + # branch reads its prime ``Next`` from the request + # iterator. Writing the Next here is what triggers + # the routine — without it the worker driver hangs + # waiting for the prime frame and the routine never + # starts. + await stream.write(protocol.Request(next=protocol.Void())) # Wait for the routine to actually start before # sending stop. Without this barrier the test races @@ -1772,7 +1599,7 @@ async def test_stop_and_cancel( # slower Python versions/CI runners, stop can land # before ``_worker_task`` is created, so # ``session.cancel()`` has nothing to cancel and - # the routine never observes :class:`CancelledError`. + # the routine never observes `CancelledError`. loop = asyncio.get_running_loop() started = await loop.run_in_executor( None, _stop_routine_started.wait, 10.0 @@ -1808,32 +1635,32 @@ async def test_stop_and_cancel( _stop_routine_started = None @pytest.mark.asyncio - async def test_stop_and_cancel_streaming_routine( + async def test_stop_should_cancel_active_streaming_routine( self, grpc_aio_stub, grpc_servicer, mocker: MockerFixture, mock_worker_proxy_cache, ): - """Test :class:`WorkerService` stop pre-empts an active + """Test `WorkerService` stop pre-empts an active async-generator routine mid-stream. Verifies the operator-preempt contract on the routine side - via a side-channel :class:`threading.Event`. See the + via a side-channel `threading.Event`. See the coroutine variant's docstring for why we assert on the routine's observation of cancellation rather than on a terminal wire frame. Given: - A running :class:`WorkerService` with an active + A running `WorkerService` with an active async-generator task suspended between yields When: stop RPC is called with a timeout of 0 Then: - The routine should observe :class:`asyncio.CancelledError` + The routine should observe `asyncio.CancelledError` within the test's budget — operator-preempt cancels the worker driver task on its loop, propagating cancellation - into the routine's :func:`asyncio.sleep` between yields. + into the routine's `asyncio.sleep` between yields. """ global _stop_cancellation_observed, _stop_routine_started _stop_cancellation_observed = threading.Event() @@ -1845,14 +1672,7 @@ async def test_stop_and_cancel_streaming_routine( # routine body. _stop_routine_started = threading.Event() try: - mock_proxy = PicklableMock(spec=WorkerProxyLike, id="test-proxy-id") - wool_task = Task( - id=uuid4(), - callable=_stop_streaming_routine, - args=(), - kwargs={}, - proxy=mock_proxy, - ) + wool_task = make_task(_stop_streaming_routine) request = protocol.Request(task=wool_task.to_protobuf()) # Act @@ -1889,17 +1709,17 @@ async def test_stop_and_cancel_streaming_routine( _stop_routine_started = None @pytest.mark.asyncio - async def test_stop_and_wait( + async def test_stop_should_await_tasks_and_signal_stopped_when_positive_timeout( self, grpc_aio_stub, grpc_servicer, mocker: MockerFixture, mock_worker_proxy_cache, ): - """Test :class:`WorkerService` stop method gracefully shuts down. + """Test `WorkerService` stop method gracefully shuts down. Given: - A running :class:`WorkerService` with active tasks + A running `WorkerService` with active tasks When: stop RPC is called with a positive timeout ("wait") Then: @@ -1911,15 +1731,7 @@ async def test_stop_and_wait( async def quick_task(): return "completed" - mock_proxy = PicklableMock(spec=WorkerProxyLike, id="test-proxy-id") - - wool_task = Task( - id=uuid4(), - callable=quick_task, - args=(), - kwargs={}, - proxy=mock_proxy, - ) + wool_task = make_task(quick_task) request = protocol.Request(task=wool_task.to_protobuf()) @@ -1950,16 +1762,16 @@ async def quick_task(): mock_worker_proxy_cache.clear.assert_called_once() @pytest.mark.asyncio - async def test_dispatch_task_that_self_cancels( + async def test_dispatch_should_return_cancelled_error_when_task_self_cancels( self, grpc_aio_stub, mocker: MockerFixture, mock_worker_proxy_cache, ): - """Test :class:`WorkerService` dispatch handles task that cancels itself. + """Test `WorkerService` dispatch handles task that cancels itself. Given: - A gRPC :class:`WorkerService` that is not stopping or stopped + A gRPC `WorkerService` that is not stopping or stopped When: Dispatch RPC is called with a task that cancels itself on the worker loop Then: @@ -1972,15 +1784,7 @@ async def self_cancelling_task(): await asyncio.sleep(0) return "should_not_return" - mock_proxy = PicklableMock(spec=WorkerProxyLike, id="test-proxy-id") - - wool_task = Task( - id=uuid4(), - callable=self_cancelling_task, - args=(), - kwargs={}, - proxy=mock_proxy, - ) + wool_task = make_task(self_cancelling_task) request = protocol.Request(task=wool_task.to_protobuf()) @@ -1999,24 +1803,24 @@ async def self_cancelling_task(): assert isinstance(exception, asyncio.CancelledError) @pytest.mark.asyncio - async def test_dispatch_client_cancellation_propagates_to_routine( + async def test_dispatch_should_cancel_routine_when_client_cancels_mid_stream( self, grpc_aio_stub, mock_worker_proxy_cache, ): - """Test :meth:`WorkerService.dispatch` cancels the worker + """Test `WorkerService.dispatch` cancels the worker routine when the client cancels mid-stream. - Regression test for A1. Pre-fix, - :meth:`DispatchSession.cancel` only set ``_cancelled = True`` + Regression test. Pre-fix, + `DispatchSession.cancel` only set ``_cancelled = True`` and pushed ``_EOS`` on the response queue; the worker driver task itself was never cancelled. A routine - mid-``_step`` (e.g. ``await asyncio.sleep(...)``) ran to + mid-``_step`` (e.g., ``await asyncio.sleep(...)``) ran to natural completion regardless of whether the caller had - gone away. Post-fix, :meth:`cancel` schedules + gone away. Post-fix, `cancel` schedules ``self._worker_task.cancel`` on the worker loop, so a compute-bound or sleeping routine receives a - :class:`asyncio.CancelledError` and unwinds rather than + `asyncio.CancelledError` and unwinds rather than holding the worker until shutdown. Given: @@ -2024,29 +1828,22 @@ async def test_dispatch_client_cancellation_propagates_to_routine( seconds — long enough that the test will have given up and asserted before it could complete naturally. The routine signals observation of - :class:`asyncio.CancelledError` via a cross-loop + `asyncio.CancelledError` via a cross-loop ``threading.Event``. When: The gRPC client cancels the dispatch stream while the routine is mid-``await asyncio.sleep``. Then: - The routine observes :class:`asyncio.CancelledError` + The routine observes `asyncio.CancelledError` within a short timeout — pre-fix this assertion timed - out because :meth:`cancel` left the worker driver + out because `cancel` left the worker driver task running. """ - global _a1_cancellation_observed, _a1_routine_started - _a1_cancellation_observed = threading.Event() - _a1_routine_started = threading.Event() + global _midstream_cancellation_observed, _midstream_routine_started + _midstream_cancellation_observed = threading.Event() + _midstream_routine_started = threading.Event() try: - mock_proxy = PicklableMock(spec=WorkerProxyLike, id="test-proxy-id") - wool_task = Task( - id=uuid4(), - callable=_a1_long_routine, - args=(), - kwargs={}, - proxy=mock_proxy, - ) + wool_task = make_task(_midstream_long_routine) request = protocol.Request(task=wool_task.to_protobuf()) # Act @@ -2055,6 +1852,12 @@ async def test_dispatch_client_cancellation_propagates_to_routine( await stream.write(request) ack = await anext(aiter(stream)) assert ack.HasField("ack") + # Under the per-frame architecture, the coroutine + # branch reads its prime ``Next`` from the request + # iterator. Writing the Next here is what triggers + # the routine — without it the worker driver hangs + # waiting for the prime frame. + await stream.write(protocol.Request(next=protocol.Void())) # Barrier: wait until the worker routine is actually # running before cancelling. The ``ack`` only @@ -2062,7 +1865,7 @@ async def test_dispatch_client_cancellation_propagates_to_routine( # ``yield ack``; the worker task is scheduled lazily # on the handler's first ``async for`` iteration. # Cancelling before the routine starts races - # :meth:`DispatchSession._schedule_worker`, which + # `DispatchSession._schedule_worker`, which # short-circuits on ``_cancelled`` and never # dispatches the routine — leaving nothing for the # cancellation to interrupt and failing this test @@ -2070,7 +1873,7 @@ async def test_dispatch_client_cancellation_propagates_to_routine( # ``_stop_routine_started``. loop = asyncio.get_running_loop() started = await loop.run_in_executor( - None, _a1_routine_started.wait, 10.0 + None, _midstream_routine_started.wait, 10.0 ) assert started, ( "Worker routine did not start within 10s — " @@ -2105,7 +1908,7 @@ async def test_dispatch_client_cancellation_propagates_to_routine( # if the chain is actually broken (regression would # see the routine sleep the full 30s). observed = await loop.run_in_executor( - None, _a1_cancellation_observed.wait, 10.0 + None, _midstream_cancellation_observed.wait, 10.0 ) # Assert @@ -2118,17 +1921,17 @@ async def test_dispatch_client_cancellation_propagates_to_routine( "cancelled." ) finally: - _a1_cancellation_observed = None - _a1_routine_started = None + _midstream_cancellation_observed = None + _midstream_routine_started = None @pytest.mark.asyncio - async def test_stop_timeout_then_cancel( + async def test_stop_should_cancel_tasks_when_timeout_expires( self, service_fixture, mock_worker_proxy_cache ): - """Test :class:`WorkerService` stop cancels tasks after timeout expires. + """Test `WorkerService` stop cancels tasks after timeout expires. Given: - A :class:`WorkerService` with an active task that outlasts the stop timeout + A `WorkerService` with an active task that outlasts the stop timeout When: stop RPC is called with a small positive timeout Then: @@ -2148,13 +1951,13 @@ async def test_stop_timeout_then_cancel( assert service.stopped.is_set() @pytest.mark.asyncio - async def test_stop_while_idle( + async def test_stop_should_signal_stopped_when_idle( self, grpc_aio_stub, grpc_servicer, mock_worker_proxy_cache ): - """Test :class:`WorkerService` stop method gracefully shuts down. + """Test `WorkerService` stop method gracefully shuts down. Given: - A running :class:`WorkerService` with no active tasks + A running `WorkerService` with no active tasks When: stop RPC is called Then: @@ -2175,11 +1978,13 @@ async def test_stop_while_idle( mock_worker_proxy_cache.clear.assert_called_once() @pytest.mark.asyncio - async def test_stop_while_stopping(self, service_fixture, mock_worker_proxy_cache): - """Test :class:`WorkerService` stop is idempotent. + async def test_stop_should_return_immediately_when_already_stopping( + self, service_fixture, mock_worker_proxy_cache + ): + """Test `WorkerService` stop is idempotent. Given: - A :class:`WorkerService` that is already stopping + A `WorkerService` that is already stopping When: stop RPC is called again Then: @@ -2209,11 +2014,13 @@ async def test_stop_while_stopping(self, service_fixture, mock_worker_proxy_cach assert service.stopped.is_set() @pytest.mark.asyncio - async def test_stop_while_stopped(self, service_fixture, mock_worker_proxy_cache): - """Test :class:`WorkerService` stop when already stopped. + async def test_stop_should_return_immediately_when_already_stopped( + self, service_fixture, mock_worker_proxy_cache + ): + """Test `WorkerService` stop when already stopped. Given: - A :class:`WorkerService` that is already stopped + A `WorkerService` that is already stopped When: stop RPC is called again Then: @@ -2238,13 +2045,13 @@ async def test_stop_while_stopped(self, service_fixture, mock_worker_proxy_cache assert service.stopped.is_set() @pytest.mark.asyncio - async def test_stop_negative_timeout_waits_indefinitely( + async def test_stop_should_wait_indefinitely_when_timeout_negative( self, service_fixture, mock_worker_proxy_cache ): - """Test :class:`WorkerService` stop with negative timeout waits indefinitely. + """Test `WorkerService` stop with negative timeout waits indefinitely. Given: - A :class:`WorkerService` with an active task + A `WorkerService` with an active task When: stop RPC is called with a negative timeout Then: @@ -2271,10 +2078,10 @@ async def test_stop_negative_timeout_waits_indefinitely( assert service.stopped.is_set() @pytest.mark.asyncio - async def test_stop_with_orphaned_cleanup_chain( + async def test_stop_should_drain_every_generation_of_orphaned_cleanup_chain( self, grpc_aio_stub, grpc_servicer, mock_worker_proxy_cache ): - """Test :class:`WorkerService` stop drains every generation of + """Test `WorkerService` stop drains every generation of orphaned worker-loop tasks. Given: @@ -2292,14 +2099,7 @@ async def test_stop_with_orphaned_cleanup_chain( # Arrange _drain_cleanup_observed = threading.Event() try: - mock_proxy = PicklableMock(spec=WorkerProxyLike, id="test-proxy-id") - wool_task = Task( - id=uuid4(), - callable=_drain_probe_routine, - args=(), - kwargs={}, - proxy=mock_proxy, - ) + wool_task = make_task(_drain_probe_routine) request = protocol.Request(task=wool_task.to_protobuf()) # Act @@ -2318,7 +2118,7 @@ async def test_stop_with_orphaned_cleanup_chain( _drain_cleanup_observed = None @pytest.mark.asyncio - async def test_dispatch_async_generator_task( + async def test_dispatch_should_yield_results_in_order_when_async_generator( self, grpc_aio_stub, mocker: MockerFixture, mock_worker_proxy_cache ): """Test dispatch() with async generator yields multiple results. @@ -2332,19 +2132,11 @@ async def test_dispatch_async_generator_task( """ # Arrange - async def test_generator(): + async def gen(): for i in range(3): yield f"gen_value_{i}" - mock_proxy = PicklableMock(spec=WorkerProxyLike, id="test-proxy-id") - - wool_task = Task( - id=uuid4(), - callable=test_generator, - args=(), - kwargs={}, - proxy=mock_proxy, - ) + wool_task = make_task(gen) request = protocol.Request(task=wool_task.to_protobuf()) next_request = protocol.Request(next=protocol.Void()) @@ -2376,16 +2168,16 @@ async def test_generator(): assert len(remaining) == 0 @pytest.mark.asyncio - async def test_dispatch_streaming_with_dispatch_timeout( + async def test_dispatch_should_restore_dispatch_timeout_per_iteration_when_streaming( self, grpc_aio_stub, mock_worker_proxy_cache ): - """Test :class:`WorkerService` streaming dispatch restores the + """Test `WorkerService` streaming dispatch restores the caller's ``dispatch_timeout`` for every iteration of an async-generator routine. Given: - An async-generator :class:`Task` whose - :class:`RuntimeContext` carries a non-default + An async-generator `Task` whose + `RuntimeContext` carries a non-default ``dispatch_timeout`` and whose routine reads ``wool.runtime.context.dispatch_timeout.get()`` on each iteration @@ -2395,7 +2187,7 @@ async def test_dispatch_streaming_with_dispatch_timeout( Then: Every yielded value equals the caller-supplied ``dispatch_timeout`` — confirming that the unified - :class:`DispatchSession` driver enters + `DispatchSession` driver enters ``work_task.runtime_context`` once for the lifetime of the generator. Regression guard for #176, where the pre-#187 ``_stream_from_worker`` code path dropped the @@ -2403,10 +2195,10 @@ async def test_dispatch_streaming_with_dispatch_timeout( ``dispatch_timeout`` at its default on subsequent frames. """ # Arrange - from wool.runtime.context import RuntimeContext + from wool.runtime.context.runtime import RuntimeContext async def capture_timeout(): - from wool.runtime.context import dispatch_timeout + from wool.runtime.context.runtime import dispatch_timeout yield dispatch_timeout.get() yield dispatch_timeout.get() @@ -2450,7 +2242,7 @@ async def drive(): assert captured == [2.5, 2.5, 2.5] @pytest.mark.asyncio - async def test_dispatch_async_generator_raises_during_iteration( + async def test_dispatch_should_yield_exception_when_async_generator_raises( self, grpc_aio_stub, mocker: MockerFixture, mock_worker_proxy_cache ): """Test dispatch() with async generator that raises yields exception. @@ -2469,15 +2261,7 @@ async def failing_generator(): yield "first_value" raise ValueError("Generator error") - mock_proxy = PicklableMock(spec=WorkerProxyLike, id="test-proxy-id") - - wool_task = Task( - id=uuid4(), - callable=failing_generator, - args=(), - kwargs={}, - proxy=mock_proxy, - ) + wool_task = make_task(failing_generator) request = protocol.Request(task=wool_task.to_protobuf()) next_request = protocol.Request(next=protocol.Void()) @@ -2501,13 +2285,78 @@ async def failing_generator(): await stream.write(next_request) response = await anext(aiter(stream)) assert response.HasField("exception") - assert response.HasField("context") + # Lazy-wire-frame: the unarmed worker omits the context + # field — the routine never touched a wool.ContextVar. + assert not response.HasField("context") exception = cloudpickle.loads(response.exception.dump) assert isinstance(exception, ValueError) assert str(exception) == "Generator error" @pytest.mark.asyncio - async def test_dispatch_async_generator_completes_normally( + async def test_dispatch_should_ship_exception_frame_when_mid_stream_decode_fails( + self, grpc_aio_stub, mocker: MockerFixture, mock_worker_proxy_cache + ): + """Test a mid-stream wire decode failure ships an exception terminal. + + Given: + A streaming dispatch whose first NextRequest succeeds and + yields a value, followed by a structurally-malformed wire + request (a `protocol.Request` with no payload oneof + set — `Frame.from_protobuf` rejects it with + ``ValueError``). + When: + The caller sends the malformed request and consumes the + stream. + Then: + The worker's per-frame decode catches the failure and + queues an `ExceptionResponseFrame` carrying the + decode error, then breaks out of the driver loop — the + dispatch is no longer serviceable after the framing + broke, so the typed terminal frame is the contract + instead of an opaque task death. + """ + + # Arrange + async def gen(): + for i in range(2): + yield i + + wool_task = make_task(gen) + + task_request = protocol.Request(task=wool_task.to_protobuf()) + next_request = protocol.Request(next=protocol.Void()) + # Malformed: no payload oneof set at all. Frame.from_protobuf + # raises ValueError("wire envelope has no payload set"). + malformed_request = protocol.Request() + + # Act + async with grpc_aio_stub() as stub: + stream = stub.dispatch() + await stream.write(task_request) + ack = await anext(aiter(stream)) + assert ack.HasField("ack") + + await stream.write(next_request) + first = await anext(aiter(stream)) + assert first.HasField("result") + assert cloudpickle.loads(first.result.dump) == 0 + + await stream.write(malformed_request) + await stream.done_writing() + remaining = [r async for r in stream] + + # Assert — the terminal frame is an ExceptionResponseFrame + # carrying the decode-time failure; no further frames after. + terminals = [r for r in remaining if r.HasField("exception")] + assert len(terminals) == 1 + shipped = cloudpickle.loads(terminals[0].exception.dump) + # The decode error class is ValueError (the bare-Frame raise); + # the caller observes the exception type rather than an + # opaque RpcError. + assert isinstance(shipped, ValueError) + + @pytest.mark.asyncio + async def test_dispatch_should_end_stream_cleanly_when_async_generator_completes( self, grpc_aio_stub, mocker: MockerFixture, mock_worker_proxy_cache ): """Test dispatch() with async generator completes cleanly. @@ -2521,19 +2370,11 @@ async def test_dispatch_async_generator_completes_normally( """ # Arrange - async def test_generator(): + async def gen(): for i in range(2): yield i - mock_proxy = PicklableMock(spec=WorkerProxyLike, id="test-proxy-id") - - wool_task = Task( - id=uuid4(), - callable=test_generator, - args=(), - kwargs={}, - proxy=mock_proxy, - ) + wool_task = make_task(gen) request = protocol.Request(task=wool_task.to_protobuf()) next_request = protocol.Request(next=protocol.Void()) @@ -2566,7 +2407,7 @@ async def test_generator(): assert len(remaining) == 0 @pytest.mark.asyncio - async def test_dispatch_async_generator_empty( + async def test_dispatch_should_end_stream_without_results_when_async_generator_empty( self, grpc_aio_stub, mocker: MockerFixture, mock_worker_proxy_cache ): """Test dispatch() with empty async generator. @@ -2584,15 +2425,7 @@ async def empty_generator(): return yield # unreachable, but makes it a generator - mock_proxy = PicklableMock(spec=WorkerProxyLike, id="test-proxy-id") - - wool_task = Task( - id=uuid4(), - callable=empty_generator, - args=(), - kwargs={}, - proxy=mock_proxy, - ) + wool_task = make_task(empty_generator) request = protocol.Request(task=wool_task.to_protobuf()) next_request = protocol.Request(next=protocol.Void()) @@ -2615,7 +2448,7 @@ async def empty_generator(): assert len(remaining) == 0 @pytest.mark.asyncio - async def test_dispatch_coroutine_for_comparison( + async def test_dispatch_should_yield_single_result_when_coroutine( self, grpc_aio_stub, mocker: MockerFixture, mock_worker_proxy_cache ): """Test dispatch() with coroutine task for comparison with async generator. @@ -2629,18 +2462,10 @@ async def test_dispatch_coroutine_for_comparison( """ # Arrange - async def test_coroutine(): + async def coro(): return "coroutine_result" - mock_proxy = PicklableMock(spec=WorkerProxyLike, id="test-proxy-id") - - wool_task = Task( - id=uuid4(), - callable=test_coroutine, - args=(), - kwargs={}, - proxy=mock_proxy, - ) + wool_task = make_task(coro) request = protocol.Request(task=wool_task.to_protobuf()) @@ -2659,17 +2484,17 @@ async def test_coroutine(): assert result == "coroutine_result" @pytest.mark.asyncio - async def test_stop_cancels_async_generator_task( + async def test_stop_should_cancel_active_async_generator_task( self, grpc_aio_stub, grpc_servicer, mocker: MockerFixture, mock_worker_proxy_cache, ): - """Test :class:`WorkerService` stop cancels an active async generator task. + """Test `WorkerService` stop cancels an active async generator task. Given: - A running :class:`WorkerService` with an active async generator task + A running `WorkerService` with an active async generator task When: stop RPC is called with a timeout of 0 Then: @@ -2682,15 +2507,7 @@ async def blocking_generator(): await asyncio.sleep(100) yield "should_not_reach" - mock_proxy = PicklableMock(spec=WorkerProxyLike, id="test-proxy-id") - - wool_task = Task( - id=uuid4(), - callable=blocking_generator, - args=(), - kwargs={}, - proxy=mock_proxy, - ) + wool_task = make_task(blocking_generator) request = protocol.Request(task=wool_task.to_protobuf()) @@ -2723,7 +2540,7 @@ async def blocking_generator(): assert grpc_servicer.stopped.is_set() @pytest.mark.asyncio - async def test_dispatch_with_version_in_ack( + async def test_dispatch_should_return_protocol_version_in_ack( self, grpc_aio_stub, mocker: MockerFixture, mock_worker_proxy_cache ): """Test dispatch returns Ack with protocol version. @@ -2740,15 +2557,7 @@ async def test_dispatch_with_version_in_ack( async def sample_task(): return "test_result" - mock_proxy = PicklableMock(spec=WorkerProxyLike, id="test-proxy-id") - - wool_task = Task( - id=uuid4(), - callable=sample_task, - args=(), - kwargs={}, - proxy=mock_proxy, - ) + wool_task = make_task(sample_task) request = protocol.Request(task=wool_task.to_protobuf()) @@ -2765,7 +2574,7 @@ async def sample_task(): assert ack_response.ack.version == protocol.__version__ @pytest.mark.asyncio - async def test_dispatch_with_empty_client_version( + async def test_dispatch_should_abort_failed_precondition_when_client_version_empty( self, grpc_aio_stub, mocker: MockerFixture, mock_worker_proxy_cache ): """Test dispatch rejects tasks with empty version field. @@ -2783,15 +2592,7 @@ async def test_dispatch_with_empty_client_version( async def sample_task(): return "test_result" - mock_proxy = PicklableMock(spec=WorkerProxyLike, id="test-proxy-id") - - wool_task = Task( - id=uuid4(), - callable=sample_task, - args=(), - kwargs={}, - proxy=mock_proxy, - ) + wool_task = make_task(sample_task) request = protocol.Request(task=wool_task.to_protobuf()) request.task.ClearField("version") @@ -2817,7 +2618,7 @@ async def sample_task(): client_major=st.integers(min_value=0, max_value=100), ) @pytest.mark.asyncio - async def test_dispatch_with_incompatible_major_version( + async def test_dispatch_should_abort_when_major_version_incompatible( self, grpc_aio_stub, mocker: MockerFixture, @@ -2845,15 +2646,7 @@ async def test_dispatch_with_incompatible_major_version( async def sample_task(): return "should_not_execute" - mock_proxy = PicklableMock(spec=WorkerProxyLike, id="test-proxy-id") - - wool_task = Task( - id=uuid4(), - callable=sample_task, - args=(), - kwargs={}, - proxy=mock_proxy, - ) + wool_task = make_task(sample_task) request = protocol.Request(task=wool_task.to_protobuf()) # Override version field to simulate incompatible client @@ -2881,7 +2674,7 @@ async def sample_task(): client_minor=st.integers(min_value=1, max_value=100), ) @pytest.mark.asyncio - async def test_dispatch_with_newer_client_same_major( + async def test_dispatch_should_abort_when_client_newer_same_major( self, grpc_aio_stub, mocker: MockerFixture, @@ -2890,7 +2683,7 @@ async def test_dispatch_with_newer_client_same_major( local_minor, client_minor, ): - """Test dispatch aborts with FAILED_PRECONDITION when client is newer than worker. + """Test dispatch aborts with FAILED_PRECONDITION on newer client. Given: A worker with version X.a.0 and a client with version @@ -2909,15 +2702,7 @@ async def test_dispatch_with_newer_client_same_major( async def sample_task(): return "should_not_execute" - mock_proxy = PicklableMock(spec=WorkerProxyLike, id="test-proxy-id") - - wool_task = Task( - id=uuid4(), - callable=sample_task, - args=(), - kwargs={}, - proxy=mock_proxy, - ) + wool_task = make_task(sample_task) request = protocol.Request(task=wool_task.to_protobuf()) request.task.version = f"{major}.{client_minor}.0" @@ -2934,7 +2719,7 @@ async def sample_task(): assert "Incompatible version" in (excinfo.value.details() or "") @pytest.mark.asyncio - async def test_dispatch_with_unparseable_client_version( + async def test_dispatch_should_abort_when_client_version_unparseable( self, grpc_aio_stub, mocker: MockerFixture, mock_worker_proxy_cache ): """Test dispatch rejects tasks with unparseable version. @@ -2952,15 +2737,7 @@ async def test_dispatch_with_unparseable_client_version( async def sample_task(): return "should_not_execute" - mock_proxy = PicklableMock(spec=WorkerProxyLike, id="test-proxy-id") - - wool_task = Task( - id=uuid4(), - callable=sample_task, - args=(), - kwargs={}, - proxy=mock_proxy, - ) + wool_task = make_task(sample_task) request = protocol.Request(task=wool_task.to_protobuf()) request.task.version = "not-a-version" @@ -2977,7 +2754,7 @@ async def sample_task(): assert "Unparseable version" in (excinfo.value.details() or "") @pytest.mark.asyncio - async def test_dispatch_async_generator_with_send( + async def test_dispatch_should_forward_send_values_via_asend_when_async_generator( self, grpc_aio_stub, mocker: MockerFixture, mock_worker_proxy_cache ): """Test dispatch() forwards send requests into async generator via asend(). @@ -2998,15 +2775,7 @@ async def echo_generator(): while value is not None: value = yield f"echo:{value}" - mock_proxy = PicklableMock(spec=WorkerProxyLike, id="test-proxy-id") - - wool_task = Task( - id=uuid4(), - callable=echo_generator, - args=(), - kwargs={}, - proxy=mock_proxy, - ) + wool_task = make_task(echo_generator) request = protocol.Request(task=wool_task.to_protobuf()) next_request = protocol.Request(next=protocol.Void()) @@ -3054,7 +2823,7 @@ async def echo_generator(): assert len(remaining) == 0 @pytest.mark.asyncio - async def test_dispatch_async_generator_send_then_close( + async def test_dispatch_should_stop_advancing_when_client_closes_after_sends( self, grpc_aio_stub, mocker: MockerFixture, mock_worker_proxy_cache ): """Test dispatch() handles client closing stream after sends. @@ -3078,15 +2847,7 @@ async def counting_generator(): else: count += 1 - mock_proxy = PicklableMock(spec=WorkerProxyLike, id="test-proxy-id") - - wool_task = Task( - id=uuid4(), - callable=counting_generator, - args=(), - kwargs={}, - proxy=mock_proxy, - ) + wool_task = make_task(counting_generator) request = protocol.Request(task=wool_task.to_protobuf()) next_request = protocol.Request(next=protocol.Void()) @@ -3121,7 +2882,7 @@ async def counting_generator(): assert len(remaining) == 0 @pytest.mark.asyncio - async def test_dispatch_pull_only_async_generator( + async def test_dispatch_should_yield_values_in_order_when_pull_only_generator( self, grpc_aio_stub, mocker: MockerFixture, mock_worker_proxy_cache ): """Test dispatch() with a pull-only async generator (no send type). @@ -3140,15 +2901,7 @@ async def pull_only(): yield "beta" yield "gamma" - mock_proxy = PicklableMock(spec=WorkerProxyLike, id="test-proxy-id") - - wool_task = Task( - id=uuid4(), - callable=pull_only, - args=(), - kwargs={}, - proxy=mock_proxy, - ) + wool_task = make_task(pull_only) request = protocol.Request(task=wool_task.to_protobuf()) next_request = protocol.Request(next=protocol.Void()) @@ -3180,7 +2933,7 @@ async def pull_only(): assert len(remaining) == 0 @pytest.mark.asyncio - async def test_dispatch_pull_only_async_generator_partial_consumption( + async def test_dispatch_should_yield_only_consumed_values_when_pull_only_generator_partial( # noqa: E501 self, grpc_aio_stub, mocker: MockerFixture, mock_worker_proxy_cache ): """Test dispatch() with partial consumption of a pull-only async generator. @@ -3199,15 +2952,7 @@ async def five_values(): for i in range(5): yield i - mock_proxy = PicklableMock(spec=WorkerProxyLike, id="test-proxy-id") - - wool_task = Task( - id=uuid4(), - callable=five_values, - args=(), - kwargs={}, - proxy=mock_proxy, - ) + wool_task = make_task(five_values) request = protocol.Request(task=wool_task.to_protobuf()) next_request = protocol.Request(next=protocol.Void()) @@ -3238,7 +2983,7 @@ async def five_values(): assert len(remaining) == 0 @pytest.mark.asyncio - async def test_dispatch_async_generator_interleaved_next_and_send( + async def test_dispatch_should_advance_correctly_when_next_and_send_interleaved( self, grpc_aio_stub, mocker: MockerFixture, mock_worker_proxy_cache ): """Test dispatch() with interleaved next and send requests. @@ -3262,15 +3007,7 @@ async def accumulator(): else: total += 1 - mock_proxy = PicklableMock(spec=WorkerProxyLike, id="test-proxy-id") - - wool_task = Task( - id=uuid4(), - callable=accumulator, - args=(), - kwargs={}, - proxy=mock_proxy, - ) + wool_task = make_task(accumulator) request = protocol.Request(task=wool_task.to_protobuf()) next_request = protocol.Request(next=protocol.Void()) @@ -3309,7 +3046,7 @@ async def accumulator(): await stream.done_writing() @pytest.mark.asyncio - async def test_dispatch_async_generator_throw_terminates( + async def test_dispatch_should_yield_exception_when_throw_terminates_generator( self, grpc_aio_stub, mocker: MockerFixture, mock_worker_proxy_cache ): """Test dispatch() with a throw request that terminates the generator. @@ -3328,15 +3065,7 @@ async def simple_generator(): yield "first" yield "second" - mock_proxy = PicklableMock(spec=WorkerProxyLike, id="test-proxy-id") - - wool_task = Task( - id=uuid4(), - callable=simple_generator, - args=(), - kwargs={}, - proxy=mock_proxy, - ) + wool_task = make_task(simple_generator) request = protocol.Request(task=wool_task.to_protobuf()) next_request = protocol.Request(next=protocol.Void()) @@ -3368,7 +3097,7 @@ async def simple_generator(): assert str(exception) == "injected" @pytest.mark.asyncio - async def test_dispatch_streaming_with_proxy_and_dispatch_context( + async def test_dispatch_should_set_proxy_and_disable_dispatch_when_streaming( self, grpc_aio_stub, mocker: MockerFixture, mock_worker_proxy_cache ): """Test dispatch() sets proxy and do_dispatch(False) for streaming. @@ -3394,15 +3123,7 @@ async def capturing_generator(): "has_proxy": wool.__proxy__.get() is not None, } - mock_proxy = PicklableMock(spec=WorkerProxyLike, id="test-proxy-id") - - wool_task = Task( - id=uuid4(), - callable=capturing_generator, - args=(), - kwargs={}, - proxy=mock_proxy, - ) + wool_task = make_task(capturing_generator) request = protocol.Request(task=wool_task.to_protobuf()) next_request = protocol.Request(next=protocol.Void()) @@ -3429,20 +3150,22 @@ async def capturing_generator(): assert result["has_proxy"] is True @pytest.mark.asyncio - async def test_dispatch_streaming_without_proxy_pool(self, grpc_aio_stub): - """Test :class:`WorkerService` streaming dispatch when - :data:`wool.__proxy_pool__` is not configured on the worker. + async def test_dispatch_should_ship_runtime_error_when_proxy_pool_unset( + self, grpc_aio_stub + ): + """Test `WorkerService` streaming dispatch when + `wool.__proxy_pool__` is not configured on the worker. Given: - A worker process where :data:`wool.__proxy_pool__` is + A worker process where `wool.__proxy_pool__` is unset — the worker cannot lease a proxy and therefore - cannot bind :data:`wool.__proxy__` for the routine + cannot bind `wool.__proxy__` for the routine When: The dispatch RPC is invoked Then: It should reply with a terminal exception Response - carrying the :class:`RuntimeError` raised by - :func:`routine_scope`'s precondition check — + carrying the `RuntimeError` raised by + `routine_scope`'s precondition check — proxy-less execution is broken by construction (no nested-dispatch capability) and the handler surfaces the precondition violation rather than silently @@ -3454,14 +3177,7 @@ async def test_dispatch_streaming_without_proxy_pool(self, grpc_aio_stub): async def capturing_generator(): yield {"has_proxy": wool.__proxy__.get() is not None} - mock_proxy = PicklableMock(spec=WorkerProxyLike, id="test-proxy-id") - wool_task = Task( - id=uuid4(), - callable=capturing_generator, - args=(), - kwargs={}, - proxy=mock_proxy, - ) + wool_task = make_task(capturing_generator) request = protocol.Request(task=wool_task.to_protobuf()) next_request = protocol.Request(next=protocol.Void()) @@ -3485,7 +3201,7 @@ async def capturing_generator(): assert "wool.__proxy_pool__ is not initialized" in str(raised) @pytest.mark.asyncio - async def test_dispatch_streaming_proxy_cleanup_on_error( + async def test_dispatch_should_clean_up_proxy_context_when_generator_errors( self, grpc_aio_stub, mocker: MockerFixture, mock_worker_proxy_cache ): """Test dispatch() cleans up proxy context on generator error. @@ -3505,15 +3221,7 @@ async def failing_generator(): yield "before_error" raise ValueError("generator_error") - mock_proxy = PicklableMock(spec=WorkerProxyLike, id="test-proxy-id") - - wool_task = Task( - id=uuid4(), - callable=failing_generator, - args=(), - kwargs={}, - proxy=mock_proxy, - ) + wool_task = make_task(failing_generator) request = protocol.Request(task=wool_task.to_protobuf()) next_request = protocol.Request(next=protocol.Void()) @@ -3547,7 +3255,7 @@ async def failing_generator(): mock_worker_proxy_cache.get.return_value.__aexit__.assert_called() @pytest.mark.asyncio - async def test_dispatch_streaming_asend_with_proxy_context( + async def test_dispatch_should_forward_asend_in_proxy_context_when_streaming( self, grpc_aio_stub, mocker: MockerFixture, mock_worker_proxy_cache ): """Test dispatch() forwards asend() within do_dispatch(False) context. @@ -3573,15 +3281,7 @@ async def echo_with_capture(): while value is not None: value = yield {"echo": value, "do_dispatch": _dd()} - mock_proxy = PicklableMock(spec=WorkerProxyLike, id="test-proxy-id") - - wool_task = Task( - id=uuid4(), - callable=echo_with_capture, - args=(), - kwargs={}, - proxy=mock_proxy, - ) + wool_task = make_task(echo_with_capture) request = protocol.Request(task=wool_task.to_protobuf()) next_request = protocol.Request(next=protocol.Void()) @@ -3629,7 +3329,7 @@ async def echo_with_capture(): assert result2 == {"echo": "world", "do_dispatch": False} @pytest.mark.asyncio - async def test_dispatch_streaming_athrow_with_proxy_context( + async def test_dispatch_should_forward_athrow_in_proxy_context_when_streaming( self, grpc_aio_stub, mocker: MockerFixture, mock_worker_proxy_cache ): """Test dispatch() forwards athrow() within do_dispatch(False) context. @@ -3657,15 +3357,7 @@ async def resilient_generator(): "has_proxy": wool.__proxy__.get() is not None, } - mock_proxy = PicklableMock(spec=WorkerProxyLike, id="test-proxy-id") - - wool_task = Task( - id=uuid4(), - callable=resilient_generator, - args=(), - kwargs={}, - proxy=mock_proxy, - ) + wool_task = make_task(resilient_generator) request = protocol.Request(task=wool_task.to_protobuf()) next_request = protocol.Request(next=protocol.Void()) @@ -3702,7 +3394,7 @@ async def resilient_generator(): assert result["do_dispatch"] is False assert result["has_proxy"] is True - def test___init___with_backpressure_hook(self): + def test___init___should_expose_unset_lifecycle_events_when_backpressure_given(self): """Test WorkerService initialization with a backpressure hook. Given: @@ -3725,146 +3417,13 @@ def hook(ctx): assert not service.stopped.is_set() @pytest.mark.asyncio - async def test_dispatch_with_caller_context_var_and_backpressure_hook( - self, grpc_aio_stub, mock_worker_proxy_cache - ): - """Test the backpressure hook observes caller-shipped ContextVar - values when it evaluates admission. - - Given: - A :class:`WorkerService` whose backpressure hook reads a - :class:`wool.ContextVar` to decide admission, and a - dispatch Request whose ``context.vars`` carries a value - for that var - When: - The dispatch RPC is invoked - Then: - The hook should observe the caller-shipped value (not - ``LookupError``), because the dispatch handler scopes the - caller-state Context via ``Context.run`` for the duration - of hook evaluation — the handler does not install it - against the main-loop task and so does not leak ownership - across the main/worker boundary - """ - # Arrange - namespace = f"bp_ctxvar_{uuid4().hex}" - tenant = wool.ContextVar("tenant", namespace=namespace) - observed: list[str] = [] - - def hook(ctx): - observed.append(tenant.get("")) - return False - - async def sample_task(): - return "accepted" - - mock_proxy = PicklableMock(spec=WorkerProxyLike, id="test-proxy-id") - wool_task = Task( - id=uuid4(), - callable=sample_task, - args=(), - kwargs={}, - proxy=mock_proxy, - ) - context_pb = protocol.Context(id=uuid4().hex) - context_pb.vars.add( - namespace=tenant.namespace, - name=tenant.name, - value=cloudpickle.dumps("acme-corp"), - ) - request = protocol.Request(task=wool_task.to_protobuf(), context=context_pb) - - service = WorkerService(backpressure=hook) - - # Act - async with grpc_aio_stub(servicer=service) as stub: - stream = stub.dispatch() - await stream.write(request) - await stream.done_writing() - responses = [r async for r in stream] - - # Assert - ack, response = responses - assert ack.HasField("ack") - assert response.HasField("result") - assert cloudpickle.loads(response.result.dump) == "accepted" - assert observed == ["acme-corp"] - - @pytest.mark.asyncio - async def test_dispatch_async_backpressure_hook_observes_caller_context_vars( - self, grpc_aio_stub, mock_worker_proxy_cache - ): - """Test an async backpressure hook observes caller-shipped - ContextVar values across its await suspension. - - Given: - A :class:`WorkerService` whose backpressure hook is - ``async def`` and reads a :class:`wool.ContextVar` after - an ``await asyncio.sleep(0)`` checkpoint, and a dispatch - Request whose ``context.vars`` carries a value for that - var - When: - The dispatch RPC is invoked - Then: - The hook should observe the caller-shipped value after - the suspension — the dispatch handler must keep the - caller-state Context attached across the await of the - hook coroutine, not just the synchronous body that - constructs it - """ - # Arrange - namespace = f"async_bp_ctxvar_{uuid4().hex}" - tenant = wool.ContextVar("tenant", namespace=namespace) - observed: list[str] = [] - - async def hook(ctx): - await asyncio.sleep(0) - observed.append(tenant.get("")) - return False - - async def sample_task(): - return "accepted" - - mock_proxy = PicklableMock(spec=WorkerProxyLike, id="test-proxy-id") - wool_task = Task( - id=uuid4(), - callable=sample_task, - args=(), - kwargs={}, - proxy=mock_proxy, - ) - context_pb = protocol.Context(id=uuid4().hex) - context_pb.vars.add( - namespace=tenant.namespace, - name=tenant.name, - value=cloudpickle.dumps("acme-corp"), - ) - request = protocol.Request(task=wool_task.to_protobuf(), context=context_pb) - - service = WorkerService(backpressure=hook) - - # Act - async with grpc_aio_stub(servicer=service) as stub: - stream = stub.dispatch() - await stream.write(request) - await stream.done_writing() - responses = [r async for r in stream] - - # Assert - ack, response = responses - assert ack.HasField("ack") - assert response.HasField("result") - assert cloudpickle.loads(response.result.dump) == "accepted" - assert observed == ["acme-corp"] - - @pytest.mark.asyncio - async def test_dispatch_with_sync_backpressure_accepting( + async def test_dispatch_should_accept_task_when_sync_backpressure_returns_false( self, grpc_aio_stub, mock_worker_proxy_cache ): """Test dispatch succeeds when sync backpressure hook returns False. Given: - A :class:`WorkerService` with a sync backpressure hook that returns False + A `WorkerService` with a sync backpressure hook that returns False When: Dispatch RPC is called Then: @@ -3875,14 +3434,7 @@ async def test_dispatch_with_sync_backpressure_accepting( async def sample_task(): return "accepted" - mock_proxy = PicklableMock(spec=WorkerProxyLike, id="test-proxy-id") - wool_task = Task( - id=uuid4(), - callable=sample_task, - args=(), - kwargs={}, - proxy=mock_proxy, - ) + wool_task = make_task(sample_task) request = protocol.Request(task=wool_task.to_protobuf()) def hook(ctx): @@ -3904,13 +3456,13 @@ def hook(ctx): assert cloudpickle.loads(response.result.dump) == "accepted" @pytest.mark.asyncio - async def test_dispatch_with_sync_backpressure_rejecting( + async def test_dispatch_should_abort_when_sync_backpressure_returns_true( self, grpc_aio_stub, mock_worker_proxy_cache ): """Test dispatch aborts when sync backpressure hook returns True. Given: - A :class:`WorkerService` with a sync backpressure hook that returns True + A `WorkerService` with a sync backpressure hook that returns True When: Dispatch RPC is called Then: @@ -3921,14 +3473,7 @@ async def test_dispatch_with_sync_backpressure_rejecting( async def sample_task(): return "should_not_reach" - mock_proxy = PicklableMock(spec=WorkerProxyLike, id="test-proxy-id") - wool_task = Task( - id=uuid4(), - callable=sample_task, - args=(), - kwargs={}, - proxy=mock_proxy, - ) + wool_task = make_task(sample_task) request = protocol.Request(task=wool_task.to_protobuf()) def hook(ctx): @@ -3947,13 +3492,13 @@ def hook(ctx): assert exc_info.value.code() == StatusCode.RESOURCE_EXHAUSTED @pytest.mark.asyncio - async def test_dispatch_with_async_backpressure_rejecting( + async def test_dispatch_should_abort_when_async_backpressure_returns_true( self, grpc_aio_stub, mock_worker_proxy_cache ): """Test dispatch aborts when async backpressure hook returns True. Given: - A :class:`WorkerService` with an async backpressure hook that returns True + A `WorkerService` with an async backpressure hook that returns True When: Dispatch RPC is called Then: @@ -3964,14 +3509,7 @@ async def test_dispatch_with_async_backpressure_rejecting( async def sample_task(): return "should_not_reach" - mock_proxy = PicklableMock(spec=WorkerProxyLike, id="test-proxy-id") - wool_task = Task( - id=uuid4(), - callable=sample_task, - args=(), - kwargs={}, - proxy=mock_proxy, - ) + wool_task = make_task(sample_task) request = protocol.Request(task=wool_task.to_protobuf()) async def async_hook(ctx): @@ -3990,13 +3528,13 @@ async def async_hook(ctx): assert exc_info.value.code() == StatusCode.RESOURCE_EXHAUSTED @pytest.mark.asyncio - async def test_dispatch_with_async_backpressure_accepting( + async def test_dispatch_should_accept_task_when_async_backpressure_returns_false( self, grpc_aio_stub, mock_worker_proxy_cache ): """Test dispatch succeeds when async backpressure hook returns False. Given: - A :class:`WorkerService` with an async backpressure hook that returns False + A `WorkerService` with an async backpressure hook that returns False When: Dispatch RPC is called Then: @@ -4005,16 +3543,9 @@ async def test_dispatch_with_async_backpressure_accepting( # Arrange async def sample_task(): - return "async_accepted" - - mock_proxy = PicklableMock(spec=WorkerProxyLike, id="test-proxy-id") - wool_task = Task( - id=uuid4(), - callable=sample_task, - args=(), - kwargs={}, - proxy=mock_proxy, - ) + return "async_accepted" + + wool_task = make_task(sample_task) request = protocol.Request(task=wool_task.to_protobuf()) async def async_hook(ctx): @@ -4036,18 +3567,18 @@ async def async_hook(ctx): assert cloudpickle.loads(response.result.dump) == "async_accepted" @pytest.mark.asyncio - async def test_dispatch_ships_stop_async_iteration_raw_for_coroutine_routine( + async def test_dispatch_should_ship_stop_async_iteration_raw_when_coroutine_raises( self, grpc_aio_stub, mock_worker_proxy_cache ): - """Test the wire surfaces :class:`StopAsyncIteration` raw + """Test the wire surfaces `StopAsyncIteration` raw when a coroutine routine raises it at the top level — matching stdlib ``await coro()`` semantics. - Regression test for F5. Pre-fix, the wire shipped - :class:`RuntimeError` because :meth:`DispatchSession._iterate` + Regression test. Pre-fix, the wire shipped + `RuntimeError` because `DispatchSession._iterate` is an async generator: when the worker's ``_ResponseQueue.get`` re-raised the routine's - :class:`StopAsyncIteration` inside _iterate's body, PEP + `StopAsyncIteration` inside _iterate's body, PEP 525 converted it to ``RuntimeError("async generator raised StopAsyncIteration")`` at the asyncgen boundary before the dispatch handler's terminal-exception clause @@ -4058,26 +3589,19 @@ async def test_dispatch_ships_stop_async_iteration_raw_for_coroutine_routine( coroutine had been local. Given: - A coroutine routine that raises :class:`StopAsyncIteration` + A coroutine routine that raises `StopAsyncIteration` When: The dispatch RPC ships its terminal-exception Response Then: - It should carry :class:`StopAsyncIteration`, not - :class:`RuntimeError`. + It should carry `StopAsyncIteration`, not + `RuntimeError`. """ async def coro_raising_sai(): raise StopAsyncIteration("from coroutine") - mock_proxy = PicklableMock(spec=WorkerProxyLike, id="test-proxy-id") - wool_task = Task( - id=uuid4(), - callable=coro_raising_sai, - args=(), - kwargs={}, - proxy=mock_proxy, - ) + wool_task = make_task(coro_raising_sai) first_request = protocol.Request(task=wool_task.to_protobuf()) # Act @@ -4105,33 +3629,33 @@ async def coro_raising_sai(): ) @pytest.mark.asyncio - async def test_dispatch_ships_runtime_error_for_async_generator_raising_sai( + async def test_dispatch_should_ship_runtime_error_when_generator_raises_stop_async_iteration( # noqa: E501 self, grpc_aio_stub, mock_worker_proxy_cache ): - """Test the wire surfaces :class:`RuntimeError` (not - :class:`StopAsyncIteration`) when an async generator - routine raises :class:`StopAsyncIteration` from its body + """Test the wire surfaces `RuntimeError` (not + `StopAsyncIteration`) when an async generator + routine raises `StopAsyncIteration` from its body — matching stdlib ``async for x in agen()`` semantics (PEP 525). Companion to the coroutine StopAsyncIteration regression - test (F5). The fix targeted the coroutine path (unwrap + test. The fix targeted the coroutine path (unwrap PEP 525's auto-conversion); the async generator path was already correct because the user's asyncgen runtime does the conversion before the worker ever sees SAI — the - dispatch handler observes a :class:`RuntimeError` from + dispatch handler observes a `RuntimeError` from ``gen.asend`` and ships it. This test pins the desired behavior so a future change to the unwrap logic does not accidentally widen and corrupt the asyncgen contract. Given: An async generator routine whose body raises - :class:`StopAsyncIteration` mid-iteration + `StopAsyncIteration` mid-iteration When: The dispatch RPC ships its terminal-exception Response Then: - It should carry :class:`RuntimeError` whose + It should carry `RuntimeError` whose ``__cause__`` preserves the original SAI. """ @@ -4139,18 +3663,11 @@ async def agen_raising_sai(): raise StopAsyncIteration("from agen") yield 1 - mock_proxy = PicklableMock(spec=WorkerProxyLike, id="test-proxy-id") - wool_task = Task( - id=uuid4(), - callable=agen_raising_sai, - args=(), - kwargs={}, - proxy=mock_proxy, - ) + wool_task = make_task(agen_raising_sai) first_request = protocol.Request(task=wool_task.to_protobuf()) next_request = protocol.Request( next=protocol.Void(), - context=protocol.Context(id=uuid4().hex), + context=protocol.ChainManifest(id=uuid4().hex), ) # Act @@ -4183,77 +3700,64 @@ async def agen_raising_sai(): assert shipped.__cause__.args == ("from agen",) @pytest.mark.asyncio - async def test_dispatch_attaches_strict_mode_context_warnings_as_notes( + async def test_dispatch_should_chain_strict_mode_encode_failure_as_cause( self, grpc_aio_stub, mock_worker_proxy_cache, mocker: MockerFixture ): - """Test the dispatch handler attaches strict-mode - :class:`ContextDecodeWarning` peers to the routine's - exception via PEP 678 ``__notes__`` and a - ``__wool_context_warnings__`` attribute, preserving the - routine exception's type. - - Regression test for the user-facing contract pinned by - F11's redesign. Pre-redesign, when a routine failed AND - ``handler.context.to_protobuf`` raised (only possible - when the operator promoted :class:`ContextDecodeWarning` - to an exception via - ``warnings.filterwarnings("error", - category=ContextDecodeWarning)``), :func:`merge_exceptions` - wrapped the routine failure and the encode peers in a - :class:`BaseExceptionGroup` — forcing strict-mode users - to migrate their existing ``except RoutineError`` clauses - to ``except*`` or ``except ExceptionGroup``. The redesign - attaches peers to the routine exception via PEP 678 - notes (visible in tracebacks) and a - ``__wool_context_warnings__`` attribute (programmatic - access), so existing exception-handling code keeps - working unchanged. + """Test the dispatch handler chains a strict-mode + `wool.ChainSerializationError` onto the routine's + exception as ``__cause__`` via ``raise from``, preserving + the routine exception's type. + + Regression test for the user-facing contract around a + strict-mode context-encode failure. Under the "fail loud" + design the routine exception keeps its primary class so + callers don't migrate their existing ``except RoutineError`` + clauses; the encode error rides on ``__cause__`` for + traceback visibility. Given: A coroutine routine that raises a custom exception, - and ``handler.context.to_protobuf`` patched to raise - a :class:`BaseExceptionGroup` of synthetic - :class:`ContextDecodeWarning` peers (simulating - strict-mode encode failure) + and the worker driver's final + `ChainManifest.to_protobuf` patched to raise a + `wool.ChainSerializationError` aggregating per-var + warnings (simulating strict-mode encode failure). When: The dispatch RPC ships its terminal-exception - Response + Response. Then: It should ship the routine's exception type bare - (not wrapped in any group), with the warnings - attached as ``__notes__`` and - ``__wool_context_warnings__``. + (not wrapped in any group), with the + `wool.ChainSerializationError` chained on + ``__cause__``. The worker arms its post-run encode by + touching `wool.ContextVar` inside the routine + so the encode path runs. """ - from wool.runtime.context import Context - from wool.runtime.context import ContextDecodeWarning + from wool.runtime.context.manifest import ChainManifest as _Context - class _RoutineFailure(Exception): - pass + namespace = f"chain_cause_{uuid4().hex}" + arm_var: wool.ContextVar[str] = wool.ContextVar("arm", namespace=namespace) async def failing_task(): - raise _RoutineFailure("primary signal") + # Touch a ContextVar so the worker arms; the patched + # encode then fires on the post-run path. + arm_var.set("arm") + # Use a stdlib exception that tblib registered at wool + # import time — cause-chain preservation across the wire + # depends on the routine exception's class having a + # tblib-compatible reduce. + raise ValueError("primary signal") - original_to_protobuf = Context.to_protobuf + original_encode = _Context.to_protobuf def encode_with_strict_failure(self, *args, **kwargs): - raise BaseExceptionGroup( - "strict-mode context encode failure", - [ - ContextDecodeWarning("var-1 unencodable"), - ContextDecodeWarning("var-2 unencodable"), - ], + raise wool.ChainSerializationError( + wool.SerializationWarning("var-1 unencodable"), + wool.SerializationWarning("var-2 unencodable"), ) - mocker.patch.object(Context, "to_protobuf", encode_with_strict_failure) + mocker.patch.object(_Context, "to_protobuf", encode_with_strict_failure) - mock_proxy = PicklableMock(spec=WorkerProxyLike, id="test-proxy-id") - wool_task = Task( - id=uuid4(), - callable=failing_task, - args=(), - kwargs={}, - proxy=mock_proxy, - ) + wool_task = make_task(failing_task) first_request = protocol.Request(task=wool_task.to_protobuf()) # Act @@ -4264,8 +3768,8 @@ def encode_with_strict_failure(self, *args, **kwargs): responses = [r async for r in stream] # Restore so the gRPC fixture's teardown does not - # explode on Context.to_protobuf calls during cleanup. - mocker.patch.object(Context, "to_protobuf", original_to_protobuf) + # explode on ChainManifest.to_protobuf calls during cleanup. + mocker.patch.object(_Context, "to_protobuf", original_encode) # Assert ack, terminal = responses @@ -4274,49 +3778,34 @@ def encode_with_strict_failure(self, *args, **kwargs): shipped = cloudpickle.loads(terminal.exception.dump) # The routine's exception type is preserved — caller's - # existing ``except _RoutineFailure`` continues to catch. - assert isinstance(shipped, _RoutineFailure), ( + # existing ``except ValueError`` continues to catch. + assert isinstance(shipped, ValueError), ( f"wire must ship the routine's exception type bare, " f"not a wrapper group — observed {type(shipped).__name__}" ) assert shipped.args == ("primary signal",) - # PEP 678 notes carry the warnings as human-readable - # diagnostic — they show up in tracebacks naturally. - assert hasattr(shipped, "__notes__"), ( - "shipped exception must have __notes__ populated" - ) - notes_text = "\n".join(shipped.__notes__) - assert "var-1 unencodable" in notes_text, ( - "first ContextDecodeWarning must appear in " - f"__notes__; observed: {shipped.__notes__}" - ) - assert "var-2 unencodable" in notes_text, ( - "second ContextDecodeWarning must appear in " - f"__notes__; observed: {shipped.__notes__}" - ) - - # __wool_context_warnings__ provides structured access - # for programmatic inspection. - assert hasattr(shipped, "__wool_context_warnings__"), ( - "shipped exception must carry __wool_context_warnings__" + # The strict-mode encode failure rides on ``__cause__`` + # via ``raise from`` chaining. + cause = shipped.__cause__ + assert isinstance(cause, wool.ChainSerializationError), ( + f"encode failure must appear on ``__cause__``; observed: " + f"{type(cause).__name__}" ) - warnings = shipped.__wool_context_warnings__ - assert len(warnings) == 2 - assert all(isinstance(w, ContextDecodeWarning) for w in warnings) - assert {str(w) for w in warnings} == { + warning_messages = {str(w) for w in cause.warnings} + assert warning_messages == { "var-1 unencodable", "var-2 unencodable", } @pytest.mark.asyncio - async def test_dispatch_with_backpressure_receiving_context( + async def test_dispatch_should_pass_backpressure_context_to_hook( self, grpc_aio_stub, mock_worker_proxy_cache ): """Test backpressure hook receives correct context. Given: - A :class:`WorkerService` with a backpressure hook that captures its argument + A `WorkerService` with a backpressure hook that captures its argument When: Dispatch RPC is called Then: @@ -4328,14 +3817,7 @@ async def test_dispatch_with_backpressure_receiving_context( async def sample_task(): return "result" - mock_proxy = PicklableMock(spec=WorkerProxyLike, id="test-proxy-id") - wool_task = Task( - id=uuid4(), - callable=sample_task, - args=(), - kwargs={}, - proxy=mock_proxy, - ) + wool_task = make_task(sample_task) request = protocol.Request(task=wool_task.to_protobuf()) captured = [] @@ -4361,13 +3843,13 @@ def hook(ctx): assert ctx.task.id == wool_task.id @pytest.mark.asyncio - async def test_dispatch_with_backpressure_and_active_tasks( + async def test_dispatch_should_report_active_task_count_to_backpressure_hook( self, grpc_aio_stub, mock_worker_proxy_cache ): """Test backpressure hook sees correct active task count. Given: - A :class:`WorkerService` with one active task already dispatched + A `WorkerService` with one active task already dispatched When: A second dispatch RPC is called with a backpressure hook Then: @@ -4377,27 +3859,13 @@ async def test_dispatch_with_backpressure_and_active_tasks( global _control_event _control_event = threading.Event() - mock_proxy = PicklableMock(spec=WorkerProxyLike, id="test-proxy-id") - - first_task = Task( - id=uuid4(), - callable=_controllable_task, - args=(), - kwargs={}, - proxy=mock_proxy, - ) + first_task = make_task(_controllable_task) first_request = protocol.Request(task=first_task.to_protobuf()) async def second_fn(): return "second" - second_task = Task( - id=uuid4(), - callable=second_fn, - args=(), - kwargs={}, - proxy=mock_proxy, - ) + second_task = make_task(second_fn) second_request = protocol.Request(task=second_task.to_protobuf()) captured_count = [] @@ -4438,13 +3906,13 @@ def hook(ctx): assert captured_count == [0, 1] @pytest.mark.asyncio - async def test_dispatch_with_backpressure_hook_raising_exception( + async def test_dispatch_should_propagate_grpc_failure_when_backpressure_hook_raises( self, grpc_aio_stub, mock_worker_proxy_cache ): """Test dispatch surfaces error when backpressure hook raises. Given: - A :class:`WorkerService` with a backpressure hook that + A `WorkerService` with a backpressure hook that raises RuntimeError When: Dispatch RPC is called @@ -4456,14 +3924,7 @@ async def test_dispatch_with_backpressure_hook_raising_exception( async def sample_task(): return "should_not_reach" - mock_proxy = PicklableMock(spec=WorkerProxyLike, id="test-proxy-id") - wool_task = Task( - id=uuid4(), - callable=sample_task, - args=(), - kwargs={}, - proxy=mock_proxy, - ) + wool_task = make_task(sample_task) request = protocol.Request(task=wool_task.to_protobuf()) def hook(ctx): @@ -4481,21 +3942,25 @@ def hook(ctx): pass @pytest.mark.asyncio - async def test_dispatch_with_caller_vars_for_coroutine_task( + async def test_dispatch_should_apply_caller_vars_before_running_coroutine( self, grpc_aio_stub, mock_worker_proxy_cache ): """Test dispatch applies caller-side vars before running a coroutine. Given: - A Request carrying a non-empty ``vars`` map (a - wool.ContextVar set on the caller, serialized via _dumps) - and a coroutine Task that reads the var + A coroutine Task plus a first mid-stream + `NextRequest` carrying a non-empty ``vars`` map (a + ``wool.ContextVar`` set on the caller, serialized via + cloudpickle). Under the per-frame architecture the + initial `TaskRequest` ships only dispatch + metadata; per-step manifests ride on subsequent + mid-stream frames. When: The dispatch RPC is invoked end-to-end via the gRPC stub Then: The Response's result equals the caller-side var value — - the worker applied the wire-shipped snapshot before running - the task. + the worker mounted the wire-shipped manifest before + running the task. """ # Arrange var = wool.ContextVar("srv001_caller_var", namespace="test_srv_vars") @@ -4503,32 +3968,28 @@ async def test_dispatch_with_caller_vars_for_coroutine_task( async def reader_task(): return var.get() - mock_proxy = PicklableMock(spec=WorkerProxyLike, id="test-proxy-id") - wool_task = Task( - id=uuid4(), - callable=reader_task, - args=(), - kwargs={}, - proxy=mock_proxy, - ) + wool_task = make_task(reader_task) - request = protocol.Request( - task=wool_task.to_protobuf(), - context=protocol.Context( + task_request = protocol.Request(task=wool_task.to_protobuf()) + next_request = protocol.Request( + next=protocol.Void(), + context=protocol.ChainManifest( + id=uuid4().hex, vars=[ protocol.ContextVar( namespace=var.namespace, name=var.name, value=cloudpickle.dumps("caller-side-value"), ) - ] + ], ), ) # Act async with grpc_aio_stub() as stub: stream = stub.dispatch() - await stream.write(request) + await stream.write(task_request) + await stream.write(next_request) await stream.done_writing() responses = [r async for r in stream] @@ -4539,108 +4000,27 @@ async def reader_task(): assert cloudpickle.loads(response.result.dump) == "caller-side-value" @pytest.mark.asyncio - async def test_dispatch_with_per_frame_caller_vars_for_async_gen( - self, grpc_aio_stub, mock_worker_proxy_cache - ): - """Test streaming dispatch applies per-frame vars before each asend. - - Given: - An async-generator Task whose frames yield the current - value of a caller-side wool.ContextVar, with subsequent - next Requests carrying updated ``vars`` maps each iteration - When: - The stream is iterated with changing ``vars`` on each frame - Then: - Each response's result reflects the per-frame ``vars`` - applied on the worker — forward-propagation is honored at - every streaming frame, not just the first. - """ - # Arrange - var = wool.ContextVar("srv002_frame_var", namespace="test_srv_vars") - - async def streaming_task(): - while True: - yield var.get() - - mock_proxy = PicklableMock(spec=WorkerProxyLike, id="test-proxy-id") - wool_task = Task( - id=uuid4(), - callable=streaming_task, - args=(), - kwargs={}, - proxy=mock_proxy, - ) - - initial_request = protocol.Request( - task=wool_task.to_protobuf(), - context=protocol.Context( - vars=[ - protocol.ContextVar( - namespace=var.namespace, - name=var.name, - value=cloudpickle.dumps("first"), - ) - ] - ), - ) - - # Act - async with grpc_aio_stub() as stub: - stream = stub.dispatch() - await stream.write(initial_request) - - response = await anext(aiter(stream)) - assert response.HasField("ack") - - results: list = [] - for frame_value in ("first", "second", "third"): - await stream.write( - protocol.Request( - next=protocol.Void(), - context=protocol.Context( - vars=[ - protocol.ContextVar( - namespace=var.namespace, - name=var.name, - value=cloudpickle.dumps(frame_value), - ) - ], - ), - ) - ) - response = await anext(aiter(stream)) - assert response.HasField("result") - results.append(cloudpickle.loads(response.result.dump)) - - await stream.done_writing() - async for _ in stream: - pass - - # Assert - assert results == ["first", "second", "third"] - - @pytest.mark.asyncio - async def test_dispatch_with_routine_raising_cancelled_during_aclose( + async def test_dispatch_should_end_cleanly_when_routine_raises_cancelled_during_aclose( # noqa: E501 self, grpc_aio_stub, mock_worker_proxy_cache ): - """Test :class:`WorkerService` streaming dispatch ends + """Test `WorkerService` streaming dispatch ends cleanly when the routine raises CancelledError during ``aclose`` on a natural-end iteration. - :func:`routine_scope` propagates aclose-time exceptions + `routine_scope` propagates aclose-time exceptions (matching stdlib ``await agen.aclose()`` semantics — see the unit tests in ``tests/runtime/routine/test_task.py`` for direct coverage). For natural-end iteration, the consumer's ``_iterate`` has already returned by the time the worker - runs aclose, and :meth:`drain` swallows the worker-side - :class:`asyncio.CancelledError` when the dispatch task + runs aclose, and `drain` swallows the worker-side + `asyncio.CancelledError` when the dispatch task itself isn't being cancelled. Net wire-level result: clean stream end with no terminal exception response. Given: A streaming async-generator routine whose teardown - handler catches :class:`GeneratorExit` during aclose - and re-raises as :class:`asyncio.CancelledError`, plus + handler catches `GeneratorExit` during aclose + and re-raises as `asyncio.CancelledError`, plus a dispatch that ends naturally (caller closes the stream without invoking service.stop). When: @@ -4661,14 +4041,7 @@ async def teardown_cancelling_generator(): except GeneratorExit: raise asyncio.CancelledError() from None - mock_proxy = PicklableMock(spec=WorkerProxyLike, id="test-proxy-id") - wool_task = Task( - id=uuid4(), - callable=teardown_cancelling_generator, - args=(), - kwargs={}, - proxy=mock_proxy, - ) + wool_task = make_task(teardown_cancelling_generator) request = protocol.Request(task=wool_task.to_protobuf()) next_request = protocol.Request(next=protocol.Void()) @@ -4697,7 +4070,7 @@ async def teardown_cancelling_generator(): assert remaining == [] @pytest.mark.asyncio - async def test_dispatch_streaming_applies_var_updates_per_frame( + async def test_dispatch_should_apply_caller_var_updates_per_frame_when_streaming( self, grpc_aio_stub, mock_worker_proxy_cache ): """Test a streaming dispatch applies caller var updates per frame. @@ -4718,25 +4091,23 @@ async def streaming_task(): while True: yield var.get() - mock_proxy = PicklableMock(spec=WorkerProxyLike, id="test-proxy-id") - wool_task = Task( - id=uuid4(), - callable=streaming_task, - args=(), - kwargs={}, - proxy=mock_proxy, - ) + wool_task = make_task(streaming_task) + # Fix the chain id across every wire frame so the worker arms + # onto it on the initial mount and subsequent updates route to + # the same per-chain cached context keyed by that chain id. + caller_chain = uuid4() initial_request = protocol.Request( task=wool_task.to_protobuf(), - context=protocol.Context( + context=protocol.ChainManifest( + id=caller_chain.hex, vars=[ protocol.ContextVar( namespace=var.namespace, name=var.name, value=cloudpickle.dumps("alpha"), ) - ] + ], ), ) @@ -4753,7 +4124,8 @@ async def streaming_task(): await stream.write( protocol.Request( next=protocol.Void(), - context=protocol.Context( + context=protocol.ChainManifest( + id=caller_chain.hex, vars=[ protocol.ContextVar( namespace=var.namespace, @@ -4776,57 +4148,54 @@ async def streaming_task(): assert results == ["alpha", "bravo", "charlie"] @pytest.mark.asyncio - async def test_dispatch_attaches_strict_mode_warnings_for_single_peer( + async def test_dispatch_should_ship_strict_mode_encode_error_when_routine_succeeds( self, grpc_aio_stub, mock_worker_proxy_cache, mocker: MockerFixture ): - """Test :class:`WorkerService.dispatch` attaches a single bare - :class:`ContextDecodeWarning` to the routine's exception via - ``__notes__`` and ``__wool_context_warnings__``. + """Test `WorkerService.dispatch` ships a strict-mode + `wool.ChainSerializationError` as the terminal exception + when the routine succeeded but the per-step encode failed. Implementation note: the routine itself returns ``"ok"``, but - :meth:`Context.to_protobuf` is patched to raise on every - call. The per-step encode (which runs inside ``_step`` to - build the success :class:`_Response`) therefore raises the - warning, which routes through ``DispatchSession`` and surfaces - in :meth:`WorkerService.dispatch`'s terminal-exception clause - — the same code path that attaches strict-mode warnings as - ``__notes__`` / ``__wool_context_warnings__`` on the - exception before serializing it back to the caller. + `ChainManifest.to_protobuf` is patched to raise on every call. + The per-step encode (which runs inside ``_step`` to build + the success `_Response`) therefore raises a + `wool.ChainSerializationError`, which propagates through + `DispatchSession` and surfaces in + `WorkerService.dispatch`'s terminal-exception clause. + The worker arms its encode path by having the routine touch + a `wool.ContextVar`. Given: - A coroutine routine AND :meth:`Context.to_protobuf` patched - to raise a single bare :class:`ContextDecodeWarning` (not - a group) on every call + A coroutine routine that touches a wool ContextVar AND + `ChainManifest.to_protobuf` patched to raise a + `wool.ChainSerializationError` on every call. When: - The dispatch RPC ships its terminal-exception Response + The dispatch RPC ships its terminal-exception Response. Then: - It should ship an exception payload with ``__notes__`` - containing the single warning and - ``__wool_context_warnings__`` of length 1. + It should ship the `wool.ChainSerializationError` + directly as the routine's terminal failure — the + "fail loud" contract means the routine's success value + is dropped because the post-run chain manifest didn't ship. """ - from wool.runtime.context import Context - from wool.runtime.context import ContextDecodeWarning + from wool.runtime.context.manifest import ChainManifest as _Context # Arrange + namespace = f"single_encode_{uuid4().hex}" + arm_var: wool.ContextVar[str] = wool.ContextVar("arm", namespace=namespace) + async def succeeding_task(): + arm_var.set("arm") return "ok" - original_to_protobuf = Context.to_protobuf - single_warning = ContextDecodeWarning("single bare peer") + original_session_encode = _Context.to_protobuf + single_warning = wool.SerializationWarning("single bare peer") def encode_with_single_failure(self, *args, **kwargs): - raise single_warning + raise wool.ChainSerializationError(single_warning) - mocker.patch.object(Context, "to_protobuf", encode_with_single_failure) + mocker.patch.object(_Context, "to_protobuf", encode_with_single_failure) - mock_proxy = PicklableMock(spec=WorkerProxyLike, id="test-proxy-id") - wool_task = Task( - id=uuid4(), - callable=succeeding_task, - args=(), - kwargs={}, - proxy=mock_proxy, - ) + wool_task = make_task(succeeding_task) request = protocol.Request(task=wool_task.to_protobuf()) # Act @@ -4837,86 +4206,84 @@ def encode_with_single_failure(self, *args, **kwargs): responses = [r async for r in stream] # Restore so the gRPC fixture's teardown does not explode on - # subsequent Context.to_protobuf calls during cleanup. - mocker.patch.object(Context, "to_protobuf", original_to_protobuf) + # subsequent ChainManifest.to_protobuf calls during cleanup. + mocker.patch.object(_Context, "to_protobuf", original_session_encode) # Assert ack, terminal = responses assert ack.HasField("ack") assert terminal.HasField("exception") shipped = cloudpickle.loads(terminal.exception.dump) - assert hasattr(shipped, "__notes__") - notes_text = "\n".join(shipped.__notes__) - assert "single bare peer" in notes_text - assert hasattr(shipped, "__wool_context_warnings__") - warnings = shipped.__wool_context_warnings__ - assert len(warnings) == 1 - assert isinstance(warnings[0], ContextDecodeWarning) - assert str(warnings[0]) == "single bare peer" + assert isinstance(shipped, wool.ChainSerializationError) + assert len(shipped.warnings) == 1 + assert str(shipped.warnings[0]) == "single bare peer" @pytest.mark.asyncio - async def test_dispatch_attaches_strict_mode_warnings_on_async_generator_path( + async def test_dispatch_should_chain_encode_failure_as_cause_when_streaming( self, grpc_aio_stub, mock_worker_proxy_cache, mocker: MockerFixture ): - """Test :class:`WorkerService.dispatch` attaches strict-mode - warning peers to the routine's exception on the - async-generator path. + """Test `WorkerService.dispatch` chains the + strict-mode encode error onto the routine's exception via + ``__cause__`` on the async-generator path. Given: - An async-generator routine that raises a custom exception - mid-stream AND :meth:`Context.to_protobuf` patched to - raise a :class:`BaseExceptionGroup` of two - :class:`ContextDecodeWarning` peers + An async-generator routine that touches a wool + ContextVar (arming the worker) and raises a custom + exception mid-stream, AND `ChainManifest.to_protobuf` + patched to raise `wool.ChainSerializationError` on + the post-run encode. When: The dispatch RPC ships its terminal-exception Response - after one successful yield + after one successful yield. Then: It should ship the routine's exception type bare with - both warnings on ``__notes__`` and - ``__wool_context_warnings__`` of length 2. + the `wool.ChainSerializationError` chained on + ``__cause__``. """ - from wool.runtime.context import Context - from wool.runtime.context import ContextDecodeWarning + from wool.runtime.context.manifest import ChainManifest as _Context # Arrange - class _RoutineFailure(Exception): - pass + namespace = f"agen_chain_{uuid4().hex}" + arm_var: wool.ContextVar[str] = wool.ContextVar("arm", namespace=namespace) async def streamer(): + arm_var.set("arm") # arm the worker so encode runs yield "first" - raise _RoutineFailure("mid-stream signal") + # Stdlib exception so tblib's reduce preserves the + # __cause__ chain across the wire. + raise RuntimeError("mid-stream signal") - original_to_protobuf = Context.to_protobuf + original_session_encode = _Context.to_protobuf # Let the first per-yield encode succeed so the streamer # delivers ``"first"`` over the wire; subsequent invocations - # (including the dispatch handler's terminal-exception - # snapshot) raise the strict-mode encode group. + # (including the worker driver's terminal-exception + # context) raise the strict-mode ChainSerializationError. call_count = {"n": 0} def encode_with_strict_failure(self, *args, **kwargs): call_count["n"] += 1 if call_count["n"] == 1: - return original_to_protobuf(self, *args, **kwargs) - raise BaseExceptionGroup( - "strict-mode encode failure", - [ - ContextDecodeWarning("agen-peer-1"), - ContextDecodeWarning("agen-peer-2"), - ], + return original_session_encode(self, *args, **kwargs) + raise wool.ChainSerializationError( + wool.SerializationWarning("agen-peer-1"), + wool.SerializationWarning("agen-peer-2"), ) - mocker.patch.object(Context, "to_protobuf", encode_with_strict_failure) + mocker.patch.object(_Context, "to_protobuf", encode_with_strict_failure) - mock_proxy = PicklableMock(spec=WorkerProxyLike, id="test-proxy-id") - wool_task = Task( - id=uuid4(), - callable=streamer, - args=(), - kwargs={}, - proxy=mock_proxy, - ) + wool_task = make_task(streamer) first_request = protocol.Request(task=wool_task.to_protobuf()) - next_request = protocol.Request(next=protocol.Void()) + # Caller-supplied chain id propagated on every mid-stream + # frame — under the per-frame architecture the worker keys + # its cached `contextvars.Context` registry by chain + # id, so omitting the manifest on subsequent NextRequests + # would allocate a fresh ``copy_context()`` per iteration + # and lose the streamer's prior ``arm_var.set`` (which is + # exactly what arms the chain so the post-step encode runs). + chain_id = uuid4().hex + next_request = protocol.Request( + next=protocol.Void(), context=protocol.ChainManifest(id=chain_id) + ) # Act async with grpc_aio_stub() as stub: @@ -4935,38 +4302,35 @@ def encode_with_strict_failure(self, *args, **kwargs): remaining = [r async for r in stream] # Restore so the gRPC fixture's teardown does not explode on - # subsequent Context.to_protobuf calls during cleanup. - mocker.patch.object(Context, "to_protobuf", original_to_protobuf) + # subsequent ChainManifest.to_protobuf calls during cleanup. + mocker.patch.object(_Context, "to_protobuf", original_session_encode) # Assert terminals = [r for r in remaining if r.HasField("exception")] assert len(terminals) == 1 shipped = cloudpickle.loads(terminals[0].exception.dump) - assert isinstance(shipped, _RoutineFailure) - assert hasattr(shipped, "__notes__") - notes_text = "\n".join(shipped.__notes__) - assert "agen-peer-1" in notes_text - assert "agen-peer-2" in notes_text - warnings = shipped.__wool_context_warnings__ - assert len(warnings) == 2 - assert all(isinstance(w, ContextDecodeWarning) for w in warnings) + assert isinstance(shipped, RuntimeError) + cause = shipped.__cause__ + assert isinstance(cause, wool.ChainSerializationError) + warning_messages = {str(w) for w in cause.warnings} + assert warning_messages == {"agen-peer-1", "agen-peer-2"} @pytest.mark.asyncio - async def test_dispatch_does_not_unwrap_runtime_error_with_unrelated_cause( + async def test_dispatch_should_ship_runtime_error_raw_when_cause_unrelated( self, grpc_aio_stub, mock_worker_proxy_cache ): - """Test :class:`WorkerService.dispatch` ships - :class:`RuntimeError` raw when its ``__cause__`` is not a - :class:`StopAsyncIteration`. + """Test `WorkerService.dispatch` ships + `RuntimeError` raw when its ``__cause__`` is not a + `StopAsyncIteration`. Given: A coroutine routine raising ``RuntimeError("not async iter")`` whose ``__cause__`` is - NOT a :class:`StopAsyncIteration` + NOT a `StopAsyncIteration` When: The dispatch RPC ships its terminal-exception Response Then: - It should ship the :class:`RuntimeError` raw (not + It should ship the `RuntimeError` raw (not unwrapped to ``__cause__``). """ @@ -4977,14 +4341,7 @@ async def raising_task(): except ValueError as cause: raise RuntimeError("not async iter") from cause - mock_proxy = PicklableMock(spec=WorkerProxyLike, id="test-proxy-id") - wool_task = Task( - id=uuid4(), - callable=raising_task, - args=(), - kwargs={}, - proxy=mock_proxy, - ) + wool_task = make_task(raising_task) request = protocol.Request(task=wool_task.to_protobuf()) # Act @@ -5004,24 +4361,24 @@ async def raising_task(): assert "not async iter" in str(shipped) @pytest.mark.asyncio - async def test_dispatch_does_not_unwrap_runtime_error_for_async_generator( + async def test_dispatch_should_ship_runtime_error_raw_when_generator_cause_stop_async_iteration( # noqa: E501 self, grpc_aio_stub, mock_worker_proxy_cache ): - """Test :class:`WorkerService.dispatch` keeps - :class:`RuntimeError` un-unwrapped on the async-generator + """Test `WorkerService.dispatch` keeps + `RuntimeError` un-unwrapped on the async-generator path even when ``__cause__`` is a - :class:`StopAsyncIteration`. + `StopAsyncIteration`. Given: An async-generator routine that raises a PEP 525 - :class:`RuntimeError` whose ``__cause__`` is a - :class:`StopAsyncIteration` + `RuntimeError` whose ``__cause__`` is a + `StopAsyncIteration` When: The dispatch RPC ships its terminal-exception Response Then: - It should ship the :class:`RuntimeError` un-unwrapped — - the wire payload is :class:`RuntimeError`, not - :class:`StopAsyncIteration`. + It should ship the `RuntimeError` un-unwrapped — + the wire payload is `RuntimeError`, not + `StopAsyncIteration`. """ # Arrange @@ -5029,14 +4386,7 @@ async def agen_raising_sai(): raise StopAsyncIteration("from agen") yield 1 - mock_proxy = PicklableMock(spec=WorkerProxyLike, id="test-proxy-id") - wool_task = Task( - id=uuid4(), - callable=agen_raising_sai, - args=(), - kwargs={}, - proxy=mock_proxy, - ) + wool_task = make_task(agen_raising_sai) first_request = protocol.Request(task=wool_task.to_protobuf()) next_request = protocol.Request(next=protocol.Void()) @@ -5059,10 +4409,10 @@ async def agen_raising_sai(): assert not isinstance(shipped, StopAsyncIteration) @pytest.mark.asyncio - async def test_dispatch_drains_handler_on_terminal_exception_path_for_coroutine( + async def test_dispatch_should_drain_handler_on_terminal_exception_when_coroutine( self, grpc_aio_stub, mock_worker_proxy_cache, mocker: MockerFixture ): - """Test :class:`WorkerService.dispatch` drains the handler on + """Test `WorkerService.dispatch` drains the handler on the coroutine terminal-exception path before yielding the terminal Response. @@ -5074,35 +4424,34 @@ async def test_dispatch_drains_handler_on_terminal_exception_path_for_coroutine( When: The dispatch RPC reaches its terminal-exception clause Then: - It should call :meth:`DispatchSession.drain` at least twice + It should call `DispatchSession.drain` at least twice (verified via spy) before yielding the terminal Response. """ - from wool.runtime.worker import session as handler_module + from wool.runtime.worker.frame import Frame as _Frame + from wool.runtime.worker.frame import ResultResponseFrame as _ResultResponseFrame # Arrange async def succeeding_coroutine(): return "value" - mock_proxy = PicklableMock(spec=WorkerProxyLike, id="test-proxy-id") - wool_task = Task( - id=uuid4(), - callable=succeeding_coroutine, - args=(), - kwargs={}, - proxy=mock_proxy, - ) + wool_task = make_task(succeeding_coroutine) request = protocol.Request(task=wool_task.to_protobuf()) drain_spy = mocker.spy(DispatchSession, "drain") - def failing_to_protobuf(self, *, serializer): - raise RuntimeError("synthetic dump failure") + # Patch the result-frame encode to raise — the dispatch handler + # routes the result-bearing `ResultResponseFrame` from + # ``_step`` through ``frame.to_protobuf()`` to ship it, and a + # raise there exercises the terminal-exception clause path + # that re-encodes via `ExceptionResponseFrame.for_send`. + original_to_protobuf = _Frame.to_protobuf - mocker.patch.object( - handler_module._Response, - "to_protobuf", - failing_to_protobuf, - ) + def failing_to_protobuf(self): + if isinstance(self, _ResultResponseFrame): + raise RuntimeError("synthetic dump failure") + return original_to_protobuf(self) + + mocker.patch.object(_Frame, "to_protobuf", failing_to_protobuf) # Act async with grpc_aio_stub() as stub: @@ -5117,20 +4466,20 @@ def failing_to_protobuf(self, *, serializer): assert terminal.HasField("exception") assert drain_spy.call_count >= 2, ( f"Expected dispatch's terminal-exception clause to call " - f"handler.drain() before snapshotting handler.context " + f"handler.drain() before reading handler's final wire context " f"(plus __aexit__'s call); observed {drain_spy.call_count} " f"call(s)." ) @pytest.mark.asyncio - async def test_dispatch_skips_backpressure_evaluation_when_no_hook( + async def test_dispatch_should_skip_backpressure_evaluation_when_no_hook( self, grpc_aio_stub, mock_worker_proxy_cache ): - """Test :class:`WorkerService.dispatch` skips backpressure + """Test `WorkerService.dispatch` skips backpressure evaluation entirely when no hook is configured. Given: - A :class:`WorkerService` with no backpressure hook + A `WorkerService` with no backpressure hook When: ``dispatch`` is invoked with a normally-completing coroutine task @@ -5143,14 +4492,7 @@ async def test_dispatch_skips_backpressure_evaluation_when_no_hook( async def sample_task(): return "no_hook_result" - mock_proxy = PicklableMock(spec=WorkerProxyLike, id="test-proxy-id") - wool_task = Task( - id=uuid4(), - callable=sample_task, - args=(), - kwargs={}, - proxy=mock_proxy, - ) + wool_task = make_task(sample_task) request = protocol.Request(task=wool_task.to_protobuf()) service = WorkerService() assert service._backpressure is None # sanity: no hook configured @@ -5170,10 +4512,10 @@ async def sample_task(): assert cloudpickle.loads(result.result.dump) == "no_hook_result" @pytest.mark.asyncio - async def test_dispatch_with_rejecting_backpressure_leaves_no_docket_entry( + async def test_dispatch_should_leave_no_docket_entry_when_backpressure_rejects( self, grpc_aio_stub, mock_worker_proxy_cache ): - """Test :class:`WorkerService.dispatch` leaves no docket entry + """Test `WorkerService.dispatch` leaves no docket entry when backpressure rejects the task. Given: @@ -5192,14 +4534,7 @@ async def test_dispatch_with_rejecting_backpressure_leaves_no_docket_entry( async def sample_task(): return "should_not_reach" - mock_proxy = PicklableMock(spec=WorkerProxyLike, id="test-proxy-id") - wool_task = Task( - id=uuid4(), - callable=sample_task, - args=(), - kwargs={}, - proxy=mock_proxy, - ) + wool_task = make_task(sample_task) request = protocol.Request(task=wool_task.to_protobuf()) def hook(ctx): @@ -5229,13 +4564,13 @@ def hook(ctx): assert service.stopping.is_set() assert service.stopped.is_set() - def test_stopping_and_stopped_reflect_lifecycle(self): - """Test :attr:`WorkerService.stopping` and - :attr:`WorkerService.stopped` reflect the service lifecycle + def test_stopping_should_reflect_lifecycle_via_is_set(self): + """Test `WorkerService.stopping` and + `WorkerService.stopped` reflect the service lifecycle through their ``is_set()`` accessor. Given: - A new :class:`WorkerService` and accesses to the + A new `WorkerService` and accesses to the ``stopping`` and ``stopped`` properties When: The properties are read pre/post-stop @@ -5261,19 +4596,19 @@ async def _drive_stop(): assert service.stopping.is_set() is True assert service.stopped.is_set() is True - def test_stopping_wrapper_does_not_expose_mutators(self): - """Test the :attr:`WorkerService.stopping` wrapper exposes + def test_stopping_should_raise_attribute_error_when_mutators_called(self): + """Test the `WorkerService.stopping` wrapper exposes only read access — calling mutators raises - :class:`AttributeError`. + `AttributeError`. Given: - A :class:`WorkerService` whose ``stopping`` accessor is + A `WorkerService` whose ``stopping`` accessor is exposed When: The caller attempts to call ``.set()`` or ``.clear()`` on the returned wrapper Then: - It should raise :class:`AttributeError` (the read-only + It should raise `AttributeError` (the read-only wrapper does not expose mutators). """ # Arrange @@ -5287,15 +4622,17 @@ def test_stopping_wrapper_does_not_expose_mutators(self): wrapper.clear() @pytest.mark.asyncio - async def test_stop_with_empty_docket_completes_cleanly(self, grpc_aio_stub): - """Test :meth:`WorkerService.stop` with ``timeout=0`` and an + async def test_stop_should_complete_cleanly_when_docket_empty_and_no_proxy_pool( + self, grpc_aio_stub + ): + """Test `WorkerService.stop` with ``timeout=0`` and an empty docket completes cleanly without a configured proxy pool. Given: - A :class:`WorkerService.stop` invocation with + A `WorkerService.stop` invocation with ``timeout=0`` while the docket is empty AND - :data:`wool.__proxy_pool__` is unset + `wool.__proxy_pool__` is unset When: The stop RPC is invoked Then: @@ -5308,14 +4645,7 @@ async def test_stop_with_empty_docket_completes_cleanly(self, grpc_aio_stub): async def sample_task(): return "should_not_reach" - mock_proxy = PicklableMock(spec=WorkerProxyLike, id="test-proxy-id") - wool_task = Task( - id=uuid4(), - callable=sample_task, - args=(), - kwargs={}, - proxy=mock_proxy, - ) + wool_task = make_task(sample_task) service = WorkerService() # Sanity: the autouse _clear_proxy_context fixture leaves the @@ -5343,15 +4673,15 @@ async def sample_task(): assert exc_info.value.code() == StatusCode.UNAVAILABLE @pytest.mark.asyncio - async def test_stop_clears_loop_pool_when_proxy_pool_clear_raises( + async def test_stop_should_set_stopped_when_proxy_pool_clear_raises( self, grpc_aio_stub, mocker: MockerFixture ): - """Test :meth:`WorkerService.stop` still sets - :attr:`stopped` when the proxy-pool's ``clear`` coroutine + """Test `WorkerService.stop` still sets + `stopped` when the proxy-pool's ``clear`` coroutine raises — the loop-pool clear runs in the ``finally``. Given: - A :class:`WorkerService.stop` invocation while the + A `WorkerService.stop` invocation while the proxy-pool ``clear`` coroutine raises When: The stop RPC is invoked @@ -5393,11 +4723,11 @@ async def test_stop_clears_loop_pool_when_proxy_pool_clear_raises( wool.__proxy_pool__.reset(token) @pytest.mark.asyncio - async def test_dispatch_ships_synthesized_runtime_error_for_unpicklable_routine_exception( # noqa: E501 + async def test_dispatch_should_ship_synthesized_runtime_error_when_routine_exception_unpicklable( # noqa: E501 self, grpc_aio_stub, mock_worker_proxy_cache ): - """Test :class:`WorkerService.dispatch` ships a synthesized - stdlib :class:`RuntimeError` for an un-picklable routine + """Test `WorkerService.dispatch` ships a synthesized + stdlib `RuntimeError` for an un-picklable routine exception. Given: @@ -5405,7 +4735,7 @@ async def test_dispatch_ships_synthesized_runtime_error_for_unpicklable_routine_ When: The dispatch RPC ships its terminal-exception Response Then: - It should ship a synthesized stdlib :class:`RuntimeError` + It should ship a synthesized stdlib `RuntimeError` whose message names the original exception type and args. """ @@ -5417,14 +4747,7 @@ def __reduce__(self): async def raising_task(): raise _UnpicklableError("unpicklable signal") - mock_proxy = PicklableMock(spec=WorkerProxyLike, id="test-proxy-id") - wool_task = Task( - id=uuid4(), - callable=raising_task, - args=(), - kwargs={}, - proxy=mock_proxy, - ) + wool_task = make_task(raising_task) request = protocol.Request(task=wool_task.to_protobuf()) # Act @@ -5444,11 +4767,11 @@ async def raising_task(): assert "unpicklable signal" in str(shipped) @pytest.mark.asyncio - async def test_dispatch_streaming_ships_synthesized_runtime_error_for_unpicklable_exception( # noqa: E501 + async def test_dispatch_should_ship_synthesized_runtime_error_when_streaming_exception_unpicklable( # noqa: E501 self, grpc_aio_stub, mock_worker_proxy_cache ): - """Test :class:`WorkerService.dispatch` ships a synthesized - :class:`RuntimeError` and a valid context snapshot when an + """Test `WorkerService.dispatch` ships a synthesized + `RuntimeError` and a valid context when an async-generator raises an un-picklable exception after one yield. @@ -5460,7 +4783,7 @@ async def test_dispatch_streaming_ships_synthesized_runtime_error_for_unpicklabl after the first result Then: The terminal Response carries a synthesized - :class:`RuntimeError` and a valid context snapshot. + `RuntimeError` and a valid context. """ # Arrange @@ -5472,14 +4795,7 @@ async def streamer(): yield "first" raise _UnpicklableError("agen unpicklable signal") - mock_proxy = PicklableMock(spec=WorkerProxyLike, id="test-proxy-id") - wool_task = Task( - id=uuid4(), - callable=streamer, - args=(), - kwargs={}, - proxy=mock_proxy, - ) + wool_task = make_task(streamer) first_request = protocol.Request(task=wool_task.to_protobuf()) next_request = protocol.Request(next=protocol.Void()) @@ -5503,18 +4819,20 @@ async def streamer(): terminals = [r for r in remaining if r.HasField("exception")] assert len(terminals) == 1 terminal = terminals[0] - assert terminal.HasField("context") + # Lazy-wire-frame: the unarmed worker omits the context + # field — the routine never touched a wool.ContextVar. + assert not terminal.HasField("context") shipped = cloudpickle.loads(terminal.exception.dump) assert type(shipped) is RuntimeError assert "_UnpicklableError" in str(shipped) assert "agen unpicklable signal" in str(shipped) @pytest.mark.asyncio - async def test_dispatch_with_unpicklable_exception_whose_str_raises( + async def test_dispatch_should_ship_class_name_only_when_exception_unpicklable_and_str_raises( # noqa: E501 self, grpc_aio_stub, mock_worker_proxy_cache ): - """Test :class:`WorkerService.dispatch` ships a synthesized - :class:`RuntimeError` containing only the exception class name + """Test `WorkerService.dispatch` ships a synthesized + `RuntimeError` containing only the exception class name when the routine's exception is un-picklable AND its ``__str__`` raises. @@ -5525,7 +4843,7 @@ async def test_dispatch_with_unpicklable_exception_whose_str_raises( When: The dispatch RPC ships its terminal-exception Response Then: - It should ship a synthesized :class:`RuntimeError` whose + It should ship a synthesized `RuntimeError` whose message carries only the class name (the message-with-args f-string fallback short-circuited because ``__str__`` raised), exercising the defensive @@ -5551,14 +4869,7 @@ def __str__(self): async def raising_task(): raise _BadStrError(_Unpicklable()) - mock_proxy = PicklableMock(spec=WorkerProxyLike, id="test-proxy-id") - wool_task = Task( - id=uuid4(), - callable=raising_task, - args=(), - kwargs={}, - proxy=mock_proxy, - ) + wool_task = make_task(raising_task) request = protocol.Request(task=wool_task.to_protobuf()) # Act @@ -5584,14 +4895,14 @@ async def raising_task(): assert shipped.args == ("_BadStrError",) @pytest.mark.asyncio - async def test_backpressure_with_truthy_non_bool_return_rejects( + async def test_backpressure_should_abort_when_hook_returns_truthy_non_bool( self, grpc_aio_stub, mock_worker_proxy_cache ): - """Test :class:`WorkerService.dispatch` rejects the task when + """Test `WorkerService.dispatch` rejects the task when the backpressure hook returns a truthy non-bool value. Given: - A :class:`BackpressureLike` hook that returns a non-bool + A `BackpressureLike` hook that returns a non-bool truthy value (e.g., a non-empty string) When: The dispatch RPC is invoked @@ -5603,14 +4914,7 @@ async def test_backpressure_with_truthy_non_bool_return_rejects( async def sample_task(): return "should_not_reach" - mock_proxy = PicklableMock(spec=WorkerProxyLike, id="test-proxy-id") - wool_task = Task( - id=uuid4(), - callable=sample_task, - args=(), - kwargs={}, - proxy=mock_proxy, - ) + wool_task = make_task(sample_task) request = protocol.Request(task=wool_task.to_protobuf()) def truthy_string_hook(ctx): @@ -5629,19 +4933,19 @@ def truthy_string_hook(ctx): assert exc_info.value.code() == StatusCode.RESOURCE_EXHAUSTED @pytest.mark.asyncio - async def test_dispatch_rejects_empty_request_stream( + async def test_dispatch_should_nack_with_value_error_when_request_stream_empty( self, grpc_aio_stub, mock_worker_proxy_cache, mocker: MockerFixture ): - """Test :class:`WorkerService.dispatch` replies with a single - :class:`Nack` when the request stream is empty. + """Test `WorkerService.dispatch` replies with a single + `Nack` when the request stream is empty. Given: A dispatch call whose request stream yields no frames When: The dispatch RPC is consumed Then: - It should respond with a single :class:`Nack` whose - exception decodes to a :class:`ValueError` naming the + It should respond with a single `Nack` whose + exception decodes to a `ValueError` naming the empty-stream rejection. """ @@ -5671,11 +4975,11 @@ async def passthrough(self, continuation, handler_call_details): assert "empty" in str(raised).lower() @pytest.mark.asyncio - async def test_dispatch_rejects_first_frame_with_wrong_oneof( + async def test_dispatch_should_nack_with_value_error_when_first_frame_wrong_oneof( self, grpc_aio_stub, mock_worker_proxy_cache, mocker: MockerFixture ): - """Test :class:`WorkerService.dispatch` replies with a single - :class:`Nack` when the first frame's payload is the wrong + """Test `WorkerService.dispatch` replies with a single + `Nack` when the first frame's payload is the wrong oneof variant. Given: @@ -5684,8 +4988,8 @@ async def test_dispatch_rejects_first_frame_with_wrong_oneof( When: The dispatch RPC is consumed Then: - It should respond with a single :class:`Nack` whose - exception decodes to a :class:`ValueError` naming the + It should respond with a single `Nack` whose + exception decodes to a `ValueError` naming the payload oneof violation. """ @@ -5719,22 +5023,22 @@ async def passthrough(self, continuation, handler_call_details): assert "payload" in message and "task" in message @pytest.mark.asyncio - async def test_dispatch_nack_with_unpicklable_rejected_original( + async def test_dispatch_should_ship_synthesized_runtime_error_in_nack_when_rejected_unpicklable( # noqa: E501 self, grpc_aio_stub, mock_worker_proxy_cache, mocker: MockerFixture ): - """Test :class:`WorkerService.dispatch` ships a synthesized - :class:`RuntimeError` for the Nack ``exception`` payload when - :attr:`Rejected.original` is itself un-picklable. + """Test `WorkerService.dispatch` ships a synthesized + `RuntimeError` for the Nack ``exception`` payload when + `Rejected.original` is itself un-picklable. Given: - A :class:`WorkerService.dispatch` whose - :attr:`Rejected.original` is itself an un-picklable + A `WorkerService.dispatch` whose + `Rejected.original` is itself an un-picklable exception When: - The dispatch RPC ships its :class:`Nack` Response + The dispatch RPC ships its `Nack` Response Then: The ``Nack.exception`` decodes to the synthesized stdlib - :class:`RuntimeError`. + `RuntimeError`. """ from wool.runtime.worker import session as handler_module @@ -5746,14 +5050,7 @@ def __reduce__(self): async def sample_task(): return "should_not_reach" - mock_proxy = PicklableMock(spec=WorkerProxyLike, id="test-proxy-id") - wool_task = Task( - id=uuid4(), - callable=sample_task, - args=(), - kwargs={}, - proxy=mock_proxy, - ) + wool_task = make_task(sample_task) request = protocol.Request(task=wool_task.to_protobuf()) original_aenter = handler_module.DispatchSession.__aenter__ @@ -5790,34 +5087,27 @@ async def failing_aenter(self): class TestBackpressureContext: - """Tests for :class:`wool.runtime.worker.service.BackpressureContext`.""" + """Tests for `wool.runtime.worker.service.BackpressureContext`.""" - def test_backpressure_context_is_frozen(self): - """Test :class:`BackpressureContext` rejects mutation of its + def test_backpressure_context_should_reject_mutation_after_construction(self): + """Test `BackpressureContext` rejects mutation of its fields after construction. Given: - A :class:`BackpressureContext` instance with assigned + A `BackpressureContext` instance with assigned ``active_task_count`` and ``task`` When: The caller attempts to mutate ``ctx.active_task_count = 5`` Then: - It should raise :class:`dataclasses.FrozenInstanceError`. + It should raise `dataclasses.FrozenInstanceError`. """ import dataclasses from wool.runtime.worker.service import BackpressureContext # Arrange - mock_proxy = PicklableMock(spec=WorkerProxyLike, id="frozen-test") - task = Task( - id=uuid4(), - callable=lambda: None, - args=(), - kwargs={}, - proxy=mock_proxy, - ) + task = make_task(lambda: None, proxy_id="frozen-test") ctx = BackpressureContext(active_task_count=0, task=task) # Act & assert @@ -5826,11 +5116,11 @@ def test_backpressure_context_is_frozen(self): class TestBackpressureLike: - """Tests for :class:`wool.runtime.worker.service.BackpressureLike`.""" + """Tests for `wool.runtime.worker.service.BackpressureLike`.""" - def test_backpressure_like_runtime_checkable(self): - """Test :class:`BackpressureLike` accepts callables and - rejects non-callables under :func:`isinstance`. + def test_backpressure_like_should_accept_callables_and_reject_non_callables(self): + """Test `BackpressureLike` accepts callables and + rejects non-callables under `isinstance`. Given: A sync callable ``def hook(ctx): ...``, an async callable diff --git a/wool/tests/runtime/worker/test_session.py b/wool/tests/runtime/worker/test_session.py index 65fac393..73d2651a 100644 --- a/wool/tests/runtime/worker/test_session.py +++ b/wool/tests/runtime/worker/test_session.py @@ -2,7 +2,7 @@ These tests exercise :class:`Rejected` and :class:`DispatchSession` through their public surface (constructor, ``async with``, ``async for``, -:meth:`drain`, :meth:`cancel`, public attributes ``task``/``context``/ +:meth:`drain`, :meth:`cancel`, public attributes ``task``/``decoded``/ ``serializer``). Worker-loop interactions use the ``new_event_loop + threading.Thread + install_task_factory`` pattern so the handler can drive a real :func:`scoped` routine across loops. @@ -22,8 +22,8 @@ from wool import protocol from wool.protocol import WorkerStub from wool.protocol import add_WorkerServicer_to_server -from wool.runtime.context import Context -from wool.runtime.context import install_task_factory +from wool.runtime.context.factory import install_task_factory +from wool.runtime.context.manifest import ChainManifest from wool.runtime.routine.task import Task from wool.runtime.routine.task import WorkerProxyLike from wool.runtime.worker.interceptor import VersionInterceptor @@ -83,6 +83,25 @@ async def _slow_coro(): return "never" +# Module-level Event so the routine and the test share one instance +# under cloudpickle transport — the routine is pickled by reference and +# resolves this global on the worker side (a closure-captured Event +# would not pickle; see ``_slow_coro``). The routine sets it on entering +# its blocking step so a cancel-mid-step test knows the worker is +# suspended *inside* ``_drive_step``'s per-step ``await`` before it +# cancels. +_STEP_BLOCKING = threading.Event() + + +async def _gen_blocks_in_step(): + """Async generator whose first step signals then blocks forever, so + a cancel preempts the in-flight per-step task rather than racing the + routine's completion.""" + _STEP_BLOCKING.set() + await asyncio.sleep(3600) + yield "never" + + async def _gen_yielding_unpicklable(): """Async generator that yields a non-cloudpickle-serializable object so the dispatch handler's :meth:`_Response.to_protobuf` @@ -108,6 +127,31 @@ def _sync_callable(): return "not_async" +# Unarmed-worker regression fixture — a module-level wool.ContextVar +# with a default plus a coroutine routine that offloads, via *plain* +# asyncio.to_thread, a function that reads it. When the routine is +# dispatched with no caller Wool state, the worker chain stays +# unarmed, so the offload runs as a plain contextvars context and the +# read returns the default instead of tripping ChainContention. +# Module-level so the routine and the var are picklable for cloudpickle +# transport. +_UNARMED_WORKER_VAR: wool.ContextVar[str] = wool.ContextVar( + "unarmed_worker_var", default="unset-default" +) + + +def _read_unarmed_worker_var() -> str: + """Read the unarmed-worker regression var — runs in the offload thread.""" + return _UNARMED_WORKER_VAR.get() + + +async def _coro_offloads_plain_to_thread(): + """Coroutine routine that offloads a wool.ContextVar read via + plain asyncio.to_thread. + """ + return await asyncio.to_thread(_read_unarmed_worker_var) + + # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- @@ -130,9 +174,19 @@ def _request_for(task: Task) -> protocol.Request: return protocol.Request(task=task.to_protobuf()) -def _next_request() -> protocol.Request: - """Build a follow-up ``next`` request frame.""" - return protocol.Request(next=protocol.Void()) +def _next_request(*, with_context: bool = False) -> protocol.Request: + """Build a follow-up ``next`` request frame. + + When *with_context* is True, sets an empty (but present) + ``context`` field on the wire so tests that exercise the + chain-manifest decode path see a present-but-empty wire frame + rather than the absent-field path that lazy-wire-frame semantics + short-circuit. + """ + request = protocol.Request(next=protocol.Void()) + if with_context: + request.context.CopyFrom(protocol.ChainManifest()) + return request async def _stream(*requests): @@ -179,7 +233,7 @@ def grpc_stub_cls(): class TestRejected: """Tests for :class:`Rejected`.""" - def test___init___with_arbitrary_exception(self): + def test___init___should_expose_original_and_stringify(self): """Test :class:`Rejected` wraps an arbitrary exception. Given: @@ -200,7 +254,9 @@ def test___init___with_arbitrary_exception(self): assert rejected.original is original assert str(rejected) == "ValueError: malformed task id" - def test___init___with_custom_exception_subclass(self): + def test___init___should_preserve_subclass_name_and_str_when_custom_exception( + self, + ): """Test :class:`Rejected` preserves a custom subclass and ``str()``. Given: @@ -237,7 +293,7 @@ class TestDispatchSession: # -- construction ----------------------------------------------------- - def test___init___with_request_stream_and_worker_loop( + def test___init___should_leave_public_attrs_unset( self, worker_loop, mock_worker_proxy_cache ): """Test :class:`DispatchSession` constructor leaves public attrs unset. @@ -248,7 +304,7 @@ def test___init___with_request_stream_and_worker_loop( :class:`DispatchSession` is constructed Then: It should be created without raising and ``task`` / - ``context`` / ``serializer`` should be unset before + ``decoded`` / ``serializer`` should be unset before ``__aenter__``. """ # Arrange @@ -260,15 +316,15 @@ def test___init___with_request_stream_and_worker_loop( # Assert assert isinstance(handler, DispatchSession) - # Class-level annotations declare ``task`` and ``context``; + # Class-level annotations declare ``task`` and ``decoded``; # they are only populated on enter so accessing them before # enter raises AttributeError. with pytest.raises(AttributeError): handler.task # noqa: B018 with pytest.raises(AttributeError): - handler.context # noqa: B018 + handler.decoded # noqa: B018 # All dispatch serializes through cloudpickle; this one - # serializer covers the payload, the context snapshots, and + # serializer covers the payload, the context contexts, and # any ``Rejected`` dumped exception (including pre-parse # failures such as StopAsyncIteration or a malformed frame). assert handler.serializer is wool.__serializer__ @@ -276,7 +332,7 @@ def test___init___with_request_stream_and_worker_loop( # -- streaming property ---------------------------------------------- @pytest.mark.asyncio - async def test_streaming_with_async_generator_task( + async def test_streaming_should_return_true_when_async_generator_task( self, worker_loop, mock_worker_proxy_cache ): """Test :attr:`streaming` reflects an async-generator task. @@ -299,7 +355,7 @@ async def test_streaming_with_async_generator_task( assert handler.streaming is True @pytest.mark.asyncio - async def test_streaming_with_coroutine_task( + async def test_streaming_should_return_false_when_coroutine_task( self, worker_loop, mock_worker_proxy_cache ): """Test :attr:`streaming` reflects a coroutine task. @@ -324,7 +380,7 @@ async def test_streaming_with_coroutine_task( # -- __aenter__ ------------------------------------------------------- @pytest.mark.asyncio - async def test___aenter___with_coroutine_task_and_cloudpickle( + async def test___aenter___should_populate_public_attrs_when_coroutine_task( self, worker_loop, mock_worker_proxy_cache ): """Test :meth:`__aenter__` populates public attrs for a @@ -336,7 +392,7 @@ async def test___aenter___with_coroutine_task_and_cloudpickle( When: The handler is entered via ``async with`` Then: - It should populate ``handler.task``, ``handler.context``, + It should populate ``handler.task``, ``handler.decoded``, and set ``handler.serializer`` to ``wool.__serializer__``. """ # Arrange @@ -348,11 +404,11 @@ async def test___aenter___with_coroutine_task_and_cloudpickle( # Assert assert isinstance(handler.task, Task) assert handler.task.callable is _coro_returning_default - assert isinstance(handler.context, Context) + assert isinstance(handler.decoded, ChainManifest) assert handler.serializer is wool.__serializer__ @pytest.mark.asyncio - async def test___aenter___with_async_generator_task_and_cloudpickle( + async def test___aenter___should_populate_public_attrs_when_async_generator_task( self, worker_loop, mock_worker_proxy_cache ): """Test :meth:`__aenter__` populates public attrs for a @@ -360,12 +416,12 @@ async def test___aenter___with_async_generator_task_and_cloudpickle( Given: A request stream whose first frame is a valid - async-generator :class:`Task` serialized via - ``to_protobuf()`` with no wire ``serializer`` field + async-generator :class:`Task` with cloudpickle + serialization When: The handler is entered via ``async with`` Then: - It should populate ``handler.task`` and ``handler.context``, + It should populate ``handler.task`` and ``handler.decoded``, set ``handler.serializer`` to ``wool.__serializer__``, and mark ``handler.streaming`` as ``True``. """ @@ -378,12 +434,12 @@ async def test___aenter___with_async_generator_task_and_cloudpickle( # Assert assert isinstance(handler.task, Task) assert handler.task.callable is _gen_default_two - assert isinstance(handler.context, Context) + assert isinstance(handler.decoded, ChainManifest) assert handler.serializer is wool.__serializer__ assert handler.streaming is True @pytest.mark.asyncio - async def test___aenter___with_sync_callable_task( + async def test___aenter___should_raise_rejected_when_sync_callable( self, worker_loop, mock_worker_proxy_cache ): """Test :meth:`__aenter__` rejects a non-async callable. @@ -413,7 +469,7 @@ async def test___aenter___with_sync_callable_task( ) @pytest.mark.asyncio - async def test___aenter___with_malformed_task_id( + async def test___aenter___should_raise_rejected_when_malformed_task_id( self, worker_loop, mock_worker_proxy_cache ): """Test :meth:`__aenter__` rejects a malformed task id. @@ -443,7 +499,7 @@ async def test___aenter___with_malformed_task_id( assert isinstance(exc_info.value.original, ValueError) @pytest.mark.asyncio - async def test___aenter___with_empty_request_stream( + async def test___aenter___should_raise_rejected_when_empty_request_stream( self, worker_loop, mock_worker_proxy_cache ): """Test :meth:`__aenter__` rejects an empty request stream. @@ -471,7 +527,7 @@ async def empty_stream(): assert "empty request stream" in str(exc_info.value.original) @pytest.mark.asyncio - async def test___aenter___with_first_frame_wrong_oneof( + async def test___aenter___should_raise_rejected_when_first_frame_not_task( self, worker_loop, mock_worker_proxy_cache ): """Test :meth:`__aenter__` rejects a first frame whose payload @@ -499,7 +555,7 @@ async def test___aenter___with_first_frame_wrong_oneof( assert "first request must carry a Task" in str(exc_info.value.original) @pytest.mark.asyncio - async def test___aenter___with_cancelled_request_iterator( + async def test___aenter___should_propagate_cancelled_when_iterator_cancelled( self, worker_loop, mock_worker_proxy_cache ): """Test :meth:`__aenter__` propagates ``CancelledError`` raw. @@ -525,42 +581,49 @@ async def cancelling_stream(): pass @pytest.mark.asyncio - async def test___aenter___with_decode_group_of_exceptions( + async def test___aenter___should_wrap_context_decode_error_in_rejected( self, worker_loop, mock_worker_proxy_cache, mocker: MockerFixture, ): - """Test :meth:`__aenter__` wraps an Exception-only decode - :class:`BaseExceptionGroup` as :class:`Rejected`. + """Test :meth:`__aenter__` wraps a pre-Ack + :class:`ChainSerializationError` in :class:`Rejected` for Nack + transport. Given: A request stream whose first frame is well-formed but - :meth:`Context.from_protobuf` raises a - :class:`BaseExceptionGroup` of :class:`Exception` peers + :func:`Chain.from_protobuf` raises a + :class:`ChainSerializationError` aggregating per-var warnings + (strict mode). When: - The handler is entered via ``async with`` + The handler is entered via ``async with``. Then: - It should raise :class:`Rejected` whose ``.original`` is a - :class:`BaseExceptionGroup` labeled - "request context decode failed". + It should raise :class:`Rejected` whose ``.original`` is + the :class:`ChainSerializationError` itself — the routine + does not run; the dispatch handler ships the error via + the Nack channel; the caller catches the same error class + symmetrically. """ # Arrange + import wool + task = _make_task(_coro_returning_default) - stream = _stream(_request_for(task)) + # Carry a present (but minimal) chain manifest so the lazy-wire- + # frame receiver actually invokes + # ``ChainManifest.from_protobuf``. + request = _request_for(task) + request.context.CopyFrom(protocol.ChainManifest()) + stream = _stream(request) - # Patch Context.from_protobuf to raise an Exception-only group. - # The constructor downgrades Exception-only groups to - # ExceptionGroup so the parse-phase ``except Exception`` arm - # routes it through Rejected. - peer = ValueError("decode peer") - eg = BaseExceptionGroup("simulated decode failure", [peer]) - from wool.runtime.context import base as ctx_base + peer = wool.SerializationWarning("decode peer") + err = wool.ChainSerializationError(peer) + from wool.runtime.context.manifest import ChainManifest mocker.patch.object( - ctx_base.Context, + ChainManifest, "from_protobuf", - classmethod(lambda cls, *a, **kw: (_ for _ in ()).throw(eg)), + classmethod(lambda cls, *a, **kw: (_ for _ in ()).throw(err)), ) # Act & assert @@ -569,61 +632,63 @@ async def test___aenter___with_decode_group_of_exceptions( pass original = exc_info.value.original - assert isinstance(original, BaseExceptionGroup) - assert "request context decode failed" in original.message + assert original is err + assert isinstance(original, wool.ChainSerializationError) + assert len(original.warnings) == 1 + assert original.warnings[0] is peer @pytest.mark.asyncio - async def test___aenter___with_decode_group_with_non_exception_peer( + async def test___aenter___should_propagate_base_exception_when_decode_raises_base( self, worker_loop, mock_worker_proxy_cache, mocker: MockerFixture, ): - """Test :meth:`__aenter__` propagates a true - :class:`BaseExceptionGroup` raw. + """Test :meth:`__aenter__` propagates a non-Exception raise + from decode raw (no Rejected wrap). Given: - A request stream whose first frame triggers a - :class:`BaseExceptionGroup` containing a non-Exception peer - (e.g. :class:`asyncio.CancelledError`) from - :meth:`Context.from_protobuf` + A request stream whose first frame triggers an + :class:`asyncio.CancelledError` raised directly from + :meth:`ChainManifest.from_protobuf` (e.g. cancellation + arriving mid-decode). When: - The handler is entered via ``async with`` + The handler is entered via ``async with``. Then: - It should propagate the :class:`BaseExceptionGroup` raw - (not as :class:`Rejected`). + It should propagate the :class:`asyncio.CancelledError` + raw — the parse-phase ``except Exception`` arm does not + catch :class:`BaseException` subclasses; Nack is the wrong + channel for cancellation/interrupt signals. """ # Arrange task = _make_task(_coro_returning_default) - stream = _stream(_request_for(task)) + # Carry a present chain manifest so the lazy-wire-frame receiver + # actually invokes ``ChainManifest.from_protobuf``. + request = _request_for(task) + request.context.CopyFrom(protocol.ChainManifest()) + stream = _stream(request) - # A group with a non-Exception peer stays a - # BaseExceptionGroup (no auto-downgrade) and so falls through - # the parse-phase ``except Exception`` arm. - peer = asyncio.CancelledError() - eg = BaseExceptionGroup("simulated decode failure", [peer]) - from wool.runtime.context import base as ctx_base + cancelled = asyncio.CancelledError() + from wool.runtime.context.manifest import ChainManifest mocker.patch.object( - ctx_base.Context, + ChainManifest, "from_protobuf", - classmethod(lambda cls, *a, **kw: (_ for _ in ()).throw(eg)), + classmethod(lambda cls, *a, **kw: (_ for _ in ()).throw(cancelled)), ) # Act & assert - with pytest.raises(BaseExceptionGroup) as exc_info: + with pytest.raises(asyncio.CancelledError) as exc_info: async with DispatchSession(stream, worker_loop): pass - # The raised group should carry the rewrap label and the - # original non-Exception peer, not be wrapped in Rejected. assert not isinstance(exc_info.value, Rejected) - assert "request context decode failed" in exc_info.value.message + assert exc_info.value is cancelled # -- __aiter__ -------------------------------------------------------- @pytest.mark.asyncio - async def test___aiter___with_coroutine_routine_yields_single_response( + async def test___aiter___should_yield_single_response_when_coroutine_routine( self, worker_loop, mock_worker_proxy_cache ): """Test :meth:`__aiter__` yields exactly one response for a @@ -631,13 +696,16 @@ async def test___aiter___with_coroutine_routine_yields_single_response( Given: A handler past :meth:`__aenter__` with a coroutine routine - returning a value + returning a value, dispatched with no caller Wool context + state When: The handler is iterated via ``async for`` Then: It should yield exactly one response whose result is the - coroutine's return value and whose context carries the - post-step :class:`protocol.Context` snapshot. + coroutine's return value and whose chain manifest is a stateless + :class:`protocol.ChainManifest` — a stateless dispatch leaves the + worker chain unarmed, so the post-step chain manifest carries no + variable entries. """ # Arrange task = _make_task(_coro_returning_default) @@ -649,18 +717,46 @@ async def test___aiter___with_coroutine_routine_yields_single_response( # Assert assert len(results) == 1 - assert results[0].result == "coroutine_value" - # The unified driver MUST emit a post-step context snapshot on - # every successful step (issue #187 motivation: "snapshot-encode - # duplicated across both paths"). The presence of an ``id`` on - # the response's :class:`protocol.Context` proves a snapshot - # was actually populated (a missing snapshot would surface as - # the empty default). - assert results[0].context is not None - assert results[0].context.id + assert results[0].payload == "coroutine_value" + # Under lazy-wire-frame semantics the unarmed worker omits + # the optional chain-manifest field entirely — the encode site + # reads wool.__chain__.get(), which raises LookupError on an + # unarmed worker, and maps that to a None wire chain manifest. + assert results[0].chain_manifest is None + + @pytest.mark.asyncio + async def test___aiter___should_yield_default_when_stateless_dispatch( + self, worker_loop, mock_worker_proxy_cache + ): + """Test a stateless dispatch leaves the worker chain unarmed. + + Given: + A coroutine routine dispatched with no caller wool.ContextVar + state, whose body offloads a wool.ContextVar read through + plain asyncio.to_thread + When: + The handler is iterated via ``async for`` + Then: + It should yield one response whose result is the variable's + constructor default — the worker chain stayed unarmed + (context_has_state gating), so the plain to_thread + offload ran as a bare contextvars context and never tripped + wool.ChainContention. + """ + # Arrange + task = _make_task(_coro_offloads_plain_to_thread) + stream = _stream(_request_for(task)) + + # Act + async with DispatchSession(stream, worker_loop) as handler: + results = [r async for r in handler] + + # Assert + assert len(results) == 1 + assert results[0].payload == "unset-default" @pytest.mark.asyncio - async def test___aiter___with_async_generator_task_yields_per_request( + async def test___aiter___should_yield_response_per_request_when_async_generator( self, worker_loop, mock_worker_proxy_cache ): """Test :meth:`__aiter__` yields one response per caller request @@ -689,10 +785,10 @@ async def test___aiter___with_async_generator_task_yields_per_request( results = [r async for r in handler] # Assert - assert [r.result for r in results] == [1, 2, 3] + assert [r.payload for r in results] == [1, 2, 3] @pytest.mark.asyncio - async def test___aiter___with_async_generator_yields_per_caller_request( + async def test___aiter___should_drive_multi_step_generator_through_caller_requests( self, worker_loop, mock_worker_proxy_cache ): """Test :meth:`__aiter__` drives a multi-step async generator @@ -720,10 +816,10 @@ async def test___aiter___with_async_generator_yields_per_caller_request( results = [r async for r in handler] # Assert - assert [r.result for r in results] == ["a", "b"] + assert [r.payload for r in results] == ["a", "b"] @pytest.mark.asyncio - async def test___aiter___called_twice_returns_same_iterator( + async def test___aiter___should_return_same_iterator_when_called_twice( self, worker_loop, mock_worker_proxy_cache ): """Test :meth:`__aiter__` is idempotent across calls. @@ -753,7 +849,7 @@ async def test___aiter___called_twice_returns_same_iterator( pass @pytest.mark.asyncio - async def test___aiter___with_routine_raising_mid_stream( + async def test___aiter___should_propagate_exception_when_routine_raises_mid_stream( self, worker_loop, mock_worker_proxy_cache ): """Test :meth:`__aiter__` propagates a routine's mid-stream @@ -782,7 +878,7 @@ async def test___aiter___with_routine_raising_mid_stream( pass @pytest.mark.asyncio - async def test___aiter___with_streaming_eof_closes_request_queue( + async def test___aiter___should_exit_cleanly_when_request_stream_ends( self, worker_loop, mock_worker_proxy_cache ): """Test :meth:`__aiter__` exits cleanly when the request stream @@ -815,7 +911,7 @@ async def test___aiter___with_streaming_eof_closes_request_queue( # -- __aexit__ -------------------------------------------------------- @pytest.mark.asyncio - async def test___aexit___after_cancel_before_aiter( + async def test___aexit___should_not_raise_when_cancelled_before_aiter( self, worker_loop, mock_worker_proxy_cache ): """Test :meth:`__aexit__` does not raise after a pre-iter cancel. @@ -845,7 +941,7 @@ async def test___aexit___after_cancel_before_aiter( [r async for r in handler] @pytest.mark.asyncio - async def test___aexit___swallows_worker_exception( + async def test___aexit___should_not_reraise_worker_exception( self, worker_loop, mock_worker_proxy_cache ): """Test :meth:`__aexit__` does not re-raise a worker exception. @@ -883,7 +979,7 @@ async def test___aexit___swallows_worker_exception( assert isinstance(captured[0], _CustomRoutineError) @pytest.mark.asyncio - async def test___aexit___with_drain_raising_cancelled_error( + async def test___aexit___should_surface_cancelled_error_when_drain_raises( self, worker_loop, mock_worker_proxy_cache, mocker: MockerFixture ): """Test :meth:`__aexit__` unwinds the exit stack even when the @@ -934,7 +1030,7 @@ async def raising_drain(self): ) @pytest.mark.asyncio - async def test___aexit___with_caller_cancelled_mid_teardown( + async def test___aexit___should_run_drain_when_caller_cancelled_mid_teardown( self, worker_loop, mock_worker_proxy_cache, mocker: MockerFixture ): """Test :meth:`__aexit__` runs the registered :meth:`drain` @@ -1018,7 +1114,7 @@ async def caller(): # -- drain ------------------------------------------------------------ @pytest.mark.asyncio - async def test_drain_called_twice_is_idempotent( + async def test_drain_should_return_promptly_when_called_twice( self, worker_loop, mock_worker_proxy_cache ): """Test :meth:`drain` returns promptly across repeat calls. @@ -1043,7 +1139,7 @@ async def test_drain_called_twice_is_idempotent( # Assert — bounded wait_for ensures both returned promptly. @pytest.mark.asyncio - async def test_drain_propagates_external_cancellation( + async def test_drain_should_propagate_cancelled_error_when_awaiting_task_cancelled( self, worker_loop, mock_worker_proxy_cache ): """Test :meth:`drain` propagates an externally-injected @@ -1106,7 +1202,7 @@ async def driver(): ) @pytest.mark.asyncio - async def test_drain_swallows_worker_cancellation( + async def test_drain_should_swallow_worker_cancellation( self, worker_loop, mock_worker_proxy_cache ): """Test :meth:`drain` swallows a worker-side cancellation. @@ -1123,39 +1219,35 @@ async def test_drain_swallows_worker_cancellation( cleanly. """ # Arrange — drive a long-running generator, cancel the worker - # task on its own loop, then call drain from the main loop. + # via the public :meth:`cancel` (which cancels the worker + # task on its own loop), then call drain from the main loop. task = _make_task(_slow_gen) stream = _stream(_request_for(task), _next_request()) async with DispatchSession(stream, worker_loop) as handler: iterator = aiter(handler) - # Wait for the worker driver to be scheduled — _worker_done - # is the public-equivalent observable populated by - # _schedule_worker. We poll without referencing it in - # the assert body itself to keep the test focused on - # the drain behavior. + # Wait for the worker driver task to appear on the worker + # loop — ``asyncio.all_tasks`` is a public observable of + # the scheduled worker task, so ``cancel`` below has a + # live worker task to cancel on its own loop. for _ in range(200): - if handler._worker_done is not None: + if asyncio.all_tasks(loop=worker_loop): break await asyncio.sleep(0.01) - # Cancel every task running on the worker loop. The - # cancellation cascades into the routine task and the - # session's ``_on_done`` callback, which closes the - # response queue and settles ``worker_done``. - def _cancel_workers(): - for t in asyncio.all_tasks(loop=worker_loop): - t.cancel() - - worker_loop.call_soon_threadsafe(_cancel_workers) - - # Give the worker time to settle; the response queue gets - # closed by the done-callback so the iterator exits. - # Narrow the catch to ``CancelledError`` — the only - # exception this scenario can legitimately produce — so - # any unrelated regression that surfaces a different - # exception class is not silently absorbed. + # Cancel via the public surface. ``cancel`` schedules the + # worker task's cancellation on the worker loop; the + # cascade settles the worker-completion future with + # ``CancelledError`` and closes the response queue, so the + # iterator exits. + await handler.cancel() + + # Drive the iterator to exit. Narrow the catch to + # ``CancelledError`` — the only exception this scenario can + # legitimately produce — so any unrelated regression that + # surfaces a different exception class is not silently + # absorbed. try: async for _ in iterator: pass @@ -1169,7 +1261,9 @@ def _cancel_workers(): # Assert — control reached here without raising. @pytest.mark.asyncio - async def test_drain_tolerates_closed_worker_loop(self, mock_worker_proxy_cache): + async def test_drain_should_return_when_worker_loop_closed( + self, mock_worker_proxy_cache + ): """Test :meth:`drain` returns when the worker loop is closed. Given: @@ -1193,8 +1287,8 @@ async def test_drain_tolerates_closed_worker_loop(self, mock_worker_proxy_cache) handler = DispatchSession(stream, loop) await handler.__aenter__() - # Close the worker loop. With no _worker_done assigned (no - # __aiter__ yet), drain's short-circuit path is exercised. + # Close the worker loop before any __aiter__, so drain's + # short-circuit path (no worker scheduled) is exercised. loop.call_soon_threadsafe(loop.stop) thread.join(timeout=5) loop.close() @@ -1204,20 +1298,20 @@ async def test_drain_tolerates_closed_worker_loop(self, mock_worker_proxy_cache) await asyncio.wait_for(handler.drain(), timeout=2.0) # Assert — drain returned without raising. finally: - # Best-effort cleanup of the exit stack. + # Best-effort public teardown. try: - await handler._stack.aclose() + await handler.__aexit__(None, None, None) except Exception: pass @pytest.mark.asyncio - async def test_drain_returns_when_worker_create_task_fails( + async def test_drain_should_return_when_worker_create_task_fails( self, mocker: MockerFixture, mock_worker_proxy_cache ): """Test :meth:`drain` returns when ``_start`` fails to create the worker task on the worker loop. - Regression test for A8. Pre-fix, ``_start`` (scheduled + Regression test. Pre-fix, ``_start`` (scheduled via :meth:`asyncio.AbstractEventLoop.call_soon_threadsafe` from :meth:`_schedule_worker`) called :meth:`asyncio.AbstractEventLoop.create_task` without a @@ -1237,7 +1331,7 @@ async def test_drain_returns_when_worker_create_task_fails( (simulating a late-loop-closure or task-factory failure). When: - :meth:`_schedule_worker` runs ``_start`` on the + The first :meth:`__aiter__` schedules ``_start`` on the worker loop and :meth:`drain` is awaited. Then: :meth:`drain` should return within a short timeout — @@ -1268,11 +1362,12 @@ async def test_drain_returns_when_worker_create_task_fails( ), ) - # Trigger scheduling. This calls - # ``call_soon_threadsafe(_start)``; _start later runs - # on the worker loop where the patched create_task - # raises. - handler._schedule_worker() + # Trigger lazy scheduling through the public iteration + # entry point. ``__aiter__`` calls ``_schedule_worker``, + # which schedules ``_start`` via + # ``call_soon_threadsafe``; _start later runs on the + # worker loop where the patched create_task raises. + aiter(handler) # Wait briefly for the worker loop to execute _start. await asyncio.sleep(0.1) @@ -1282,7 +1377,7 @@ async def test_drain_returns_when_worker_create_task_fails( await asyncio.wait_for(handler.drain(), timeout=2.0) finally: try: - await handler._stack.aclose() + await handler.__aexit__(None, None, None) except Exception: pass loop.call_soon_threadsafe(loop.stop) @@ -1293,7 +1388,7 @@ async def test_drain_returns_when_worker_create_task_fails( # -- cancel ----------------------------------------------------------- @pytest.mark.asyncio - async def test_cancel_before_aenter_is_safe( + async def test_cancel_should_not_raise_when_called_before_aenter( self, worker_loop, mock_worker_proxy_cache ): """Test :meth:`cancel` invoked before :meth:`__aenter__` is safe. @@ -1324,7 +1419,7 @@ async def test_cancel_before_aenter_is_safe( [r async for r in handler] @pytest.mark.asyncio - async def test_cancel_after_iteration_is_safe( + async def test_cancel_should_not_raise_when_called_after_iteration( self, worker_loop, mock_worker_proxy_cache ): """Test :meth:`cancel` invoked after iteration is a no-op. @@ -1350,7 +1445,7 @@ async def test_cancel_after_iteration_is_safe( # Assert — control reached here without raising. @pytest.mark.asyncio - async def test_cancel_during_iteration_from_different_task( + async def test_cancel_should_raise_cancelled_on_iterator_when_different_task( self, worker_loop, mock_worker_proxy_cache ): """Test :meth:`cancel` from another task surfaces @@ -1383,9 +1478,13 @@ async def driver(): # Act driver_task = asyncio.create_task(driver()) - # Wait until the iterator is suspended on response_queue.get. + # Wait until the worker driver task is scheduled on the worker + # loop — ``asyncio.all_tasks`` is a public observable of the + # running worker task, which (for the never-yielding + # ``_slow_gen``) implies the iterator is suspended on the + # response. for _ in range(500): - if "handler" in captured and captured["handler"]._worker_done is not None: + if "handler" in captured and asyncio.all_tasks(loop=worker_loop): break await asyncio.sleep(0.01) else: @@ -1405,7 +1504,49 @@ async def driver(): assert results == [] @pytest.mark.asyncio - async def test_cancel_during_suspended_iteration_unblocks_iterator( + async def test_cancel_should_raise_cancelled_error_when_racing_worker_scheduling( + self, worker_loop, mock_worker_proxy_cache + ): + """Test :meth:`cancel` racing worker scheduling still cancels + the worker. + + Given: + A handler whose first :meth:`__aiter__` has queued the + worker driver onto the worker loop, with :meth:`cancel` + invoked in the same main-loop turn — before the worker + loop has run the queued driver + When: + The iterator's first :meth:`anext` is awaited + Then: + It should raise :class:`asyncio.CancelledError` and yield + no value — the driver's own start-time cancellation + re-check observes the flag and cancels the freshly + created worker task, so a routine never runs to natural + completion after a cancel that raced its scheduling. + """ + # Arrange — a long-running routine so a missed start-time + # cancellation would otherwise let the worker run for real. + task = _make_task(_slow_coro) + stream = _stream(_request_for(task)) + + async with DispatchSession(stream, worker_loop) as handler: + # ``__aiter__`` runs ``_schedule_worker`` synchronously, + # queuing the worker driver onto the worker loop via + # ``call_soon_threadsafe``. ``cancel`` is synchronous up to + # its return, so calling it before any ``await`` yields the + # main loop sets the cancel flag while the queued driver + # has not yet run on the worker loop — the race the driver + # guards against at start time. + iterator = aiter(handler) + await handler.cancel() + + # Act & assert — the iterator surfaces the cancellation and + # the routine never produces a value. + with pytest.raises(asyncio.CancelledError): + await anext(iterator) + + @pytest.mark.asyncio + async def test_cancel_should_raise_cancelled_error_when_iterator_suspended( self, worker_loop, mock_worker_proxy_cache ): """Test :meth:`cancel` unblocks an iterator suspended on a @@ -1439,9 +1580,12 @@ async def driver(): # Act driver_task = asyncio.create_task(driver()) - # Wait for the iterator to be suspended. + # Wait for the worker driver task to be scheduled on the + # worker loop — ``asyncio.all_tasks`` is a public observable + # of the running worker task; for the never-yielding + # ``_slow_gen`` this implies the iterator is suspended. for _ in range(500): - if "handler" in captured and captured["handler"]._worker_done is not None: + if "handler" in captured and asyncio.all_tasks(loop=worker_loop): break await asyncio.sleep(0.01) else: @@ -1471,7 +1615,7 @@ async def driver(): assert captured.get("cancelled") is True @pytest.mark.asyncio - async def test_cancel_called_twice_is_idempotent( + async def test_cancel_should_not_raise_when_called_twice( self, worker_loop, mock_worker_proxy_cache ): """Test :meth:`cancel` is idempotent across repeat calls. @@ -1495,60 +1639,125 @@ async def test_cancel_called_twice_is_idempotent( # Assert — control reached here without raising. @pytest.mark.asyncio - async def test___aiter___streaming_rewraps_mid_stream_context_decode_failure( + async def test_cancel_should_preempt_routine_when_suspended_mid_step( + self, worker_loop, mock_worker_proxy_cache + ): + """Test :meth:`cancel` preempts a routine suspended inside a step. + + Given: + A handler whose streaming routine has been forwarded its + first request and is suspended inside ``_drive_step``'s + per-step ``await`` (a long sleep). + When: + :meth:`cancel` is awaited. + Then: + It should preempt the routine rather than let it run to + natural completion — the iterator surfaces + :class:`asyncio.CancelledError` and :meth:`drain` returns + promptly instead of awaiting the routine's full sleep. + """ + # Arrange — forward the first Next to the worker. ``anext`` + # forwards the request (the put precedes the response await) + # then suspends on the response that never arrives, so the + # worker enters its first step. The routine sets a module-level + # Event the moment it is inside that step. + _STEP_BLOCKING.clear() + task = _make_task(_gen_blocks_in_step) + stream = _stream(_request_for(task), _next_request()) + + async with DispatchSession(stream, worker_loop) as handler: + iterator = aiter(handler) + pull = asyncio.ensure_future(anext(iterator)) + try: + # Wait until the routine signals it is suspended *inside* + # the step (it sets the Event right before its in-step + # ``await``), so the cancel below lands while the + # per-step task is in flight. Awaiting the blocking wait + # in the default executor keeps the main loop free to + # pump ``pull``'s request forward to the worker. + entered = await asyncio.get_running_loop().run_in_executor( + None, _STEP_BLOCKING.wait, 5.0 + ) + assert entered, "worker never entered the blocking step" + + # Act — cancel while the step task is still in flight. + await handler.cancel() + + # Assert — the suspended pull surfaces cancellation, and + # drain returns promptly because the step task was + # cancelled rather than awaited to its full sleep. + with pytest.raises(asyncio.CancelledError): + await pull + await asyncio.wait_for(handler.drain(), timeout=2.0) + finally: + if not pull.done(): + pull.cancel() + + @pytest.mark.asyncio + async def test___aiter___should_propagate_context_error_when_mid_stream_decode_fails( self, worker_loop, mock_worker_proxy_cache, mocker: MockerFixture ): - """Test mid-stream context decode failures are re-wrapped - with a labeled message so the dispatch handler can - distinguish them from initial-frame decode failures. + """Test mid-stream chain-manifest decode failures propagate as + :class:`ChainSerializationError` through the routine-exception + channel. Given: A streaming session whose first ``next`` request decodes - cleanly but whose second ``next`` request's context - decode raises a :class:`BaseExceptionGroup` (operator - promoted :class:`ContextDecodeWarning` to an exception) + cleanly but whose second ``next`` request's chain-manifest + decode raises a :class:`ChainSerializationError` aggregating + per-var warnings (operator promoted + :class:`SerializationWarning` to an exception). When: - The caller iterates the session past the first response + The caller iterates the session past the first response. Then: - The error should surface as a :class:`BaseExceptionGroup` - (or :class:`ExceptionGroup` after constructor downgrade) - labeled "mid-stream request context decode failed" so - the dispatch handler can distinguish it from - initial-frame decode failures that share the same peer - type. + The :class:`ChainSerializationError` should surface + unmolested out of the iterator — under the strict-mode + "fail loud" contract the worker ships it via the + routine-exception channel; the caller's existing + ``except ChainSerializationError`` catches without migrating + to ``except*``. """ - from wool.runtime.context import base as ctx_base + # Arrange — streaming routine that yields per ``next``. The + # mid-stream next-requests carry a present-but-empty chain manifest so + # ``RequestFrame.from_protobuf`` routes into + # ``ChainManifest.from_protobuf`` (and hence the patched + # mock) under lazy-wire-frame semantics — without a present + # field, the decode path short-circuits before the mock can + # fire. ``__aenter__``'s initial decode does not hit the mock + # either because the task-frame fixture omits the optional + # context field; so the first patched call corresponds to the + # first mid-stream decode. + import wool - # Arrange — streaming routine that yields per ``next``. task = _make_task(_gen_three) stream = _stream( _request_for(task), - _next_request(), - _next_request(), + _next_request(with_context=True), + _next_request(with_context=True), ) - # Counter-based patch: let the initial __aenter__ decode - # and the first per-step decode succeed; the third call - # (second mid-stream decode) raises a - # ``BaseExceptionGroup`` of Exception-only peers. - original = ctx_base.Context.from_protobuf + # Counter-based patch: let the first per-step decode succeed; + # the second mid-stream decode raises a ChainSerializationError. + from wool.runtime.context.manifest import ChainManifest + + original = ChainManifest.from_protobuf calls = {"n": 0} - peer = ValueError("decode peer") + peer = wool.SerializationWarning("decode peer") - def fake_from_protobuf(cls, proto_ctx, *, serializer): + def fake_from_protobuf(cls, wire, *, serializer=None): calls["n"] += 1 - if calls["n"] >= 3: - raise BaseExceptionGroup("simulated decode failure", [peer]) - return original.__func__(cls, proto_ctx, serializer=serializer) + if calls["n"] >= 2: + raise wool.ChainSerializationError(peer) + return original(wire, serializer=serializer) mocker.patch.object( - ctx_base.Context, + ChainManifest, "from_protobuf", classmethod(fake_from_protobuf), ) - # Act — iterate; the second step raises the re-wrapped - # group, which surfaces out of the iterator. + # Act — iterate; the second step raises ChainSerializationError, + # which surfaces out of the iterator. captured: list[BaseException] = [] async with DispatchSession(stream, worker_loop) as handler: try: @@ -1557,56 +1766,15 @@ def fake_from_protobuf(cls, proto_ctx, *, serializer): except BaseException as e: captured.append(e) - # Assert + # Assert: typed ChainSerializationError carrying the original + # warning on .warnings. assert len(captured) == 1 - eg = captured[0] - assert isinstance(eg, BaseExceptionGroup) - assert "mid-stream request context decode failed" in eg.message - - @pytest.mark.asyncio - async def test___aenter___propagates_keyboard_interrupt_during_aclose( - self, worker_loop, mock_worker_proxy_cache, mocker: MockerFixture - ): - """Test the safe-aclose helper does not swallow - :class:`KeyboardInterrupt` during cleanup. - - Given: - A handler whose ``__aenter__`` is failing (empty request - stream → :class:`Rejected`) and whose cleanup - ``_stack.aclose()`` raises :class:`KeyboardInterrupt` - (simulating a Ctrl-C landing mid-cleanup) - When: - The handler is entered - Then: - :class:`KeyboardInterrupt` should propagate raw out of - the helper rather than being swallowed by the - ``except Exception`` arm. - """ - - # Arrange — an empty stream forces __aenter__ to raise - # StopAsyncIteration and route through the safe-aclose - # error path. - async def empty_stream(): - if False: - yield # pragma: no cover - - handler = DispatchSession(empty_stream(), worker_loop) - - # Patch the exit stack's aclose so the cleanup raises - # KeyboardInterrupt — the safe-aclose helper must re-raise - # this rather than swallow it under ``except Exception``. - mocker.patch.object( - handler._stack, - "aclose", - side_effect=KeyboardInterrupt("simulated Ctrl-C"), - ) - - # Act & assert - with pytest.raises(KeyboardInterrupt): - await handler.__aenter__() + primary = captured[0] + assert isinstance(primary, wool.ChainSerializationError) + assert primary.warnings == (peer,) @pytest.mark.asyncio - async def test___aiter___streaming_breaks_on_cancel_between_requests( + async def test___aiter___should_break_pump_when_cancelled_between_requests( self, worker_loop, mock_worker_proxy_cache ): """Test the streaming pump observes a mid-pump cancel and @@ -1696,10 +1864,10 @@ async def driver(): # worker would advance the generator to ``3``, and the # response would appear in ``observed["responses"]``. assert len(observed.get("responses", [])) == 1 - assert observed["responses"][0].result == 1 + assert observed["responses"][0].payload == 1 @pytest.mark.asyncio - async def test_cancel_tolerates_closed_worker_loop( + async def test_cancel_should_return_when_worker_loop_closed( self, worker_loop, mock_worker_proxy_cache, mocker: MockerFixture ): """Test :meth:`cancel` swallows ``RuntimeError`` from a torn-down @@ -1717,93 +1885,43 @@ async def test_cancel_tolerates_closed_worker_loop( closed")`` is swallowed because the dispatch is no longer serviceable. """ - # Arrange + # Arrange — drive a real dispatch to completion so the worker + # driver task is genuinely scheduled (and recorded) on the + # worker loop. Once it has run, ``cancel`` still routes through + # the worker-task-cancel branch (the task reference is set), so + # the patched ``call_soon_threadsafe`` exercises the + # closed-loop swallow rather than the early-return path that a + # never-scheduled worker would take. task = _make_task(_coro_returning_default) stream = _stream(_request_for(task)) async with DispatchSession(stream, worker_loop) as handler: - # Bind a mock worker task without scheduling a real one - # so we exercise the ``call_soon_threadsafe`` branch of - # cancel() directly. - handler._worker_task = mocker.MagicMock() + # Consume the result so the worker driver task settles — + # this avoids the teardown blocking on a live routine while + # still leaving the worker-task reference in place for + # cancel to act on. + results = [r async for r in handler] + assert results[0].payload == "coroutine_value" + # Patch the worker loop's ``call_soon_threadsafe`` to - # simulate a torn-down loop ("Event loop is closed"). + # simulate a torn-down loop ("Event loop is closed") — the + # boundary cancel must schedule across. This affects only + # the worker-task-cancel path; the response queue close + # rides the main loop. mocker.patch.object( worker_loop, "call_soon_threadsafe", side_effect=RuntimeError("Event loop is closed"), ) - # Act + # Act & assert — cancel returns without raising; the + # ``call_soon_threadsafe`` RuntimeError is swallowed. await handler.cancel() - # Assert — cancel returned without raising; the - # RuntimeError was swallowed. - assert handler._cancelled is True - - # -- Migrated from test_service.py::TestWorkerService (F17) ----------- - - @pytest.mark.asyncio - async def test___aenter___preserves_parse_error_when_aclose_raises( - self, worker_loop, mock_worker_proxy_cache - ): - """Test :meth:`__aenter__` preserves the original parse error as - :class:`Rejected` even when ``_stack.aclose()`` raises during - cleanup. - - Regression test: pre-fix, the parse-phase ``except Exception as - e: await self._stack.aclose(); raise Rejected(e) from None`` - ran ``aclose`` un-guarded. If the stack's exit chain raised - (e.g. a registered resource's ``__aexit__`` failing), the new - exception replaced ``e`` and ``Rejected(e)`` was never - constructed — the dispatch handler's Nack-with-exception - channel observed the cleanup failure instead of the typed - parse error. The fix swallows aclose failures so the parse - error always reaches the caller. - - Given: - A request that fails parse-phase validation (a non-async - callable) AND a stack whose ``aclose`` raises during the - resulting cleanup - When: - :meth:`__aenter__` is invoked - Then: - It should raise :class:`Rejected` whose ``original`` - attribute carries the parse-phase :class:`ValueError`, not - the simulated aclose failure. - """ - # Arrange — a non-async callable triggers a ValueError in the - # __aenter__ validation step. ``_stack.aclose`` is patched - # below to raise, so the cleanup-failure path is exercised - # regardless of what the stack contains. - task = _make_task(_sync_callable) - stream = _stream(_request_for(task)) - handler = DispatchSession(stream, worker_loop) - - # Patch ``_stack.aclose`` directly: the underlying - # :class:`AsyncExitStack` is not exposed through any public - # hook, so the cleanup-failure path can only be exercised by - # replacing this attribute. The substitution mirrors the - # operational failure mode (a registered resource's - # ``__aexit__`` raising). - async def raising_aclose(): - raise RuntimeError("simulated aclose failure during cleanup") - - handler._stack.aclose = raising_aclose - - # Act + Assert - with pytest.raises(Rejected) as exc_info: - await handler.__aenter__() - - assert isinstance(exc_info.value.original, ValueError), ( - f"Rejected.original must carry the parse-phase ValueError, " - f"not the aclose failure — observed " - f"{type(exc_info.value.original).__name__}" - ) - assert "Expected coroutine function" in str(exc_info.value.original) + # -- Migrated from test_service.py::TestWorkerService ----------------- @pytest.mark.asyncio - async def test___aiter___defers_worker_scheduling( + async def test___aiter___should_defer_worker_scheduling( self, worker_loop, mock_worker_proxy_cache ): """Test :meth:`__aiter__` defers worker scheduling until the @@ -1811,15 +1929,14 @@ async def test___aiter___defers_worker_scheduling( :meth:`__aenter__`. Regression test for the race between dispatch's backpressure - hook and the worker for :meth:`Context._guard` ownership. - Pre-fix :meth:`__aenter__` scheduled the worker eagerly; with - a backpressure hook that yielded the main loop while holding - ``attached(handler.context)``, the worker thread would race - to acquire the same Context's guard and spuriously raise - ``RuntimeError("wool.Context is already running...")``. The - invariant tested here — :meth:`__aenter__` is parse-only — - guarantees no contention regardless of how long any - post-parse main-loop work holds the Context. + hook and the worker for chain ownership. Pre-fix + :meth:`__aenter__` scheduled the worker eagerly; with a + backpressure hook that yielded the main loop while reading the + decoded chain manifest, the worker thread would race to re-stamp the + same chain's owner. The invariant tested here — + :meth:`__aenter__` is parse-only — guarantees no contention + regardless of how long any post-parse main-loop work holds the + chain. Given: A :class:`DispatchSession` constructed around a parsed @@ -1838,114 +1955,128 @@ async def test___aiter___defers_worker_scheduling( stream = _stream(_request_for(task)) async with DispatchSession(stream, worker_loop) as handler: - # ``_worker_done`` is the private marker created by - # :meth:`_schedule_worker`. Probed directly because no - # public observable can witness "worker not yet - # scheduled" without producing or consuming a Response — - # which itself would force scheduling. The marker is the - # narrowest stand-in. - assert handler._worker_done is None, ( + # The worker driver runs as a task on the worker loop; + # ``asyncio.all_tasks`` is the public observable of + # whether it has been scheduled. After parse-only + # ``__aenter__`` (no ``__aiter__`` yet) the worker loop + # carries no driver task. + assert not asyncio.all_tasks(loop=worker_loop), ( "DispatchSession.__aenter__ must defer worker scheduling" ) # Act iterator = aiter(handler) - # Assert - assert handler._worker_done is not None, ( - "DispatchSession.__aiter__ must schedule the worker on first call" - ) + # Assert — the first ``__aiter__`` schedules the worker + # driver task on the worker loop. Scheduling crosses + # loops via ``call_soon_threadsafe``, so poll until the + # driver task appears. + for _ in range(500): + if asyncio.all_tasks(loop=worker_loop): + break + await asyncio.sleep(0.01) + else: + pytest.fail( + "DispatchSession.__aiter__ must schedule the worker on first call" + ) response = await anext(iterator) - assert response.result == "coroutine_value" + assert response.payload == "coroutine_value" @pytest.mark.asyncio - async def test___aiter___swallows_request_queue_runtime_error_mid_stream( - self, worker_loop, mock_worker_proxy_cache, mocker: MockerFixture + async def test_streaming_should_end_stream_when_put_to_closed_worker_loop( + self, mock_worker_proxy_cache ): - """Test :meth:`_iterate` swallows :class:`RuntimeError` from - ``_RequestQueue.put`` so a closed worker loop mid-stream - terminates the iterator cleanly instead of surfacing the - loop-teardown error as a routine failure. - - Regression test: pre-fix, the streaming branch of - :meth:`_iterate` called ``request_queue.put(protobuf_request)`` - un-guarded. :meth:`_RequestQueue.put` schedules onto the - worker loop via ``call_soon_threadsafe``; if the worker loop - has been torn down (graceful shutdown teardown landing - between two main-loop pumps), put raises ``RuntimeError( - "Event loop is closed")``. The unguarded propagation - surfaced the runtime error out of :meth:`_iterate` — but the - routine never failed; the worker loop did. The fix mirrors - :meth:`drain`'s pattern: catch ``RuntimeError`` at the put - site and break cleanly. + """Test forwarding a request after the worker loop closed ends + the stream cleanly. Given: - A streaming session with :meth:`_RequestQueue.put` patched - to succeed on the first call and raise ``RuntimeError`` - on the second + A streaming dispatch whose worker loop is torn down after + the worker is scheduled but before the next request is + forwarded — the graceful-shutdown race where the loop pool + reclaims the worker loop between two main-loop pumps. When: - The iterator is driven for the second response — the - patched put triggers the runtime error mid-stream + The session forwards the next request onto the closed loop. Then: - The iterator should terminate cleanly without surfacing - a synthetic :class:`RuntimeError` as a routine failure. + ``_RequestQueue.put`` should surface the closed loop as the + typed ``_WorkerLoopClosed`` signal and the iterator should + end with no responses, rather than shipping the transport + teardown as a routine failure. """ - from wool.runtime.worker.session import _RequestQueue - - # Arrange — patching :class:`_RequestQueue.put` is justified - # here: the "real boundary" alternative (closing the worker - # loop mid-stream) has a semantic obstacle. Closing the - # worker loop while the worker task is suspended on - # ``request_queue.get`` leaves the worker-completion future - # ``_worker_done`` unresolved (the ``_on_done`` callback - # never fires), hanging the subsequent :meth:`drain` call - # from ``__aexit__``. The patch injects the exact failure - # mode the fix targets — a ``call_soon_threadsafe`` raise on - # a torn-down loop — without leaking the worker task. - call_count = 0 - original_put = _RequestQueue.put - - def patched_put(self, request): - nonlocal call_count - call_count += 1 - if call_count >= 2: - raise RuntimeError("simulated closed worker loop during put") - return original_put(self, request) - - mocker.patch.object(_RequestQueue, "put", patched_put) + # Arrange — a dedicated worker loop we can close + # deterministically. It is never run, so ``__aenter__`` (on the + # main loop) and the scheduling in ``__aiter__`` (a queued, + # never-executed ``call_soon`` on this loop) leave no worker + # task to orphan. The session is driven without ``async with``: + # closing the worker loop leaves the worker-completion future + # unresolved, so the registered ``drain`` teardown would block + # awaiting it — the production graceful-shutdown path tolerates + # this, and the closed loop plus pending future are GC-clean. + worker_loop = asyncio.new_event_loop() + try: + task = _make_task(_gen_default_two) + stream = _stream(_request_for(task), _next_request()) + handler = DispatchSession(stream, worker_loop) + await handler.__aenter__() + iterator = aiter(handler) - task = _make_task(_gen_default_two) - stream = _stream( - _request_for(task), - _next_request(), - _next_request(), - ) + # Act — tear the worker loop down before the next request is + # forwarded. + worker_loop.close() + responses = [response async for response in iterator] + finally: + if not worker_loop.is_closed(): + worker_loop.close() - # Act - responses: list = [] - async with DispatchSession(stream, worker_loop) as handler: - try: - async for response in handler: - responses.append(response) - except RuntimeError as e: - if "simulated closed worker loop" in str(e): - pytest.fail( - "_iterate must catch RuntimeError from " - "request_queue.put and terminate the stream " - "cleanly; the unguarded raise surfaced " - "the synthetic loop-teardown error as a " - "routine failure" - ) - raise + # Assert — the closed-loop put surfaced the typed signal and the + # stream ended cleanly with no responses. + assert responses == [] - # Assert — iterator terminated cleanly. The first put - # succeeded so one response was yielded; the second put - # raised and was swallowed. - assert len(responses) == 1 - assert responses[0].result == "a" + @pytest.mark.asyncio + async def test_coroutine_should_end_stream_when_put_to_closed_worker_loop( + self, mock_worker_proxy_cache + ): + """Test forwarding a coroutine's prime after the worker loop + closed ends the stream cleanly. + + Given: + A coroutine dispatch whose worker loop is torn down after + the worker is scheduled but before the prime request is + forwarded — the graceful-shutdown race where the loop pool + reclaims the worker loop between two main-loop pumps. + When: + The session forwards the prime request onto the closed loop. + Then: + ``_RequestQueue.put`` should surface the closed loop as the + typed ``_WorkerLoopClosed`` signal and the iterator should + end with no responses, rather than shipping the transport + teardown as a routine failure. + """ + # Arrange — a dedicated worker loop we can close deterministically. + # As with the streaming twin, the session is driven without + # ``async with`` so closing the worker loop (leaving the + # worker-completion future unresolved) does not block the + # registered ``drain`` teardown. + worker_loop = asyncio.new_event_loop() + try: + task = _make_task(_coro_returning_default) + stream = _stream(_request_for(task), _next_request()) + handler = DispatchSession(stream, worker_loop) + await handler.__aenter__() + iterator = aiter(handler) + + # Act — tear the worker loop down before the prime is forwarded. + worker_loop.close() + responses = [response async for response in iterator] + finally: + if not worker_loop.is_closed(): + worker_loop.close() + + # Assert — the closed-loop put surfaced the typed signal and the + # stream ended cleanly with no responses. + assert responses == [] @pytest.mark.asyncio - async def test___aexit___drains_on_terminal_exception( + async def test___aexit___should_call_drain_twice_on_terminal_exception( self, grpc_aio_stub, mock_worker_proxy_cache, mocker: MockerFixture ): """Test :class:`DispatchSession.drain` is called both from the @@ -1957,10 +2088,11 @@ async def test___aexit___drains_on_terminal_exception( terminal-exception clause is reached while the worker is still alive. Main-loop handler-level failures (e.g. ``response.to_protobuf`` raising on dump) reach the except - clause with the worker mid-``_step`` mutating ``work_ctx``. - Without an explicit drain before the snapshot, - ``handler.context.to_protobuf(...)`` reads ``_data`` while - the worker writes it. The fix calls + clause with the worker mid-``_step`` mutating the work + chain. Without an explicit drain before reading + ``session._final_wire_chain_manifest``, the handler would read the + worker-published chain manifest before the worker task has + finished encoding it inside its own Chain. The fix calls :meth:`DispatchSession.drain` from dispatch's terminal- exception clause; :meth:`__aexit__` also calls drain (via its exit stack, idempotent) so the spy observes two calls @@ -1986,7 +2118,7 @@ async def test___aexit___drains_on_terminal_exception( first_request = protocol.Request(task=task.to_protobuf()) next_request = protocol.Request( next=protocol.Void(), - context=protocol.Context(id=uuid4().hex), + context=protocol.ChainManifest(id=uuid4().hex), ) drain_spy = mocker.spy(DispatchSession, "drain") @@ -2012,16 +2144,17 @@ async def test___aexit___drains_on_terminal_exception( # dispatch handler's terminal-exception clause (the fix) and # once from :meth:`__aexit__`'s exit-stack unwind. Without # the fix, drain is called only from ``__aexit__`` and the - # context snapshot races the still-alive worker. + # read of ``session._final_wire_chain_manifest`` races the + # still-alive worker. assert drain_spy.call_count >= 2, ( f"Expected dispatch's terminal-exception clause to call " - f"DispatchSession.drain before snapshotting " - f"DispatchSession.context (plus __aexit__'s call); " - f"observed {drain_spy.call_count} call(s)." + f"DispatchSession.drain before reading " + f"DispatchSession._final_wire_chain_manifest (plus __aexit__'s " + f"call); observed {drain_spy.call_count} call(s)." ) @pytest.mark.asyncio - async def test___aexit___does_not_mask_routine_exception_when_cancel_raises( + async def test___aexit___should_ship_routine_exception_when_cancel_raises( self, grpc_aio_stub, mock_worker_proxy_cache, @@ -2091,7 +2224,7 @@ async def raising_cancel(self): ) @pytest.mark.asyncio - async def test_drain_with_closed_worker_loop_pre_schedule( + async def test_drain_should_return_when_worker_loop_closed_pre_schedule( self, mock_worker_proxy_cache ): """Test :meth:`drain` returns promptly when the worker loop is @@ -2157,12 +2290,14 @@ async def test_drain_with_closed_worker_loop_pre_schedule( await asyncio.wait_for(handler.drain(), timeout=2.0) finally: try: - await handler._stack.aclose() + await handler.__aexit__(None, None, None) except Exception: pass @pytest.mark.asyncio - async def test_cancel_before_aiter(self, worker_loop, mock_worker_proxy_cache): + async def test_cancel_should_short_circuit_scheduling_when_called_before_aiter( + self, worker_loop, mock_worker_proxy_cache + ): """Test :meth:`cancel` invoked before :meth:`__aiter__` short- circuits worker scheduling and surfaces :class:`asyncio.CancelledError` on the iterator's first @@ -2186,10 +2321,9 @@ async def test_cancel_before_aiter(self, worker_loop, mock_worker_proxy_cache): :meth:`cancel` is called and then :meth:`__aiter__` is invoked Then: - It should not schedule the worker (``_request_queue`` / - ``_worker_done`` remain ``None``) and the iterator's - first ``anext`` should raise - :class:`asyncio.CancelledError`. + It should not schedule the worker driver task on the + worker loop and the iterator's first ``anext`` should + raise :class:`asyncio.CancelledError`. """ # Arrange task = _make_task(_coro_returning_default) @@ -2200,17 +2334,17 @@ async def test_cancel_before_aiter(self, worker_loop, mock_worker_proxy_cache): await handler.cancel() iterator = aiter(handler) - # Assert — worker was not scheduled. The two private - # markers stand in for "no scheduling occurred"; no - # public observable can witness the absence of - # scheduling without forcing it. - assert handler._request_queue is None, ( - "cancel() before __aiter__ must short-circuit " - "_schedule_worker — _request_queue should remain None" - ) - assert handler._worker_done is None, ( - "cancel() before __aiter__ must short-circuit " - "_schedule_worker — _worker_done should remain None" + # Give any cross-loop scheduling a chance to land so the + # absence below is meaningful rather than merely early. + await asyncio.sleep(0.05) + + # Assert — the cancel short-circuit means no worker driver + # task is ever scheduled on the worker loop. + # ``asyncio.all_tasks`` is the public observable of that + # absence. + assert not asyncio.all_tasks(loop=worker_loop), ( + "cancel() before __aiter__ must short-circuit worker " + "scheduling — no driver task should run on the worker loop" ) # The iterator surfaces the cancellation immediately. @@ -2218,7 +2352,7 @@ async def test_cancel_before_aiter(self, worker_loop, mock_worker_proxy_cache): await anext(iterator) @pytest.mark.asyncio - async def test_cancel_from_different_task( + async def test_cancel_should_not_raise_when_called_from_different_task( self, worker_loop, mock_worker_proxy_cache ): """Test :meth:`cancel` is safe to call from a task other than @@ -2262,13 +2396,12 @@ async def driver(): driver_task = asyncio.create_task(driver()) - # Wait until the worker has been scheduled — proves - # ``__aiter__`` ran and ``_iterate`` is suspended at - # ``response_queue.get`` waiting on the slow coroutine. - # Polling ``_worker_done`` matches the existing pattern in - # this test class for "iterator suspended" detection - # (see :func:`test_cancel_during_iteration_from_different_task`). - while "handler" not in captured or captured["handler"]._worker_done is None: + # Wait until the worker driver task is scheduled on the + # worker loop — ``asyncio.all_tasks`` is the public observable + # of the running worker task. For the never-returning + # ``_slow_coro`` this proves ``__aiter__`` ran and ``_iterate`` + # is suspended at ``response_queue.get``. + while "handler" not in captured or not asyncio.all_tasks(loop=worker_loop): await asyncio.sleep(0) # Act — cancel from this task while driver_task drives the diff --git a/wool/tests/stdlib_parity/conftest.py b/wool/tests/stdlib_parity/conftest.py new file mode 100644 index 00000000..e62fc478 --- /dev/null +++ b/wool/tests/stdlib_parity/conftest.py @@ -0,0 +1,78 @@ +import asyncio +import contextvars +import uuid + +import pytest +import pytest_asyncio + +from tests.helpers import scoped_context +from wool.runtime.context.var import ContextVar + + +@pytest.fixture( + params=["asyncio", "uvloop"], + ids=["asyncio", "uvloop"], +) +def event_loop_policy(request): + """Parametrize every parity test over the asyncio and uvloop event loops. + + Wool advertises context propagation "across every conformant event loop — + uvloop included". ``call_soon``/``call_later``, ``run_in_executor``, and + ``set_task_factory`` are loop-implemented, so a parity claim is only + substantiated when each test runs under both the default ``asyncio`` loop + and uvloop's Cython reimplementation. pytest-asyncio honors this fixture + when constructing the per-test loop. + """ + if request.param == "uvloop": + uvloop = pytest.importorskip("uvloop") + return uvloop.EventLoopPolicy() + return asyncio.DefaultEventLoopPolicy() + + +@pytest.fixture( + params=["stdlib", "wool"], + ids=["stdlib", "wool"], +) +def make_var(request): + """Return a factory constructing a stdlib or wool context variable. + + Parametrizes a value-propagation parity test over both variable types so + a single assertion proves ``wool.ContextVar`` propagates identically to + ``contextvars.ContextVar`` across the scheduling edge under test. The + factory takes a name stem and appends a process-unique suffix to avoid + ``wool.ContextVar`` registry collisions across tests. + """ + if request.param == "wool": + return lambda stem: ContextVar(f"{stem}_{uuid.uuid4().hex}") + return lambda stem: contextvars.ContextVar(f"{stem}_{uuid.uuid4().hex}") + + +@pytest.fixture(autouse=True) +def isolated_context(): + """Run each parity test under a fresh, unarmed Wool context. + + Resets the wool-owned context ``contextvars.ContextVar`` so a + :meth:`wool.ContextVar.set` in one test does not leak its armed + context into the next. + """ + with scoped_context(): + yield + + +@pytest_asyncio.fixture(autouse=True) +async def reset_task_factory(): + """Clear the running loop's task factory after each parity test. + + Several parity tests install Wool's task factory, a user factory, + or a deliberately-broken legacy factory on the per-test event loop. + pytest-asyncio's loop teardown (``loop.shutdown_asyncgens``) routes + its own task creation through whatever factory is left installed — + a broken factory left in place crashes teardown. Clearing the + factory on teardown keeps a test's factory mutation from poisoning + loop teardown or any later test sharing the loop. + """ + yield + try: + asyncio.get_running_loop().set_task_factory(None) + except RuntimeError: # pragma: no cover — no running loop for a sync test + pass diff --git a/wool/tests/stdlib_parity/test_async_gen_aclose.py b/wool/tests/stdlib_parity/test_async_gen_aclose.py deleted file mode 100644 index 1388622a..00000000 --- a/wool/tests/stdlib_parity/test_async_gen_aclose.py +++ /dev/null @@ -1,77 +0,0 @@ -"""Stdlib parity pins for ``async-generator.aclose`` semantics. - -These tests assert observations about CPython's own -``await agen.aclose()`` behavior. They are intentionally NOT tests of -any wool code — they pin stdlib semantics so that a future change in -CPython's async-generator close protocol fails here first, signaling -that the paired :func:`wool.runtime.routine.task.routine_scope` -regression tests (and the helper's contract) may need to be revisited. - -Companion to :class:`tests.runtime.routine.test_task.TestRoutineScope`, -which exercises ``routine_scope`` against the same parity assumptions. -""" - -import asyncio - -import pytest - - -class TestAsyncGenAcloseParity: - @pytest.mark.asyncio - async def test_aclose_propagates_internal_cancelled(self): - """Test ``aclose`` propagates internal CancelledError. - - Given: - A direct ``asyncio`` async generator that raises - :class:`asyncio.CancelledError` during aclose unwind - while the awaiting task's ``cancelling()`` count is 0. - When: - ``await agen.aclose()`` is invoked after one iteration. - Then: - It should raise :class:`asyncio.CancelledError`. - """ - - # Arrange - async def naughty_gen(): - try: - yield 1 - yield 2 - except GeneratorExit: - raise asyncio.CancelledError() - - agen = naughty_gen() - await agen.__anext__() - - # Act & assert - with pytest.raises(asyncio.CancelledError): - await agen.aclose() - - @pytest.mark.asyncio - async def test_aclose_raises_runtime_error_when_yielding_during_ge(self): - """Test ``aclose`` raises RuntimeError when the routine yields during GeneratorExit. - - Given: - A direct ``asyncio`` async generator that catches - :class:`GeneratorExit` and yields a value (a PEP 525 - protocol violation). - When: - ``await agen.aclose()`` is invoked after one iteration. - Then: - It should raise - ``RuntimeError("async generator ignored GeneratorExit")``. - """ - - # Arrange - async def yielding_gen(): - try: - yield 1 - yield 2 - except GeneratorExit: - yield "rude" - - agen = yielding_gen() - await agen.__anext__() - - # Act & assert - with pytest.raises(RuntimeError, match="ignored GeneratorExit"): - await agen.aclose() diff --git a/wool/tests/stdlib_parity/test_context_parity.py b/wool/tests/stdlib_parity/test_context_parity.py new file mode 100644 index 00000000..861206fe --- /dev/null +++ b/wool/tests/stdlib_parity/test_context_parity.py @@ -0,0 +1,403 @@ +"""Stdlib parity pins for the headline v2 ``wool.ContextVar`` guarantees. + +The v2 context model backs each :class:`wool.ContextVar` with its own +stdlib :class:`contextvars.ContextVar`, so the surface guarantees that +distinguished v2 — ``reset`` rejected across a different +:class:`contextvars.Context`, single-use ``Token`` enforcement, +``copy_context()`` propagation and copy-on-write isolation, and the +``1 + N`` armed-context width — are now *native* stdlib behaviors. + +These tests pin each of those side-by-side with a stdlib +:class:`contextvars.ContextVar` via the ``make_var`` fixture, so one +assertion proves the wool variable behaves identically. The +``1 + N``-width and async-generator value tests round out the v2 +contract. The loop-driven tests additionally run under both the +default ``asyncio`` loop and uvloop, via the ``event_loop_policy`` +fixture in ``conftest.py``; the ``1 + N``-width tests are pure +:class:`contextvars.Context` enumeration and need no running loop. +""" + +import contextvars +import uuid + +import pytest + +import wool +from wool.runtime.context.var import ContextVar + +pytestmark = pytest.mark.stdlib_parity + + +def _unique(stem: str) -> str: + """Return a process-unique variable name to avoid registry collisions.""" + return f"{stem}_{uuid.uuid4().hex}" + + +def _wool_owned(context: contextvars.Context) -> list[contextvars.ContextVar]: + """Return the wool-owned stdlib variables enumerated by *context*. + + Every wool-owned :class:`contextvars.ContextVar` carries an + ``__wool`` name prefix (``__wool_chain__`` for the chain + context, ``__wool_var__:{ns}:{name}`` for each backing), so a + plain iteration over the context distinguishes them from + application variables. + """ + return [var for var in context if var.name.startswith("__wool")] + + +def _wool_owned_in_a_fresh_context(work) -> list[str]: + """Run *work* in a brand-new Chain and list the wool-owned variables. + + A fresh :class:`contextvars.Context` carries none of the backing + variables earlier work on the running thread may have left set, so + the result is exactly the wool-owned variables *work* itself binds — + making the count immune to cross-test backing-variable leakage. + """ + holder: list[list[str]] = [] + + def _runner() -> None: + work() + copied = contextvars.copy_context() + holder.append([v.name for v in _wool_owned(copied)]) + + contextvars.Context().run(_runner) + return holder[0] + + +class TestUnarmedChainLoudFail: + def test_get_should_raise_lookup_error_when_chain_unarmed(self): + """Test the Wool chain variable raises LookupError when unarmed. + + Given: + A fresh contextvars.Context in which Wool's chain variable was + never armed, mirroring a stdlib contextvars.ContextVar declared + without a default. + When: + wool.__chain__.get() is read with no default argument. + Then: + It should raise LookupError exactly as the defaultless stdlib + variable does — the chain variable lost its default, so unarmed + access fails loudly rather than returning None, the contract + behind every .get(None) read. + """ + # Arrange + stdlib_var: contextvars.ContextVar[object] = contextvars.ContextVar( + _unique("parity_no_default") + ) + + def _assert_loud() -> None: + with pytest.raises(LookupError): + stdlib_var.get() + with pytest.raises(LookupError): + wool.__chain__.get() + + # Act & assert — in a fresh context neither defaultless variable + # has a value, so a bare ``get()`` raises. + contextvars.Context().run(_assert_loud) + + +class TestResetCrossContextRejection: + @pytest.mark.asyncio + async def test_reset_should_raise_value_error_when_in_foreign_context( + self, make_var + ): + """Test reset of a token in a different Chain raises ValueError. + + This is the headline v2 fix: a token is bound to the + :class:`contextvars.Context` it was minted in, and resetting it + elsewhere is rejected natively. + + Given: + A context variable set in the current context, yielding a + token. + When: + The token is reset inside a copy_context().run() — a + different contextvars.Context than the one it was minted + in. + Then: + It should raise :class:`ValueError`, identically for a + stdlib and a wool variable. + """ + # Arrange + var = make_var("reset_xctx") + token = var.set("value") + foreign = contextvars.copy_context() + + # Act & assert + with pytest.raises(ValueError): + foreign.run(lambda: var.reset(token)) + + @pytest.mark.asyncio + async def test_reset_should_restore_prior_value_when_in_minting_context( + self, make_var + ): + """Test reset of a token in its own Chain restores the prior value. + + Given: + A context variable set in the current context, yielding a + token. + When: + The token is reset in the same contextvars.Context it was + minted in. + Then: + It should restore the variable's prior state without + raising, identically for a stdlib and a wool variable. + """ + # Arrange + var = make_var("reset_same") + token = var.set("value") + + # Act + var.reset(token) + + # Assert + assert var.get("unset") == "unset" + + +class TestResetWrongVariableRejection: + @pytest.mark.asyncio + async def test_reset_should_raise_value_error_when_token_from_another_variable( + self, make_var + ): + """Test resetting variable B with variable A's token raises ValueError. + + Given: + Two distinct context variables, A set to yield a token. + When: + B.reset is called with A's token. + Then: + It should raise :class:`ValueError` — a token is bound to + the variable that minted it, identically for a stdlib and a + wool variable. + """ + # Arrange + var_a = make_var("wrong_a") + var_b = make_var("wrong_b") + token_a = var_a.set("a") + + # Act & assert + with pytest.raises(ValueError): + var_b.reset(token_a) + + +class TestTokenSingleUse: + @pytest.mark.asyncio + async def test_resetting_a_token_twice_should_raise_runtime_error(self, make_var): + """Test resetting the same token twice raises RuntimeError. + + Given: + A context variable set once, yielding a token, then reset + with that token. + When: + The same token is reset a second time. + Then: + It should raise :class:`RuntimeError` — tokens are + single-use, identically for a stdlib and a wool variable. + """ + # Arrange + var = make_var("single_use") + token = var.set("value") + var.reset(token) + + # Act & assert + with pytest.raises(RuntimeError): + var.reset(token) + + +class TestCopyContextPropagation: + @pytest.mark.asyncio + async def test_value_set_before_copy_should_be_visible_in_ctx_run(self, make_var): + """Test a value set before copy_context() is visible inside ctx.run. + + Given: + A context variable set before a copy_context() is taken. + When: + A function is invoked through the copy's run(). + Then: + It should observe the pre-copy value — copy_context() + snapshots the bindings, identically for a stdlib and a wool + variable. + """ + # Arrange + var = make_var("copy_visible") + var.set("before-copy") + copy = contextvars.copy_context() + + # Act + observed = copy.run(var.get) + + # Assert + assert observed == "before-copy" + + @pytest.mark.asyncio + async def test_set_inside_ctx_run_should_not_leak_out(self, make_var): + """Test a set inside copy_context().run() does not leak to the caller. + + Given: + A context variable set in the caller, then a copy_context() + taken. + When: + The variable is set to a new value inside the copy's run(). + Then: + The caller should still observe its original value — native + contextvars copy-on-write isolates the copy's write, + identically for a stdlib and a wool variable. + """ + # Arrange + var = make_var("copy_isolate") + var.set("caller") + copy = contextvars.copy_context() + + def mutate() -> str: + var.set("inside-run") + return var.get() + + # Act + inside = copy.run(mutate) + + # Assert + assert inside == "inside-run" + assert var.get() == "caller" + + +class TestCopyContextWidth: + def test_unarmed_context_should_carry_no_wool_variables(self): + """Test an unarmed context enumerates no wool-owned variables. + + Given: + A brand-new context in which a wool.ContextVar is declared + but never set. + When: + A copy_context() of that context is enumerated. + Then: + It should carry no wool-owned contextvars.ContextVar — an + unarmed context is indistinguishable from a plain + contextvars.Context. + """ + + # Arrange + def declare_but_do_not_set() -> None: + ContextVar(_unique("width_unarmed")) + + # Act + wool_owned = _wool_owned_in_a_fresh_context(declare_but_do_not_set) + + # Assert + assert wool_owned == [] + + @pytest.mark.parametrize("n", [1, 2, 3]) + def test_armed_context_should_enumerate_one_plus_n_variables(self, n): + """Test an armed context enumerates 1 + N wool-owned variables. + + Given: + A brand-new context armed with N distinct bound + wool.ContextVars. + When: + A copy_context() of that context is enumerated for its + wool-owned variables. + Then: + It should enumerate exactly 1 + N — the one context + variable plus one backing variable per bound wool.ContextVar. + """ + # Arrange + bound = [ContextVar(_unique(f"width_armed_{index}")) for index in range(n)] + + def arm() -> None: + for index, var in enumerate(bound): + var.set(f"v{index}") + + # Act + wool_owned = _wool_owned_in_a_fresh_context(arm) + + # Assert + assert len(wool_owned) == 1 + n + assert wool_owned.count("__wool_chain__") == 1 + backing = [name for name in wool_owned if name != "__wool_chain__"] + assert len(backing) == n + + +class TestAsyncGeneratorValuePropagation: + @pytest.mark.asyncio + async def test_value_should_be_visible_across_asend_athrow_and_aclose( + self, make_var + ): + """Test a context variable is visible across every async-gen resumption. + + Given: + A context variable set before an async generator is driven. + When: + The generator is resumed via __anext__, asend, athrow, and + finally aclose. + Then: + It should observe the scope's value at every resumption + point — an async generator runs in the context active at + each step, identically for a stdlib and a wool variable. + """ + # Arrange + var = make_var("agen_value") + var.set("scope") + observed: list[tuple[str, str]] = [] + + async def generator(): + observed.append(("anext", var.get())) + yield 1 + observed.append(("asend", var.get())) + try: + yield 2 + except ValueError: + observed.append(("athrow", var.get())) + try: + yield 3 + except GeneratorExit: + observed.append(("aclose", var.get())) + raise + + gen = generator() + + # Act + await gen.__anext__() + await gen.asend("payload") + await gen.athrow(ValueError("boom")) + await gen.aclose() + + # Assert + assert observed == [ + ("anext", "scope"), + ("asend", "scope"), + ("athrow", "scope"), + ("aclose", "scope"), + ] + + @pytest.mark.asyncio + async def test_asend_should_return_next_yield_value(self, make_var): + """Test asend resumes the generator and returns the next yield. + + Given: + A context variable set before an async generator that + echoes the variable on each yield is driven. + When: + The generator is resumed with asend. + Then: + It should return the scope's value as the next yielded + item — value propagation rides asend's resumption, + identically for a stdlib and a wool variable. + """ + # Arrange + var = make_var("agen_asend") + var.set("scope") + + async def generator(): + while True: + yield var.get() + + gen = generator() + await gen.__anext__() + + # Act + resumed = await gen.asend(None) + + # Assert + assert resumed == "scope" + + # Cleanup + await gen.aclose() diff --git a/wool/tests/stdlib_parity/test_executor_offload.py b/wool/tests/stdlib_parity/test_executor_offload.py new file mode 100644 index 00000000..fb82cb94 --- /dev/null +++ b/wool/tests/stdlib_parity/test_executor_offload.py @@ -0,0 +1,332 @@ +"""Stdlib parity pins for ``wool.ContextVar`` behavior across OS-thread offload. + +These tests pin the boundary between cooperative loop work (which +shares a chain safely) and genuine OS-thread parallelism (which must +not). They cover the offload edges: + +- :meth:`loop.run_in_executor` — carries NO ``contextvars.Context`` of + its own, matching stdlib; a bare executor callable sees neither the + caller's stdlib variables nor a Wool chain. This holds for a + thread-pool worker and for a freshly spawned process-pool worker; a + fork-started worker is a memory clone and inherits the parent's + ``contextvars`` regardless, so the process-pool test pins ``spawn`` + explicitly — the start method Wool's own ``WorkerProcess`` enforces. +- :func:`asyncio.to_thread` — copies the caller's + :class:`contextvars.Context` (chain UUID and all) into the worker + thread; touching a :class:`wool.ContextVar` from an armed context + then trips :class:`wool.ChainContention`. +- :func:`wool.to_thread` — the supported alternative; forks a fresh + detached chain owned by the worker thread, with no merge-back. + +Also pins armed-gating: an unarmed context behaves as a plain +:class:`contextvars.Context` and incurs no guard. Every test +additionally runs under both the default ``asyncio`` loop and uvloop, +via the ``event_loop_policy`` fixture in ``conftest.py``. +""" + +import asyncio +import contextvars +import multiprocessing +import threading +import uuid +from concurrent.futures import ProcessPoolExecutor + +import pytest + +import wool +from tests.helpers import context_is_unarmed +from wool.runtime.context.exceptions import ChainContention +from wool.runtime.context.var import ContextVar + +pytestmark = pytest.mark.stdlib_parity + + +def _unique(stem: str) -> str: + """Return a process-unique variable name to avoid registry collisions.""" + return f"{stem}_{uuid.uuid4().hex}" + + +class TestRunInExecutorParity: + @pytest.mark.asyncio + async def test_stdlib_contextvar_should_be_absent_in_run_in_executor(self): + """Test a bare run_in_executor callable does not see a stdlib ContextVar. + + Given: + A plain contextvars.ContextVar set in the caller scope. + When: + loop.run_in_executor offloads a bare callable that reads it + with a fallback. + Then: + It should observe the fallback — run_in_executor carries no + contextvars.Context of its own. + """ + # Arrange + var: contextvars.ContextVar[str] = contextvars.ContextVar(_unique("std_rie")) + var.set("caller") + loop = asyncio.get_running_loop() + + # Act + observed = await loop.run_in_executor(None, lambda: var.get("")) + + # Assert + assert observed == "" + + @pytest.mark.asyncio + async def test_wool_contextvar_should_be_absent_in_run_in_executor(self): + """Test a bare run_in_executor callable carries no Wool chain. + + Given: + A wool.ContextVar set in an armed caller scope. + When: + loop.run_in_executor offloads a bare callable reading + current_context. + Then: + It should observe None — bare run_in_executor carries no + Wool context, matching stdlib. + """ + # Arrange + var = ContextVar(_unique("wool_rie")) + var.set("caller") # Arm the context. + loop = asyncio.get_running_loop() + + # Act + observed = await loop.run_in_executor(None, lambda: wool.__chain__.get(None)) + + # Assert + assert observed is None + + @pytest.mark.asyncio + async def test_run_in_executor_with_a_process_pool_should_carry_no_wool_context( + self, + ): + """Test a bare process-pool run_in_executor callable carries no Wool chain. + + Given: + A wool.ContextVar set in an armed caller scope. + When: + loop.run_in_executor offloads a bare callable to a + ProcessPoolExecutor backed by a freshly spawned worker + process that reads current_context. + Then: + It should observe no context in the worker process — a bare + run_in_executor carries no Wool context of its own, so a + spawned worker boots with no chain, matching the thread-pool + path and stdlib. + """ + # Arrange + var = ContextVar(_unique("wool_rie_proc")) + var.set("caller") # Arm the context. + loop = asyncio.get_running_loop() + + # Act + # Pin ``spawn``: a spawned worker boots a fresh interpreter, + # whereas a fork-started worker is a memory clone that inherits + # the parent's contextvars (the Wool context and plain stdlib + # vars alike), leaving no clean boundary to assert against. + # Wool's own WorkerProcess enforces ``spawn`` for this reason. + spawn = multiprocessing.get_context("spawn") + with ProcessPoolExecutor(max_workers=1, mp_context=spawn) as pool: + observed = await loop.run_in_executor(pool, context_is_unarmed) + + # Assert + assert observed is True + + +class TestAsyncioToThreadParity: + @pytest.mark.asyncio + async def test_stdlib_contextvar_should_be_visible_in_asyncio_to_thread(self): + """Test a stdlib ContextVar value is visible inside asyncio.to_thread. + + Given: + A plain contextvars.ContextVar set in the caller scope. + When: + asyncio.to_thread offloads a function reading it. + Then: + It should observe the caller's value — asyncio.to_thread + copies the caller's contextvars.Context. + """ + # Arrange + var: contextvars.ContextVar[str] = contextvars.ContextVar(_unique("std_a2t")) + var.set("caller") + + # Act + observed = await asyncio.to_thread(var.get) + + # Assert + assert observed == "caller" + + @pytest.mark.asyncio + async def test_plain_to_thread_should_raise_chain_contention_when_from_armed_context( + self, + ): + """Test asyncio.to_thread touching a wool.ContextVar trips the guard. + + Given: + An armed context whose chain is owned by the loop thread. + When: + asyncio.to_thread offloads a function that reads a + wool.ContextVar from a worker thread. + Then: + It should raise wool.ChainContention — the copied chain + reaches an OS thread that does not own it. + """ + # Arrange + var = ContextVar(_unique("a2t_guard")) + var.set("armed") + + # Act & assert + with pytest.raises(ChainContention, match="cannot be shared across OS threads"): + await asyncio.to_thread(var.get) + + @pytest.mark.asyncio + async def test_asyncio_to_thread_should_not_raise_when_unarmed_context(self): + """Test asyncio.to_thread touching a wool.ContextVar in an unarmed context. + + Given: + An unarmed context — no chain, no guard. + When: + asyncio.to_thread offloads a function reading a + wool.ContextVar with a default. + Then: + It should not raise — an unarmed context behaves as a + plain contextvars.Context. + """ + # Arrange + var = ContextVar(_unique("a2t_unarmed"), default="d") + + # Act + observed = await asyncio.to_thread(var.get) + + # Assert + assert observed == "d" + + +class TestWoolToThreadParity: + @pytest.mark.asyncio + async def test_wool_to_thread_should_carry_value_without_guard(self): + """Test wool.to_thread carries the caller's value without tripping the guard. + + Given: + An armed context with a wool.ContextVar set. + When: + wool.to_thread offloads a function reading it. + Then: + It should observe the caller's value and not raise + wool.ChainContention. + """ + # Arrange + var = ContextVar(_unique("w2t_value")) + var.set("caller") + + # Act + observed = await wool.to_thread(var.get) + + # Assert + assert observed == "caller" + + @pytest.mark.asyncio + async def test_wool_to_thread_should_not_raise_when_unarmed_context(self): + """Test wool.to_thread from an unarmed context does not raise. + + Given: + An unarmed context — no wool.ContextVar has been set. + When: + wool.to_thread offloads a function reading a wool.ContextVar + with a default. + Then: + It should return the default and not raise — an unarmed + context incurs no guard. + """ + # Arrange + var = ContextVar(_unique("w2t_unarmed"), default="default") + + # Act + observed = await wool.to_thread(var.get) + + # Assert + assert observed == "default" + + @pytest.mark.asyncio + async def test_wool_to_thread_should_run_on_fresh_detached_chain(self): + """Test wool.to_thread runs the offload on a fresh detached chain. + + Given: + An armed context whose chain id is known. + When: + wool.to_thread offloads a function reading + wool.__chain__.get().id. + Then: + It should differ from the caller's chain id. + """ + # Arrange + var = ContextVar(_unique("w2t_chain")) + var.set("x") # Arm the context. + caller = wool.__chain__.get(None) + assert caller is not None + + def read_chain() -> uuid.UUID: + context = wool.__chain__.get(None) + assert context is not None + return context.id + + # Act + offloaded_chain = await wool.to_thread(read_chain) + + # Assert + assert offloaded_chain != caller.id + + @pytest.mark.asyncio + async def test_wool_to_thread_should_run_on_worker_thread(self): + """Test wool.to_thread runs the offloaded function off the loop thread. + + Given: + An armed context and the running loop's thread id. + When: + wool.to_thread offloads a function reading its own thread + id. + Then: + It should differ from the loop thread's id. + """ + # Arrange + var = ContextVar(_unique("w2t_thread")) + var.set("x") # Arm the context. + loop_thread = threading.get_ident() + + # Act + offloaded_thread = await wool.to_thread(threading.get_ident) + + # Assert + assert offloaded_thread != loop_thread + + @pytest.mark.asyncio + async def test_wool_to_thread_should_inherit_copy_of_caller_bindings(self): + """Test wool.to_thread inherits a copy of the caller's bindings. + + Given: + An armed context with a wool.ContextVar set. + When: + wool.to_thread offloads a function that reads the variable, + sets it, and reads it again. + Then: + It should observe the caller's value first (the detached + chain inherited a copy of the caller's bindings), then its + own mutation — while the caller's value stays unchanged, so + the mutation does not merge back. + """ + # Arrange + var = ContextVar(_unique("w2t_no_merge")) + var.set("caller") + + def mutate() -> tuple[str, str]: + before = var.get() + var.set("thread") + after = var.get() + return before, after + + # Act + before, after = await wool.to_thread(mutate) + + # Assert + assert before == "caller" + assert after == "thread" + assert var.get() == "caller" diff --git a/wool/tests/stdlib_parity/test_loop_callbacks.py b/wool/tests/stdlib_parity/test_loop_callbacks.py new file mode 100644 index 00000000..e3c60f88 --- /dev/null +++ b/wool/tests/stdlib_parity/test_loop_callbacks.py @@ -0,0 +1,411 @@ +"""Stdlib parity pins for ``wool.ContextVar`` propagation into loop callbacks. + +A :class:`wool.ContextVar` value rides in a single wool-owned stdlib +:class:`contextvars.ContextVar`, so it propagates with stdlib +visibility into every event-loop scheduling edge that contexts a +:class:`contextvars.Context`: :meth:`loop.call_soon`, +:meth:`call_soon_threadsafe`, :meth:`call_later`, :meth:`call_at`, +:meth:`loop.add_reader` / :meth:`add_writer`, +:meth:`add_signal_handler`, and :meth:`Future.add_done_callback`. + +The value-propagation tests take the ``make_var`` fixture and run once +per variable type (:class:`contextvars.ContextVar` and +:class:`wool.ContextVar`), so a single assertion proves the two behave +identically. They are additionally parametrized over the scheduling +edge, so one body pins every callback type. The ``wool``-only tests +additionally pin chain identity: unlike child tasks, callbacks share +the scheduling scope's chain (no fork) — they run cooperatively on the +owning thread. + +Every test additionally runs under both the default ``asyncio`` loop +and uvloop, via the ``event_loop_policy`` fixture in ``conftest.py``. +""" + +import asyncio +import contextlib +import os +import signal +import socket +import threading +import uuid +from collections.abc import Callable +from collections.abc import Iterator +from typing import TypeVar + +import pytest + +import wool +from wool.runtime.context.var import ContextVar + +pytestmark = pytest.mark.stdlib_parity + +_T = TypeVar("_T") + +# Loose upper bound for awaiting a callback's result. A conformant callback +# fires near-instantly; the timeout is only a backstop so that a callback +# which never resolves the future (for example, one that was never scheduled) +# fails the test fast instead of hanging the suite. +_CALLBACK_TIMEOUT = 5.0 + + +def _unique(stem: str) -> str: + """Return a process-unique variable name to avoid registry collisions.""" + return f"{stem}_{uuid.uuid4().hex}" + + +def _resolve(future: asyncio.Future[_T], produce: Callable[[], _T]) -> None: + """Route ``produce()``'s result or exception onto ``future``. + + Used as a loop callback so that a failure inside ``produce()`` — such + as a ``LookupError`` from a variable that failed to propagate — surfaces + as a failure on ``future`` instead of being swallowed by the loop's + exception handler, which would leave the awaiting test hanging. + """ + if future.done(): + return + try: + future.set_result(produce()) + except BaseException as exc: # noqa: BLE001 — route every failure to the future + future.set_exception(exc) + + +def _scope_chain_id() -> uuid.UUID: + """Return the active context's chain id, asserting the scope is armed.""" + context = wool.__chain__.get(None) + assert context is not None + return context.id + + +class _CallbackScheduler: + """One event-loop scheduling edge wired to drive a single callback. + + Each instance knows how to register *callback* (a zero-argument + function) on its scheduling edge and how to tear that registration + down afterward. Built fresh per test by :func:`scheduler`, so a + single parametrized test body pins every callback type. + """ + + def __init__(self, name: str): + self.name = name + + @contextlib.contextmanager + def schedule( + self, loop: asyncio.AbstractEventLoop, callback: Callable[[], None] + ) -> Iterator[None]: + """Register *callback* on this edge; clean up on block exit.""" + raise NotImplementedError # pragma: no cover — overridden per edge + + +class _CallSoonScheduler(_CallbackScheduler): + @contextlib.contextmanager + def schedule(self, loop, callback): + loop.call_soon(callback) + yield + + +class _CallLaterScheduler(_CallbackScheduler): + @contextlib.contextmanager + def schedule(self, loop, callback): + loop.call_later(0, callback) + yield + + +class _CallAtScheduler(_CallbackScheduler): + @contextlib.contextmanager + def schedule(self, loop, callback): + loop.call_at(loop.time(), callback) + yield + + +class _AddReaderScheduler(_CallbackScheduler): + @contextlib.contextmanager + def schedule(self, loop, callback): + reader, writer = socket.socketpair() + writer.send(b"x") + + def on_readable() -> None: + loop.remove_reader(reader.fileno()) + callback() + + loop.add_reader(reader.fileno(), on_readable) + try: + yield + finally: + reader.close() + writer.close() + + +class _AddWriterScheduler(_CallbackScheduler): + @contextlib.contextmanager + def schedule(self, loop, callback): + left, right = socket.socketpair() + + def on_writable() -> None: + loop.remove_writer(left.fileno()) + callback() + + loop.add_writer(left.fileno(), on_writable) + try: + yield + finally: + left.close() + right.close() + + +class _AddSignalHandlerScheduler(_CallbackScheduler): + @contextlib.contextmanager + def schedule(self, loop, callback): + def on_signal() -> None: + loop.remove_signal_handler(signal.SIGUSR1) + callback() + + loop.add_signal_handler(signal.SIGUSR1, on_signal) + os.kill(os.getpid(), signal.SIGUSR1) + yield + + +class _DoneCallbackScheduler(_CallbackScheduler): + @contextlib.contextmanager + def schedule(self, loop, callback): + future: asyncio.Future[None] = loop.create_future() + future.add_done_callback(lambda _: callback()) + future.set_result(None) + yield + + +_SCHEDULERS = [ + _CallSoonScheduler("call_soon"), + _CallLaterScheduler("call_later"), + _CallAtScheduler("call_at"), + _AddReaderScheduler("add_reader"), + _AddWriterScheduler("add_writer"), + _AddSignalHandlerScheduler("add_signal_handler"), + _DoneCallbackScheduler("add_done_callback"), +] + + +@pytest.fixture(params=_SCHEDULERS, ids=lambda s: s.name) +def scheduler(request) -> _CallbackScheduler: + """Return one event-loop scheduling edge to pin under parity. + + Parametrizes a callback-propagation test over every loop edge that + contexts a :class:`contextvars.Context` — + ``call_soon``/``call_later``/``call_at``, + ``add_reader``/``add_writer``, ``add_signal_handler``, and + ``Future.add_done_callback`` — so one test body pins them all. + """ + return request.param + + +class TestCallbackValuePropagationParity: + @pytest.mark.asyncio + async def test_callback_should_propagate_scoped_value(self, scheduler, make_var): + """Test a context variable propagates into a loop callback. + + Given: + A context variable set in the scheduling scope. + When: + A callback scheduled on the loop edge under test reads it. + Then: + It should observe the scheduling scope's value, identically + for a stdlib and a wool variable and for every loop edge. + """ + # Arrange + var = make_var("cb_value") + var.set("scope") + loop = asyncio.get_running_loop() + done: asyncio.Future[str] = loop.create_future() + + # Act + with scheduler.schedule(loop, lambda: _resolve(done, var.get)): + observed = await asyncio.wait_for(done, timeout=_CALLBACK_TIMEOUT) + + # Assert + assert observed == "scope" + + @pytest.mark.asyncio + async def test_callback_should_observe_registration_time_value( + self, scheduler, make_var + ): + """Test a loop callback observes the registration-time value. + + Given: + A context variable set to one value before the callback is + registered, then mutated to a different value after + registration but before the callback fires. + When: + A callback on the loop edge under test reads the variable. + Then: + It should observe the registration-time value, not the + post-mutation value — every loop edge copies the context at + registration, identically for a stdlib and a wool variable. + """ + # Arrange + var = make_var("cb_snap") + var.set("at-registration") + loop = asyncio.get_running_loop() + done: asyncio.Future[str] = loop.create_future() + + # Act + with scheduler.schedule(loop, lambda: _resolve(done, var.get)): + var.set("after-mutation") + observed = await asyncio.wait_for(done, timeout=_CALLBACK_TIMEOUT) + + # Assert + assert observed == "at-registration" + + @pytest.mark.asyncio + async def test_callback_should_preserve_scope_chain(self, scheduler): + """Test a loop callback shares the scheduling scope's chain. + + Given: + A wool.ContextVar set in an armed scheduling scope. + When: + A callback on the loop edge under test reads + wool.__chain__.get().id. + Then: + It should observe the scheduling scope's chain id — a + cooperatively-scheduled callback shares the chain, it is + not forked onto a fresh one, for every loop edge. + """ + # Arrange + var = ContextVar(_unique("cb_chain")) + var.set("scope") + scope_chain = _scope_chain_id() + loop = asyncio.get_running_loop() + done: asyncio.Future[uuid.UUID] = loop.create_future() + + # Act + with scheduler.schedule(loop, lambda: _resolve(done, _scope_chain_id)): + observed = await asyncio.wait_for(done, timeout=_CALLBACK_TIMEOUT) + + # Assert + assert observed == scope_chain + + @pytest.mark.asyncio + async def test_callback_set_should_not_leak_to_scheduling_scope( + self, scheduler, make_var + ): + """Test a callback's set does not leak back to the scheduling scope. + + This is Pitfall 2: the callback runs in a copy_context() copy, + so its write stays in that copy. + + Given: + A context variable set in the scheduling scope. + When: + A callback on the loop edge under test sets the variable to + a new value. + Then: + The scheduling scope should still observe its original + value — native contextvars copy-on-write isolates the + callback's write, identically for a stdlib and a wool + variable and for every loop edge. + """ + # Arrange + var = make_var("cb_isolate") + var.set("scope") + loop = asyncio.get_running_loop() + done: asyncio.Future[None] = loop.create_future() + + def mutate() -> None: + var.set("callback-local") + + # Act + with scheduler.schedule(loop, lambda: _resolve(done, mutate)): + await asyncio.wait_for(done, timeout=_CALLBACK_TIMEOUT) + + # Assert + assert var.get() == "scope" + + @pytest.mark.asyncio + async def test_callback_lookup_error_should_surface_on_future( + self, scheduler, make_var + ): + """Test a LookupError raised inside a callback surfaces correctly. + + Given: + A context variable with no value and no default. + When: + A callback on the loop edge under test reads it. + Then: + The read's :class:`LookupError` should surface as the + awaiting future's exception, identically for a stdlib and a + wool variable and for every loop edge. + """ + # Arrange + var = make_var("cb_lookup") + loop = asyncio.get_running_loop() + done: asyncio.Future[object] = loop.create_future() + + # Act & assert + with scheduler.schedule(loop, lambda: _resolve(done, var.get)): + with pytest.raises(LookupError): + await asyncio.wait_for(done, timeout=_CALLBACK_TIMEOUT) + + +class TestCallSoonThreadsafeParity: + @pytest.mark.asyncio + async def test_call_soon_threadsafe_should_observe_fallback_when_from_foreign_thread( + self, make_var + ): + """Test call_soon_threadsafe does not capture the loop thread's context. + + Given: + A context variable set on the loop thread and a separate OS + thread that never entered that scope. + When: + The foreign thread schedules a callback via + loop.call_soon_threadsafe that reads the variable with a + fallback. + Then: + It should observe the fallback — call_soon_threadsafe + captures the scheduling thread's context, not the loop + thread's, identically for a stdlib and a wool variable. + """ + # Arrange + var = make_var("cst") + var.set("loop-scope") + loop = asyncio.get_running_loop() + done: asyncio.Future[str] = loop.create_future() + + def schedule_from_foreign_thread() -> None: + loop.call_soon_threadsafe(_resolve, done, lambda: var.get("fallback")) + + # Act + worker = threading.Thread(target=schedule_from_foreign_thread) + worker.start() + worker.join(timeout=_CALLBACK_TIMEOUT) + observed = await asyncio.wait_for(done, timeout=_CALLBACK_TIMEOUT) + + # Assert + assert observed == "fallback" + + @pytest.mark.asyncio + async def test_call_soon_threadsafe_should_observe_value_when_from_setting_thread( + self, make_var + ): + """Test call_soon_threadsafe from the setting thread sees the value. + + Given: + A context variable set on the loop thread. + When: + The loop thread itself schedules a callback via + loop.call_soon_threadsafe that reads the variable. + Then: + It should observe the loop scope's value — scheduling from + the thread that set the variable carries that thread's + context, identically for a stdlib and a wool variable. + """ + # Arrange + var = make_var("cst_self") + var.set("loop-scope") + loop = asyncio.get_running_loop() + done: asyncio.Future[str] = loop.create_future() + + # Act + loop.call_soon_threadsafe(_resolve, done, var.get) + observed = await asyncio.wait_for(done, timeout=_CALLBACK_TIMEOUT) + + # Assert + assert observed == "loop-scope" diff --git a/wool/tests/stdlib_parity/test_task_creation.py b/wool/tests/stdlib_parity/test_task_creation.py new file mode 100644 index 00000000..bc803d3a --- /dev/null +++ b/wool/tests/stdlib_parity/test_task_creation.py @@ -0,0 +1,979 @@ +"""Stdlib parity pins for ``wool.ContextVar`` propagation into child tasks. + +These tests pin BOTH stdlib ``contextvars`` behavior and the Wool +parity: a :class:`wool.ContextVar` value set in a parent context must +be visible inside child tasks created through every asyncio task- +spawning edge — :func:`asyncio.create_task`, +:meth:`loop.create_task`, :func:`asyncio.ensure_future`, +:func:`asyncio.gather`, and :class:`asyncio.TaskGroup`. Unlike a plain +:class:`contextvars.ContextVar` (which a child task observes under the +*same* :class:`contextvars.Context` copy), a Wool child task runs on a +freshly-minted ``chain_id`` — copy-on-fork — while still inheriting the +parent's variable values. + +The value-propagation tests take the ``make_var`` fixture and run once +per variable type, so a single assertion proves the two propagate +identically; the copy-on-fork tests parametrize over the spawn +entrypoint, so one body pins every edge. The ``wool``-only tests +additionally pin copy-on-fork: each child, and each of two siblings, +forks onto a distinct chain. + +Test classes group by parity concern rather than by a single +production class — this is a cross-cutting parity suite with no one +class under test. Every test additionally runs under both the default +``asyncio`` loop and uvloop, via the ``event_loop_policy`` fixture in +``conftest.py``. + +A future change to CPython's task-context copy semantics, or to Wool's +task factory, fails here first. +""" + +import asyncio +import contextvars +import threading +import types +import uuid +import warnings +from collections.abc import Callable +from collections.abc import Coroutine +from typing import Any +from typing import TypeVar + +import pytest + +import wool +from wool.runtime.context.exceptions import ChainContention +from wool.runtime.context.exceptions import TaskFactoryDisplaced +from wool.runtime.context.factory import install_task_factory +from wool.runtime.context.var import ContextVar + +pytestmark = pytest.mark.stdlib_parity + +_T = TypeVar("_T") + +_CoroFactory = Callable[[], Coroutine[Any, Any, _T]] + + +def _unique(stem: str) -> str: + """Return a process-unique variable name to avoid registry collisions.""" + return f"{stem}_{uuid.uuid4().hex}" + + +class _Spawner: + """One asyncio task-spawning entrypoint, driven uniformly. + + Each instance knows how to run a batch of coroutine factories + through its entrypoint and return their results in order, so a + single parametrized test body pins every spawn edge. + """ + + def __init__(self, name: str): + self.name = name + + async def run(self, factories: list[_CoroFactory[_T]]) -> list[_T]: + """Spawn one child per factory and return their results in order.""" + raise NotImplementedError # pragma: no cover — overridden per edge + + +class _CreateTaskSpawner(_Spawner): + async def run(self, factories): + tasks = [asyncio.create_task(factory()) for factory in factories] + return list(await asyncio.gather(*tasks)) + + +class _LoopCreateTaskSpawner(_Spawner): + async def run(self, factories): + loop = asyncio.get_running_loop() + tasks = [loop.create_task(factory()) for factory in factories] + return list(await asyncio.gather(*tasks)) + + +class _EnsureFutureSpawner(_Spawner): + async def run(self, factories): + futures = [asyncio.ensure_future(factory()) for factory in factories] + return list(await asyncio.gather(*futures)) + + +class _GatherSpawner(_Spawner): + async def run(self, factories): + return list(await asyncio.gather(*(factory() for factory in factories))) + + +class _TaskGroupSpawner(_Spawner): + async def run(self, factories): + tasks: list[asyncio.Task[Any]] = [] + async with asyncio.TaskGroup() as tg: + tasks = [tg.create_task(factory()) for factory in factories] + return [task.result() for task in tasks] + + +_SPAWNERS = [ + _CreateTaskSpawner("create_task"), + _LoopCreateTaskSpawner("loop.create_task"), + _EnsureFutureSpawner("ensure_future"), + _GatherSpawner("gather"), + _TaskGroupSpawner("TaskGroup"), +] + + +@pytest.fixture(params=_SPAWNERS, ids=lambda s: s.name) +def spawner(request) -> _Spawner: + """Return one asyncio task-spawning entrypoint to pin under parity. + + Parametrizes a task-creation test over every spawn edge — + :func:`asyncio.create_task`, :meth:`loop.create_task`, + :func:`asyncio.ensure_future`, :func:`asyncio.gather`, and + :class:`asyncio.TaskGroup` — so one test body pins them all. + """ + return request.param + + +class TestTaskCreationValuePropagationParity: + @pytest.mark.asyncio + async def test_child_should_observe_scoped_value(self, spawner, make_var): + """Test a context variable value is visible in a child task. + + Given: + A context variable set in the parent. + When: + A child created through the spawn edge under test reads it. + Then: + It should observe the parent's value, identically for a + stdlib and a wool variable and for every spawn edge. + """ + # Arrange + var = make_var("ct_value") + var.set("parent") + + async def child() -> str: + return var.get() + + # Act + observed = await spawner.run([child]) + + # Assert + assert observed == ["parent"] + + +class TestTaskCreationCopyOnFork: + @pytest.mark.asyncio + async def test_child_should_fork_fresh_chain(self, spawner): + """Test a child task forks onto a fresh chain. + + Given: + A wool.ContextVar set in an armed parent with the task + factory installed. + When: + A child created through the spawn edge under test reads the + value and wool.__chain__.get().id. + Then: + It should observe the parent's value on a chain id distinct + from the parent's, and leave the parent's own chain + unchanged — copy-on-fork, for every spawn edge. + """ + # Arrange + install_task_factory() + var = ContextVar(_unique("fork_chain")) + var.set("parent") + parent = wool.__chain__.get(None) + assert parent is not None + + async def child() -> tuple[str, uuid.UUID]: + context = wool.__chain__.get(None) + assert context is not None + return var.get(), context.id + + # Act + (observed,) = await spawner.run([child]) + + # Assert + assert observed[0] == "parent" + assert observed[1] != parent.id + after = wool.__chain__.get(None) + assert after is not None + assert after.id == parent.id + + @pytest.mark.asyncio + async def test_siblings_should_fork_distinct_chains(self, spawner): + """Test sibling child tasks each fork onto a distinct chain. + + Given: + A wool.ContextVar set in an armed parent with the task + factory installed. + When: + Two sibling children are created through the spawn edge + under test. + Then: + It should give each sibling a chain id distinct from the + parent's and from the other sibling's — copy-on-fork mints a + fresh chain per task, never one chain shared across + siblings, for every spawn edge. + """ + # Arrange + install_task_factory() + var = ContextVar(_unique("fork_siblings")) + var.set("parent") + parent = wool.__chain__.get(None) + assert parent is not None + + async def child() -> uuid.UUID: + context = wool.__chain__.get(None) + assert context is not None + return context.id + + # Act + first, second = await spawner.run([child, child]) + + # Assert + assert first != parent.id + assert second != parent.id + assert first != second + + @pytest.mark.asyncio + async def test_factory_should_be_dormant_when_unarmed(self, spawner): + """Test the task factory is dormant when the context is unarmed. + + Given: + The Wool task factory installed but no wool.ContextVar set + (unarmed context). + When: + A child is created through the spawn edge under test. + Then: + It should observe no context inside the child — the factory + is dormant when unarmed, for every spawn edge. + """ + # Arrange + install_task_factory() + + async def child() -> bool: + return wool.__chain__.get(None) is None + + # Act + (context_is_none,) = await spawner.run([child]) + + # Assert + assert context_is_none + + @pytest.mark.asyncio + async def test_unarmed_factory_should_preserve_plain_contextvars(self, spawner): + """Test the installed task factory leaves plain contextvars intact. + + Given: + The Wool task factory installed, no wool.ContextVar ever set + (an unarmed context), and a plain contextvars.ContextVar set + in the parent. + When: + A child created through the spawn edge under test reads the + plain variable and current_context. + Then: + It should observe the parent's plain value and no Wool + context — installing the factory costs an unarmed context + nothing, behaving as a plain contextvars.Context, for every + spawn edge. + """ + # Arrange + install_task_factory() + var: contextvars.ContextVar[str] = contextvars.ContextVar(_unique("unarmed")) + var.set("parent") + + async def child() -> tuple[str, bool]: + return var.get(), wool.__chain__.get(None) is None + + # Act + ((observed_value, context_is_none),) = await spawner.run([child]) + + # Assert + assert observed_value == "parent" + assert context_is_none + + +class TestTaskFactoryComposition: + @pytest.mark.asyncio + async def test_user_factory_should_fork_when_installed_before_wool(self): + """Test a user factory installed before Wool's still yields copy-on-fork. + + Given: + A user task factory installed on the loop, then Wool's task + factory installed after it (composing over the user one), + and an armed wool.ContextVar. + When: + A child task is created and reads the parent value and its + own chain id. + Then: + It should observe the parent's value on a forked chain, and + the user factory should have run — Wool composes over the + user factory rather than displacing it. + """ + # Arrange + loop = asyncio.get_running_loop() + user_factory_calls: list[bool] = [] + + def user_factory(loop, coro, **kwargs): + user_factory_calls.append(True) + return asyncio.Task(coro, loop=loop, **kwargs) + + loop.set_task_factory(user_factory) + install_task_factory() + var = ContextVar(_unique("compose_value")) + var.set("parent") + parent = wool.__chain__.get(None) + assert parent is not None + + async def child() -> tuple[str, uuid.UUID]: + context = wool.__chain__.get(None) + assert context is not None + return var.get(), context.id + + # Act + observed_value, child_chain = await asyncio.create_task(child()) + + # Assert + assert observed_value == "parent" + assert child_chain != parent.id + assert user_factory_calls + + @pytest.mark.asyncio + async def test_stdlib_value_should_propagate_under_composition(self): + """Test a stdlib ContextVar propagates under a composed factory. + + Given: + A user task factory installed before Wool's, an armed + wool.ContextVar, and a plain contextvars.ContextVar set in + the parent. + When: + A child task created under the composed factory reads the + plain variable. + Then: + It should observe the parent's stdlib value — composition + preserves native contextvars propagation. + """ + # Arrange + loop = asyncio.get_running_loop() + + def user_factory(loop, coro, **kwargs): + return asyncio.Task(coro, loop=loop, **kwargs) + + loop.set_task_factory(user_factory) + install_task_factory() + wool_var = ContextVar(_unique("compose_arm")) + wool_var.set("arm") + std_var: contextvars.ContextVar[str] = contextvars.ContextVar( + _unique("compose_std") + ) + std_var.set("parent") + + async def child() -> str: + return std_var.get() + + # Act + observed = await asyncio.create_task(child()) + + # Assert + assert observed == "parent" + + @pytest.mark.asyncio + async def test_legacy_two_arg_inner_factory_should_raise_type_error(self): + """Test a legacy 2-arg inner factory raises TypeError under composition. + + Given: + A legacy task factory with the two-argument ``(loop, coro)`` + signature installed, then Wool's factory composed over it, + and an armed wool.ContextVar so Wool wraps and forwards + ``context=`` to the inner factory. + When: + A child task is created. + Then: + It should raise :class:`TypeError` — Wool always forwards + ``context=`` and a legacy 2-arg factory cannot accept it. + """ + # Arrange + loop = asyncio.get_running_loop() + + def legacy_factory(loop, coro): + return asyncio.Task(coro, loop=loop) + + loop.set_task_factory(legacy_factory) + install_task_factory() + var = ContextVar(_unique("legacy_arm")) + var.set("arm") + + async def child() -> None: + return None + + # Act & assert — clear the broken composed factory before exit + # so the loop's own teardown does not route through it. The + # legacy factory rejects the ``context=`` Wool forwards, so the + # wrapped ``_forked_scope`` coroutine Wool built is never + # awaited; that leaked-coroutine RuntimeWarning is incidental to + # this error path and is filtered here so it does not mask the + # TypeError under assertion. + try: + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", "coroutine .* was never awaited", RuntimeWarning + ) + with pytest.raises(TypeError): + await asyncio.create_task(child()) + finally: + loop.set_task_factory(None) + + @pytest.mark.asyncio + async def test_factory_installed_after_wool_should_raise_displacement(self): + """Test a factory installed after Wool's emits a displacement warning. + + Given: + Wool's task factory installed on the loop, then a + third-party factory installed after it (displacing Wool's). + When: + Wool's self-install path runs again (a wool.ContextVar.set). + Then: + It should raise :class:`TaskFactoryDisplaced` — copy-on- + fork is silently lost for tasks created after the + displacement. + """ + # Arrange + loop = asyncio.get_running_loop() + install_task_factory() + + def third_party_factory(loop, coro, **kwargs): + return asyncio.Task(coro, loop=loop, **kwargs) + + loop.set_task_factory(third_party_factory) + var = ContextVar(_unique("displaced")) + + # Act & assert — the next Wool API contact self-checks the loop. + with pytest.raises(TaskFactoryDisplaced, match="displaced"): + var.set("arm") + + @pytest.mark.asyncio + async def test_install_task_factory_should_be_idempotent(self): + """Test install_task_factory is a no-op when already installed. + + Given: + Wool's task factory installed on the loop. + When: + install_task_factory is called a second time. + Then: + It should leave the same Wool-wrapped factory in place — the + second call is a no-op, not a re-wrap that would compose + Wool's factory over itself. + """ + # Arrange + loop = asyncio.get_running_loop() + install_task_factory() + first = loop.get_task_factory() + + # Act + with warnings.catch_warnings(): + warnings.simplefilter("error") + install_task_factory() + second = loop.get_task_factory() + + # Assert + assert first is second + assert getattr(second, "__wool_wrapped__", False) + + +class TestChainContentionGuardBoundary: + @pytest.mark.asyncio + async def test_two_tasks_should_raise_chain_contention_when_sharing_armed_context( + self, + ): + """Test two tasks handed one context that is later armed trip the guard. + + Pins ``assert_chain_owner``'s cross-task trigger boundary: the + task factory's copy-on-fork cannot catch a context shared while + still *unarmed*; the chain-owner guard catches the second runner + once it is armed. + + Given: + One unarmed contextvars.Context passed to two tasks (sharing + an unarmed context is permitted, exactly as stdlib allows). + When: + The first task arms the context with a wool.ContextVar.set + and the second task then touches a wool.ContextVar. + Then: + The second task should raise :class:`ChainContention` — + two tasks cannot run one armed chain. + """ + # Arrange + install_task_factory() + var = ContextVar(_unique("guard_shared")) + shared = contextvars.copy_context() + loop = asyncio.get_running_loop() + armed = asyncio.Event() + first_can_finish = asyncio.Event() + + async def first() -> None: + var.set("armed-by-first") + armed.set() + await first_can_finish.wait() + + async def second() -> BaseException | None: + await armed.wait() + try: + var.get("fallback") + except BaseException as exc: # noqa: BLE001 — return it for assertion + return exc + finally: + first_can_finish.set() + return None + + # Act — both tasks are handed the SAME context object. + first_task = loop.create_task(first(), context=shared) + second_task = loop.create_task(second(), context=shared) + observed, _ = await asyncio.gather(second_task, first_task) + + # Assert + assert isinstance(observed, ChainContention) + + @pytest.mark.asyncio + async def test_repassing_a_live_armed_context_should_raise_chain_contention(self): + """Test re-passing a live armed context to create_task is rejected. + + Pins the factory's up-front rejection: an armed context already + driving a live task cannot be handed to a second create_task. + + Given: + An armed contextvars.Context owned by a running task. + When: + That same context is passed to loop.create_task while the + owning task is still live. + Then: + It should raise :class:`ChainContention` from the + factory itself, before the second task runs. + """ + # Arrange + install_task_factory() + var = ContextVar(_unique("guard_repass")) + loop = asyncio.get_running_loop() + holder_running = asyncio.Event() + holder_can_finish = asyncio.Event() + + async def holder() -> None: + var.set("armed") + holder_running.set() + await holder_can_finish.wait() + + async def noop() -> None: + return None + + # Act & assert — the holder task arms the context it is handed, + # then that live, armed context is re-passed to a second + # create_task while the holder is still running. The context is + # threaded in explicitly rather than read back via + # Task.get_context(), which is Python 3.12+; the project + # supports 3.11. + holder_context = contextvars.copy_context() + holder_task = loop.create_task(holder(), context=holder_context) + await holder_running.wait() + try: + with pytest.raises(ChainContention, match="armed"): + loop.create_task(noop(), context=holder_context) + finally: + holder_can_finish.set() + await holder_task + + @pytest.mark.asyncio + async def test_callback_on_the_owning_thread_should_not_raise(self): + """Test a callback on the owning thread does not trip the guard. + + The negative boundary: cooperatively-scheduled work on the + chain's owning thread shares the chain serially and must not + raise. + + Given: + An armed wool.ContextVar owned by the loop thread. + When: + A loop.call_soon callback on that same thread reads the + variable. + Then: + It should observe the value without raising + :class:`ChainContention` — a callback shares the chain + but never runs concurrently with its owner. + """ + # Arrange + install_task_factory() + var = ContextVar(_unique("guard_callback")) + var.set("scope") + loop = asyncio.get_running_loop() + done: asyncio.Future[str] = loop.create_future() + + def read() -> None: + if not done.done(): + try: + done.set_result(var.get()) + except BaseException as exc: # noqa: BLE001 + done.set_exception(exc) + + # Act + loop.call_soon(read) + observed = await asyncio.wait_for(done, timeout=5.0) + + # Assert + assert observed == "scope" + + @pytest.mark.asyncio + async def test_offloaded_code_that_never_touches_a_var_should_not_raise(self): + """Test offloaded code that touches no wool.ContextVar does not raise. + + The negative boundary: the guard fires when a wool.ContextVar is + read or written, not when a thread boundary is crossed. + + Given: + An armed wool.ContextVar owned by the loop thread. + When: + asyncio.to_thread offloads a function that does NOT touch + any wool.ContextVar. + Then: + It should return normally without raising + :class:`ChainContention` — offloaded code that never + enters the chain is never flagged. + """ + # Arrange + var = ContextVar(_unique("guard_untouched")) + var.set("scope") + + def offloaded() -> str: + return "no-var-touched" + + # Act + observed = await asyncio.to_thread(offloaded) + + # Assert + assert observed == "no-var-touched" + + +class TestNestedTaskDepth: + @pytest.mark.asyncio + async def test_grandchild_should_fork_chain_distinct_from_ancestors(self): + """Test a grandchild task forks a chain distinct from its ancestors. + + Given: + A wool.ContextVar set in an armed parent with the task + factory installed. + When: + The parent spawns a child task, which spawns a grandchild + task, each reading wool.__chain__.get().id. + Then: + It should give parent, child, and grandchild three distinct + chain ids — copy-on-fork mints a fresh chain at every + nesting level. + """ + # Arrange + install_task_factory() + var = ContextVar(_unique("nested_chain")) + var.set("depth-0") + parent = wool.__chain__.get(None) + assert parent is not None + + async def grandchild() -> uuid.UUID: + context = wool.__chain__.get(None) + assert context is not None + return context.id + + async def child() -> tuple[uuid.UUID, uuid.UUID]: + context = wool.__chain__.get(None) + assert context is not None + grandchild_chain = await asyncio.create_task(grandchild()) + return context.id, grandchild_chain + + # Act + child_chain, grandchild_chain = await asyncio.create_task(child()) + + # Assert + assert len({parent.id, child_chain, grandchild_chain}) == 3 + + @pytest.mark.asyncio + async def test_depth_0_value_should_be_visible_at_depth_3(self): + """Test a value set at the root is visible three task levels deep. + + Given: + A wool.ContextVar set in an armed parent with the task + factory installed. + When: + The parent spawns a chain of child -> grandchild -> + great-grandchild tasks, the innermost reading the variable. + Then: + It should observe the depth-0 value at depth 3 — + copy-on-fork is transitive: each fork inherits its parent's + bindings. + """ + # Arrange + install_task_factory() + var = ContextVar(_unique("nested_value")) + var.set("root-value") + + async def great_grandchild() -> str: + return var.get() + + async def grandchild() -> str: + return await asyncio.create_task(great_grandchild()) + + async def child() -> str: + return await asyncio.create_task(grandchild()) + + # Act + observed = await asyncio.create_task(child()) + + # Assert + assert observed == "root-value" + + +class TestGeneratorBasedCoroutine: + @pytest.mark.asyncio + async def test_generator_based_coroutine_should_propagate_value(self): + """Test a generator-based coroutine task propagates a wool value. + + Exercises the ``Coroutine | Generator`` arm of ``wool_factory``: + a task built from a generator decorated with + :func:`asyncio.coroutine`-style ``@types.coroutine``. + + Given: + A wool.ContextVar set in an armed parent with the task + factory installed, and a child built from a + ``@types.coroutine`` generator function. + When: + A child task created from that generator-based coroutine + reads the variable. + Then: + It should observe the parent's value — the task factory + forks the generator-based coroutine onto a copy-on-fork + chain exactly like a native ``async def`` coroutine. + """ + # Arrange + install_task_factory() + var = ContextVar(_unique("genbased_value")) + var.set("parent") + + @types.coroutine + def generator_based_child(): + # A bare ``yield`` makes this a generator; ``@types.coroutine`` + # marks it awaitable so asyncio's create_task path accepts it. + yield + return var.get() + + # Act + observed = await asyncio.create_task(generator_based_child()) + + # Assert + assert observed == "parent" + + +class TestRunCoroutineThreadsafe: + @pytest.mark.asyncio + async def test_run_coroutine_threadsafe_should_propagate_value(self, make_var): + """Test run_coroutine_threadsafe propagates a value from a foreign thread. + + Given: + A context variable set inside a coroutine that a foreign + thread submits via asyncio.run_coroutine_threadsafe. + When: + That coroutine reads the variable it set. + Then: + It should observe the value — the submitted coroutine runs + as a task on the loop and sees its own writes, identically + for a stdlib and a wool variable. + """ + # Arrange + var = make_var("rct_value") + loop = asyncio.get_running_loop() + ready = asyncio.Event() + + async def submitted() -> str: + var.set("submitted-scope") + return var.get() + + result: list[str] = [] + error: list[BaseException] = [] + + def submit_from_foreign_thread() -> None: + future = asyncio.run_coroutine_threadsafe(submitted(), loop) + try: + result.append(future.result(timeout=5.0)) + except BaseException as exc: # noqa: BLE001 + error.append(exc) + loop.call_soon_threadsafe(ready.set) + + # Act + worker = threading.Thread(target=submit_from_foreign_thread) + worker.start() + await asyncio.wait_for(ready.wait(), timeout=5.0) + worker.join(timeout=5.0) + + # Assert + assert not error + assert result == ["submitted-scope"] + + @pytest.mark.asyncio + async def test_run_coroutine_threadsafe_should_fork_fresh_chain(self): + """Test a run_coroutine_threadsafe coroutine forks its own chain. + + Given: + An armed wool.ContextVar on the loop thread with the task + factory installed. + When: + A foreign thread submits a coroutine via + asyncio.run_coroutine_threadsafe that arms its own + wool.ContextVar and reads wool.__chain__.get().id. + Then: + It should run on a chain owned by the loop thread without + tripping :class:`ChainContention` — the coroutine is + scheduled as a task on the loop, not run on the foreign + thread. + """ + # Arrange + install_task_factory() + loop = asyncio.get_running_loop() + loop_var = ContextVar(_unique("rct_loop")) + loop_var.set("loop-armed") + ready = asyncio.Event() + result: list[uuid.UUID] = [] + error: list[BaseException] = [] + + async def submitted() -> uuid.UUID: + var = ContextVar(_unique("rct_chain")) + var.set("submitted") + context = wool.__chain__.get(None) + assert context is not None + return context.id + + def submit_from_foreign_thread() -> None: + future = asyncio.run_coroutine_threadsafe(submitted(), loop) + try: + result.append(future.result(timeout=5.0)) + except BaseException as exc: # noqa: BLE001 + error.append(exc) + loop.call_soon_threadsafe(ready.set) + + # Act + worker = threading.Thread(target=submit_from_foreign_thread) + worker.start() + await asyncio.wait_for(ready.wait(), timeout=5.0) + worker.join(timeout=5.0) + + # Assert + assert not error + assert len(result) == 1 + + +class TestSchedulingEdgeExceptionPropagation: + @pytest.mark.asyncio + async def test_lookup_error_in_a_gather_child_should_surface(self): + """Test a LookupError raised in a gather child surfaces to the awaiter. + + Given: + A wool.ContextVar with no value and no default. + When: + A child coroutine run via asyncio.gather reads it. + Then: + The :class:`LookupError` should surface out of the gather + awaiter — an exception raised in a scheduling-edge child + propagates to the caller. + """ + # Arrange + var = ContextVar(_unique("gather_exc")) + + async def child() -> str: + return var.get() + + # Act & assert + with pytest.raises(LookupError): + await asyncio.gather(child()) + + @pytest.mark.asyncio + async def test_lookup_error_in_a_done_callback_should_surface(self): + """Test a LookupError raised in a done callback surfaces to the awaiter. + + Given: + A wool.ContextVar with no value and no default, and a future + whose done callback reads it. + When: + The future resolves and the done callback runs. + Then: + The read's :class:`LookupError` should surface onto the + observer future — an exception raised on a done-callback + scheduling edge propagates rather than being swallowed by + the loop's exception handler. + """ + # Arrange + var = ContextVar(_unique("dc_exc")) + loop = asyncio.get_running_loop() + observed: asyncio.Future[object] = loop.create_future() + + def done_callback(_: object) -> None: + if observed.done(): + return + try: + observed.set_result(var.get()) + except BaseException as exc: # noqa: BLE001 — route it to the future + observed.set_exception(exc) + + future: asyncio.Future[None] = loop.create_future() + future.add_done_callback(done_callback) + + # Act & assert + future.set_result(None) + with pytest.raises(LookupError): + await asyncio.wait_for(observed, timeout=5.0) + + +class TestRunInExecutorWithExplicitContext: + @pytest.mark.asyncio + async def test_run_in_executor_with_ctx_run_should_raise_chain_contention(self): + """Test loop.run_in_executor with an explicit ctx.run raises ChainContention. + + Pins the raw ``loop.run_in_executor(None, ctx.run, fn)`` idiom: + a copy_context() run on an executor thread. + + Given: + An armed wool.ContextVar and a copy_context() of the armed + context. + When: + loop.run_in_executor offloads ``ctx.run`` over a function + that reads the variable. + Then: + It should raise :class:`ChainContention` — running the + armed chain's copied context on an executor thread enters + the chain off its owning thread. + """ + # Arrange + var = ContextVar(_unique("rie_ctxrun")) + var.set("armed") + loop = asyncio.get_running_loop() + ctx = contextvars.copy_context() + + def read() -> str: + return var.get("fallback") + + # Act & assert + with pytest.raises(ChainContention): + await loop.run_in_executor(None, ctx.run, read) + + @pytest.mark.asyncio + async def test_run_in_executor_with_ctx_run_should_not_raise_when_unarmed(self): + """Test run_in_executor with ctx.run on an unarmed context does not raise. + + Given: + An unarmed contextvars.Context (no wool.ContextVar set) and + a wool.ContextVar with a default. + When: + loop.run_in_executor offloads ``ctx.run`` over a function + that reads the variable. + Then: + It should return the default without raising — an unarmed + context carries no chain to enter. + """ + # Arrange + var = ContextVar(_unique("rie_unarmed"), default="d") + loop = asyncio.get_running_loop() + ctx = contextvars.copy_context() + + def read() -> str: + return var.get() + + # Act + observed = await loop.run_in_executor(None, ctx.run, read) + + # Assert + assert observed == "d" diff --git a/wool/tests/test_exception.py b/wool/tests/test_exceptions.py similarity index 82% rename from wool/tests/test_exception.py rename to wool/tests/test_exceptions.py index 8d113aab..5a66ae6e 100644 --- a/wool/tests/test_exception.py +++ b/wool/tests/test_exceptions.py @@ -1,10 +1,10 @@ import pytest from wool import AdvertiseHostError -from wool import ContextDecodeWarning from wool import IneffectiveLeaseWarning from wool import IneffectiveQuorumTimeoutWarning from wool import LoopbackAdvertisementWarning +from wool import SerializationWarning from wool import WoolError from wool import WoolWarning @@ -12,10 +12,10 @@ class TestWoolError: """Tests for WoolError umbrella membership. - Fully qualified name: wool.exception.WoolError + Fully qualified name: wool.exceptions.WoolError """ - def test___init___with_wool_exception_types(self): + def test___init___should_subclass_wool_error(self): """Test wool's typed exceptions descend from WoolError. Given: @@ -33,19 +33,21 @@ def test___init___with_wool_exception_types(self): class TestWoolWarning: """Tests for WoolWarning umbrella membership. - Fully qualified name: wool.exception.WoolWarning + Fully qualified name: wool.exceptions.WoolWarning """ @pytest.mark.parametrize( "category", [ - ContextDecodeWarning, IneffectiveLeaseWarning, IneffectiveQuorumTimeoutWarning, LoopbackAdvertisementWarning, + SerializationWarning, ], ) - def test___init___with_wool_warning_categories(self, category): + def test___init___should_subclass_wool_warning_and_not_stdlib_warnings( + self, category + ): """Test wool's warning categories descend from WoolWarning. Given: diff --git a/wool/tests/test_public.py b/wool/tests/test_public.py index 248ede3b..d572aaba 100644 --- a/wool/tests/test_public.py +++ b/wool/tests/test_public.py @@ -1,7 +1,7 @@ import wool -def test_public_symbol_accessibility(): +def test_public_symbol_accessibility_should_import_every_symbol(): """Test public symbol accessibility from wool package. Given: @@ -19,7 +19,7 @@ def test_public_symbol_accessibility(): assert hasattr(wool, symbol), f"Symbol '{symbol}' not found in wool package" -def test_public_api_completeness(): +def test_public_api_completeness_should_match_expected_surface(): """Test public API completeness of wool package. Given: @@ -35,16 +35,19 @@ def test_public_api_completeness(): "TransientRpcError", "UnexpectedResponse", "WorkerConnection", - "Context", - "ContextAlreadyBound", - "ContextDecodeWarning", + "ChainContention", + "ChainSerializationError", + "TaskFactoryDisplaced", "ContextVar", "ContextVarCollision", "RuntimeContext", + "SerializationError", + "SerializationWarning", "Token", - "copy_context", - "create_task", - "current_context", + "WoolError", + "WoolWarning", + "install_task_factory", + "to_thread", "LoadBalancerContextLike", "LoadBalancerLike", "NoWorkersAvailable", @@ -82,6 +85,7 @@ def test_public_api_completeness(): "PredicateFunction", "WorkerMetadata", "Factory", + "UndefinedType", } # Act @@ -91,7 +95,50 @@ def test_public_api_completeness(): assert actual_public_api == expected_public_api -def test_serializer_singleton_identity(): +def test_removed_symbols_should_not_be_accessible(): + """Test that symbols removed in this PR are not accessible from wool. + + Given: + The wool package after the stdlib-context refactor. + When: + Checking for the presence of each removed symbol by name. + Then: + It should not expose any of the five removed symbols as attributes. + """ + # Arrange + removed_names = [ + "Chain", + "current_context", + "copy_context", + "create_task", + "ContextAlreadyBound", + ] + + # Act & assert + for name in removed_names: + assert not hasattr(wool, name), ( + f"Removed symbol '{name}' is still accessible on the wool package" + ) + + +def test_wool_error_should_subclass_exception_not_runtime_error(): + """Test WoolError descends from Exception and nothing narrower. + + Given: + The wool.WoolError umbrella class. + When: + Its position in the builtin exception hierarchy is checked. + Then: + It should subclass Exception directly — not RuntimeError — so + Wool-domain signals never match broad handlers written for + stdlib runtime errors. + """ + # Arrange, act, & assert + assert issubclass(wool.WoolError, Exception) + assert not issubclass(wool.WoolError, RuntimeError) + + +def test_serializer_singleton_should_return_same_instance(): """Test wool.__serializer__ is a stable module-level singleton. Given: @@ -105,7 +152,7 @@ def test_serializer_singleton_identity(): assert wool.__serializer__ is wool.__serializer__ -def test_serializer_singleton_with_serializer_protocol(): +def test_serializer_singleton_should_satisfy_serializer_protocol(): """Test wool.__serializer__ satisfies the wool.Serializer protocol. Given: From 9ce6c50ceea2a62b410352c05b7b5ec9a4640032 Mon Sep 17 00:00:00 2001 From: Conrad Date: Sat, 27 Jun 2026 17:51:29 -0400 Subject: [PATCH 7/7] docs: Rewrite the READMEs for the stdlib-aligned chain model Rewrite the package, context, and worker READMEs to match the shipped surface: the stdlib-aligned public API, the ChainManifest.mount wire-apply pattern, and the Frame-based dispatch vocabulary. --- wool/README.md | 102 ++++++++++++----------- wool/src/wool/runtime/context/README.md | 104 ++++++++++++------------ wool/src/wool/runtime/worker/README.md | 28 +++---- 3 files changed, 123 insertions(+), 111 deletions(-) diff --git a/wool/README.md b/wool/README.md index 2a6e5cf0..ec199335 100644 --- a/wool/README.md +++ b/wool/README.md @@ -79,7 +79,7 @@ The decorated function, its arguments, returned or yielded values, and exception ### Dispatch gate -Under the hood, the `@wool.routine` decorator replaces the function with a wrapper that checks a `do_dispatch` context variable. This is a `ContextVar[bool]` that defaults to `True` and acts as a dispatch gate — when `True`, calling a routine packages the call into a task and sends it to a remote worker. Workers set `do_dispatch` to `False` before executing the function body, preventing infinite re-dispatch. The variable is restored to `True` for any nested `@wool.routine` calls within the function, so those dispatch normally to other workers. +Under the hood, the `@wool.routine` decorator replaces the function with a wrapper that checks an internal dispatch-routing flag. The flag defaults to dispatching — calling a routine packages the call into a task and sends it to a remote worker. Workers run the function body with the flag cleared, preventing infinite re-dispatch, and restore it for any nested `@wool.routine` calls within the function, so those dispatch normally to other workers. ### Coroutines vs. async generators @@ -144,27 +144,27 @@ Each var's namespace is inferred from the top-level package of the calling frame ### How propagation works -At dispatch time, Wool snapshots only the vars that have been explicitly `set()` in the current `wool.Context` — default-only values are not shipped. The snapshot is assembled in O(k) time by iterating the per-Context data dict (which contains only explicitly-set vars), not the full process-wide registry. It rides every dispatch frame as a `Context` protobuf message carrying a `map` keyed by each var's `":"`, alongside the active `wool.Context` id that identifies the logical chain. +At dispatch time, Wool encodes the active chain context — the `wool.ContextVar` bindings explicitly `set()` in the current chain, default-only values excluded — into a `ChainManifest` protobuf message. The message carries the chain id plus one entry per bound (or reset) variable, each holding the variable's `(namespace, name)` key and its cloudpickled value, and rides every dispatch frame. -`wool.ContextVar.__reduce__` embeds the var's current value directly in the reduce tuple, so when a `wool.ContextVar` appears anywhere in a pickled object graph its value travels with it. References across a task's args, kwargs, and ContextVar snapshot all land on the same local instance on the receiver. Unpickling goes through a strict construction path: if no var is yet registered under the key, a "stub" instance is registered through an internal back-door that bypasses the duplicate-key check, and the var's value is applied from the wire; when the worker's module-scope constructor later runs, it promotes the stub in place, preserving any wire state and reference identity. +A `wool.ContextVar` is identity, not storage: it serializes (through Wool's own pickler — vanilla `pickle` is rejected) as just its `(namespace, name)` key and constructor default, never a value. Values travel only in the wire context, so a `wool.ContextVar` referenced across a task's args, kwargs, and the context all reconstitutes to the same local instance on the receiver. If the receiver has not yet imported the module that declares the variable, a placeholder "stub" instance is registered under the key; when the declaring `wool.ContextVar(...)` call later runs, it promotes the stub in place, preserving reference identity and any propagated value. -On the worker, each task is activated in its own `wool.Context` carrying the caller's chain id and the caller's propagated values, distinct from any concurrent task's `wool.Context` on the same worker. When the worker returns (or yields), the final var state is attached to the gRPC response and applied on the caller side, so worker-side mutations flow back automatically. For async generators, the caller also attaches its current context to each iteration request, enabling bidirectional state exchange between caller and worker at every yield/next boundary. +On the worker, the routine runs under the caller's chain — a context decoded from the dispatch frame, carrying the caller's chain id and propagated values. When the worker returns (or yields), the resulting context is encoded onto the gRPC response and merged on the caller side, so worker-side mutations flow back automatically. For async generators, the caller also attaches its current context to each iteration request, enabling bidirectional state exchange between caller and worker at every yield/next boundary. ### Isolation -Each dispatched task runs inside its own `wool.Context`, carrying the caller's chain id and the caller's propagated values. Concurrent tasks on the same worker with different values for the same variable never interfere — each sees only its own propagated state. Worker-side mutations (via `set()`) are back-propagated to the caller when the task returns or yields, but they do not leak to other concurrent tasks: each dispatch activates its own `wool.Context` on the worker, and `asyncio.create_task` children fork a copy of the parent's `wool.Context` on creation (mirroring `contextvars.copy_context()` semantics), so concurrent execution paths do not share a mutable `wool.Context` and bidirectional value propagation stays coherent under the transparent-dispatch model. +Each dispatched routine runs under the caller's chain, and concurrent tasks on the same worker each carry their own context — concurrent tasks with different values for the same variable never interfere. Worker-side mutations (via `set()`) are back-propagated to the caller when the task returns or yields, but they do not leak to other concurrent tasks: each `asyncio.create_task` child forks a copy of the parent's context under a fresh chain id (copy-on-fork via Wool's task factory), so concurrent execution paths never share mutable Wool state and bidirectional value propagation stays coherent under the transparent-dispatch model. ### Decode failure semantics -Context propagation is **ancillary state** in wool's wire protocol — a separate channel from the routine's primary signal (its return value or raised exception). When a wire context fails to decode (cross-version pickle skew, custom class missing on the receiver, on-wire corruption of a single var value), wool never preempts the primary signal to surface the ancillary failure. The routine's outcome is delivered, and the failure is reported via Python's standard `warnings` mechanism with a `wool.ContextDecodeWarning` so callers can decide how to respond. +Context propagation is **ancillary state** in wool's wire protocol — a separate channel from the routine's primary signal (its return value or raised exception). When a wire context fails to decode (cross-version pickle skew, custom class missing on the receiver, on-wire corruption of a single var value), wool never preempts the primary signal to surface the ancillary failure. The routine's outcome is delivered, and the failure is reported via Python's standard `warnings` mechanism with a `wool.SerializationWarning` so callers can decide how to respond. -Three modes are available, and they compose with the standard Python warnings system rather than wool-specific API: +Each mode is configured through Python's standard warnings system: | Mode | How to enable | Behavior | | ---- | ------------- | -------- | -| Lenient (default) | _no opt-in_ | Decode failure emits `wool.ContextDecodeWarning`; primary signal returned. Caller-side exception frames also receive the failure on `__notes__`. | +| Lenient (default) | _no opt-in_ | Decode failure emits `wool.SerializationWarning`; primary signal returned. Caller-side exception frames also receive the failure on the exception's `__context__`. | | Inspect | `warnings.catch_warnings(record=True)` | Decode failure captured into a list; primary signal returned. Standard pattern for "best effort with audit trail". | -| Strict | `warnings.filterwarnings("error", category=wool.ContextDecodeWarning)` | Decode failure raises (the warning is promoted to an exception); primary signal lost. | +| Strict | `warnings.filterwarnings("error", category=wool.SerializationWarning)` | Decode failure raises (the warning is promoted to an exception); primary signal lost. | The lenient default keeps wool useful for callers that treat tracing-style state as advisory. Strict mode is for callers whose correctness depends on context state and prefer to fail fast. Inspect mode is the right choice when you want both the primary signal and visibility into ancillary failures: @@ -173,21 +173,21 @@ import warnings import wool with warnings.catch_warnings(record=True) as captured: - warnings.simplefilter("always", category=wool.ContextDecodeWarning) + warnings.simplefilter("always", category=wool.SerializationWarning) result = await some_routine() # always returns - decode_failures = [w for w in captured if issubclass(w.category, wool.ContextDecodeWarning)] + decode_failures = [w for w in captured if issubclass(w.category, wool.SerializationWarning)] if decode_failures: log.warning("context propagation degraded for %d frame(s)", len(decode_failures)) ``` -The same semantics apply on both sides of the wire: the worker emits `ContextDecodeWarning` when a request context fails to decode (and runs the routine with a fresh empty context as fallback), and the caller emits `ContextDecodeWarning` when a response context fails to decode (and delivers the result anyway). On the caller side, an exception-frame decode failure additionally rides on the routine's exception via `__notes__` so the failure surfaces in tracebacks. On the worker side, a snapshot encode failure that coincides with a routine exception rides similarly on the routine exception via `__notes__`. There is no `ExceptionGroup` chaining and no wrapper-exception API to learn — just a standard warning class and standard `try/except` around primary signals. +The same semantics apply on both sides of the wire: the worker emits `SerializationWarning` when a request context fails to decode (each unreadable entry is dropped individually so the routine still runs under whatever partial context decoded; only when the whole frame is unreadable does the worker context remain unarmed), and the caller emits `SerializationWarning` when a response context fails to decode (and delivers the result anyway). On the caller side, an exception-frame decode failure is additionally chained onto the routine exception's `__context__` so the failure surfaces in tracebacks. On the worker side, a context encode failure that coincides with a routine exception rides on the routine exception's `__cause__` via `raise ... from`. Callers need only a standard warning class and standard `try/except` around primary signals. #### Worker-side strict mode Strict mode applies symmetrically on the worker side via Python's standard `PYTHONWARNINGS` environment variable, which `multiprocessing` propagates to spawned worker subprocesses by default: ```bash -export PYTHONWARNINGS="error::wool.ContextDecodeWarning" +export PYTHONWARNINGS="error::wool.SerializationWarning" python my_app.py ``` @@ -195,7 +195,7 @@ Or programmatically before constructing the pool: ```python import os -os.environ["PYTHONWARNINGS"] = "error::wool.ContextDecodeWarning" +os.environ["PYTHONWARNINGS"] = "error::wool.SerializationWarning" import wool @@ -203,33 +203,43 @@ async with wool.WorkerPool(): ... # workers spawned now promote the warning to an exception ``` -When the worker promotes the warning to an exception, wool ships it back through the routine-exception channel, so the caller catches the exact same `wool.ContextDecodeWarning` class — symmetric with caller-side strict mode. No `RpcError` to special-case, no out-of-band wire metadata. +When the worker promotes the warning to an exception, wool ships it back through the routine-exception channel, so the caller catches the exact same `wool.SerializationWarning` class — symmetric with caller-side strict mode. The caller handles it with the same `try/except` it would use for any local exception. -### Binding a `wool.Context` to a task +### Task forking and thread offload -The canonical way to bind a `wool.Context` to a freshly-spawned `asyncio.Task` is `wool.create_task` (typed shim) or `asyncio.create_task` (or `loop.create_task`) directly with `context=wool_ctx`: +`wool.ContextVar` state rides in stdlib `contextvars`, so it propagates into child tasks, event-loop callbacks, timers, and `Future` done-callbacks with no special API — exactly as stdlib `contextvars` does. Wool's task factory (self-installed on the running loop the first time a `wool.ContextVar` is set, or explicitly via `wool.install_task_factory(loop)`) adds one thing: every child `asyncio.Task` created in an armed context is forked onto a fresh chain id, so concurrent tasks never share a mutable chain. Two caveats: if other libraries also install a task factory, Wool's must be installed *last* — a factory installed after Wool's silently drops fork-on-task, and a later `wool.ContextVar` access raises `wool.TaskFactoryDisplaced` (a `RuntimeError` subclass) once Wool detects it has been displaced — and an armed task's coroutine is wrapped, so `Task.get_coro()` and `repr()`'s coroutine field reflect the wrapper rather than the user coroutine. Every task the factory creates carries one Wool done-callback (visible as `cb=[...]` in `repr()`); a task created in an unarmed context is otherwise coroutine-identical to a plain `asyncio.Task` created on the same loop — its auto-generated `Task-N` name draws from that loop implementation's own counter. + +Offloading to another OS thread is the one case that needs care. Wool enforces a single runner per chain, so a plain `asyncio.to_thread()` from a context that has set a `wool.ContextVar` would place a second runner on the caller's chain in genuine parallelism — the first `wool.ContextVar` access in the worker thread raises `wool.ChainContention`. Use `wool.to_thread()` instead; it mirrors `asyncio.to_thread` but forks a fresh, detached chain for the worker thread: ```python -ctx = wool.copy_context() -task = wool.create_task(some_coro(), context=ctx) -# Equivalent at runtime: -task = asyncio.create_task(some_coro(), context=ctx) # type: ignore[arg-type] +result = await wool.to_thread(cpu_bound, payload) ``` -Both forms route through Wool's task factory, which self-installs on the running loop the first time any Wool API is touched (or on demand via `wool.install_task_factory(loop)`). The factory wraps the coroutine so the `wool.Context`'s single-task guard is held continuously across awaits — any concurrent attempt to bind a second task to the same `wool.Context` raises `RuntimeError` immediately when that task starts running. `wool.create_task` exists purely as a typing shim: stdlib's `context=` kwarg is typed for `contextvars.Context` and `wool.Context` cannot subclass it (the C type disallows subclassing), so the Wool helper hides the cast. - -When `context=` is omitted, the factory forks `wool.copy_context()` from the parent task and binds the fresh chain id to the child. This is the default `asyncio.create_task(coro)` path and matches stdlib's `contextvars.copy_context()` semantics with wool's chain-id contract layered on top. +A bare `loop.run_in_executor(None, func)` is different again: stdlib `run_in_executor` copies no context at all, so `func` runs with no Wool chain — no `wool.ChainContention`, but no propagation either, and `wool.ContextVar` reads in the worker fall through to their defaults. Reach for `wool.to_thread()` whenever the offloaded work needs the caller's bindings. ### Backpressure hooks -`BackpressureLike` hooks run after the caller's propagated `wool.ContextVar` snapshot is applied to the worker's context, so a hook can read caller-provided values (e.g., a tenant id) to make admission decisions without the caller having to plumb them through the `BackpressureContext` explicitly. +`BackpressureLike` hooks run with the caller's propagated `wool.ContextVar` context installed, so a hook can read caller-provided values (e.g., a tenant id) to make admission decisions without the caller having to plumb them through the `BackpressureContext` explicitly. + +### Runtime options + +`wool.RuntimeContext` carries block-scoped runtime-option overrides — currently just the dispatch timeout — separate from the `wool.ContextVar` context. Used as a context manager it overrides the ambient timeout for every `@wool.routine` dispatch in the block, and it is auto-captured on every `Task` at construction so the worker restores the caller's effective timeout before running the routine: + +```python +import wool + +with wool.RuntimeContext(dispatch_timeout=30): + result = await my_routine() +``` + +The underlying `dispatch_timeout` is a plain stdlib `contextvars.ContextVar[float | None]` (`None` means no timeout), distinct from the Wool-owned context variable that carries `wool.ContextVar` state. ### Limitations - **Values must be _cloudpicklable_.** A `TypeError` naming the offending variable is raised at dispatch time if serialization fails. -- **Only explicitly set values propagate.** A variable that has never been `set()` (only has a class-level default) is not included in the snapshot — the worker falls through to its own default. -- **Receivers must eventually declare the var.** Until the worker imports the module that constructs the var, the wire-shipped value is held on a stub pinned to the receiver `wool.Context`; a later `wool.ContextVar(...)` declaration promotes the stub and the propagated value applies transparently. If the worker never declares the var, the stub is collected with its receiver `wool.Context` and the value is dropped. -- **Tokens are scoped to their originating `wool.Context`.** A `Token` minted inside a task cannot be reset from outside that `wool.Context` — including after crossing an `asyncio.create_task` fork boundary, since child tasks receive fresh `wool.Context` ids. Reset the token in the same logical chain that produced it, or use `var.set(...)` to install a new value without relying on the token. +- **Only explicitly set values propagate.** A variable that has never been `set()` (only has a class-level default) is not included in the context — the worker falls through to its own default. +- **Receivers must eventually declare the var.** Until the worker imports the module that constructs the var, the wire-shipped value is held on a stub kept alive by the decoded context; a later `wool.ContextVar(...)` declaration promotes the stub and the propagated value applies transparently. If the worker never declares the var, the stub is collected with that context and the value is dropped. +- **Tokens are scoped to their originating chain.** A `Token` minted inside a task cannot be reset from a different chain — including after crossing an `asyncio.create_task` fork boundary, since child tasks receive a fresh chain id. Reset the token in the same logical chain that produced it, or use `var.set(...)` to install a new value without relying on the token. - **Wire keys are tied to the top-level package name.** Renaming the top-level package (e.g., `myapp` → `myapp_v2`) changes every var's wire key, so a rolling deploy that has callers and workers on different top-level names will silently drop propagated values on the mismatched side. Keep the top-level package name stable across rolling deploys, or bridge the transition with explicit `namespace=` overrides. Moving a module deeper within the same top-level package is safe — the key is the package root, not the full module path. ## Worker pools @@ -416,13 +426,13 @@ A dispatch crosses two processes and several stages on each side; failures can o | ----- | ----------------------------------- | ------------------------------- | -------------------- | | Caller-side request encoding | n/a (no transport involved yet) | Original `Exception` (unwrapped) | None — no worker contacted | | gRPC handshake | `TransientRpcError` (transient codes) or `RpcError` (non-transient, incl. `FAILED_PRECONDITION` for version mismatch) | n/a | Skip on transient; evict on non-transient | -| Worker-side request decoding | `Rejected.original` re-raised on the caller (typed) | Strict-mode `wool.ContextDecodeWarning` re-raised on the caller | None — typed re-raise | -| Routine execution | n/a | Original routine exception (type and traceback preserved); ancillary warnings on `__notes__` | None | -| Worker-side response encoding | Routine exception with `__notes__` (strict-mode context encode) | n/a | None | +| Worker-side request decoding | `Rejected.original` re-raised on the caller (typed) | Strict-mode `wool.ContextSerializationError` re-raised on the caller | None — typed re-raise | +| Routine execution | n/a | Original routine exception (type and traceback preserved); ancillary `wool.ContextSerializationError` chained on `__cause__` | None | +| Worker-side response encoding | Routine exception chained from strict-mode `wool.ContextSerializationError` via `__cause__` | n/a | None | | Caller-side response decoding | `UnexpectedResponse` (malformed payload, missing class, version skew) | n/a | None — caller-fault, worker is healthy | | Post-execution teardown | Swallowed (the wire is already closed) | n/a | n/a | -In every case the caller sees a single exception — wool does not wrap routine failures in `ExceptionGroup` or any wool-specific wrapper class. Ancillary signals (context decode failures, snapshot encode failures coincident with a routine exception) attach to the primary exception via PEP 678 `__notes__` and a `__wool_context_warnings__` attribute, so existing `except RoutineError:` clauses keep matching. +In every case the caller sees a single exception — wool does not wrap routine failures in `ExceptionGroup` or any wool-specific wrapper class. Ancillary signals (context decode failures, context encode failures coincident with a routine exception) ride on the primary exception as `__cause__` via `raise primary from decode_err`, so existing `except RoutineError:` clauses keep matching and the decode error remains visible in the traceback through cause chaining. ```python try: @@ -446,33 +456,33 @@ Version compatibility is checked by `VersionInterceptor` **before** the dispatch ### Worker-side request decoding -`DispatchSession.__aenter__` is the worker's parse phase: it reads the first request frame, decodes the caller's `wool.Context` snapshot and rebuilds the `wool.Task` (both via `cloudpickle`), and validates that the callable is an async function or async generator. +`DispatchSession.__aenter__` is the worker's parse phase: it reads the first request frame, decodes the caller's context and rebuilds the `wool.Task` (both via `cloudpickle`), and validates that the callable is an async function or async generator. Failures here wrap in `Rejected` and surface via a `Nack` frame whose `exception` payload carries the original failure (cloudpickle-dumped). The caller deserializes and re-raises, so the user observes the **actual failure class**, not an opaque RPC error: - Malformed task id, cloudpickle errors on the routine callable, ImportError on a missing module, non-async callable → original `Exception` re-raised on the caller. -- Strict-mode promoted `wool.ContextDecodeWarning` (operator set `warnings.filterwarnings("error", category=wool.ContextDecodeWarning)` in the worker subprocess) → the warning ships through the same Nack-with-exception path and re-raises on the caller as `wool.ContextDecodeWarning`. The default lenient mode emits the warning and runs the routine against a fresh empty context (see Context propagation > Decode failure semantics). +- Strict-mode promoted `wool.SerializationWarning` (operator set `warnings.filterwarnings("error", category=wool.SerializationWarning)` in the worker subprocess) → the promoted warnings aggregate into a `wool.ContextSerializationError` that ships through the same Nack-with-exception path and re-raises on the caller as `wool.ContextSerializationError`. The default lenient mode emits the warning and runs the routine against a fresh empty context (see Context propagation > Decode failure semantics). Parse-phase rejections reflect a user-code or version-skew issue, not a worker-health issue. The load balancer does not evict the worker. ### Routine execution -The worker drives one `_step` per request frame inside `routine_scope` (the worker-loop task that owns the routine's `wool.Context` guard). Three terminal shapes are possible: +The worker drives one `_drive_step` per request frame inside `routine_scope` (the worker-loop task that runs the routine under the work context). Three terminal shapes are possible: - **Clean completion** — the routine returns (coroutine) or raises `StopAsyncIteration` (async generator), and the response stream ends. -- **Routine exception** — the worker serializes the original exception with `cloudpickle` (using [tbpickle](https://github.com/wool-labs/tbpickle) to make stack frames picklable) and ships it on a terminal `Response.exception` frame. The caller deserializes and re-raises, **preserving the original type and traceback**. The user's `except RoutineError:` clause matches as written; the load balancer takes no action. +- **Routine exception** — the worker serializes the original exception with `cloudpickle` (using [tblib](https://github.com/python-tblib/tblib) to make stack frames picklable) and ships it on a terminal `Response.exception` frame. The caller deserializes and re-raises, **preserving the original type and traceback**. The user's `except RoutineError:` clause matches as written; the load balancer takes no action. - **Operator pre-emption** — graceful shutdown cancels in-flight dispatches; the underlying `CancelledError` ships through the routine-exception channel. See _Cancellation_ below. If the routine raises an exception that drags an unpicklable object into its graph (e.g., a C-level frame reachable via `__traceback__`/`__cause__`), the worker's `_safely_serialize_exception` falls back to reconstructing the exception cleanly via `cls(*exc.args)` — **preserving the exception class** so the user's `except RoutineError:` clause still matches. Only if even the clean reconstruction fails to pickle does the fallback demote to a stdlib `RuntimeError`. ### Worker-side response encoding -After each successful step, the dispatch handler builds a `protocol.Response`: it dumps the result via `cloudpickle` and attaches the post-step `wool.Context` snapshot. +After each successful step, the dispatch handler builds a `protocol.Response`: it dumps the result via `cloudpickle` and attaches the post-step context. -- **Result dump fails** (un-picklable yielded value) — the failure surfaces as a handler-side exception during response encoding. The dispatch handler drains the worker, snapshots `session.context`, and ships a terminal `Response.exception` carrying the encode failure. Caller observes the dump exception; no worker eviction. -- **Strict-mode context encode failure during a routine exception** — `session.context.to_protobuf` raised a `BaseExceptionGroup` of `wool.ContextDecodeWarning` peers during the terminal-exception path. The handler attaches the warnings to the routine's exception via PEP 678 `__notes__` and a `__wool_context_warnings__` attribute, so the **routine exception's type is preserved**. The terminal response drops the `context` field. The caller's `except RoutineError:` clause still matches; the warnings remain visible in the traceback and accessible programmatically. +- **Result dump fails** (un-picklable yielded value) — the failure surfaces as a handler-side exception during response encoding. The dispatch handler drains the worker, reads the final wire context published by the worker task via `session._final_wire_context`, and ships a terminal `Response.exception` carrying the encode failure. Caller observes the dump exception; no worker eviction. +- **Strict-mode context encode failure during a routine exception** — the worker task's final-encode step raised a `wool.ContextSerializationError` aggregating per-var warnings during the terminal-exception path. The handler reads the encode failure from the worker (alongside `session._final_wire_context`) and chains it onto the routine's exception as `__cause__` via `raise routine_exc from encode_err`, so the **routine exception's type is preserved**. The terminal response drops the `context` field. The caller's `except RoutineError:` clause still matches; the encode error remains visible in the traceback through cause chaining. -`DispatchSession.__aexit__` registers `drain` on its exit stack precisely because of this path: a result-dump failure mid-stream leaves the worker still running, and drain must complete before `session.context.to_protobuf` snapshots state for the terminal frame — otherwise the snapshot races the worker's `_step` writing the same context. +`DispatchSession.__aexit__` registers `drain` on its exit stack precisely because of this path: a result-dump failure mid-stream leaves the worker still running, and drain must complete before the handler reads `session._final_wire_context` for the terminal frame — otherwise the read races the worker task still publishing the final wire context. ### Caller-side response decoding @@ -499,7 +509,7 @@ Cancellation reaches the dispatch via three routes; all three resolve to the sam - **Caller cancels its `await routine()`** — the caller's `WorkerConnection` cancels the gRPC call on the way out; the worker observes the stream tear-down, the dispatch handler invokes `DispatchSession.cancel`, the worker task is cancelled on its own loop, and any routine suspended inside an `await` receives `CancelledError`. - **Routine self-raises `CancelledError`** — the cancellation ships through the routine-exception channel unchanged. -- **Operator preempts** (worker graceful shutdown) — `WorkerService._cancel` calls `DispatchSession.cancel` on every in-flight dispatch; same effect as caller-initiated. +- **Operator preempts** (worker graceful shutdown) — `WorkerService._preempt` calls `DispatchSession.cancel` on every in-flight dispatch; same effect as caller-initiated. In all three cases the caller's `await routine()` raises `asyncio.CancelledError`, matching stdlib's `await task` semantics where `task.cancel()` from any source produces the same observable. @@ -590,10 +600,10 @@ sequenceDiagram end Worker ->> Worker: DispatchSession.__aiter__ schedules worker driver lazily - Worker ->> Worker: routine_scope enters; routine runs under wool.Context + Worker ->> Worker: routine_scope enters; routine runs under the work context alt Coroutine (single synthesized "next" request) - Worker ->> Worker: _step advances coroutine, serializes result + context + Worker ->> Worker: _drive_step advances coroutine, serializes result + context Worker -->> Routine: Response.result + context Routine ->> Routine: deserialize, apply back-propagated context Routine -->> Client: return result @@ -601,14 +611,14 @@ sequenceDiagram loop Each iteration Client ->> Routine: next / send / throw Routine ->> Worker: iteration request frame - Worker ->> Worker: _step advances generator + Worker ->> Worker: _drive_step advances generator Worker -->> Routine: Response.result + context Routine ->> Routine: deserialize, apply back-propagated context Routine -->> Client: yield result end else Routine or encoding exception - Worker ->> Worker: drain worker, snapshot session.context - Worker -->> Routine: terminal Response.exception (cloudpickled, may include __notes__) + Worker ->> Worker: drain worker, read final wire context published by worker task + Worker -->> Routine: terminal Response.exception (cloudpickled; decode/encode failures ride on __context__/__cause__) Routine ->> Routine: deserialize exception (preserves type and traceback) Routine -->> Client: re-raise end diff --git a/wool/src/wool/runtime/context/README.md b/wool/src/wool/runtime/context/README.md index 6ab4e778..2c8fe78a 100644 --- a/wool/src/wool/runtime/context/README.md +++ b/wool/src/wool/runtime/context/README.md @@ -1,89 +1,78 @@ # Context -Wool maintains its own context system parallel to Python's `contextvars`. A `wool.Context` is a self-contained snapshot of `wool.ContextVar` bindings plus a UUID identifying the logical execution chain. The two systems do not share state: `contextvars.Context.run()` does not fork or clear a `wool.Context`, and `wool.Context.run()` does not touch the surrounding `contextvars.Context`. Wool's task factory is the boundary at which Wool's fork-on-task semantics engage; `contextvars` continues to work exactly as it would normally. +Wool's context model is founded directly on Python's stdlib `contextvars`. There is one context system, not two. `wool.ContextVar` is a `contextvars`-backed variable that additionally (a) propagates across worker boundaries and (b) opts its execution chain into Wool's linearity rules. Wool owns a single stdlib `contextvars.ContextVar` (`wool.__chain__`) that carries all `wool.ContextVar` state as an immutable `Chain`; everything else in this subsystem reads, rebuilds, or transports that chain. -## Per-scope binding +The goal is stdlib parity with context continuity across workers: a routine awaited on a worker behaves exactly as a process-local await would. Stdlib asyncio has two context behaviors — a plain `await` *shares* the caller's context (the callee's mutations are visible after the await), and `asyncio.create_task` *copies* it (the child's mutations are isolated). Wool extends both across the wire: a dispatch ships the caller's chain state out and merges the worker's mutations back, so a remote await is indistinguishable from a local one, while `create_task` forks — exactly as it would locally. -Each `wool.Context` is keyed in a process-wide registry by either the current `asyncio.Task` (for async code) or the current `threading.Thread` (for sync code). `wool.current_context()` returns the live binding for the current scope, creating one lazily if no `wool.Context` is bound yet. `wool.copy_context()` produces a shallow copy of the current scope's `wool.Context` with a fresh logical-chain UUID. +Because Wool state rides in stdlib `contextvars`, it propagates with stdlib visibility across every conformant event loop — uvloop included — and across every *cooperative* asyncio scheduling edge: `call_soon`/`call_later`/`call_at`, `add_reader`/`add_writer`/`add_signal_handler`, and `Future.add_done_callback`, each of which inherits the scheduling scope's chain unchanged. Child task creation propagates too, but as a *fork* — the child inherits the parent's variable values on a freshly minted chain id, not the parent's chain itself (see [Fork-on-task](#fork-on-task)). No event-loop interception is involved. Non-owner entry is the one deliberate exception — a chain entered by an OS thread or task that does not own it (a plain `asyncio.to_thread` offload, an armed context re-run on another thread via `Context.run`, an armed context handed to a directly instantiated `asyncio.Task`) fails loud rather than propagating silently; see [the chain-contention guard](#the-chain-contention-guard) below. -The first time any Wool API is touched on a running loop, the loop self-installs Wool's task factory. This is idempotent — repeated calls on the same loop are no-ops — and composes with any user-installed factory. Calling `wool.install_task_factory(loop)` explicitly is supported but rarely necessary; the only ordering hazard is a third-party factory installed *after* Wool's, which silently breaks copy-on-fork for subsequently-created tasks. +## The chain -## Fork-on-task +A chain is one serial branch of the program's async call tree — the logical call stack descending from the most recent `asyncio.create_task` fork, on which every frame executes strictly in sequence. Plain awaits, generator yields, callbacks, and worker dispatches all extend the branch; `create_task` always starts a new one. The chain is what survives the process hop: a worker arms onto the caller's chain id because it genuinely is the same branch, executing elsewhere for a while. -When a new `asyncio.Task` is created on a loop with Wool's task factory installed, the factory examines the `context=` kwarg: +Chain state is held as a single immutable `Chain` in the Wool-owned `wool.__chain__` variable: the chain id, the owner stamps the contention guard arbitrates, the index of bound variables (`vars`), pending reset signals (`resets`), and stub pins (`stubs`) — see the `Chain` class docstring for the field-by-field contract. The chain is an *index*, not a value store: each `wool.ContextVar`'s value lives in its own backing `contextvars.ContextVar`. -- `context=wool_ctx` — the new task is bound to *that* `wool.Context` for its lifetime; concurrent attempts to bind a second task to the same `wool.Context` raise `RuntimeError` as soon as the second task starts running. -- `context=stdlib_ctx` — the stdlib `contextvars.Context` is forwarded verbatim, and Wool forks the parent task's `wool.Context` (or constructs a fresh one if no parent is bound) to bind to the child. -- `context=None` — same fork behaviour as the stdlib path; the child receives a copy of the parent's `wool.Context` under a fresh chain UUID. +`wool.ContextVar.get` reads its key via the active chain; `set` and `reset` rebuild the chain and reinstall it. The chain is immutable, so it can be shared by reference across callbacks, timers, and child tasks without intercepting the loop — a callback that sets a variable produces a *new* chain scoped to its own `contextvars.Context` copy, leaving the scheduling scope's chain untouched. -The fresh chain UUID matters because Tokens are scoped by `wool.Context` UUID. A child task cannot reset a Token minted in its parent's chain — the same isolation that makes concurrent tasks on a worker safe. +## Armed-gating -To get fresh stdlib *and* fresh Wool state in a synchronous block, compose the two: +Enforcement is **armed-gated**. The Wool-owned context variable defaults to unset; while it is unset the surrounding context is *unarmed* and behaves as a plain `contextvars.Context` — no chain UUID and no guard. A context is armed by either of two events: the first `wool.ContextVar.set()` on it, or a merge of incoming wire state — a worker receiving a dispatch frame, or a caller merging a routine's response back. A process that neither sets a `wool.ContextVar` nor dispatches a routine that propagates chain state pays nothing at all: the task factory is never installed, so task creation is untouched. -```python -import contextvars -import wool - -wool.Context().run(contextvars.Context().run, fn, *args, **kwargs) -``` +Arming mints a fresh chain id, installs the first `Chain`, and self-installs the task factory on the running loop. Self-install is a no-op when there is no running loop — an armed context still works, but child tasks created on a later-started loop will not fork onto fresh chains until the factory is installed (call `wool.install_task_factory(loop)` explicitly in that case). Once installed, the factory is loop-wide: every task created on that loop thereafter — Wool-related or not — carries a small constant cost, its context registered for the chain-contention guard. Only a process that never arms any context pays nothing. ## `wool.ContextVar` -`ContextVar` mirrors `contextvars.ContextVar` at the surface — `get()`, `set()`, `reset()` — but its identity model is structural, not object-based. Every `ContextVar` has a `(namespace, name)` key. Two `ContextVar(name="foo", namespace="bar")` calls in the same process resolve to the *same* singleton via the process-wide var registry; collisions on the same key with mismatched declarations raise `ContextVarCollision`. The namespace is inferred from the top-level package of the calling frame when not passed explicitly, which keeps wire keys stable when modules are refactored deeper within the same package. +`ContextVar` mirrors `contextvars.ContextVar` at the surface — `get()`, `set()`, `reset()` — with one addition: an optional `namespace`. Its identity model is structural, not object-based. Every `ContextVar` has a `(namespace, name)` key; two `ContextVar(name="foo", namespace="bar")` calls in the same process resolve to the *same* singleton via the process-wide var registry, and collisions on the same key with mismatched declarations raise `ContextVarCollision`. The namespace is inferred from the top-level package of the calling frame when not passed explicitly, which keeps wire keys stable when modules are refactored deeper within the same package. The `(namespace, name)` identity is what survives pickle round-trips. A `wool.ContextVar` arriving on the wire from a worker reduces by `(namespace, name)` and re-resolves to the same instance on the receiver. Pickling is gated through Wool's pickler; vanilla `pickle.dumps` is rejected via `__reduce_ex__` to keep the registry-bound semantics from leaking into ad-hoc serialization. -Values are stored in the current `wool.Context` — one per `asyncio.Task`, one per thread for sync code. `get()` returns the value bound in the active `wool.Context`, falling back to the supplied default, then the constructor default, then raising `LookupError`. `set()` writes into the active `wool.Context` and returns a `Token`. +`get()` returns the value bound in the active context, falling back to the supplied default, then the constructor default, then raising `LookupError`. `set()` rebuilds the context with the new binding and returns a `Token`. ## `wool.Token` -Tokens carry a UUID and the `(namespace, name)` of the var they reset. Live tokens are deduplicated process-wide, so pickle round-trips within a process preserve identity — a token round-tripped through `cloudpickle` resolves back to the same instance. Across the wire, the consumed-state bit rides on the wire frame so a worker's `reset()` correctly observes consumption that originated upstream, and vice versa. +`wool.Token` *is* `contextvars.Token` — re-exported, not wrapped. `set()` returns the native token, `reset(token)` spends it natively, and `Token.MISSING` is stdlib's sentinel for no-prior-binding. Tokens are process-local handles and are not serializable: they never ride the wire, and cross-process token transport is deferred (see [#231](https://github.com/wool-labs/wool/issues/231)). -Tokens are scoped to the `wool.Context` UUID in which `set()` ran. `ContextVar.reset(token)` raises `ValueError` if the active `wool.Context` UUID differs from the one the token was minted under. This holds across pickle round-trips: a token round-tripped through the wire still carries its originating chain id, and reset only succeeds when the receiver is operating in the same logical chain. `Token.MISSING` is the singleton sentinel returned for `old_value` when the var had no prior binding. +`ContextVar.reset(token)` defers to native validation — stdlib raises `ValueError` for a token spent twice, minted by another variable, or minted in a different `contextvars.Context` — and that last rejection is what keeps concurrent tasks isolated, since a forked task runs in a context copy where the parent's tokens are foreign. Wool layers only its chain bookkeeping on top: a reset that rewinds to no prior value records the pending reset signal so the drop propagates over the wire (see [Wire propagation](#wire-propagation)). -## Wire propagation +## Fork-on-task -Each gRPC dispatch frame carries a `Context` protobuf message containing: +On every loop where Wool gets armed, it installs a task factory (composed with any user-installed factory; Wool's factory must be installed *last* — a third-party factory installed after it silently drops copy-on-fork for subsequently-created tasks, and the next `wool.ContextVar` access raises `wool.TaskFactoryDisplaced` once Wool detects the displacement). Every task the factory creates — armed or unarmed — carries one Wool `add_done_callback` (visible as `cb=[...]` in `repr()`) and has its `contextvars.Context` registered for the chain-contention guard. -- The `wool.Context` UUID (preserved across the call so dispatches under the same chain resolve to the same logical id). -- A list of `ContextVar` entries, each with `(namespace, name)`, an optional serialized value, and consumed-token UUIDs that have reset the var in the sender's chain. +When a new `asyncio.Task` is created in an *armed* context, the factory additionally wraps the child coroutine in a `_forked_scope` coroutine so the child runs under a *forked* context: it inherits the parent's variable bindings but receives a freshly minted chain id. The wrapper has visible consequences for an armed task: `Task.get_coro()` and `repr()`'s coroutine field reflect `_forked_scope` rather than the user coroutine, tracebacks gain one wrapper frame, and the wrapper coroutine surfaces in `TaskGroup`/`gather` exception output. A task created in an *unarmed* context is not wrapped — its coroutine, `get_coro()`, and auto-generated name match a plain `asyncio.Task` created on the same loop (the `Task-N` counter is per loop implementation) — but it still carries the Wool done-callback and guard registration noted above, so it is not byte-for-byte identical to a task created without Wool's factory. -`Context.to_protobuf()` snapshots the sender's bound vars and consumed-token store; only vars that are bound or have consumed tokens emit entries (default-only values are absent). `Context.from_protobuf()` merges each entry into the receiver's `wool.Context`: the var resolves through the var registry (or registers a stub if the receiver hasn't yet imported the declaring module), the value is deserialized into the binding, and consumed-token UUIDs either flip the corresponding live tokens to consumed or are stashed for later reset propagation if no live token has been seen yet. +The fresh chain UUID is correct fork semantics: it isolates concurrent child tasks from each other and from their parent, and it is why a child task cannot `reset` a `Token` minted in its parent's chain. -On the worker side, a fresh `wool.Context` is bound to the worker-loop task, the wire vars are merged in via `Context.update(...)`, and the routine runs inside that `wool.Context`. On completion, the resulting `wool.Context` is snapshotted via `to_protobuf()` and returned in the response so the caller sees mutations made on the worker. Async generators bidirectionally exchange context at every yield/next boundary using the same machinery. +## The chain-contention guard -## Single-task-per-Context +Wool enforces strictly serial execution within a chain: at most one OS thread *and* one `asyncio.Task` may run code under a given chain at a time. The guard has two dimensions. The thread dimension compares the accessing OS thread against the chain's owning thread; the task dimension compares the running `asyncio.Task` against the chain's owning task — both stamped at mount (the `Chain.thread` and `Chain.task` fields). Both are armed-gated — they engage only once a chain exists. -A `wool.Context` enforces "at most one task running inside it at a time." Two layers cooperate: +Cooperatively-scheduled work on the loop's thread never trips the guard. Event-loop callbacks and timers scheduled on the loop thread inherit the scheduling scope's context and run on the loop's own thread with no running task, so they share the chain but can never run concurrently with its owner. The exception is `loop.call_soon_threadsafe` called from another OS thread: it captures *that* thread's context, so a callback scheduled from a thread that armed its own chain runs the chain off its owner thread and raises `wool.ChainContention` — by design, since the chain genuinely spans two threads. Child tasks never trip the guard for the opposite reason: the task factory forks every child onto a fresh chain with its own owner stamps, so a child never enters the parent's chain at all. -- A re-entry check on `wool.Context` itself: any attempt to enter a `wool.Context` that is already running user code raises `RuntimeError`. `Context.run()` and `attached(ctx)` both gate user code through this check, so the discipline holds regardless of which entry point installed the `wool.Context`. -- A one-shot binding contract on the task factory: a task may be bound to a `wool.Context` exactly once at creation. A duplicate binding raises `ContextAlreadyBound` rather than silently stomping the prior chain UUID. +Genuine OS-thread parallelism does trip the thread dimension. Plain `asyncio.to_thread` from an armed context copies the surrounding `contextvars` context — chain UUID and all — into an executor thread, placing a second runner on one chain in parallel; the first `wool.ContextVar` access from that thread raises `wool.ChainContention`. The thread dimension keys on the chain's owning thread — captured when the chain was armed — not on detected concurrency: any `wool.ContextVar` access from a different OS thread raises, even a strictly sequential, never-concurrent resume of an armed context on a second thread by non-asyncio code. That is a deliberate divergence from plain `contextvars`, which permits sequential cross-thread reuse; `wool.to_thread` is the supported cross-thread path. -When a task is created with `context=wool_ctx`, the binding is pinned to that task for its lifetime. A second task targeting the same `wool.Context` while the first is still mid-flight raises immediately as it starts running — even when the two attempts span asyncio loop ticks. +Two `asyncio.Task` objects sharing one chain trip the task dimension. The task factory raises `wool.ChainContention` up front when an *armed* `contextvars.Context` already driving a live task is handed to a second `create_task` — that creation-time rejection is itself armed-gated, so sharing an *unarmed* context across tasks is permitted, exactly as stdlib asyncio permits it. If an unarmed shared context is armed *later*, the task dimension catches the second task the moment it touches a `wool.ContextVar` on a chain another live task already owns. The same access-time catch covers direct `asyncio.Task(...)` instantiation, which bypasses the loop's task factory entirely — no fork occurs, so the new task shares its context's chain and trips the guard on its first `wool.ContextVar` access. Each task should still run in its own context: omit `context=` (asyncio copies it per task by default) or pass a fresh `contextvars.copy_context()` to each. -## Binding APIs +`wool.to_thread(func, *args, **kwargs)` is the supported alternative. It mirrors `asyncio.to_thread` but forks Wool chain state: the worker thread runs under a freshly minted, **detached** chain (a copy of the caller's bindings under a new chain id owned by the worker thread, with no merge-back). Unlike a `@wool.routine` dispatch — which back-propagates the worker's `wool.ContextVar` mutations to the caller — a `wool.to_thread` offload is write-isolated: mutations the offloaded function makes are not visible to the caller, matching `asyncio.to_thread`'s copy-in semantics. Use it whenever Wool-aware work must run in another OS thread. -Two public entry points install a `wool.Context` for a scope: +## Wire propagation -- `attached(ctx)` — synchronous context manager. Installs `ctx` on the current scope; concurrent entry from another task raises `RuntimeError` immediately. Restores the previous binding on exit. -- `wool.create_task(coro, context=ctx)` — pre-binds `ctx` to a freshly-spawned `asyncio.Task` for the coroutine's lifetime. Equivalent at runtime to `asyncio.create_task(coro, context=ctx)` when Wool's task factory is installed; the helper exists purely as a typing shim because stdlib's `context=` is typed for `contextvars.Context` and `wool.Context` cannot subclass it. +Each gRPC dispatch frame carries a `ChainManifest` protobuf message containing the chain's id and a list of `ContextVar` entries — each with `(namespace, name)` and an optional serialized value. An absent value signals a reset-to-no-prior-value on the sender's chain. (Cross-process token transport is deferred — see https://github.com/wool-labs/wool/issues/231.) -Both routes enforce the single-task-per-Context rule. `Context.run(fn, *args, **kwargs)` is the sync-equivalent of `attached(self)`. +- `Chain.to_protobuf()` encodes the sender's chain: bound variables emit valued entries read live from their backings, pending resets emit value-less entries, and emission is sorted so identical chain state encodes byte-identically. Only variables that are bound or reset-pending emit entries (default-only values are absent). +- `ChainManifest.from_protobuf()` decodes a wire message into an inert manifest: each entry resolves through the var registry (or registers a stub if the receiver hasn't yet imported the declaring module) and values are deserialized onto the manifest — decoded but applied to no `contextvars.Context`, so a backpressure hook can inspect caller-shipped values without installing them. +- `ChainManifest.mount()` is the single install pipeline: it drains the manifest's values into backing variables, applies reset signals, and installs the resulting `Chain`. A live receiver merges — the manifest wins the keys it touches and the receiver keeps its chain id; an unarmed receiver, or the worker's initial dispatch frame, installs the manifest wholesale under the sender's chain id. -## Stub promotion +On the worker side, the caller's chain manifest is decoded, installed under the worker-loop task (re-stamped so the worker thread owns the chain), and the routine runs inside it. On completion the resulting chain manifest is encoded onto the response so the caller sees mutations made on the worker. Async generators bidirectionally exchange chain manifests at every yield/next boundary using the same machinery. -A `wool.ContextVar` may arrive on the wire — embedded in a pickled object graph or as an entry in a wire `Context` snapshot — before the receiver's process has imported the module that declares it. To preserve the propagated value across that gap, the receiver registers a placeholder `wool.ContextVar` under the wire key and pins it to the active `wool.Context` so it stays alive. - -Two ingress paths share the same placeholder slot: +## Stub promotion -- A pickled `wool.ContextVar` reconstituting on the receiver carries the original constructor default; the placeholder adopts it. -- A wire `Context` snapshot's var entry carries no default; the placeholder is created default-less and adopts a default later if the pickle path encounters the same key first. +A `wool.ContextVar` may arrive on the wire — embedded in a pickled object graph or as an entry in a wire `ChainManifest` message — before the receiver's process has imported the module that declares it. To preserve the propagated value across that gap, the receiver registers a placeholder `wool.ContextVar` (a *stub*) under the wire key. -When the receiver's user code eventually constructs `wool.ContextVar(...)` under the placeholder's key, the placeholder is promoted in place — the existing instance is returned to the caller, preserving any wire state and reference identity already accumulated. If the receiver never declares the var, the placeholder is collected with its pinning `wool.Context` and the propagated value is dropped. +Two ingress paths share the same placeholder slot via `resolve_stub`: a pickled `wool.ContextVar` reconstituting on the receiver (which carries the original constructor default), and a wire `ChainManifest` entry (which carries no default and is created default-less). A stub stays alive — kept by the embedding object graph for a pickled instance, or by the decoded manifest's `stubs` for a wire entry — until the receiver's user code constructs `wool.ContextVar(...)` under its key, at which point the placeholder is promoted in place. If the receiver never declares the variable, the stub is collected with whatever held it and the propagated value is dropped. ## `RuntimeContext` and `dispatch_timeout` -`RuntimeContext` carries block-scoped runtime option overrides — currently just `dispatch_timeout` — separate from the `ContextVar` snapshot in `wool.Context`. It auto-captures on every `Task` at construction time and rides the wire on every dispatch, encoded by `RuntimeContext.to_protobuf()`, so the worker restores the caller's effective `dispatch_timeout` before running the routine. +`RuntimeContext` carries block-scoped runtime option overrides — currently just `dispatch_timeout` — separate from the `wool.ContextVar` context. It auto-captures on every `Task` at construction time and rides the wire on every dispatch, encoded by `RuntimeContext.to_protobuf()`, so the worker restores the caller's effective `dispatch_timeout` before running the routine. ```python import wool @@ -92,13 +81,11 @@ with wool.RuntimeContext(dispatch_timeout=30): result = await my_routine() ``` -`dispatch_timeout` itself is a stdlib `contextvars.ContextVar[float | None]`, exposed at module scope so `RuntimeContext.__enter__` / `__exit__` can `set` and `reset` it through the standard `contextvars` mechanism. This is the one place inside the context subpackage where stdlib `contextvars` is used directly: `RuntimeContext` is a stdlib-flavoured override surface, distinct from `wool.Context` (which carries `wool.ContextVar` snapshots). - -A bare `RuntimeContext()` with no explicit `dispatch_timeout` substitutes the live `dispatch_timeout` value at encode time, so a `Task` constructed for wire transport propagates the encoder's effective timeout without the caller having to materialize the override explicitly. +`dispatch_timeout` itself is a plain stdlib `contextvars.ContextVar[float | None]`, distinct from the Wool-owned context variable. `RuntimeContext.__enter__` / `__exit__` `set` and `reset` it through the standard `contextvars` mechanism. ## Decode failures -Wire-context decode is **ancillary state** in Wool's protocol contract: a failure to decode an incoming context never preempts the routine's primary signal (its return value or raised exception). Both `Context.to_protobuf` and `Context.from_protobuf` emit `wool.ContextDecodeWarning` for per-var encode/decode failures and skip the offending entry; surviving vars decode normally. A malformed wire context UUID falls back to a fresh UUID with the failure recorded as the same warning class. +Chain-manifest decode is **ancillary state** in Wool's protocol contract: a failure to decode an incoming chain manifest never preempts the routine's primary signal (its return value or raised exception). Both `Chain.to_protobuf` and `ChainManifest.from_protobuf` emit `wool.SerializationWarning` for per-variable encode/decode failures and skip the offending entry (a failed encode suppresses the variable's reset signal too, so the receiver cannot read a phantom reset); surviving variables decode normally. A malformed wire chain id is the exception: without it the receiver cannot correlate frames to a chain, so it raises `wool.ChainSerializationError` unconditionally — a structural protocol error, not ancillary per-variable state. Promote the warning to an error to opt into strict mode: @@ -106,7 +93,22 @@ Promote the warning to an error to opt into strict mode: import warnings import wool -warnings.filterwarnings("error", category=wool.ContextDecodeWarning) +warnings.filterwarnings("error", category=wool.SerializationWarning) ``` -Under strict mode, per-var failures aggregate into a single `BaseExceptionGroup` raised after the encode/decode loop completes — every bad var surfaces, not just the first. Worker-side strict mode is enabled via `PYTHONWARNINGS="error::wool.ContextDecodeWarning"`, which `multiprocessing` propagates to spawned subprocesses by default. The worker ships the promoted warning back through the routine-exception channel, so the caller catches the same `wool.ContextDecodeWarning` class symmetrically. See the top-level [`wool/README.md`](../../../../README.md#decode-failure-semantics) for the full lenient/inspect/strict modes. +Under strict mode, per-entry failures aggregate into a single `wool.ChainSerializationError` (a `wool.WoolError` subclass with the warnings on `.warnings`) raised after the encode/decode loop completes — every bad entry surfaces, not just the first. Strict mode is **fatal to the whole frame**: the partial chain manifest that `from_protobuf` accumulated (the surviving entries that decoded successfully) is *not* surfaced — every entry on the wire frame is discarded along with the failing one. Callers that need partial application semantics under strict mode must run a non-strict decode first. + +Worker-side strict mode is enabled via `PYTHONWARNINGS="error::wool.SerializationWarning"`, which `multiprocessing` propagates to spawned subprocesses by default. The worker ships the aggregated `wool.ChainSerializationError` back through the routine-exception channel, so the caller catches the same error class symmetrically. When a routine *also* raises a primary exception, the `wool.ChainSerializationError` rides on it as `__cause__` via `raise primary from decode_err`, preserving the primary's class so `except RoutineError:` keeps matching. See the top-level [`wool/README.md`](../../../../README.md#decode-failure-semantics) for the full lenient/inspect/strict modes. + +`wool.TaskFactoryDisplaced` follows different semantics: it is **not** ancillary state and **not** a tunable warning. Factory displacement is structurally fatal to chain propagation across every subsequent task on the loop, so `TaskFactoryDisplaced` is raised unconditionally — there is no strict-mode opt-in because there is no graceful-mode default. The raise escapes the next user-Wool operation that triggers detection (including the response-merge install path) and preempts the routine's primary signal. The primary signal remains visible via `__context__` chaining. This asymmetry is deliberate: factory displacement is loud-fail-fast on infrastructure breakage, distinct from per-variable wire corruption (`wool.SerializationWarning`) where graceful degradation is the right default. + +## Exceptions and warnings + +The context subsystem publishes six typed signals on the `wool` barrel. They appear inline above; this list collects them so callers know what to catch and what to promote: + +- `wool.ChainContention` (exception) — raised when a chain is entered by a thread or asyncio task other than the one that owns it (see [the chain-contention guard](#the-chain-contention-guard) above). It is a `wool.WoolError` subclass carrying structured diagnostic fields (`chain_id`, `kind`, `owning_thread`/`current_thread`, `owning_task`/`current_task`) in addition to the human-readable message. +- `wool.ContextVarCollision` (exception) — raised on construction of a second `wool.ContextVar` under an existing `(namespace, name)` key. Library authors should pass `namespace=` explicitly when constructing variables from shared factory code; application code can rely on the implicit package-name inference. +- `wool.SerializationWarning` (warning) — emitted when a wire `ChainManifest` fails to encode or decode (see [Decode failures](#decode-failures) above). A `wool.WoolWarning` subclass; promote to an exception via `warnings.filterwarnings("error", ...)` (or `PYTHONWARNINGS="error::wool.SerializationWarning"` on workers) to opt into strict mode. +- `wool.SerializationError` (exception, `wool.WoolError` subclass) — the base raised when a single value fails to encode across the wire. Its strict-mode subclass `wool.ChainSerializationError` aggregates per-variable chain-manifest failures, so catching `wool.SerializationError` matches every wire serialization failure, atomic or aggregated. +- `wool.ChainSerializationError` (exception, `wool.WoolError` subclass) — raised when wire encode or decode fails under strict mode. Aggregates the promoted `wool.SerializationWarning` instances on `.warnings` (a tuple). Routine code typically does not see this directly: a result-frame decode failure raises it as the routine's primary; an exception-frame decode failure rides on the routine's exception as `__cause__` via `raise from`. +- `wool.TaskFactoryDisplaced` (exception, `wool.WoolError` subclass) — raised when Wool's task factory has been displaced from a loop it was previously installed on. Subsequent child tasks no longer fork onto fresh chains; a non-forked child inherits its parent's owning-task identity and trips `wool.ChainContention` on its first `wool.ContextVar` access. Detected reactively on the next `wool.ContextVar` access; raised unconditionally because displacement is structurally fatal to chain propagation across every subsequent task on the loop. Install Wool's factory last, or compose factories manually, to avoid this. diff --git a/wool/src/wool/runtime/worker/README.md b/wool/src/wool/runtime/worker/README.md index a59d1df1..518c3a5c 100644 --- a/wool/src/wool/runtime/worker/README.md +++ b/wool/src/wool/runtime/worker/README.md @@ -75,7 +75,7 @@ async with wool.WorkerPool(spawn=4, discovery=wool.LanDiscovery()): | `metadata` | `WorkerMetadata \| None` | Full metadata including address, tags, version, and transport options. `None` before `start()`. | | `tags` | `set[str]` | Capability tags for filtering and selection. | | `extra` | `dict[str, Any]` | Arbitrary key-value metadata. | -| `address` | `str \| None` | gRPC target address (e.g. `"host:port"`, `"unix:path"`). `None` before `start()`. | +| `address` | `str \| None` | gRPC target address (e.g., `"host:port"`, `"unix:path"`). `None` before `start()`. | `LocalWorker` is the built-in implementation: @@ -192,24 +192,24 @@ Each worker subprocess has a two-loop architecture: Per-dispatch, the gRPC handler instantiates a `DispatchSession` — an async context manager and async iterator that owns the dispatch's full worker-side lifecycle as a uniform driver for both coroutine and async-generator Wool routines. The session has four phases: -1. **Parsing** (`__aenter__`) — reads the first request frame, decodes the caller's `wool.Context` snapshot and rebuilds the `wool.Task` (both via `cloudpickle`), and validates the routine type. Failures wrap in `Rejected` and surface via a `Nack` carrying the typed exception (see _Exception flow_ below). -2. **Iteration** (`__aiter__`) — schedules the worker driver lazily on first call so the handler's pre-iteration decisions (i.e., backpressure hook) run before the worker takes the `wool.Context` guard. The worker driver enters a routine-scoped context manager for the parsed task and drives the routine iteratively, with the cross-loop bridge mediated by a pair of queues. Coroutine Wool routines synthesize a single `next` request internally; async-generator routines are driven by iteration commands issued by the client, mirroring the standard library's async-generator semantics. +1. **Parsing** (`__aenter__`) — reads the first request frame, decodes the caller's chain manifest and rebuilds the `wool.Task` (both via `cloudpickle`), and validates the routine type. Failures wrap in `Rejected` and surface via a `Nack` carrying the typed exception (see _Exception flow_ below). +2. **Iteration** (`__aiter__`) — schedules the worker driver lazily on first call so the handler's pre-iteration decisions (i.e., backpressure hook) run before the worker task installs the work context on its own thread. The worker driver enters a routine-scoped context manager for the parsed task and drives the routine iteratively, with the cross-loop bridge mediated by a pair of queues. Coroutine Wool routines synthesize a single `next` request internally; async-generator routines are driven by iteration commands issued by the client, mirroring the standard library's async-generator semantics. 3. **Teardown** (`__aexit__`) — drains the worker driver and unwinds the exit stack. Drain is registered as an exit-stack callback so resource release runs even if drain itself raises. 4. **Cancellation** (`cancel`) — sets a flag observed by both the iteration loop and the deferred scheduler, cancels the worker driver task on the worker loop so a routine suspended inside an `await` receives `CancelledError`, and pushes an end-of-stream frame onto the response queue. Cancellation is idempotent and cross-task safe (no `aclose` of the iterator, so a `cancel` call from any task — including the service's preemption path during graceful shutdown — does not race the driving task). -The dispatch handler decodes the wire context into a fresh `wool.Context` inside the session's parse phase and schedules the routine on the worker loop with that instance as the `context=` argument to `loop.create_task`. Wool's event loop task factory routes the explicit `wool.Context` through its scoped-binding path, registering the same instance against the worker task — not a copy — so mutations under the worker task are observable to the handler when it later snapshots the `wool.Context` for back-propagation. +The dispatch handler decodes the chain manifest into a `ChainManifest` inside the session's parse phase. The worker driver task's first action is to install that chain manifest — re-stamped so the worker-loop thread owns the chain — and the routine then runs under it. When the decoded caller frame carries no state, the mount is skipped — the worker task runs unarmed, matching the armed-gating contract documented in `context/README.md`. A later mid-stream frame with state arms lazily through the `ChainManifest.mount` pipeline inside `_drive_step`. After each step the worker publishes the post-step chain manifest onto the session so the handler can encode it for back-propagation; the handler reads it only after draining the worker, so the cross-thread read is race-free. -### Context decode failures +### Chain-manifest decode failures -Wire context is **ancillary state** in Wool's protocol contract: a failure to decode an incoming context — whether on the initial dispatch frame, a mid-stream frame, or a back-propagated response — never preempts the routine's primary signal (its return value or raised exception). The worker's contract on each side: +The chain manifest is **ancillary state** in Wool's protocol contract: a failure to decode an incoming chain manifest — whether on the initial dispatch frame, a mid-stream frame, or a back-propagated response — never preempts the routine's primary signal (its return value or raised exception). The worker's contract on each side: -- **Initial-frame decode failure (request).** The routine still runs, with a fresh empty `wool.Context` as fallback. A `wool.ContextDecodeWarning` is emitted on the worker. -- **Mid-stream decode failure (request).** The current iteration continues without applying the upstream merge. A `wool.ContextDecodeWarning` is emitted on the worker. -- **Snapshot encode failure (response).** The back-propagated wire context is replaced with an empty context; the response still carries the routine's result or exception. A `wool.ContextDecodeWarning` is emitted on the worker. When the snapshot encode failure coincides with a routine exception, the snapshot failure additionally rides on the routine exception via `__notes__` so the caller's traceback shows both signals. +- **Initial-frame decode failure (request).** In non-strict mode each unreadable entry is dropped with a `wool.SerializationWarning` and the routine runs under whatever partial chain manifest decoded; an empty or entirely-unreadable frame leaves the worker context unarmed. In strict mode the promoted warnings aggregate into a `wool.ChainSerializationError` that is shipped via the `Nack` channel — the routine does not run. +- **Mid-stream decode failure (request).** In non-strict mode each unreadable entry is dropped with a `wool.SerializationWarning` and the surviving partial chain manifest is still merged into the active work context before the step proceeds. In strict mode the promoted warnings aggregate into a `wool.ChainSerializationError` that propagates out of the step as the routine's terminal failure — shipped via the routine-exception channel, just like any other routine-time exception. +- **Chain-manifest encode failure (response).** The back-propagated chain manifest is replaced with an empty chain manifest; the response still carries the routine's result or exception. A `wool.SerializationWarning` is emitted on the worker. When a strict-mode encode failure coincides with a routine exception, the resulting `wool.ChainSerializationError` is chained onto the routine exception via `raise routine_exc from encode_err` so the caller's traceback shows both signals. -The caller side mirrors this contract: response-context decode failures emit `wool.ContextDecodeWarning` on the caller and never preempt the routine's outcome. See the top-level [`wool/README.md`](../../../../README.md#decode-failure-semantics) for the full lenient/inspect/strict modes. +The caller side mirrors this contract: response chain-manifest decode failures emit `wool.SerializationWarning` on the caller and never preempt the routine's outcome. See the top-level [`wool/README.md`](../../../../README.md#decode-failure-semantics) for the full lenient/inspect/strict modes. -Worker-side strict mode is enabled via Python's standard `PYTHONWARNINGS` environment variable (which `multiprocessing` propagates to spawned worker subprocesses by default). When the worker promotes the warning to an exception, the dispatch handler catches it before the routine starts and ships it via the routine-exception channel, so the caller observes a `wool.ContextDecodeWarning` raised — symmetric with caller-side strict mode rather than a generic gRPC error. Promotions raised after the routine starts surface through the existing routine-exception machinery. +Worker-side strict mode is enabled via Python's standard `PYTHONWARNINGS` environment variable (which `multiprocessing` propagates to spawned worker subprocesses by default). When the worker promotes the warning to an exception, the dispatch handler catches it before the routine starts and ships it via the routine-exception channel, so the caller observes a `wool.SerializationWarning` raised — symmetric with caller-side strict mode rather than a generic gRPC error. Promotions raised after the routine starts surface through the existing routine-exception machinery. ### Exception flow @@ -218,14 +218,14 @@ Worker-side failures route through one of three exit channels. See the top-level | Source | Surface | Caller observes | | ------ | ------- | --------------- | | Parse-phase failure (`Rejected` from `__aenter__`) | `Nack` frame with cloudpickled `original` exception | Original exception re-raised, type and traceback preserved | -| Routine-time exception (raised inside `_step`) | Terminal `Response.exception` with cloudpickle-dumped exception + post-step context snapshot | Original exception re-raised, type and traceback preserved | -| Handler-level encoding failure (result dump fails, strict-mode context encode raises) | Same terminal `Response.exception` channel; either ships the encode failure directly (result dump) or attaches `wool.ContextDecodeWarning` peers to the routine exception via PEP 678 `__notes__` (context encode during routine exception) | Either the encode failure or the routine exception with notes attached | +| Routine-time exception (raised inside `_drive_step`) | Terminal `Response.exception` with cloudpickle-dumped exception + post-step chain manifest | Original exception re-raised, type and traceback preserved | +| Handler-level encoding failure (result dump fails, strict-mode chain-manifest encode raises) | Same terminal `Response.exception` channel; either ships the encode failure directly (result dump) or chains the resulting `wool.ChainSerializationError` onto the routine exception's `__cause__` via `raise ... from` (chain-manifest encode during routine exception) | Either the encode failure or the routine exception with the encode failure on `__cause__` | `VersionInterceptor` aborts incoming requests with `FAILED_PRECONDITION` before the dispatch handler runs; that surfaces on the caller as a non-transient `RpcError` and is **not** routed through the `Nack` channel. The `Nack` frame's purpose is to ship a typed parse-phase exception so the caller observes the **actual failure class** rather than an opaque RPC error. A `Nack` only appears pre-Ack; once the dispatch handler yields an `Ack`, all further terminal signals ride on `Response.exception`. The dispatch FSM is `Ack? (Result* (Exception | ε)) | Nack`. -Operator-initiated cancellation (graceful shutdown) flows through the routine-exception channel: `WorkerService._cancel` invokes `DispatchSession.cancel` on every in-flight dispatch, the worker task is cancelled on the worker loop, and `CancelledError` rides on the terminal frame. The caller's `await routine()` raises `CancelledError` — indistinguishable from caller-initiated or routine-self-cancellation, matching stdlib's `await task` semantics. +Operator-initiated cancellation (graceful shutdown) flows through the routine-exception channel: `WorkerService._preempt` invokes `DispatchSession.cancel` on every in-flight dispatch, the worker task is cancelled on the worker loop, and `CancelledError` rides on the terminal frame. The caller's `await routine()` raises `CancelledError` — indistinguishable from caller-initiated or routine-self-cancellation, matching stdlib's `await task` semantics. ### Dispatch protocol