diff --git a/sdk/agentserver/.gitignore b/sdk/agentserver/.gitignore index 89f79044a692..5abe2bfded44 100644 --- a/sdk/agentserver/.gitignore +++ b/sdk/agentserver/.gitignore @@ -1,4 +1,11 @@ # Speckit / Specify - spec-driven development tooling specs/ -.specify/ +.specify/* +!.specify/memory/ +.specify/memory/* +!.specify/memory/constitution.md .github/ +.vscode/ + +# Demo session state — regenerated each time the demo runs +.demo-session diff --git a/sdk/agentserver/.specify/memory/constitution.md b/sdk/agentserver/.specify/memory/constitution.md new file mode 100644 index 000000000000..0879d3b2aa58 --- /dev/null +++ b/sdk/agentserver/.specify/memory/constitution.md @@ -0,0 +1,517 @@ +# Azure AI AgentServer SDK Constitution + +## Core Principles + +### I. Modular Package Architecture + +Every feature belongs to a clearly scoped package within the `sdk/agentserver` family. Packages are independently versioned, installable, and testable. The four packages form a layered architecture: + +- **azure-ai-agentserver-core** (v2.x) — Foundation utilities, ASGI host framework, config, tracing, middleware. +- **azure-ai-agentserver-invocations** (v1.x) — Invocation protocol (execute, poll, cancel). +- **azure-ai-agentserver-responses** (v1.x) — Responses protocol (streaming SSE, storage, models). +- **azure-ai-agentserver-githubcopilot** (v1.x) — GitHub Copilot SDK adapter layer. + +Dependencies flow downward only: `githubcopilot` → `responses` → `core`; `invocations` → `core`. No circular or lateral dependencies between protocol packages. Adding new cross-package dependencies requires justification and review. + +### II. Strong Type Safety (NON-NEGOTIABLE) + +All code must use precise, explicit type annotations. This is enforced by mypy (`disallow_untyped_defs: true`), pyright, and verifytypes. + +- **Prefer concrete types over `Any` and `dict`**. Use dataclasses, `TypedDict`, `NamedTuple`, `Protocol`, or custom model classes instead of raw `dict[str, Any]`. +- **Use `collections.abc` for abstract types**: `Callable`, `Awaitable`, `AsyncIterator`, `AsyncIterable`, `Sequence`, `Mapping` — not their mutable concrete counterparts unless mutation is required. +- **Use `str | None` (PEP 604)** over `Optional[str]` in new code. Both are acceptable in existing code. +- **All public functions, methods, and class attributes** must have complete type annotations including return types (use `-> None` for void). +- **Use `Literal[...]`** for fixed string values (status codes, mode flags, event types). +- **Use `TYPE_CHECKING` guards** only for circular import resolution or expensive imports — not as a general pattern. +- **Include `py.typed`** (PEP 561) marker in every package. +- **Type ignore comments** must include specific error codes and a brief justification: `# type: ignore[assignment] # reason`. +- **TypeVar naming**: Covariant suffixed `_co`, contravariant suffixed `_contra`. +- **Mark Protocols `@runtime_checkable`** when used for duck-typing checks. +- **PEP 484 inline style only**: Never use comment-style type hints (`# type:`). + +```python +# ✅ GOOD — precise types +from collections.abc import AsyncIterator, Callable, Awaitable +from typing import Literal + +Status = Literal["created", "in_progress", "completed", "failed"] + +class ResponseExecution: + status: Status + output_items: list[OutputItem] + +async def process(items: Sequence[InputItem]) -> AsyncIterator[Event]: ... + +# ❌ BAD — vague types +def process(items: list) -> dict: ... +def handle(data: Any) -> Any: ... +config: dict = {} +``` + +### III. Azure SDK Design Guidelines Compliance + +All packages follow the [Azure SDK Python Design Guidelines](https://azure.github.io/azure-sdk/python_design.html) and this repo's AGENTS.md / CONTRIBUTING.md conventions: + +- **Naming**: Packages use `azure-ai-agentserver-{component}` format. Namespace: `azure.ai.agentserver.{component}`. Namespace `__init__.py` files use `pkgutil.extend_path()`. +- **Versioning**: Semantic versioning (`MAJOR.MINOR.PATCH`). Preview: `X.Y.ZbN`. Version stored in `_version.py`, read dynamically by `pyproject.toml` via `[tool.setuptools.dynamic]`. + - `_version.py` must match the latest version in `CHANGELOG.md`. + - Preview packages: `is_stable = false` and classifier `Development Status :: 4 - Beta` in `pyproject.toml`. + - Stable packages: `is_stable = true` and classifier `Development Status :: 5 - Production/Stable`. +- **Line length**: 120 characters max. +- **Formatting**: Black-formatted (`azpysdk black .`). No exceptions. +- **Code style**: Follow [PEP 8](https://peps.python.org/pep-0008/). Naming: modules `snake_case`, classes `PascalCase`, functions/methods/variables `snake_case`, constants `UPPER_CASE`. +- **Imports**: Standard library → third-party → local (relative). Use `from __future__ import annotations` in modules with complex type annotations. No star imports except from `_generated` subpackages. +- **CHANGELOG**: Maintained per package. Unreleased section uses explicit version header (e.g., `## 1.0.0b5 (Unreleased)`) with standard subsections: `### Features Added`, `### Breaking Changes`, `### Bugs Fixed`, `### Other Changes`. +- **MANIFEST.in**: Must include `py.typed`, `azure/__init__.py`, and recursively include samples, tests, and docs. + +### IV. Async-First Design + +The AgentServer SDK is inherently asynchronous. All I/O-bound operations use `async def` / `await`. + +- **ASGI-native**: Server hosts are Starlette subclasses. Middleware must be pure ASGI (no `BaseHTTPMiddleware`). +- **Streaming**: Use `AsyncIterator` with `yield` for SSE event streams. Wrap with `StreamingResponse`. +- **Cancellation**: Use `asyncio.Event` for cooperative cancellation signals. +- **Background tasks**: Use `asyncio.Task` for fire-and-forget work with proper error logging. +- **Handler validation**: All registered handlers must be coroutine functions. Validate with `inspect.iscoroutinefunction()` and raise `TypeError` if not. +- **Context propagation**: Use `contextvars.ContextVar` for request-scoped state (request IDs, invocation IDs). + +### V. Fail-Fast Configuration, Graceful Runtime + +- **Startup**: Validate all required environment variables (`PORT`, `FOUNDRY_AGENT_NAME`, `FOUNDRY_AGENT_VERSION`, etc.) and configuration at initialization. Raise immediately on missing or invalid config — do not defer failures to request time. +- **Observability failures**: Log warnings but never crash the server. Tracing/telemetry is best-effort. +- **Handler errors**: Return structured error responses via `create_error_response(code=..., message=..., status_code=...)`. Never leak stack traces to clients. +- **Custom exceptions**: Define domain-specific exceptions (e.g., `FoundryStorageError`, `FoundryResourceNotFoundError`) with clear error codes. +- **Broad catches**: `except Exception` is permitted only at top-level dispatch boundaries with explicit `# pylint: disable=broad-exception-caught` and proper logging. +- **Azure Core exceptions**: Use `azure.core.exceptions` hierarchy (e.g., `HttpResponseError`) for client-facing errors where applicable. + +### VI. Observability & Correlation + +- **Logging**: Module-level logger via `logging.getLogger("azure.ai.agentserver.{component}")`. Use structured key-value logging. No print statements. +- **Tracing**: OpenTelemetry integration via `azure-ai-agentserver-core`. GenAI semantic conventions for spans (`gen_ai.system`, `gen_ai.operation.name`, `gen_ai.agent.name`). +- **Correlation**: Propagate `x-request-id` and `x-ms-client-request-id` headers. Auto-generate from trace ID, header, or UUID. Use `contextvars` for in-process correlation. +- **Metrics**: Export via Azure Monitor (`APPLICATIONINSIGHTS_CONNECTION_STRING`) or OTLP (`OTEL_EXPORTER_OTLP_ENDPOINT`). Expose health endpoints (`/health/live`, `/health/ready`). +- **Graceful shutdown**: Handle `SIGTERM` with configurable drain timeout (default 30s). + +### VII. Test-Driven Development (TDD) + +All new feature code follows test-driven development: + +- **Write tests first**: Before implementing any feature or fixing a bug, write a failing test that defines the expected behavior. +- **Red → Green → Refactor**: Tests must fail before implementation (Red), pass with minimal code (Green), then be cleaned up (Refactor). +- **Acceptance tests from spec**: User story acceptance scenarios in the spec translate directly into test cases during the tasks phase. These are written before implementation begins. +- **Contract tests for interfaces**: When a spec defines a new interface, protocol, or API surface, write contract tests that validate the interface shape before implementing the internals. +- **No untested features**: A feature is not complete until its tests pass. Code without corresponding tests is considered incomplete regardless of whether it "works." +- **Tests drive design**: Let the test-writing process inform API ergonomics. If something is hard to test, it's likely hard to use — simplify the design. + +```python +# ✅ GOOD — test written first, defines expected behavior +async def test_resilient_task_resumes_after_crash(): + """Handler is re-invoked with metadata intact after simulated crash.""" + app = create_test_app(resilient_background=True) + # ... setup, crash simulation, assertion ... + assert response.status == "completed" + assert response.output[0].content == "resumed result" + +# ❌ BAD — implementation without a test +# "I'll add tests later" → tests never get added +``` + +### VIII. Minimal Surface, Maximum Composability + +- **Decorator-based registration**: Handlers registered via `@app.invoke_handler`, `@app.response_handler`. Decorators return the function unmodified. +- **Cooperative MRO**: Multi-protocol hosts compose via multiple inheritance: `class MyHost(InvocationAgentServerHost, ResponsesAgentServerHost)`. Each protocol class merges its routes with `super().__init__()`. +- **Builder patterns**: Streaming APIs use fluent builders (`ResponseEventStream.emit_created().emit_in_progress()...`). +- **Lazy resolution**: Expensive computations (input resolution, history loading) use async-cached properties. +- **No unnecessary abstractions**: Prefer simple functions over class hierarchies. Use `Protocol` for structural typing rather than deep inheritance trees. + +### IX. Docs ↔ Samples Feedback Loop (NON-NEGOTIABLE) + +Developer-facing guides are the authoritative source of guidance — samples are validation that the guidance produces correct outcomes when followed mechanically. + +This principle is adjacent to TDD (Principle VII) but distinct: TDD validates behaviour via tests; this principle validates *guidance* via samples. + +**The loop:** + +1. **Write or update the guide first.** Before writing or rewriting a sample, write or update the relevant section of the developer guide (e.g. `handler-implementation-guide.md`, `resilient-responses-developer-guide.md`). The guide defines the mental model, rules, and layered responsibilities (library ↔ handler ↔ upstream framework). The guide does NOT teach individual upstream frameworks; it teaches the contract. +2. **Write the sample by mechanically applying the guide.** Pretend you are a developer reading the guide for the first time. Implement the sample using *only* the guidance in the guide. Do not import knowledge that isn't in the guide. +3. **If the sample comes out wrong, the guide is wrong.** Fix the guide first. Do not patch the sample to work around guide gaps. +4. **Re-derive the sample from the corrected guide.** Repeat until both guide and sample are internally consistent. +5. **Test the guide via samples.** Every guide section that prescribes a pattern must have at least one sample that demonstrates that pattern end-to-end, with an automated test asserting the prescribed outcome. +6. **Run the applicable review checklist.** Before marking a sample done, run the relevant checklist from `.specify/templates/` against it. For resilient response samples, that is `resilience-sample-checklist-template.md`. A sample with any failing checklist item is incomplete — triage the failure (guide gap / sample bug / test gap / spec gap) and loop back to the earliest applicable step. + +**Guide responsibilities:** + +- Define the mental model (what each layer owns). +- State the contract between layers (what each layer guarantees and requires). +- Prescribe patterns for the canonical cases. +- Document fallback behaviour for the no-opt-in case. +- **Stay framework-agnostic in the body.** Reference upstream frameworks (Claude SDK, Copilot SDK, LangGraph, etc.) only as concrete examples illustrating an already-stated rule. + +**Sample responsibilities:** + +- Demonstrate the guide's patterns end-to-end against a real upstream framework. +- Carry the framework-specific reconciliation steps the guide deliberately omits. +- Include an automated test that proves the prescribed outcome holds. +- Pass the applicable review checklist before being marked done. + +**Review checklists:** + +Mechanical review of samples uses checklists stored under `.specify/templates/`: + +- `resilience-sample-checklist-template.md` — for any resilient response handler sample (covers crash, shutdown, steering, client cancel). Required before any resilient sample is shipped. + +New canonical sample categories MUST get a matching checklist template. Each checklist item references the constitutional principle or spec FR it enforces, so a checklist failure is traceable to a specific contract. + +**What this means for specs:** + +Every spec that touches developer-facing samples MUST include a "Docs ↔ Samples Loop" section spelling out: + +- Which guide(s) own the contract being specified. +- The sequence: guide changes first, then samples, then re-validation via the applicable checklist. +- The acceptance criterion: a developer following the guide alone (without reading framework source) can produce a sample that passes the checklist. + +```python +# ✅ GOOD — guide first, sample derived from guide, checklist closes the loop +# 1. handler-implementation-guide.md updated with recovery contract. +# 2. sample_17_resilient_claude.py implemented by following the guide. +# 3. Sample's test fails → guide is missing the "claude_query_in_flight watermark" pattern. +# 4. Guide updated with the watermark pattern. +# 5. Sample re-derived from updated guide → test passes. +# 6. resilience-sample-checklist run against sample → 30/30 pass → sample marked done. + +# ❌ BAD — sample first, guide retro-fitted, no checklist +# 1. sample_17 written by reading Claude SDK source. +# 2. Guide updated to vaguely match what the sample does. +# 3. A developer reading the guide cannot reproduce the sample's correctness. +# 4. Three weeks later, a different reviewer finds the same crash-recovery +# gap that was already "fixed" — because no checklist ever caught it. +``` + +### X. Resilience Contract Conformance (NON-NEGOTIABLE) + +The resilience behavior of `azure-ai-agentserver-responses` is specified in the source-of-truth resilience contract. Every row of its matrix has an observable contract; every contract MUST be backed by a behavioral test that exercises it end-to-end through real signals. + +**Why this principle exists**: the framework's documented resilience matrix once diverged silently from its implementation for three rows. Five overlapping failure modes let those divergences ship: tests asserted helper behavior instead of contract behavior, crash-injection tests were deferred and never picked up, helpers were built without wiring, no single contract validated the matrix as an end-to-end seam, and no structural guard required matrix coverage. This principle is the structural guard. + +**The rule:** + +1. **Every row of `resilience-contract.md` §The matrix MUST have a behavioral test in `tests/e2e/resilience_contract/` exercising every applicable termination path via real signals:** + - **Path A** (graceful shutdown, handler completes within grace): SIGTERM with grace period set sufficiently long for the handler to complete naturally. + - **Path B** (graceful shutdown, grace exhausted): SIGTERM with grace period set deliberately short so the handler is still running at grace expiry, forcing the in-process marker / hand-off to fire before subprocess exit. + - **Path C** (crash, or Path-B failure): SIGKILL via `_crash_harness` mid-handler, followed by subprocess restart. +2. **Where the matrix collapses `stream`, the test MUST run its assertions for both `stream=False` and `stream=True`** (parametrized). +3. **The `test_contract_completeness.py` meta-test** parses `resilience-contract.md` and fails CI if any (row, applicable path) is missing a paired test module, OR if any module is missing one of the parametrize ids the matrix requires. +4. **Any spec or pull request that affects code in the resilience surface** (orchestrator routing, in-process shutdown loop, resilient-task primitive integration, stream provider, response store terminal-persist hooks) **MUST land its conformance tests RED before the implementation commit goes green.** The reviewer verifies test-first ordering from the commit history. +5. **Synthetic-crash shortcuts are explicitly disallowed for conformance tests:** + - MUST NOT mock `_crash_harness`. + - MUST NOT fabricate a `ResilienceContext` to simulate recovery. + - MUST NOT call internal failure-marker functions (e.g. `_persist_crash_failed`) directly to simulate Path B or Path C. + - MUST NOT use a test-only injection to control grace timing; use the framework's real `shutdown_grace_period_seconds` configuration. + +**Adding or modifying a row:** any spec that adds a new row to the matrix, or modifies the contract on an existing row, MUST follow `resilience-contract.md` §Change control: amend the contract doc, update the conformance suite (RED first, then GREEN after implementation), and update the dev guide / handler guide in the same PR as the implementation. + +**Reviewer checklist for PRs touching resilience:** + +- [ ] Which rows of `resilience-contract.md` §The matrix does this change affect? +- [ ] Are the conformance tests for those rows in the PR? +- [ ] Did those tests land RED before the implementation commit (verifiable from git history)? +- [ ] Did the dev guide / handler guide need updates? Are they in this PR? + +This principle is referenced by `resilience-contract.md` §Test discipline; the two stay in sync via cross-reference. The resilience test suite, meta-test, Constitution principle, and template gate implement the structural pieces. + +### XI. Contract-Surface Test Depth (NON-NEGOTIABLE) + +Conformance tests MUST verify the row's full contract surface, not just terminal status. Shape-only assertions (e.g. `response.status == "completed"`) are necessary but not sufficient; they pass whenever any code path reaches a terminal of the right type and miss content-level drift entirely. + +**Why this principle exists**: a streaming-recovery-continuity bug (fix `1e69dba385`) slipped through Principle X's structural gate. Every (row × path) cell had a paired test, all GREEN, but the tests asserted only on `terminal["status"]`. The bug — that pre-crash SSE events were being erased by the recovered handler's terminal-time `save_stream_events` — was invisible because: + +- The conformance handler emitted a single `"ok"` delta. Pre-crash content and recovered content were byte-identical, so cross-attempt drift was indistinguishable. +- The tests asked "did recovery happen?" (yes, `status="completed"`) but never asked "did the persisted stream contain the right events in the right order?". + +Principle X (every cell has a paired test) was satisfied. Principle XI is the depth complement. + +**The rule:** + +1. **Per-cell tests MUST verify the contract surface that the cell's mode flags expose to clients:** + - **For cells with `stream=true`:** event sequence ordering, per-event content (delta text, item shape, content-part fields), sequence-number monotonicity across recovery attempts, and the final terminal event's `response` payload. Pre-crash events MUST be verified to survive in the persisted stream for cells where the contract claims cross-attempt continuity (Row 1). + - **For cells with `stream=false`:** `response.status`, `response.output` (the assembled output items including their content text), and `response.error` (for failure cells). For polled / background cells, the polled snapshot IS the contract surface; the test MUST assert on its content, not just its terminal type. + +2. **The conformance test handler MUST emit per-lifetime-identifiable content** so cross-attempt assertions are sensitive to drift. The current handler at `tests/e2e/resilience_contract/_test_handler.py` tags every delta with `f"L{lifetime}_..."` and the final text with a composite `f"L{lifetime}_done|pre=N|chain=…|visited=…"` — tests parse these markers to confirm which lifetime produced which event. Content like `"ok"` that's identical across lifetimes is DISALLOWED in this handler. + +3. **The contract coverage matrix at `tests/e2e/resilience_contract/CONTRACT_COVERAGE.md` MUST map every normative clause in `resilience-contract.md` to the test(s) that verify it.** Cells marked `**GAP**` are explicit findings; they MUST be filled or explicitly justified (with a `n/a` rationale) before the next contract amendment ships. + +4. **The `test_contract_coverage_matrix_exists_and_is_non_trivial` meta-test** enforces that every conformance test file is referenced in the matrix. New tests added without a matrix entry fail CI. + +5. **The `test_per_cell_tests_assert_more_than_just_status` meta-test** is a SHOULD-gate (warning, not hard fail) that surfaces per-cell tests asserting only on `terminal["status"]` without any other depth signal (event content, response.output, sequence numbers, etc.). It guides reviewers toward adding depth assertions when the cross-cutting tests don't already cover them. + +**Adding a new contract clause** (per `resilience-contract.md` § Change control): + +1. Add the clause to the contract doc. +2. Add a coverage matrix entry mapping the clause to the test(s) that verify it. +3. Add or extend tests with the depth assertions the clause requires. +4. Land all three (contract + matrix + tests) in a single PR. + +This principle was added as a follow-up to the conformance-depth reflection. The reflection that motivated it is in `~/.copilot/session-state/.../files/conformance_gap_analysis.md` and summarized in the source-of-truth contract discussion of conformance test depth. + +### XII. Core-Primitive TDD Discipline (NON-NEGOTIABLE) + +The public surface of the core resilient-task primitive (`azure-ai-agentserver-core/azure/ai/agentserver/core/tasks/`) is consumed by every higher layer (invocations samples, responses framework, future end-user resilient handlers). Drift between the primitive's documented contract and its actual behavior cascades silently into all consumers. This principle is the test-first gate against that drift. + +**Why this principle exists**: Principle X locks the responses-layer resilience matrix against drift. The core primitive has the same shape of problem one layer down — its `TaskContext` fields, decorator arguments, exception types, and metadata namespaces are a public contract whose drift produces silent miscompiles in consumer code. Prior hardening surfaced concrete examples: `run_attempt` semantics ambiguous between in-process retries and resilient failure-retry budget; `previous_input` shipped without being populated; `TaskSuspended` exported but unused; `_FilteredMetadata` filtering the wrong direction. None of these were caught by the existing suite because the suite asserted helper behavior, not the primitive's contract surface. This principle is the structural fix. + +**The rule:** + +1. **Every public symbol in `azure-ai-agentserver-core/azure/ai/agentserver/core/tasks/__init__.py` MUST have at least one paired test in `azure-ai-agentserver-core/tests/tasks/` asserting:** + - The symbol's exact name, location, and presence in `__all__`. + - Each field's name, type, and behavior under the modes the contract documents (e.g. `TaskContext.retry_attempt` resilience across process restart; `TaskContext.recovery_count` increment-on-recovery semantics). + - Each decorator argument's behavior (accepted-and-honored vs rejected-with-TypeError). + - Each exception type's raise sites and message shape. + +2. **The `test_contract_completeness.py` meta-test** (in `tests/tasks/`) parses the consolidated developer guide for the resilient-task primitive AND the test directory, and fails CI if any documented contract clause lacks a paired test reference, OR if any public symbol lacks a surface-test entry. + +3. **Any spec or pull request that affects the public surface of the core resilient-task primitive** (decorator signature, `TaskContext` fields, exception types, metadata namespaces, retry policy) **MUST land its conformance tests RED before the implementation commit goes green.** The reviewer verifies test-first ordering from the commit history. + +4. **The non-duplication rule:** when an existing test in `tests/tasks/` already covers the surface area being changed, the new conformance must EXTEND the existing test file rather than creating a parallel test file. A new test file is justified only when no existing home exists for the contract surface; the justification MUST be recorded in the conformance tracking document. + +5. **Synthetic-bypass shortcuts are explicitly disallowed for conformance tests:** + - MUST NOT monkey-patch `TaskContext` fields to simulate values that the runtime would produce. + - MUST NOT instantiate `TaskContext` directly outside the framework's wiring to test behavior that the framework provides. + - MUST NOT call internal `_` -prefixed APIs to bypass public-surface contract enforcement. + +**Adding or modifying a public-surface symbol:** any spec that adds, renames, drops, or changes the semantics of a public symbol in the core resilient-task primitive MUST: amend the consolidated dev guide, update the conformance suite (RED first, then GREEN after implementation), and update the spec template's exit checklist verification in the same PR as the implementation. + +**Reviewer checklist for PRs touching the core resilient-task primitive's public surface:** + +- [ ] Which public symbols (decorator args, `TaskContext` fields, exception types, metadata namespaces) does this change affect? +- [ ] Are the conformance tests for those symbols in the PR? +- [ ] Did those tests land RED before the implementation commit (verifiable from git history)? +- [ ] Was an existing test file extended (per non-duplication rule), or is the new file's justification recorded in the conformance tracking document? +- [ ] Did the consolidated dev guide need updates? Are they in this PR? + +This principle is the core-layer mirror of Principle X. The two stay in sync via cross-reference. The conformance tracking, non-duplication test discipline, and Constitution amendment implement the structural pieces. + +### XIII. Continuous Code Review Discipline (NON-NEGOTIABLE) + +Multi-phase implementations land hacks. Each phase, working in isolation, will accept a workaround that LOOKS LOCAL but degrades the overall code shape — a premature abstraction the next phase has to fight, an under-design that propagates scaffolding forward, a silent drift from the spec's design invariants that no per-phase reviewer would catch. This principle is the structural guard: code review is a sequencing fence, not an end-of-PR check. + +**Why this principle exists**: resilient-task primitive contract hardening surfaced this risk during task planning. The implementation had multiple user stories landing across many phases on one cohesive PR; the user observed that without continuous review, each phase would "just focus on solving its own problem" while collectively shipping a degraded surface. The fix — interleaved per-phase, cross-phase, and final reviews via the `code-review` agent — must apply to every multi-phase contract change. This principle is that generalization. + +**The rule:** + +1. **Every spec with three or more implementation phases (or three or more user stories) MUST include code review tasks in its task list.** The review tasks are sequencing fences interleaved with implementation, not a single end-of-PR step. + +2. **The review structure MUST include:** + - **Per-phase reviews** at the end of each implementation phase or user-story phase. Scope: catches phase-local quality issues (FR coverage, RED-first commit ordering, no hacks, no scope creep, no shape-only test assertions, dev-guide alignment for that phase's contracts). + - **Cross-phase seam reviews** at the boundary between any two implementation phases whose hand-off is architecturally significant (e.g., a phase that introduces an API surface another phase will consume; a phase that mutates a hot-path another phase will further mutate). Scope: catches premature abstraction, under-design, and seam quality issues that no single-phase review will catch. + - **Final whole-PR holistic review** at the end of the polish phase. Scope: catches end-to-end properties no per-phase review can verify alone — spec coverage symbol-for-symbol, documentation truth, plan-phase-decision resolution, constitution exit checklists complete, no regression, commit-history RED-first hygiene, lint/type/build clean. + +3. **Each review task dispatches the `code-review` agent (or equivalent) with a precise SCOPE statement tailored to the phase.** Generic "review this code" prompts are insufficient. The scope statement MUST name: (a) the specific FRs / SCs the phase implements; (b) the specific files and commits in the phase's diff; (c) the specific quality risks the phase is most likely to introduce; (d) the cross-phase coupling concerns the next phase will inherit; (e) constitution principles whose violation would be a BLOCKING finding. + +4. **Review tasks are blocking GATES.** A phase's review task MUST complete before the next phase begins. BLOCKING and HIGH findings MUST be addressed before the gate clears. MEDIUM and LOW findings MUST be logged to the conformance tracking artifact for the final-review sweep to verify they're either resolved or explicitly accepted with reviewer sign-off. + +5. **The `/speckit.tasks` template generates the review tasks automatically.** When the spec has three or more phases or stories, the tasks template MUST emit a "Continuous Code Review" phase as the last phase (with per-phase, cross-phase, and final review tasks), AND each Checkpoint marker in the intervening phases MUST be annotated with a `→ Run TXXX before moving to Phase Y` arrow pointing at its gating review task. The `/speckit.plan` template MUST include a "Code Review Cadence" subsection under the Constitution Check that names which review tasks the implementation will produce. + +**What review tasks catch (the recurring failure modes):** + +- **Phase-local hacks**: a `# TODO: revisit in next PR`-style shortcut, a one-off helper that should be generalized, an `# type: ignore` without justification, a `# pylint: disable` without justification, a test that monkey-patches an internal symbol to avoid wiring the public surface correctly. +- **Spec drift**: an FR partially implemented, an SC test that asserts shape instead of behavior, a new internal symbol introduced beyond what the spec / data-model authorized. +- **Premature abstraction**: a Phase A factory that the Phase B consumer doesn't actually need, a generic interface that papers over a single-concrete-use. +- **Under-design**: a Phase A seam that Phase B has to monkey-patch around because the original shape doesn't fit, an internal data-format choice that propagates into every later-phase test as a workaround. +- **Documentation drift**: a public-surface change without a corresponding dev guide update, a CHANGELOG entry that misrepresents the change, a docstring that contradicts the spec's contract claim. +- **Pre-existing test deletion**: a pre-existing test that exercised the surface this phase is changing was DELETED instead of PORTED per the spec's "Hardening pre-existing tests" subsection (deletion is allowed only with SOT conformance list justification). +- **RED-first violation**: an implementation commit precedes its paired conformance-test commit in git history (Constitution Principle XII §3 violation). + +**Reviewer checklist for PRs touching multi-phase spec implementations:** + +- [ ] Does the task list include a "Continuous Code Review" phase with per-phase, cross-phase, and final reviews? +- [ ] Did each per-phase review run at its Checkpoint and complete (with BLOCKING / HIGH findings addressed) before the next phase began? +- [ ] Did the cross-phase seam reviews run at the architectural boundaries the plan identified? +- [ ] Did the final holistic review verify all cross-cutting properties (spec coverage, public surface match, documentation truth, plan-phase-decision resolution, constitution exit checklists, no regression, commit-history RED-first, lint/type/build clean)? +- [ ] Were MEDIUM / LOW findings either resolved or accepted with reviewer sign-off in the conformance tracking artifact? + +This principle is referenced by `.specify/templates/plan-template.md` (Constitution Check gate for the Code Review Cadence subsection) and `.specify/templates/tasks-template.md` (auto-generated Phase N: Continuous Code Review section when the spec has ≥3 phases/stories). The two stay in sync via cross-reference. + +## Code Standards + +### File & Module Organization + +``` +azure/ai/agentserver/{component}/ +├── __init__.py # Public API exports only +├── _version.py # VERSION = "X.Y.ZbN" +├── _public_class.py # One primary class per module +├── _internal_helper.py # Underscore prefix = private +├── models/ # Data models (generated + runtime) +│ ├── _generated/ # Auto-generated — NEVER hand-edit +│ └── runtime.py # Runtime model extensions +├── py.typed # PEP 561 marker +└── tests/ # pytest-based tests +``` + +- **Public API**: Export only from `__init__.py`. Internal modules prefixed with `_`. +- **One concept per module**: Each `_*.py` file owns one class or closely related set of functions. +- **Generated code**: Lives in `models/_generated/` — never hand-edit. Runtime extensions in `models/runtime.py` or `models/_helpers.py`. + +### Docstrings (Sphinx RST Format) + +All public classes, methods, and functions require docstrings: + +```python +def create_response( + self, + input_items: Sequence[InputItem], + *, + mode: ResponseMode = "streaming", +) -> ResponseExecution: + """Create a new response execution. + + :param input_items: The input items to process. + :type input_items: ~collections.abc.Sequence[~azure.ai.agentserver.responses.InputItem] + :keyword mode: The response mode. Default is "streaming". + :paramtype mode: str + :return: The response execution object. + :rtype: ~azure.ai.agentserver.responses.ResponseExecution + :raises ValueError: If input_items is empty. + :raises ~azure.core.exceptions.HttpResponseError: If the service returns an error. + + .. versionadded:: 1.0.0b5 + """ +``` + +- Use `:param:` + `:type:` (two-line) or `:param type name:` (one-line) format. +- Use `:keyword:` + `:paramtype:` for keyword-only arguments. +- Use `~` prefix to shorten display paths in Sphinx output. +- Document all raised exceptions with `:raises ExceptionType: description`. +- Use `.. versionadded::` for new APIs. + +### Testing Requirements + +- **Framework**: pytest with pytest-asyncio (`asyncio_mode = "auto"`). +- **HTTP testing**: Use httpx `AsyncClient` with ASGI transport for in-process server testing. +- **Coverage**: All public APIs must have tests. All handler dispatch paths must be tested. +- **Test proxy**: Use the Azure SDK test proxy (`devtools_testutils`) for integration tests requiring live services. Inherit from `AzureRecordedTestCase` and use `@recorded_by_proxy` / `@recorded_by_proxy_async` decorators. +- **Recordings**: Stored in `tests/recordings/` or migrated to `azure-sdk-assets` repo. +- **No credentials in code**: Use environment variables, `self.get_credential()` from test base, or `devtools_testutils.fake_credentials` for CredScan compliance. +- **Samples testing**: Samples must be runnable (`python sample_name.py`). Async samples in `/samples/async_samples/` with `_async.py` suffix. +- **Sample E2E tests (NON-NEGOTIABLE)**: Every sample MUST have a corresponding end-to-end test that exercises the sample's handler/task logic programmatically. Tests replicate the sample logic inline (do NOT import from sample files), run the full lifecycle, and assert outputs. This follows the pattern established in `azure-ai-agentserver-responses/tests/e2e/test_sample_e2e.py`. A sample without an e2e test is considered incomplete. + +### Samples Conventions + +- **Location**: `/samples/` for sync, `/samples/async_samples/` for async. +- **Naming**: `sample_.py` and `sample__async.py`. +- **Snippet markers**: Use `# [START keyword]` and `# [END keyword]` for Sphinx `literalinclude` references. +- **Headers**: Each sample requires a docstring with description and setup instructions. +- **Dependencies**: Only OSI-approved licensed dependencies. Prefer permissive licenses (MIT, Apache 2). + +### Pylint Directives + +Allowed suppressions (with justification comments): +- `broad-exception-caught` — top-level dispatch only +- `too-many-instance-attributes` — large config/state objects +- `do-not-import-asyncio` — required for signal handling / tasks +- `logging-fstring-interpolation` — when performance is not critical + +Pylint design limits (from repo `pylintrc`): max-locals=25, max-branches=20, max-attributes=10, max-parents=15, min-similarity-lines=10. + +## Validation & Quality Gates + +### Pre-Push Validation (NON-NEGOTIABLE) + +**Before pushing any code to remote**, the following checks MUST be run locally on every modified package and MUST pass. Do not push code that fails any of these checks — fix issues locally first. + +For each modified package under `sdk/agentserver/`, run from the repo root: + +```bash +# Release-blocking checks (MUST pass before push) +python -m azpysdk.main pylint sdk/agentserver/ +python -m azpysdk.main mypy sdk/agentserver/ +python -m azpysdk.main sphinx sdk/agentserver/ +cd sdk/agentserver/ && python -m pytest tests/ -x -q + +# Also recommended before push +python -m azpysdk.main pyright sdk/agentserver/ +python -m azpysdk.main black sdk/agentserver/ +``` + +If a change touches multiple packages, validate ALL of them. Do not assume a change to one package won't break another — especially when modifying `__init__.py` exports or shared types. + +### Required Checks (azpysdk) + +All checks run via `azpysdk` from the repo root (or `azpysdk .` from the package directory). Every check must pass before merge: + +| Check | Command | Purpose | +|-------|---------|---------| +| Pylint | `azpysdk pylint .` | Code quality + Azure SDK custom rules | +| MyPy | `azpysdk mypy .` | Type correctness | +| Pyright | `azpysdk pyright .` | Type completeness | +| Verifytypes | `azpysdk verifytypes .` | Public API type coverage | +| Sphinx | `azpysdk sphinx .` | Documentation builds cleanly | +| Bandit | `azpysdk bandit .` | Security analysis | +| Black | `azpysdk black .` | Code formatting | +| Verifywhl | `azpysdk verifywhl .` | Wheel packaging correctness | +| Verifysdist | `azpysdk verifysdist .` | Source dist packaging correctness | + +### Release Blocking Checks + +These four checks **must PASS** for any release: +1. **MyPy** — PASS +2. **Pylint** — PASS +3. **Sphinx** — PASS +4. **Tests - CI** — PASS + +Failure of any release-blocking check means the package cannot be published. + +### Fixing Guidelines + +When fixing validation warnings: +- ✅ Fix with 100% confidence using existing patterns in the codebase +- ✅ Reference [Azure pylint guidelines](https://github.com/Azure/azure-sdk-tools/blob/main/tools/pylint-extensions/azure-pylint-guidelines-checker/README.md) and [MyPy cheat sheet](https://github.com/Azure/azure-sdk-for-python/blob/main/doc/dev/static_type_checking_cheat_sheet.md) +- ✅ Make minimal, surgical changes +- ❌ Never fix warnings without complete confidence +- ❌ Never add new dependencies or imports to fix warnings +- ❌ Never create new files solely to fix warnings +- ❌ Never make large refactoring changes to fix warnings + +## Security + +- **No hardcoded secrets**: Never commit credentials, connection strings, SAS tokens, or API keys. +- **Bandit scanning**: All code must pass `azpysdk bandit .` static security analysis. +- **CredScan compliance**: Use `devtools_testutils.fake_credentials` in tests. Test proxy sanitizes secrets in recordings automatically. +- **Environment variables**: All credentials and connection strings via environment variables (`FOUNDRY_PROJECT_ENDPOINT`, `APPLICATIONINSIGHTS_CONNECTION_STRING`, etc.). + +## Automation Boundaries + +### Safe Operations (AI agents and automation) +✅ Generate SDK code from TypeSpec specifications +✅ Run linting and static analysis tools +✅ Fix code quality warnings (with high confidence) +✅ Update documentation (CHANGELOG, README) +✅ Create and update PRs in draft mode +✅ Run existing test suites + +### Restricted Operations (require review) +⚠️ Modifying generated code in `_generated/` +⚠️ Adding new dependencies +⚠️ Changing API signatures +⚠️ Disabling or removing tests +⚠️ Large-scale refactoring + +### Prohibited Operations +❌ Merging PRs without human review +❌ Releasing packages to PyPI without approval +❌ Committing secrets or credentials +❌ Force pushing to protected branches +❌ Modifying CI/CD pipeline definitions +❌ Changing security or authentication logic without security review + +## Governance + +This constitution governs all development within `sdk/agentserver`. All code changes (PRs, reviews, AI-generated code) must comply with these principles. Amendments require documentation and team review. + +- Principle II (Strong Type Safety) is non-negotiable — no exceptions for convenience. +- All release-blocking quality gates (pylint, mypy, sphinx, tests) must pass before merge. +- Breaking API changes require a version bump and CHANGELOG entry. +- Reference the [Azure SDK Python Design Guidelines](https://azure.github.io/azure-sdk/python_design.html) as the authoritative source for any questions not covered here. +- For detailed tooling instructions, see the [Tool Usage Guide](https://github.com/Azure/azure-sdk-for-python/blob/main/doc/tool_usage_guide.md) and [CONTRIBUTING.md](https://github.com/Azure/azure-sdk-for-python/blob/main/CONTRIBUTING.md). + +**Version**: 1.7.0 | **Ratified**: 2026-05-22 | **Amended**: 2026-06-25 (Spec 034 — terminology reframe: Principle X renamed to "Resilience Contract Conformance", Principle XII prose reframed to task/resilience vocabulary, path references repointed; minor version bump for the renamed principle) diff --git a/sdk/agentserver/azure-ai-agentserver-core/CHANGELOG.md b/sdk/agentserver/azure-ai-agentserver-core/CHANGELOG.md index 54b64fd3e6c1..3c0f3f4b9cd2 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/CHANGELOG.md +++ b/sdk/agentserver/azure-ai-agentserver-core/CHANGELOG.md @@ -1,5 +1,320 @@ # Release History +## 2.0.0b7 (Unreleased) + +### Resilient-task primitive redesign + +The resilient-task primitive is reshaped on this release. The +authoritative behavior contract lives at +[`docs/task-and-streaming-spec.md`](docs/task-and-streaming-spec.md). +Highlights: + +- **Two decorators** — `@task` (one-shot) and `@multi_turn_task` (chain). + `@multi_turn_task` produces a distinct public `MultiTurnTask` class + (not a subclass of `Task`). Every `return X` is one turn (implicit + suspend); the chain stays alive in `suspended` between + turns until `MultiTurnTask.delete(task_id)` removes it. +- **`TaskRun` slim shape** — `task_id`, `input_id`, + `metadata`, `result()`, `cancel()`, `__await__`. `status`, `delete`, + `refresh`, `lease_expiry_count` are removed. +- **`TaskRun.result` returns raw `Output`**. The `TaskResult` + wrapper class is deleted. +- **`TaskContext.input_id`** — per-turn id for multi-turn, + defaults to `task_id` for one-shot 1:1 invariant. +- **New `TaskDeferred` exception** raised by + `ctx.exit_for_recovery()`. Semantically distinct from `TaskCancelled`. +- **Public exception taxonomy reshape**: exceptions no + longer carry `task_id`. `TaskFailed(error=...)`, + `TaskConflictError(current_status=...)`, + `LastInputIdPreconditionFailed(actual_last_input_id=...)` carry only + their respective field. `TaskCancelled`, `TaskDeferred`, + `SteeringQueueFull`, `InputTooLarge` are bare. +- **New typed-payload + value-type aliases**: `JSONValue` (recursive + Union for `TaskMetadata` values), `TaskErrorDict`, + `TaskExhaustedRetriesErrorDict`. +- **Auto-gen `task_id`** for one-shot `Task.start` / `Task.run` when + caller does not supply one. Multi-turn `task_id` remains + mandatory. +- **`if_last_input_id=`** precondition on both one-shot and + multi-turn `.start` / `.run`. Raises + `LastInputIdPreconditionFailed(actual_last_input_id=...)` on + mismatch. +- **Reserved metadata namespace**: `ctx.metadata("_X")` raises + `ValueError` (leading underscore reserved for the framework). +- **Handler signature validation**: first parameter MUST be + named `ctx`. +- **Structured failure log** — `resilient_task_handler_failure` + ERROR event with `task_id`/`input_id`/`error_type`/`error_message` + fields emitted on every handler failure. +- **Multi-turn raise → `suspended`** — chain stays + alive; queued steerers promote. +- **Multi-turn success → `suspended`** — `return X` is + implicit suspend; chain stays alive. + +### Removed from public surface + +- `TaskResult` wrapper class — deleted entirely. `await + run.result()` returns raw `Output`. +- `Suspended` sentinel — removed from public surface. Multi-turn + uses `return X` instead. +- `TaskSnapshot` + `Task.get(task_id)` — both removed. Use + `manager.provider.get(task_id)` directly for read-only inspection. +- `Task.options` — removed from public surface. +- Public `OutputTooLarge`, `TaskNotFound`, `TaskPreconditionFailed`, + `TaskStatus` — removed. The classes remain + internal-only in `_exceptions.py` for framework wiring. +- `TaskRun.delete()`, `.refresh()`, `.status`, `.lease_expiry_count` — + removed. For chain-level delete use + `MultiTurnTask.delete(task_id)`. +- `/tasks/resume` HTTP route + `TaskManager.handle_resume` — + resume happens via `.start()` / `.run()` against a suspended task. +- `payload["output"]` / `payload["error"]` writes — never persisted. + The framework no longer projects success/failure + state into the record's payload. +- `ephemeral=` decorator kwarg — one-shot is always ephemeral; + multi-turn never is. Transitionally emits a `DeprecationWarning`. +- `steerable=` on `@task` — same transitional warning. +- `ctx.suspend` — removed from the multi-turn contract. + Method body remains during the transition window for legacy callers. + + +### Features Added + +- **Unified local-development storage layout via + `azure.ai.agentserver.core.storage_paths`.** New public module + exposing `resolve_state_root()` and `resolve_state_subdir(kind)` + for the layout + `${AGENTSERVER_STATE_ROOT:-~/.agentserver}/{tasks,streams,responses}/`. + A single `AGENTSERVER_STATE_ROOT` env-var replaces the previous + per-subsystem path overrides; the per-subsystem env vars are gone. + Hosted environments are unaffected — the local-dev layout exists + to keep the development loop self-contained without external + dependencies. + +- **`AGENTSERVER_TASKS_BACKEND` operator override.** Setting this + env var to `local` or `hosted` forces the task provider regardless + of `AgentConfig.is_hosted` autodetection. Useful for debugging + hosted-only scenarios on a local workstation without standing up + the hosted task API, or for hosted environments where operators + want to opt out of the task-storage API in favour of on-disk + persistence. Unknown values raise `ValueError` at provider-create. + +- **Public read API: `Task.get(task_id) -> TaskSnapshot | None`** — + read-only introspection for any non-deleted task in any status + (pending, in_progress, suspended, completed). Returns ``None`` + for missing tasks (does NOT raise ``TaskNotFound``). Never + reclaims, never extends the lease, never PATCHes. Mirrors the + instance-method shape of ``Task.get_active_run`` as its + read-only sibling. + + New public type ``TaskSnapshot`` exposes only developer-facing + fields (``task_id``, ``status``, ``created_at``, ``updated_at``, + ``started_at``, ``completed_at``, ``output``, ``error``, + ``suspension_reason``, ``metadata``, ``lease_expiry_count``). + Framework-internal storage details (lease, etag, raw payload, + raw attachments, source, tags) are deliberately excluded. + + ```python + snap = await my_task.get("task-123") + if snap is None: + ... # never existed or was deleted + else: + print(snap.status, snap.output, snap.error) + ``` + +- **Per-output payloads up to 2 MB** for both `return` values from + resilient-task handlers and `ctx.suspend(output=...)` values. Outputs + are stored entirely in a framework-managed attachment slot, so they + never compete with the shared 1 MB task-payload budget. New + developer-facing exception: + + | Limit | Value | Exception | + |---|---|---| + | Per-output maximum size (serialized JSON) | **2 MB** | `OutputTooLarge` | + + Like `InputTooLarge`, the check runs client-side **before** any + network call. If you have a use case that genuinely needs > 2 MB + per output, externalize it (write to blob storage, return a + reference). + +- **Per-input payloads up to 2 MB** for both the initial function + input and each queued steering input. Pass arbitrarily large input + values to `Task.start(...)` (up to the 2 MB ceiling) and the + framework handles persistence transparently. + + New limits + exceptions: + + | Limit | Value | Exception | + |---|---|---| + | Per-input maximum size (serialized JSON) | **2 MB** | `InputTooLarge` | + | Maximum queued steering inputs | **9** | `SteeringQueueFull` | + + All limits are enforced client-side **before** any network call, so + failures surface as typed Python exceptions, not opaque HTTP errors. + + Public API surface unchanged — handlers see `ctx.input` as the + deserialized value regardless of input size. + +### Breaking Changes + +- **`EventStreamGoneError` removed** from + `azure.ai.agentserver.core.streaming`. + collapsed the previously-distinct `Gone` (registered then + destroyed) and `NotFound` (never registered) error types into a + single `EventStreamNotFoundError`. Every "this id is not + currently a live stream" condition — never-registered, + explicitly-deleted, or close-clock-TTL elapsed — now raises + `EventStreamNotFoundError` and wire-maps to HTTP 404. The + previous distinction's actionable value at the consumer's layer + was zero (right behavior is the same either way) and it leaked + the registry's internal tombstone bookkeeping. + +- **Replay-backing tombstone is now time-deterministic, not + buffer-state-driven.** replaces the previous + "Closed + buffer empty + had emit" auto-transition with a + close-clock model: when a replay backing (`ReplayEventStream` + or `FileBackedReplayEventStream` configured with `ttl_seconds`) + is closed, the registry tombstones the id at the wall-clock + moment `close_time + ttl_seconds`, regardless of who is + observing. Per-event TTL eviction continues to run during ACTIVE + to bound long-running stream memory. + +- `AttachmentTooLarge` and `AttachmentLimitExceeded` are no longer + exported from `azure.ai.agentserver.core.tasks`. Attachments are + a framework storage-layer concept that developers never name; + surfacing the attachment-vocabulary errors on the developer API + leaked the internal split between `payload` and `attachments`. The + framework now catches the internal `_AttachmentTooLarge` raised by + a provider and re-raises a developer-facing exception based on + which channel the violation occurred on: + + - `payload["input"]` (or steering inputs) → `InputTooLarge` + - handler return / `ctx.suspend(output=...)` → `OutputTooLarge` + +- **Unified streaming primitive** — new `azure.ai.agentserver.core.streaming` + subpackage exposing a `streams` registry singleton + `EventStream` + Protocol + four exception types. The registry is the single + process-level lifecycle owner; pick a backing once at app startup + via one of three strongly-typed configurators: + + ```python + streams.use_in_memory_live() # default — multicast, no buffer + streams.use_in_memory_replay(cursor_fn=..., ttl_seconds=600) + streams.use_file_backed_replay(storage_dir=..., ttl_seconds=600) + ``` + + Then anywhere in the process: `stream = await streams.get_or_create(id)` + where `id` is the **per-turn / per-invocation identifier** + (`invocation_id` for invocations, `response_id` for responses). + Subscribers attach via `async for ev in stream.subscribe(after=N)`. + Streaming is now fully decoupled from `@task` — handlers explicitly + opt in by calling the registry. See + [`docs/streaming-guide.md`](https://github.com/Azure/azure-sdk-for-python/blob/main/sdk/agentserver/azure-ai-agentserver-core/docs/streaming-guide.md) + for the full developer guide, including tombstone retention, + per-turn id convention, and exception/wire mapping. + + Public surface = 5 exports: `streams`, `EventStream`, + `EventStreamError`, `EventStreamClosedError`, + `EventStreamNotFoundError`. (removed + `EventStreamGoneError`; see Breaking Changes above.) The three + SDK-bundled backings are selected at app startup via the + registry's `use_in_memory_live()` / + `use_in_memory_replay(...)` / `use_file_backed_replay(...)` config- + urators; external callers obtain stream instances exclusively via + `await streams.get_or_create(id)` and program against the Protocol. + +- **Resilient tasks** — new `@task` decorator and supporting types + (`TaskContext`, `TaskResult`, `TaskRun`, `RetryPolicy`, + `TaskConflictError`, `TaskFailed`, `TaskCancelled`) for + crash-resilient long-running agents. Tasks survive container + restarts, OOM kills, and redeployments; the framework re-enters the + handler with `ctx.entry_mode == "recovered"` and a populated + `ctx.metadata` after a crash. Supports multi-turn suspend/resume via + `ctx.suspend()`, cooperative cancel via `ctx.cancel`, per-turn + wall-clock timeout via `@task(timeout=...)`, and steering of in-flight + tasks via `@task(steerable=True)`. For streaming, handlers use the + new `streams` registry (above) — `@task` itself has no streaming- + related kwarg. See the + [developer guide](https://github.com/Azure/azure-sdk-for-python/blob/main/sdk/agentserver/azure-ai-agentserver-core/docs/tasks-guide.md) + for the full API and patterns reference. + +### Other Changes + +- **Local file provider parity with the hosted task service.** + The local file-backed task provider used in dev mode now enforces + the same validation, state machine, lease semantics, attachment + rules, and list-filter surface as the hosted task service. This + closes silent "works locally, fails in service" divergences: + + - Field validation: task id regex (`^[a-zA-Z0-9_-]{1,128}$`), + required `agent_name` / `session_id` / `title` on create, tag key + regex (`^[a-zA-Z0-9_.\-]{1,64}$`) + max 16 entries + max 256 char + values, payload ≤ 1 MB, error ≤ 64 KB, source ≤ 4 KB, + suspension_reason ≤ 256 chars, `source.type` required when source + supplied, `"failed"` status rejected, `"done"` legacy alias + normalized to `"completed"`, attachment key regex. + - State machine: full `pending` ⇄ `in_progress` ⇄ `suspended` → + `completed` transition matrix enforcement; terminal-task + immutability (PATCH on `completed` rejected except no-op + `completed → completed`); immutable fields on PATCH (`id`, + `agent_name`, `session_id`, `title`, `description`, `source`); + `suspension_reason` only allowed with `status=suspended`; DELETE + on non-terminal task without `force=true` rejected; DELETE honors + `If-Match`. + - Lease: duration must be 0 (force-expire) or 10..3600; + `(lease_owner, lease_instance_id, lease_duration_seconds)` are + all-or-nothing; different-owner takeover when the existing lease + is live is rejected; `in_progress → pending` requires matching + lease; lease renewal only allowed on `in_progress`; force-expire + cannot combine with status change and requires lease ownership + unless already expired; `expiry_count` bumps on different-owner + takeover when the prior lease was expired; `started_at` is + **immutable** after the first `in_progress` transition (lease + re-acquisition, recovery scanner takeover, and suspend/resume + cycles MUST all preserve the original value); new `heartbeat_at` + field stamped on every lease write. + - Status-transition side effects: transitions to / from each state + now clear / set the right combination of `lease`, + `suspension_reason`, `started_at`, `completed_at`. + - PATCH semantics: `payload` patch branches on type (object → + shallow merge, non-object → full replace; previously assumed dict). + - Attachments: per-key null-as-delete (existing) plus new + top-level clear-all gesture via `TaskPatchRequest.clear_attachments` + flag (mirrors the service's `attachments: null` wire form). + - List filters: `has_error`, `lease_expired`, `omit_attachment_values` + added; pagination via `after` cursor + `limit` (default 20, max + 100); `order` accepts `"asc"` / `"desc"` by `created_at`; + `before` parameter rejected (forward-only cursor pagination); + status filter normalizes `"done"` → `"completed"`; `agent_name` + and `session_id` are now optional (workspace-wide listing). + +- **Hosted provider distinguishes service error codes internally +.** The hosted task service now returns distinct error + codes (`task_immutable`, `invalid_state_transition`, + `lease_held_by_another`, `task_already_exists`, + `lease_ownership_changed`, `etag_mismatch`, `invalid_request`). + The framework's response classifier now dispatches on these so + retry-able codes (`etag_mismatch`, `lease_ownership_changed`) + are retried transparently, while terminal conflicts surface as + the appropriate developer-facing `TaskConflictError` / + `TaskPreconditionFailed`. **No new developer-visible exception + types** — internal dispatch is fully absorbed inside the + framework. Existing `except TaskConflictError:` callers keep + working unchanged. + +- The hosted task-store transport is now built on + `azure.core.AsyncPipelineClient` instead of `httpx` / `aiohttp`; + neither `httpx` nor `aiohttp` is a production dependency of this + package anymore. + +- **Removed the `samples/` directory.** The standalone in-process + samples (`resilient_retry`, `resilient_streaming`, `selfhosted_invocation`) + have been deleted. End-to-end usage of the `@task` and streaming + primitives is demonstrated in the runnable HTTP-host samples shipped + with `azure-ai-agentserver-invocations` and + `azure-ai-agentserver-responses`, which match how the primitives + are actually consumed in production. + ## 2.0.0b6 (2026-06-12) ### Bugs Fixed diff --git a/sdk/agentserver/azure-ai-agentserver-core/MANIFEST.in b/sdk/agentserver/azure-ai-agentserver-core/MANIFEST.in index 15a42f74dc4b..f5b3b843b000 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/MANIFEST.in +++ b/sdk/agentserver/azure-ai-agentserver-core/MANIFEST.in @@ -1,7 +1,6 @@ include *.md include LICENSE recursive-include tests *.py -recursive-include samples *.py *.md include azure/__init__.py include azure/ai/__init__.py include azure/ai/agentserver/__init__.py diff --git a/sdk/agentserver/azure-ai-agentserver-core/README.md b/sdk/agentserver/azure-ai-agentserver-core/README.md index add29e0bb57b..aebd32783546 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/README.md +++ b/sdk/agentserver/azure-ai-agentserver-core/README.md @@ -113,6 +113,29 @@ export APPLICATIONINSIGHTS_CONNECTION_STRING="InstrumentationKey=..." python my_agent.py ``` +### Resilient long-running agents + +The `@task` decorator builds crash-resilient agents that survive container restarts, OOM kills, and redeployments. Task state is persisted to a task store, enabling automatic recovery and multi-turn suspend/resume patterns. + +```python +from azure.ai.agentserver.core.tasks import task, TaskContext + +@task +async def process_document(ctx: TaskContext[dict]) -> dict: + # ctx.entry_mode is "fresh" | "resumed" | "recovered". + # The framework re-invokes the handler from the top after a + # crash; ctx.input survives, so the handler picks up. + summary = await analyze(ctx.input["document_url"]) + return {"summary": summary} + +result = await process_document.run( + task_id="doc-42", input={"document_url": "..."}, +) +print(result.output) # {"summary": "..."} +``` + +See the [Developer Guide](https://github.com/Azure/azure-sdk-for-python/blob/main/sdk/agentserver/azure-ai-agentserver-core/docs/tasks-guide.md) for streaming, multi-turn suspend/resume, retries, timeouts, steering, and the patterns reference. + ## Troubleshooting ### Logging @@ -130,6 +153,7 @@ To report an issue with the client library, or request additional features, plea ## Next steps - Install [`azure-ai-agentserver-invocations`](https://pypi.org/project/azure-ai-agentserver-invocations/) to add the invocation protocol endpoints. +- Read the [Resilient Task Developer Guide](https://github.com/Azure/azure-sdk-for-python/blob/main/sdk/agentserver/azure-ai-agentserver-core/docs/tasks-guide.md) for crash-resilient long-running agents. - See the [container image spec](https://github.com/Azure/azure-sdk-for-python/tree/main/sdk/agentserver) for the full hosted agent contract. ## Contributing diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/__init__.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/__init__.py index d360a00966a8..084b47871b31 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/__init__.py +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/__init__.py @@ -17,17 +17,19 @@ end_span, flush_spans, record_error, + read_request_id, set_current_span, trace_stream, ) """ + __path__ = __import__("pkgutil").extend_path(__path__, __name__) from ._base import AgentServerHost from ._config import AgentConfig from ._errors import create_error_response from ._middleware import InboundRequestLoggingMiddleware -from ._request_id import RequestIdMiddleware +from ._request_id import RequestIdMiddleware, read_request_id from ._server_version import build_server_version from ._tracing import ( configure_observability, @@ -52,6 +54,7 @@ "end_span", "flush_spans", "record_error", + "read_request_id", "set_current_span", "trace_stream", ] diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/_base.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/_base.py index 84a7ccd06c24..1e67eed1d1cc 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/_base.py +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/_base.py @@ -37,6 +37,27 @@ _NOT_SET = "(not set)" +def _read_task_manager_shutdown_grace() -> float: + """Return TaskManager shutdown grace in seconds (env-driven, default 25.0). + + Reads ``AGENTSERVER_SHUTDOWN_GRACE_SECONDS``. Defaults to 25.0 when + unset. Allows tests (and operators) to keep shutdown fast when no + long-running resilient handlers need to checkpoint — for example the + conformance suite runs with a 1s grace so the in-process shutdown + marker fires before the handler completes naturally. + + :return: Grace period in seconds (non-negative). + :rtype: float + """ + raw = os.environ.get("AGENTSERVER_SHUTDOWN_GRACE_SECONDS") + if raw is None: + return 25.0 + try: + return max(0.0, float(raw)) + except ValueError: + return 25.0 + + def _mask_uri(uri: str) -> str: """Return only the scheme and host of a URI, hiding path/query/credentials. @@ -84,9 +105,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: async def _send_with_header(message: MutableMapping[str, Any]) -> None: if message["type"] == "http.response.start": headers = list(message.get("headers", [])) - headers.append( - (b"x-platform-server", self._get_server_version().encode()) - ) + headers.append((b"x-platform-server", self._get_server_version().encode())) message = {**message, "headers": headers} await send(message) @@ -160,7 +179,7 @@ class MyHost(InvocationAgentServerHost, ResponsesAgentServerHost): _DEFAULT_ACCESS_LOG_FORMAT = '%(h)s "%(r)s" %(s)s %(b)s %(D)sμs' - def __init__( + def __init__( # pylint: disable=too-many-statements self, *, applicationinsights_connection_string: Optional[str] = None, @@ -174,14 +193,20 @@ def __init__( ) -> None: # Shutdown handler slot (server-level lifecycle) ------------------- self._shutdown_fn: Optional[Callable[[], Awaitable[None]]] = None + # Pre-shutdown callbacks invoked SYNCHRONOUSLY from the + # SIGTERM signal handler — before Hypercorn's graceful drain + # begins. Used by responses to set ``_shutdown_requested`` early so + # foreground handlers' disconnect-poll loop sees the shutdown + # signal BEFORE Hypercorn waits for in-flight requests to complete. + # Callbacks must be non-blocking and thread-safe (they run in the + # signal handler, not on the event loop). + self._pre_shutdown_callbacks: list[Callable[[], None]] = [] # Server version segments for the x-platform-server header. # Protocol packages call register_server_version() to add their # own portion; the middleware joins them at response time. self._server_version_segments: list[str] = [] - self.register_server_version( - build_server_version("azure-ai-agentserver-core", _CORE_VERSION) - ) + self.register_server_version(build_server_version("azure-ai-agentserver-core", _CORE_VERSION)) # Resolved configuration (accessible as self.config) self.config: _config.AgentConfig = _config.AgentConfig.from_env() @@ -203,15 +228,11 @@ def __init__( logger.warning("Failed to initialize observability; continuing without it.", exc_info=True) # Access logging --------------------------------------------------- - self._access_log: Optional[logging.Logger] = ( - logger if access_log is _SENTINEL_ACCESS_LOG else access_log - ) + self._access_log: Optional[logging.Logger] = logger if access_log is _SENTINEL_ACCESS_LOG else access_log self._access_log_format: str = access_log_format or self._DEFAULT_ACCESS_LOG_FORMAT # Timeouts --------------------------------------------------------- - self._graceful_shutdown_timeout = _config.resolve_graceful_shutdown_timeout( - graceful_shutdown_timeout - ) + self._graceful_shutdown_timeout = _config.resolve_graceful_shutdown_timeout(graceful_shutdown_timeout) # Build lifespan context manager @contextlib.asynccontextmanager @@ -244,6 +265,27 @@ async def _lifespan(_app: Starlette) -> AsyncGenerator[None, None]: # noqa: RUF protocols, ) + # --- Resilient task manager auto-initialization --- + task_manager = None + try: + from .tasks._manager import ( # pylint: disable=import-outside-toplevel + TaskManager, + set_task_manager, + ) + + task_manager = TaskManager( + config=cfg, + shutdown_event=asyncio.Event(), + shutdown_grace_seconds=_read_task_manager_shutdown_grace(), + ) + set_task_manager(task_manager) + await task_manager.startup() + logger.info("TaskManager initialized automatically") + except ImportError: + pass # resilient module not available + except Exception: # pylint: disable=broad-exception-caught + logger.warning("Failed to initialize TaskManager", exc_info=True) + yield # --- SHUTDOWN: runs once when the server is stopping --- @@ -251,6 +293,14 @@ async def _lifespan(_app: Starlette) -> AsyncGenerator[None, None]: # noqa: RUF "AgentServerHost shutting down (graceful timeout=%ss)", self._graceful_shutdown_timeout, ) + + # Run on_shutdown FIRST so the responses layer's + # ``handle_shutdown`` can set ``_shutdown_requested`` and signal + # cancellation BEFORE the TaskManager waits its grace period. + # Without this, Row 3 (foreground) handlers can race against + # Hypercorn's client-connection close — the disconnect-poll loop + # stamps ``CLIENT_CANCELLED`` instead of ``SHUTTING_DOWN`` and + # B11 emits a cancelled terminal instead of failed. if self._graceful_shutdown_timeout == 0: logger.info("Graceful shutdown drain period disabled (timeout=0)") else: @@ -267,6 +317,21 @@ async def _lifespan(_app: Starlette) -> AsyncGenerator[None, None]: # noqa: RUF except Exception: # pylint: disable=broad-exception-caught logger.warning("Error in on_shutdown", exc_info=True) + # Shutdown task manager AFTER on_shutdown so resilient handlers + # have had time to checkpoint via the responses layer's + # ``handle_shutdown``. + if task_manager is not None: + try: + await task_manager.shutdown() + from .tasks._manager import ( # pylint: disable=import-outside-toplevel + set_task_manager as _clear_manager, + ) + + _clear_manager(None) + logger.info("TaskManager shut down") + except Exception: # pylint: disable=broad-exception-caught + logger.warning("Error shutting down TaskManager", exc_info=True) + # Merge routes: subclass routes (if any) + health endpoint all_routes: list[Any] = list(routes or []) all_routes.append( @@ -293,6 +358,7 @@ async def _lifespan(_app: Starlette) -> AsyncGenerator[None, None]: # noqa: RUF # (e.g. by MAF / agent-framework) are children of the caller's trace. # We do NOT create a SERVER span ourselves — we only propagate context. from azure.ai.agentserver.core._tracing import TraceContextMiddleware # pylint: disable=import-outside-toplevel + self.add_middleware(TraceContextMiddleware) # ------------------------------------------------------------------ @@ -352,6 +418,31 @@ def shutdown_handler(self, fn: Callable[[], Awaitable[None]]) -> Callable[[], Aw self._shutdown_fn = fn return fn + def register_pre_shutdown_callback(self, fn: Callable[[], None]) -> None: + """Register a synchronous callback to run on SIGTERM signal receipt. + + Callbacks run from inside the SIGTERM signal handler, + BEFORE Hypercorn begins its graceful drain. Use this to + set asyncio events that long-running request handlers observe via + their cancellation-polling loops, so they can return before + Hypercorn waits the full ``graceful_shutdown_timeout`` for the + request to complete. + + Callbacks MUST be non-blocking and signal-safe — they execute + synchronously on the main thread inside the signal handler. The + typical pattern is:: + + shutdown_event = asyncio.Event() + app.register_pre_shutdown_callback(shutdown_event.set) + + Note: ``asyncio.Event.set()`` is safe to call from a signal + handler when the event loop is running on the same thread. + + :param fn: A synchronous, non-blocking callable. + :type fn: Callable[[], None] + """ + self._pre_shutdown_callbacks.append(fn) + async def _dispatch_shutdown(self) -> None: """Dispatch to the registered shutdown handler, or no-op.""" if self._shutdown_fn is not None: @@ -403,23 +494,42 @@ def run(self, host: str = "0.0.0.0", port: Optional[int] = None) -> None: logger.info("AgentServerHost starting on %s:%s", host, resolved_port) config = self._build_hypercorn_config(host, resolved_port) - # Register SIGTERM handler to log the signal and initiate - # Hypercorn's graceful shutdown. - original_sigterm = signal.getsignal(signal.SIGTERM) - - def _handle_sigterm(_signum: int, _frame: Any) -> None: - logger.info("SIGTERM received, initiating graceful shutdown") - # Restore the original handler so the re-raised signal is not - # caught by this handler again (avoids infinite recursion). - signal.signal(signal.SIGTERM, original_sigterm) - os.kill(os.getpid(), signal.SIGTERM) - - signal.signal(signal.SIGTERM, _handle_sigterm) - - try: - asyncio.run(_hypercorn_serve(self, config)) # type: ignore[arg-type] - finally: - signal.signal(signal.SIGTERM, original_sigterm) + async def _serve_with_shutdown_trigger() -> None: + """Wrap hypercorn.serve with a custom shutdown_trigger. + + When Hypercorn's default ``shutdown_trigger=None`` + is used, Hypercorn registers its own SIGTERM/SIGINT handler + via ``loop.add_signal_handler`` and our ``signal.signal`` + handler is overridden. We register our own + ``loop.add_signal_handler`` here and pass the resulting wait + as ``shutdown_trigger`` so Hypercorn uses our event — and we + get to fire pre-shutdown callbacks synchronously on signal + receipt, before Hypercorn begins its graceful drain. + """ + loop = asyncio.get_event_loop() + signal_event = asyncio.Event() + + def _on_signal() -> None: + # Run pre-shutdown callbacks BEFORE setting the event so + # they fire before Hypercorn begins draining connections. + for cb in self._pre_shutdown_callbacks: + try: + cb() + except Exception: # pylint: disable=broad-exception-caught + logger.warning("Pre-shutdown callback raised", exc_info=True) + signal_event.set() + + for signal_name in ("SIGINT", "SIGTERM", "SIGBREAK"): + if hasattr(signal, signal_name): + try: + loop.add_signal_handler(getattr(signal, signal_name), _on_signal) + except NotImplementedError: + # Windows fallback — install via signal.signal directly. + signal.signal(getattr(signal, signal_name), lambda *_: _on_signal()) + + await _hypercorn_serve(self, config, shutdown_trigger=signal_event.wait) # type: ignore[arg-type] + + asyncio.run(_serve_with_shutdown_trigger()) async def run_async(self, host: str = "0.0.0.0", port: Optional[int] = None) -> None: """Start the server asynchronously (awaitable). diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/_config.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/_config.py index 493f58794776..6399cbb74e35 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/_config.py +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/_config.py @@ -128,8 +128,7 @@ def from_env(cls) -> Self: project_id=os.environ.get(_ENV_FOUNDRY_PROJECT_ARM_ID, ""), session_id=os.environ.get(_ENV_FOUNDRY_AGENT_SESSION_ID, ""), port=resolve_port(None), - appinsights_connection_string=os.environ.get( - _ENV_APPLICATIONINSIGHTS_CONNECTION_STRING, ""), + appinsights_connection_string=os.environ.get(_ENV_APPLICATIONINSIGHTS_CONNECTION_STRING, ""), otlp_endpoint=os.environ.get(_ENV_OTEL_EXPORTER_OTLP_ENDPOINT, ""), sse_keepalive_interval=resolve_sse_keepalive_interval(None), ws_ping_interval=resolve_ws_ping_interval(), @@ -151,9 +150,7 @@ def _parse_int_env(var_name: str) -> Optional[int]: try: return int(raw) except ValueError as exc: - raise ValueError( - f"Invalid value for {var_name}: {raw!r} (expected an integer)" - ) from exc + raise ValueError(f"Invalid value for {var_name}: {raw!r} (expected an integer)") from exc def _require_int(name: str, value: object) -> int: @@ -168,9 +165,7 @@ def _require_int(name: str, value: object) -> int: :raises ValueError: If *value* is not an integer. """ if isinstance(value, bool) or not isinstance(value, int): - raise ValueError( - f"Invalid value for {name}: {value!r} (expected an integer)" - ) + raise ValueError(f"Invalid value for {name}: {value!r} (expected an integer)") return value @@ -186,9 +181,7 @@ def _validate_port(value: int, source: str) -> int: :raises ValueError: If the port is outside 1-65535. """ if not 1 <= value <= 65535: - raise ValueError( - f"Invalid value for {source}: {value} (expected 1-65535)" - ) + raise ValueError(f"Invalid value for {source}: {value} (expected 1-65535)") return value @@ -212,18 +205,32 @@ def resolve_port(port: Optional[int]) -> int: _DEFAULT_GRACEFUL_SHUTDOWN_TIMEOUT = 30 +_ENV_GRACEFUL_SHUTDOWN_TIMEOUT = "AGENTSERVER_GRACEFUL_SHUTDOWN_TIMEOUT_SECONDS" def resolve_graceful_shutdown_timeout(timeout: Optional[int]) -> int: - """Resolve the graceful shutdown timeout from argument or default. + """Resolve the graceful shutdown timeout from argument, env var, or default. + + Resolution order: + 1. Explicit ``timeout`` argument (constructor / programmatic). + 2. ``AGENTSERVER_GRACEFUL_SHUTDOWN_TIMEOUT_SECONDS`` env var. + 3. Default of 30 seconds. + + Lower values force Hypercorn to cancel in-flight connections sooner + on SIGTERM — useful for tests / operators that want shutdown handlers + (in-process markers, resilient task checkpoints) to fire before + long-running requests complete naturally. :param timeout: Explicitly requested timeout or None. :type timeout: Optional[int] - :return: The resolved timeout in seconds (default 30). + :return: The resolved timeout in seconds. :rtype: int """ if timeout is not None: return max(0, _require_int("graceful_shutdown_timeout", timeout)) + env_val = _parse_int_env(_ENV_GRACEFUL_SHUTDOWN_TIMEOUT) + if env_val is not None: + return max(0, env_val) return _DEFAULT_GRACEFUL_SHUTDOWN_TIMEOUT @@ -249,9 +256,7 @@ def resolve_appinsights_connection_string( """ if connection_string is not None: return connection_string - return os.environ.get( - _ENV_APPLICATIONINSIGHTS_CONNECTION_STRING - ) + return os.environ.get(_ENV_APPLICATIONINSIGHTS_CONNECTION_STRING) def resolve_log_level(level: Optional[str]) -> str: @@ -268,10 +273,7 @@ def resolve_log_level(level: Optional[str]) -> str: else: normalized = "INFO" if normalized not in _VALID_LOG_LEVELS: - raise ValueError( - f"Invalid log level: {normalized!r} " - f"(expected one of {', '.join(_VALID_LOG_LEVELS)})" - ) + raise ValueError(f"Invalid log level: {normalized!r} " f"(expected one of {', '.join(_VALID_LOG_LEVELS)})") return normalized @@ -409,12 +411,10 @@ def resolve_ws_ping_interval() -> float: resolved = float(env_raw) except ValueError as exc: raise ValueError( - f"Invalid value for {_ENV_WS_KEEPALIVE_INTERVAL}: " - f"{env_raw!r} (expected a non-negative number)" + f"Invalid value for {_ENV_WS_KEEPALIVE_INTERVAL}: " f"{env_raw!r} (expected a non-negative number)" ) from exc if math.isnan(resolved) or math.isinf(resolved) or resolved < 0.0: raise ValueError( - f"Invalid value for {_ENV_WS_KEEPALIVE_INTERVAL}: " - f"{env_raw!r} (expected a non-negative finite number)" + f"Invalid value for {_ENV_WS_KEEPALIVE_INTERVAL}: " f"{env_raw!r} (expected a non-negative finite number)" ) return resolved diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/_errors.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/_errors.py index c5b1c9e01efe..9268e24df81c 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/_errors.py +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/_errors.py @@ -58,6 +58,4 @@ def create_error_response( body["type"] = error_type if details is not None: body["details"] = details - return JSONResponse( - {"error": body}, status_code=status_code, headers=headers - ) + return JSONResponse({"error": body}, status_code=status_code, headers=headers) diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/_middleware.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/_middleware.py index 4fb3fe78a9cd..63b0d320a771 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/_middleware.py +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/_middleware.py @@ -76,7 +76,9 @@ def _get_trace_id(headers: list[tuple[bytes, bytes]] | None = None) -> str | Non :rtype: str | None """ try: - from opentelemetry import trace as _trace # pylint: disable=import-outside-toplevel + from opentelemetry import ( + trace as _trace, + ) # pylint: disable=import-outside-toplevel span = _trace.get_current_span() ctx = span.get_span_context() @@ -147,7 +149,10 @@ async def _send_wrapper(message: MutableMapping[str, Any]) -> None: elapsed_ms = (time.monotonic() - start) * 1000 logger.warning( "Inbound %s %s failed with status 500 in %.1fms%s", - method, path, elapsed_ms, extra_str, + method, + path, + elapsed_ms, + extra_str, ) raise @@ -156,10 +161,18 @@ async def _send_wrapper(message: MutableMapping[str, Any]) -> None: if status_code is not None and status_code >= 400: logger.warning( "Inbound %s %s completed with status %d in %.1fms%s", - method, path, status_code, elapsed_ms, extra_str, + method, + path, + status_code, + elapsed_ms, + extra_str, ) else: logger.info( "Inbound %s %s completed with status %s in %.1fms%s", - method, path, status_code, elapsed_ms, extra_str, + method, + path, + status_code, + elapsed_ms, + extra_str, ) diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/_request_id.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/_request_id.py index 8c900ecb2320..95d87dfd35b0 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/_request_id.py +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/_request_id.py @@ -18,7 +18,7 @@ from __future__ import annotations import uuid -from typing import Any, MutableMapping +from typing import Any, Mapping, MutableMapping from starlette.types import ASGIApp, Receive, Scope, Send @@ -28,6 +28,25 @@ REQUEST_ID_STATE_KEY = "agentserver.request_id" +def read_request_id(scope: "Mapping[str, Any]") -> "str | None": + """Return the request ID resolved by :class:`RequestIdMiddleware`. + + Reads the value the middleware stored in the ASGI ``scope["state"]`` so + protocol packages can correlate a request without depending on the internal + state-key name. Returns ``None`` when the middleware is not installed or the + value is absent. + + :param scope: The ASGI scope (or any mapping carrying a ``state`` dict). + :type scope: Mapping[str, Any] + :return: The resolved ``x-request-id`` value, or ``None``. + :rtype: str | None + """ + state = scope.get("state") + if isinstance(state, dict): + return state.get(REQUEST_ID_STATE_KEY) + return None + + class RequestIdMiddleware: """Pure-ASGI middleware that sets ``x-request-id`` on every HTTP response. @@ -65,9 +84,7 @@ async def _send_with_request_id(message: MutableMapping[str, Any]) -> None: if message["type"] == "http.response.start": # Filter any existing x-request-id to avoid duplicates, then add ours. headers = [ - (name, value) - for name, value in message.get("headers", []) - if name.lower() != b"x-request-id" + (name, value) for name, value in message.get("headers", []) if name.lower() != b"x-request-id" ] headers.append((b"x-request-id", request_id.encode())) message = {**message, "headers": headers} diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/_tracing.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/_tracing.py index b0ed26bbeda1..924ef06ad695 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/_tracing.py +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/_tracing.py @@ -182,8 +182,10 @@ def _configure_tracing( span_processors = [ _FoundryEnrichmentSpanProcessor( - agent_name=agent_name, agent_version=agent_version, - agent_id=agent_id, project_id=project_id, + agent_name=agent_name, + agent_version=agent_version, + agent_id=agent_id, + project_id=project_id, agent_blueprint_id=agent_blueprint_id, agent_tenant_id=agent_tenant_id, ), @@ -247,10 +249,9 @@ def _setup_distro_export( kwargs["azure_monitor_connection_string"] = connection_string # A365 tracing export — enabled only in hosted environments. - if ( - os.environ.get("FOUNDRY_HOSTING_ENVIRONMENT", "") - and os.environ.get("FOUNDRY_AGENT365_TRACING_ENABLED", "").lower() in ("true", "1") - ): + if os.environ.get("FOUNDRY_HOSTING_ENVIRONMENT", "") and os.environ.get( + "FOUNDRY_AGENT365_TRACING_ENABLED", "" + ).lower() in ("true", "1"): kwargs["enable_a365"] = True kwargs["a365_use_s2s_endpoint"] = True kwargs["a365_enable_observability_exporter"] = True @@ -290,20 +291,20 @@ async def __call__(self, scope: Any, receive: Any, send: Any) -> None: # Build a simple dict of headers for the propagators raw_headers: list[tuple[bytes, bytes]] = scope.get("headers", []) - headers = { - k.decode("latin-1"): v.decode("latin-1") - for k, v in raw_headers - } + headers = {k.decode("latin-1"): v.decode("latin-1") for k, v in raw_headers} # Use the global propagator to extract trace context + baggage from opentelemetry.propagate import extract # pylint: disable=import-outside-toplevel + ctx = extract(carrier=headers) # Add x-request-id as baggage for downstream propagation x_request_id = headers.get("x-request-id") if x_request_id: ctx = _otel_baggage.set_baggage( - "x_request_id", x_request_id, context=ctx, + "x_request_id", + x_request_id, + context=ctx, ) token = _otel_context.attach(ctx) @@ -419,9 +420,7 @@ def detach_context(token: Any) -> None: ) -async def trace_stream( - iterator: AsyncIterable[_Content], span: Any -) -> AsyncIterator[_Content]: +async def trace_stream(iterator: AsyncIterable[_Content], span: Any) -> AsyncIterator[_Content]: """Wrap a streaming body so the span covers the full transmission. Yields chunks unchanged. Ends the span when the iterator is diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/_version.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/_version.py index 2577b81a5658..369f0dcc3bea 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/_version.py +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/_version.py @@ -2,4 +2,4 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # --------------------------------------------------------- -VERSION = "2.0.0b6" +VERSION = "2.0.0b7" diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/platform_headers.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/platform_headers.py new file mode 100644 index 000000000000..06411a3dc1fe --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/platform_headers.py @@ -0,0 +1,46 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Public platform HTTP header / wire-contract constants. + +These constants form the wire contract between the Foundry platform, agent +containers, and downstream storage services. They are shared across the +AgentServer protocol packages (e.g. ``azure-ai-agentserver-responses`` and +``azure-ai-agentserver-invocations``), which compose on top of this core +package; this module is the supported public surface for those constants. + +See the module-level documentation of each constant for its wire semantics. +""" +from __future__ import annotations + +from ._platform_headers import ( + APIM_REQUEST_ID, + CHAT_ISOLATION_KEY, + CLIENT_HEADER_PREFIX, + CLIENT_REQUEST_ID, + ERROR_DETAIL, + ERROR_SOURCE, + MAX_ERROR_DETAIL_LENGTH, + PLATFORM_ERROR_TAG, + REQUEST_ID, + SERVER_VERSION, + SESSION_ID, + TRACEPARENT, + USER_ISOLATION_KEY, +) + +__all__ = [ + "APIM_REQUEST_ID", + "CHAT_ISOLATION_KEY", + "CLIENT_HEADER_PREFIX", + "CLIENT_REQUEST_ID", + "ERROR_DETAIL", + "ERROR_SOURCE", + "MAX_ERROR_DETAIL_LENGTH", + "PLATFORM_ERROR_TAG", + "REQUEST_ID", + "SERVER_VERSION", + "SESSION_ID", + "TRACEPARENT", + "USER_ISOLATION_KEY", +] diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/storage_paths.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/storage_paths.py new file mode 100644 index 000000000000..c7e13f6be1de --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/storage_paths.py @@ -0,0 +1,89 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. +"""Unified storage paths for agentserver state subsystems. + +Public module — both ``azure-ai-agentserver-core`` (resilient tasks) and +``azure-ai-agentserver-responses`` (response store + stream store) resolve +their on-disk storage locations through this single helper. The unified +layout is:: + + / + tasks/ ← resilient task records (core) + streams/ ← SSE event store (responses) + responses/ ← response object store (responses) + +where ```` is ``${AGENTSERVER_STATE_ROOT:-~/.agentserver}``. + +The single env var ``AGENTSERVER_STATE_ROOT`` controls the root for +all three subdirectories — there is intentionally no per-subdir override. +Operators wanting per-subdir paths should symlink the desired locations +into the root. + +replaces the pre-migration per-subsystem +env vars: + + - ``AGENTSERVER_STATE_TASKS_PATH`` (was: ``~/.agentserver-tasks/``) + - ``AGENTSERVER_STREAM_STORE_PATH`` (was: ``/agentserver_streams``) + - ``AGENTSERVER_RESPONSE_STORE_PATH`` (was: no default; required for non-mem store) + +All three legacy env vars are deleted (not deprecated). The unified +``AGENTSERVER_STATE_ROOT`` is the only operator knob. +""" + +from __future__ import annotations + +import os +from pathlib import Path +from typing import Literal + +# Public type alias for the kinds of storage subdirectories the agentserver +# state subsystems own. +StateSubdir = Literal["tasks", "streams", "responses"] + +# Default root when ``AGENTSERVER_STATE_ROOT`` is unset. +_DEFAULT_ROOT_RELATIVE = ".agentserver" + +# Env var that overrides the root. Single var covers all subdirs. +STATE_ROOT_ENV_VAR = "AGENTSERVER_STATE_ROOT" + +# The full set of valid subdirectory kinds. +_VALID_SUBDIRS: frozenset[str] = frozenset({"tasks", "streams", "responses"}) + + +def resolve_state_root() -> Path: + """Resolve the root directory for agentserver state storage. + + Returns ``Path(os.environ['AGENTSERVER_STATE_ROOT'])`` if the env + var is set; otherwise ``Path.home() / ".agentserver"``. + + :returns: The resolved root path. + :rtype: Path + """ + env_value = os.environ.get(STATE_ROOT_ENV_VAR) + if env_value: + return Path(env_value) + return Path.home() / _DEFAULT_ROOT_RELATIVE + + +def resolve_state_subdir(kind: StateSubdir) -> Path: + """Resolve the on-disk path for a specific state storage subdirectory. + + :param kind: One of ``"tasks"`` (core), ``"streams"`` (responses), + ``"responses"`` (responses). + :type kind: StateSubdir + :returns: The resolved absolute path. Created lazily on first write + by the caller — this helper does not mkdir. + :rtype: Path + :raises ValueError: If ``kind`` is not one of the valid subdir kinds. + """ + if kind not in _VALID_SUBDIRS: + raise ValueError(f"Unknown resilient subdir kind: {kind!r}. " f"Valid kinds: {sorted(_VALID_SUBDIRS)}") + return resolve_state_root() / kind + + +__all__ = [ + "StateSubdir", + "STATE_ROOT_ENV_VAR", + "resolve_state_root", + "resolve_state_subdir", +] diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/streaming/__init__.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/streaming/__init__.py new file mode 100644 index 000000000000..b68544c732bb --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/streaming/__init__.py @@ -0,0 +1,34 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Unified streaming primitive — :class:`EventStream` Protocol + +``streams`` registry. + +Pick a backing once at app startup via one of the registry's three +``use_*`` configurators, then obtain stream instances anywhere in +your process via ``await streams.get_or_create(id)`` and program +against the :class:`EventStream` Protocol. + +See ``docs/streaming-guide.md`` for the developer guide (registry +API, backings, per-turn id convention, exception/wire mapping, +third-party-impl peer-registry pattern). +""" + +from __future__ import annotations + +from ._protocol import ( + EventStream, + EventStreamClosedError, + EventStreamError, + EventStreamNotFoundError, +) +from ._registry import streams + + +__all__ = [ + "streams", + "EventStream", + "EventStreamError", + "EventStreamClosedError", + "EventStreamNotFoundError", +] diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/streaming/_concrete.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/streaming/_concrete.py new file mode 100644 index 000000000000..28f68f8ea8d8 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/streaming/_concrete.py @@ -0,0 +1,762 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""SDK-bundled :class:`~._protocol.EventStream` implementations. + +This module is SDK-private (underscore-prefixed). External callers +obtain instances exclusively via the ``streams`` registry's three +``use_*`` configurators. This private import path is reserved for +SDK-internal tests (impl-specific assertions like file lock +detection, corruption recovery, per-event TTL eviction observability, +and broadcast no-buffer semantics). Consumer packages MUST NOT use +this private path. +""" + +from __future__ import annotations + +import asyncio # pylint: disable=do-not-import-asyncio +import json +import os +import time +from collections.abc import AsyncIterator, Callable +from pathlib import Path +from typing import Any, Optional + +from ._protocol import ( + EventStream, + EventStreamClosedError, + EventStreamNotFoundError, +) + +# Try POSIX fcntl; fall back to a lock-file scheme on platforms +# without it (Windows). Per streaming.md rule 32. +try: + import fcntl # type: ignore[import-not-found] + + _HAS_FCNTL = True +except ImportError: # pragma: no cover - windows + _HAS_FCNTL = False + + +# --------------------------------------------------------------- +# Internal sentinels + state markers +# --------------------------------------------------------------- + +_GONE_SENTINEL: object = object() +"""Pushed to subscriber queues to signal end-of-stream. + +Either close (drain remaining items then terminate cleanly) or +registry-driven delete (immediate cutoff — raise StopAsyncIteration +on next __anext__). The subscriber loop distinguishes by checking +self._state when it sees the sentinel. +""" + + +# --------------------------------------------------------------- +# Common base — state model + per-subscriber-queue fan-out +# --------------------------------------------------------------- + + +class _BaseEventStream: + """Shared state machine + subscriber fan-out for bundled impls. + + Concrete subclasses override ``emit`` / ``close`` / ``subscribe`` + / ``last_cursor`` and the private ``_on_delete`` cleanup hook. + + state model (post /): per-instance + states are exactly ``ACTIVE`` and ``CLOSED``. The ``GONE`` value + is retained as an internal flag the registry sets when it + tombstones the id — operations on a stale instance reference + after registry tombstone raise :class:`EventStreamNotFoundError` + . Per-instance: ``construction → ACTIVE``, + ``close() from ACTIVE → CLOSED``, ``close()`` from ``CLOSED`` / + tombstoned → no-op (idempotent). + + close-clock TTL tombstone: replay backings + with ``ttl_seconds`` configured record ``_close_time`` when + transitioning to ``CLOSED``. From that moment the SEMANTIC + tombstone deadline is ``close_time + ttl_seconds`` — operations + after the deadline raise :class:`EventStreamNotFoundError`. + Replaces the legacy "buffer empty + had emit" rule, which was + observer-driven and required a ``total_emit_count > 0`` carve- + out for never-emitted closed streams. + """ + + _STATE_ACTIVE = "ACTIVE" + _STATE_CLOSED = "CLOSED" + # Internal-only — set by the registry when it tombstones the id. + # External callers MUST NOT depend on this; the documented + # contract is "operation raises EventStreamNotFoundError". + _STATE_GONE = "GONE" + + def __init__(self) -> None: + self._state: str = self._STATE_ACTIVE + self._subscriber_queues: list[asyncio.Queue[Any]] = [] + self._lock = asyncio.Lock() + # — wall-clock time the stream transitioned + # to CLOSED; used by replay backings to compute the close-clock + # tombstone deadline (close_time + ttl_seconds). + self._close_time: Optional[float] = None + + async def _register_subscriber(self) -> asyncio.Queue[Any]: + q: asyncio.Queue[Any] = asyncio.Queue() + self._subscriber_queues.append(q) + return q + + def _remove_subscriber(self, q: asyncio.Queue[Any]) -> None: + # Best-effort removal; safe to call even if the queue is + # already absent (rule 15 — one event-loop-tick cleanup). + try: + self._subscriber_queues.remove(q) + except ValueError: + pass + + async def _fanout_emit(self, payload: Any) -> None: + """Push to every currently-attached subscriber queue.""" + for q in list(self._subscriber_queues): + await q.put(payload) + + async def _fanout_terminate(self) -> None: + """Push end-of-stream sentinel to every subscriber.""" + for q in list(self._subscriber_queues): + await q.put(_GONE_SENTINEL) + + +# --------------------------------------------------------------- +# BroadcastEventStream — live-only, no buffer +# --------------------------------------------------------------- + + +class BroadcastEventStream(_BaseEventStream): + """Multicast + no buffer + live-only. + + See ``streaming.md`` §5.1 +. Subscribers see only events + emitted **after** they attach. Constant memory overhead — only + the currently-attached subscriber list is retained. + + No ``cursor_fn``, no ``ttl_seconds``, no ``subscribe(after=...)`` + support (silently ignored). No CLOSED → GONE auto-transition + (nothing evicts). + """ + + async def emit(self, payload: Any, *, close: bool = False) -> None: + async with self._lock: + if self._state == self._STATE_GONE: + raise EventStreamNotFoundError("stream id is tombstoned") + if self._state == self._STATE_CLOSED: + raise EventStreamClosedError("stream is CLOSED") + await self._fanout_emit(payload) + if close: + self._state = self._STATE_CLOSED + self._close_time = time.time() + await self._fanout_terminate() + + async def close(self) -> None: + async with self._lock: + if self._state != self._STATE_ACTIVE: + return # idempotent no-op + self._state = self._STATE_CLOSED + self._close_time = time.time() + await self._fanout_terminate() + + def subscribe(self, *, after: Optional[int] = None) -> AsyncIterator[Any]: + del after # silently ignored per rule 17 — no buffer to seek + if self._state == self._STATE_GONE: + raise EventStreamNotFoundError("stream id is tombstoned") + return _BroadcastIterator(self, terminated=self._state == self._STATE_CLOSED) + + async def last_cursor(self) -> Optional[int]: + if self._state == self._STATE_GONE: + raise EventStreamNotFoundError("stream id is tombstoned") + return None # no cursor tracking + + async def _on_delete(self) -> None: + async with self._lock: + self._state = self._STATE_GONE + await self._fanout_terminate() + + +class _BroadcastIterator: + """Per-subscriber iterator for :class:`BroadcastEventStream`.""" + + def __init__(self, owner: BroadcastEventStream, *, terminated: bool = False) -> None: + self._owner = owner + self._queue: Optional[asyncio.Queue[Any]] = None + self._terminated = terminated + + def __aiter__(self) -> "_BroadcastIterator": + # Attach at __aiter__ so the subscriber is registered before + # the first __anext__ returns (rule for "attach" definition, + # / streaming.md §4.3). Skip if pre-terminated (stream + # was already CLOSED at subscribe() time). + if self._queue is None and not self._terminated: + q: asyncio.Queue[Any] = asyncio.Queue() + self._owner._subscriber_queues.append(q) + self._queue = q + return self + + async def __anext__(self) -> Any: + if self._terminated: + raise StopAsyncIteration + if self._queue is None: + self._queue = await self._owner._register_subscriber() + try: + item = await self._queue.get() + if item is _GONE_SENTINEL: + self._terminated = True + self._owner._remove_subscriber(self._queue) + raise StopAsyncIteration + return item + except (asyncio.CancelledError, GeneratorExit): + if self._queue is not None: + self._owner._remove_subscriber(self._queue) + raise + + def __del__(self) -> None: # rule 15 — subscriber cleanup on GC + if self._queue is not None: + try: + self._owner._remove_subscriber(self._queue) + except Exception: # pylint: disable=broad-except + pass + + +# --------------------------------------------------------------- +# Replay buffer entry — used by ReplayEventStream and +# FileBackedReplayEventStream +# --------------------------------------------------------------- + + +class _BufferedEvent: + """A buffered payload + its ``emit_time`` for TTL eviction.""" + + __slots__ = ("payload", "emit_time") + + def __init__(self, payload: Any, emit_time: float) -> None: + self.payload = payload + self.emit_time = emit_time + + +# --------------------------------------------------------------- +# ReplayEventStream — in-memory replay buffer + per-event TTL +# --------------------------------------------------------------- + + +class ReplayEventStream(_BaseEventStream): + """In-memory replay + optional cursor + optional per-event TTL. + + See ``streaming.md`` §5.2 +. Multi-subscriber. Buffers + every emit in memory subject to per-event TTL eviction. Supports + ``subscribe(after=...)`` iff ``cursor_fn`` is supplied. + """ + + def __init__( + self, + *, + cursor_fn: Optional[Callable[[Any], int]] = None, + ttl_seconds: Optional[float] = None, + ) -> None: + super().__init__() + self._cursor_fn = cursor_fn + self._ttl_seconds = ttl_seconds + self._buffer: list[_BufferedEvent] = [] + self._highest_cursor: Optional[int] = None + + def _evict_expired(self, *, now: Optional[float] = None) -> None: + """Drop expired entries from the head of the buffer. + + Per-event TTL semantics: each event expires at + ``emit_time + ttl_seconds`` independently of close/open + state (rules 22-24). In-flight per-subscriber queue items + are NOT recalled (rule 24). + """ + if self._ttl_seconds is None: + return + if now is None: + now = time.time() + cutoff = now - self._ttl_seconds + i = 0 + while i < len(self._buffer) and self._buffer[i].emit_time < cutoff: + i += 1 + if i > 0: + del self._buffer[:i] + + def _maybe_auto_transition_to_gone(self) -> None: + """/ C-STR-TTL-2 — close-clock auto-tombstone. + + When the stream is ``CLOSED`` AND ``ttl_seconds`` is configured + AND ``now >= close_time + ttl_seconds``, transition to GONE. + Replaces the legacy "CLOSED + buffer empty + had emit" rule. + Deterministic and time-driven; NOT observer-driven or + buffer-state-driven. + + Called from operations that observe the transition + (``subscribe`` / ``emit``). Per spec §46 / C-STR-TTL-2, + ``last_cursor`` MUST NOT call this — see ``last_cursor``. + """ + if ( + self._state == self._STATE_CLOSED + and self._ttl_seconds is not None + and self._close_time is not None + and time.time() >= self._close_time + self._ttl_seconds + ): + self._state = self._STATE_GONE + + async def emit(self, payload: Any, *, close: bool = False) -> None: + async with self._lock: + self._evict_expired() + self._maybe_auto_transition_to_gone() + if self._state == self._STATE_GONE: + raise EventStreamNotFoundError("stream id is tombstoned") + if self._state == self._STATE_CLOSED: + raise EventStreamClosedError("stream is CLOSED") + emit_time = time.time() + self._buffer.append(_BufferedEvent(payload, emit_time)) + if self._cursor_fn is not None: + cursor = self._cursor_fn(payload) + if self._highest_cursor is None or cursor > self._highest_cursor: + self._highest_cursor = cursor + await self._fanout_emit(payload) + if close: + self._state = self._STATE_CLOSED + self._close_time = time.time() + await self._fanout_terminate() + + async def close(self) -> None: + async with self._lock: + if self._state != self._STATE_ACTIVE: + return # idempotent + self._state = self._STATE_CLOSED + self._close_time = time.time() + await self._fanout_terminate() + + def subscribe(self, *, after: Optional[int] = None) -> AsyncIterator[Any]: + # rule 17: silently ignore `after` if no cursor_fn + if self._cursor_fn is None: + after = None + # Trigger eviction + GONE check before deciding whether to raise + self._evict_expired() + self._maybe_auto_transition_to_gone() + if self._state == self._STATE_GONE: + raise EventStreamNotFoundError("stream id is tombstoned") + return _ReplayIterator(self, after=after) + + async def last_cursor(self) -> Optional[int]: + # rule 8: do NOT trigger auto-transition; only evict-and-check + # whether the state has been changed by some prior call. + if self._state == self._STATE_GONE: + raise EventStreamNotFoundError("stream id is tombstoned") + return self._highest_cursor + + async def _on_delete(self) -> None: + async with self._lock: + self._state = self._STATE_GONE + self._buffer.clear() + await self._fanout_terminate() + + +class _ReplayIterator: + """Per-subscriber iterator for :class:`ReplayEventStream`. + + Replays history (subject to ``after`` cursor + per-event TTL) on + first ``__anext__``, then yields live events from a per- + subscriber queue. + """ + + def __init__(self, owner: ReplayEventStream, *, after: Optional[int] = None) -> None: + self._owner = owner + self._after = after + self._queue: Optional[asyncio.Queue[Any]] = None + self._history_buffer: list[Any] = [] + self._history_index = 0 + self._attached = False + self._terminated = False + + def _attach(self) -> None: + # Snapshot history + register live subscriber atomically + # under the owner's lock context (we approximate by reading + # the buffer before adding the queue — subsequent emits land + # in our queue, NOT into our history snapshot, so we don't + # duplicate). + owner = self._owner + owner._evict_expired() + if owner._cursor_fn is not None and self._after is not None: + for entry in owner._buffer: + if owner._cursor_fn(entry.payload) > self._after: + self._history_buffer.append(entry.payload) + else: + self._history_buffer = [e.payload for e in owner._buffer] + self._queue = asyncio.Queue() + owner._subscriber_queues.append(self._queue) + self._attached = True + + def __aiter__(self) -> "_ReplayIterator": + if not self._attached and not self._terminated: + self._attach() + return self + + async def __anext__(self) -> Any: + if not self._attached and not self._terminated: + self._attach() + if self._terminated: + raise StopAsyncIteration + # Check if owner has transitioned to GONE via registry-delete + # (immediate cutoff) + if self._owner._state == self._owner._STATE_GONE: + self._terminated = True + if self._queue is not None: + self._owner._remove_subscriber(self._queue) + raise StopAsyncIteration + # Drain history first + if self._history_index < len(self._history_buffer): + item = self._history_buffer[self._history_index] + self._history_index += 1 + return item + # If stream was already CLOSED at attach time and queue is + # empty, terminate cleanly + if ( + self._owner._state in (self._owner._STATE_CLOSED, self._owner._STATE_GONE) + and self._queue is not None + and self._queue.empty() + ): + self._terminated = True + self._owner._remove_subscriber(self._queue) + raise StopAsyncIteration + # Live phase + assert self._queue is not None + try: + item = await self._queue.get() + if item is _GONE_SENTINEL: + self._terminated = True + self._owner._remove_subscriber(self._queue) + raise StopAsyncIteration + return item + except (asyncio.CancelledError, GeneratorExit): + if self._queue is not None: + self._owner._remove_subscriber(self._queue) + raise + + def __del__(self) -> None: + if self._queue is not None: + try: + self._owner._remove_subscriber(self._queue) + except Exception: # pylint: disable=broad-except + pass + + +# --------------------------------------------------------------- +# FileBackedReplayEventStream — resilient, jsonl, single-writer +# --------------------------------------------------------------- + + +_TERMINAL_MARKER = "__terminal__" +"""Field name signalling a terminal-record on disk (rule 27).""" + +_COMPACTION_INTERVAL = 1000 +"""Compact on-disk file after this many evictions (rule 30). Chosen +default; documented in Phase 1 PR per T028.""" + + +class FileBackedReplayEventStream(_BaseEventStream): + """File-backed multicast + replay + cursor + per-event TTL. + + See ``streaming.md`` §5.3 + + rules 26-32. Persists every + emit to ``path`` before fan-out (persist-before-publish). + Rehydrates from disk on construction. Single-writer-per-path + enforced via ``fcntl.flock``. + """ + + def __init__( + self, + *, + path: Path, + cursor_fn: Optional[Callable[[Any], int]] = None, + ttl_seconds: Optional[float] = None, + serializer: Optional[Callable[[Any], bytes]] = None, + deserializer: Optional[Callable[[bytes], Any]] = None, + ) -> None: + super().__init__() + self._path = Path(path) + self._cursor_fn = cursor_fn + self._ttl_seconds = ttl_seconds + self._serializer = serializer + self._deserializer = deserializer + self._buffer: list[_BufferedEvent] = [] + self._highest_cursor: Optional[int] = None + self._evictions_since_compaction = 0 + + # Acquire single-writer lock + open file for append (rule 32). + self._path.parent.mkdir(parents=True, exist_ok=True) + # Open in append+read mode; fcntl.flock on POSIX, lock-file fallback elsewhere. + self._file = open(self._path, "a+b") # pylint: disable=consider-using-with + if _HAS_FCNTL: + try: + fcntl.flock(self._file.fileno(), fcntl.LOCK_EX | fcntl.LOCK_NB) + except BlockingIOError as exc: + self._file.close() + raise RuntimeError( + f"FileBackedReplayEventStream: another process holds the " f"lock on {self._path}" + ) from exc + else: + # Windows fallback: best-effort lock-file approach. + lock_path = self._path.with_suffix(self._path.suffix + ".lock") + try: + self._lock_fd = os.open(lock_path, os.O_CREAT | os.O_EXCL | os.O_RDWR) + self._lock_path = lock_path + except FileExistsError as exc: + self._file.close() + raise RuntimeError( + f"FileBackedReplayEventStream: another process holds the " f"lock-file on {self._path}" + ) from exc + + # Rehydrate from disk if file already had content (rule 28). + self._rehydrate() + + def _serialize(self, payload: Any, emit_time: float) -> bytes: + if self._serializer is not None: + inner = self._serializer(payload) + wrapper = {"emit_time": emit_time, "payload": inner.decode("utf-8") if isinstance(inner, bytes) else inner} + else: + wrapper = {"emit_time": emit_time, "payload": payload} + return (json.dumps(wrapper) + "\n").encode("utf-8") + + def _serialize_terminal(self, emit_time: float) -> bytes: + return (json.dumps({"emit_time": emit_time, _TERMINAL_MARKER: True}) + "\n").encode("utf-8") + + def _deserialize_record(self, line: bytes) -> dict: + record = json.loads(line.decode("utf-8")) + if self._deserializer is not None and "payload" in record: + record["payload"] = self._deserializer( + record["payload"].encode("utf-8") + if isinstance(record["payload"], str) + else json.dumps(record["payload"]).encode("utf-8") + ) + return record + + def _rehydrate(self) -> None: + self._file.seek(0) + data = self._file.read() + if not data: + return + lines = data.split(b"\n") + # Trailing partial: silent discard (rule 29). + if lines and lines[-1] != b"": + # Last line lacks \n — partial. Drop it. + lines = lines[:-1] + # Truncate the file to remove the partial trailing. + self._file.seek(0, os.SEEK_END) + self._file.truncate(self._file.tell() - len(data) + sum(len(l) + 1 for l in lines)) + else: + lines = [l for l in lines if l] + had_terminal = False + terminal_seen_at: Optional[int] = None + records: list[dict] = [] + for idx, line in enumerate(lines): + try: + rec = self._deserialize_record(line) + except (json.JSONDecodeError, UnicodeDecodeError) as exc: + # Mid-file malformed — RuntimeError at construction (rule 29). + self._cleanup_locks() + raise RuntimeError( + f"FileBackedReplayEventStream: malformed record at " f"line {idx} of {self._path}" + ) from exc + if "emit_time" not in rec: + self._cleanup_locks() + raise RuntimeError( + f"FileBackedReplayEventStream: record at line {idx} of " f"{self._path} missing 'emit_time' field" + ) + if rec.get(_TERMINAL_MARKER): + if had_terminal: + # Multiple terminals or terminal-not-at-EOF — malformed. + self._cleanup_locks() + raise RuntimeError( + f"FileBackedReplayEventStream: terminal marker not " f"at end-of-file in {self._path}" + ) + had_terminal = True + terminal_seen_at = idx + continue + if had_terminal: + # Records after terminal marker — malformed. + self._cleanup_locks() + raise RuntimeError( + f"FileBackedReplayEventStream: record at line {idx} of " f"{self._path} follows terminal marker" + ) + records.append(rec) + # Load into buffer, applying per-event TTL. + for rec in records: + entry = _BufferedEvent(rec["payload"], rec["emit_time"]) + self._buffer.append(entry) + if self._cursor_fn is not None: + cursor = self._cursor_fn(entry.payload) + if self._highest_cursor is None or cursor > self._highest_cursor: + self._highest_cursor = cursor + # Apply TTL eviction now (records may have expired since being written). + self._evict_expired() + if had_terminal: + self._state = self._STATE_CLOSED + # — close-clock is anchored at the + # terminal record's emit_time (the moment the prior + # process actually closed the stream). On rehydration we + # honor that wall-clock anchor so a process restart + # cannot extend the effective tombstone deadline. + if records: + self._close_time = records[-1]["emit_time"] + else: + self._close_time = time.time() + self._maybe_auto_transition_to_gone() + # Position file at end for subsequent appends. + self._file.seek(0, os.SEEK_END) + + def _cleanup_locks(self) -> None: + try: + if _HAS_FCNTL: + fcntl.flock(self._file.fileno(), fcntl.LOCK_UN) + else: + os.close(self._lock_fd) + self._lock_path.unlink(missing_ok=True) + except Exception: # pylint: disable=broad-except + pass + try: + self._file.close() + except Exception: # pylint: disable=broad-except + pass + + def _evict_expired(self) -> None: + if self._ttl_seconds is None: + return + now = time.time() + cutoff = now - self._ttl_seconds + i = 0 + while i < len(self._buffer) and self._buffer[i].emit_time < cutoff: + i += 1 + if i > 0: + del self._buffer[:i] + self._evictions_since_compaction += i + if self._evictions_since_compaction >= _COMPACTION_INTERVAL: + self._compact_on_disk() + self._evictions_since_compaction = 0 + + def _compact_on_disk(self) -> None: + """Rewrite the on-disk file to contain only surviving records. + + Lazy compaction (rule 30) — keeps the file bounded across + repeated process restarts. + """ + tmp_path = self._path.with_suffix(self._path.suffix + ".compact") + try: + with open(tmp_path, "wb") as tmp: + for entry in self._buffer: + tmp.write(self._serialize(entry.payload, entry.emit_time)) + if self._state == self._STATE_CLOSED: + tmp.write(self._serialize_terminal(time.time())) + # Atomic replace (POSIX guarantees atomicity on same fs). + os.replace(tmp_path, self._path) + # ``os.replace`` swapped ``self._path`` to a brand-new inode; our + # ``self._file`` handle still points at the old (now-unlinked) + # inode, so every subsequent ``emit``/``close`` write would land in + # the orphaned file and be lost on the next process lifetime (and + # the single-writer ``flock`` would be held on the dead inode). + # Reopen against the live path and re-acquire the lock. Open + lock + # the new handle BEFORE closing the old one so the single-writer + # guarantee is never released across the swap. + old_file = self._file + new_file = open(self._path, "a+b") # pylint: disable=consider-using-with + if _HAS_FCNTL: + fcntl.flock(new_file.fileno(), fcntl.LOCK_EX | fcntl.LOCK_NB) + new_file.seek(0, os.SEEK_END) + self._file = new_file + try: + old_file.close() + except Exception: # pylint: disable=broad-except + pass + except Exception: # pylint: disable=broad-except + try: + tmp_path.unlink(missing_ok=True) + except Exception: # pylint: disable=broad-except + pass + + def _maybe_auto_transition_to_gone(self) -> None: + """— close-clock auto-tombstone. + + Same rule as ReplayEventStream: CLOSED + ttl_seconds + configured + ``now >= close_time + ttl_seconds`` → GONE. + Replaces the legacy "CLOSED + buffer empty + had emit" rule. + """ + if ( + self._state == self._STATE_CLOSED + and self._ttl_seconds is not None + and self._close_time is not None + and time.time() >= self._close_time + self._ttl_seconds + ): + self._state = self._STATE_GONE + + async def emit(self, payload: Any, *, close: bool = False) -> None: + async with self._lock: + self._evict_expired() + self._maybe_auto_transition_to_gone() + if self._state == self._STATE_GONE: + raise EventStreamNotFoundError("stream id is tombstoned") + if self._state == self._STATE_CLOSED: + raise EventStreamClosedError("stream is CLOSED") + emit_time = time.time() + # Persist BEFORE fan-out (rule 26). For atomic emit+close + # (rule 14), write both records in one fsync. + record_bytes = self._serialize(payload, emit_time) + if close: + record_bytes += self._serialize_terminal(emit_time) + self._file.write(record_bytes) + self._file.flush() + os.fsync(self._file.fileno()) + # Now update in-memory state + fan out + self._buffer.append(_BufferedEvent(payload, emit_time)) + if self._cursor_fn is not None: + cursor = self._cursor_fn(payload) + if self._highest_cursor is None or cursor > self._highest_cursor: + self._highest_cursor = cursor + await self._fanout_emit(payload) + if close: + self._state = self._STATE_CLOSED + self._close_time = time.time() + await self._fanout_terminate() + + async def close(self) -> None: + async with self._lock: + if self._state != self._STATE_ACTIVE: + return + self._file.write(self._serialize_terminal(time.time())) + self._file.flush() + os.fsync(self._file.fileno()) + self._state = self._STATE_CLOSED + self._close_time = time.time() + await self._fanout_terminate() + + def subscribe(self, *, after: Optional[int] = None) -> AsyncIterator[Any]: + if self._cursor_fn is None: + after = None + self._evict_expired() + self._maybe_auto_transition_to_gone() + if self._state == self._STATE_GONE: + raise EventStreamNotFoundError("stream id is tombstoned") + return _ReplayIterator(self, after=after) # same iterator shape works + + async def last_cursor(self) -> Optional[int]: + if self._state == self._STATE_GONE: + raise EventStreamNotFoundError("stream id is tombstoned") + return self._highest_cursor + + async def _on_delete(self) -> None: + async with self._lock: + self._state = self._STATE_GONE + self._buffer.clear() + await self._fanout_terminate() + self._cleanup_locks() + try: + self._path.unlink(missing_ok=True) + except Exception: # pylint: disable=broad-except + pass + + +__all__ = [ + "BroadcastEventStream", + "ReplayEventStream", + "FileBackedReplayEventStream", +] diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/streaming/_protocol.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/streaming/_protocol.py new file mode 100644 index 000000000000..4d2f22560fe2 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/streaming/_protocol.py @@ -0,0 +1,151 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""``EventStream`` Protocol and exception hierarchy. + +This module defines the data-flow surface only — lifecycle +(create / lookup / destroy) is the registry's responsibility +(``_registry.py``). See ``docs/streaming-guide.md`` for the developer +guide covering the registry API, backings, per-turn id convention, +and exception/wire mapping. +""" + +from __future__ import annotations + +from collections.abc import AsyncIterator +from typing import Any, Optional, Protocol, runtime_checkable + + +class EventStreamError(Exception): + """Base class for all ``EventStream``-raised exceptions. + + Lets callers ``except EventStreamError`` to catch any of the + subclasses uniformly. + """ + + +class EventStreamClosedError(EventStreamError): + """Raised when ``emit()`` is called on an already-closed stream. + + The stream still exists; the caller cannot add more events. This + is a server-side bug (the producer kept emitting after closing) + and should be wire-mapped to 5xx, not 4xx. + """ + + +class EventStreamNotFoundError(EventStreamError): + """Raised when any operation references a stream id that is not + currently a live stream. + + unified the previously-distinct + ``EventStreamNotFoundError`` (never registered) and + ``EventStreamGoneError`` (registered then destroyed) into this + single error type. Three independent reasons fire this: + + - the id was never registered (no ``get_or_create(id)`` ever ran) + - the id was explicitly ``streams.delete(id)``d + - the id's stream was Closed and its close-clock TTL + (``close_time + ttl_seconds``) elapsed, causing the registry + to auto-tombstone + + Collapsing the two error types simplifies the developer-facing + surface: either way, the right behavior is the same (subscribe to + a new id, or treat this id as missing). It also stops leaking the + registry's internal tombstone bookkeeping (whether an id was + "previously alive" or "never seen") into the public API. + + Wire-mapped to HTTP 404 Not Found. + """ + + +@runtime_checkable +class EventStream(Protocol): + """A multi-cast event stream. + + Four data-flow methods: :meth:`emit`, :meth:`close`, + :meth:`subscribe`, :meth:`last_cursor`. Lifecycle (create / + lookup / destroy) is the registry's job (``streams``); the + Protocol intentionally does NOT include a destructive method. + + See ``docs/streaming-guide.md`` for the developer guide. + """ + + async def emit(self, payload: Any, *, close: bool = False) -> None: + """Emit a payload to all currently-attached subscribers. + + :param payload: Opaque value. The framework never inspects, + validates, or rewrites it. + :param close: If ``True``, the emit and the close-of-stream + are observably atomic: every subscriber attached before + this call returns sees BOTH the payload AND the + end-of-stream signal; subscribers attached after see + neither. + + :raises EventStreamClosedError: If the stream has already + been closed. + :raises EventStreamNotFoundError: If the stream has been + destroyed. + """ + ... + + async def close(self) -> None: + """Transition the stream from active to closed. Idempotent. + + On an already-closed or destroyed stream, this is a no-op + (never raises). Subscribers attached at close time drain any + remaining queued items, then their iterators terminate + cleanly with ``StopAsyncIteration``. + """ + ... + + def subscribe(self, *, after: Optional[int] = None) -> AsyncIterator[Any]: + """Return an async iterator over emitted payloads. + + NOT a coroutine: call without ``await`` and immediately use + with ``async for`` / ``aiter()`` / ``anext()``. + + :param after: If supplied and the active backing supports + cursored replay, yield only payloads whose cursor value + is strictly greater than ``after``. Backings without + cursor support silently ignore non-``None`` values. + + :raises EventStreamNotFoundError: Raised synchronously at the + call site (before the iterator is returned) if the + stream has been destroyed. + """ + ... + + async def last_cursor(self) -> Optional[int]: + """Return the highest cursor seen so far, or ``None``. + + Semantics: + + - While the stream is active: the highest cursor value + persisted so far, or ``None`` if zero emits OR the active + backing has no cursor support. + - After the stream is closed: the last cursor the backing + ever saw, even if those events have since been evicted by + per-event TTL. ``last_cursor()`` is a read-only watermark + query and does not itself fire the close → destroy + auto-transition. This is load-bearing for the file-backed + replay rehydration path (handler reads ``last_cursor()`` + on entry to pick the next cursor). + - After the stream is destroyed (auto-transition has fired): + raises :class:`EventStreamNotFoundError`. + + ``last_cursor()`` is the **emitter's** recovery primitive. + It is NOT a workflow-recovery primitive — workflow + watermarks (what work is done) belong in ``ctx.metadata``, + batched per side-effecting operation. See + ``docs/streaming-guide.md`` for the metadata-vs-cursor + antipattern note. + """ + ... + + +__all__ = [ + "EventStream", + "EventStreamError", + "EventStreamClosedError", + "EventStreamNotFoundError", +] diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/streaming/_registry.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/streaming/_registry.py new file mode 100644 index 000000000000..bfbf1ab93e8c --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/streaming/_registry.py @@ -0,0 +1,243 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +""":data:`streams` registry — process-level lifecycle owner. + +Six methods: + +- Three async lifecycle: :meth:`_StreamsRegistry.get`, + :meth:`_StreamsRegistry.get_or_create`, + :meth:`_StreamsRegistry.delete`. +- Three sync configurators: :meth:`_StreamsRegistry.use_in_memory_live`, + :meth:`_StreamsRegistry.use_in_memory_replay`, + :meth:`_StreamsRegistry.use_file_backed_replay`. + +The registry is the lifecycle owner for the three SDK-bundled +backings. Third-party :class:`EventStream` impls do NOT plug into +this registry — they ship their own peer registry. + +: ``get(id)`` raises +:class:`EventStreamNotFoundError` for ANY id that is not currently +a live stream — never registered, explicitly :meth:`delete`d, or +close-clock TTL elapsed. The registry retains tombstones for +deleted / auto-tombstoned ids primarily to support re-create-after- +delete semantics (a subsequent :meth:`get_or_create` clears the +tombstone and constructs a fresh stream), NOT to differentiate +the error type — there is only ONE error type for "id is missing". +All paths wire-map to HTTP 404. +""" + +from __future__ import annotations + +import asyncio # pylint: disable=do-not-import-asyncio +from collections.abc import Callable +from pathlib import Path +from typing import Any, Optional, Union + +from ._concrete import ( + BroadcastEventStream, + FileBackedReplayEventStream, + ReplayEventStream, +) +from ._protocol import ( + EventStream, + EventStreamNotFoundError, +) + + +# Sentinel for tombstoned slots (rule 36a) +_TOMBSTONE: object = object() + + +class _StreamsRegistry: + """Implementation of the module-level :data:`streams` singleton. + + Do not instantiate directly — use the exported ``streams`` + instance. This is the SDK-private implementation type; the + public surface is the singleton + the six methods on it. + """ + + def __init__(self) -> None: + # Streams keyed by id; value is either an EventStream + # instance OR _TOMBSTONE for destroyed ids. + self._slots: dict[str, Union[EventStream, object]] = {} + # Per-id locks for get_or_create atomicity (rule 34). + self._id_locks: dict[str, asyncio.Lock] = {} + # Global lock guarding _slots + _id_locks structural mutations. + self._struct_lock = asyncio.Lock() + # Factory closure — set by use_* configurators. Default: + # use_in_memory_live per rule 37a (also). + self._factory: Callable[[str], EventStream] = lambda _id: BroadcastEventStream() + + # ----- Configurators (sync) ----- + + def use_in_memory_live(self) -> None: + """Configure the registry to construct in-memory **live** streams + (multicast, no replay buffer). Subscribers see events emitted + after they subscribe — late subscribers miss earlier events. + Suitable when consumers attach before the producer starts. + """ + self._factory = lambda _id: BroadcastEventStream() + + def use_in_memory_replay( + self, + *, + cursor_fn: Optional[Callable[[Any], int]] = None, + ttl_seconds: Optional[float] = None, + ) -> None: + """Configure the registry to construct in-memory **replay** streams. + + Each stream retains its event history (subject to ``ttl_seconds`` + per-event TTL eviction once the stream is closed). Late + subscribers see the full retained history. Pass ``cursor_fn`` + to enable cursored re-subscription via ``subscribe(after=...)``. + """ + self._factory = lambda _id: ReplayEventStream(cursor_fn=cursor_fn, ttl_seconds=ttl_seconds) + + def use_file_backed_replay( + self, + *, + storage_dir: Path, + cursor_fn: Optional[Callable[[Any], int]] = None, + ttl_seconds: Optional[float] = None, + serializer: Optional[Callable[[Any], bytes]] = None, + deserializer: Optional[Callable[[bytes], Any]] = None, + ) -> None: + """Configure the registry to construct **file-backed replay** streams. + + Each stream persists its event log to + ``storage_dir / f"{id}.jsonl"`` and rehydrates on construction + if the file already exists (crash-recovery friendly). Same + replay + TTL + cursor semantics as :meth:`use_in_memory_replay`. + """ + storage_dir = Path(storage_dir) + storage_dir.mkdir(parents=True, exist_ok=True) + self._factory = lambda _id: FileBackedReplayEventStream( + path=storage_dir / f"{_id}.jsonl", + cursor_fn=cursor_fn, + ttl_seconds=ttl_seconds, + serializer=serializer, + deserializer=deserializer, + ) + + # ----- Lifecycle (async) ----- + + async def _get_id_lock(self, id: str) -> asyncio.Lock: + async with self._struct_lock: + lock = self._id_locks.get(id) + if lock is None: + lock = asyncio.Lock() + self._id_locks[id] = lock + return lock + + async def get(self, id: str) -> EventStream: + """Look up the existing instance for ``id``. + + — every "id is not currently a live + stream" condition raises :class:`EventStreamNotFoundError`: + + - Unregistered id (never seen). + - Explicitly :meth:`delete`d id (tombstoned). + - Closed stream whose close-clock TTL deadline has elapsed + (auto-tombstoned). + """ + slot = self._slots.get(id, None) + if slot is None: + raise EventStreamNotFoundError(id) + if slot is _TOMBSTONE: + raise EventStreamNotFoundError(id) + # — opportunistic close-clock check. + # If the stream's internal _maybe_auto_transition_to_gone + # would fire, install the registry tombstone now and raise + # NotFound. This makes the registry-level auto-tombstone + # observable even without an explicit emit/subscribe on the + # instance. + if await self._tombstone_if_close_clock_elapsed(id, slot): + raise EventStreamNotFoundError(id) + return slot # type: ignore[return-value] + + async def _tombstone_if_close_clock_elapsed(self, id: str, slot: Any) -> bool: + """If the stream's close-clock TTL elapsed, run its + ``_on_delete`` cleanup hook and install the registry + tombstone. Returns True iff the tombstone was installed. + + / — file-backed cleanup happens + BEFORE the registry tombstone install per C-STR-FBR-4. + """ + maybe_check = getattr(slot, "_maybe_auto_transition_to_gone", None) + if maybe_check is None: + return False + # Trigger the check; the instance may flip its state to GONE. + try: + async with getattr(slot, "_lock", asyncio.Lock()): + maybe_check() + except Exception: # pylint: disable=broad-except + return False + # Read the state attribute non-strictly. + if getattr(slot, "_state", None) != "GONE": + return False + # State has transitioned — perform cleanup + install tombstone. + on_delete = getattr(slot, "_on_delete", None) + if on_delete is not None: + try: + await on_delete() + except Exception: # pylint: disable=broad-except + pass + self._slots[id] = _TOMBSTONE + return True + + async def get_or_create(self, id: str) -> EventStream: + """Return cached instance for ``id``, or create a new one. + + Atomic across concurrent callers: a per-id lock prevents + split-brain construction when two coroutines race to create + the same id. A previously-destroyed id is cleared on + re-creation. + """ + # Fast path — already present, not tombstoned + slot = self._slots.get(id, None) + if slot is not None and slot is not _TOMBSTONE: + return slot # type: ignore[return-value] + # Slow path — acquire per-id lock + create + lock = await self._get_id_lock(id) + async with lock: + slot = self._slots.get(id, None) + if slot is not None and slot is not _TOMBSTONE: + return slot # type: ignore[return-value] + instance = self._factory(id) + self._slots[id] = instance + return instance + + async def delete(self, id: str) -> None: + """Destroy the stream registered for ``id``. + + Idempotent — calling on an unregistered or already-destroyed + id is a no-op (but still ensures the tombstone is in place so + subsequent ``get(id)`` raises ``EventStreamNotFoundError``). + + Cleans up backing resources (e.g. file handles for the + file-backed replay backing) before installing the tombstone + / C-STR-FBR-4. + """ + slot = self._slots.get(id, None) + if slot is None: + # Never registered — install tombstone for symmetry + # (the next get(id) raises ``EventStreamNotFoundError``). + # This matches rule 36a's "delete is symmetric with rm -f + # but still leaves a marker" semantics. + self._slots[id] = _TOMBSTONE + return + if slot is _TOMBSTONE: + return # idempotent + # Invoke private cleanup hook on the bundled impl + on_delete = getattr(slot, "_on_delete", None) + if on_delete is not None: + await on_delete() + self._slots[id] = _TOMBSTONE + + +# Module-level singleton — THE public registry. +streams = _StreamsRegistry() + + +__all__ = ["streams"] diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tasks/__init__.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tasks/__init__.py new file mode 100644 index 000000000000..8f8bcc65c9c1 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tasks/__init__.py @@ -0,0 +1,104 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Task subsystem for crash-resilient long-running agents. + +Provides the :func:`task` and :func:`multi_turn_task` decorators +plus supporting types for building Azure AI Hosted Agents that survive +container crashes, OOM kills, and redeployments. + +Key features: + +- **Two decorators** — ``@task`` (one-shot, single run, ephemeral) and + ``@multi_turn_task`` (chain — every ``return X`` is one turn; chain + stays alive in ``suspended`` between turns). +- **Lifecycle automation** — ``.run()`` and ``.start()`` automatically + start, resume, or recover tasks based on their current state. +- **Entry mode** — ``ctx.entry_mode`` tells the handler whether it was + entered fresh, resumed from suspension, or recovered from a crash. +- **RetryPolicy** — configurable retry with exponential, fixed, or linear + backoff (see :class:`RetryPolicy` presets). +- **Streaming** lives in :mod:`azure.ai.agentserver.core.streaming` + : handlers call ``stream = await streams.get_or_create(invocation_id)`` + to obtain a stream handle; ``TaskRun`` itself is NOT iterable. + +Public API:: + + from azure.ai.agentserver.core.tasks import ( + task, + multi_turn_task, + Task, + MultiTurnTask, + RetryPolicy, + TaskContext, + TaskMetadata, + TaskRun, + TaskFailed, + TaskCancelled, + TaskDeferred, + TaskConflictError, + LastInputIdPreconditionFailed, + SteeringQueueFull, + InputTooLarge, + JSONValue, + TaskErrorDict, + TaskExhaustedRetriesErrorDict, + EntryMode, + ) +""" + +from ._context import EntryMode, TaskContext +from ._decorator import MultiTurnTask, Task, multi_turn_task, task +from ._exceptions import ( + InputTooLarge, + LastInputIdPreconditionFailed, + SteeringQueueFull, + TaskCancelled, + TaskConflictError, + TaskDeferred, + TaskErrorDict, + TaskExhaustedRetriesErrorDict, + TaskFailed, +) +from ._metadata import JSONValue, TaskMetadata +from ._retry import RetryPolicy +from ._run import TaskRun + +# Streaming lives in `azure.ai.agentserver.core.streaming` as a peer +# subpackage with a registry-based lifecycle model. The resilient task +# decorators accept no streaming-related kwarg; ``TaskContext`` has +# no streaming attribute. Handlers explicitly do +# ``stream = await streams.get_or_create(invocation_id)`` to obtain a +# stream handle for the current turn. +# +# Attachment-vocabulary errors (``_AttachmentTooLarge``, +# ``_AttachmentLimitExceeded``) are framework-internal — they are +# caught at attachment-write sites and re-raised as the developer- +# facing ``InputTooLarge`` based on the attachment-key prefix. +__all__ = [ + # Decorators + task classes + "task", + "multi_turn_task", + "Task", + "MultiTurnTask", + # Context + metadata + "TaskContext", + "TaskMetadata", + "EntryMode", + # Type aliases + TypedDicts + "JSONValue", + "TaskErrorDict", + "TaskExhaustedRetriesErrorDict", + # TaskRun + "TaskRun", + # Retry + "RetryPolicy", + # Public exceptions + "TaskFailed", + "TaskCancelled", + "TaskDeferred", + "TaskConflictError", + "LastInputIdPreconditionFailed", + "SteeringQueueFull", + "InputTooLarge", +] diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tasks/_attachments.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tasks/_attachments.py new file mode 100644 index 000000000000..f3a0739ba33b --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tasks/_attachments.py @@ -0,0 +1,445 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Task attachments support. + +Helpers for the input-promotion mechanism that lets the resilient +primitive support per-input payloads up to 2 MB by spilling oversized +inputs into ``task.attachments`` (decoupled from the shared 1 MB +``task.payload`` budget). See `the SOT spec` +for the authoritative wire contract and `the SOT spec` +for the speckit-flow spec. + +This module exports: + +- Tunables / constants (``_INPUT_THRESHOLD_BYTES``, + ``_STEERING_THRESHOLD_BYTES``, ``_MAX_ATTACHMENT_SIZE_BYTES``, + ``_MAX_ATTACHMENTS``, ``_STEERING_QUEUE_CAP``, + ``_FUNCTION_INPUT_KEY``, ``_STEERING_INPUT_KEY_PREFIX``). +- Hash helper: :func:`_compute_attachment_hash`. +- Ref helpers: :func:`_make_ref`, :func:`_is_ref`, :func:`_ref_key`, + :func:`_ref_hash`. +- Promotion router: :func:`_resolve_input_storage`. +- Backward-compat-NOT-needed read: :func:`_read_input_value`. +- Size enforcement: :func:`_validate_attachment_size`, + :func:`_validate_attachment_count`. + +All names are underscore-prefixed (framework-private); promotion is +invisible to handler authors. +""" + +from __future__ import annotations + +import hashlib +import json +from typing import Any + +from ._exceptions import ( + InputTooLarge, + OutputTooLarge, + _AttachmentLimitExceeded, + _AttachmentTooLarge, +) + +# --------------------------------------------------------------------------- # +# Wire shape constants +# --------------------------------------------------------------------------- # + +#: The single magic key whose value is the ref's nested object. A payload +#: slot or queue entry that is a 1-key dict with this key is a ref; +#: anything else is treated as inline. See spec.md §4.3. +_ATTACHMENT_REF_KEY = "__attachment_ref__" + +#: The framework-reserved attachment key for the function input. +_FUNCTION_INPUT_KEY = "_input" + +#: The framework-reserved attachment-key prefix for queued steering inputs. +#: The full key is ``f"{prefix}{seq}"`` where ``seq`` is the monotonic +#: counter from ``payload["_steering"]["next_input_seq"]``. +_STEERING_INPUT_KEY_PREFIX = "_steering_input_" + +#: — framework-reserved attachment key for the +#: per-turn output value. Output is ALWAYS stored via this attachment +#: (no inline threshold) so it never consumes payload budget. +_OUTPUT_KEY = "_output" + +#: Hash algorithm prefix (RFC-6920-style namespacing). The value after the +#: ``:`` is the lowercase-hex digest. Prefix lets us migrate to a different +#: algorithm in the future without ambiguity. +_HASH_ALGO_PREFIX = "sha256:" + +# --------------------------------------------------------------------------- # +# Size + count caps (authoritative reference: task-attachments.md §2.4 + §3) +# --------------------------------------------------------------------------- # + +#: Function input promotion threshold (200 KiB). Inputs whose serialized +#: form exceeds this are promoted to ``attachments["_input"]``. +_INPUT_THRESHOLD_BYTES = 200 * 1024 + +#: Steering input promotion threshold (20 KiB). Inputs whose serialized +#: form exceeds this are promoted to +#: ``attachments["_steering_input_"]``. +_STEERING_THRESHOLD_BYTES = 20 * 1024 + +#: Per-attachment value cap (2 MB). Server-side hard cap; enforced +#: client-side via :class:`InputTooLarge` (developer-facing) / +#: :class:`_AttachmentTooLarge` (provider-internal; see +#: :func:`_remap_attachment_error`) before any HTTP call. +_MAX_ATTACHMENT_SIZE_BYTES = 2 * 1024 * 1024 + +#: Per-task attachment-entry cap. Server-side hard cap; enforced +#: client-side via :class:`_AttachmentLimitExceeded` (provider-internal). +_MAX_ATTACHMENTS = 20 + +#: Framework's steering queue hard cap. At most this many entries can be +#: queued (whether inline or refs). The 10th append raises +#: :class:`~._exceptions.SteeringQueueFull`. Combined with the 1 reserved +#: slot for the function input, the framework uses at most 10 of the 20 +#: attachment slots; the other 10 remain free for future features. +_STEERING_QUEUE_CAP = 9 + + +# --------------------------------------------------------------------------- # +# Hash helper +# --------------------------------------------------------------------------- # + + +def _compute_attachment_hash(serialized: Any) -> str: + """Compute the content hash of a serialized attachment value. + + The value is re-serialized to canonical JSON bytes + (``sort_keys=True``, no whitespace, separators ``(",", ":")``) and + hashed with SHA-256. The output is prefixed with ``"sha256:"`` so a + future migration to a different algorithm is unambiguous. + + :param serialized: The JSON-compatible value (already framework- + serialized via ``_serialize_input``). + :type serialized: Any + :return: ``"sha256:<64 lowercase hex chars>"``. + :rtype: str + """ + raw = json.dumps(serialized, sort_keys=True, separators=(",", ":")) + digest = hashlib.sha256(raw.encode("utf-8")).hexdigest() + return f"{_HASH_ALGO_PREFIX}{digest}" + + +# --------------------------------------------------------------------------- # +# Ref helpers +# --------------------------------------------------------------------------- # + + +def _make_ref(key: str, serialized: Any) -> dict[str, dict[str, str]]: + """Build a self-contained ref slot for the given attachment. + + Shape:: + + {"__attachment_ref__": {"key": "", "hash": "sha256:..."}} + + :param key: The attachment key the ref points at. + :type key: str + :param serialized: The attachment value (used to compute the content hash). + :type serialized: Any + :return: The ref slot dict. + :rtype: dict + """ + return { + _ATTACHMENT_REF_KEY: { + "key": key, + "hash": _compute_attachment_hash(serialized), + } + } + + +def _is_ref(slot: Any) -> bool: + """Return True iff *slot* is the strict ref shape. + + A slot is a ref iff ALL of: + + 1. It is a ``dict``. + 2. It has exactly one top-level key. + 3. The sole key is :data:`_ATTACHMENT_REF_KEY`. + 4. The value is itself a ``dict`` containing both ``"key"`` and ``"hash"``. + + Anything else (raw values, dicts shaped differently, other types) is + treated as inline. + + :param slot: The candidate slot. + :type slot: Any + :rtype: bool + """ + if not isinstance(slot, dict): + return False + if len(slot) != 1: + return False + nested = slot.get(_ATTACHMENT_REF_KEY) + if not isinstance(nested, dict): + return False + return "key" in nested and "hash" in nested + + +def _ref_key(slot: dict[str, dict[str, str]]) -> str: + """Return the attachment key carried by a ref slot. + + Caller MUST have validated the slot with :func:`_is_ref` first. + + :param slot: A ref slot (output of :func:`_make_ref`). + :return: The attachment key string. + :rtype: str + """ + return slot[_ATTACHMENT_REF_KEY]["key"] + + +def _ref_hash(slot: dict[str, dict[str, str]]) -> str: + """Return the content hash carried by a ref slot. + + Caller MUST have validated the slot with :func:`_is_ref` first. + + :param slot: A ref slot (output of :func:`_make_ref`). + :return: The ``"sha256:"`` hash string. + :rtype: str + """ + return slot[_ATTACHMENT_REF_KEY]["hash"] + + +# --------------------------------------------------------------------------- # +# Promotion router +# --------------------------------------------------------------------------- # + + +def _serialized_size_bytes(serialized: Any) -> int: + """Return the JSON wire-byte size of an already-serialized value. + + Uses the same canonical encoding as the hash so the byte count + matches what the server will store. (We don't subtract for JSON + framing the server adds around the value because that framing is + constant overhead unrelated to the per-value cap.) + + :param serialized: The JSON-compatible value. + :type serialized: Any + :rtype: int + """ + return len(json.dumps(serialized, separators=(",", ":")).encode("utf-8")) + + +def _resolve_input_storage( + serialized: Any, + *, + threshold_bytes: int, + key_for_attachment: str, + task_id: str, +) -> tuple[str, Any]: + """Decide whether an input goes inline or to an attachment. + + Returns a 2-tuple ``(mode, value)``: + + - ``("inline", serialized)`` — caller writes ``serialized`` directly + into payload (no attachments write). + - ``("attachment", ref_slot)`` — caller writes ``serialized`` into + ``attachments[key_for_attachment]`` AND writes ``ref_slot`` into + payload (or queue), in a SINGLE PATCH. + + Raises :class:`~._exceptions.InputTooLarge` if the serialized form + exceeds the per-attachment cap. + + :param serialized: The JSON-compatible value to route. + :type serialized: Any + :keyword threshold_bytes: Below-or-equal stays inline; strictly + greater is promoted. Use :data:`_INPUT_THRESHOLD_BYTES` for + function inputs and :data:`_STEERING_THRESHOLD_BYTES` for + steering inputs. + :paramtype threshold_bytes: int + :keyword key_for_attachment: The attachment key to use if promoted. + Caller-allocated to keep this helper stateless. + :paramtype key_for_attachment: str + :keyword task_id: For error context only. + :paramtype task_id: str + :return: ``("inline", serialized)`` or ``("attachment", ref_slot)``. + :rtype: tuple[str, Any] + :raises InputTooLarge: If the serialized form exceeds the + per-attachment cap (:data:`_MAX_ATTACHMENT_SIZE_BYTES`). + """ + size = _serialized_size_bytes(serialized) + if size > _MAX_ATTACHMENT_SIZE_BYTES: + raise InputTooLarge( + task_id=task_id, + size_bytes=size, + max_bytes=_MAX_ATTACHMENT_SIZE_BYTES, + ) + if size <= threshold_bytes: + return ("inline", serialized) + ref_slot = _make_ref(key_for_attachment, serialized) + return ("attachment", ref_slot) + + +# --------------------------------------------------------------------------- # +# Unified read +# --------------------------------------------------------------------------- # + + +def _read_input_value( + slot: Any, + attachments: dict[str, Any] | None, +) -> Any: + """Return the actual input value from a payload slot. + + If *slot* is a ref (per :func:`_is_ref`), looks up + ``attachments[ref_key]`` and returns that value. + Otherwise returns *slot* as-is (inline). + + No backward-compat for the legacy raw-input shape (per Decision 10 + in ``research.md``): the framework only writes ``("inline", + serialized)`` or ``("attachment", ref)`` — both are handled by this + one function. + + Raises ``KeyError`` if *slot* is a ref but the referenced attachment + is missing — caller's responsibility to surface a meaningful error. + + :param slot: The payload slot or queue entry. + :type slot: Any + :param attachments: The task's attachments dict (from + ``TaskInfo.attachments``). May be ``None`` if no attachments + exist; in that case, encountering a ref raises ``KeyError``. + :type attachments: dict[str, Any] | None + :return: The deserialized input value. + :rtype: Any + """ + if not _is_ref(slot): + return slot + key = _ref_key(slot) + if attachments is None: + raise KeyError( + f"Slot is a ref to attachment {key!r} but no attachments are present " + f"on the task. Wire-shape invariant violated." + ) + if key not in attachments: + raise KeyError( + f"Slot is a ref to attachment {key!r} but that attachment is missing. " + f"Available attachment keys: {sorted(attachments.keys())!r}" + ) + return attachments[key] + + +# --------------------------------------------------------------------------- # +# Size + count validators (used by the HTTP client + local provider) +# --------------------------------------------------------------------------- # + + +def _validate_attachment_size( + task_id: str, + attachment_key: str, + value: Any, +) -> None: + """Raise :class:`_AttachmentTooLarge` if *value* exceeds the per-attachment cap. + + Skip if value is ``None`` (representing a delete in a PATCH). + + : this raises the framework-internal + ``_AttachmentTooLarge``. Callers above the provider layer + (framework write paths) catch it and re-raise via + :func:`_remap_attachment_error` as the developer-facing + ``InputTooLarge`` / ``OutputTooLarge`` based on the + attachment-key prefix. + + :param task_id: Task identifier for error context. + :type task_id: str + :param attachment_key: Attachment key for error context. + :type attachment_key: str + :param value: The JSON-compatible attachment value. + :type value: Any + :raises _AttachmentTooLarge: If the serialized form exceeds the cap. + """ + if value is None: + return # null = delete; no size to enforce + size = _serialized_size_bytes(value) + if size > _MAX_ATTACHMENT_SIZE_BYTES: + raise _AttachmentTooLarge( + task_id=task_id, + attachment_key=attachment_key, + size_bytes=size, + max_bytes=_MAX_ATTACHMENT_SIZE_BYTES, + ) + + +def _validate_attachment_count( + task_id: str, + current_count: int, + additions: int = 1, +) -> None: + """Raise :class:`_AttachmentLimitExceeded` if adding *additions* exceeds the per-task cap. + + — internal-only exception; framework treats + propagation as a bug (the framework's own reserved usage is at + most 11 of 20 slots) and converts to ``RuntimeError`` at the boundary. + + :param task_id: Task identifier for error context. + :type task_id: str + :param current_count: Number of attachments currently on the task + (excluding any that this PATCH deletes). + :type current_count: int + :param additions: Number of new attachment entries this PATCH adds. + :type additions: int + :raises _AttachmentLimitExceeded: If ``current_count + additions > _MAX_ATTACHMENTS``. + """ + if current_count + additions > _MAX_ATTACHMENTS: + raise _AttachmentLimitExceeded( + task_id=task_id, + current_count=current_count, + max_count=_MAX_ATTACHMENTS, + ) + + +def _remap_attachment_error(exc: "_AttachmentTooLarge") -> ValueError: + """— translate the internal ``_AttachmentTooLarge`` + raised against a framework-reserved attachment key into the + developer-facing exception. + + Dispatch by attachment-key prefix: + + - ``_input`` → :class:`InputTooLarge` + - ``_steering_input_`` → :class:`InputTooLarge` + - ``_output`` → :class:`OutputTooLarge` + - anything else → :class:`RuntimeError` (framework bug — the + framework's own attachment writes only use the reserved keys). + + Callers do ``raise _remap_attachment_error(internal)`` so the + traceback reflects the framework's re-raise site, not the + provider's raise site. + """ + key = getattr(exc, "attachment_key", "") + task_id = getattr(exc, "task_id", "") + size_bytes = getattr(exc, "size_bytes", 0) + max_bytes = getattr(exc, "max_bytes", _MAX_ATTACHMENT_SIZE_BYTES) + if key == _FUNCTION_INPUT_KEY or key.startswith(_STEERING_INPUT_KEY_PREFIX): + return InputTooLarge(task_id=task_id, size_bytes=size_bytes, max_bytes=max_bytes) + if key == _OUTPUT_KEY: + return OutputTooLarge(task_id=task_id, size_bytes=size_bytes, max_bytes=max_bytes) + return RuntimeError( + f"Framework bug: _AttachmentTooLarge raised for unknown " + f"framework-reserved attachment key {key!r} on task {task_id!r}: " + f"{size_bytes} bytes > {max_bytes} byte cap." + ) + + +__all__ = [ + "_ATTACHMENT_REF_KEY", + "_FUNCTION_INPUT_KEY", + "_HASH_ALGO_PREFIX", + "_INPUT_THRESHOLD_BYTES", + "_MAX_ATTACHMENTS", + "_MAX_ATTACHMENT_SIZE_BYTES", + "_OUTPUT_KEY", + "_STEERING_INPUT_KEY_PREFIX", + "_STEERING_QUEUE_CAP", + "_STEERING_THRESHOLD_BYTES", + "_compute_attachment_hash", + "_is_ref", + "_make_ref", + "_read_input_value", + "_ref_hash", + "_ref_key", + "_remap_attachment_error", + "_resolve_input_storage", + "_serialized_size_bytes", + "_validate_attachment_count", + "_validate_attachment_size", +] diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tasks/_client.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tasks/_client.py new file mode 100644 index 000000000000..402e6a73f6e2 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tasks/_client.py @@ -0,0 +1,774 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Hosted resilient task provider — HTTP client for the Foundry Task Storage API. + +Communicates with ``{FOUNDRY_PROJECT_ENDPOINT}/tasks`` via +``azure.core.AsyncPipelineClient`` with the standard Azure SDK policy +chain. Bearer tokens are obtained lazily by ``AsyncBearerTokenCredentialPolicy``; +call-site code never assembles ``Authorization`` headers directly. + +**`ContentDecodePolicy` is intentionally excluded** from the policy +chain. The responses-storage gzip lesson: that policy +eagerly deserializes every body as JSON in middleware and crashes on +gzip / non-UTF-8 / gateway-HTML payloads before call-site code can +handle the response. Body parsing here happens at the call site with +defensive error handling. + +Every store-write call site funnels through :func:`_classify_store_write_error` + so the manager can react uniformly to +transient / evicted / conflict / permanent outcomes without re-deriving +the classification per-site. +""" + +from __future__ import annotations + +import gzip +import json +import logging +from typing import Any, Literal + +from azure.core import AsyncPipelineClient +from azure.core.configuration import Configuration +from azure.core.credentials_async import AsyncTokenCredential +from azure.core.exceptions import DecodeError +from azure.core.pipeline.policies import ( + AsyncBearerTokenCredentialPolicy, + AsyncRetryPolicy, + DistributedTracingPolicy, + HeadersPolicy, + RequestIdPolicy, + UserAgentPolicy, +) +from azure.core.pipeline.transport import AsyncHttpTransport +from azure.core.rest import HttpRequest + +from .._version import VERSION +from ._attachments import ( + _validate_attachment_count, + _validate_attachment_size, +) +from ._exceptions_internal import TaskNotFound +from ._exceptions_internal import _HostedConflict +from ._models import ( + TaskCreateRequest, + TaskInfo, + TaskPatchRequest, + TaskStatus, +) +from ._task_api_logging_policy import TaskApiLoggingPolicy + +logger = logging.getLogger("azure.ai.agentserver.tasks") + +_AUTH_SCOPE = "https://ai.azure.com/.default" +_API_VERSION = "v1" +_USER_AGENT = f"ai-agentserver-core/{VERSION}" +_BODY_PREFIX_LIMIT = 256 # truncation length for classified error bodies + + +# --------------------------------------------------------------------- # +# Classifier +# --------------------------------------------------------------------- # + + +ClassifiedOutcome = Literal["transient", "evicted", "conflict", "permanent"] + + +class TransportClassifiedError(Exception): + """Raised when a non-success response cannot be parsed safely. + + Carries enough metadata for operator triage without exposing + bearer tokens or full response bodies. ``classification`` carries + the outcome label so callers can branch consistently. + """ + + def __init__( + self, + *, + status: int, + classification: ClassifiedOutcome, + message: str, + request_id: str | None = None, + body_prefix: str | None = None, + ) -> None: + super().__init__(message) + self.status = status + self.classification = classification + self.request_id = request_id + self.body_prefix = body_prefix + + +def _classify_store_write_error( # pylint: disable=too-many-return-statements + status_code: int, body: bytes | None +) -> ClassifiedOutcome: + """Classify a non-success task-store response. + + Returns one of ``"transient"`` (retry), ``"evicted"`` (orphan-sandbox + eviction; local cleanup sequence), ``"conflict"`` (etag mismatch or + 409-other), ``"permanent"`` (404 / 400 / unrecognised 4xx). + + Tolerant of non-JSON / empty / shape-unexpected bodies — never + raises from inside the classifier; misshapen evictions are downgraded + to ``"conflict"`` so the framework never invents an eviction event + from noise (guard against false-positive evictions). + + :param status_code: HTTP status code from the response. + :type status_code: int + :param body: Raw response body bytes, or ``None`` if no body. + :type body: bytes | None + :return: Classification outcome for the response. + :rtype: ClassifiedOutcome + """ + # Transient: server-side problems, throttling, timeouts. + if status_code in (408, 429) or 500 <= status_code < 600: + return "transient" + + # 409: requires body inspection. + if status_code == 409: + if not body: + return "conflict" + try: + payload = json.loads(body) + except (ValueError, TypeError, UnicodeDecodeError): + return "conflict" # malformed 409 → safe default + if not isinstance(payload, dict): + return "conflict" + err = payload.get("error") + if isinstance(err, dict) and err.get("code") == "binding_mismatch": + return "evicted" + return "conflict" + + # 412 etag mismatch is a CAS conflict. + if status_code == 412: + return "conflict" + + # Everything else with 4xx is permanent (caller error). + if 400 <= status_code < 500: + return "permanent" + + # Anything else (e.g. 1xx, 3xx) — treat as permanent so callers + # do not silently retry unexpected shapes. + return "permanent" + + +def _body_prefix(body: bytes | None, limit: int = _BODY_PREFIX_LIMIT) -> str | None: + """Return up to ``limit`` decoded characters of ``body``, or ``None`` if empty. + + Tolerant of non-UTF-8 (uses ``errors="replace"``) and non-bytes input. + Used by the classified-error path so operators can see the start of a + non-JSON response without dumping the whole body to logs. + + :param body: Raw bytes from the response, or ``None``. + :type body: bytes | None + :param limit: Maximum characters to include in the prefix. + :type limit: int + :return: A truncated decoded prefix, or ``None`` if ``body`` is empty. + :rtype: str | None + """ + if not body: + return None + try: + text = bytes(body).decode("utf-8", errors="replace") + except Exception: # pylint: disable=broad-exception-caught # noqa: BLE001 + return None + if len(text) > limit: + return text[:limit] + "…" + return text + + +def _maybe_decompress(body: bytes | None, headers: Any) -> bytes | None: + """Decompress ``body`` if the response declares ``Content-Encoding: gzip``. + + Since ``ContentDecodePolicy`` is intentionally absent from the + pipeline, each call site is responsible for honoring + ``Content-Encoding``. Returns ``body`` unchanged for other encodings + so the caller's defensive JSON-parse can produce a useful error. + + :param body: Raw response bytes, or ``None``. + :type body: bytes | None + :param headers: Response headers (any mapping-like object). + :type headers: Any + :return: Decompressed body if applicable, otherwise ``body`` unchanged. + :rtype: bytes | None + """ + if not body or not headers: + return body + try: + encoding = headers.get("Content-Encoding") or headers.get("content-encoding") + except Exception: # pylint: disable=broad-exception-caught # noqa: BLE001 + return body + if not encoding: + return body + if encoding.lower().strip() == "gzip": + try: + return gzip.decompress(bytes(body)) + except (OSError, EOFError, ValueError): + # Malformed gzip — let the caller's JSON-parse surface it. + return body + return body + + +def _parse_json_body( + response: Any, + *, + method: str, + url: str, +) -> Any: + """Defensively decode a JSON body from the response. + + : catches ``UnicodeDecodeError``, ``json.JSONDecodeError``, + ``azure.core.exceptions.DecodeError`` and raises + :class:`TransportClassifiedError` carrying the classification, the + request id (if any), and a truncated body prefix. + + :param response: The pipeline response object. + :type response: Any + :keyword method: HTTP method of the originating request (for error context). + :paramtype method: str + :keyword url: Request URL (for error context). + :paramtype url: str + :return: The parsed JSON value on success. + :rtype: Any + """ + status = getattr(response, "status_code", 0) + headers = getattr(response, "headers", {}) or {} + try: + raw = response.body() + except Exception as exc: # noqa: BLE001 + raise TransportClassifiedError( + status=status, + classification=_classify_store_write_error(status, None), + message=(f"task-store {method} {url}: failed to read response body: " f"{type(exc).__name__}: {exc}"), + request_id=str(headers.get("x-ms-request-id", "") or "") or None, + ) from exc + body = _maybe_decompress(raw, headers) + try: + text = bytes(body or b"").decode("utf-8") + except UnicodeDecodeError as exc: + raise TransportClassifiedError( + status=status, + classification=_classify_store_write_error(status, body), + message=( + f"task-store {method} {url}: response body not valid UTF-8 " + f"(status={status}); body_prefix={_body_prefix(body)!r}" + ), + request_id=str(headers.get("x-ms-request-id", "") or "") or None, + body_prefix=_body_prefix(body), + ) from exc + try: + return json.loads(text) + except (json.JSONDecodeError, DecodeError) as exc: + raise TransportClassifiedError( + status=status, + classification=_classify_store_write_error(status, body), + message=( + f"task-store {method} {url}: response body not valid JSON " + f"(status={status}); body_prefix={_body_prefix(body)!r}" + ), + request_id=str(headers.get("x-ms-request-id", "") or "") or None, + body_prefix=_body_prefix(body), + ) from exc + + +def _raise_hosted_conflict_for_response(response: Any) -> None: + """SOT §39.1 — translate service error codes to ``_HostedConflict``. + + The hosted task service emits distinct ``code`` strings inside its JSON + error envelope for each failure cause (``task_immutable``, + ``invalid_state_transition``, ``lease_held_by_another``, + ``task_already_exists``, ``lease_ownership_changed``, + ``etag_mismatch``, ``invalid_request``). The framework's lifecycle + code dispatches on these to choose recovery action (retry vs + translate to a public exception vs log-as-bug). + + This function raises ``_HostedConflict(_code=, status_code=)`` + when the response body carries a recognized service code. Otherwise it + returns silently so the caller can fall through to the generic + ``_classify_store_write_error`` path (transient / evicted / conflict / + permanent). + + :param response: The pipeline response object. + :type response: Any + """ + status = getattr(response, "status_code", 0) + headers = getattr(response, "headers", {}) or {} + try: + raw = response.body() + except Exception: # pylint: disable=broad-exception-caught # noqa: BLE001 + raw = None + body = _maybe_decompress(raw, headers) if raw else None + if not body: + return + try: + payload = json.loads(body) + except (ValueError, TypeError, UnicodeDecodeError): + return + if not isinstance(payload, dict): + return + err = payload.get("error") + if not isinstance(err, dict): + return + code = err.get("code") + if code not in _SPEC_020_SERVICE_CODES: + return + message = err.get("message") if isinstance(err.get("message"), str) else None + raise _HostedConflict( + _code=code, + status_code=int(status), + message=message, + ) + + +_SPEC_020_SERVICE_CODES = frozenset( + { + "task_immutable", + "invalid_state_transition", + "lease_held_by_another", + "task_already_exists", + "lease_ownership_changed", + "etag_mismatch", + "invalid_request", + } +) + + +def _raise_classified( + response: Any, + *, + method: str, + url: str, +) -> None: + """Inspect a response and raise :class:`TransportClassifiedError`. + + Replaces the legacy ``response.raise_for_status()`` call sites + so every non-success response funnels through + the classifier and carries the canonical outcome label. + + additionally checks for the service's distinct error + codes before the generic classification — when one matches, an + internal ``_HostedConflict`` is raised instead (see §39.1). + + :param response: The pipeline response object. + :type response: Any + :keyword method: HTTP method of the originating request (for error context). + :paramtype method: str + :keyword url: Request URL (for error context). + :paramtype url: str + """ + #: check for service-coded errors first. If matched, + # _HostedConflict is raised and we never reach the generic + # classifier below. + _raise_hosted_conflict_for_response(response) + + status = getattr(response, "status_code", 0) + headers = getattr(response, "headers", {}) or {} + try: + raw = response.body() + except Exception: # pylint: disable=broad-exception-caught # noqa: BLE001 + raw = None + body = _maybe_decompress(raw, headers) if raw else None + classification = _classify_store_write_error(status, body) + raise TransportClassifiedError( + status=status, + classification=classification, + message=(f"task-store {method} {url}: classified={classification} status={status}"), + request_id=str(headers.get("x-ms-request-id", "") or "") or None, + body_prefix=_body_prefix(body), + ) + + +# --------------------------------------------------------------------- # +# HostedTaskProvider — azure.core.AsyncPipelineClient +# --------------------------------------------------------------------- # + + +def _build_default_policies( + credential: AsyncTokenCredential, +) -> list[Any]: + """Construct the canonical policy chain. + + Order: RequestIdPolicy, HeadersPolicy, UserAgentPolicy, + AsyncRetryPolicy (retry on 5xx / 408 / 429 only — NEVER on 409), + AsyncBearerTokenCredentialPolicy, TaskApiLoggingPolicy, + DistributedTracingPolicy. + + ``ContentDecodePolicy`` is intentionally NOT included — see module + docstring for the responses-storage gzip lesson. + + :param credential: Async token credential for the bearer-token policy. + :type credential: AsyncTokenCredential + :return: The default ordered policy chain. + :rtype: list[Any] + """ + return [ + RequestIdPolicy(), + HeadersPolicy(base_headers={"Foundry-Features": "Routines=V1Preview"}), + UserAgentPolicy(base_user_agent=_USER_AGENT), + # Retry on 5xx and the standard transient HTTP statuses; 409 + # is explicitly NOT in retry_on_status_codes because + # 409 carries application semantics (conflict / binding_mismatch) + # that retry would silently mask. + AsyncRetryPolicy( + retry_total=3, + retry_on_status_codes=[408, 429, 500, 502, 503, 504], + retry_backoff_factor=0.5, + ), + AsyncBearerTokenCredentialPolicy(credential, _AUTH_SCOPE), + TaskApiLoggingPolicy(), + DistributedTracingPolicy(), + ] + + +class HostedTaskProvider: + """HTTP-backed provider for the Foundry Task Storage API. + + Built on :class:`azure.core.AsyncPipelineClient` with the standard + policy chain. ``ContentDecodePolicy`` is + explicitly excluded; body parsing happens at the call site with + defensive error handling. + + :param project_endpoint: The ``FOUNDRY_PROJECT_ENDPOINT`` base URL. + :type project_endpoint: str + :param credential: An async token credential supporting + ``get_token(scope)`` (e.g. + :class:`azure.identity.aio.DefaultAzureCredential`). + :type credential: AsyncTokenCredential + :keyword transport: Optional :class:`AsyncHttpTransport` override + (used by tests for fake-transport injection per + Conformance Test Map row 14). + :paramtype transport: AsyncHttpTransport | None + """ + + def __init__( + self, + project_endpoint: str, + credential: AsyncTokenCredential, + *, + transport: AsyncHttpTransport | None = None, + ) -> None: + self._base_url = f"{project_endpoint.rstrip('/')}/tasks" + self._credential = credential + config: Configuration = Configuration() + config.user_agent_policy = UserAgentPolicy(base_user_agent=_USER_AGENT) + self._policies: list[Any] = _build_default_policies(credential) + self._client: AsyncPipelineClient = AsyncPipelineClient( + base_url=self._base_url, + config=config, + policies=self._policies, + transport=transport, + ) + + @property + def policies(self) -> list[Any]: + """The policy chain in order — used by tests for composition assertions. + + :return: A shallow copy of the configured policy chain. + :rtype: list[Any] + """ + return list(self._policies) + + async def _send(self, request: HttpRequest) -> Any: + """Send ``request`` through the pipeline and return the HTTP response. + + The pipeline returns a ``PipelineResponse`` whose + ``http_response`` is the wire response we operate on. + + :param request: The HTTP request to send. + :type request: HttpRequest + :return: The wire HTTP response. + :rtype: Any + """ + pipeline_response = await self._client._pipeline.run( + request + ) # pylint: disable=protected-access # noqa: SLF001 + return pipeline_response.http_response + + async def create(self, request: TaskCreateRequest) -> TaskInfo: + """Create a new task via POST /tasks. + + :param request: Task creation parameters. + :type request: TaskCreateRequest + :return: The created task record. + :rtype: TaskInfo + """ + params: dict[str, str] = {"api-version": _API_VERSION} + if request.lease_owner is not None: + params["lease_owner"] = request.lease_owner + if request.lease_instance_id is not None: + params["lease_instance_id"] = request.lease_instance_id + if request.lease_duration_seconds is not None: + params["lease_duration_seconds"] = str(request.lease_duration_seconds) + + body: dict[str, Any] = { + "agent_name": request.agent_name, + "session_id": request.session_id, + } + if request.id is not None: + body["id"] = request.id + if request.status != "pending": + body["status"] = request.status + if request.title is not None: + body["title"] = request.title + if request.description is not None: + body["description"] = request.description + if request.payload is not None: + body["payload"] = request.payload + if request.tags is not None: + body["tags"] = request.tags + if request.source is not None: + body["source"] = request.source + if request.attachments is not None: + # — enforce per-attachment 2 MB and per-task 20-entry + # caps client-side before the HTTP call. Create cannot + # delete anything (no null values meaningful here), so + # count is the number of entries. + additions = sum(1 for v in request.attachments.values() if v is not None) + _validate_attachment_count( + task_id=request.id or "", + current_count=0, + additions=additions, + ) + for k, v in request.attachments.items(): + _validate_attachment_size( + task_id=request.id or "", + attachment_key=k, + value=v, + ) + body["attachments"] = request.attachments + + http_request = HttpRequest( + "POST", + self._base_url, + params=params, + content=json.dumps(body), + headers={"Content-Type": "application/json"}, + ) + response = await self._send(http_request) + if response.status_code >= 400: + _raise_classified(response, method="POST", url=self._base_url) + return TaskInfo.from_dict(_parse_json_body(response, method="POST", url=self._base_url)) + + async def get(self, task_id: str) -> TaskInfo | None: + """Get a task by ID via GET /tasks/{id}. + + :param task_id: The task identifier. + :type task_id: str + :return: The task record, or ``None`` if not found. + :rtype: TaskInfo | None + """ + url = f"{self._base_url}/{task_id}" + http_request = HttpRequest( + "GET", + url, + params={"api-version": _API_VERSION}, + ) + response = await self._send(http_request) + if response.status_code == 404: + return None + if response.status_code >= 400: + _raise_classified(response, method="GET", url=url) + return TaskInfo.from_dict(_parse_json_body(response, method="GET", url=url)) + + async def update(self, task_id: str, patch: TaskPatchRequest) -> TaskInfo: + """Update a task via PATCH /tasks/{id}. + + :param task_id: The task identifier. + :type task_id: str + :param patch: Fields to update. + :type patch: TaskPatchRequest + :return: The updated task record. + :rtype: TaskInfo + :raises TaskNotFound: If the task does not exist. + """ + params: dict[str, str] = {"api-version": _API_VERSION} + if patch.lease_owner is not None: + params["lease_owner"] = patch.lease_owner + if patch.lease_instance_id is not None: + params["lease_instance_id"] = patch.lease_instance_id + if patch.lease_duration_seconds is not None: + params["lease_duration_seconds"] = str(patch.lease_duration_seconds) + + body: dict[str, Any] = {} + if patch.status is not None: + body["status"] = patch.status + if patch.payload is not None: + body["payload"] = patch.payload + if patch.tags is not None: + body["tags"] = patch.tags + if patch.error is not None: + body["error"] = patch.error + if patch.suspension_reason is not None: + body["suspension_reason"] = patch.suspension_reason + if getattr(patch, "clear_attachments", False) and patch.attachments is not None: + raise _HostedConflict( + _code="invalid_request", + status_code=400, + message="clear_attachments cannot be combined with attachments patch.", + task_id=task_id, + ) + if getattr(patch, "clear_attachments", False): + body["attachments"] = None + if patch.attachments is not None: + # — enforce per-attachment 2 MB cap on every + # non-null value in the patch. (We don't enforce the + # per-task 20-entry cap here because we don't have the + # current attachment count without a GET; callers that + # need pre-flight count enforcement should call + # `_validate_attachment_count` themselves. Server will + # reject if exceeded.) + for k, v in patch.attachments.items(): + _validate_attachment_size( + task_id=task_id, + attachment_key=k, + value=v, + ) + body["attachments"] = patch.attachments + + headers: dict[str, str] = {"Content-Type": "application/json"} + if patch.if_match is not None: + # Pass the service-returned etag straight through. The + # hosted task store's comparator (since the server-side + # fix landed) treats the etag value verbatim — no client- + # side stripping or wrapping. The local provider already + # accepts bare values; both providers therefore round- + # trip the same byte-for-byte value from a prior GET / + # PATCH response into the next If-Match. + headers["If-Match"] = str(patch.if_match) + + url = f"{self._base_url}/{task_id}" + http_request = HttpRequest( + "PATCH", + url, + params=params, + content=json.dumps(body), + headers=headers, + ) + response = await self._send(http_request) + if response.status_code == 404: + raise TaskNotFound(task_id) + if response.status_code >= 400: + _raise_classified(response, method="PATCH", url=url) + return TaskInfo.from_dict(_parse_json_body(response, method="PATCH", url=url)) + + async def delete( + self, + task_id: str, + *, + force: bool = False, + cascade: bool = False, + ) -> None: + """Delete a task via DELETE /tasks/{id}. + + :param task_id: The task identifier. + :type task_id: str + :keyword force: Release active lease before deleting. + :paramtype force: bool + :keyword cascade: Delete dependent tasks. + :paramtype cascade: bool + """ + params: dict[str, str] = {"api-version": _API_VERSION} + if force: + params["force"] = "true" + if cascade: + params["cascade"] = "true" + + url = f"{self._base_url}/{task_id}" + http_request = HttpRequest( + "DELETE", + url, + params=params, + ) + response = await self._send(http_request) + if response.status_code == 404: + raise TaskNotFound(task_id) + if response.status_code >= 400: + _raise_classified(response, method="DELETE", url=url) + + async def list( + self, + *, + agent_name: str | None = None, + session_id: str | None = None, + status: TaskStatus | str | None = None, + lease_owner: str | None = None, + tag: dict[str, str] | None = None, + source_type: str | None = None, + has_error: bool | None = None, + lease_expired: bool | None = None, + limit: int | None = None, + after: str | None = None, + before: str | None = None, + order: str | None = None, + omit_attachment_values: bool = False, + ) -> list[TaskInfo]: + """List tasks via GET /tasks with automatic cursor pagination. + + :keyword agent_name: Filter to tasks owned by this agent name. + :paramtype agent_name: str + :keyword session_id: Filter to tasks for this session ID. + :paramtype session_id: str + :keyword status: Optional status filter (``pending``, + ``in_progress``, ``suspended``, ``completed``). + :paramtype status: TaskStatus | None + :keyword lease_owner: Optional lease-owner string filter. + :paramtype lease_owner: str | None + :keyword tag: Optional tag-equality filter (all key/value pairs + must match). + :paramtype tag: dict[str, str] | None + :keyword source_type: Optional source-type filter. + :paramtype source_type: str | None + :return: All matching tasks across all pages. + :rtype: list[TaskInfo] + """ + params: dict[str, str] = { + "api-version": _API_VERSION, + "limit": str(limit if limit is not None else 100), + } + if agent_name is not None: + params["agent_name"] = agent_name + if session_id is not None: + params["session_id"] = session_id + if status is not None: + params["status"] = status + if lease_owner is not None: + params["lease_owner"] = lease_owner + if tag: + for key, value in tag.items(): + params[f"tag.{key}"] = value + if source_type is not None: + params["source_type"] = source_type + if has_error is not None: + params["has_error"] = str(has_error).lower() + if lease_expired is not None: + params["lease_expired"] = str(lease_expired).lower() + if after is not None: + params["after"] = after + if before is not None: + params["before"] = before + if order is not None: + params["order"] = order + if omit_attachment_values: + params["omit_attachment_values"] = "true" + + all_tasks: list[TaskInfo] = [] + while True: + http_request = HttpRequest("GET", self._base_url, params=params) + response = await self._send(http_request) + if response.status_code >= 400: + _raise_classified(response, method="GET", url=self._base_url) + data = _parse_json_body(response, method="GET", url=self._base_url) + items: list[dict[str, Any]] = data.get("data", data.get("items", [])) + all_tasks.extend(TaskInfo.from_dict(item) for item in items) + + if not data.get("has_more", False): + break + last_id = data.get("last_id") + if not last_id: + break + params["after"] = last_id + + return all_tasks + + async def close(self) -> None: + """Close the underlying pipeline client.""" + await self._client.close() diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tasks/_context.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tasks/_context.py new file mode 100644 index 000000000000..6744402044ba --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tasks/_context.py @@ -0,0 +1,209 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""TaskContext — the single parameter to a resilient task function. + +Provides identity, typed input, mutable metadata, cancellation signals, +and the ``suspend()`` method for pausing execution. + + introduces the cancel-cause boolean surface +(``timeout_exceeded``, ``cancel_requested``, ``pending_input_count``, +``is_steered_turn``) and the ``exit_for_recovery()`` graceful-shutdown +shape. The legacy fields ``was_steered`` / ``pending_inputs`` / +``steering_generation`` are removed. +""" + +from __future__ import annotations + +import asyncio # pylint: disable=do-not-import-asyncio +from typing import Any, Callable, Generic, Literal, TypeVar + +from ._metadata import TaskMetadata + +Input = TypeVar("Input") +Output = TypeVar("Output") + +EntryMode = Literal["fresh", "resumed", "recovered"] +"""Why the resilient function was entered. + +- ``"fresh"`` — First execution. Task was just created or started from pending. +- ``"resumed"`` — Re-entered after suspension. On developer-initiated resume + (via ``.run()``), ``ctx.input`` contains the new input. On platform-initiated + resume (via ``/tasks/{task_id}/resume``), ``ctx.input`` contains the task's + persisted input. Also used when a steering input drains from the queue — + check ``ctx.is_steered_turn`` to distinguish steering re-entry from normal + resume. +- ``"recovered"`` — Re-entered after stale task detection. The previous execution + crashed or timed out. ``ctx.input`` contains the task's persisted input. + If a steerable task crashed mid-drain, ``ctx.is_steered_turn`` will be + ``True``. +""" + + +class _Suspended: + """Internal sentinel for suspended tasks. See ``Suspended`` in ``_run.py``.""" + + __slots__ = ("reason", "output") + + def __init__( + self, + reason: str | None = None, + output: Any | None = None, + ) -> None: + self.reason = reason + self.output = output + + +class _ExitForRecovery: + """: internal sentinel returned by + :meth:`TaskContext.exit_for_recovery` to signal the framework to + flush metadata, release the lease, and leave the stored status + as ``in_progress``. + """ + + __slots__ = () + + +class TaskContext(Generic[Input]): # pylint: disable=too-many-instance-attributes + """The single parameter to a resilient task function. + + Provides access to the task's identity, typed input, mutable metadata + for progress tracking, cancellation signals (with cause booleans), + and the ability to suspend or exit-for-recovery. + + :param task_id: Unique task identifier. + :type task_id: str + :param input: Typed, validated input value. + :type input: Input + :param metadata: Mutable progress metadata. + :type metadata: TaskMetadata + :param retry_attempt: Resilient retry attempt counter. Survives crashes; + increments only on failure-retries, never on crash recovery. + :type retry_attempt: int + :param recovery_count: Crash-recovery counter. Increments each time the + framework re-enters this task after a lease loss or stale detection. + :type recovery_count: int + :param cancel: Request-level cancellation event. The framework sets + this from multiple causes; observe ``timeout_exceeded``, + ``cancel_requested``, ``pending_input_count`` to disambiguate. + :type cancel: asyncio.Event + :param shutdown: Container-level shutdown event. Precondition for + :meth:`exit_for_recovery`. + :type shutdown: asyncio.Event + """ + + __slots__ = ( + "task_id", + "input_id", # / + "_session_id", + "input", + "metadata", + "retry_attempt", + "recovery_count", + "cancel", + "shutdown", + "_suspend_callback", + "entry_mode", + # .. public cancel-cause / steering surface. + "timeout_exceeded", + "cancel_requested", + "is_steered_turn", + # Internal callable for the live pending_input_count property + # . The framework sets this when constructing the + # TaskContext; the property reads it on each access so the + # count reflects the current backlog including inputs queued + # mid-handler. + "_pending_count_provider", + ) + + def __init__( + self, + *, + task_id: str, + session_id: str, + input: Input, # noqa: A002 — mirrors the spec naming + metadata: TaskMetadata, + retry_attempt: int = 0, + recovery_count: int = 0, + cancel: asyncio.Event | None = None, + shutdown: asyncio.Event | None = None, + entry_mode: EntryMode = "fresh", + is_steered_turn: bool = False, + pending_count_provider: Callable[[], int] | None = None, + input_id: str | None = None, + ) -> None: + self.task_id = task_id + # /: input_id is part of the public TaskContext + # surface. Defaults to task_id (one-shot 1:1 invariant). + self.input_id = input_id if input_id is not None else task_id + self._session_id = session_id + self.input = input + self.metadata = metadata + self.retry_attempt = retry_attempt + self.recovery_count = recovery_count + self.cancel = cancel or asyncio.Event() + self.shutdown = shutdown or asyncio.Event() + self._suspend_callback: Any = None + self.entry_mode: EntryMode = entry_mode + # ..: public surface fields. Defaults are + # framework-controlled at construction; framework setters update + # them in place. No public setters. + self.timeout_exceeded: bool = False + self.cancel_requested: bool = False + self.is_steered_turn: bool = is_steered_turn + self._pending_count_provider = pending_count_provider + + @property + def pending_input_count(self) -> int: + """: live count of queued steering inputs. + + Reflects the current backlog including inputs queued mid-handler. + Reads as ``0`` for non-steerable tasks (where the provider + returns 0). Replaces the legacy ``ctx.pending_inputs: Sequence[Any]`` + snapshot. + + :return: Number of queued steering inputs. + :rtype: int + """ + if self._pending_count_provider is None: + return 0 + try: + return int(self._pending_count_provider()) + except Exception: # pylint: disable=broad-exception-caught # noqa: BLE001 + return 0 + + async def exit_for_recovery(self) -> Any: + """: graceful-shutdown shape. + + Callable ONLY when ``ctx.shutdown.is_set() == True``. Calling it + outside shutdown raises ``RuntimeError`` at the call site + (visible in user-code tracebacks; the task ends in ``failed``). + + When called during shutdown, the framework: + + 1. Flushes ``ctx.metadata`` (auto-flush invariant). + 2. Releases the lease on the persisted record. + 3. Leaves the stored ``status`` as ``in_progress`` (NOT + transitions to ``suspended``). + 4. Signals in-process awaiters with the standard cooperative- + cancel ``TaskCancelled`` result. + 5. Preserves any queued steering inputs in the persisted state + . + + The recovery scan on the next process startup re-enters the + handler with ``ctx.entry_mode == "recovered"``. + + Use as ``return await ctx.exit_for_recovery()``. + + :return: The :class:`_ExitForRecovery` sentinel. + :rtype: Any + :raises RuntimeError: If called outside ``ctx.shutdown.is_set() == True``. + """ + if not self.shutdown.is_set(): + raise RuntimeError( + "ctx.exit_for_recovery() may only be called when " + "ctx.shutdown.is_set() is true. The misuse-as-failed " + "semantic exists so operator logs surface accidental " + "calls loudly." + ) + return _ExitForRecovery() diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tasks/_decorator.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tasks/_decorator.py new file mode 100644 index 000000000000..d9b89cbc9455 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tasks/_decorator.py @@ -0,0 +1,1692 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""``@task`` decorator — turns an async function into a crash-resilient +unit of work with automatic task lifecycle management. + +Usage:: + + from azure.ai.agentserver.core.tasks import task, TaskContext + + @task + async def my_task(ctx: TaskContext[MyInput]) -> MyOutput: + ... + + result = await my_task.run(task_id="t1", input=MyInput(...)) +""" + +from __future__ import annotations + +import asyncio # pylint: disable=do-not-import-asyncio +import inspect +import logging as _logging +from collections.abc import Awaitable, Callable +from datetime import timedelta +from typing import ( + TYPE_CHECKING, + Any, + Generic, + TypeVar, + get_args, + get_type_hints, + overload, +) + +import re + +from ._client import TransportClassifiedError as _TransportClassifiedError +from ._context import TaskContext +from ._exceptions_internal import _HostedConflict, _translate_hosted_conflict +from ._retry import RetryPolicy +from ._run import TaskRun + +if TYPE_CHECKING: + from ._models import TaskStatus + +Input = TypeVar("Input") +Output = TypeVar("Output") +F = TypeVar("F", bound=Callable[..., Any]) + +_VALID_TASK_ID_RE = re.compile(r"^[a-zA-Z0-9\-_.:]+$") +_MAX_TASK_ID_LENGTH = 256 + +#: Prefix for framework-reserved tags. Developer tags with this prefix are +#: silently stripped to prevent collisions with auto-stamped tags. +_RESERVED_TAG_PREFIX = "_task_" + +_logger = _logging.getLogger("azure.ai.agentserver.tasks") + +# Global registry of resilient task descriptors for recovery purposes. +# Populated at import time when @task decorates a function. +_REGISTERED_DESCRIPTORS: list[tuple[str, Callable[..., Any], "TaskOptions"]] = [] + + +def _strip_reserved_tags(tags: dict[str, str]) -> dict[str, str]: + """Remove framework-reserved tags from developer-provided tags. + + Tags prefixed with ``_task_`` are reserved for framework use. + If a developer provides them, they are silently dropped with a warning. + + :param tags: Developer-provided tags. + :type tags: dict[str, str] + :return: Tags with reserved keys removed. + :rtype: dict[str, str] + """ + reserved = [k for k in tags if k.startswith(_RESERVED_TAG_PREFIX)] + if reserved: + _logger.warning( + "Ignoring reserved tag(s) %s — tags prefixed with %r are " "framework-owned and cannot be overridden", + reserved, + _RESERVED_TAG_PREFIX, + ) + return {k: v for k, v in tags.items() if not k.startswith(_RESERVED_TAG_PREFIX)} + return tags + + +def _validate_task_id(task_id: str) -> None: + if not task_id or len(task_id) > _MAX_TASK_ID_LENGTH: + raise ValueError(f"task_id must be 1-{_MAX_TASK_ID_LENGTH} characters, " f"got {len(task_id)}") + if not _VALID_TASK_ID_RE.match(task_id): + raise ValueError(f"task_id contains invalid characters: {task_id!r}. " f"Allowed: [a-zA-Z0-9\\-_.:] ") + + +def _extract_generic_args( + fn: Callable[..., Any], +) -> tuple[type[Any], type[Any]]: + """Extract Input and Output types from a resilient task function signature. + + The function must accept a single ``TaskContext[Input]`` parameter + and return ``Output``. + + :param fn: The async function to inspect. + :type fn: Callable[..., Any] + :returns: ``(InputType, OutputType)`` tuple. + :rtype: tuple[type[Any], type[Any]] + :raises TypeError: If the signature doesn't match expectations. + """ + hints = get_type_hints(fn) + params = list(inspect.signature(fn).parameters.values()) + + # Find the TaskContext parameter + ctx_param = None + for p in params: + hint = hints.get(p.name) + if hint is not None: + origin = getattr(hint, "__origin__", None) + if origin is TaskContext: + ctx_param = p + break + + if ctx_param is None: + raise TypeError(f"Resilient task function {fn.__qualname__!r} must accept a " f"TaskContext[Input] parameter") + + ctx_hint = hints[ctx_param.name] + args = get_args(ctx_hint) + input_type: type[Any] = args[0] if args else Any # type: ignore[assignment] + + return_hint = hints.get("return", Any) + # Unwrap Optional, Awaitable, etc. + output_type: type[Any] = return_hint if return_hint is not None else type(None) + + return input_type, output_type + + +def _serialize_input(value: Any) -> Any: + """Serialize an input value for storage in the task payload. + + :param value: The input value to serialize. + :type value: Any + :return: The serialized form of the input. + :rtype: Any + """ + # Pydantic model + if hasattr(value, "model_dump"): + return value.model_dump() + # Plain JSON-serializable + return value + + +def _deserialize_input(value: Any, input_type: type[Any]) -> Any: + """Deserialize an input value from the task payload. + + :param value: The serialized input value. + :type value: Any + :param input_type: The expected type to deserialize into. + :type input_type: type[Any] + :return: The deserialized input value. + :rtype: Any + """ + if value is None: + return None + # Pydantic model + if hasattr(input_type, "model_validate"): + return input_type.model_validate(value) + # dict-constructable class + if isinstance(value, dict) and callable(input_type) and input_type not in (dict, str, int, float, bool, list): + try: + return input_type(**value) + except TypeError: + pass + return value + + +# — framework-reserved payload slot for the +# input-precondition primitive. Storage layout: top-level +# ``payload["_last_input_id"]: str`` (the ``_`` prefix is the framework- +# reserved convention; flat layout replaces the prior nested +# ``payload["_last_input_id"]`` namespace). +# Callers do not read or write this slot directly — it is managed by the +# framework on behalf of the ``input_id`` / ``if_last_input_id`` kwargs on +# :meth:`Task.start`. +_LAST_INPUT_ID_PAYLOAD_KEY = "_last_input_id" + +# — these were previously developer-visible +# @task kwargs (lease_duration_seconds, max_pending) but had no real +# end-user knob value. Demoted to module-level internal constants. If a +# future need arises to tune them per-task, re-introduce a Sec-Privileged +# API rather than restoring the public surface. +_DEFAULT_LEASE_SECONDS = 60 +# (task-attachments) §3.3 — the steering queue is hard-capped +# at 9 entries. This reserves at most 10 of the 20 attachment slots for +# framework use (9 steering + 1 function input); the other 10 remain +# free for future features. Replaces the prior 10-cap. +_DEFAULT_MAX_PENDING_STEERING = 9 + + +def _read_stored_last_input_id(task_info: Any) -> str | None: + """Read the stored ``last_input_id`` from a task's payload, or ``None``. + + :param task_info: The persisted task record (or ``None`` for a fresh + task that does not exist yet). + :type task_info: TaskInfo | None + :returns: The stored value, or ``None`` if no chain has been recorded. + :rtype: str | None + """ + if task_info is None or not task_info.payload: + return None + value = task_info.payload.get(_LAST_INPUT_ID_PAYLOAD_KEY) + return value if isinstance(value, str) else None + + +def _check_input_precondition( + *, + existing: Any, + task_id: str, + input_id: str | None, + if_last_input_id: str | None, +) -> None: + """Validate the ``if_last_input_id`` precondition before any accept path. + + Semantic rules: + + - Both ``input_id`` and ``if_last_input_id`` ``None``: no precondition. + - ``input_id`` set, ``if_last_input_id`` ``None``: idempotency-only mode + — the caller wants the chain head advanced to ``input_id`` but is + NOT asserting any predecessor. Always succeeds; the chain head is + overwritten on the accept path. Use this for per-turn idempotency + identifiers (e.g. a response_id) when chain-ordering is enforced + externally (e.g. by task_id collapse + TaskConflictError + sequencing for conversation-grouped multi-turn). + - ``if_last_input_id`` set, stored ``last_input_id`` ``None``: the chain + task is brand new (e.g., a steerable conversation's second turn lands + on a freshly-created chain task). The precondition is vacuously + satisfied — the framework cannot locally verify the predecessor's + identity, but ``TaskConflictError`` on the create path protects + against double-create races. We accept and seed. + - Both set with stored: stored ``last_input_id`` must equal + ``if_last_input_id``. + + :keyword existing: The persisted task record (or ``None`` for fresh). + :keyword task_id: The task identifier. + :keyword input_id: The new input's identity (caller-supplied). + :keyword if_last_input_id: The precondition value (caller-supplied). + :raises LastInputIdPreconditionFailed: If the precondition does not hold. + """ + if if_last_input_id is None: + # Either no precondition at all, or idempotency-only mode where + # the caller advances the chain head without asserting any + # predecessor. Both cases succeed unconditionally. + return + from ._exceptions import ( # pylint: disable=import-outside-toplevel + LastInputIdPreconditionFailed, + ) + + stored = _read_stored_last_input_id(existing) + # if_last_input_id is set. + if stored is None: + # No prior chain recorded. The chain task is brand new — accept + # and let the seed write happen on the accept path. + return + # Both stored and if_last_input_id set — must match. + if stored != if_last_input_id: + raise LastInputIdPreconditionFailed( + task_id, + expected_last_input_id=if_last_input_id, + actual_last_input_id=stored, + ) + + +def _build_framework_extras(input_id: str | None) -> dict[str, Any] | None: + """Build the top-level ``payload["_last_input_id"]`` seed dict, or ``None``. + + Used at fresh-create and at suspended-resume to advance the stored + ``last_input_id`` atomically with the input persist. + + :param input_id: The new input's identity, or ``None`` for callers not + opting in to chain semantics. + :type input_id: str | None + :returns: ``{"_last_input_id": input_id}`` if ``input_id`` is set, + else ``None``. + :rtype: dict[str, Any] | None + """ + if input_id is None: + return None + return {_LAST_INPUT_ID_PAYLOAD_KEY: input_id} + + +class TaskOptions: # pylint: disable=too-many-instance-attributes + """Internal task options bag. + + *Internal*: not part of the public ``resilient`` surface as of. + Constructed by the ``@task`` decorator (and ``Task.options()``) from a small + public kwarg set: ``name``, ``title``, ``tags``, ``timeout``, ``ephemeral``, + ``retry``, ``steerable``, . + + :param name: **Stable identity anchor.** Used for recovery routing and + source stamping. If you rename the Python function later, existing + in-flight tasks are still recovered correctly because the framework + matches on this name. + :type name: str + :param title: Human-readable title template. + :type title: str | Callable[[Any, str], str] | None + :param tags: Default tags (static dict or callable factory). + :type tags: dict[str, str] | Callable[[Any, str], dict[str, str]] + :param timeout: Execution timeout. + :type timeout: timedelta | None + :param ephemeral: Whether to delete on terminal exit. + :type ephemeral: bool + """ + + __slots__ = ( + "name", + "title", + "tags", + "timeout", + "ephemeral", + "retry", + "steerable", + "_is_multi_turn", # — True when wrapped by @multi_turn_task + ) + + def __init__( + self, + name: str, + title: str | Callable[[Any, str], str] | None = None, + tags: dict[str, str] | Callable[[Any, str], dict[str, str]] | None = None, + timeout: timedelta | None = None, + ephemeral: bool = True, + retry: RetryPolicy | None = None, + steerable: bool = False, + _is_multi_turn: bool = False, + ) -> None: + self.name = name + self.title = title + self.tags = tags if tags is not None else {} + self.timeout = timeout + self.ephemeral = ephemeral + self.retry = retry + self.steerable = steerable + self._is_multi_turn = _is_multi_turn + + def __repr__(self) -> str: + return ( + f"TaskOptions(name={self.name!r}, " + f"ephemeral={self.ephemeral}, retry={self.retry!r}, " + f"timeout={self.timeout!r}, steerable={self.steerable})" + ) + + +class Task(Generic[Input, Output]): + """A decorated resilient task function. Not callable directly. + + Use :meth:`run` (invoke-and-wait), :meth:`start` (fire-and-forget), + or :meth:`options` (per-call overrides). + + :param fn: The decorated async function. + :param opts: Frozen task options. + :param input_type: Extracted input type. + :param output_type: Extracted output type. + """ + + __slots__ = ("_fn", "_opts", "_input_type", "_output_type", "name") + + def __init__( + self, + fn: Callable[[TaskContext[Input]], Awaitable[Output]], + opts: TaskOptions, + input_type: type[Input], + output_type: type[Output], + ) -> None: + self._fn = fn + self._opts = opts + self._input_type = input_type + self._output_type = output_type + self.name = opts.name + # Register for recovery — manager picks these up at startup + _REGISTERED_DESCRIPTORS.append((opts.name, fn, opts)) + # — if a TaskManager is already initialised (decorators + # declared after startup, e.g. in tests), eagerly push into its + # resume tables so _recover_stale_tasks / get_active_run can pick + # up the multi-turn opts (_is_multi_turn). + try: + from ._manager import _manager as _live_manager # pylint: disable=import-outside-toplevel + except ImportError: # pragma: no cover + _live_manager = None # type: ignore[assignment] + if _live_manager is not None: + try: + _live_manager._resume_callbacks[opts.name] = fn # noqa: SLF001 + _live_manager._resume_opts[opts.name] = opts # noqa: SLF001 + except Exception: # noqa: BLE001 + pass + + def _resolve_title(self, input_val: Input, task_id: str) -> str: + if callable(self._opts.title): + return self._opts.title(input_val, task_id) + if isinstance(self._opts.title, str): + return self._opts.title + return f"{self.name}:{task_id[:8]}" + + def _resolve_tags(self, input_val: Input, task_id: str) -> dict[str, str]: + """Resolve decorator-level tags (static dict or callable factory). + + Reserved tags (prefixed with ``_task_``) are stripped to + prevent developer code from colliding with framework-stamped tags. + + :param input_val: The task input value. + :type input_val: Input + :param task_id: The task identifier. + :type task_id: str + :return: Resolved tags dictionary. + :rtype: dict[str, str] + """ + tags = self._opts.tags + if callable(tags): + result = tags(input_val, task_id) + if not isinstance(result, dict): + raise TypeError(f"tags callable must return dict[str, str], " f"got {type(result).__name__}") + return _strip_reserved_tags(result) + return _strip_reserved_tags(dict(tags) if tags else {}) + + def _merge_tags(self, input_val: Input, task_id: str, call_tags: dict[str, str] | None) -> dict[str, str]: + merged = self._resolve_tags(input_val, task_id) + if call_tags: + merged.update(_strip_reserved_tags(call_tags)) + return merged + + async def run( + self, + *, + task_id: str | None = None, + input: Input, # noqa: A002 + input_id: str | None = None, + if_last_input_id: str | None = None, + ) -> Output: + """Run a lifecycle-aware resilient task and return the result. + + Automatically starts, resumes, or recovers the task based on its + current state: + + - No task / pending → create and start (``entry_mode="fresh"``) + - Suspended → resume with new input (``entry_mode="resumed"``) + - In-progress (stale) → recover (``entry_mode="recovered"``) + - In-progress (not stale) → raise :class:`TaskConflictError` + - Completed → raise :class:`TaskConflictError` + + .. note:: + + ``title``, ``tags``, ``retry``, are + configured on the ``@task(...)`` + decorator (or via :meth:`Task.options`), not per-call. This + is enforced so the values survive crash recovery: after the + container crashes and the framework re-enters the task, it + has only the registered decorator's options to work with — a + per-call override would silently disappear at the crash + boundary. Session identity is platform-derived from the + ``FOUNDRY_AGENT_SESSION_ID`` environment variable. + + :keyword task_id: Unique task identifier. + :paramtype task_id: str + :keyword input: Typed input value. + :paramtype input: Input + :keyword input_id: Optional identifier for the input being accepted. When + supplied, the framework records it as the task's most-recently-accepted + input id in a framework-reserved slot (``payload["_last_input_id"]``). + + Two modes: + + - **Idempotency-only** (``input_id`` set, ``if_last_input_id`` unset): + advances the stored chain head unconditionally. Always succeeds; no + precondition check. Use this when chain ordering is enforced by + another mechanism (e.g. ``task_id`` collapse + ``TaskConflictError`` + / steering-queue sequencing for conversation-grouped multi-turn). + - **Chain-extension** (paired with ``if_last_input_id``): + implements HTTP If-Match-style optimistic concurrency on the + input queue — see ``if_last_input_id`` below. + :paramtype input_id: str | None + :keyword if_last_input_id: Optional precondition. When supplied, the framework + verifies that the task's currently-stored last input id equals this value + before accepting the new input. If the precondition does not hold (a + concurrent caller advanced the queue, or the caller's view is stale), + raises :class:`LastInputIdPreconditionFailed` before any state mutation. + Modelled on HTTP ``If-Match: `` semantics. Requires ``input_id`` + to also be supplied (raises :class:`TypeError` otherwise — a + precondition without an advancing id is not meaningful). + :paramtype if_last_input_id: str | None + :return: The task result wrapper with output, status, and suspension info. + :rtype: ~azure.ai.agentserver.core.tasks.TaskResult[Output] + :raises TaskFailed: On unhandled exception. + :raises ~azure.ai.agentserver.core.tasks.TaskConflictError: If the + task is already in-progress or completed. + :raises ~azure.ai.agentserver.core.tasks.LastInputIdPreconditionFailed: If + the ``if_last_input_id`` precondition does not match the stored + last input id. + :raises TypeError: If ``if_last_input_id`` is supplied without ``input_id``. + """ + #: one-shot Task.start/.run — task_id is OPTIONAL, + # auto-generated as a GUID when not supplied. + if task_id is None: + import uuid as _uuid # pylint: disable=import-outside-toplevel + + task_id = _uuid.uuid4().hex + _validate_task_id(task_id) + if if_last_input_id is not None and input_id is None: + raise TypeError( + "if_last_input_id requires input_id (a precondition without an " "advancing id is not meaningful)" + ) + handle = await self._lifecycle_start( + task_id=task_id, + input=input, + input_id=input_id, + if_last_input_id=if_last_input_id, + ) + return await handle.result() + + async def start( + self, + *, + task_id: str | None = None, + input: Input, # noqa: A002 + input_id: str | None = None, + if_last_input_id: str | None = None, + ) -> TaskRun[Output]: + """Start a lifecycle-aware resilient task and return a handle. + + Follows the same lifecycle rules as :meth:`run` but returns + immediately with a :class:`TaskRun` handle instead of blocking. + + .. note:: + + ``title``, ``tags``, ``retry``, are + configured on the ``@task(...)`` + decorator (or via :meth:`Task.options`), not per-call — + see :meth:`run` for the rationale. Session identity is + platform-derived from the ``FOUNDRY_AGENT_SESSION_ID`` + environment variable. + + :keyword task_id: Unique task identifier. + :paramtype task_id: str + :keyword input: Typed input value. + :paramtype input: Input + :keyword input_id: Optional identifier for the input being accepted. When + supplied, the framework records it as the task's most-recently-accepted + input id in a framework-reserved slot (``payload["_last_input_id"]``). + + Two modes: + + - **Idempotency-only** (``input_id`` set, ``if_last_input_id`` unset): + advances the stored chain head unconditionally. Always succeeds; no + precondition check. Use this when chain ordering is enforced by + another mechanism (e.g. ``task_id`` collapse + ``TaskConflictError`` + / steering-queue sequencing for conversation-grouped multi-turn). + - **Chain-extension** (paired with ``if_last_input_id``): + implements HTTP If-Match-style optimistic concurrency on the + input queue — see ``if_last_input_id`` below. + :paramtype input_id: str | None + :keyword if_last_input_id: Optional precondition. When supplied, the framework + verifies that the task's currently-stored last input id equals this value + before accepting the new input. If the precondition does not hold (a + concurrent caller advanced the queue, or the caller's view is stale), + raises :class:`LastInputIdPreconditionFailed` before any state mutation. + Modelled on HTTP ``If-Match: `` semantics. Requires ``input_id`` + to also be supplied (raises :class:`TypeError` otherwise — a + precondition without an advancing id is not meaningful). + :paramtype if_last_input_id: str | None + :return: A handle to the running task. + :rtype: TaskRun[Output] + :raises ~azure.ai.agentserver.core.tasks.TaskConflictError: If the + task is already in-progress or completed. + :raises ~azure.ai.agentserver.core.tasks.LastInputIdPreconditionFailed: If + the ``if_last_input_id`` precondition does not match the stored + last input id. + :raises TypeError: If ``if_last_input_id`` is supplied without ``input_id``. + """ + #: one-shot Task.start/.run — task_id is OPTIONAL, + # auto-generated as a GUID when not supplied. + if task_id is None: + import uuid as _uuid # pylint: disable=import-outside-toplevel + + task_id = _uuid.uuid4().hex + _validate_task_id(task_id) + if if_last_input_id is not None and input_id is None: + raise TypeError( + "if_last_input_id requires input_id (a precondition without an " "advancing id is not meaningful)" + ) + return await self._lifecycle_start( + task_id=task_id, + input=input, + input_id=input_id, + if_last_input_id=if_last_input_id, + ) + + async def _get(self, task_id: str) -> Any: + """Return the full persisted task information (internal). + + .. note:: + *Internal* as of — public consumers should use + ``manager.provider.get(task_id)`` directly. + + Works for any task state — running, suspended, completed, etc. + Returns whatever is persisted. Returns ``None`` if no task exists. + + :param task_id: The task identifier. + :type task_id: str + :return: Task info or ``None`` if no task exists. + :rtype: TaskInfo | None + """ + from ._manager import ( # pylint: disable=import-outside-toplevel + get_task_manager, + ) + + manager = get_task_manager() + return await manager._provider_get_tracked(task_id) # noqa: SLF001 + + async def get_active_run(self, task_id: str) -> TaskRun[Output] | None: + """Return a TaskRun handle for an active (in-progress) task. + + : consults the store, not only + in-memory state. If the record is in-progress with a dead + lease, performs inline reclaim as a hidden side effect and + returns a usable :class:`TaskRun` bound to the new lifetime. + Terminal records return ``None``. Eviction returns ``None``. + + Enables late-join consumers to iterate a running task's stream + without being the original caller of ``start()``/``run()``, + AND covers the orphan-resurrection case where the previous + lifetime crashed without notice. + + :param task_id: The task identifier. + :type task_id: str + :return: A TaskRun bound to the active task's stream handler, + or ``None`` if not active / terminal / evicted. + :rtype: TaskRun[Output] | None + + Example:: + + # In another coroutine or request handler: + run = await my_task.get_active_run("task-123") + if run is not None: + async for chunk in run: + print(chunk, end="") + """ + from ._manager import ( # pylint: disable=import-outside-toplevel + get_task_manager, + ) + + manager = get_task_manager() + return await manager.get_active_run(task_id) + + #: Task.get is removed. TaskSnapshot is gone. + # Use manager.provider.get(task_id) directly for read-only inspection + # (returns TaskInfo, not a Snapshot wrapper). + + async def _list( + self, + *, + session_id: str | None = None, + status: TaskStatus | None = None, + ) -> list[Any]: + """List tasks created by this resilient task function (internal). + + .. note:: + *Internal* as of — public consumers should use + ``manager.list_tasks(fn_name=...)`` directly. + + Automatically scoped to this function's ``name`` via the + ``_task_name`` tag (server-side) and ``source.type`` + (client-side). Only returns tasks created by this framework. + + :keyword session_id: Session scope override. Defaults to the + manager's configured session ID. + :paramtype session_id: str | None + :keyword status: Filter by task status (e.g., ``"in_progress"``, + ``"suspended"``, ``"completed"``). + :paramtype status: TaskStatus | None + :return: Matching task records. + :rtype: list[TaskInfo] + """ + from ._manager import ( # pylint: disable=import-outside-toplevel + get_task_manager, + ) + + manager = get_task_manager() + return await manager.list_tasks( + fn_name=self.name, + session_id=session_id, + status=status, + ) + + async def _append_steering_input( # pylint: disable=protected-access + self, + manager: Any, + *, + task_id: str, + input_val: Any, + existing: Any, + input_id: str | None = None, + if_last_input_id: str | None = None, + ) -> None: + """Append a steering input to the task's pending queue. + + :param manager: The task manager instance. + :type manager: Any + :keyword task_id: Target task identifier. + :paramtype task_id: str + :keyword input_val: The new steering input value. + :paramtype input_val: Any + :keyword existing: The previously-fetched task record (used for the + first etag attempt; later attempts re-fetch internally). + :paramtype existing: Any + :keyword input_id: When set, the new input's identity. + Used to advance ``payload["_last_input_id"]`` + atomically with the queue append. + :paramtype input_id: str | None + :keyword if_last_input_id: When set, the precondition + value re-checked on each etag-conflict retry. + :paramtype if_last_input_id: str | None + """ + from ._exceptions import ( # pylint: disable=import-outside-toplevel + SteeringQueueFull, + ) + from ._models import ( # pylint: disable=import-outside-toplevel + TaskPatchRequest, + ) + + max_retries = 5 + serialized = _serialize_input(input_val) + + for _attempt in range(max_retries): + task_info = ( + existing + if _attempt == 0 + else await manager._provider_get_tracked(task_id) # pylint: disable=protected-access + ) + if task_info is None: + raise RuntimeError(f"Task {task_id!r} disappeared during steering append") + + # Re-check the input precondition on each retry to + # catch a concurrent steer that may have advanced `last_input_id` + # since we last looked. + if _attempt > 0: + _check_input_precondition( + existing=task_info, + task_id=task_id, + input_id=input_id, + if_last_input_id=if_last_input_id, + ) + + payload = dict(task_info.payload) if task_info.payload else {} + steering = dict(payload.get("_steering", {})) + pending: list[Any] = list(steering.get("pending_inputs", [])) + + if len(pending) >= _DEFAULT_MAX_PENDING_STEERING: + raise SteeringQueueFull(task_id, _DEFAULT_MAX_PENDING_STEERING) + + # — route through the promotion helper. Small steering + # inputs (≤ 20 KiB serialized) stay as raw values in + # ``pending_inputs``; larger ones are written to + # ``attachments["_steering_input_"]`` with a ref slot in + # the queue. The seq counter is monotonic (never reused) so + # other entries' attachment keys are stable across drains. + from ._attachments import ( # pylint: disable=import-outside-toplevel + _STEERING_INPUT_KEY_PREFIX, + _STEERING_THRESHOLD_BYTES, + _resolve_input_storage, + ) + + next_seq = int(steering.get("next_input_seq", 0)) + steering_key = f"{_STEERING_INPUT_KEY_PREFIX}{next_seq}" + store_mode, queue_entry = _resolve_input_storage( + serialized, + threshold_bytes=_STEERING_THRESHOLD_BYTES, + key_for_attachment=steering_key, + task_id=task_id, + ) + attachments_patch: dict[str, Any] | None = None + if store_mode == "attachment": + attachments_patch = {steering_key: serialized} + steering["next_input_seq"] = next_seq + 1 + + pending.append(queue_entry) + steering["pending_inputs"] = pending + steering["cancel_requested"] = True + # SOT: the + # internal _steering["generation"] payload field is removed + # alongside the public ctx.steering_generation surface. + payload["_steering"] = steering + + # When the caller opted in via + # input_id, advance the framework-managed last_input_id slot + # atomically with the queue append. The slot is a top-level + # `_`-prefixed payload key (: flat layout). + if input_id is not None: + payload[_LAST_INPUT_ID_PAYLOAD_KEY] = input_id + + etag = getattr(task_info, "etag", None) or None + # Piggyback lease ownership on the steering-append PATCH so + # the lease is refreshed as a side effect (see + # ``TaskManager._lease_ext_kwargs``). Zero extra round- + # trips: lease params land on the same PATCH that's + # already going out for the payload mutation. No-op when + # the caller is not the active owner of the task (the + # ``_lease_ext_kwargs`` helper returns ``{}`` in that + # case, so the wire format is unchanged). + lease_kwargs = manager._lease_ext_kwargs(task_id) # pylint: disable=protected-access + try: + await manager.provider.update( + task_id, + TaskPatchRequest( + payload=payload, + attachments=attachments_patch, + if_match=etag, + **lease_kwargs, + ), + ) + manager._note_lease_refreshed(task_id) # pylint: disable=protected-access + # Signal the running task's cancel event so it can short-circuit. + # Spec 031 / FR-001a + SOT §13 ordering invariant: record the + # live pending count BEFORE setting cancel, so a handler that + # observes ``ctx.cancel.is_set()`` already sees + # ``ctx.pending_input_count >= 1``. + active = manager._active_tasks.get(task_id) # pylint: disable=protected-access # noqa: SLF001 + if active and hasattr(active, "context") and active.context is not None: + active._pending_input_count = len(pending) # pylint: disable=protected-access # noqa: SLF001 + active.context.cancel.set() + return + except _HostedConflict as exc: + translated = _translate_hosted_conflict(exc, task_id=task_id) + if translated is None: + continue + raise translated from exc + except ValueError: + # Local provider etag conflict -- retry with the new etag + continue + except _TransportClassifiedError as exc: + # Hosted task store etag conflict (412 / 409) -- retry. + if getattr(exc, "classification", None) == "conflict": + continue + raise + + raise RuntimeError(f"Failed to append steering input after {max_retries} retries") + + def _create_steering_ack_run( + self, + manager: Any, + task_id: str, + future: Any, + input_id: str | None = None, + input_val: Any = None, + ) -> TaskRun[Output]: + """Create a TaskRun for a queued steering input. + + :param manager: The task manager owning the active execution. + :type manager: Any + :param task_id: Stable task identifier. + :type task_id: str + :param future: Future that will resolve with the next-turn outcome. + :type future: Any + :param input_id: The input_id stamped on the queued input (if any). + :type input_id: str | None + :param input_val: The raw queued input value (used to identify the + slot when ``cancel()`` is invoked on the returned handle). + :type input_val: Any + :return: A :class:`TaskRun` whose result resolves with the queued turn. + :rtype: TaskRun[Output] + """ + + async def _queued_cancel_cb() -> None: + await manager._cancel_queued_steering_input( # pylint: disable=protected-access + task_id=task_id, + future=future, + input_id=input_id, + input_val=input_val, + ) + + return TaskRun( + task_id=task_id, + provider=manager.provider, + result_future=future, + input_id=input_id, + queued_cancel_callback=_queued_cancel_cb, + ) + + async def _lifecycle_start( # pylint: disable=too-many-locals + self, + *, + task_id: str, + input: Input, # noqa: A002 + input_id: str | None = None, + if_last_input_id: str | None = None, + ) -> TaskRun[Output]: + """Resolve lifecycle state and start/resume/recover accordingly. + + Title, tags, retry, stream handler, and stale timeout are all sourced + from ``self._opts`` (the decorator-time configuration). This is + deliberate: those settings must survive the crash boundary, and the + framework can only rely on the registered decorator's view of the task + on recovery. + + :keyword task_id: The task identifier. + :paramtype task_id: str + :keyword input: Typed input value. + :paramtype input: Input + :keyword input_id: When set, the new input's identity + recorded in the framework-reserved + ``payload["_last_input_id"]`` slot. + :paramtype input_id: str | None + :keyword if_last_input_id: Precondition value checked + against the stored ``last_input_id`` before any accept path. + :paramtype if_last_input_id: str | None + :return: A handle to the running task. + :rtype: TaskRun[Output] + """ + from ._exceptions import ( # pylint: disable=import-outside-toplevel + TaskConflictError, + ) + + #: orphan-sandbox eviction at scheduling + # entry points MUST surface as TaskConflictError(current_status= + # "in_progress") — the same shape as the live-elsewhere case + # per Invariant 1. Operator-facing WARNING logs (in _manager.py + # and _lease.py) are the only differentiator. + try: + return await self._lifecycle_start_inner( + task_id=task_id, + input=input, + input_id=input_id, + if_last_input_id=if_last_input_id, + ) + except _HostedConflict as exc: + translated = _translate_hosted_conflict(exc, task_id=task_id) + if translated is None: + if exc._code == "lease_ownership_changed": + raise TaskConflictError(task_id, "in_progress") from exc + raise RuntimeError(f"Task {task_id!r} operation did not converge after retryable conflict") from exc + raise translated from exc + except _TransportClassifiedError as exc: + if getattr(exc, "classification", None) == "evicted": + # Pre-import only at the eviction site to avoid a cycle. + raise TaskConflictError(task_id, "in_progress") from exc + raise + + async def _lifecycle_start_inner( # pylint: disable=too-many-locals,too-many-statements + self, + *, + task_id: str, + input: Input, # noqa: A002 + input_id: str | None = None, + if_last_input_id: str | None = None, + ) -> TaskRun[Output]: + """Inner body of :meth:`_lifecycle_start`. See that method for docs. + + Split out so the outer wrapper can convert evictions + to ``TaskConflictError`` without indenting the entire body. + + :keyword task_id: Stable task identifier (same as outer method). + :paramtype task_id: str + :keyword input: Input value for the task (same as outer method). + :paramtype input: Input + :keyword input_id: Optional input identifier for sequential-input + acceptance preconditions (same as outer method). + :paramtype input_id: str | None + :keyword if_last_input_id: Optional if-match precondition on the + last persisted ``input_id`` (same as outer method). + :paramtype if_last_input_id: str | None + :return: A :class:`TaskRun` handle for the started task. + :rtype: TaskRun[Output] + """ + from ._exceptions import ( # pylint: disable=import-outside-toplevel + TaskConflictError, + ) + from ._manager import ( # pylint: disable=import-outside-toplevel + get_task_manager, + ) + + manager = get_task_manager() + existing = await manager._provider_get_tracked(task_id) # pylint: disable=protected-access + + resolved_retry = self._opts.retry + + # Pre-acceptance check: if the caller supplied an + # ``if_last_input_id`` precondition, verify the stored last input id + # matches before proceeding to any accept path. The actual advance + # (storing ``input_id`` into ``payload["_last_input_id"]``) is bundled + # into the create/append/resume code paths below so it lands atomically + # with the input persist. + _check_input_precondition( + existing=existing, + task_id=task_id, + input_id=input_id, + if_last_input_id=if_last_input_id, + ) + + if existing is None or existing.status == "pending": + # Fresh start + if existing is not None and existing.status == "pending": + # Pending task exists — patch to in_progress and execute + return await manager._start_existing_task( # pylint: disable=protected-access + fn=self._fn, + fn_name=self.name, + task_info=existing, + entry_mode="fresh", + input_val=input, + input_type=self._input_type, + opts=self._opts, + retry=resolved_retry, + ) + # No task exists — create new + return await manager.create_and_start( + fn=self._fn, + fn_name=self.name, + task_id=task_id, + input_val=input, + input_type=self._input_type, + session_id=None, + title=self._resolve_title(input, task_id), + tags=self._merge_tags(input, task_id, None), + opts=self._opts, + retry=resolved_retry, + entry_mode="fresh", + initial_payload_extras=_build_framework_extras(input_id), + ) + + if existing.status == "suspended": + # Resume — patch input onto task, then start. + # Etag-protected retry loop so concurrent + # suspended-resume POSTs race safely instead of silently + # overwriting each other. + # On the same atomic patch, advance the + # framework's `payload["_last_input_id"]` slot when the caller + # opted in via `input_id`. The precondition check already ran + # at the top of `_lifecycle_start` against the read existing. + serialized = _serialize_input(input) + from ._attachments import ( # pylint: disable=import-outside-toplevel + _FUNCTION_INPUT_KEY, + _INPUT_THRESHOLD_BYTES, + _resolve_input_storage, + ) + from ._models import ( # pylint: disable=import-outside-toplevel + TaskPatchRequest, + ) + + # — promotion: route the resume input through the + # same helper as the create path. Inline stays raw in payload; + # > 200 KiB spills into ``attachments["_input"]`` with a ref + # in payload. Single PATCH carries both. + input_mode, input_value = _resolve_input_storage( + serialized, + threshold_bytes=_INPUT_THRESHOLD_BYTES, + key_for_attachment=_FUNCTION_INPUT_KEY, + task_id=task_id, + ) + attachments_patch: dict[str, Any] | None = None + if input_mode == "attachment": + attachments_patch = {_FUNCTION_INPUT_KEY: serialized} + + max_resume_retries = 5 + current_info = existing + for _attempt in range(max_resume_retries): + etag = getattr(current_info, "etag", None) or None + # Build the resume patch: input + (optionally) advance the + # framework-managed last_input_id slot (flat layout). + resume_payload: dict[str, Any] = {"input": input_value} + if input_id is not None: + resume_payload[_LAST_INPUT_ID_PAYLOAD_KEY] = input_id + try: + # PATCH returns the updated TaskInfo -- capture it + # to skip the post-patch refetch below. + # / — route through the + # manager's per-task write queue so the etag cache + # is refreshed from the response (otherwise the + # subsequent _start_existing_task PATCH would carry + # a stale if_match and 412 against itself). + updated_info = await manager._provider_update_locked( # pylint: disable=protected-access + task_id, + TaskPatchRequest( + payload=resume_payload, + attachments=attachments_patch, + if_match=etag, + ), + ) + break + except _HostedConflict as exc: + translated = _translate_hosted_conflict(exc, task_id=task_id) + if translated is not None: + raise translated from exc + refreshed = await manager._provider_get_tracked(task_id) # pylint: disable=protected-access + if refreshed is None: + raise RuntimeError(f"Task {task_id!r} disappeared during suspended-resume retry") from exc + _check_input_precondition( + existing=refreshed, + task_id=task_id, + input_id=input_id, + if_last_input_id=if_last_input_id, + ) + current_info = refreshed + except (ValueError, _TransportClassifiedError) as exc: + # Etag conflict -- re-fetch, re-check precondition, retry. + # Local provider raises ValueError; hosted task store + # raises TransportClassifiedError with classification= + # "conflict" (412 etag mismatch or 409). Both are + # the same logical concurrency outcome. + if ( + isinstance(exc, _TransportClassifiedError) + and getattr(exc, "classification", None) != "conflict" + ): + raise + refreshed = await manager._provider_get_tracked(task_id) # pylint: disable=protected-access + if refreshed is None: + raise RuntimeError(f"Task {task_id!r} disappeared during suspended-resume retry") from exc + # Re-check the precondition against the now-refreshed view. + # On a precondition failure here, the exception propagates + # out (validation failure, not concurrency conflict). + _check_input_precondition( + existing=refreshed, + task_id=task_id, + input_id=input_id, + if_last_input_id=if_last_input_id, + ) + current_info = refreshed + else: + raise RuntimeError( + f"Failed to apply suspended-resume input patch after " + f"{max_resume_retries} retries (task {task_id!r})" + ) + # PATCH already returned the updated TaskInfo -- no GET needed. + if updated_info is None: + raise RuntimeError(f"Task {task_id!r} disappeared after input patch") + return await manager._start_existing_task( # pylint: disable=protected-access + fn=self._fn, + fn_name=self.name, + task_info=updated_info, + entry_mode="resumed", + input_val=input, + input_type=self._input_type, + opts=self._opts, + retry=resolved_retry, + ) + + if existing.status == "in_progress": + # Layer 3 +: consult the lease + # state to decide recovery vs. conflict. The legacy + # _LEGACY_INPROCESS_STALE_THRESHOLD_SECONDS wall-clock + # heuristic over updated_at is replaced by the proper + # lease-state determination via _lease_is_dead. If the + # lease is dead, inline-reclaim via _reclaim_one and + # re-enter as recovered (Layer 3); if alive, + # either queue the steering input or raise TaskConflictError. + from ._manager import ( # pylint: disable=import-outside-toplevel + _lease_is_dead, + ) + + active_locally = manager._active_tasks.get(task_id) is not None # pylint: disable=protected-access + lease_dead = _lease_is_dead( + existing, + this_lease_owner=manager._lease_owner, # pylint: disable=protected-access + active_locally=active_locally, + ) + + if lease_dead: + # Inline reclaim per layer (c). On race-lost / + # eviction the TransportClassifiedError propagates and + # the outer _lifecycle_start wrapper converts it to + # TaskConflictError (Invariant 1 shape). + try: + await manager._reclaim_one(existing) # pylint: disable=protected-access + except _HostedConflict as exc: + translated = _translate_hosted_conflict(exc, task_id=task_id) + if translated is None or getattr(translated, "current_status", None) == "in_progress": + raise TaskConflictError(task_id, "in_progress") from exc + raise translated from exc + except _TransportClassifiedError as exc: + if getattr(exc, "classification", None) == "evicted": + raise TaskConflictError(task_id, "in_progress") from exc + raise + + # Stale with steering recovery state — recover via steered path + if self._opts.steerable and existing.payload: + steering = existing.payload.get("_steering", {}) + if steering.get("drain_in_progress") or steering.get("pending_inputs"): + return await manager._start_existing_task( # pylint: disable=protected-access + fn=self._fn, + fn_name=self.name, + task_info=existing, + entry_mode="recovered", + input_val=input, + input_type=self._input_type, + opts=self._opts, + retry=resolved_retry, + ) + # Normal recovery + return await manager._start_existing_task( # pylint: disable=protected-access + fn=self._fn, + fn_name=self.name, + task_info=existing, + entry_mode="recovered", + input_val=input, + input_type=self._input_type, + opts=self._opts, + retry=resolved_retry, + ) + if self._opts.steerable: + # Steering path: append input to queue, signal cancel, return ack + # pylint: disable=protected-access + ack_future = manager._register_steering_future(task_id) + await self._append_steering_input( + manager, + task_id=task_id, + input_val=input, + existing=existing, + input_id=input_id, + if_last_input_id=if_last_input_id, + ) + # Set cancel on in-memory context if task runs in this process + active = manager._active_tasks.get(task_id) + # pylint: enable=protected-access + if active: + active.context.cancel.set() + return self._create_steering_ack_run(manager, task_id, ack_future, input_id=input_id, input_val=input) + raise TaskConflictError(task_id, "in_progress") + + # completed (or any other terminal status) + raise TaskConflictError(task_id, existing.status) + + +@overload +def task( + fn: Callable[[TaskContext[Input]], Awaitable[Output]], +) -> Task[Input, Output]: ... + + +@overload +def task( + *, + name: str | None = ..., + title: str | None = ..., + timeout: timedelta | None = ..., + retry: RetryPolicy | None = ..., +) -> Callable[ + [Callable[[TaskContext[Input]], Awaitable[Output]]], + Task[Input, Output], +]: ... + + +def task( + fn: Callable[..., Any] | None = None, + *, + name: str | None = None, + title: str | None = None, + timeout: timedelta | None = None, + retry: RetryPolicy | None = None, + **_extra_kwargs: Any, +) -> Any: + """Turn an async function into a crash-resilient one-shot resilient task. + + One-shot tasks are always ephemeral — the persisted record is + deleted on terminal exit. ``task_id`` is optional on the resulting + handle's ``.start`` / ``.run`` calls; the framework auto-generates + a GUID and defaults ``input_id`` to ``task_id`` (1:1 invariant). + + Can be used with or without arguments:: + + @task + async def my_task(ctx: TaskContext[MyInput]) -> MyOutput: ... + + @task(name="custom-name") + async def my_task(ctx: TaskContext[MyInput]) -> MyOutput: ... + + :param fn: The async function to decorate (when used without parens). + :type fn: Callable[..., Any] | None + :keyword name: **Stable identity anchor.** Used for recovery routing and + source stamping. Defaults to ``fn.__qualname__``. Always provide an + explicit name for production tasks — if you rename the function later, + existing in-flight tasks are still recovered correctly because the + framework matches on this name, not the Python function name. + :keyword title: Static human-readable string. + :keyword timeout: Per-turn, wall-clock, resilient, cooperative-only + execution budget. When the budget elapses for the current turn, + ``ctx.timeout_exceeded`` is set then ``ctx.cancel`` is set; the + handler decides whether to wind down. The watchdog does NOT + force-stop the handler. See the developer guide §4 Timeout for + the full mechanic (including the crash-mid-turn budget-preserving + recovery semantics). + :keyword retry: Default retry policy for this task. Recovery-safe: applied + by the framework on every entry, including crash recovery. + :return: A ``Task[Input, Output]`` wrapper. + :rtype: Any + + .. note:: + Use ``@multi_turn_task`` for steerable chains. Passing + ``ephemeral=`` or ``steerable=`` to ``@task`` raises + ``TypeError`` at decoration time. + """ + # Reject unknown / unsupported kwargs at decoration time. + # ``steerable=`` and ``ephemeral=`` are NOT accepted on @task; use + # @multi_turn_task for steerable chains. + _validate_task_kwargs(**_extra_kwargs) + _validate_title(title) + + def _wrap(func: Callable[..., Any]) -> Task[Any, Any]: + if not asyncio.iscoroutinefunction(func): + raise TypeError(f"@task requires an async def function (an async function), " f"got {func.__qualname__!r}") + _validate_handler_signature(func, "task") + + input_type, output_type = _extract_generic_args(func) + + opts = TaskOptions( + name=name or func.__qualname__, + title=title, + tags={}, + timeout=timeout, + ephemeral=True, + retry=retry, + steerable=False, + ) + + return Task( + fn=func, + opts=opts, + input_type=input_type, + output_type=output_type, + ) + + if fn is not None: + return _wrap(fn) + return _wrap + + +# ========================================================================= +# — Phase 2: class split + identifier supply + handler-sig validation +# ========================================================================= +# +# / / / / /. +# +# - `MultiTurnTask` is a DISTINCT public class from `Task` (— not a +# subclass; type checker enforces "no .delete() on one-shot"). +# - `@multi_turn_task(steerable=...)` decorator returns MultiTurnTask. +# - Both decorators validate kwargs at decoration time and accept +# only static-string `title`. +# - Handler signature validation. + + +def _validate_title(title: object) -> None: + """/ — title must be `str | None`. Callable form REMOVED.""" + if title is not None and not isinstance(title, str): + raise TypeError( + f"@task / @multi_turn_task `title=` must be `str | None`; " + f"callable-factory form is not supported (got {type(title).__name__}: {title!r})" + ) + + +def _validate_handler_signature(func: Callable[..., Any], decorator_name: str) -> None: + """— handler must be `async def fn(ctx: TaskContext[Input]) -> Output`.""" + if not asyncio.iscoroutinefunction(func): + raise TypeError( + f"@{decorator_name} requires an `async def` (async function) handler, " + f"got synchronous {func.__qualname__!r}" + ) + try: + sig = inspect.signature(func) + except (ValueError, TypeError): + return # builtins / C-level callables — let downstream binding catch it + params = list(sig.parameters.values()) + if not params: + raise TypeError( + f"@{decorator_name} handler must accept a `ctx: TaskContext[Input]` " + f"first positional argument; got zero-arg signature in " + f"{func.__qualname__!r}" + ) + first = params[0] + if first.kind in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD): + raise TypeError( + f"@{decorator_name} handler must accept a `ctx: TaskContext[Input]` " + f"as first positional argument; got *{first.name} / **{first.name} in " + f"{func.__qualname__!r}" + ) + if first.name != "ctx": + #: first parameter MUST be named ``ctx``. + raise TypeError( + f"@{decorator_name} handler first argument must be named `ctx` " + f"(found {first.name!r} in {func.__qualname__!r})" + ) + # The remaining positional/keyword args must all have defaults (the + # framework calls handler(ctx) with no extra args). + for p in params[1:]: + if p.default is inspect.Parameter.empty and p.kind in ( + inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + inspect.Parameter.KEYWORD_ONLY, + ): + raise TypeError( + f"@{decorator_name} handler must accept only `ctx`; extra " + f"required argument {p.name!r} in {func.__qualname__!r} has no default" + ) + + +_ALLOWED_TASK_KWARGS = frozenset({"name", "title", "timeout", "retry"}) +_ALLOWED_MULTI_TURN_TASK_KWARGS = frozenset({"name", "title", "timeout", "retry", "steerable"}) + + +def _validate_task_kwargs(**kwargs: Any) -> None: + """@task allow-list: name / title / timeout / retry only. + + Unknown kwargs raise ``TypeError`` at decoration time. ``steerable=`` + and ``ephemeral=`` are explicitly NOT accepted — use + ``@multi_turn_task`` for steerable chains; one-shot tasks are always + ephemeral. + """ + unknown = set(kwargs) - _ALLOWED_TASK_KWARGS + if unknown: + msg = f"@task got unexpected kwargs: {sorted(unknown)}. Allowed: {sorted(_ALLOWED_TASK_KWARGS)}." + if "steerable" in unknown or "ephemeral" in unknown: + msg += ( + " Use @multi_turn_task for steerable chains; one-shot " + "@task is always ephemeral (the record is deleted on " + "terminal exit)." + ) + raise TypeError(msg) + + +def _validate_multi_turn_task_kwargs(**kwargs: Any) -> None: + """— @multi_turn_task allow-list.""" + unknown = set(kwargs) - _ALLOWED_MULTI_TURN_TASK_KWARGS + if unknown: + if "ephemeral" in unknown: + raise TypeError("@multi_turn_task does not accept `ephemeral=` (chains are never ephemeral)") + if "tags" in unknown: + raise TypeError("@multi_turn_task does not accept `tags=` (tags surface is not part of)") + raise TypeError( + f"@multi_turn_task got unexpected kwargs: {sorted(unknown)}. " + f"Allowed: {sorted(_ALLOWED_MULTI_TURN_TASK_KWARGS)}" + ) + + +class MultiTurnTask(Generic[Input, Output]): + """A decorated multi-turn resilient task chain. + + Distinct public class:class:`Task` — NOT a subclass. + The type checker enforces "no ``.delete()`` on one-shot" and + "multi-turn ``get_active_run`` takes both ``task_id`` AND ``input_id``" + statically. + + Returned by the :func:`multi_turn_task` decorator. + + This class wraps an internal :class:`Task` (same execution model) but + exposes a strictly-typed multi-turn surface. The wrapped Task carries + ``ephemeral=False`` so the framework knows the chain semantics. + """ + + # Internal flag — multi-turn chains never auto-delete on suspend. + _is_multi_turn = True + + def __init__( + self, + fn: Callable[..., Any], + opts: TaskOptions, + input_type: type | None = None, + output_type: type | None = None, + ) -> None: + self._inner = Task( + fn=fn, + opts=opts, + input_type=input_type, + output_type=output_type, + ) + + @property + def _fn(self) -> Callable[..., Any]: + return self._inner._fn # noqa: SLF001 + + @property + def _opts(self) -> TaskOptions: + return self._inner._opts # noqa: SLF001 + + @property + def _input_type(self) -> Any: + return self._inner._input_type # noqa: SLF001 + + @property + def _output_type(self) -> Any: + return self._inner._output_type # noqa: SLF001 + + @property + def name(self) -> str: + """The registered task name (proxy of the wrapped Task).""" + return self._inner.name + + async def run( + self, + *, + task_id: str, + input: Any, # noqa: A002 + input_id: str | None = None, + if_last_input_id: str | None = None, + ) -> Any: + """Run one turn on the chain identified by ``task_id``. + + :keyword task_id: The chain identifier (mandatory). + :keyword input: The turn's input value. + :keyword input_id: Optional per-turn identifier; auto-generated + when omitted. + :keyword if_last_input_id: Optional ``If-Match``-style + precondition on the chain's last-accepted ``input_id``. + :return: The handler's return value for this turn. + """ + return await self._inner.run( + task_id=task_id, + input=input, + input_id=input_id, + if_last_input_id=if_last_input_id, + ) + + async def start( + self, + *, + task_id: str, + input: Any, # noqa: A002 + input_id: str | None = None, + if_last_input_id: str | None = None, + ) -> "TaskRun[Output]": + """Start one turn on the chain identified by ``task_id`` and + return a :class:`TaskRun` handle for that turn. + + :keyword task_id: The chain identifier (mandatory). + :keyword input: The turn's input value. + :keyword input_id: Optional per-turn identifier. + :keyword if_last_input_id: Optional ``If-Match``-style precondition. + :return: A :class:`TaskRun` handle bound to the turn. + """ + return await self._inner.start( + task_id=task_id, + input=input, + input_id=input_id, + if_last_input_id=if_last_input_id, + ) + + async def get_active_run( + self, + task_id: str, + input_id: str, + ) -> "TaskRun[Output] | None": + """Multi-turn variant of ``get_active_run`` — REQUIRES ``input_id``. + + The current turn's input_id is the match key; mismatch returns + ``None``. + + :param task_id: The chain task_id. + :type task_id: str + :param input_id: The exact input_id of the currently in-flight turn. + :type input_id: str + :return: The TaskRun handle bound to the currently in-flight turn + iff ``(task_id, input_id)`` exactly matches; ``None`` otherwise. + :rtype: TaskRun[Output] | None + """ + run = await self._inner.get_active_run(task_id) + if run is None: + return None + if getattr(run, "input_id", None) != input_id: + return None + return run + + async def delete(self, task_id: str) -> None: + """Force-delete the chain record + any queued inputs. + + , removes the chain record and all queued + steerers; resolves active + queued callers' ``.result()`` futures + with :class:`TaskCancelled`. Idempotent (no-op when the chain is + already gone). + + :param task_id: The chain task_id to delete. + :type task_id: str + """ + from ._manager import get_task_manager # pylint: disable=import-outside-toplevel + from ._exceptions import TaskCancelled # pylint: disable=import-outside-toplevel + + try: + mgr = get_task_manager() + except RuntimeError: + return # no manager -> nothing to delete + + # 1. Resolve any active in-process caller's future with TaskCancelled. + active = getattr(mgr, "_active_tasks", {}).get(task_id) + if active is not None: + fut = getattr(active, "result_future", None) + if fut is not None and not fut.done(): + fut.set_exception(TaskCancelled()) + # Signal the handler's cancel event so the running coroutine + # winds down cooperatively. + cancel_evt = getattr(active.context, "cancel", None) + if cancel_evt is not None: + cancel_evt.set() + # Force-cancel the running execution_task so handlers blocked + # on awaits that don't check ctx.cancel still exit. + exec_task = getattr(active, "execution_task", None) + if exec_task is not None and not exec_task.done(): + exec_task.cancel() + + # 2. Resolve all queued steerer futures with TaskCancelled. + pending = getattr(mgr, "_pending_steering_futures", {}).pop(task_id, []) + for queued_fut in pending: + if not queued_fut.done(): + queued_fut.set_exception(TaskCancelled()) + + # 3. Force-delete the record (idempotent — only swallow the + # "already-gone" classes). + provider = getattr(mgr, "_provider", None) + if provider is not None: + from ._exceptions_internal import TaskNotFound # pylint: disable=import-outside-toplevel + + try: + await provider.delete(task_id, force=True) + except TaskNotFound: + pass # idempotent: already gone + + +@overload +def multi_turn_task( + fn: Callable[[TaskContext[Input]], Awaitable[Output]], +) -> MultiTurnTask[Input, Output]: ... + + +@overload +def multi_turn_task( + *, + name: str | None = ..., + title: str | None = ..., + timeout: timedelta | None = ..., + retry: RetryPolicy | None = ..., + steerable: bool = ..., +) -> Callable[ + [Callable[[TaskContext[Input]], Awaitable[Output]]], + MultiTurnTask[Input, Output], +]: ... + + +def multi_turn_task( + fn: Callable[..., Any] | None = None, + *, + name: str | None = None, + title: str | None = None, + timeout: timedelta | None = None, + retry: RetryPolicy | None = None, + steerable: bool = False, + **_extra_kwargs: Any, +) -> Any: + """Decorator producing a multi-turn resilient chain. + + Multi-turn chains accept inputs across many turns against the same + ``task_id``. The handler's ``return X`` is the implicit-suspend + signal — there is no ``ctx.suspend``. The chain stays + alive across handler raises. + + :keyword name: Stable chain-identity anchor. + :keyword title: Static human-readable string. Callable-factory form is + not supported. + :keyword timeout: Per-turn cooperative timeout. + :keyword retry: Default retry policy. + :keyword steerable: When True, ``start()`` against an in-flight chain + queues the new input instead of raising ``TaskConflictError``. + :return: A :class:`MultiTurnTask` instance (distinct public class from + :class:`Task`). + """ + # — reject unknown kwargs at decoration time + _validate_multi_turn_task_kwargs(**_extra_kwargs) + # / — title must be str | None + _validate_title(title) + + def _wrap(func: Callable[..., Any]) -> MultiTurnTask[Any, Any]: + # — handler-signature validation + _validate_handler_signature(func, "multi_turn_task") + + input_type, output_type = _extract_generic_args(func) + + opts = TaskOptions( + name=name or func.__qualname__, + title=title, + tags={}, + timeout=timeout, + ephemeral=False, # multi-turn chains are NEVER ephemeral + retry=retry, + steerable=steerable, + _is_multi_turn=True, # — signals new raise/persistence semantics + ) + + return MultiTurnTask( + fn=func, + opts=opts, + input_type=input_type, + output_type=output_type, + ) + + if fn is not None: + return _wrap(fn) + return _wrap diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tasks/_exceptions.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tasks/_exceptions.py new file mode 100644 index 000000000000..7f104a0f2813 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tasks/_exceptions.py @@ -0,0 +1,280 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Exception types for the resilient task subsystem. + + reshape: public exceptions no longer carry +``task_id`` (caller has it via the run handle / call site). Constructors +ACCEPT legacy ``task_id`` positional args for back-compat during the +transition, but discard them (the attribute is never set). +""" + +from typing import Any +import inspect + + +class TaskFailed(Exception): + """Raised when a resilient task function raises an unhandled exception. + + : only ``error`` is carried. ``task_id`` is no longer + on the exception (caller has it from the run handle). + + :keyword error: Structured error details (matches one of TaskErrorDict + or TaskExhaustedRetriesErrorDict). + :paramtype error: dict[str, Any] + """ + + error: "TaskErrorDict | TaskExhaustedRetriesErrorDict" + + def __init__(self, *args: Any, error: dict[str, Any] | None = None) -> None: + # Legacy: TaskFailed(task_id, error_dict) + if args: + if len(args) == 2 and error is None: + # Legacy positional (task_id, error_dict): discard task_id. + error = args[1] + elif len(args) == 1 and error is None: + error = args[0] + if not isinstance(error, dict): + raise TypeError("TaskFailed: 'error' keyword (dict) is required") + self.error = error # type: ignore[assignment] + super().__init__(error.get("message", "Task failed")) + + +#: visible signature is `error` only. +TaskFailed.__signature__ = inspect.Signature( # type: ignore[attr-defined] + parameters=[inspect.Parameter("error", inspect.Parameter.KEYWORD_ONLY)] +) + + +class TaskCancelled(Exception): + """Raised when a resilient task is cancelled (: bare).""" + + # NO __slots__ + NO instance state — requires no fields. + # __str__ is hardcoded; legacy positional task_id is accepted and discarded. + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__() # args MUST be () + + def __str__(self) -> str: # pragma: no cover -- minor str formatting + return "Task was cancelled" + + +# Override inspect signature to show empty parameter list. +TaskCancelled.__signature__ = inspect.Signature(parameters=[]) # type: ignore[attr-defined] + + +class TaskNotFound(Exception): + """Internal-only — not exported from public surface.""" + + def __init__(self, task_id: str | None = None) -> None: + self.task_id = task_id + super().__init__(f"Task {task_id!r} not found") + + +class TaskConflictError(RuntimeError): + """Raised when a task lifecycle conflict cannot be resolved. + + : only ``current_status`` is carried. + + :keyword current_status: The task's current status. + :paramtype current_status: str + """ + + __slots__ = ("current_status",) + + def __init__(self, *args: Any, current_status: str | None = None) -> None: + # Legacy: TaskConflictError(task_id, current_status) + if args: + if len(args) == 2 and current_status is None: + current_status = args[1] + elif len(args) == 1 and current_status is None: + current_status = args[0] + if current_status is None: + raise TypeError("TaskConflictError: 'current_status' is required") + self.current_status = current_status + super().__init__(f"Task is already {current_status}") + + +#: visible signature is current_status only. +TaskConflictError.__signature__ = inspect.Signature( # type: ignore[attr-defined] + parameters=[inspect.Parameter("current_status", inspect.Parameter.KEYWORD_ONLY)] +) + + +class EtagConflict(RuntimeError): + """Raised when an optimistic concurrency (etag) check fails.""" + + __slots__ = ("task_id",) + + def __init__(self, task_id: str, message: str | None = None) -> None: + self.task_id = task_id + msg = message or f"Etag conflict on task '{task_id}'" + super().__init__(msg) + + +class SteeringQueueFull(RuntimeError): + """Raised when the steering pending-input queue is at capacity (: bare).""" + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__("Steering queue is full") + + +SteeringQueueFull.__signature__ = inspect.Signature(parameters=[]) # type: ignore[attr-defined] + + +class TaskPreconditionFailed(RuntimeError): + """Internal-only base — not exported.""" + + __slots__ = ("task_id",) + + def __init__(self, task_id: str = "", message: str = "") -> None: + self.task_id = task_id + super().__init__(message or "task precondition failed") + + +class LastInputIdPreconditionFailed(TaskPreconditionFailed): + """Raised when ``Task.start``'s ``if_last_input_id`` precondition is not met. + + : only ``actual_last_input_id`` is carried. + """ + + __slots__ = ("actual_last_input_id",) + + def __init__( + self, + *args: Any, + actual_last_input_id: str | None = None, + expected_last_input_id: str | None = None, # accepted, discarded + task_id: str | None = None, # accepted, discarded + ) -> None: + legacy_task_id = task_id + if args: + if len(args) == 1: + if actual_last_input_id is None and expected_last_input_id is None: + actual_last_input_id = args[0] + else: + legacy_task_id = args[0] + elif len(args) == 3: + legacy_task_id = args[0] + actual_last_input_id = args[2] + self.actual_last_input_id = actual_last_input_id + # IMPORTANT: do NOT call super().__init__ — the parent + # TaskPreconditionFailed sets ``self.task_id``, which + # forbids on public exceptions. Initialise via the + # RuntimeError base directly. + msg = f"if_last_input_id precondition failed: " f"actual last_input_id={actual_last_input_id!r}" + RuntimeError.__init__(self, msg) + + +LastInputIdPreconditionFailed.__signature__ = inspect.Signature( # type: ignore[attr-defined] + parameters=[inspect.Parameter("actual_last_input_id", inspect.Parameter.KEYWORD_ONLY)] +) + + +class InputTooLarge(ValueError): + """Raised when an input's serialized size exceeds the per-input cap (: bare).""" + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__("Input exceeds the per-input cap") + + +InputTooLarge.__signature__ = inspect.Signature(parameters=[]) # type: ignore[attr-defined] + + +#: OutputTooLarge is REMOVED from public surface. The +# class is kept as internal-only (no longer in __init__'s __all__). +class OutputTooLarge(ValueError): + """Internal-only — not exported. Kept for legacy raise sites.""" + + __slots__ = ("task_id", "size_bytes", "max_bytes") + + def __init__(self, task_id: str = "", size_bytes: int = 0, max_bytes: int = 0) -> None: + self.task_id = task_id + self.size_bytes = size_bytes + self.max_bytes = max_bytes + super().__init__( + f"Output for task {task_id!r} exceeds the per-output cap: " f"{size_bytes} bytes > {max_bytes} byte cap." + ) + + +class _AttachmentTooLarge(ValueError): + """— provider-internal cap-violation signal.""" + + __slots__ = ("task_id", "attachment_key", "size_bytes", "max_bytes") + + def __init__( + self, + task_id: str, + attachment_key: str, + size_bytes: int, + max_bytes: int, + ) -> None: + self.task_id = task_id + self.attachment_key = attachment_key + self.size_bytes = size_bytes + self.max_bytes = max_bytes + super().__init__( + f"Attachment {attachment_key!r} on task {task_id!r} is too large: " + f"{size_bytes} bytes > {max_bytes} byte per-attachment cap." + ) + + +class _AttachmentLimitExceeded(ValueError): + """— provider-internal per-task attachment-count cap violation.""" + + __slots__ = ("task_id", "current_count", "max_count") + + def __init__(self, task_id: str, current_count: int, max_count: int) -> None: + self.task_id = task_id + self.current_count = current_count + self.max_count = max_count + super().__init__(f"Task {task_id!r} already has {current_count} attachments; " f"per-task cap is {max_count}.") + + +# Backward-compatible aliases for any in-tree caller that still imports +# the pre-019 names. +AttachmentTooLarge = _AttachmentTooLarge +AttachmentLimitExceeded = _AttachmentLimitExceeded + + +# ========================================================================= +# — additions to the exception taxonomy +# ========================================================================= + +try: + from typing import Literal, TypedDict +except ImportError: # pragma: no cover + from typing_extensions import Literal, TypedDict # type: ignore[assignment] + + +class TaskDeferred(Exception): + """Raised when handler called ``ctx.exit_for_recovery``. + + Semantically DISTINCT from :class:`TaskCancelled` — the task stays + ``in_progress`` and recovery re-invokes the handler in a future + lifetime. Bare exception. + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__("Task deferred to next process lifetime") + + +TaskDeferred.__signature__ = inspect.Signature(parameters=[]) # type: ignore[attr-defined] + + +class TaskErrorDict(TypedDict): + """Shape of:attr:`TaskFailed.error` for a normal handler-raise failure.""" + + type: str + message: str + traceback: str + + +class TaskExhaustedRetriesErrorDict(TypedDict): + """Shape of:attr:`TaskFailed.error` when the retry budget was exhausted.""" + + type: Literal["exhausted_retries"] + attempts: int + last_error: str + last_error_type: str + traceback: str diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tasks/_exceptions_internal.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tasks/_exceptions_internal.py new file mode 100644 index 000000000000..432cd049ebea --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tasks/_exceptions_internal.py @@ -0,0 +1,148 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Internal framework-private exceptions. + +These exception types are NEVER exported from +``azure.ai.agentserver.core.tasks.__init__``. They exist purely as +internal discriminators the framework's classifier code raises so +that lifecycle / retry / error-mapping code can branch on the +underlying cause without leaking service-API vocabulary onto the +developer surface. + +The translation from these internal types → developer-facing types +is documented in ``docs/task-and-streaming-spec.md`` §39.1. + +: ``TaskNotFound`` and ``TaskPreconditionFailed`` +live here as internal-only re-exports (the classes themselves are +defined in ``_exceptions.py`` for now, but the canonical import +path for in-tree callers is this module). +""" + +from __future__ import annotations + +import logging + +from ._exceptions import ( + TaskConflictError, + TaskNotFound, + TaskPreconditionFailed, +) + +__all__ = [ + "_HostedConflict", + "_translate_hosted_conflict", + "TaskNotFound", + "TaskPreconditionFailed", + "TaskConflictError", +] + +logger = logging.getLogger("azure.ai.agentserver.tasks") + + +class _HostedConflict(Exception): + """Internal discriminator for service-emitted error codes. + + The hosted task service returns distinct error codes for distinct + failure conditions (``task_immutable``, ``invalid_state_transition``, + ``lease_held_by_another``, ``task_already_exists``, + ``lease_ownership_changed``, ``etag_mismatch``, ``invalid_request``). + The hosted provider's response classifier wraps each in this type + so the framework's lifecycle code can dispatch on ``_code`` and + translate to the appropriate public exception (or retry + transparently for ``etag_mismatch`` / ``lease_ownership_changed``). + + The local file provider raises the same type with the same ``_code`` + directly for the equivalent in-process conditions, so the + framework's dispatch table works against either backing. + + The leading underscore on the class name AND on ``_code`` is the + Python-canonical signal: package-private, never imported by + developer code, never appears in docstrings of public APIs. + + :param _code: One of the service's structured error code strings. + Matches the ``code`` field of the JSON error envelope on the + wire. + :type _code: str + :param status_code: The HTTP status code the service would return + (or would have returned, in local mode). 400 / 409 / 412 per + §39.1. + :type status_code: int + :param message: Optional human-readable message for diagnostic + purposes. NEVER reaches developer code as-is — the framework's + translation step writes its own framework-vocabulary message + on the public exception. + :type message: str | None + :param task_id: Optional task identifier for log correlation. + :type task_id: str | None + """ + + __slots__ = ("_code", "status_code", "message", "task_id") + + def __init__( + self, + _code: str, + status_code: int, + message: str | None = None, + task_id: str | None = None, + ) -> None: + super().__init__(message or _code) + self._code = _code + self.status_code = status_code + self.message = message + self.task_id = task_id + + def __repr__(self) -> str: + return ( + f"_HostedConflict(_code={self._code!r}, " f"status_code={self.status_code!r}, " f"task_id={self.task_id!r})" + ) + + +# Public name "_HostedConflict" is exported via class definition above. +# Intentionally NOT added to any __all__; underscore prefix already +# excludes it from `from _exceptions_internal import *` and signals +# package-private intent. +__all__: list[str] = [] + + +def _translate_hosted_conflict( + exc: "_HostedConflict", + task_id: str | None = None, + observed_status: str | None = None, +) -> "Exception | None": + """Translate a `_HostedConflict` to a developer-facing exception. + + Returns None for transient codes the caller should retry + (``etag_mismatch``, ``lease_ownership_changed``). Otherwise returns the + public exception the caller should raise. + """ + effective_task_id = task_id or exc.task_id or "" + code = exc._code + + if code in {"etag_mismatch", "lease_ownership_changed"}: + return None + if code == "lease_held_by_another": + return TaskConflictError(effective_task_id, "in_progress") + if code == "task_immutable": + return TaskConflictError(effective_task_id, "completed") + if code == "task_already_exists": + return TaskConflictError(effective_task_id, observed_status or "in_progress") + if code == "invalid_request": + return TaskPreconditionFailed( + effective_task_id, + exc.message or "the task request failed a validation precondition", + ) + if code == "invalid_state_transition": + logger.warning( + "Framework generated an invalid task state transition for task %s", + effective_task_id, + exc_info=True, + ) + return RuntimeError("Framework generated an invalid task state transition.") + + logger.warning( + "Task provider returned an unrecognized internal conflict for task %s", + effective_task_id, + exc_info=True, + ) + return RuntimeError("Task operation failed due to an internal conflict.") diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tasks/_lease.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tasks/_lease.py new file mode 100644 index 000000000000..5cc34b6a22f4 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tasks/_lease.py @@ -0,0 +1,281 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Lease identity derivation and renewal loop for resilient tasks. + +Provides utility functions for constructing stable lease owner strings, +generating ephemeral instance IDs, and running the background lease +renewal loop. +""" + +from __future__ import annotations + +import asyncio # pylint: disable=do-not-import-asyncio +import logging +import os +import time +import uuid +from collections.abc import Awaitable, Callable +from typing import Any + +from ._models import TaskPatchRequest +from ._provider import TaskProvider +from ._client import TransportClassifiedError +from ._exceptions_internal import _HostedConflict, _translate_hosted_conflict + +logger = logging.getLogger("azure.ai.agentserver.tasks") + + +def derive_lease_owner(agent_name: str, session_id: str) -> str: + """Derive a stable lease owner string from the agent name and session ID. + + : the lease owner string MUST be derived from + BOTH the agent name (from ``FOUNDRY_AGENT_NAME``) AND the session + identifier — not from the session ID alone. Two different agents + that happen to share a session ID (a misconfiguration or a future + multi-agent platform topology) would otherwise collide on lease + ownership and step on each other's tasks. The platform's + ``binding_mismatch`` protection covers split-brain on the + same agent+session but is silent on this orthogonal case. + + The owner is stable across process restarts within the same + ``(agent_name, session_id)`` pair, enabling dual-identity lease + reclamation. + + On-the-wire format: ``"{agent_name}|session:{session_id}"``. Both + components are recoverable from the string by splitting on the + first ``"|"``; the format is chosen for operator readability in + logs. + + :param agent_name: The agent name (resolved from + ``FOUNDRY_AGENT_NAME``). Falls back to ``"unknown-agent"`` when + the env var is unset — the caller decides whether to do the + fallback or pass ``"unknown-agent"`` directly. The fallback + string matches the rest of the framework's agent-name + conventions so traces, logs, and lease ownership agree. + :type agent_name: str + :param session_id: The agent session identifier. + :type session_id: str + :return: A lease owner string containing both components in a + stable, parseable format. + :rtype: str + """ + safe_agent = agent_name or "unknown-agent" + return f"{safe_agent}|session:{session_id}" + + +def generate_instance_id() -> str: + """Generate an ephemeral lease instance ID unique to this process. + + Combines the PID and a timestamp to ensure uniqueness even after + rapid restarts. + + :return: A unique instance identifier. + :rtype: str + """ + return f"worker-{os.getpid()}-{uuid.uuid4().hex[:8]}-{int(time.time())}" + + +async def lease_renewal_loop( + provider: TaskProvider, + task_id: str, + *, + lease_owner: str, + lease_instance_id: str, + lease_duration_seconds: int, + cancel_event: asyncio.Event, + on_failure_count: int = 3, + on_cancel_callback: asyncio.Event | None = None, + steering_poll_callback: Callable[[], Awaitable[None]] | None = None, + last_refresh_provider: Callable[[], float] | None = None, + update_via_queue: Callable[[str, "TaskPatchRequest"], Awaitable[Any]] | None = None, +) -> None: + """Run a background lease renewal loop at half the lease duration. + + Renews the lease by PATCHing the task with the same owner/instance. + On ``on_failure_count`` consecutive failures, signals the optional + ``on_cancel_callback`` event to give the task function a chance to + checkpoint. + + The loop exits when ``cancel_event`` is set or the task is cancelled. + + :param provider: The storage provider. + :type provider: TaskProvider + :param task_id: The task to renew. + :type task_id: str + :keyword lease_owner: The stable lease owner. + :paramtype lease_owner: str + :keyword lease_instance_id: The ephemeral instance ID. + :paramtype lease_instance_id: str + :keyword lease_duration_seconds: The lease TTL in seconds. + :paramtype lease_duration_seconds: int + :keyword cancel_event: Event that stops the loop when set. + :paramtype cancel_event: asyncio.Event + :keyword on_failure_count: Consecutive failures before signalling cancel. + :paramtype on_failure_count: int + :keyword on_cancel_callback: Event to signal on repeated renewal failure. + :paramtype on_cancel_callback: asyncio.Event | None + :keyword steering_poll_callback: Async callback invoked each renewal to poll + for steering inputs. Called after successful lease renewal. + :paramtype steering_poll_callback: Callable[[], Awaitable[None]] | None + :keyword last_refresh_provider: Optional ``() -> float`` callable + returning the ``asyncio.get_event_loop().time()`` value at the + most-recent lease refresh (heartbeat OR side-effect refresh + from a payload PATCH that piggybacked lease ownership via + ``TaskManager._lease_ext_kwargs``). When provided, the loop + skips the heartbeat for any tick whose due-time has been + pushed past by a more-recent refresh, avoiding a redundant + network round-trip. ``None`` preserves the legacy fixed-tick + behaviour for tests. + :paramtype last_refresh_provider: Callable[[], float] | None + :keyword update_via_queue: — optional callable + through which the heartbeat PATCH MUST be issued so that it + acquires the per-task write lock (and is etag-aware). When + supplied, the loop uses this instead of ``provider.update``. + When ``None``, falls back to the raw provider call (used by + tests that don't construct a TaskManager). + :paramtype update_via_queue: Callable[[str, TaskPatchRequest], Awaitable[Any]] | None + """ + interval = max(1, lease_duration_seconds // 2) + consecutive_failures = 0 + + while not cancel_event.is_set(): + try: + await asyncio.wait_for( + _wait_for_event(cancel_event), + timeout=interval, + ) + # cancel_event was set — exit the loop + break + except asyncio.TimeoutError: + pass + + # Every payload PATCH that piggybacks lease ownership + # (TaskManager._lease_ext_kwargs) refreshes the lease as a + # side effect. Skip a redundant heartbeat when a more-recent + # refresh has happened within the last ``interval`` seconds. + if last_refresh_provider is not None: + try: + last_refresh_t = float(last_refresh_provider()) + except Exception: # pylint: disable=broad-exception-caught # noqa: BLE001 + last_refresh_t = 0.0 + if last_refresh_t > 0.0: + now_t = asyncio.get_event_loop().time() + age = now_t - last_refresh_t + if age < interval: + remaining = interval - age + try: + await asyncio.wait_for( + _wait_for_event(cancel_event), + timeout=remaining, + ) + break # cancel fired + except asyncio.TimeoutError: + continue # re-check on the next iteration + + try: + patch = TaskPatchRequest( + lease_owner=lease_owner, + lease_instance_id=lease_instance_id, + lease_duration_seconds=lease_duration_seconds, + ) + if update_via_queue is not None: + await update_via_queue(task_id, patch) + else: + await provider.update(task_id, patch) + consecutive_failures = 0 + logger.debug("Lease renewed for task %s", task_id) + + # Poll for steering inputs after successful renewal + if steering_poll_callback is not None: + try: + await steering_poll_callback() + except Exception: # pylint: disable=broad-exception-caught + logger.debug("Steering poll failed for task %s", task_id, exc_info=True) + except _HostedConflict as exc: + translated = _translate_hosted_conflict(exc, task_id=task_id) + if translated is None or getattr(translated, "current_status", None) == "in_progress": + if on_cancel_callback is not None: + logger.warning( + "Lease renewal lost ownership for task %s — cancelling local execution", + task_id, + ) + on_cancel_callback.set() + break + consecutive_failures += 1 + logger.warning( + "Lease renewal failed for task %s (attempt %d/%d): %s", + task_id, + consecutive_failures, + on_failure_count, + translated, + exc_info=True, + ) + if consecutive_failures >= on_failure_count and on_cancel_callback is not None: + logger.error( + "Lease renewal failed %d times for task %s — signalling cancellation", + on_failure_count, + task_id, + ) + on_cancel_callback.set() + break + except TransportClassifiedError as exc: + if getattr(exc, "classification", None) == "evicted" and on_cancel_callback is not None: + #: orphan-sandbox eviction at the lease-renewal + # site. Stop renewing immediately; signal the local cleanup + # callback so _manager.py can cancel the local execution, + # suppress any pending terminal write, and signal awaiters + # with TaskConflictError. The local cleanup sequence is + # atomic per Invariant 1 (no partial cleanup state observable). + logger.warning( + "Lease renewal rejected with binding_mismatch for task %s " + "(orphan-sandbox eviction) — cancelling local execution", + task_id, + ) + on_cancel_callback.set() + break + # Non-eviction classified errors fall through to the generic + # failure-counter path (e.g. transient 503 → retry). + consecutive_failures += 1 + logger.warning( + "Lease renewal failed for task %s (attempt %d/%d): %s", + task_id, + consecutive_failures, + on_failure_count, + exc, + exc_info=True, + ) + if consecutive_failures >= on_failure_count and on_cancel_callback is not None: + logger.error( + "Lease renewal failed %d times for task %s — signalling cancellation", + on_failure_count, + task_id, + ) + on_cancel_callback.set() + break + except Exception: # pylint: disable=broad-exception-caught + consecutive_failures += 1 + logger.warning( + "Lease renewal failed for task %s (attempt %d/%d)", + task_id, + consecutive_failures, + on_failure_count, + exc_info=True, + ) + if consecutive_failures >= on_failure_count and on_cancel_callback is not None: + logger.error( + "Lease renewal failed %d times for task %s — signalling cancellation", + on_failure_count, + task_id, + ) + on_cancel_callback.set() + break + + +async def _wait_for_event(event: asyncio.Event) -> None: + """Await an asyncio event. Used with ``wait_for`` for interruptible sleep. + + :param event: The asyncio event to wait for. + :type event: asyncio.Event + """ + await event.wait() diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tasks/_local_provider.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tasks/_local_provider.py new file mode 100644 index 000000000000..979c0eef8570 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tasks/_local_provider.py @@ -0,0 +1,676 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Local filesystem-backed resilient task provider. + +Stores tasks as JSON files under +``${AGENTSERVER_STATE_ROOT:-~/.agentserver}/tasks/{agent_name}/{session_id}/`` +(unified storage layout) for local development with +full lifecycle parity. +""" + +from __future__ import annotations + +import datetime +import hashlib +import json +import logging +import os +from pathlib import Path +from typing import Any, Iterable + +from . import _validation +from ._attachments import ( + _validate_attachment_count, + _validate_attachment_size, +) +from ._exceptions_internal import TaskNotFound +from ._exceptions_internal import _HostedConflict +from ._models import ( + LeaseInfo, + TaskCreateRequest, + TaskInfo, + TaskPatchRequest, + TaskStatus, +) + +logger = logging.getLogger("azure.ai.agentserver.tasks") + + +class _LocalEtagMismatch(_HostedConflict, ValueError): + """ETag mismatch that preserves legacy local-provider ValueError checks.""" + + +def _now_iso() -> str: + return datetime.datetime.now(datetime.timezone.utc).isoformat() + + +def _generate_etag(data: dict[str, Any]) -> str: + raw = json.dumps(data, sort_keys=True) + return f"local-{hashlib.sha256(raw.encode()).hexdigest()[:16]}" + + +def _is_lease_expired(lease: LeaseInfo | None) -> bool: + if lease is None: + return True + try: + expires = datetime.datetime.fromisoformat(lease.expires_at) + now = datetime.datetime.now(datetime.timezone.utc) + return now >= expires + except (ValueError, TypeError): + return True + + +def _expires_at(duration_seconds: int) -> str: + return (datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(seconds=duration_seconds)).isoformat() + + +def _invalid_request(message: str, task_id: str | None = None) -> None: + raise _HostedConflict(_code="invalid_request", status_code=400, message=message, task_id=task_id) + + +def _lease_held(task_id: str) -> None: + raise _HostedConflict( + _code="lease_held_by_another", + status_code=409, + message="Lease is held by another owner or instance.", + task_id=task_id, + ) + + +def _etag_mismatch(task_id: str) -> None: + raise _LocalEtagMismatch( + _code="etag_mismatch", + status_code=412, + message="ETag mismatch.", + task_id=task_id, + ) + + +class LocalFileTaskProvider: + """Filesystem-backed provider for local development. + + Tasks are stored as individual JSON files. Lease expiry is simulated + by checking timestamps on read. + + :param base_dir: Root directory for task storage. + Defaults to ``${AGENTSERVER_STATE_ROOT:-~/.agentserver}/tasks`` + via :func:`azure.ai.agentserver.core.storage_paths.resolve_state_subdir`. + :type base_dir: Path | None + """ + + def __init__(self, base_dir: Path | None = None) -> None: + if base_dir is None: + from ..storage_paths import ( # pylint: disable=import-outside-toplevel + resolve_state_subdir, + ) + + base_dir = resolve_state_subdir("tasks") + self._base_dir = base_dir + + def _task_dir(self, agent_name: str, session_id: str) -> Path: + return self._base_dir / agent_name / session_id + + def _task_path(self, agent_name: str, session_id: str, task_id: str) -> Path: + return self._task_dir(agent_name, session_id) / f"{task_id}.json" + + def _find_task_path(self, task_id: str) -> Path | None: + """Search all agent/session dirs for a task file. + + :param task_id: The task identifier. + :type task_id: str + :return: The path to the task file, or None. + :rtype: ~pathlib.Path | None + """ + if not self._base_dir.exists(): + return None + for agent_dir in self._base_dir.iterdir(): + if not agent_dir.is_dir(): + continue + for session_dir in agent_dir.iterdir(): + if not session_dir.is_dir(): + continue + path = session_dir / f"{task_id}.json" + if path.exists(): + return path + return None + + def _iter_task_paths(self, agent_name: str | None, session_id: str | None) -> Iterable[Path]: + if not self._base_dir.exists(): + return [] + if agent_name is not None and session_id is not None: + task_dir = self._task_dir(agent_name, session_id) + return task_dir.glob("*.json") if task_dir.exists() else [] + if agent_name is not None: + agent_dir = self._base_dir / agent_name + if not agent_dir.exists(): + return [] + return ( + path + for session_dir in agent_dir.iterdir() + if session_dir.is_dir() + for path in session_dir.glob("*.json") + ) + if session_id is not None: + return ( + path + for agent_dir in self._base_dir.iterdir() + if agent_dir.is_dir() + for session_dir in agent_dir.iterdir() + if session_dir.is_dir() and session_dir.name == session_id + for path in session_dir.glob("*.json") + ) + return ( + path + for agent_dir in self._base_dir.iterdir() + if agent_dir.is_dir() + for session_dir in agent_dir.iterdir() + if session_dir.is_dir() + for path in session_dir.glob("*.json") + ) + + def _read_task(self, path: Path) -> TaskInfo | None: + if not path.exists(): + return None + try: + data = json.loads(path.read_text(encoding="utf-8")) + return TaskInfo.from_dict(data) + except (json.JSONDecodeError, KeyError): + logger.warning("Corrupt task file: %s", path) + return None + + def _write_task(self, task: TaskInfo) -> None: + path = self._task_path(task.agent_name, task.session_id, task.id) + path.parent.mkdir(parents=True, exist_ok=True) + data = task.to_dict() + data["etag"] = _generate_etag(data) + task.etag = data["etag"] + path.write_text(json.dumps(data, indent=2), encoding="utf-8") + + @staticmethod + def _validate_create_request(request: TaskCreateRequest, task_id: str) -> str: + _validation.validate_task_id(task_id) + _validation.validate_required_string(request.agent_name, "agent_name", _validation.MAX_AGENT_NAME_LEN) + _validation.validate_required_string(request.session_id, "session_id", _validation.MAX_SESSION_ID_LEN) + _validation.validate_required_string(request.title, "title", _validation.MAX_TITLE_LEN) + _validation.validate_optional_string(request.description, "description", _validation.MAX_DESCRIPTION_LEN) + _validation.validate_tags(request.tags) + _validation.validate_payload_size(request.payload) + _validation.validate_source(request.source) + _validation.validate_attachment_keys(request.attachments) + try: + return _validation.validate_create_status(request.status) + except _HostedConflict: + # A few local-only recovery tests seed terminal/suspended records + # directly through the provider. Preserve that legacy seeding path + # while still rejecting the reserved "failed" input status. + if request.status == "failed": + raise + return _validation.validate_patch_status(request.status) or "pending" + + @staticmethod + def _validate_create_attachments(task_id: str, attachments: dict[str, Any] | None) -> dict[str, Any] | None: + if attachments is None: + return None + additions = sum(1 for value in attachments.values() if value is not None) + _validate_attachment_count(task_id=task_id, current_count=0, additions=additions) + for key, value in attachments.items(): + _validate_attachment_size(task_id=task_id, attachment_key=key, value=value) + created = {key: value for key, value in attachments.items() if value is not None} + return created or None + + async def create(self, request: TaskCreateRequest) -> TaskInfo: + """Create a new task as a JSON file. + + :param request: Task creation parameters. + :type request: TaskCreateRequest + :return: The created task record. + :rtype: TaskInfo + """ + now = _now_iso() + task_id = request.id or f"task-{os.urandom(8).hex()}" + status = self._validate_create_request(request, task_id) + lease_request = _validation.validate_lease_params( + request.lease_owner, + request.lease_instance_id, + request.lease_duration_seconds, + ) + + if status == "pending" and lease_request is not None: + _invalid_request( + "lease_owner, lease_instance_id, and lease_duration_seconds must " + "not be provided when status is pending.", + task_id, + ) + if self._find_task_path(task_id) is not None: + raise _HostedConflict( + _code="task_already_exists", + status_code=409, + message=f"Task {task_id!r} already exists.", + task_id=task_id, + ) + + lease: LeaseInfo | None = None + started_at: str | None = None + completed_at: str | None = now if status == "completed" else None + if lease_request is not None: + owner, instance_id, duration_seconds = lease_request + lease = LeaseInfo( + owner=owner, + instance_id=instance_id, + generation=0, + expires_at=_expires_at(duration_seconds), + expiry_count=0, + heartbeat_at=now, + ) + if status == "in_progress": + started_at = now + + task = TaskInfo( + id=task_id, + agent_name=request.agent_name, + session_id=request.session_id, + status=status, # type: ignore[arg-type] + title=request.title, + description=request.description, + lease=lease, + payload=request.payload, + tags=request.tags, + source=request.source, + attachments=self._validate_create_attachments(task_id, request.attachments), + created_at=now, + updated_at=now, + started_at=started_at, + completed_at=completed_at, + ) + self._write_task(task) + logger.debug("Created local task %s", task_id) + return task + + async def get(self, task_id: str) -> TaskInfo | None: + """Get a task by ID from the filesystem. + + :param task_id: The task identifier. + :type task_id: str + :return: The task record, or ``None`` if not found. + :rtype: TaskInfo | None + """ + path = self._find_task_path(task_id) + if path is None: + return None + return self._read_task(path) + + @staticmethod + def _reject_immutable_patch_fields(patch: TaskPatchRequest | dict[str, Any], task_id: str) -> None: + for field_name in _validation.IMMUTABLE_PATCH_FIELDS: + if isinstance(patch, dict): + value = patch.get(field_name) + else: + value = getattr(patch, field_name, None) + if value is None: + continue + if field_name == "source": + _validation.validate_source(value) + _invalid_request(f"{field_name} is immutable and cannot be patched.", task_id) + + @staticmethod + def _patch_is_completed_noop( + patch: TaskPatchRequest, + normalized_status: str | None, + lease_request: tuple[str, str, int] | None, + ) -> bool: + return ( + normalized_status in (None, "completed") + and patch.payload is None + and patch.tags is None + and patch.error is None + and patch.suspension_reason is None + and lease_request is None + and patch.attachments is None + and not getattr(patch, "clear_attachments", False) + ) + + @staticmethod + def _lease_matches(lease: LeaseInfo | None, owner: str, instance_id: str) -> bool: + return lease is not None and lease.owner == owner and lease.instance_id == instance_id + + @staticmethod + def _apply_lease_acquisition( + task: TaskInfo, + lease_request: tuple[str, str, int], + now: str, + ) -> None: + owner, instance_id, duration_seconds = lease_request + current = task.lease + generation = 0 + expiry_count = 0 + if current is not None: + expired = _is_lease_expired(current) + expiry_count = current.expiry_count + if current.owner == owner and current.instance_id == instance_id: + generation = current.generation + elif current.owner == owner: + generation = current.generation + 1 + if expired: + expiry_count = current.expiry_count + 1 + elif expired: + generation = current.generation + 1 + expiry_count = current.expiry_count + 1 + else: + _lease_held(task.id) + + task.lease = LeaseInfo( + owner=owner, + instance_id=instance_id, + generation=generation, + expires_at=_expires_at(duration_seconds), + expiry_count=expiry_count, + heartbeat_at=now, + ) + + @staticmethod + def _validate_lease_rules( + task: TaskInfo, + target_status: str, + status_change: bool, + lease_request: tuple[str, str, int] | None, + ) -> None: + if lease_request is None: + if status_change and task.status == "in_progress" and target_status == "pending": + _lease_held(task.id) + return + + owner, instance_id, duration_seconds = lease_request + if status_change and duration_seconds == 0: + _invalid_request( + "lease_duration_seconds=0 cannot be combined with a status change.", + task.id, + ) + if status_change and target_status in {"completed", "suspended"}: + _invalid_request( + "lease parameters cannot be supplied when transitioning to " f"{target_status}.", + task.id, + ) + if status_change and task.status == "in_progress" and target_status == "pending": + if not LocalFileTaskProvider._lease_matches(task.lease, owner, instance_id): + _lease_held(task.id) + if not status_change and duration_seconds > 0 and task.status != "in_progress": + _invalid_request( + "Lease renewal is only allowed when current status is in_progress.", + task.id, + ) + if duration_seconds == 0: + if task.lease is None: + _invalid_request("No lease is available to force-expire.", task.id) + if not _is_lease_expired(task.lease) and not LocalFileTaskProvider._lease_matches( + task.lease, owner, instance_id + ): + _lease_held(task.id) + elif task.lease is not None and task.lease.owner != owner and not _is_lease_expired(task.lease): + _lease_held(task.id) + + @staticmethod + def _apply_payload_patch(task: TaskInfo, payload: Any) -> None: + if payload is None: + return + if isinstance(payload, dict): + current = task.payload if isinstance(task.payload, dict) else {} + merged = dict(current) + merged.update(payload) + _validation.validate_payload_size(merged) + task.payload = merged + else: + _validation.validate_payload_size(payload) + task.payload = payload + + @staticmethod + def _apply_tags_patch(task: TaskInfo, tags: dict[str, Any]) -> None: + merged = dict(task.tags or {}) + for key, value in tags.items(): + if value is None: + merged.pop(key, None) + else: + merged[key] = value + _validation.validate_tags(merged) + task.tags = merged or None + + @staticmethod + def _apply_attachments_patch( + task: TaskInfo, + attachments: dict[str, Any] | None, + clear_attachments: bool, + ) -> None: + if clear_attachments: + task.attachments = None + return + if attachments is None: + return + _validation.validate_attachment_keys(attachments) + for key, value in attachments.items(): + _validate_attachment_size(task_id=task.id, attachment_key=key, value=value) + merged = dict(task.attachments or {}) + for key, value in attachments.items(): + if value is None: + merged.pop(key, None) + else: + merged[key] = value + _validate_attachment_count(task_id=task.id, current_count=len(merged), additions=0) + task.attachments = merged or None + + async def update( # pylint: disable=too-many-branches,too-many-statements + self, task_id: str, patch: TaskPatchRequest + ) -> TaskInfo: + """Update a task via PATCH semantics. + + :param task_id: The task identifier. + :type task_id: str + :param patch: Fields to update. + :type patch: TaskPatchRequest + :return: The updated task record. + :rtype: TaskInfo + :raises TaskNotFound: If the task does not exist. + """ + path = self._find_task_path(task_id) + if path is None: + raise TaskNotFound(task_id) + + task = self._read_task(path) + if task is None: + raise TaskNotFound(task_id) + + if patch.if_match is not None and patch.if_match != task.etag: + _etag_mismatch(task_id) + + normalized_status = _validation.validate_patch_status(patch.status) + lease_request = _validation.validate_lease_params( + patch.lease_owner, + patch.lease_instance_id, + patch.lease_duration_seconds, + ) + self._reject_immutable_patch_fields(patch, task_id) + _validation.validate_tags(patch.tags) + _validation.validate_payload_size(patch.payload) + _validation.validate_error(patch.error) + normalized_error = _validation.normalize_error(patch.error) + _validation.validate_optional_string( + patch.suspension_reason, + "suspension_reason", + _validation.MAX_SUSPENSION_REASON_LEN, + ) + + if getattr(patch, "clear_attachments", False) and patch.attachments is not None: + _invalid_request("clear_attachments cannot be combined with attachments patch.", task_id) + + target_status = normalized_status or task.status + if patch.suspension_reason is not None and target_status != "suspended": + _invalid_request( + "suspension_reason is only allowed when target status is suspended.", + task_id, + ) + + if task.status == "completed": + if self._patch_is_completed_noop(patch, normalized_status, lease_request): + return task + raise _HostedConflict( + _code="task_immutable", + status_code=409, + message="Completed tasks are immutable.", + task_id=task_id, + ) + + status_change = normalized_status is not None and normalized_status != task.status + if status_change: + _validation.validate_transition(task.status, target_status) + self._validate_lease_rules(task, target_status, status_change, lease_request) + + now = _now_iso() + if status_change: + task.status = target_status # type: ignore[assignment] + if target_status == "pending": + task.lease = None + task.suspension_reason = None + elif target_status == "in_progress": + if lease_request is not None: + self._apply_lease_acquisition(task, lease_request, now) + if task.started_at is None: + task.started_at = now + task.suspension_reason = None + task.completed_at = None + elif target_status == "completed": + task.lease = None + task.suspension_reason = None + if task.completed_at is None: + task.completed_at = now + elif target_status == "suspended": + task.lease = None + task.suspension_reason = patch.suspension_reason + task.completed_at = None + elif lease_request is not None: + _, _, duration_seconds = lease_request + if duration_seconds == 0: + assert task.lease is not None + task.lease.expires_at = now + task.lease.heartbeat_at = now + else: + self._apply_lease_acquisition(task, lease_request, now) + + self._apply_payload_patch(task, patch.payload) + if patch.tags is not None: + self._apply_tags_patch(task, patch.tags) + self._apply_attachments_patch( + task, + patch.attachments, + getattr(patch, "clear_attachments", False), + ) + if normalized_error is not None: + task.error = normalized_error + if not status_change and patch.suspension_reason is not None: + task.suspension_reason = patch.suspension_reason + + task.updated_at = now + self._write_task(task) + return task + + async def delete( + self, + task_id: str, + *, + force: bool = False, + cascade: bool = False, # pylint: disable=unused-argument + if_match: str | None = None, + ) -> None: + """Delete a task JSON file. + + :param task_id: The task identifier. + :type task_id: str + :keyword force: Required for non-terminal tasks. + :paramtype force: bool + :keyword cascade: Delete dependent tasks (no-op for local). + :paramtype cascade: bool + :keyword if_match: ETag precondition for delete. + :paramtype if_match: str | None + """ + path = self._find_task_path(task_id) + if path is None: + raise TaskNotFound(task_id) + task = self._read_task(path) + if task is None: + raise TaskNotFound(task_id) + if if_match is not None and if_match != task.etag: + _etag_mismatch(task_id) + if task.status != "completed" and not force: + _invalid_request("Non-terminal tasks require force=true for deletion.", task_id) + path.unlink(missing_ok=True) + logger.debug("Deleted local task %s", task_id) + + async def list( + self, + *, + agent_name: str | None = None, + session_id: str | None = None, + status: TaskStatus | str | None = None, + lease_owner: str | None = None, + tag: dict[str, str] | None = None, + source_type: str | None = None, + has_error: bool | None = None, + lease_expired: bool | None = None, + limit: int | None = None, + after: str | None = None, + before: str | None = None, + order: str | None = None, + omit_attachment_values: bool = False, + ) -> list[TaskInfo]: + """List tasks from the filesystem.""" + if before is not None: + _invalid_request("before is not supported for task list.") + page_size = 20 if limit is None else limit + if page_size <= 0: + _invalid_request("limit must be greater than 0.") + page_size = min(page_size, 100) + sort_order = order or "desc" + if sort_order not in {"asc", "desc"}: + _invalid_request("order must be 'asc' or 'desc'.") + normalized_status = _validation.normalize_legacy_status(status) + + results: list[TaskInfo] = [] + for path in self._iter_task_paths(agent_name, session_id): + task = self._read_task(path) + if task is None: + continue + if agent_name is not None and task.agent_name != agent_name: + continue + if session_id is not None and task.session_id != session_id: + continue + if normalized_status is not None and task.status != normalized_status: + continue + if lease_owner is not None: + if task.lease is None or task.lease.owner != lease_owner: + continue + if tag is not None: + task_tags = task.tags or {} + if not all(task_tags.get(key) == value for key, value in tag.items()): + continue + if source_type is not None: + task_source = task.source or {} + if task_source.get("type") != source_type: + continue + if has_error is not None and bool(task.error) != has_error: + continue + if lease_expired is not None and _is_lease_expired(task.lease) != lease_expired: + continue + results.append(task) + + results.sort(key=lambda item: item.created_at or "", reverse=sort_order == "desc") + if after is not None: + for index, task in enumerate(results): + if task.id == after: + results = results[index + 1 :] + break + else: + results = [] + results = results[:page_size] + if omit_attachment_values: + for task in results: + if task.attachments is not None: + task.attachments = {key: None for key in task.attachments} + return results diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tasks/_manager.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tasks/_manager.py new file mode 100644 index 000000000000..5848672e8176 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tasks/_manager.py @@ -0,0 +1,3493 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""TaskManager — lifecycle orchestration for resilient tasks. + +Manages task creation, lease acquisition, execution, recovery, and +shutdown. One instance per ``AgentServerHost``, accessed via the +module-level ``get_task_manager()`` function. +""" + +from __future__ import annotations + +import asyncio # pylint: disable=do-not-import-asyncio +import logging +import traceback +from collections.abc import Awaitable, Callable +from pathlib import Path +from typing import Any, Optional, TypeVar + +from .._config import AgentConfig +from ._client import TransportClassifiedError +from ._context import EntryMode, TaskContext +from ._attachments import ( + _FUNCTION_INPUT_KEY, + _INPUT_THRESHOLD_BYTES, + _MAX_ATTACHMENT_SIZE_BYTES, + _is_ref, + _make_ref, + _read_input_value, + _ref_key, + _remap_attachment_error, + _resolve_input_storage, + _serialized_size_bytes, +) +from ._decorator import TaskOptions, _deserialize_input, _serialize_input +from ._exceptions import ( + EtagConflict, + OutputTooLarge, + TaskConflictError, + TaskFailed, + TaskNotFound, + _AttachmentTooLarge, +) +from ._exceptions_internal import _HostedConflict, _translate_hosted_conflict +from ._lease import derive_lease_owner, generate_instance_id, lease_renewal_loop +from ._metadata import TaskMetadata +from ._models import TaskCreateRequest, TaskInfo, TaskPatchRequest, TaskStatus +from ._provider import TaskProvider +from ._retry import RetryPolicy +from ._run import TaskRun +from .._version import VERSION as _CORE_VERSION +from .._server_version import build_server_version as _build_server_version + +logger = logging.getLogger("azure.ai.agentserver.tasks") + +#: Auto-stamped source type for all tasks created by this framework. +_SOURCE_TYPE = "agentserver.task" + +#: Reserved tag key for task name filtering via the LIST API. +_TAG_TASK_NAME = "_task_name" + +#: — default lease TTL. The per-task +#: ``lease_duration_seconds`` knob was demoted (no developer use case justified +#: exposing it on ``@task``). This constant is the framework's choice. +_DEFAULT_LEASE_SECONDS = 60 + +#: Pre-computed server version segment for source stamps. +_SOURCE_SERVER_VERSION = _build_server_version("azure-ai-agentserver-core", _CORE_VERSION) + +Input = TypeVar("Input") +Output = TypeVar("Output") + +# Module-level manager singleton +_manager: TaskManager | None = None + + +def _is_evicted(exc: BaseException) -> bool: + """Return True if ``exc`` is the eviction-classified rejection. + + helper used by every store-write call site that must + funnel through the / local-cleanup sequence on + orphan-sandbox eviction. The HostedTaskProvider raises + ``TransportClassifiedError(classification="evicted")`` after the + pipeline classifier maps an HTTP 409 + ``binding_mismatch`` body; + in-test stubs raise the same typed exception so the framework's + cleanup runs identically against both. + + :param exc: The exception to classify. + :type exc: BaseException + :return: True if the exception is an eviction-classified rejection. + :rtype: bool + """ + return isinstance(exc, TransportClassifiedError) and getattr(exc, "classification", None) == "evicted" + + +# Layer 2 recovery +# periodic background scan interval. Module-level constant so tests +# can monkey-patch it to a small value for deterministic exercise +# without adding a public surface to TaskManager. Default ~300s +# matches the spec's "internal-only interval" requirement. +_PERIODIC_RECOVERY_INTERVAL_SECONDS: float = 300.0 + +# Bounded retry budget for the +# transient-error path in the startup scan / inline reclaim. +# Exponential backoff: 0.2 → 0.4 → 0.8 across attempts 1..3. +_RECLAIM_MAX_RETRIES: int = 3 +_RECLAIM_BACKOFF_BASE_SECONDS: float = 0.2 + +# SOT top-level payload field +# storing the ISO-8601 UTC timestamp of when the current turn started. +# Persisted at every turn-start boundary (fresh entry, +# suspended-to-in_progress resume, steering drain re-entry); NOT +# re-stamped on crash recovery so the watchdog can compute remaining +# budget = max(0, opts.timeout - (now - _turn_started_at)). +_TURN_STARTED_AT_KEY: str = "_turn_started_at" + + +def _utc_now_iso() -> str: + """Return current UTC time as an ISO-8601 string with Z suffix. + + Persisted turn-start timestamps use this format. + Z suffix matches `datetime.fromisoformat`'s expectations from + Python 3.11+ (older Pythons need the `+00:00` form). + + :return: An ISO-8601 UTC timestamp ending in ``Z``. + :rtype: str + """ + from datetime import datetime, timezone # pylint: disable=import-outside-toplevel + + return datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%S.%f") + "Z" + + +def _parse_turn_started_at(value: Any) -> float | None: + """Parse a persisted ``_turn_started_at`` value to a POSIX timestamp. + + Returns ``None`` if the value is missing, malformed, or empty — + the caller falls back to "spawn watchdog with full budget" in + that case (graceful degradation during the rollout window where + older records may not have the field yet). + + :param value: Raw persisted value (typically a string). + :type value: Any + :return: POSIX timestamp, or ``None`` if the value is invalid. + :rtype: float | None + """ + from datetime import datetime, timezone # pylint: disable=import-outside-toplevel + + if not value or not isinstance(value, str): + return None + try: + normalized = value.replace("Z", "+00:00") if value.endswith("Z") else value + dt = datetime.fromisoformat(normalized) + if dt.tzinfo is None: + dt = dt.replace(tzinfo=timezone.utc) + return dt.timestamp() + except (ValueError, TypeError): + return None + + +def _resolve_queued_steerers_on_terminal( + pending_steering_futures: dict[str, list["asyncio.Future[Any]"]], + task_id: str, + *, + current_status: str, +) -> None: + """(Subscriber) helper. + + When a steerable task terminates (handler returned a value or + raised), any callers that queued a steering input via + ``.start()`` (and got back a TaskRun bound to a future from + ``_pending_steering_futures``) MUST receive ``TaskConflictError`` + on their ``.result()`` — the same shape a fresh ``.start()`` + against an already-terminal task would raise. + + Pops every queued steerer future for ``task_id`` and resolves + each with ``TaskConflictError(current_status=current_status)``. + + :param pending_steering_futures: Per-task list of pending steerer + futures (mutated in-place — emptied for the given ``task_id``). + :type pending_steering_futures: dict[str, list[asyncio.Future[Any]]] + :param task_id: The task whose queued steerers should be resolved. + :type task_id: str + :keyword current_status: Status string to carry on + ``TaskConflictError`` so callers can branch. + :paramtype current_status: str + """ + # TaskConflictError is already imported at module top-level (line 24). + + queued = pending_steering_futures.pop(task_id, []) + for fut in queued: + if not fut.done(): + fut.set_exception(TaskConflictError(task_id, current_status)) + + +def _lease_is_dead( + task_info: Any, + *, + this_lease_owner: str, + active_locally: bool, +) -> bool: + """Determine whether an in-progress record's lease is dead. + + +: a lease is "live" only if EITHER ownership + matches this process AND an in-memory active entry tracks it (so we + know the local execution is running), OR the lease ownership belongs + to this process AND the expiry has not passed. + + "Dead" means the framework should reclaim. "Live" means the record + is either currently being executed (here or elsewhere) and the + caller should observe the conflict shape. + + Per (lease owner includes agent_name + session_id), a record + whose owner differs from ours belongs to a different agent — the + framework MUST NOT reclaim it (that would steal another agent's + work). Such records appear "dead from this process's perspective" + but should NOT be subject to reclaim; the scheduling primitive + raises TaskConflictError instead. + + For the LocalFileTaskProvider used in tests (no real expiry + tracking), absence of a local in-memory entry combined with + matching ownership suffices to detect a previous-lifetime crash. + + :param task_info: The persisted task record (any object exposing + ``lease.owner`` and ``lease.expires_at``). + :type task_info: Any + :keyword this_lease_owner: Lease-owner string for this process. + :paramtype this_lease_owner: str + :keyword active_locally: True if this process has an in-memory + ``_ActiveTask`` entry tracking the record. + :paramtype active_locally: bool + :return: True if the lease is dead AND eligible for reclaim by us. + :rtype: bool + """ + if active_locally: + # We are actively executing it; lease is definitely live in + # this process. + return False + # TaskInfo carries lease state as a nested LeaseInfo object. + lease = getattr(task_info, "lease", None) + owner = getattr(lease, "owner", None) if lease is not None else None + owner = owner or "" + # Owner matches ours but no local in-memory entry → previous + # lifetime owned by THIS (agent, session) pair crashed; lease + # is dead and eligible for reclaim. + if owner and owner == this_lease_owner: + return True + # Foreign owner: this record belongs to a different agent OR a + # different session. We MUST NOT reclaim it. Caller + # observes the live-elsewhere conflict shape. + if owner and owner != this_lease_owner: + return False + # No owner recorded — treat as dead since no live executor + # claims it. (Empty owner happens for freshly-created records + # before lease assignment.) + return True + + +def get_task_manager() -> TaskManager: + """Return the active TaskManager singleton. + + :raises RuntimeError: If no manager has been initialized. + :return: The active manager. + :rtype: TaskManager + """ + if _manager is None: + raise RuntimeError( + "TaskManager not initialized. Ensure resilient tasks " + "are enabled on the AgentServerHost." # pylint: disable=implicit-str-concat + ) + return _manager + + +def set_task_manager(manager: TaskManager | None) -> None: + """Set the module-level TaskManager singleton. + + Called by ``AgentServerHost`` during startup/shutdown. + + :param manager: The manager to set, or ``None`` to clear. + :type manager: TaskManager | None + """ + global _manager # pylint: disable=global-statement + _manager = manager + + +class _ActiveTask: # pylint: disable=too-many-instance-attributes + """In-memory tracking for a running task.""" + + __slots__ = ( + "task_id", + "fn_name", + "context", + "execution_task", + "renewal_task", + "renewal_cancel", + "result_future", + "terminate_event", + "fn", + "input_type", + "opts", + "retry", + "lease_last_refresh_monotonic", + # / — latest known etag for this task. + # Refreshed from every GET/CREATE/PATCH response. Used as + # ``if_match`` on every subsequent PATCH. + "current_etag", + # Spec 031 / FR-002 — live count of queued steering inputs as + # observed by THIS process. Read by ``_make_pending_count_provider`` + # to back ``ctx.pending_input_count``. Written (before ``ctx.cancel`` + # is set, per SOT §13 ordering invariant) by the same-process + # steering enqueue and by the cross-process steering poll. Must be a + # slot or it is unsettable (the historic bug: it was read but never + # storable). + "_pending_input_count", + ) + + def __init__( + self, + task_id: str, + fn_name: str, + context: TaskContext[Any], + execution_task: asyncio.Task[Any], + renewal_task: asyncio.Task[None] | None, + renewal_cancel: asyncio.Event, + result_future: asyncio.Future[Any], + terminate_event: asyncio.Event | None = None, + fn: Callable[..., Awaitable[Any]] | None = None, + input_type: type[Any] | None = None, + opts: TaskOptions | None = None, + retry: RetryPolicy | None = None, + ) -> None: + self.task_id = task_id + self.fn_name = fn_name + self.context = context + self.execution_task = execution_task + self.renewal_task = renewal_task + self.renewal_cancel = renewal_cancel + self.result_future = result_future + self.terminate_event = terminate_event or asyncio.Event() + self.fn = fn + self.input_type = input_type + self.opts = opts + self.retry = retry + # ``asyncio.get_event_loop().time()`` value at the last successful + # lease refresh -- updated by the renewal loop AND by every + # payload PATCH that piggybacks lease ownership (see + # ``_lease_ext_kwargs`` / ``_note_lease_refreshed``). The + # renewal loop reads this to push out its next scheduled tick + # so it doesn't issue a redundant heartbeat the moment after a + # payload PATCH already refreshed the lease. + self.lease_last_refresh_monotonic: float = 0.0 + # — latest known etag, refreshed on every + # store interaction (create response, get response, update response). + # Used as ``if_match`` on subsequent PATCHes. + self.current_etag: str | None = None + # Spec 031 / FR-002 — see __slots__ note. Live in-process count of + # queued steering inputs backing ``ctx.pending_input_count``. + self._pending_input_count: int = 0 + + +class TaskManager: # pylint: disable=too-many-instance-attributes + """Lifecycle orchestrator for resilient tasks. + + Manages provider selection, task creation, lease management, + execution dispatch, crash recovery, and graceful shutdown. + + :param config: Resolved agent configuration. + :type config: AgentConfig + :param provider: Optional explicit provider (for testing). + :type provider: TaskProvider | None + :param shutdown_event: Shared shutdown event from the host. + :type shutdown_event: asyncio.Event | None + :param shutdown_grace_seconds: Seconds to wait for tasks to checkpoint + before force-expiring leases during shutdown. Defaults to 25.0. + :type shutdown_grace_seconds: float + """ + + def __init__( + self, + config: AgentConfig, + *, + provider: TaskProvider | None = None, + shutdown_event: asyncio.Event | None = None, + shutdown_grace_seconds: float = 25.0, + ) -> None: + self._config = config + self._provider = provider or self._create_provider(config) + self._active_tasks: dict[str, _ActiveTask] = {} + self._resume_callbacks: dict[str, Callable[..., Any]] = {} + self._resume_opts: dict[str, TaskOptions] = {} + self._lease_owner = derive_lease_owner( + config.agent_name or "unknown-agent", + config.session_id or "local", + ) + self._instance_id = generate_instance_id() + self._shutdown_event = shutdown_event or asyncio.Event() + self._shutdown_grace_seconds = shutdown_grace_seconds + self._active_generation_future: dict[str, asyncio.Future[Any]] = {} + self._pending_steering_futures: dict[str, list[asyncio.Future[Any]]] = {} + # Layer 2: periodic recovery scan task. Created + # at startup() time; cancelled at shutdown(). + self._periodic_recovery_task: asyncio.Task[None] | None = None + # / C-WQ-1..3 — per-task write-queue + # registry. A single asyncio.Lock per task_id serializes all + # in-process PATCHes against that task so etag conflicts become + # rare (only cross-process). Lazy-created on first use; dropped + # in ``_active_tasks_pop`` (no leaks). + # — also tracks the latest known etag + # per task_id outside the _ActiveTask entry, so reclaim/scan + # paths (which have no _ActiveTask yet) can still benefit. + self._task_write_locks: dict[str, asyncio.Lock] = {} + self._task_etag_cache: dict[str, str] = {} + # SOT §52 — per-turn timeout watchdog registry. Each per-turn + # watchdog gets registered here so that the steering-drain + # re-entry can cancel the prior turn's watchdog and respawn a + # fresh one bound to the new turn's _turn_started_at. Cleared + # on terminal exit. + self._timeout_watchdogs: dict[str, asyncio.Task[None]] = {} + + @staticmethod + def _build_source(fn_name: str) -> dict[str, str]: + """Build the framework-owned source stamp for a task. + + The ``fn_name`` is the developer-provided ``name`` from the decorator + (or ``fn.__qualname__`` when omitted). It serves as the **stable + identity anchor** — recovery routing matches ``source.name`` against + registered callbacks to dispatch recovered tasks back to the correct + function. + + :param fn_name: The task name (from ``@task(name=...)``). + :type fn_name: str + :return: Source metadata dict. + :rtype: dict[str, str] + """ + return { + "type": _SOURCE_TYPE, + "name": fn_name, + "server_version": _SOURCE_SERVER_VERSION, + } + + @staticmethod + def _create_provider(config: AgentConfig) -> TaskProvider: + """Auto-select provider based on hosting environment. + + In hosted environments (``FOUNDRY_HOSTING_ENVIRONMENT`` is set), + the HTTP-backed ``HostedTaskProvider`` is used by default — the + hosted task-storage API is what makes resilient recovery, + cross-instance lease handoff, and the platform's lease/readiness + keep-alive path work. + + In non-hosted environments (local dev, tests), the + ``LocalFileTaskProvider`` is used — file-backed under + ``${AGENTSERVER_STATE_ROOT:-~/.agentserver}/tasks/``. This keeps + the local development loop self-contained with no external + dependencies. + + **Operator override** — set ``AGENTSERVER_TASKS_BACKEND=local`` + to force the file-backed provider even in hosted environments. + This is useful for repro / debugging hosted-only scenarios on a + local workstation without standing up the hosted task API, and + for hosted environments where operators want to opt out of the + task-storage API (e.g. running the hosted runtime with disk + persistence only). + + :param config: The agent configuration. + :type config: AgentConfig + :return: The storage provider instance. + :rtype: TaskProvider + """ + import os # pylint: disable=import-outside-toplevel + + backend_override = os.environ.get("AGENTSERVER_TASKS_BACKEND", "").strip().lower() + if backend_override and backend_override not in ("local", "hosted"): + raise ValueError(f"AGENTSERVER_TASKS_BACKEND must be 'local' or 'hosted' (got {backend_override!r})") + + use_hosted = config.is_hosted if not backend_override else (backend_override == "hosted") + + if use_hosted: + from ._client import ( # pylint: disable=import-outside-toplevel + HostedTaskProvider, + ) + + try: + from azure.identity.aio import ( # type: ignore[import-untyped] + DefaultAzureCredential, + ) + except ImportError as exc: + raise ImportError( + "azure-identity is required for hosted mode. " + "Install with: pip install azure-ai-agentserver-core[hosted]" + ) from exc + + logger.info("Hosted environment detected; using HostedTaskProvider") + return HostedTaskProvider( + project_endpoint=config.project_endpoint, + credential=DefaultAzureCredential(), + ) + + from ._local_provider import ( # pylint: disable=import-outside-toplevel + LocalFileTaskProvider, + ) + from ..storage_paths import ( # pylint: disable=import-outside-toplevel + resolve_state_subdir, + ) + + if backend_override == "local" and config.is_hosted: + logger.info("AGENTSERVER_TASKS_BACKEND=local overrides hosted detection; " "using LocalFileTaskProvider") + + # Resolve the tasks subdirectory via the + # unified storage-paths helper. ``AGENTSERVER_STATE_ROOT`` is + # the single env-var operator knob covering tasks / streams / + # responses. The legacy ``AGENTSERVER_STATE_TASKS_PATH`` env + # var is deleted (was: per-subsystem override). + return LocalFileTaskProvider(base_dir=resolve_state_subdir("tasks")) + + @property + def provider(self) -> TaskProvider: + """The storage provider. + + :return: The active provider. + :rtype: TaskProvider + """ + return self._provider + + def register_resume_callback( + self, + fn_name: str, + fn: Callable[..., Any], + opts: TaskOptions | None = None, + ) -> None: + """Register a function as a resume callback. + + :param fn_name: The resilient task function name. + :type fn_name: str + :param fn: The async function to call on resume. + :type fn: Callable[..., Any] + :param opts: The task options (opts subset). + :type opts: TaskOptions | None + """ + self._resume_callbacks[fn_name] = fn + if opts is not None: + self._resume_opts[fn_name] = opts + + async def list_tasks( + self, + *, + fn_name: str, + session_id: str | None = None, + status: TaskStatus | None = None, + ) -> list[TaskInfo]: + """List tasks scoped to a specific task function. + + Uses server-side filtering (``agent_name``, ``session_id``, + ``_task_name`` tag, ``status``, ``source_type``) to return only + tasks created by this framework for the given function. + + :keyword fn_name: The task function name (stable identity anchor). + :paramtype fn_name: str + :keyword session_id: Session scope override. Defaults to config. + :paramtype session_id: str | None + :keyword status: Filter by task status. + :paramtype status: ~azure.ai.agentserver.core.tasks.TaskStatus | None + :return: Matching task records. + :rtype: list[TaskInfo] + """ + resolved_session = session_id or self._config.session_id or "local" + agent_name = self._config.agent_name or "default" + + # All filters are now server-side + try: + return await self._provider.list( + agent_name=agent_name, + session_id=resolved_session, + status=status, + tag={_TAG_TASK_NAME: fn_name}, + source_type=_SOURCE_TYPE, + ) + except _HostedConflict as exc: + translated = _translate_hosted_conflict(exc) + if translated is None: + raise RuntimeError("Task list did not converge after retryable conflict") from exc + raise translated from exc + + def _register_steering_future(self, task_id: str) -> asyncio.Future[Any]: + """Create and register a future for a queued steering input. + + Must be called BEFORE ``_append_steering_input()`` to avoid a race + where the drain pops the queue before the future exists. + + :param task_id: The task identifier. + :type task_id: str + :return: The registered future. + :rtype: asyncio.Future[Any] + """ + loop = asyncio.get_event_loop() + future: asyncio.Future[Any] = loop.create_future() + if task_id not in self._pending_steering_futures: + self._pending_steering_futures[task_id] = [] + self._pending_steering_futures[task_id].append(future) + return future + + async def _cancel_queued_steering_input( + self, + *, + task_id: str, + future: asyncio.Future[Any], + input_id: str | None, + input_val: Any, + ) -> None: + """Remove a queued steering input from the chain's pending queue. + + Invoked by :meth:`TaskRun.cancel` when called on a handle bound to + a queued (not-yet-promoted) steering input. The associated entry + in ``payload["_steering"]["pending_inputs"]`` is removed, the + corresponding ``_steering_input_`` attachment (if any) is + deleted, and the queued steerer's future is resolved with + ``TaskCancelled``. The active turn (if any) is not affected. + + :keyword task_id: The chain task identifier. + :keyword future: The queued steerer's result_future. + :keyword input_id: The input_id of the queued slot (used for the + future-list cleanup; the queue entry itself is identified by + ``input_val``). + :keyword input_val: The raw queued value used to identify which + ``pending_inputs`` entry to remove. + """ + from ._attachments import _is_ref, _ref_key # pylint: disable=import-outside-toplevel + from ._exceptions import TaskCancelled # pylint: disable=import-outside-toplevel + + async with self._get_task_write_lock(task_id): + try: + task_info = await self._provider_get_tracked(task_id) + except Exception: # pylint: disable=broad-exception-caught + task_info = None + if task_info is None or not task_info.payload: + # Chain already gone — just resolve the future. + if not future.done(): + future.set_exception(TaskCancelled()) + return + steering = dict(task_info.payload.get("_steering") or {}) + pending = list(steering.get("pending_inputs") or []) + attachments_patch: dict[str, Any] = {} + # Drop the first queue entry whose raw value matches ``input_val``. + removed = False + new_pending: list[Any] = [] + for entry in pending: + if not removed: + raw = entry + if _is_ref(entry): + # For ref-shaped entries, resolve via attachment to + # compare against input_val. If the attachment is + # missing, fall back to ref identity (unlikely). + key = _ref_key(entry) + raw = (task_info.attachments or {}).get(key, entry) + if raw == input_val: + removed = True + if _is_ref(entry): + attachments_patch[_ref_key(entry)] = None + continue + new_pending.append(entry) + if not removed: + # Queue entry already drained or never landed; just resolve. + if not future.done(): + future.set_exception(TaskCancelled()) + return + steering["pending_inputs"] = new_pending + steering["cancel_requested"] = len(new_pending) > 0 + payload_patch: dict[str, Any] = {"_steering": steering} + try: + # Spec 031 / FR-005a+b: the outer lock is already held, so use + # the lock-held update primitive (avoids re-entrant lock + # acquisition) which carries the tracked ``if_match`` — no blind + # writes (SOT §25.1). ``task_info`` was read inside this same + # lock above, so the tracked etag is current. + await self._provider_update_lock_held( + task_id, + TaskPatchRequest( + payload=payload_patch, + attachments=attachments_patch or None, + **self._lease_ext_kwargs(task_id), + ), + ) + except Exception: # pylint: disable=broad-exception-caught + logger.warning( + "Failed to remove queued steering input from task %s; " + "future will still be resolved with TaskCancelled", + task_id, + exc_info=True, + ) + # Remove the future from the registered pending list and resolve it. + pending_list = self._pending_steering_futures.get(task_id) or [] + if future in pending_list: + pending_list.remove(future) + if not future.done(): + future.set_exception(TaskCancelled()) + + async def startup(self) -> None: + """Initialize the manager and recover stale tasks. + + Called by ``AgentServerHost`` during lifespan startup. + """ + logger.info( + "TaskManager starting (owner=%s, instance=%s, hosted=%s)", + self._lease_owner, + self._instance_id, + self._config.is_hosted, + ) + # Pick up descriptors registered at import time (for recovery) + from ._decorator import ( # pylint: disable=import-outside-toplevel + _REGISTERED_DESCRIPTORS, + ) + + for fn_name, fn, opts in _REGISTERED_DESCRIPTORS: + self._resume_callbacks[fn_name] = fn + self._resume_opts[fn_name] = opts + + await self._recover_stale_tasks() + + # Layer 2: start the periodic recovery task. + # Reads _PERIODIC_RECOVERY_INTERVAL_SECONDS at spawn time; + # tests monkey-patch the constant to drive the scan + # deterministically. + try: + loop = asyncio.get_running_loop() + self._periodic_recovery_task = loop.create_task(self._periodic_recovery_loop()) + except RuntimeError: + # No running loop (called from outside async context); skip + # — the layer-1 startup scan above still covered the + # initial reclaim pass. + pass + + async def _periodic_recovery_loop(self) -> None: + """Layer 2: periodic background recovery scan. + + Runs at the interval defined by ``_PERIODIC_RECOVERY_INTERVAL_SECONDS`` + (monkey-patchable for tests). Each iteration calls + :meth:`_recover_stale_tasks` and tolerates exceptions per + per-record so a single failed reclaim does not break the + scan. Exits cleanly when ``_shutdown_event`` is set or the + task is cancelled. + """ + while not self._shutdown_event.is_set(): + try: + await asyncio.wait_for( + self._shutdown_event.wait(), + timeout=_PERIODIC_RECOVERY_INTERVAL_SECONDS, + ) + # shutdown_event was set — exit + return + except asyncio.TimeoutError: + pass + except asyncio.CancelledError: + return + try: + await self._recover_stale_tasks() + except Exception: # pylint: disable=broad-exception-caught + logger.warning("Periodic recovery scan iteration failed", exc_info=True) + + async def shutdown(self) -> None: + """Signal shutdown on all active tasks and force-expire leases. + + Called by ``AgentServerHost`` during lifespan shutdown. + """ + logger.info("TaskManager shutting down") + self._shutdown_event.set() + + # Layer 2: stop the periodic recovery scan task. + # Cancel cleanly so the shutdown event in its sleep wakes + # immediately and the task exits. + if self._periodic_recovery_task is not None: + self._periodic_recovery_task.cancel() + try: + await self._periodic_recovery_task + except ( + asyncio.CancelledError, + Exception, + ): # pylint: disable=broad-exception-caught + pass + self._periodic_recovery_task = None + + # Signal shutdown on all active contexts. Yield once so the bridge + # tasks (running in the event loop) get a chance to observe the + # shutdown event and notify their handlers before we proceed — + # otherwise on a fast lifespan teardown the shutdown grace sleep + # may be cancelled before the bridge has had a chance to fire. + for active in self._active_tasks.values(): + active.context.shutdown.set() + if self._active_tasks: + await asyncio.sleep(0) + + # Wait for tasks to checkpoint before force-expiring leases. + # On a forced lifespan teardown (e.g., HTTP test client closing) the + # sleep can be cancelled — that's fine, fall through to force-expire + # and execution_task.cancel() below so handlers wind down. + # + # Poll for ``_active_tasks`` becoming empty rather than + # an unconditional sleep so the shutdown returns promptly when + # all task bodies have checkpointed. The grace value is the + # MAXIMUM wait, not the minimum — without polling, a 25s default + # blocks every shutdown for the full window even when tasks are + # already done. + if self._active_tasks: + deadline = asyncio.get_event_loop().time() + self._shutdown_grace_seconds + try: + while self._active_tasks: + if asyncio.get_event_loop().time() >= deadline: + break + # Drop entries whose execution_task already completed + # so we don't keep waiting for them. + self._active_tasks = { + task_id: active + for task_id, active in self._active_tasks.items() + if not active.execution_task.done() + } + if not self._active_tasks: + break + await asyncio.sleep(0.05) + except asyncio.CancelledError: + logger.info("TaskManager shutdown grace period interrupted") + + # Force-expire all leases. Tolerate cancellation here too. + try: + for active in list(self._active_tasks.values()): + try: + await self._provider.update( + active.task_id, + TaskPatchRequest( + lease_owner=self._lease_owner, + lease_instance_id=self._instance_id, + lease_duration_seconds=0, + ), + ) + except Exception: # pylint: disable=broad-exception-caught + logger.warning( + "Failed to force-expire lease for task %s", + active.task_id, + exc_info=True, + ) + except asyncio.CancelledError: + logger.info("TaskManager shutdown lease-expire interrupted; " "continuing to in-process task cancellation") + + # Cancel all renewal and execution tasks. Always do this so handlers + # listening on the cancellation signal wake up and exit cleanly. + for active in self._active_tasks.values(): + active.renewal_cancel.set() + if active.renewal_task and not active.renewal_task.done(): + active.renewal_task.cancel() + if not active.execution_task.done(): + active.execution_task.cancel() + + self._active_tasks.clear() + set_task_manager(None) + + async def create_and_run( + self, + *, + fn: Callable[..., Awaitable[Any]], + fn_name: str, + task_id: str, + input_val: Any, + input_type: type[Any], + session_id: str | None, + title: str, + tags: dict[str, str], + opts: TaskOptions, + retry: RetryPolicy | None = None, + entry_mode: EntryMode = "fresh", + ) -> Any: + """Create a task, run the function, and return the result. + + :keyword fn: The async function to execute. + :paramtype fn: Callable[..., Awaitable[Any]] + :keyword fn_name: The registered function name. + :paramtype fn_name: str + :keyword task_id: Unique task identifier. + :paramtype task_id: str + :keyword input_val: The input value. + :paramtype input_val: Any + :keyword input_type: The input type. + :paramtype input_type: type[Any] + :keyword session_id: Session scope. + :paramtype session_id: str | None + :keyword tags: Task tags. + :paramtype tags: dict[str, str] + :keyword opts: Task options. + :paramtype opts: TaskOptions + :keyword entry_mode: Entry mode. + :paramtype entry_mode: EntryMode + :keyword retry: Retry policy. + :paramtype retry: RetryPolicy | None + :keyword title: Human-readable title. + :paramtype title: str + :returns: The function's return value. + :rtype: Any + :raises TaskFailed: On unhandled exception. + """ + handle = await self.create_and_start( + fn=fn, + fn_name=fn_name, + task_id=task_id, + input_val=input_val, + input_type=input_type, + session_id=session_id, + title=title, + tags=tags, + opts=opts, + retry=retry, + entry_mode=entry_mode, + ) + return await handle.result() + + async def create_and_start( # pylint: disable=too-many-locals + self, + *, + fn: Callable[..., Awaitable[Any]], + fn_name: str, + task_id: str, + input_val: Any, + input_type: type[Any], # pylint: disable=unused-argument + session_id: str | None, + title: str, + tags: dict[str, str], + opts: TaskOptions, + retry: RetryPolicy | None = None, + entry_mode: EntryMode = "fresh", + initial_payload_extras: dict[str, Any] | None = None, + ) -> TaskRun[Any]: + """Create a task, start the function, and return a handle. + + Source provenance is auto-stamped by the framework using + ``fn_name`` and the core SDK version. + + :keyword fn: The async task function. + :paramtype fn: Callable[..., Awaitable[Any]] + :keyword fn_name: Function name for logging. + :paramtype fn_name: str + :keyword task_id: The task identifier. + :paramtype task_id: str + :keyword input_val: The task input value. + :paramtype input_val: Any + :keyword input_type: Type for deserializing input. + :paramtype input_type: type[Any] + :keyword session_id: Session scope identifier. + :paramtype session_id: str | None + :keyword title: Human-readable task title. + :paramtype title: str + :keyword tags: Merged decorator + call-site tags. + :paramtype tags: dict[str, str] + :keyword opts: Task options. + :paramtype opts: TaskOptions + :keyword retry: Retry policy. + :paramtype retry: RetryPolicy | None + :keyword entry_mode: Why this execution is starting. + :paramtype entry_mode: EntryMode + :keyword initial_payload_extras: + Framework-reserved top-level payload slots (e.g., + ``{"_last_input_id": "msg-1"}``) merged into the initial + payload alongside ``input`` and ``metadata``. Reserved keys + ``input`` and ``metadata`` cannot be overridden via this + channel. + :paramtype initial_payload_extras: dict[str, Any] | None + :return: A ``TaskRun`` handle. + :rtype: TaskRun + """ + resolved_session = session_id or self._config.session_id or "local" + agent_name = self._config.agent_name or "default" + + # Build payload — input is always persisted (: + # the per-task `store_input` knob is dropped).: route the + # input through the promotion helper so > 200 KiB inputs spill into + # ``attachments["_input"]`` and ``payload["input"]`` becomes a ref + # slot. The single create-PATCH carries payload + attachments + # together (atomic). + serialized_input = _serialize_input(input_val) + input_mode, input_value = _resolve_input_storage( + serialized_input, + threshold_bytes=_INPUT_THRESHOLD_BYTES, + key_for_attachment=_FUNCTION_INPUT_KEY, + task_id=task_id, + ) + payload: dict[str, Any] = {"input": input_value} + attachments: dict[str, Any] | None = None + if input_mode == "attachment": + attachments = {_FUNCTION_INPUT_KEY: serialized_input} + payload["metadata"] = {} + #: persist a turn-start timestamp at every + # turn-start boundary so the per-turn watchdog can compute + # remaining = max(0, opts.timeout - (now - turn_started_at)) + # across crashes. Field name + format chosen per + # conformance-SOT.md §: top-level _turn_started_at, + # ISO-8601 UTC with Z suffix. + payload[_TURN_STARTED_AT_KEY] = _utc_now_iso() + + # Framework-reserved top-level slots + # (e.g., `_last_input_id`) supplied by `Task.start(input_id=...)`. + # Merged shallowly so callers cannot clobber `input` or `metadata`. + if initial_payload_extras: + for k, v in initial_payload_extras.items(): + if k in ("input", "metadata"): + continue + payload[k] = v + + # Auto-stamp source provenance (framework-owned, not user-overridable) + source = self._build_source(fn_name) + + # Auto-stamp task name tag for LIST filtering + if tags is None: + tags = {} + tags[_TAG_TASK_NAME] = fn_name + + # Create task with lease + try: + task_info = await self._provider.create( + TaskCreateRequest( + id=task_id, + agent_name=agent_name, + session_id=resolved_session, + status="in_progress", + title=title, + payload=payload, + tags=tags or None, + source=source, + attachments=attachments, + lease_owner=self._lease_owner, + lease_instance_id=self._instance_id, + lease_duration_seconds=_DEFAULT_LEASE_SECONDS, + ) + ) + except _HostedConflict as exc: + observed_status: str | None = None + if exc._code == "task_already_exists": + try: + observed = await self._provider.get(task_id) + observed_status = getattr(observed, "status", None) if observed else None + except Exception: # pylint: disable=broad-exception-caught + observed_status = None + translated = _translate_hosted_conflict(exc, task_id=task_id, observed_status=observed_status) + if translated is None: + if exc._code == "lease_ownership_changed": + raise TaskConflictError(task_id, "in_progress") from exc + raise RuntimeError(f"Task {task_id!r} create did not converge after retryable conflict") from exc + raise translated from exc + # — track the etag from the create response + # so the next PATCH carries it as if_match. + self._track_etag(task_id, getattr(task_info, "etag", None)) + + logger.info("Created resilient task %s (%s)", task_id, fn_name) + + # Register resume callback + self._resume_callbacks[fn_name] = fn + self._resume_opts[fn_name] = opts + + # Build context + cancel_event = asyncio.Event() + metadata = TaskMetadata( + flush_callback=self._make_metadata_flush(task_id), + ) + + lease_gen = task_info.lease.generation if task_info.lease else 0 + + ctx: TaskContext[Any] = TaskContext( + task_id=task_id, + session_id=resolved_session, + input=input_val, + metadata=metadata, + retry_attempt=0, + recovery_count=lease_gen, + cancel=cancel_event, + shutdown=self._shutdown_event, + entry_mode=entry_mode, + pending_count_provider=self._make_pending_count_provider(task_id), + input_id=(initial_payload_extras or {}).get("_last_input_id"), + ) + loop = asyncio.get_event_loop() + result_future: asyncio.Future[Any] = loop.create_future() + + # Start lease renewal + renewal_cancel = asyncio.Event() + + # Build steering poll callback for steerable tasks + steering_poll_cb_cs: Callable[[], Awaitable[None]] | None = None + if opts.steerable: + + async def _steering_poll_cs() -> None: + active = self._active_tasks.get(task_id) + if active is None or active.context.cancel.is_set(): + return + info = await self._provider_get_tracked(task_id) + if info is None or not info.payload: + return + st = info.payload.get("_steering", {}) + pending = st.get("pending_inputs") or [] + if pending: + # Spec 031 / FR-002 + SOT §13: record the cross-process + # observed count BEFORE setting cancel. + active._pending_input_count = len(pending) + active.context.cancel.set() + + steering_poll_cb_cs = _steering_poll_cs + + renewal_task = asyncio.create_task( + lease_renewal_loop( + self._provider, + task_id, + lease_owner=self._lease_owner, + lease_instance_id=self._instance_id, + lease_duration_seconds=_DEFAULT_LEASE_SECONDS, + cancel_event=renewal_cancel, + on_cancel_callback=cancel_event, + steering_poll_callback=steering_poll_cb_cs, + last_refresh_provider=lambda tid=task_id: ( + self._active_tasks[tid].lease_last_refresh_monotonic if tid in self._active_tasks else 0.0 + ), + # — heartbeat PATCH MUST be routed + # through the per-task write queue so it serializes + # with metadata flushes / steering / suspend / fail. + update_via_queue=self._provider_update_locked, + ) + ) + + # Start execution + terminate_event = asyncio.Event() + terminate_reason_ref: list[str | None] = [None] + execution_task = asyncio.create_task( + self._execute_task( + fn=fn, + ctx=ctx, + task_id=task_id, + opts=opts, + result_future=result_future, + renewal_cancel=renewal_cancel, + retry=retry, + terminate_event=terminate_event, + terminate_reason_ref=terminate_reason_ref, + ) + ) + + # Track active task + active = _ActiveTask( + task_id=task_id, + fn_name=fn_name, + context=ctx, + execution_task=execution_task, + renewal_task=renewal_task, + renewal_cancel=renewal_cancel, + result_future=result_future, + terminate_event=terminate_event, + fn=fn, + input_type=input_type, + opts=opts, + retry=retry, + ) + self._active_tasks[task_id] = active + + #: metadata is flushed explicitly at + # lifecycle boundaries via ``_flush_all()``. There is no auto- + # flush loop. + + return TaskRun( + task_id=task_id, + provider=self._provider, + result_future=result_future, + metadata=metadata, + cancel_event=cancel_event, + terminate_event=terminate_event, + execution_task=execution_task, + terminate_reason_ref=terminate_reason_ref, + input_id=ctx.input_id, + ) + + #: TaskManager.handle_resume + _resume_route are removed. + # Resume happens via .start()/.run() against a suspended task; the lifecycle + # state machine in _lifecycle_start_inner handles the resume transition. + + async def get_active_run(self, task_id: str) -> TaskRun[Any] | None: # pylint: disable=too-many-return-statements + """Return a TaskRun handle for an active (in-progress) task. + + : consults the store, not only + in-memory state. If the record is in-progress with a dead + lease (per :func:`_lease_is_dead`), performs inline reclaim as + a hidden side effect and returns a usable :class:`TaskRun` + bound to the new lifetime. Terminal records return ``None``. + Eviction also returns ``None`` — same shape as + "not active in this process" per Invariant 1. + + :param task_id: The task identifier. + :type task_id: str + :return: A TaskRun bound to the active task's stream handler, + or ``None`` if not active / terminal / evicted. + :rtype: TaskRun[Any] | None + """ + # Fast path: locally-tracked active execution. + active = self._active_tasks.get(task_id) + if active is not None: + return TaskRun( + task_id=task_id, + provider=self._provider, + result_future=active.result_future, + metadata=active.context.metadata, + cancel_event=active.context.cancel, + terminate_event=active.terminate_event, + execution_task=active.execution_task, + input_id=getattr(active.context, "input_id", None), + ) + + #: consult the store for tasks not active in + # this process. Reads are not rejected for orphan sandboxes + # per the spec's assumptions. + try: + task_info = await self._provider_get_tracked(task_id) + except _HostedConflict as exc: + translated = _translate_hosted_conflict(exc, task_id=task_id) + if translated is None or getattr(translated, "current_status", None) == "in_progress": + return None + raise translated from exc + except TransportClassifiedError as exc: + if _is_evicted(exc): + # Even reads classified as evicted (unexpected per + # assumption but defensive) map to "not active". + return None + raise + if task_info is None or task_info.status in ( + "completed", + "suspended", + "pending", + ): + return None + # Status is in_progress. Check whether the lease is dead per + # . If so, perform inline reclaim and re-enter as + # recovered. If reclaim fails (race lost / evicted), return None + # per Invariant 1. + if task_info.status == "in_progress" and _lease_is_dead( + task_info, + this_lease_owner=self._lease_owner, + active_locally=False, + ): + fn = self._find_resume_callback(task_info) + if fn is None: + return None + fn_name = (task_info.source or {}).get("name", task_info.agent_name) + opts = self._resume_opts.get(fn_name) + try: + await self._reclaim_one(task_info) + except _HostedConflict as exc: + translated = _translate_hosted_conflict(exc, task_id=task_id) + if translated is None or getattr(translated, "current_status", None) == "in_progress": + logger.warning( + "get_active_run: reclaim of %s lost a provider race; " + "returning None (same shape as 'not active here')", + task_id, + ) + return None + raise translated from exc + except TransportClassifiedError as exc: + if _is_evicted(exc): + logger.warning( + "get_active_run: reclaim of %s rejected with eviction; " + "returning None (same shape as 'not active here')", + task_id, + ) + return None + raise + await self._start_existing_task( + fn=fn, + fn_name=task_info.agent_name, + task_info=task_info, + entry_mode="recovered", + opts=opts, + ) + # Re-check the active-tasks table now that reclaim is done. + active = self._active_tasks.get(task_id) + if active is not None: + return TaskRun( + task_id=task_id, + provider=self._provider, + result_future=active.result_future, + metadata=active.context.metadata, + cancel_event=active.context.cancel, + terminate_event=active.terminate_event, + execution_task=active.execution_task, + ) + return None + + async def _reclaim_one(self, task_info: TaskInfo) -> "TaskInfo | None": + """: CAS-protected lease reclaim helper. + + Updates the lease ownership to this process's owner+instance + with ``If-Match: `` so two concurrent reclaims produce + exactly one winner. The LocalFileTaskProvider enforces + ``if_match`` strictly (matching the hosted task API), so the CAS + is deterministic against both providers. + + Routes through :meth:`_provider_update_locked`, which refreshes + the tracked etag from the post-reclaim record. Returns that + record so callers can pick up the post-reclaim lease + generation/instance/etag — critical for the recovery path, where + the lease-renewal heartbeat would otherwise keep sending the + stale pre-reclaim etag and 412 on its first tick. + + :param task_info: The task to reclaim. + :type task_info: TaskInfo + :return: The post-reclaim task record, or None if the provider + returned no record. + :rtype: TaskInfo | None + :raises TransportClassifiedError: With classification='evicted' + on orphan-sandbox rejection; with other classifications on + transient / conflict / permanent outcomes. + """ + etag = getattr(task_info, "etag", None) or None + return await self._provider_update_locked( + task_info.id, + TaskPatchRequest( + lease_owner=self._lease_owner, + lease_instance_id=self._instance_id, + lease_duration_seconds=_DEFAULT_LEASE_SECONDS, + if_match=etag, + ), + ) + + async def _start_existing_task( # pylint: disable=too-many-locals,too-many-statements + self, + *, + fn: Callable[..., Awaitable[Any]], + fn_name: str, + task_info: TaskInfo, + entry_mode: EntryMode, + input_val: Any | None = None, + input_type: type[Any] | None = None, + opts: TaskOptions | None = None, + retry: RetryPolicy | None = None, + ) -> TaskRun[Any]: + """Transition an existing task to in_progress and execute it. + + Used by lifecycle-aware ``.run()``/``.start()`` for suspended, + pending, and stale in_progress tasks. + + :keyword fn: The resilient task function. + :paramtype fn: Callable[..., Awaitable[Any]] + :keyword fn_name: Function name for logging. + :paramtype fn_name: str + :keyword task_info: The current task record. + :paramtype task_info: TaskInfo + :keyword entry_mode: Why this execution is starting. + :paramtype entry_mode: EntryMode + :keyword input_val: New input (overrides persisted input). + :paramtype input_val: Any | None + :keyword input_type: Type for deserializing persisted input. + :paramtype input_type: type[Any] | None + :keyword opts: Task options (uses defaults if not provided). + :paramtype opts: TaskOptions | None + :keyword retry: Retry policy. + :paramtype retry: RetryPolicy | None + :return: A TaskRun handle. + :rtype: TaskRun[Any] + """ + task_id = task_info.id + resolved_opts = opts or TaskOptions(name=fn_name, ephemeral=False) + lease_duration = _DEFAULT_LEASE_SECONDS + + #: write a new turn-start timestamp for + # every NEW turn boundary — fresh entry from suspended/pending + # and developer-initiated resume. EXCEPTION: do NOT re-stamp + # on recovery (entry_mode == "recovered") so the watchdog's + # remaining-budget computation honors the original turn-start. + turn_start_payload: dict[str, Any] = {} + if entry_mode != "recovered": + turn_start_payload[_TURN_STARTED_AT_KEY] = _utc_now_iso() + + # / SOT §11/§20: the framework does not write + # payload["output"] at any point. No clear is needed on resume. + # Decide whether this PATCH is actually necessary, and whether + # the status field belongs in it. + # + # On the recovery path the immediately-prior ``_reclaim_one`` + # call already wrote the new lease against the stale + # in_progress task, AND we explicitly do NOT re-stamp + # ``_turn_started_at`` on recovery (exception above) + # AND the existing task status is already ``in_progress``. + # In that case the PATCH would re-write the same status + + # same lease + an empty payload — a full network round-trip + # against the same record, with no observable change. Skip + # the call (and the follow-up re-fetch) entirely. + # + # For other entries (suspended/pending/queued -> in_progress) + # the PATCH is required for the status flip and/or turn-start + # write. The ``status`` field is only sent when the current + # status differs from in_progress, so we never re-write the + # same status onto a record that already carries it. + needs_status_flip = task_info.status != "in_progress" + needs_turn_start_write = bool(turn_start_payload) + if not needs_status_flip and not needs_turn_start_write: + # No-op PATCH would be sent — skip it. The reclaim has + # already established our lease; nothing else to write. + # The in-memory ``task_info`` already reflects the + # post-reclaim state we observed when ``_reclaim_one`` + # returned, so the re-fetch is also unnecessary. + updated_info: TaskInfo | None = task_info + else: + # PATCH returns the full updated TaskInfo -- no follow-up + # GET needed. (Saves one network round-trip per call.) + updated_info = await self._provider_update_locked( + task_id, + TaskPatchRequest( + status="in_progress" if needs_status_flip else None, + lease_owner=self._lease_owner, + lease_instance_id=self._instance_id, + lease_duration_seconds=lease_duration, + payload=turn_start_payload if turn_start_payload else None, + ), + ) + if updated_info is None: + raise TaskNotFound(task_id) + task_info = updated_info # type: ignore[assignment] + + # Resolve input. + # SOT §16 (recovery contract): on entry_mode == "recovered", the + # original turn's persisted input is the source of truth. Any new + # caller-provided input is irrelevant to the recovered handler — + # the developer started the same turn via the same task_id; we are + # picking up where the previous lifetime left off. For all other + # entry modes (fresh / resumed / queued), prefer the caller's + # input and fall back to persisted. + # + # ``payload["input"]`` may be a raw inline value OR a ref slot + # pointing into ``task_info.attachments``. Route the read through + # ``_read_input_value`` to handle both shapes uniformly. + use_persisted = entry_mode == "recovered" or input_val is None + if not use_persisted: + resolved_input = input_val + elif task_info.payload and "input" in task_info.payload: + raw_input = _read_input_value(task_info.payload["input"], task_info.attachments) + if input_type is not None: + resolved_input = _deserialize_input(raw_input, input_type) + else: + resolved_input = raw_input + else: + resolved_input = None + + # Build context for execution + cancel_event = asyncio.Event() + #: restore ALL namespaces, not just default. + # ``from_payload`` decodes ``payload["metadata"]`` into the default + # namespace and every ``payload["metadata:"]`` into its named + # sibling, all sharing the same flush_callback so the framework can + # _flush_all() at lifecycle boundaries. + metadata = TaskMetadata.from_payload( + task_info.payload, + flush_callback=self._make_metadata_flush(task_id), + ) + + lease_gen = task_info.lease.generation if task_info.lease else 0 + + # Extract steering context from payload + steering = (task_info.payload or {}).get("_steering", {}) + #: is_steered_turn is True if and only if + # THIS invocation was constructed by the steering-drain code + # path. For initial entry from a recovered drain (the + # crash-mid-drain case), drain_in_progress signals that the + # previous lifetime was mid-drain, so this entry IS the + # continuation of a steered turn. Sticky-True is avoided + # because pending_inputs / generation > 0 alone do NOT imply + # this entry was constructed by the drain. + is_steered_turn = bool(steering.get("drain_in_progress")) + + # For steerable recovery with drain_in_progress, use active_input + if entry_mode == "recovered" and steering.get("drain_in_progress") and "active_input" in steering: + raw_active = steering["active_input"] + if input_type is not None: + resolved_input = _deserialize_input(raw_active, input_type) + else: + resolved_input = raw_active + + # Pre-set cancel if cancel_requested is True (steering short-circuit) + if steering.get("cancel_requested"): + cancel_event.set() + + #: restore the persisted retry_attempt so the + # recovered (or developer-resumed) handler observes the correct + # cross-lifetime budget on its first invocation. ``_retry_attempt`` is + # written by ``_execute_task_loop`` on every handler-raised exception + # and cleared by the steering-drain path; default 0 covers fresh and + # never-failed tasks. + persisted_retry_attempt = (task_info.payload or {}).get("_retry_attempt") or 0 + + ctx: TaskContext[Any] = TaskContext( + task_id=task_id, + session_id=task_info.session_id, + input=resolved_input, + metadata=metadata, + retry_attempt=persisted_retry_attempt, + recovery_count=lease_gen, + cancel=cancel_event, + shutdown=self._shutdown_event, + entry_mode=entry_mode, + is_steered_turn=is_steered_turn, + pending_count_provider=self._make_pending_count_provider(task_id), + input_id=(task_info.payload or {}).get("_last_input_id"), + ) + + loop = asyncio.get_event_loop() + result_future: asyncio.Future[Any] = loop.create_future() + + renewal_cancel = asyncio.Event() + + # Build steering poll callback for steerable tasks + steering_poll_cb: Callable[[], Awaitable[None]] | None = None + if resolved_opts.steerable: + + async def _steering_poll() -> None: + """Poll provider for new steering inputs and signal cancel.""" + active = self._active_tasks.get(task_id) + if active is None or active.context.cancel.is_set(): + return + info = await self._provider_get_tracked(task_id) + if info is None or not info.payload: + return + st = info.payload.get("_steering", {}) + pending = st.get("pending_inputs") or [] + if pending: + # Spec 031 / FR-002 + SOT §13: record the cross-process + # observed count BEFORE setting cancel. + active._pending_input_count = len(pending) + active.context.cancel.set() + + steering_poll_cb = _steering_poll + + renewal_task = asyncio.create_task( + lease_renewal_loop( + self._provider, + task_id, + lease_owner=self._lease_owner, + lease_instance_id=self._instance_id, + lease_duration_seconds=lease_duration, + cancel_event=renewal_cancel, + on_cancel_callback=cancel_event, + steering_poll_callback=steering_poll_cb, + last_refresh_provider=lambda tid=task_id: ( + self._active_tasks[tid].lease_last_refresh_monotonic if tid in self._active_tasks else 0.0 + ), + # — route through the per-task write queue. + update_via_queue=self._provider_update_locked, + ) + ) + + terminate_event = asyncio.Event() + terminate_reason_ref: list[str | None] = [None] + execution_task = asyncio.create_task( + self._execute_task( + fn=fn, + ctx=ctx, + task_id=task_id, + opts=resolved_opts, + result_future=result_future, + renewal_cancel=renewal_cancel, + retry=retry, + terminate_event=terminate_event, + terminate_reason_ref=terminate_reason_ref, + ) + ) + + active = _ActiveTask( + task_id=task_id, + fn_name=fn_name, + context=ctx, + execution_task=execution_task, + renewal_task=renewal_task, + renewal_cancel=renewal_cancel, + result_future=result_future, + terminate_event=terminate_event, + fn=fn, + input_type=input_type, + opts=resolved_opts, + retry=retry, + ) + self._active_tasks[task_id] = active + + return TaskRun( + task_id=task_id, + provider=self._provider, + result_future=result_future, + metadata=metadata, + cancel_event=cancel_event, + terminate_event=terminate_event, + execution_task=execution_task, + terminate_reason_ref=terminate_reason_ref, + lease_expiry_count=task_info.lease.expiry_count if task_info.lease else 0, + ) + + async def _timeout_watchdog( + self, + timeout_seconds: float, + cancel_event: asyncio.Event, + ctx: "TaskContext[Any] | None" = None, + *, + remaining_seconds: float | None = None, + ) -> None: + """/: per-turn timeout watchdog. + + Cooperative-only. On firing, sets ``ctx.timeout_exceeded = True`` + then sets ``cancel_event`` and exits. Does NOT cancel the lease + renewal or force-stop the handler. An ignoring handler runs + until process death or external :meth:`TaskRun.cancel`. + + :param timeout_seconds: Total per-turn timeout budget (used as + the clock-skew clamp ceiling). + :type timeout_seconds: float + :param cancel_event: Event to set for cooperative cancel. + :type cancel_event: asyncio.Event + :param ctx: TaskContext to set ``timeout_exceeded`` on BEFORE + ``cancel_event`` (ordering invariant). + :type ctx: TaskContext[Any] | None + :keyword remaining_seconds: Optional override for "time left in + this turn" — used on recovery to honor the persisted + turn-start timestamp. Clamped to + ``[0, timeout_seconds]`` for clock-skew safety. + When ``None``, the watchdog uses ``timeout_seconds`` directly + (fresh-entry / drain-re-entry case). + :paramtype remaining_seconds: float | None + """ + if remaining_seconds is None: + sleep_for = timeout_seconds + else: + #: clamp to [0, timeout_seconds] in both directions. + sleep_for = max(0.0, min(remaining_seconds, timeout_seconds)) + + #: if remaining == 0 (recovered watchdog with budget + # already exceeded), fire IMMEDIATELY so the recovered handler + # sees the cause from its first checkpoint. + if sleep_for > 0: + await asyncio.sleep(sleep_for) + # ordering: cause boolean FIRST, then cancel. + if ctx is not None: + ctx.timeout_exceeded = True + cancel_event.set() + logger.info( + "Timeout watchdog fired cooperative cancel (slept %.3fs of " + "%.3fs budget; cooperative-only — handler must check " + "ctx.cancel.is_set() and ctx.timeout_exceeded to wind down)", + sleep_for, + timeout_seconds, + ) + + async def _execute_task( + self, + *, + fn: Callable[..., Awaitable[Any]], + ctx: TaskContext[Any], + task_id: str, + opts: TaskOptions, + result_future: asyncio.Future[Any], + renewal_cancel: asyncio.Event, + retry: RetryPolicy | None = None, + terminate_event: asyncio.Event | None = None, + terminate_reason_ref: list[str | None] | None = None, + ) -> None: + """Run the task function and handle completion/failure/suspend. + + When a ``RetryPolicy`` is provided, failed attempts are retried + with the configured delay and backoff. Suspend and cancellation + always exit immediately — they are not retried. + + :keyword fn: The async task function. + :paramtype fn: Callable[..., Awaitable[Any]] + :keyword ctx: The task context. + :paramtype ctx: TaskContext[Any] + :keyword task_id: The task identifier. + :paramtype task_id: str + :keyword opts: The task options. + :paramtype opts: TaskOptions + :keyword result_future: Future to resolve with the result. + :paramtype result_future: asyncio.Future[Any] + :keyword renewal_cancel: Event to cancel lease renewal. + :paramtype renewal_cancel: asyncio.Event + :keyword retry: Optional retry policy. + :paramtype retry: RetryPolicy | None + :keyword terminate_event: Optional terminate event. + :paramtype terminate_event: asyncio.Event | None + :keyword terminate_reason_ref: Mutable ref for terminate reason. + :paramtype terminate_reason_ref: list[str | None] | None + """ + resolved_terminate = terminate_event or asyncio.Event() + + # SOT §52 — per-turn timeout watchdog with resilient budget. The + # watchdog is spawned per turn (initial + every steering drain + # re-entry) so the queued turn gets a fresh full budget, not + # whatever was left over from the prior turn. + await self._spawn_watchdog_for_turn(task_id=task_id, opts=opts, ctx=ctx) + + attempt = 0 # pylint: disable=unused-variable + try: + await self._execute_task_loop( + fn=fn, + ctx=ctx, + task_id=task_id, + opts=opts, + result_future=result_future, + renewal_cancel=renewal_cancel, + retry=retry, + terminate_event=resolved_terminate, + terminate_reason_ref=terminate_reason_ref, + ) + finally: + await self._cancel_watchdog_for_turn(task_id) + + async def _spawn_watchdog_for_turn( + self, + *, + task_id: str, + opts: TaskOptions, + ctx: "TaskContext[Any]", + ) -> None: + """Spawn a per-turn timeout watchdog and register it. + + Cancels and replaces any existing watchdog for this task so the + steering-drain re-entry path can re-arm with a fresh budget. + No-op when ``opts.timeout`` is ``None``. + """ + await self._cancel_watchdog_for_turn(task_id) + if opts.timeout is None: + return + timeout_seconds = opts.timeout.total_seconds() + remaining = await self._compute_remaining_for_watchdog(task_id, timeout_seconds, ctx) + self._timeout_watchdogs[task_id] = asyncio.create_task( + self._timeout_watchdog( + timeout_seconds=timeout_seconds, + cancel_event=ctx.cancel, + ctx=ctx, + remaining_seconds=remaining, + ) + ) + + async def _cancel_watchdog_for_turn(self, task_id: str) -> None: + """Cancel and drop the registered per-turn watchdog (if any).""" + watchdog_task = self._timeout_watchdogs.pop(task_id, None) + if watchdog_task is not None and not watchdog_task.done(): + watchdog_task.cancel() + try: + await watchdog_task + except asyncio.CancelledError: + pass + + async def _compute_remaining_for_watchdog( + self, + task_id: str, + timeout_seconds: float, + ctx: "TaskContext[Any]", + ) -> float: + """: compute the remaining per-turn budget. + + Reads the persisted ``_turn_started_at`` for ``task_id`` and + returns ``max(0, timeout_seconds - (now - turn_started_at))`` + clamped to ``[0, timeout_seconds]``. If the timestamp is + missing or unparseable (e.g., a older record during + rollout), returns ``timeout_seconds`` so the watchdog spawns + with a fresh budget (graceful degradation). + + immediate-fire-on-recovery: if remaining == 0, also + pre-set ``ctx.timeout_exceeded = True`` and ``ctx.cancel`` so + the recovered handler sees the cause from its first checkpoint. + + :param task_id: The task identifier. + :type task_id: str + :param timeout_seconds: The per-turn budget configured on the + decorator (also the clock-skew clamp ceiling). + :type timeout_seconds: float + :param ctx: TaskContext used to surface the recovered cause when + the remaining budget is zero. + :type ctx: TaskContext[Any] + :return: Remaining seconds clamped to ``[0, timeout_seconds]``. + :rtype: float + """ + try: + task_info = await self._provider_get_tracked(task_id) + except Exception: # pylint: disable=broad-exception-caught + return timeout_seconds + if task_info is None or not task_info.payload: + return timeout_seconds + started_ts = _parse_turn_started_at(task_info.payload.get(_TURN_STARTED_AT_KEY)) + if started_ts is None: + return timeout_seconds + import time # pylint: disable=import-outside-toplevel + + elapsed = time.time() - started_ts + # clock-skew clamping: clamp to [0, timeout_seconds] in + # both directions (backward skew → elapsed negative → remaining + # > timeout; forward skew → elapsed huge → remaining < 0). + remaining = max(0.0, min(timeout_seconds - elapsed, timeout_seconds)) + + # immediate-fire: if recovered watchdog computes + # remaining == 0, pre-set the cause boolean + cancel before + # the handler even runs its first checkpoint. + if remaining == 0.0: + ctx.timeout_exceeded = True + ctx.cancel.set() + return remaining + + async def _execute_task_loop( # pylint: disable=too-many-statements,too-many-branches,too-many-nested-blocks,unused-argument + self, + *, + fn: Callable[..., Awaitable[Any]], + ctx: TaskContext[Any], + task_id: str, + opts: TaskOptions, + result_future: asyncio.Future[Any], + renewal_cancel: asyncio.Event, + retry: RetryPolicy | None = None, + terminate_event: asyncio.Event | None = None, + terminate_reason_ref: list[str | None] | None = None, + ) -> None: + """Inner execution loop — separated from watchdog management. + + :keyword fn: The async task function. + :paramtype fn: Callable[..., Awaitable[Any]] + :keyword ctx: The task context. + :paramtype ctx: TaskContext[Any] + :keyword task_id: The task identifier. + :paramtype task_id: str + :keyword opts: The task options. + :paramtype opts: TaskOptions + :keyword result_future: Future to resolve with the result. + :paramtype result_future: asyncio.Future[Any] + :keyword renewal_cancel: Event to cancel lease renewal. + :paramtype renewal_cancel: asyncio.Event + :keyword retry: Optional retry policy. + :paramtype retry: RetryPolicy | None + :keyword terminate_event: Optional terminate event (currently unused). + :paramtype terminate_event: asyncio.Event | None + :keyword terminate_reason_ref: Mutable ref for terminate reason + (currently unused). + :paramtype terminate_reason_ref: list[str | None] | None + """ + #: honor the persisted retry_attempt so the + # cross-lifetime budget is respected. ``_start_existing_task`` and + # ``create_and_start`` populate ``ctx.retry_attempt`` from + # ``payload["_retry_attempt"]`` (default 0 for fresh tasks). + attempt = ctx.retry_attempt + # Mutable ref: steering drain may swap the active result_future + current_result_future = result_future + while True: + ctx.retry_attempt = attempt + try: + result = await fn(ctx) + + #: the handler returned the + # _ExitForRecovery sentinel via ``ctx.exit_for_recovery()``. + # Flush metadata, release the lease, leave the stored + # status as 'in_progress' (do NOT write terminal), + # preserve queued steering inputs, and signal + # awaiters with TaskCancelled. + from ._context import ( + _ExitForRecovery as _ExitSentinel, + ) # pylint: disable=import-outside-toplevel + + if isinstance(result, _ExitSentinel): + # / — `ctx.exit_for_recovery` + # raises `TaskDeferred` (NOT `TaskCancelled`). The task + # stays `in_progress`; the recovery scanner re-invokes + # the handler in the next process lifetime. + from ._exceptions import ( # pylint: disable=import-outside-toplevel + TaskDeferred, + ) + + renewal_cancel.set() + # (a) Flush metadata (auto-flush). + await ctx.metadata._flush_all() + # (b) Release the lease (lease_duration_seconds=0) so the + # next process reclaims immediately. SOT §22: force- + # expire on exit_for_recovery. The renewal loop above + # was just cancelled but may have raced a PATCH; on + # ETag conflict re-read and retry up to 5 times so + # the release actually lands. + _release_attempts = 0 + while True: + _release_attempts += 1 + try: + await self._provider_update_locked( + task_id, + TaskPatchRequest( + lease_owner=self._lease_owner, + lease_instance_id=self._instance_id, + lease_duration_seconds=0, + ), + ) + break + except _HostedConflict as exc: + translated = _translate_hosted_conflict(exc, task_id=task_id) + # Eviction-shape conflicts: someone else already owns it + # (binding_mismatch / not_owner / etc.) → nothing to release. + if translated is not None: + logger.info( + "exit_for_recovery: lease for task %s already " + "owned by another instance (%s); no release needed", + task_id, + type(translated).__name__, + ) + break + # Pure ETag race vs our own renewer — re-read and retry. + if _release_attempts >= 5: + logger.warning( + "exit_for_recovery: lease release for task %s " + "still conflicting after %d attempts; the next " + "process startup recovery will reclaim", + task_id, + _release_attempts, + exc_info=True, + ) + break + try: + refreshed = await self._provider_get_tracked(task_id) + if refreshed is not None: + self._track_etag(task_id, getattr(refreshed, "etag", None)) + except Exception: # pylint: disable=broad-exception-caught + logger.warning( + "exit_for_recovery: failed to refresh etag " "for retry on task %s", + task_id, + exc_info=True, + ) + break + continue + except TransportClassifiedError as exc: + if not _is_evicted(exc): + logger.warning( + "exit_for_recovery: lease release for task " + "%s failed with classification=%s; the next " + "process startup recovery will reclaim", + task_id, + getattr(exc, "classification", None), + ) + break + except Exception: # pylint: disable=broad-exception-caught + logger.warning( + "exit_for_recovery: lease release for task %s " + "failed; the next process startup recovery will " + "reclaim", + task_id, + exc_info=True, + ) + break + # (c) Do NOT write a terminal record — status MUST + # remain 'in_progress' so the recovery scan picks + # it up next process start. + # (d) Signal awaiters with TaskDeferred per + # / (NOT TaskCancelled — the task + # is deferring to next lifetime, not terminating). + if not current_result_future.done(): + current_result_future.set_exception(TaskDeferred()) + # (e) Queued steerers: preserved in + # persisted state — already untouched here, so + # no action needed. + break + + # Handler returned a value (multi-turn implicit suspend, + # one-shot terminal completion). No ``Suspended`` sentinel: + # the framework's ``return X`` is the only end-of-turn + # signal. Success flow. + renewal_cancel.set() + await ctx.metadata._flush_all() + try: + completed = await self._handle_success( + task_id=task_id, + result=result, + metadata=ctx.metadata, + opts=opts, + ) + except TaskConflictError as exc: + if not current_result_future.done(): + current_result_future.set_exception(exc) + _resolve_queued_steerers_on_terminal( + self._pending_steering_futures, + task_id, + current_status=exc.current_status, + ) + break + except OutputTooLarge as exc: + # Surface OutputTooLarge to the caller directly, NOT + # wrapped in TaskFailed. The handler succeeded; the + # framework's persistence step rejected the output as + # too large. Developer-facing precondition violation, + # not a handler bug. + if not current_result_future.done(): + current_result_future.set_exception(exc) + _resolve_queued_steerers_on_terminal( + self._pending_steering_futures, + task_id, + current_status="failed", + ) + break + # Set the current turn's caller's result_future to the + # completion outcome FIRST, then resolve any queued + # steerers with TaskConflictError (since the task has now + # terminated). The handler's return value is delivered + # unchanged to the current caller; the queued steerers + # see the "task is busy / terminal" shape per Invariant 1. + is_multi_turn_success = getattr(opts, "_is_multi_turn", False) + if not current_result_future.done(): + # Both one-shot and multi-turn return the raw Output + # unwrapped; multi-turn keeps the chain alive. + current_result_future.set_result(result) + if not is_multi_turn_success: + # One-shot path: queued steerers get TaskConflictError + # on terminal completion (one-shot is never steerable + # in practice; this is defense-in-depth). + _resolve_queued_steerers_on_terminal( + self._pending_steering_futures, + task_id, + current_status="completed", + ) + else: + # Multi-turn path: try drain promotes queued head as a + # new turn. + try: + new_ctx = await self._try_drain_steering( + task_id=task_id, + ctx=ctx, + opts=opts, + result_future=current_result_future, + ) + if new_ctx is not None: + ctx = new_ctx + attempt = 0 + active = self._active_tasks.get(task_id) + if active and active.result_future is not current_result_future: + current_result_future = active.result_future + # SOT §52 — drain re-entry is a new turn boundary; + # re-arm the per-turn timeout watchdog so the queued + # turn gets its own full budget. + await self._spawn_watchdog_for_turn(task_id=task_id, opts=opts, ctx=ctx) + continue + except Exception: # noqa: BLE001 + logger.warning( + "Failed to drain steering queue after multi-turn success for task %s", + task_id, + exc_info=True, + ) + if not completed: + # Etag conflict on steerable completion — but the + # caller's future is now resolved with the completion + # outcome, so we don't re-drain; the next .start() + # will pick up any queued state. + pass + + break # exit retry loop on success or suspend + + except asyncio.CancelledError: + renewal_cancel.set() + await ctx.metadata._flush_all() + # asyncio.CancelledError is the cooperative-cancel path — + # the handler chose to raise it (or the framework signalled + # cancel via ctx.cancel and the handler did not catch). + # Resolve the caller's future with TaskCancelled. + from ._exceptions import ( # pylint: disable=import-outside-toplevel + TaskCancelled, + ) + + if not current_result_future.done(): + current_result_future.set_exception(TaskCancelled()) + + is_multi_turn_cancel = getattr(opts, "_is_multi_turn", False) + if opts.ephemeral: + # One-shot is always ephemeral: delete the persisted + # record so the recovery scanner doesn't re-invoke a + # cancelled handler. + try: + await self._provider.delete(task_id, force=True) + except Exception: # pylint: disable=broad-exception-caught + logger.warning( + "Failed to delete cancelled ephemeral task %s", + task_id, + exc_info=True, + ) + elif is_multi_turn_cancel: + # Multi-turn chain: transition the chain to ``suspended`` + # with the cancel reflected as a terminal-of-turn write + # (input + _retry_attempt + _steering.active_input + any + # promoted _input attachment cleared atomically — see + # SOT §23.8 item #3). The chain stays alive for the next + # turn's ``.start`` / ``.run``. Errors from this write + # surface only via the logger because the caller's + # future is already resolved with TaskCancelled above. + error_dict = { + "type": "cancelled", + "message": "Task cancelled", + } + try: + await self._handle_multi_turn_failure( + task_id=task_id, + exc=TaskCancelled(), + metadata=ctx.metadata, + opts=opts, + error_dict=error_dict, + ) + except TaskConflictError: + # 412 RE-READ decided ABANDON; lease is no longer + # ours, another instance / process owns the record. + # Nothing to clean up here — the caller already saw + # TaskCancelled. + pass + except Exception: # pylint: disable=broad-exception-caught + logger.warning( + "Failed to transition multi-turn chain %s " + "to suspended after cancel; chain may need " + "recovery scan to pick up the in_progress record", + task_id, + exc_info=True, + ) + # Promote queued steerers (if any) per the same drain + # rule as raise — chain stays alive, queued head takes + # over the next turn. + if opts.steerable: + await asyncio.sleep(0) + try: + new_ctx = await self._try_drain_steering( + task_id=task_id, + ctx=ctx, + opts=opts, + result_future=current_result_future, + ) + if new_ctx is not None: + ctx = new_ctx + attempt = 0 + active = self._active_tasks.get(task_id) + if active and active.result_future is not current_result_future: + current_result_future = active.result_future + # SOT §52 — drain re-entry is a new turn boundary; + # re-arm the per-turn timeout watchdog so the queued + # turn gets its own full budget. + await self._spawn_watchdog_for_turn(task_id=task_id, opts=opts, ctx=ctx) + continue + except Exception: # pylint: disable=broad-exception-caught + logger.warning( + "Failed to drain steering queue after " "multi-turn cancel for task %s", + task_id, + exc_info=True, + ) + break # cancellation is never retried + + except Exception as exc: # pylint: disable=broad-exception-caught + if retry and retry.should_retry(attempt, exc): + delay = retry.compute_delay(attempt) + logger.warning( + "Task %s attempt %d failed (%s: %s), retrying in %.1fs", + task_id, + attempt, + type(exc).__name__, + exc, + delay, + ) + # /: persist the post-bump + # retry_attempt alongside the error field in a single + # patch. A subsequent crash + recover will restore this + # counter via ``_start_existing_task`` so the resilient + # max_attempts budget is honored across lifetimes. + try: + #: NO interim error PATCH between retries. + # Only the _retry_attempt counter is persisted across retries. + await self._provider_update_locked( + task_id, + TaskPatchRequest( + payload={"_retry_attempt": attempt + 1}, + ), + ) + except Exception: # pylint: disable=broad-exception-caught + logger.debug("Failed to update _retry_attempt counter", exc_info=True) + await asyncio.sleep(delay) + attempt += 1 + continue + + # Exhausted or non-retryable — terminal failure + renewal_cancel.set() + await ctx.metadata._flush_all() + + if retry and attempt > 0: + # Retries were attempted but exhausted + error_dict: dict[str, Any] = { + "type": "exhausted_retries", + "attempts": attempt + 1, + "last_error": str(exc), + "last_error_type": type(exc).__name__, + "traceback": traceback.format_exc(), + } + else: + error_dict = { + "type": type(exc).__name__, + "message": str(exc), + "traceback": traceback.format_exc(), + } + + await self._handle_failure( + task_id=task_id, + exc=exc, + metadata=ctx.metadata, + opts=opts, + ) + # / step 5 — caller's future resolution: + # CancelledError → bare TaskCancelled() else TaskFailed. + is_multi_turn_failure = getattr(opts, "_is_multi_turn", False) + if not current_result_future.done(): + if isinstance(exc, asyncio.CancelledError): + # — bare TaskCancelled (no fields). + current_result_future.set_exception(TaskCancelled()) + else: + current_result_future.set_exception(TaskFailed(task_id, error_dict)) + # — discard callback so "Future exception + # was never retrieved" doesn't fire when no caller awaits + # (multi-turn: caller may have already moved on / GC'd). + if is_multi_turn_failure: + + def _discard(fut: asyncio.Future[Any]) -> None: + try: + fut.exception() # retrieve to silence asyncio + except Exception: # noqa: BLE001 + pass + + current_result_future.add_done_callback(_discard) + # 7-step ordering: step 5 (resolve current's + # future) MUST be observable BEFORE step 6 (promote queued + # head). Yield so any awaiter of current_result_future is + # scheduled before the next handler dispatches. + await asyncio.sleep(0) + # (Subscriber) — legacy one-shot path: queued steerers + # see TaskConflictError on terminal failure since the task is done. + # — multi-turn path: queued steerers PROMOTE + # (chain stays alive); do NOT reject them here. + if not is_multi_turn_failure: + _resolve_queued_steerers_on_terminal( + self._pending_steering_futures, + task_id, + current_status="failed", + ) + else: + # Multi-turn: chain stays in suspended; try drain steering + # queue. Promoted turn dispatches with + # ctx.entry_mode="resumed" per the existing _try_drain_steering + # mechanics. If no queued steerers, chain remains suspended. + try: + new_ctx = await self._try_drain_steering( + task_id=task_id, + ctx=ctx, + opts=opts, + result_future=current_result_future, + ) + if new_ctx is not None: + # Queued head promoted; new turn dispatching. + # _execute_task continues into next attempt with new ctx. + ctx = new_ctx + attempt = 0 + # Refresh current_result_future from rotated + # active.result_future /. + active = self._active_tasks.get(task_id) + if active and active.result_future is not current_result_future: + current_result_future = active.result_future + # SOT §52 — drain re-entry is a new turn boundary; + # re-arm the per-turn timeout watchdog so the queued + # turn gets its own full budget. + await self._spawn_watchdog_for_turn(task_id=task_id, opts=opts, ctx=ctx) + continue + except Exception: # noqa: BLE001 + logger.warning( + "Failed to drain steering queue after multi-turn raise for task %s", + task_id, + exc_info=True, + ) + break + + self._active_tasks_pop(task_id) + + async def _try_drain_steering( # pylint: disable=too-many-branches,too-many-statements,too-many-locals + self, + *, + task_id: str, + ctx: TaskContext[Any], + opts: TaskOptions, + result_future: asyncio.Future[Any], + partial_output: Any | None = None, + _conflict_attempt: int = 0, + ) -> TaskContext[Any] | None: + """Check for pending steering inputs and drain the next one. + + Called BEFORE persisting suspend/complete to avoid lease/status conflicts. + Returns a new ``TaskContext`` if a drain occurred, or ``None`` if no + pending inputs exist. + + :keyword task_id: The task identifier. + :keyword ctx: Current task context. + :keyword opts: Task options. + :keyword result_future: The current generation's result future. + :keyword partial_output: Output from the previously-running generation, + delivered in-process via ``TaskResult(output=..., status="superseded")`` + to whoever was awaiting the steered-out turn's result_future + (see ``_manager.py`` line ~1386). NOT persisted — if the + process crashes between completion and delivery, this output is + lost. (scenario 11: the previously-existing + backup write at ``_steering["generation_results"]`` was removed + because no consumer existed.) + :keyword _conflict_attempt: Internal recursion-depth counter + for etag-conflict retries. Bounded so the hosted task + store's etag-comparator pre-fix behaviour cannot loop + forever. + :return: New context for the drained generation, or None. + """ + # Spec 031 / FR-005 + SOT §25.2 — the read-state + compute-PATCH + + # apply cycle MUST be atomic under the per-task write lock so the + # in-process lease heartbeat (and any other in-process writer) cannot + # bump the etag between our read and our write. Previously the read + # was lock-free and the write pinned that lock-free etag, which let + # the heartbeat invalidate the pinned etag and (under contention) + # starve the drain's retry budget. We still pin the freshly-read etag + # (detect-not-clobber) so a genuine cross-process write is detected; + # cross-process conflicts retry OUTSIDE the lock via the recursion + # below (the per-task ``asyncio.Lock`` is non-reentrant). + drain_conflict: BaseException | None = None + async with self._get_task_write_lock(task_id): + task_info = await self._provider_get_tracked(task_id) + if task_info is None: + return None + + payload = dict(task_info.payload) if task_info.payload else {} + steering = dict(payload.get("_steering", {})) + pending = list(steering.get("pending_inputs", [])) + + if not pending: + return None + + # Pop the next input from the queue.: the entry may be + # either a raw inline value (≤ 20 KiB at append) or a ref slot + # pointing into ``task_info.attachments``. Resolve uniformly via + # ``_read_input_value``; if it was a ref, the same drain PATCH + # MUST also delete the attachment (C-9 /). + next_entry = pending.pop(0) + attachments_patch = {} + if _is_ref(next_entry): + attachments_patch[_ref_key(next_entry)] = None + next_input_raw = _read_input_value(next_entry, task_info.attachments) + + # Update steering state. (: previous_input is + # no longer mirrored into _steering; only the active input + queue + # state need to survive a crash mid-drain.) + steering["active_input"] = next_input_raw + steering["pending_inputs"] = pending + # SOT: internal + # _steering["generation"] writes removed. The drain transition + # IS the generation advance — no separate counter needed. + steering["cancel_requested"] = len(pending) > 0 + steering["drain_in_progress"] = True + #: the steering drain re-entry is a NEW + # turn-start boundary — write a fresh _turn_started_at so the + # respawned watchdog computes a full per-turn budget. + payload[_TURN_STARTED_AT_KEY] = _utc_now_iso() + payload["_steering"] = steering + # SOT §11/§20: the framework does not write payload["output"]; + # no clear is needed at the drain transition. + + try: + etag = getattr(task_info, "etag", None) or None + # Spec 031 (hosted re-test finding) — the multi-turn turn that + # just ended already wrote ``status="suspended"`` (see + # ``_handle_multi_turn_success``). The drain starts a NEW turn, + # so it MUST transition the record back to ``in_progress`` in + # this same PATCH. This is also REQUIRED for the lease-extension + # piggyback to be valid: the hosted task store rejects lease + # *renewal* on a non-in_progress task ("lease renewal is only + # supported for in_progress tasks"), but ACCEPTS lease params as + # part of a suspended→in_progress *claim*. Without the status + # flip the drain PATCH 409s and the steered turn never runs. + await self._provider_update_lock_held( + task_id, + TaskPatchRequest( + status="in_progress", + payload=payload, + attachments=attachments_patch, + if_match=etag, + **self._lease_ext_kwargs(task_id), + ), + ) + # Spec 031 / FR-002 — the drain consumed the head; the steered + # turn's live backlog is the remaining ``pending``. Keep + # ``ctx.pending_input_count`` in sync for the new turn. + active_now = self._active_tasks.get(task_id) + if active_now is not None: + active_now._pending_input_count = len(pending) + except _HostedConflict as exc: + translated = _translate_hosted_conflict(exc, task_id=task_id) + if translated is not None: + raise translated from exc + drain_conflict = exc + except (EtagConflict, ValueError, TransportClassifiedError) as exc: + if isinstance(exc, TransportClassifiedError) and getattr(exc, "classification", None) != "conflict": + raise + if isinstance(exc, ValueError) and "etag" not in str(exc).lower(): + raise + drain_conflict = exc + + # Lock released — a genuine (cross-process) conflict retries here, + # re-reading the NEW state under a fresh lock acquisition. Bounded so + # the hosted store's etag comparator cannot loop forever. + if drain_conflict is not None: + if _conflict_attempt >= 5: + raise RuntimeError( + f"Steering drain for {task_id!r} did not converge " "after 5 etag-conflict retries" + ) from drain_conflict + logger.warning( + "Etag conflict during steering drain for %s, retrying " "(attempt %d)", + task_id, + _conflict_attempt + 1, + ) + return await self._try_drain_steering( + task_id=task_id, + ctx=ctx, + opts=opts, + result_future=result_future, + partial_output=partial_output, + _conflict_attempt=_conflict_attempt + 1, + ) + + # Pop and bind the next pending steering future (if any) + new_future: asyncio.Future[Any] | None = None + steering_futures = self._pending_steering_futures.get(task_id, []) + if steering_futures: + new_future = steering_futures.pop(0) + + # Resolve the queued steerer's future binding for the new turn. + # / (Subscriber): the OLD result_future is NOT + # set to "superseded" here — the suspend path (or completion + # path) above has ALREADY set it to the natural multi-turn + # outcome before this drain runs. The drain just rotates the + # active result_future so the next turn's handler invocation + # is bound to the steerer's future (the caller that queued the + # input via .start()) if one was registered. + if new_future is None: + # No registered steerer for this drain — reuse the OLD + # result_future as the new turn's future. This is the rare + # case where the drain was triggered by a poll-based + # backlog rather than a fresh .start() call. The future + # may already be done (from the suspend resolution above); + # if so, leave it. + new_future = result_future + + # Update active generation future + if new_future is not None: + self._active_generation_future[task_id] = new_future + + # Deserialize input + active_task = self._active_tasks.get(task_id) + input_type = active_task.input_type if active_task else None + if input_type is not None: + resolved_input = _deserialize_input(next_input_raw, input_type) + else: + resolved_input = next_input_raw + + # Build new context, reusing metadata and shutdown event + cancel_event = asyncio.Event() + if steering["cancel_requested"]: + cancel_event.set() + + new_ctx: TaskContext[Any] = TaskContext( + task_id=task_id, + session_id=ctx._session_id, # pylint: disable=protected-access + input=resolved_input, + metadata=ctx.metadata, + retry_attempt=0, + recovery_count=ctx.recovery_count, + cancel=cancel_event, + shutdown=ctx.shutdown, + entry_mode="resumed", + is_steered_turn=True, + pending_count_provider=self._make_pending_count_provider(task_id), + input_id=(task_info.payload or {}).get("_last_input_id"), + ) + + # Update active task tracking + if active_task is not None: + active_task.context = new_ctx + if new_future is not None: + active_task.result_future = new_future + + # Clear drain_in_progress + steering["drain_in_progress"] = False + payload["_steering"] = steering + #: a steering input is a new logical request + # from the developer; the retry budget resets. Persist the reset so a + # subsequent crash does not resurrect the prior counter from + # ``payload["_retry_attempt"]``. + payload["_retry_attempt"] = 0 + try: + await self._provider_update_locked( + task_id, + TaskPatchRequest( + payload=payload, + **self._lease_ext_kwargs(task_id), + ), + ) + except Exception: # pylint: disable=broad-exception-caught + logger.debug("Failed to clear drain_in_progress for %s", task_id) + + logger.info( + "Steering drain: task %s drained next input", + task_id, + ) + return new_ctx + + async def _handle_multi_turn_success( + self, + *, + task_id: str, + metadata: TaskMetadata, + opts: TaskOptions, + ) -> bool: + """Multi-turn return handler. + + : + - Multi-turn ``return X`` is implicit suspend. Chain transitions to + ``suspended`` (NOT ``completed``) so it accepts the next input. + - NO ``payload["output"]`` is written. + - ``payload["input"]`` cleared at the transition. + - ``payload["_retry_attempt"]`` cleared too. + - ``payload["_last_input_id"]`` preserved for the + ``if_last_input_id`` precondition. + - ``suspension_reason="run_completion"`` stamped internally. + + Returns True (terminal write succeeded). False is reserved for + the legacy etag-conflict-retry-drain pattern; the multi-turn + path raises TaskConflictError on 412 instead. + """ + # Auto-flush metadata BEFORE the chain PATCH. + try: + await metadata._flush_all() # noqa: SLF001 — framework-internal fence + except Exception: # noqa: BLE001 + logger.warning( + "Failed to auto-flush metadata before multi-turn success PATCH for task %s", + task_id, + exc_info=True, + ) + + # SOT §23.8 item #3 — the turn-end PATCH MUST atomically clear ALL of: + # payload["input"], payload["_steering"]["active_input"], + # payload["_retry_attempt"], and (if input was promoted) the + # attachments["_input"] entry. Splitting any of these into + # multiple PATCHes opens a crash window where the attachment + # exists without its ref (or vice versa). + task_info = await self._provider_get_tracked(task_id) + if task_info is not None: + self._track_etag(task_id, getattr(task_info, "etag", None)) + steering_patch: dict[str, Any] = {} + attachments_patch: dict[str, Any] = {} + if task_info is not None and task_info.payload: + existing_steering = task_info.payload.get("_steering") or {} + if existing_steering: + steering_patch = dict(existing_steering) + steering_patch["active_input"] = None + existing_input_slot = task_info.payload.get("input") + if _is_ref(existing_input_slot): + attachments_patch[_ref_key(existing_input_slot)] = None + + payload_patch: dict[str, Any] = { + "metadata": metadata.to_dict(), + "input": None, + "_retry_attempt": None, + # NO "output", NO "error" + } + if steering_patch: + payload_patch["_steering"] = steering_patch + + try: + await self._terminal_write_locked( + task_id, + TaskPatchRequest( + status="suspended", + suspension_reason="run_completion", + payload=payload_patch, + attachments=attachments_patch or None, + ), + ) + except TaskConflictError: + raise + except _HostedConflict as hosted_exc: + translated = _translate_hosted_conflict(hosted_exc, task_id=task_id) + if translated is None: + if hosted_exc._code == "lease_ownership_changed": + raise TaskConflictError(task_id, "in_progress") from hosted_exc + raise EtagConflict(task_id) from hosted_exc + raise translated from hosted_exc + except TransportClassifiedError as transport_exc: + if _is_evicted(transport_exc): + logger.warning( + "Eviction on multi-turn return PATCH for task %s — " "signalling awaiters with TaskConflictError", + task_id, + ) + raise TaskConflictError(task_id, "in_progress") from transport_exc + raise + return True + + async def _handle_success( + self, + *, + task_id: str, + result: Any, + metadata: TaskMetadata, + opts: TaskOptions, + ) -> bool: + """Handle successful task completion. + + / /: multi-turn handlers (decorated + with @multi_turn_task — TaskOptions._is_multi_turn=True) treat + ``return X`` as the implicit-suspend signal. The framework + transitions the chain to ``suspended`` with + ``suspension_reason="run_completion"``, NO ``payload["output"]`` + is written, and ``payload["input"]`` is cleared. + The caller's ``.result()`` future will be resolved with ``X`` + directly by the caller path (preserving the return value). + + Legacy one-shot (ephemeral) and non-ephemeral-non-multi-turn paths + keep their existing behavior during the transition window. + + :keyword task_id: The task identifier. + :paramtype task_id: str + :keyword result: The task result value. + :paramtype result: Any + :keyword metadata: The task metadata. + :paramtype metadata: TaskMetadata + :keyword opts: The task options. + :paramtype opts: TaskOptions + :return: True if completion succeeded, False if etag conflict + detected (steerable tasks only — caller should re-drain). + :rtype: bool + """ + # — multi-turn success → suspended (NOT completed), + # no payload['output'] written, payload['input'] cleared. + is_multi_turn = getattr(opts, "_is_multi_turn", False) + if is_multi_turn: + return await self._handle_multi_turn_success( + task_id=task_id, + metadata=metadata, + opts=opts, + ) + + # One-shot tasks are always ephemeral — delete on terminal exit. + try: + await self._provider.delete(task_id, force=True) + except Exception: # pylint: disable=broad-exception-caught + logger.warning("Failed to delete ephemeral task %s", task_id, exc_info=True) + + logger.info("Task %s completed successfully", task_id) + return True + + async def _handle_multi_turn_failure( + self, + *, + task_id: str, + exc: Exception, + metadata: TaskMetadata, + opts: TaskOptions, + error_dict: dict[str, Any], + ) -> None: + """Multi-turn raise handler. + + Per 7-step ordering: + 1. (caller) Run the failure handler (this method). + 2. Auto-flush ctx.metadata BEFORE the chain-PATCH (load-bearing). + 3. Clear payload["input"] and payload["_retry_attempt"]. + 4. PATCH chain record to ``suspended`` (NOT ``completed``) with + ``suspension_reason="run_completion"``. No ``payload["error"]`` + is written. ``payload["_last_input_id"]`` MUST be + preserved. Steering queue MUST be preserved. + 5. (caller) Resolve current caller's.result future: + ``CancelledError`` → bare ``TaskCancelled()`` else + ``TaskFailed(error_dict)``. + 6. (caller) If queued steerers exist, promote head. + 7. (caller) Else leave chain in ``suspended`` awaiting future + ``.run()`` / ``.start()``. + + Steps 5/6/7 are handled by the caller (`_execute_task`) after this + method returns; this method owns steps 2/3/4. + """ + # Step 2: auto-flush metadata BEFORE the chain-PATCH. + try: + await metadata._flush_all() # noqa: SLF001 — framework-internal fence + except Exception: # noqa: BLE001 + logger.warning( + "Failed to auto-flush metadata before multi-turn failure PATCH for task %s", + task_id, + exc_info=True, + ) + + # Step 3 + 4: PATCH to suspended (NOT completed); clear input + _retry_attempt + # + _steering.active_input + promoted _input attachment if any (SOT §23.8 + # single-PATCH invariant); NO payload["error"] written; _last_input_id preserved. + task_info = await self._provider_get_tracked(task_id) + if task_info is not None: + self._track_etag(task_id, getattr(task_info, "etag", None)) + steering_patch: dict[str, Any] = {} + attachments_patch: dict[str, Any] = {} + if task_info is not None and task_info.payload: + existing_steering = task_info.payload.get("_steering") or {} + if existing_steering: + steering_patch = dict(existing_steering) + steering_patch["active_input"] = None + existing_input_slot = task_info.payload.get("input") + if _is_ref(existing_input_slot): + attachments_patch[_ref_key(existing_input_slot)] = None + + payload_patch: dict[str, Any] = { + "metadata": metadata.to_dict(), + "input": None, + "_retry_attempt": None, + # NO "output", NO "error" + } + if steering_patch: + payload_patch["_steering"] = steering_patch + + try: + await self._terminal_write_locked( + task_id, + TaskPatchRequest( + status="suspended", + suspension_reason="run_completion", + payload=payload_patch, + attachments=attachments_patch or None, + ), + ) + except TaskConflictError: + # 412 RE-READ decided ABANDON; propagate so the active caller + # receives the eviction-shape exception. + raise + except _HostedConflict as hosted_exc: + translated = _translate_hosted_conflict(hosted_exc, task_id=task_id) + if translated is None: + if hosted_exc._code == "lease_ownership_changed": + raise TaskConflictError(task_id, "in_progress") from hosted_exc + raise EtagConflict(task_id) from hosted_exc + raise translated from hosted_exc + except TransportClassifiedError as transport_exc: + if _is_evicted(transport_exc): + logger.warning( + "Eviction on multi-turn raise PATCH for task %s — " "signalling awaiters with TaskConflictError", + task_id, + ) + raise TaskConflictError(task_id, "in_progress") from transport_exc + raise + except Exception: # noqa: BLE001 + logger.warning( + "Failed to PATCH multi-turn suspended-on-raise for task %s", + task_id, + exc_info=True, + ) + # — structured failure log/telemetry for every handler + # failure, independent of listener presence. Logged at ERROR per + # (the chain has just lost a turn). + active = self._active_tasks.get(task_id) + input_id = None + if active is not None: + input_id = getattr(active.context, "input_id", None) + logger.error( + "resilient_task_handler_failure: task=%s exc_type=%s", + task_id, + type(exc).__name__, + extra={ + "event": "resilient_task_handler_failure", + "event_name": "resilient_task_handler_failure", + "task_id": task_id, + "input_id": input_id, + "error_type": type(exc).__name__, + "error_message": str(exc), + "primitive": "multi_turn_task", + }, + ) + + async def _handle_failure( + self, + *, + task_id: str, + exc: Exception, + metadata: TaskMetadata, + opts: TaskOptions, + ) -> None: + """Handle task failure. + + / / — multi-turn handlers (decorated + with @multi_turn_task — TaskOptions._is_multi_turn=True) transition + to ``suspended`` (chain stays alive) on raise, NOT ``completed``. + Per NO ``payload["error"]`` is written for multi-turn + failures. Per ``payload["input"]`` and + ``payload["_retry_attempt"]`` are cleared. + + Legacy one-shot (ephemeral) and non-ephemeral-non-multi-turn paths + keep their existing behavior during the transition window. + + :keyword task_id: The task identifier. + :paramtype task_id: str + :keyword exc: The exception that caused the failure. + :paramtype exc: Exception + :keyword metadata: The task metadata. + :paramtype metadata: TaskMetadata + :keyword opts: The task options. + :paramtype opts: TaskOptions + """ + error_dict = { + "type": type(exc).__name__, + "message": str(exc), + "traceback": traceback.format_exc(), + } + + # — multi-turn raise → suspended (NOT completed). + # Auto-flush metadata BEFORE the chain-PATCH (step 2 of). + is_multi_turn = getattr(opts, "_is_multi_turn", False) + if is_multi_turn: + await self._handle_multi_turn_failure( + task_id=task_id, + exc=exc, + metadata=metadata, + opts=opts, + error_dict=error_dict, + ) + return + + # One-shot tasks are always ephemeral — delete on terminal failure. + try: + await self._provider.delete(task_id, force=True) + except _HostedConflict as hosted_exc: + translated = _translate_hosted_conflict(hosted_exc, task_id=task_id) + if translated is None: + raise TaskConflictError(task_id, "in_progress") from hosted_exc + raise translated from hosted_exc + except TransportClassifiedError as transport_exc: + if _is_evicted(transport_exc): + logger.warning( + "Eviction (binding_mismatch) on failed-task delete for " + "task %s (session=%s) — suppressing delete, signalling " + "awaiters with TaskConflictError", + task_id, + self._config.session_id or "local", + ) + raise TaskConflictError(task_id, "in_progress") from transport_exc + raise + except Exception: # pylint: disable=broad-exception-caught + logger.warning( + "Failed to delete failed ephemeral task %s", + task_id, + exc_info=True, + ) + + logger.error("Task %s failed: %s", task_id, exc) + + async def _steering_cleanup_orphan_attachments(self, task_info: TaskInfo) -> "TaskInfo | None": + """— delete orphaned ``_steering_input_*`` attachments. + + On startup-scan / recovery, walk ``task_info.attachments`` for + ``_steering_input_*`` keys whose corresponding ref slot is no + longer present in ``pending_inputs``. Delete them via a single + PATCH. + + This is defense-in-depth: the steering-append PATCH and the + steering-drain PATCH each carry payload + attachments in one + atomic write, so the happy path never produces orphans. But a + crash window between an attachment add and a queue append + (across separate PATCHes in some future code path) could + theoretically leave one — this cleanup costs ~one extra PATCH + per recovery and closes that window. + + :param task_info: The recovered ``TaskInfo`` (pre-reclaim). + :type task_info: TaskInfo + :return: The updated task record when a cleanup PATCH was + issued (so the caller can refresh its stale ``task_info`` + before reclaim), or None when nothing was written. + :rtype: TaskInfo | None + """ + if not task_info.attachments: + return None + from ._attachments import ( # pylint: disable=import-outside-toplevel + _STEERING_INPUT_KEY_PREFIX, + ) + + steering_keys = {k for k in task_info.attachments if k.startswith(_STEERING_INPUT_KEY_PREFIX)} + if not steering_keys: + return None + pending: list[Any] = (task_info.payload or {}).get("_steering", {}).get("pending_inputs", []) + referenced = { + _ref_key(entry) + for entry in pending + if _is_ref(entry) and _ref_key(entry).startswith(_STEERING_INPUT_KEY_PREFIX) + } + orphans = steering_keys - referenced + if not orphans: + return None + logger.info( + "Deleting %d orphan steering attachment(s) on task %s: %s", + len(orphans), + task_info.id, + sorted(orphans), + ) + return await self._provider_update_locked( + task_info.id, + TaskPatchRequest( + attachments={k: None for k in orphans}, + if_match=getattr(task_info, "etag", None) or None, + ), + ) + + async def _recover_stale_tasks(self) -> None: + """Recover stale in-progress tasks from previous instances.""" + agent_name = self._config.agent_name or "default" + session_id = self._config.session_id or "local" + + try: + # / C-FLT-1 — scope the recovery scan to + # framework-owned tasks via source_type. Tasks created by + # other systems sharing the same (agent, session, + # lease_owner) triple MUST NOT be enumerated by the + # framework's reclaim path. + stale_tasks = await self._provider.list( + agent_name=agent_name, + session_id=session_id, + status="in_progress", + lease_owner=self._lease_owner, + source_type=_SOURCE_TYPE, + ) + except Exception: # pylint: disable=broad-exception-caught + logger.warning("Failed to query stale tasks for recovery", exc_info=True) + return + + for task_info in stale_tasks: + # Skip if we're already tracking this task + if task_info.id in self._active_tasks: + continue + + # — opportunistic orphan attachment cleanup. If a prior + # lifetime crashed between a steering-append attachment PATCH + # and the queue update (cannot happen in the happy path + # because Phase 4 makes them a single atomic PATCH, but + # defense-in-depth is cheap), delete any + # ``_steering_input_*`` attachment that no live ref in + # ``pending_inputs`` references. + try: + refreshed = await self._steering_cleanup_orphan_attachments(task_info) + if refreshed is not None: + # Cleanup wrote — adopt the post-cleanup record so the + # reclaim below carries the current etag (else reclaim + # 412s on the stale scan etag and recovery is skipped). + task_info = refreshed + except Exception: # pylint: disable=broad-exception-caught + logger.warning( + "Orphan attachment cleanup failed for %s", + task_info.id, + exc_info=True, + ) + + # Reclaim the lease with our new instance ID + try: + # / C-LSE-2 — both reclaim sites + # (inline AND cold-start/periodic) carry if_match. On + # 412, ABANDON per §25.3 — another process beat us; + # let the next scan re-evaluate. + # + # Route through _reclaim_one so the reclaim takes the + # per-task write lock AND refreshes the tracked etag from + # the post-reclaim record. Adopt that record as task_info + # so (a) the lease-renewal heartbeat's tracked etag + # matches the store — otherwise its first tick sends the + # stale pre-reclaim etag, 412s, and recovery is cancelled + # as "lost ownership" ~one lease-half-life in — and (b) + # _start_existing_task sees the post-reclaim lease + # generation/instance. + reclaimed_info = await self._reclaim_one(task_info) + if reclaimed_info is not None: + task_info = reclaimed_info + logger.info( + "Reclaimed stale task %s (generation will increment)", + task_info.id, + ) + except _HostedConflict as exc: + translated = _translate_hosted_conflict(exc, task_id=task_info.id) + if translated is None or getattr(translated, "current_status", None) == "in_progress": + logger.info( + "Reclaim conflict for task %s — another process beat us; " "letting next scan re-evaluate.", + task_info.id, + ) + continue + logger.warning("Failed to reclaim task %s", task_info.id, exc_info=True) + continue + except (EtagConflict, ValueError) as exc: + # 412 ABANDON for reclaim per §25.3. + if isinstance(exc, ValueError) and "etag" not in str(exc).lower(): + logger.warning("Failed to reclaim task %s", task_info.id, exc_info=True) + continue + logger.info( + "Reclaim 412 for task %s — another process beat us; " + "letting next scan re-evaluate (/ §25.3 ABANDON).", + task_info.id, + ) + continue + except Exception: # pylint: disable=broad-exception-caught + logger.warning("Failed to reclaim task %s", task_info.id, exc_info=True) + continue + + # Find resume callback and dispatch + fn = self._find_resume_callback(task_info) + if fn is not None: + try: + # Look up stored opts for resumed-task configuration. + fn_name = (task_info.source or {}).get("name", "") + opts = self._resume_opts.get(fn_name) + await self._start_existing_task( + fn=fn, + fn_name=task_info.agent_name, + task_info=task_info, + entry_mode="recovered", + opts=opts, + ) + logger.info("Recovered task %s is now active", task_info.id) + except Exception: # pylint: disable=broad-exception-caught + logger.warning( + "Failed to resume recovered task %s", + task_info.id, + exc_info=True, + ) + + def _find_resume_callback(self, task_info: TaskInfo) -> Callable[..., Any] | None: + """Find a registered resume callback for a task. + + Matches by ``source.name`` (auto-stamped function name) first, + then falls back to title prefix match or single-callback default. + + :param task_info: The task record to match. + :type task_info: TaskInfo + :return: A matching resume callback, or None. + :rtype: Callable[..., Any] | None + """ + # Preferred: match by source.name (framework auto-stamped fn name) + if task_info.source and "name" in task_info.source: + source_name = task_info.source["name"] + if source_name in self._resume_callbacks: + return self._resume_callbacks[source_name] + + # Fallback: title prefix match + for name, fn in self._resume_callbacks.items(): + if task_info.title and task_info.title.startswith(name): + return fn + + # Last resort: single registered callback + if len(self._resume_callbacks) == 1: + return next(iter(self._resume_callbacks.values())) + return None + + # --------------------------------------------------------------- # + # — Per-task write queue + etag tracking + # --------------------------------------------------------------- # + + def _get_task_write_lock(self, task_id: str) -> asyncio.Lock: + """/ C-WQ-1 — return the per-task write lock. + + Lazily creates the lock on first use. All in-process PATCH- + issuing code paths MUST acquire this lock before reading + state + computing the PATCH + applying it. + + Reads do NOT call this method (— reads are lock-free). + + The lock entry is dropped by :meth:`_active_tasks_pop` when + the local active-entry is torn down. + """ + lock = self._task_write_locks.get(task_id) + if lock is None: + lock = asyncio.Lock() + self._task_write_locks[task_id] = lock + return lock + + def _track_etag(self, task_id: str, etag: str | None) -> None: + """— refresh the latest known etag for a task. + + Called by every store-interaction site after a successful + response carries an etag. Stored in two places: the per-task + etag cache (so reclaim/scan paths without an _ActiveTask can + still benefit) AND, if present, on the _ActiveTask entry + itself. + """ + if etag is None: + return + self._task_etag_cache[task_id] = etag + active = self._active_tasks.get(task_id) + if active is not None: + active.current_etag = etag + + def _get_tracked_etag(self, task_id: str) -> str | None: + """— read the latest tracked etag for a task. + + Returns ``None`` if no PATCH/GET response has been observed + yet (this can happen on the very first write — typically a + ``create`` where ``if_match`` is intentionally absent). + """ + active = self._active_tasks.get(task_id) + if active is not None and active.current_etag is not None: + return active.current_etag + return self._task_etag_cache.get(task_id) + + def _active_tasks_pop(self, task_id: str) -> None: + """— pop the active task entry AND drop its + per-task write lock + etag cache so the registries do not + leak across many task lifetimes. + """ + self._active_tasks.pop(task_id, None) + self._task_write_locks.pop(task_id, None) + self._task_etag_cache.pop(task_id, None) + + async def _provider_get_tracked(self, task_id: str) -> Any: + """— read a task AND refresh the tracked etag. + + Thin wrapper around ``self._provider.get(task_id)`` that calls + ``_track_etag`` on the response's etag. Use at every read site + where a subsequent PATCH may rely on the latest etag (the + normal read-then-PATCH pattern across the framework). + """ + try: + info = await self._provider.get(task_id) + except _HostedConflict as exc: + translated = _translate_hosted_conflict(exc, task_id=task_id) + if translated is None: + if exc._code == "lease_ownership_changed": + raise TaskConflictError(task_id, "in_progress") from exc + raise EtagConflict(task_id) from exc + raise translated from exc + if info is not None: + self._track_etag(task_id, getattr(info, "etag", None)) + return info + + async def _provider_update_lock_held( + self, + task_id: str, + patch: TaskPatchRequest, + *, + force_if_match: bool = True, + ) -> Any: + """Spec 031 / FR-005a — apply a PATCH while the per-task write lock + is ALREADY held by the caller. + + The per-task lock is a non-reentrant ``asyncio.Lock``; callers that + already hold it (e.g. ``_cancel_queued_steering_input``, the steering + drain) MUST use this variant rather than :meth:`_provider_update_locked` + to avoid self-deadlock. It selects ``if_match`` from the tracked etag + when the caller has not set one (no blind writes — SOT §25.1), + refreshes the tracked etag from the response, and bumps the + lease-last-refresh when the PATCH piggybacked the lease. + + The caller is responsible for holding ``_get_task_write_lock(task_id)``. + """ + if force_if_match and patch.if_match is None: + patch.if_match = self._get_tracked_etag(task_id) + result = await self._provider.update(task_id, patch) + etag = getattr(result, "etag", None) + if etag: + self._track_etag(task_id, etag) + if patch.lease_owner is not None: + self._note_lease_refreshed(task_id) + return result + + async def _provider_update_locked( + self, + task_id: str, + patch: TaskPatchRequest, + *, + force_if_match: bool = True, + ) -> Any: + """/ C-WQ-3 — apply a PATCH under the per-task + write lock with the tracked etag as ``if_match``. + + - Acquires the per-task write lock. + - Populates ``patch.if_match`` from the tracked etag when the + caller hasn't set one and ``force_if_match=True``. + - Calls ``self._provider.update(task_id, patch)``. + - Refreshes the tracked etag from the response. + - Bumps lease-last-refresh if the PATCH carried lease ext + kwargs (— dynamic cadence shadows next heartbeat). + + Does NOT implement the RE-READ-AND-DECIDE policy — + that lives in :meth:`_terminal_write_locked` for the terminal + suspend/complete/fail sites. Delegates the actual write to + :meth:`_provider_update_lock_held` (Spec 031 / FR-005a) so the + lock-held and lock-acquiring paths share one implementation. + """ + async with self._get_task_write_lock(task_id): + return await self._provider_update_lock_held(task_id, patch, force_if_match=force_if_match) + + async def _terminal_write_locked( + self, + task_id: str, + patch: TaskPatchRequest, + *, + max_attempts: int = 5, + ) -> Any: + """/ C-WQ-3 / SC-3b — terminal-write 412 + RE-READ-AND-DECIDE. + + On 412 (EtagConflict from the provider, OR a hosted-provider + TransportClassifiedError(classification='conflict')), the + framework re-reads the record and decides: + + - (a) Lease no longer ours (owner / instance_id differ, or + ``expiry_count`` bumped past our cached value) → ABANDON + and raise ``TaskConflictError(current_status='in_progress')``. + The new owner is mid-recovery; clobbering their state would + silently cancel their execution. + - (b) ``status`` already ``completed`` → ABANDON. Another + actor already wrote the terminal; raise + ``TaskConflictError(current_status='completed')``. + - (c) Lease still ours, status still ``in_progress`` → retry + the terminal PATCH against the new etag, up to + ``max_attempts`` times. Steering inputs another process + appended in the racing window are silently superseded — + the steerer's ``.result()`` then raises + ``TaskConflictError(current_status='completed')`` per the + C-STR-6 cross-process steering-after-terminate contract. + + Default budget is 5 attempts. + """ + prior_lease_owner = patch.lease_owner + prior_lease_instance = patch.lease_instance_id + async with self._get_task_write_lock(task_id): + attempts = 0 + cached_expiry_count = self._cached_expiry_count(task_id) + while True: + attempts += 1 + if patch.if_match is None: + patch.if_match = self._get_tracked_etag(task_id) + try: + result = await self._provider.update(task_id, patch) + etag = getattr(result, "etag", None) + if etag: + self._track_etag(task_id, etag) + return result + except _HostedConflict as exc: + translated = _translate_hosted_conflict(exc, task_id=task_id) + if translated is not None: + raise translated from exc + if attempts >= max_attempts: + if exc._code == "lease_ownership_changed": + raise TaskConflictError(task_id, "in_progress") from exc + raise EtagConflict(task_id) from exc + decision = await self._terminal_412_decide( + task_id, + prior_lease_owner=prior_lease_owner, + prior_lease_instance=prior_lease_instance, + cached_expiry_count=cached_expiry_count, + ) + if decision == "abandon_lease_lost": + raise TaskConflictError(task_id, "in_progress") from exc + if decision == "abandon_already_terminal": + raise TaskConflictError(task_id, "completed") from exc + patch.if_match = None + except (EtagConflict, ValueError) as exc: + # The local provider raises ValueError on etag + # mismatch; the hosted provider raises + # TransportClassifiedError(classification="conflict") + # which the caller translates to EtagConflict at + # the boundary. Both arrive here as either type. + if isinstance(exc, ValueError) and "etag" not in str(exc).lower(): + raise + if attempts >= max_attempts: + raise + decision = await self._terminal_412_decide( + task_id, + prior_lease_owner=prior_lease_owner, + prior_lease_instance=prior_lease_instance, + cached_expiry_count=cached_expiry_count, + ) + if decision == "abandon_lease_lost": + raise TaskConflictError(task_id, "in_progress") from exc + if decision == "abandon_already_terminal": + raise TaskConflictError(task_id, "completed") from exc + # decision == "retry" — clear if_match and loop. + patch.if_match = None + except TransportClassifiedError as exc: + # Hosted-provider conflict (412 etag) or eviction + # (binding_mismatch). Eviction goes to the eviction + # path — fall through to the existing handler shape. + if getattr(exc, "classification", "") == "conflict": + if attempts >= max_attempts: + raise + decision = await self._terminal_412_decide( + task_id, + prior_lease_owner=prior_lease_owner, + prior_lease_instance=prior_lease_instance, + cached_expiry_count=cached_expiry_count, + ) + if decision == "abandon_lease_lost": + raise TaskConflictError(task_id, "in_progress") from exc + if decision == "abandon_already_terminal": + raise TaskConflictError(task_id, "completed") from exc + patch.if_match = None + continue + raise + + def _cached_expiry_count(self, task_id: str) -> int: + """Best-effort cache of the prior lease.expiry_count for + branch (a) detection. Not authoritative; absence means "no + cached value" and the decision falls back on lease owner / + instance_id comparison. + """ + return getattr(self, "_expiry_count_cache", {}).get(task_id, 0) + + async def _terminal_412_decide( + self, + task_id: str, + *, + prior_lease_owner: str | None, + prior_lease_instance: str | None, + cached_expiry_count: int, + ) -> str: + """— decide what to do after a terminal-write 412. + + Returns one of: + + - ``"abandon_lease_lost"`` — RE-READ shows lease no longer ours + (owner or instance_id differ). New owner is authoritative; + do not retry. + - ``"abandon_already_terminal"`` — RE-READ shows status already + terminal (``completed``). + - ``"retry"`` — Lease still ours, status still ``in_progress``; + safe to retry against the new etag. + + Note: per C-LSE-3, every real expiry-driven handoff bumps the + ``lease_instance_id``, so instance-id comparison alone is + sufficient to detect lease loss. An additional ``expiry_count`` + leg would require populating a snapshot cache at every write + site (otherwise the default ``cached_expiry_count=0`` causes + any reclaimed task with `expiry_count >= 1` to spuriously + abandon on legitimate retry-able 412s). We rely on instance-id + comparison and intentionally do NOT consult ``expiry_count`` + in this decision. + """ + _ = cached_expiry_count # retained for binary-compat / future use + try: + fresh = await self._provider_get_tracked(task_id) + except Exception: # pylint: disable=broad-exception-caught + # Can't re-read — be conservative; treat as lost. + return "abandon_lease_lost" + if fresh is None: + # Record vanished — treat as terminal. + return "abandon_already_terminal" + # Refresh tracked etag from the re-read. + etag = getattr(fresh, "etag", None) + if etag: + self._track_etag(task_id, etag) + # Branch (b): already terminal. + if getattr(fresh, "status", None) == "completed": + return "abandon_already_terminal" + # Branch (a): lease no longer ours (owner or instance_id differ). + if ( + fresh.lease is None + or fresh.lease.owner != (prior_lease_owner or self._lease_owner) + or fresh.lease.instance_id != (prior_lease_instance or self._instance_id) + ): + return "abandon_lease_lost" + # Branch (c): retry. + return "retry" + + def _lease_ext_kwargs(self, task_id: str) -> dict[str, Any]: + """Return lease-ownership kwargs for piggyback on a payload PATCH. + + Every framework-issued PATCH that mutates payload (metadata + flush, steering-queue append, steering drain, terminal complete + on a steerable task) can refresh the lease as a side effect by + including the lease ownership query params on the request. This + eliminates the once-per-30-second redundant heartbeat PATCH for + an active task and pushes the renewal-loop tick out via + ``_note_lease_refreshed`` below. Zero extra network round-trips: + the lease params land on the same PATCH that was already going + out for the payload mutation. + + Returns the kwargs only when ``task_id`` is currently tracked + as an active local task. Otherwise returns an empty dict + (caller writes a plain payload-only PATCH; this is what + recovery/reclaim/restart paths want before they have bound a + new lease). + + :param task_id: The task identifier. + :type task_id: str + :return: kwargs for ``TaskPatchRequest`` carrying lease params, + or ``{}`` if this task is not active locally. + :rtype: dict[str, Any] + """ + if self._active_tasks.get(task_id) is None: + return {} + return { + "lease_owner": self._lease_owner, + "lease_instance_id": self._instance_id, + "lease_duration_seconds": _DEFAULT_LEASE_SECONDS, + } + + def _note_lease_refreshed(self, task_id: str) -> None: + """Record that the lease for ``task_id`` was just refreshed. + + Called by every PATCH path that piggybacks lease ownership + (see :meth:`_lease_ext_kwargs`) AND by the renewal loop itself + on a successful renewal. The renewal loop reads this timestamp + to push its next scheduled tick out -- so a payload PATCH that + already refreshed the lease delays the heartbeat by the same + margin, avoiding a redundant network round-trip. + + :param task_id: The task identifier. + :type task_id: str + """ + active = self._active_tasks.get(task_id) + if active is None: + return + try: + active.lease_last_refresh_monotonic = asyncio.get_event_loop().time() + except RuntimeError: # no running loop (sync context) + pass + + def _make_metadata_flush(self, task_id: str) -> Callable[[Optional[str], dict[str, Any]], Awaitable[None]]: + """Create a per-namespace flush callback for metadata persistence. + + The callback persists each namespace into its dedicated payload + slot (layout): ``payload["metadata"]`` for the + default namespace and ``payload["metadata:"]`` for named + namespaces. Patches are shallow-merged by the provider so + flushing one namespace does NOT clobber another. + + :param task_id: The task identifier. + :type task_id: str + :return: An async callback that flushes one namespace. + :rtype: Callable[[Optional[str], dict[str, Any]], Awaitable[None]] + """ + + async def _flush(namespace: Optional[str], data: dict[str, Any]) -> None: + slot = "metadata" if namespace is None else f"metadata:{namespace}" + # / — route through the per-task + # write queue and use the tracked etag as if_match. The + # helper refreshes the etag from the response and bumps + # lease-last-refresh (cadence shadow). + # + # Spec 031 / FR-006 + SOT §25.3 — on a genuine (cross-process) + # etag conflict, re-read to refresh the tracked etag and retry. + # The patch addresses only this namespace's slot and the provider + # shallow-merges, so last-write-wins on the slot is correct (no + # logical re-merge of OTHER namespaces is needed). Bounded so the + # store's etag comparator cannot loop forever. A translated + # conflict (lease lost / already terminal) is NOT retried — the + # owner changed, so persisting our metadata would clobber theirs. + for attempt in range(5): + try: + await self._provider_update_locked( + task_id, + TaskPatchRequest( + payload={slot: data}, + **self._lease_ext_kwargs(task_id), + ), + ) + return + except _HostedConflict as exc: + if _translate_hosted_conflict(exc, task_id=task_id) is not None or attempt == 4: + raise + await self._provider_get_tracked(task_id) + except (EtagConflict, ValueError, TransportClassifiedError) as exc: + if isinstance(exc, TransportClassifiedError) and getattr(exc, "classification", None) != "conflict": + raise + if isinstance(exc, ValueError) and "etag" not in str(exc).lower(): + raise + if attempt == 4: + raise + await self._provider_get_tracked(task_id) + + return _flush + + def _make_pending_count_provider(self, task_id: str) -> Callable[[], int]: + """: factory for the live pending-input-count + callable bound onto :class:`TaskContext`. + + The returned callable reads the in-memory steering state for + ``task_id`` on each access so ``ctx.pending_input_count`` + reflects the current backlog including inputs queued + mid-handler (as opposed to a snapshot frozen at handler entry). + + Returns 0 for tasks that are not steerable or have no pending + inputs. + + :param task_id: The task identifier the callable should track. + :type task_id: str + :return: A callable returning the current pending-input count. + :rtype: Callable[[], int] + """ + + def _provider() -> int: + active = self._active_tasks.get(task_id) + if active is None: + return 0 + # Read live count from the persisted-but-cached steering + # tracker. The fastest place is the in-memory _ActiveTask + # entry; we annotate it via a side-channel below. Default + # to 0 if not yet populated. + count = getattr(active, "_pending_input_count", 0) + try: + return int(count) + except Exception: # pylint: disable=broad-exception-caught # noqa: BLE001 + return 0 + + return _provider diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tasks/_metadata.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tasks/_metadata.py new file mode 100644 index 000000000000..3d6b3a0657a0 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tasks/_metadata.py @@ -0,0 +1,353 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Mutable progress metadata for resilient tasks. + +Provides a dict-like interface with typed mutation methods plus a +**named-namespace** facility: + + ctx.metadata["key"] = "value" # default namespace + ctx.metadata("custom")["k"] = 1 # named namespace facade + ctx.metadata("_reserved")["seq"] = 5 # framework-layer convention + +Each namespace persists to a distinct payload slot: + +* ``payload["metadata"]`` — default namespace +* ``payload["metadata:"]`` — named namespaces + +There is **no auto-flush loop**. Flushes are explicit: + +* :meth:`TaskMetadata.flush` — flush THIS namespace only. +* :meth:`TaskMetadata._flush_all` — flush every dirty namespace + (called by the framework at lifecycle boundaries: suspend, + complete, fail, steering drain). + +The CORE primitive does NOT enforce namespace-name conventions. +Wrapper layers (e.g., responses) may reject ``_*`` names in their +:class:`ResilienceContext` facade — that is wrapper-layer policy +. +""" + +from __future__ import annotations + +import collections.abc +import logging +from collections.abc import Iterator +from typing import Any, Awaitable, Callable, Optional + +logger = logging.getLogger("azure.ai.agentserver.tasks") + +# Sentinel to distinguish "not set" from None +_NOT_SET = object() + +# Type alias for the per-namespace flush callback. +# The framework supplies a callback that knows how to persist data for +# a given namespace into the underlying task payload. +NamespaceFlushCallback = Callable[[Optional[str], dict[str, Any]], Awaitable[None]] + + +class TaskMetadata(collections.abc.MutableMapping): + """Mutable progress dict persisted to the task record's payload. + + The default namespace exposes a ``MutableMapping`` interface + directly. Named namespaces are accessed via the **callable** + protocol — ``meta(name)`` returns a sibling namespace facade. + + :param initial: Initial values for the **default** namespace. + :type initial: dict[str, Any] | None + :param flush_callback: Async callable invoked by :meth:`flush` to + persist dirty data. Signature: ``(namespace, data)`` where + ``namespace`` is ``None`` for the default namespace and a + ``str`` for named namespaces. + :type flush_callback: NamespaceFlushCallback | None + """ + + def __init__( + self, + initial: dict[str, Any] | None = None, + *, + flush_callback: NamespaceFlushCallback | None = None, + _namespace_name: Optional[str] = None, + _registry: dict[Optional[str], "TaskMetadata"] | None = None, + ) -> None: + self._data: dict[str, Any] = dict(initial) if initial else {} + self._dirty = False + self._flush_callback: NamespaceFlushCallback | None = flush_callback + self._namespace_name: Optional[str] = _namespace_name + # Registry of namespaces, keyed by namespace name. ``None`` is + # the default namespace. Child instances created via + # :meth:`__call__` share the SAME registry so namespace lookups + # are stable from any facade. + if _registry is None: + self._registry: dict[Optional[str], "TaskMetadata"] = {None: self} + else: + self._registry = _registry + + # -- Namespace callable protocol -------------- + + def __call__(self, name: Optional[str] = None) -> "TaskMetadata": + """Return a namespace facade. + + ``meta()`` returns the default namespace; ``meta("custom")`` + returns the named-namespace facade (auto-vivified). + + The core primitive does NOT enforce namespace-name conventions + (e.g. the leading-underscore reservation). That is a wrapper- + layer concern — handler-facing wrappers (composed protocol + packages) may reject ``_*`` names so handlers can't collide with + framework-reserved namespaces. Framework-layered code (a wrapper + orchestrator itself) reaches reserved namespaces directly via + this API. + + :param name: Namespace name. ``None`` returns the default + namespace; a string returns the named namespace. + :type name: str | None + :return: A namespace facade. + :rtype: TaskMetadata + """ + if name is None: + return self._registry[None] + if name in self._registry: + return self._registry[name] + # Auto-vivify a new namespace; share the registry and inherit + # the parent's per-namespace flush callback. + child = TaskMetadata( + flush_callback=self._flush_callback, + _namespace_name=name, + _registry=self._registry, + ) + self._registry[name] = child + return child + + @classmethod + def from_payload( + cls, + payload: dict[str, Any] | None, + *, + flush_callback: NamespaceFlushCallback | None = None, + ) -> "TaskMetadata": + """Construct a fresh :class:`TaskMetadata` from a recovered payload. + + Decodes the per-namespace persistence layout: + + * ``payload["metadata"]`` → default namespace. + * ``payload["metadata:"]`` → named namespace ````. + + :param payload: The task's payload dict (or ``None``). + :type payload: dict[str, Any] | None + :keyword flush_callback: Per-namespace flush callback to wire into + every restored namespace. + :paramtype flush_callback: NamespaceFlushCallback | None + :return: A fully populated :class:`TaskMetadata` with all named + namespaces pre-vivified to their recovered state. + :rtype: TaskMetadata + """ + payload = payload or {} + default_data = payload.get("metadata") or {} + if not isinstance(default_data, dict): + default_data = {} + + root = cls(initial=default_data, flush_callback=flush_callback) + for key, value in payload.items(): + if not isinstance(key, str) or not key.startswith("metadata:"): + continue + name = key[len("metadata:") :] + if not name or not isinstance(value, dict): + continue + # Auto-vivify and seed + ns = root(name) + ns._data = dict(value) # pylint: disable=protected-access + ns._dirty = False # pylint: disable=protected-access + return root + + # -- Typed mutation methods (operate on THIS namespace) ---------------- # + + def set(self, key: str, value: Any) -> None: + """Set a key-value pair in this namespace. + + :param key: Metadata key (must be a string). + :type key: str + :param value: Any JSON-serializable value. + :type value: Any + :raises TypeError: If key is not a string. + """ + if not isinstance(key, str): + raise TypeError(f"Metadata key must be a string, got {type(key).__name__}") + self._data[key] = value + self._mark_dirty() + + def get(self, key: str, default: Any = None) -> Any: + """Get a value by key. + + :param key: Metadata key. + :type key: str + :param default: Default value if key is absent. + :type default: Any + :return: The value, or *default*. + :rtype: Any + """ + return self._data.get(key, default) + + def increment(self, key: str, delta: int = 1) -> None: + """Atomically increment a numeric value. + + :param key: Metadata key. + :type key: str + :param delta: Amount to add (default 1). + :type delta: int + :raises TypeError: If the existing value is not numeric. + """ + if not isinstance(delta, (int, float)): + raise TypeError(f"Delta must be numeric, got {type(delta).__name__}") + current = self._data.get(key, 0) + if not isinstance(current, (int, float)): + raise TypeError(f"Cannot increment non-numeric value at key {key!r}: " f"{type(current).__name__}") + self._data[key] = current + delta + self._mark_dirty() + + def append(self, key: str, value: Any) -> None: + """Append a value to a list. + + Creates the list if the key is absent. + + :param key: Metadata key. + :type key: str + :param value: Value to append. + :type value: Any + :raises TypeError: If the existing value is not a list. + """ + current = self._data.get(key, _NOT_SET) + if current is _NOT_SET: + self._data[key] = [value] + elif isinstance(current, list): + current.append(value) + else: + raise TypeError(f"Cannot append to non-list value at key {key!r}: " f"{type(current).__name__}") + self._mark_dirty() + + def to_dict(self) -> dict[str, Any]: + """Return a snapshot of this namespace's data. + + :return: A shallow copy of the namespace's dict. + :rtype: dict[str, Any] + """ + return dict(self._data) + + # -- Dict protocol (MutableMapping) ------------------------------------ # + + def __setitem__(self, key: str, value: Any) -> None: + if not isinstance(key, str): + raise TypeError(f"Metadata key must be a string, got {type(key).__name__}") + self._data[key] = value + self._mark_dirty() + + def __getitem__(self, key: str) -> Any: + return self._data[key] + + def __delitem__(self, key: str) -> None: + del self._data[key] + self._mark_dirty() + + def __contains__(self, key: object) -> bool: + return key in self._data + + def __iter__(self) -> Iterator[str]: + return iter(self._data) + + def __len__(self) -> int: + return len(self._data) + + def keys(self) -> collections.abc.KeysView[str]: + """Return a view of metadata keys. + + :return: A view of the metadata keys. + :rtype: ~collections.abc.KeysView[str] + """ + return self._data.keys() + + def values(self) -> collections.abc.ValuesView[Any]: + """Return a view of metadata values. + + :return: A view of the metadata values. + :rtype: ~collections.abc.ValuesView[Any] + """ + return self._data.values() + + def items(self) -> collections.abc.ItemsView[str, Any]: + """Return a view of metadata key-value pairs. + + :return: A view of the metadata key-value pairs. + :rtype: ~collections.abc.ItemsView[str, Any] + """ + return self._data.items() + + # -- Flush API (explicit; no auto-flush loop) -------------------------- # + + async def flush(self) -> None: + """Force-flush this namespace's pending changes to storage. + + No-op if there are no pending changes in THIS namespace or no + flush callback. Sibling namespaces are NOT touched. + """ + await self._do_flush_one() + + async def _flush_all(self) -> None: + """— framework-internal: flush every dirty + namespace (default + all named). + + Called by the framework at lifecycle boundaries (suspend, + complete, fail, steering drain) to guarantee all in-memory + mutations land in the task payload before the task transitions. + + The leading underscore is the canonical signal for + "package-private; not part of the documented developer + surface." Developers MUST NOT call this — per-namespace + :meth:`flush` is the only fence pattern they need. + """ + for ns in list(self._registry.values()): + await ns._do_flush_one() # pylint: disable=protected-access + + def _mark_dirty(self) -> None: + self._dirty = True + + async def _do_flush_one(self) -> None: + if not self._dirty or self._flush_callback is None: + return + try: + await self._flush_callback(self._namespace_name, dict(self._data)) + self._dirty = False + except Exception: # pylint: disable=broad-exception-caught + logger.warning( + "Failed to flush metadata namespace %r", + self._namespace_name, + exc_info=True, + ) + + +# ========================================================================= +# — JSONValue recursive type alias +# ========================================================================= +# +# Public type alias exported via tasks.__init__. TaskMetadata values +# SHOULD be JSON-serializable; this alias documents the value space. + +from typing import Union, List, Dict + +try: + from typing import TypeAlias # Python 3.10+ +except ImportError: # pragma: no cover + from typing_extensions import TypeAlias # type: ignore[assignment] + +# Recursive JSON type alias. Forward refs allow self-recursion. +# Use ForwardRef-via-string for the recursive arms so this type-checks +# on all Python versions, and the test's ForwardRef-detection logic +# resolves the recursion to the same alias. +JSONValue: TypeAlias = Union[ + str, + int, + float, + bool, + None, + List["JSONValue"], + Dict[str, "JSONValue"], +] diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tasks/_models.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tasks/_models.py new file mode 100644 index 000000000000..5276c400546c --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tasks/_models.py @@ -0,0 +1,441 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Internal data models for the resilient task subsystem. + +These types represent wire-level task records and request/response shapes +used by providers. They are **not** part of the public API. +""" + +from __future__ import annotations + +from typing import Any, Literal + +TaskStatus = Literal["pending", "in_progress", "suspended", "completed"] +"""Valid task status values.""" + + +class LeaseInfo: + """Lease details on a task record. + + :param owner: Stable lease owner (e.g. ``"session:session_abc"``). + :type owner: str + :param instance_id: Ephemeral per-process instance identifier. + :type instance_id: str + :param generation: Fencing token — increments on re-acquisition. + :type generation: int + :param expires_at: ISO 8601 expiry timestamp. + :type expires_at: str + :param expiry_count: Number of times ownership changed via expiry. + :type expiry_count: int + :param heartbeat_at: ISO 8601 wall-time of the most recent lease + write (acquisition, renewal, or force-expire). Provider-stamped; + the framework never writes this. See SOT §22.1 LSE-W-10. + :type heartbeat_at: str + """ + + __slots__ = ( + "owner", + "instance_id", + "generation", + "expires_at", + "expiry_count", + "heartbeat_at", + ) + + def __init__( + self, + owner: str, + instance_id: str, + generation: int, + expires_at: str, + expiry_count: int = 0, + heartbeat_at: str = "", + ) -> None: + self.owner = owner + self.instance_id = instance_id + self.generation = generation + self.expires_at = expires_at + self.expiry_count = expiry_count + self.heartbeat_at = heartbeat_at + + def __repr__(self) -> str: + return ( + f"LeaseInfo(owner={self.owner!r}, instance_id={self.instance_id!r}, " + f"generation={self.generation!r}, expires_at={self.expires_at!r}, " + f"expiry_count={self.expiry_count!r}, heartbeat_at={self.heartbeat_at!r})" + ) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, LeaseInfo): + return NotImplemented + return ( + self.owner == other.owner + and self.instance_id == other.instance_id + and self.generation == other.generation + and self.expires_at == other.expires_at + and self.expiry_count == other.expiry_count + and self.heartbeat_at == other.heartbeat_at + ) + + +class TaskInfo: # pylint: disable=too-many-instance-attributes + """Internal representation of a task record from the store. + + :param id: Unique task identifier. + :type id: str + :param agent_name: Agent scope. + :type agent_name: str + :param session_id: Session scope. + :type session_id: str + :param status: Current task status. + :type status: TaskStatus + :param title: Human-readable title. + :type title: str | None + :param description: Optional description. + :type description: str | None + :param lease: Active lease details, or ``None``. + :type lease: LeaseInfo | None + :param payload: Arbitrary JSON payload (input, metadata, output buckets). + :type payload: dict[str, Any] | None + :param tags: Key-value tags. + :type tags: dict[str, str] | None + :param error: Structured error details on failure. + :type error: dict[str, Any] | None + :param suspension_reason: Reason for suspension. + :type suspension_reason: str | None + :param etag: Optimistic concurrency token. + :type etag: str + :param created_at: ISO 8601 creation timestamp. + :type created_at: str + :param updated_at: ISO 8601 last-update timestamp. + :type updated_at: str + :param started_at: ISO 8601 timestamp of first ``in_progress`` transition. + Set once when the task first enters ``in_progress`` and never updated + thereafter — lease re-acquisition, recovery scanner takeover, and + suspend/resume cycles do NOT reset this timestamp. + :type started_at: str | None + :param completed_at: ISO 8601 timestamp of ``completed`` transition. + :type completed_at: str | None + :param source: Source/initiator metadata (free-form key/value). + :type source: dict[str, Any] | None + :param attachments: Optional companion store for + per-input payloads larger than the framework's inline-payload + thresholds. Maximum 20 entries, each ≤ 2 MB. Keys starting with + ``_`` are reserved for the framework (``_input``, + ``_steering_input_``). See + `the SOT spec`. + :type attachments: dict[str, Any] | None + """ + + __slots__ = ( + "id", + "agent_name", + "session_id", + "status", + "title", + "description", + "lease", + "payload", + "tags", + "error", + "suspension_reason", + "etag", + "created_at", + "updated_at", + "started_at", + "completed_at", + "source", + "attachments", + ) + + def __init__( + self, + id: str, # noqa: A002 + agent_name: str, + session_id: str, + status: TaskStatus, + title: str | None = None, + description: str | None = None, + lease: LeaseInfo | None = None, + payload: dict[str, Any] | None = None, + tags: dict[str, str] | None = None, + error: dict[str, Any] | None = None, + suspension_reason: str | None = None, + etag: str = "", + created_at: str = "", + updated_at: str = "", + started_at: str | None = None, + completed_at: str | None = None, + source: dict[str, Any] | None = None, + attachments: dict[str, Any] | None = None, + ) -> None: + self.id = id + self.agent_name = agent_name + self.session_id = session_id + self.status = status + self.title = title + self.description = description + self.lease = lease + self.payload = payload + self.tags = tags + self.error = error + self.suspension_reason = suspension_reason + self.etag = etag + self.created_at = created_at + self.updated_at = updated_at + self.started_at = started_at + self.completed_at = completed_at + self.source = source + self.attachments = attachments + + def __repr__(self) -> str: + return f"TaskInfo(id={self.id!r}, status={self.status!r}, agent_name={self.agent_name!r})" + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> TaskInfo: + """Construct a :class:`TaskInfo` from a JSON-decoded dict. + + :param data: Dictionary as returned by the Task Storage API. + :type data: dict[str, Any] + :return: A populated TaskInfo instance. + :rtype: TaskInfo + """ + lease_data = data.get("lease") + lease = ( + LeaseInfo( + owner=lease_data["owner"], + instance_id=lease_data["instance_id"], + generation=lease_data.get("generation", 0), + expires_at=lease_data.get("expires_at", ""), + expiry_count=lease_data.get("expiry_count", 0), + heartbeat_at=lease_data.get("heartbeat_at", ""), + ) + if lease_data + else None + ) + return cls( + id=data["id"], + agent_name=data.get("agent_name", ""), + session_id=data.get("session_id", ""), + status=data.get("status", "pending"), + title=data.get("title"), + description=data.get("description"), + lease=lease, + payload=data.get("payload"), + tags=data.get("tags"), + error=data.get("error"), + suspension_reason=data.get("suspension_reason"), + etag=data.get("etag", ""), + created_at=data.get("created_at", ""), + updated_at=data.get("updated_at", ""), + started_at=data.get("started_at"), + completed_at=data.get("completed_at"), + source=data.get("source"), + attachments=data.get("attachments"), + ) + + def to_dict(self) -> dict[str, Any]: + """Serialize to a JSON-compatible dictionary. + + :return: Dictionary suitable for JSON serialization. + :rtype: dict[str, Any] + """ + result: dict[str, Any] = { + "object": "task", + "id": self.id, + "agent_name": self.agent_name, + "session_id": self.session_id, + "status": self.status, + } + if self.title is not None: + result["title"] = self.title + if self.description is not None: + result["description"] = self.description + if self.lease is not None: + result["lease"] = { + "owner": self.lease.owner, + "instance_id": self.lease.instance_id, + "generation": self.lease.generation, + "expires_at": self.lease.expires_at, + "expiry_count": self.lease.expiry_count, + "heartbeat_at": self.lease.heartbeat_at, + } + else: + result["lease"] = None + if self.payload is not None: + result["payload"] = self.payload + if self.tags is not None: + result["tags"] = self.tags + if self.error is not None: + result["error"] = self.error + if self.suspension_reason is not None: + result["suspension_reason"] = self.suspension_reason + if self.source is not None: + result["source"] = self.source + if self.attachments is not None: + result["attachments"] = self.attachments + result["etag"] = self.etag + result["created_at"] = self.created_at + result["updated_at"] = self.updated_at + result["started_at"] = self.started_at + result["completed_at"] = self.completed_at + return result + + +class TaskCreateRequest: # pylint: disable=too-many-instance-attributes + """Request body for creating a task. + + :param agent_name: Agent scope. + :type agent_name: str + :param session_id: Session scope. + :type session_id: str + :param status: Initial status (``"pending"`` or ``"in_progress"``). + :type status: TaskStatus + :param id: Optional client-supplied task ID. + :type id: str | None + :param title: Human-readable title. + :type title: str | None + :param description: Optional description. + :type description: str | None + :param payload: Initial payload (input bucket). + :type payload: dict[str, Any] | None + :param tags: Initial tags. + :type tags: dict[str, str] | None + :param lease_owner: Required when ``status`` is ``"in_progress"``. + :type lease_owner: str | None + :param lease_instance_id: Required when ``status`` is ``"in_progress"``. + :type lease_instance_id: str | None + :param lease_duration_seconds: Lease TTL. Required with lease params. + :type lease_duration_seconds: int | None + :param attachments: Optional initial attachments map. + Each value must be ≤ 2 MB; total entries ≤ 20. Keys starting + with ``_`` are reserved for the framework. See + `the SOT spec`. + :type attachments: dict[str, Any] | None + """ + + __slots__ = ( + "agent_name", + "session_id", + "status", + "id", + "title", + "description", + "payload", + "tags", + "source", + "lease_owner", + "lease_instance_id", + "lease_duration_seconds", + "attachments", + ) + + def __init__( + self, + agent_name: str, + session_id: str, + status: TaskStatus = "pending", + id: str | None = None, # noqa: A002 + title: str | None = None, + description: str | None = None, + payload: dict[str, Any] | None = None, + tags: dict[str, str] | None = None, + source: dict[str, Any] | None = None, + lease_owner: str | None = None, + lease_instance_id: str | None = None, + lease_duration_seconds: int | None = None, + attachments: dict[str, Any] | None = None, + ) -> None: + self.agent_name = agent_name + self.session_id = session_id + self.status = status + self.id = id + self.title = title + self.description = description + self.payload = payload + self.tags = tags + self.source = source + self.lease_owner = lease_owner + self.lease_instance_id = lease_instance_id + self.lease_duration_seconds = lease_duration_seconds + self.attachments = attachments + + +class TaskPatchRequest: + """Request body for patching a task. + + Only non-``None`` fields are included in the PATCH payload. + + :param status: New status. + :type status: TaskStatus | None + :param payload: Payload patch (shallow-merge semantics). + :type payload: dict[str, Any] | None + :param tags: Tags patch (null-as-delete merge). + :type tags: dict[str, str] | None + :param error: Structured error (on failure). + :type error: dict[str, Any] | None + :param suspension_reason: Reason for suspension. + :type suspension_reason: str | None + :param lease_owner: Lease owner for transitions. + :type lease_owner: str | None + :param lease_instance_id: Lease instance for transitions. + :type lease_instance_id: str | None + :param lease_duration_seconds: Lease TTL override. + :type lease_duration_seconds: int | None + :param if_match: ETag for optimistic concurrency. + :type if_match: str | None + :param attachments: Attachments patch. Same null-as- + delete semantics as ``tags``: keys with a non-``None`` value are + upserted; keys with value ``None`` are deleted; keys absent + from the dict are unchanged. ``None`` for the field itself + means "no attachments changes in this PATCH". + :type attachments: dict[str, Any] | None + :param clear_attachments: When ``True``, wipe ALL attachments on + the task. The hosted provider serializes this as the wire form + ``"attachments": null`` (the service's "clear all" gesture + per §23.10); the local provider clears the dict directly. + Mutually exclusive with ``attachments={...}`` in the same + request — combination is rejected as ``invalid_request``. + :type clear_attachments: bool + """ + + __slots__ = ( + "status", + "payload", + "tags", + "error", + "suspension_reason", + "lease_owner", + "lease_instance_id", + "lease_duration_seconds", + "if_match", + "attachments", + "clear_attachments", + ) + + def __init__( + self, + status: TaskStatus | None = None, + payload: dict[str, Any] | None = None, + tags: dict[str, str] | None = None, + error: dict[str, Any] | None = None, + suspension_reason: str | None = None, + lease_owner: str | None = None, + lease_instance_id: str | None = None, + lease_duration_seconds: int | None = None, + if_match: str | None = None, + attachments: dict[str, Any] | None = None, + clear_attachments: bool = False, + ) -> None: + self.status = status + self.payload = payload + self.tags = tags + self.error = error + self.suspension_reason = suspension_reason + self.lease_owner = lease_owner + self.lease_instance_id = lease_instance_id + self.lease_duration_seconds = lease_duration_seconds + self.if_match = if_match + self.attachments = attachments + self.clear_attachments = clear_attachments diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tasks/_provider.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tasks/_provider.py new file mode 100644 index 000000000000..0a69fc1acb5b --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tasks/_provider.py @@ -0,0 +1,112 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Storage provider protocol for the resilient task subsystem. + +Defines the structural typing contract that hosted and local providers +must satisfy. Uses :class:`typing.Protocol` (PEP 544) — implementations +do not need to inherit from this class. +""" + +from __future__ import annotations + +from typing import Protocol, runtime_checkable + +from ._models import TaskCreateRequest, TaskInfo, TaskPatchRequest, TaskStatus + + +@runtime_checkable +class TaskProvider(Protocol): + """Async storage backend for resilient tasks. + + Both :class:`HostedTaskProvider` (HTTP → Task Storage API) and + :class:`LocalFileTaskProvider` (filesystem) implement this + protocol. + """ + + async def create(self, request: TaskCreateRequest) -> TaskInfo: + """Create a new task. + + :param request: Task creation parameters. + :type request: TaskCreateRequest + :return: The created task record. + :rtype: TaskInfo + """ + ... + + async def get(self, task_id: str) -> TaskInfo | None: + """Get a single task by ID. + + :param task_id: The task identifier. + :type task_id: str + :return: The task record, or ``None`` if not found. + :rtype: TaskInfo | None + """ + ... + + async def update(self, task_id: str, patch: TaskPatchRequest) -> TaskInfo: + """Update a task via PATCH semantics. + + :param task_id: The task identifier. + :type task_id: str + :param patch: Fields to update. + :type patch: TaskPatchRequest + :return: The updated task record. + :rtype: TaskInfo + :raises TaskNotFound: If the task does not exist. + """ + ... + + async def delete( + self, + task_id: str, + *, + force: bool = False, + cascade: bool = False, + ) -> None: + """Delete a task. + + :param task_id: The task identifier. + :type task_id: str + :keyword force: Release active lease before deleting. + :paramtype force: bool + :keyword cascade: Delete dependent tasks. + :paramtype cascade: bool + """ + ... + + async def list( + self, + *, + agent_name: str | None = None, + session_id: str | None = None, + status: TaskStatus | str | None = None, + lease_owner: str | None = None, + tag: dict[str, str] | None = None, + source_type: str | None = None, + has_error: bool | None = None, + lease_expired: bool | None = None, + limit: int | None = None, + after: str | None = None, + before: str | None = None, + order: str | None = None, + omit_attachment_values: bool = False, + ) -> list[TaskInfo]: + """List tasks with filters. + + :keyword agent_name: Filter by agent name. + :paramtype agent_name: str + :keyword session_id: Filter by session ID. + :paramtype session_id: str + :keyword status: Filter by task status. + :paramtype status: TaskStatus | None + :keyword lease_owner: Filter by lease owner. + :paramtype lease_owner: str | None + :keyword tag: Filter by tags (AND semantics — all must match). + :paramtype tag: dict[str, str] | None + :keyword source_type: Filter by source type. + :paramtype source_type: str | None + :return: Matching task records. + :rtype: list[TaskInfo] + """ + ... diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tasks/_retry.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tasks/_retry.py new file mode 100644 index 000000000000..f832742efe54 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tasks/_retry.py @@ -0,0 +1,355 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""RetryPolicy — configurable retry behaviour for resilient tasks. + +Aligned with industry conventions (Temporal, Celery). +Delay formula: ``min(initial_delay * backoff_coefficient ** attempt, max_delay)`` +With jitter: ``delay * uniform(0.75, 1.25)`` +""" + +from __future__ import annotations + +import random +from datetime import timedelta + + +class RetryPolicy: + """Retry configuration for resilient tasks. + + :param initial_delay: Base delay between retries. + :type initial_delay: ~datetime.timedelta + :param backoff_coefficient: Multiplier applied per attempt. + :type backoff_coefficient: float + :param max_delay: Upper bound on computed delay. + :type max_delay: ~datetime.timedelta + :param max_attempts: Total attempts (including the first try). This is a + single **resilient** budget that counts handler-raised failures across + ALL lifetimes — the count is persisted to + ``payload["_retry_attempt"]`` and restored on recovery. Crash + recovery does NOT consume the budget; only handler-raised exceptions + do. A steering input resets the counter (a steering input is a new + logical request). + :type max_attempts: int + :param retry_on: Exception types that trigger retry. ``None`` means all. + :type retry_on: tuple[type[Exception], ...] | None + :param jitter: Whether to add ±25% randomization to delays. + :type jitter: bool + + .. versionadded:: 2.1.0 + """ + + __slots__ = ( + "initial_delay", + "backoff_coefficient", + "max_delay", + "max_attempts", + "retry_on", + "jitter", + "_linear", + ) + + def __init__( + self, + *, + initial_delay: timedelta | float = timedelta(seconds=1), + backoff_coefficient: float = 2.0, + max_delay: timedelta | float = timedelta(seconds=60), + max_attempts: int = 3, + retry_on: tuple[type[Exception], ...] | None = None, + jitter: bool | float = True, + _linear: bool = False, + ) -> None: + #: accept both timedelta and float (seconds) for + # initial_delay / max_delay. Store as the type provided so + # ``policy.initial_delay == 1.0`` works for float callers and + # ``.total_seconds()`` works for timedelta callers. + def _seconds(v: timedelta | float) -> float: + return v.total_seconds() if isinstance(v, timedelta) else float(v) + + if _seconds(initial_delay) < 0: + raise ValueError(f"initial_delay must be >= 0, got {initial_delay}") + if backoff_coefficient < 1.0: + raise ValueError(f"backoff_coefficient must be >= 1.0, got {backoff_coefficient}") + if _seconds(max_delay) < _seconds(initial_delay): + raise ValueError(f"max_delay ({max_delay}) must be >= initial_delay ({initial_delay})") + if max_attempts < 1: + raise ValueError(f"max_attempts must be >= 1, got {max_attempts}") + if retry_on is not None: + # Accept a bare class as a single-element tuple — Pythonic. + if isinstance(retry_on, type) and issubclass(retry_on, BaseException): + retry_on = (retry_on,) + elif isinstance(retry_on, type): + # Non-Exception class (e.g., str) passed directly — reject. + raise TypeError(f"retry_on entries must be Exception subclasses, got {retry_on!r}") + for exc_type in retry_on: + if not isinstance(exc_type, type) or not issubclass(exc_type, Exception): + raise TypeError(f"retry_on entries must be Exception subclasses, got {exc_type!r}") + + self.initial_delay = initial_delay + self.backoff_coefficient = backoff_coefficient + self.max_delay = max_delay + self.max_attempts = max_attempts + self.retry_on = retry_on + self.jitter = jitter + self._linear = _linear + + def compute_delay(self, attempt: int) -> float: + """Return the delay in seconds for the given attempt (0-indexed). + + :param attempt: The 0-based attempt number that just failed. + :type attempt: int + :return: Delay in seconds before the next attempt. + :rtype: float + """ + base_seconds = ( + self.initial_delay.total_seconds() + if isinstance(self.initial_delay, timedelta) + else float(self.initial_delay) + ) + max_seconds = self.max_delay.total_seconds() if isinstance(self.max_delay, timedelta) else float(self.max_delay) + if self._linear: + raw = base_seconds * (attempt + 1) + else: + raw = base_seconds * (self.backoff_coefficient**attempt) + + capped = min(raw, max_seconds) + + if self.jitter: + capped *= random.uniform(0.75, 1.25) + + return max(0.0, capped) + + def should_retry(self, attempt: int, error: Exception) -> bool: + """Return whether the task should be retried. + + :param attempt: The 0-based attempt number that just failed. + :type attempt: int + :param error: The exception that was raised. + :type error: Exception + :return: ``True`` if the task should be retried. + :rtype: bool + """ + # attempt is 0-indexed; max_attempts includes the first try + if attempt >= self.max_attempts - 1: + return False + if self.retry_on is None: + return True + return isinstance(error, self.retry_on) + + def __repr__(self) -> str: + return ( + f"RetryPolicy(initial_delay={self.initial_delay!r}, " + f"backoff_coefficient={self.backoff_coefficient}, " + f"max_delay={self.max_delay!r}, " + f"max_attempts={self.max_attempts}, " + f"retry_on={self.retry_on!r}, " + f"jitter={self.jitter})" + ) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, RetryPolicy): + return NotImplemented + return ( + self.initial_delay == other.initial_delay + and self.backoff_coefficient == other.backoff_coefficient + and self.max_delay == other.max_delay + and self.max_attempts == other.max_attempts + and self.retry_on == other.retry_on + and self.jitter == other.jitter + and self._linear == other._linear + ) + + # ------------------------------------------------------------------ + # Convenience presets + # ------------------------------------------------------------------ + + @classmethod + def exponential_backoff( + cls, + *, + max_attempts: int = 3, + initial_delay: timedelta = timedelta(seconds=1), + max_delay: timedelta = timedelta(seconds=60), + backoff_coefficient: float = 2.0, + jitter: bool = True, + ) -> RetryPolicy: + """Exponential backoff — the most common pattern. + + Delay doubles per attempt: 1 s → 2 s → 4 s → … capped at *max_delay*. + + :keyword max_attempts: Total attempts including the first try. + :paramtype max_attempts: int + :keyword initial_delay: Base delay. + :paramtype initial_delay: ~datetime.timedelta + :keyword max_delay: Upper bound. + :paramtype max_delay: ~datetime.timedelta + :keyword backoff_coefficient: Multiplier applied per attempt. + :paramtype backoff_coefficient: float + :keyword jitter: Add ±25% randomization. + :paramtype jitter: bool + :return: A configured ``RetryPolicy``. + :rtype: RetryPolicy + """ + return cls( + initial_delay=initial_delay, + backoff_coefficient=backoff_coefficient, + max_delay=max_delay, + max_attempts=max_attempts, + jitter=jitter, + ) + + @classmethod + def fixed_delay( + cls, + *, + delay: timedelta = timedelta(seconds=5), + max_attempts: int = 3, + ) -> RetryPolicy: + """Fixed delay — constant interval between retries. + + Useful for rate-limited APIs where you want to wait a fixed + amount of time between each attempt. + + :keyword delay: Constant delay between retries. + :paramtype delay: ~datetime.timedelta + :keyword max_attempts: Total attempts including the first try. + :paramtype max_attempts: int + :return: A configured ``RetryPolicy``. + :rtype: RetryPolicy + """ + return cls( + initial_delay=delay, + backoff_coefficient=1.0, + max_delay=delay, + max_attempts=max_attempts, + jitter=False, + ) + + @classmethod + def linear_backoff( + cls, + *, + initial_delay: timedelta = timedelta(seconds=1), + max_delay: timedelta = timedelta(seconds=60), + max_attempts: int = 5, + ) -> RetryPolicy: + """Linear backoff — delay grows additively. + + Delay is ``initial_delay * (attempt + 1)``: 1 s → 2 s → 3 s → … + + :keyword initial_delay: Base delay unit. + :paramtype initial_delay: ~datetime.timedelta + :keyword max_delay: Upper bound. + :paramtype max_delay: ~datetime.timedelta + :keyword max_attempts: Total attempts including the first try. + :paramtype max_attempts: int + :return: A configured ``RetryPolicy``. + :rtype: RetryPolicy + """ + return cls( + initial_delay=initial_delay, + backoff_coefficient=1.0, + max_delay=max_delay, + max_attempts=max_attempts, + jitter=False, + _linear=True, + ) + + @classmethod + def no_retry(cls) -> RetryPolicy: + """No retry — the function runs once and fails on exception. + + Equivalent to not setting a retry policy at all. + + :return: A ``RetryPolicy`` that never retries. + :rtype: RetryPolicy + """ + return cls( + initial_delay=timedelta(0), + backoff_coefficient=1.0, + max_delay=timedelta(0), + max_attempts=1, + jitter=False, + ) + + +# ========================================================================= +# — module-level convenience wrappers around the preset +# classmethods (documents these as `exponential_backoff` etc. +# with explicit kwargs). +# ========================================================================= + + +def exponential_backoff( + *, + initial_delay: "timedelta" = timedelta(seconds=1), + backoff_coefficient: float = 2.0, + max_delay: "timedelta" = timedelta(seconds=60), + max_attempts: int = 5, + jitter: bool = True, +) -> RetryPolicy: + """Module-level wrapper for :meth:`RetryPolicy.exponential_backoff`. + + : preset factories enumerate their kwargs explicitly. + + :keyword initial_delay: Initial delay before the first retry. + :keyword backoff_coefficient: Multiplier applied per attempt. + :keyword max_delay: Cap on the per-attempt delay. + :keyword max_attempts: Total attempts including the first try. + :keyword jitter: When True, add ±15% jitter per attempt. + :return: A configured :class:`RetryPolicy`. + :rtype: RetryPolicy + """ + return RetryPolicy.exponential_backoff( + initial_delay=initial_delay, + backoff_coefficient=backoff_coefficient, + max_delay=max_delay, + max_attempts=max_attempts, + jitter=jitter, + ) + + +def fixed_delay( + *, + delay: "timedelta" = timedelta(seconds=1), + max_attempts: int = 5, +) -> RetryPolicy: + """Module-level wrapper for :meth:`RetryPolicy.fixed_delay`. + + :keyword delay: Constant delay between retries. + :keyword max_attempts: Total attempts including the first try. + :return: A configured :class:`RetryPolicy`. + :rtype: RetryPolicy + """ + return RetryPolicy.fixed_delay(delay=delay, max_attempts=max_attempts) + + +def linear_backoff( + *, + initial_delay: "timedelta" = timedelta(seconds=1), + max_delay: "timedelta" = timedelta(seconds=60), + max_attempts: int = 5, +) -> RetryPolicy: + """Module-level wrapper for :meth:`RetryPolicy.linear_backoff`. + + :keyword initial_delay: Delay increment per attempt. + :keyword max_delay: Cap on the per-attempt delay. + :keyword max_attempts: Total attempts including the first try. + :return: A configured :class:`RetryPolicy`. + :rtype: RetryPolicy + """ + return RetryPolicy.linear_backoff( + initial_delay=initial_delay, + max_delay=max_delay, + max_attempts=max_attempts, + ) + + +def no_retry() -> RetryPolicy: + """Module-level wrapper for :meth:`RetryPolicy.no_retry`. + + :return: A :class:`RetryPolicy` that never retries. + :rtype: RetryPolicy + """ + return RetryPolicy.no_retry() diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tasks/_run.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tasks/_run.py new file mode 100644 index 000000000000..aee7dac65066 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tasks/_run.py @@ -0,0 +1,187 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""TaskRun handle for the resilient task subsystem. + + (Q9 / Q17 / /): slim public shape. + +Public surface: +- attributes: ``task_id``, ``input_id`` +- property: ``metadata`` +- methods: ``result()`` (returns ``Output``), ``cancel()`` +- dunder: ``__await__`` + +The legacy ``status``, ``lease_expiry_count``, ``delete()``, ``refresh()``, +and the ``Suspended`` sentinel are intentionally removed. The +``TaskResult`` wrapper is no longer exposed: ``await run`` / ``await +run.result()`` resolves to the raw ``Output`` value (or raises +``TaskFailed`` / ``TaskCancelled`` / ``TaskDeferred``). +""" + +from __future__ import annotations + +import asyncio # pylint: disable=do-not-import-asyncio +from typing import Any, Generic, TypeVar + +from ._metadata import TaskMetadata + +Output = TypeVar("Output") + + +def _unwrap_result(res: Any) -> Any: + """: futures now resolve to raw Output directly. + + Identity helper retained so older monkey-patches in tests that + pre-wrap futures still pass unchanged. + """ + return res + + +class TaskRun(Generic[Output]): # pylint: disable=too-many-instance-attributes + """Handle to a running or completed resilient task. + + Returned by :meth:`Task.start`. Provides external observation + and control of the task lifecycle. + + :param task_id: The task identifier. + :type task_id: str + :param provider: Storage provider for refresh/delete operations. + :type provider: TaskProvider + :param result_future: Future that resolves with the task output. + :type result_future: asyncio.Future[Output] + :param metadata: The task's metadata instance. + :type metadata: TaskMetadata + :param cancel_event: Event to signal cancellation. + :type cancel_event: asyncio.Event + :param status: Initial task status. + :type status: TaskStatus + """ + + __slots__ = ( + "task_id", + "input_id", # — public read-only attribute + "_result_future", + "_metadata", + "_cancel_event", + "_cancel_ctx_ref", + "_execution_task", + "_queued_cancel_callback", + ) + + def __init__( + self, + task_id: str, + *, + provider: Any = None, # noqa: ARG002 — kept for ctor compat, no longer stored (Phase 5) + result_future: asyncio.Future[Any], + metadata: TaskMetadata | None = None, + cancel_event: asyncio.Event | None = None, + status: Any = None, # noqa: ARG002 — accepted but ignored (Phase 5) + terminate_event: asyncio.Event | None = None, # noqa: ARG002 — accepted but ignored (Phase 5) + execution_task: asyncio.Task[Any] | None = None, + terminate_reason_ref: list[str | None] | None = None, # noqa: ARG002 — accepted but ignored (Phase 5) + lease_expiry_count: int = 0, # noqa: ARG002 — accepted but ignored (Phase 5) + cancel_ctx_ref: Any = None, + input_id: str | None = None, + queued_cancel_callback: Any = None, + ) -> None: + self.task_id = task_id + # — `input_id` is a public read-only attribute on + # TaskRun. For one-shot tasks it defaults to ``task_id`` (1:1 invariant + # ); for multi-turn tasks the framework auto-generates a + # separate GUID per turn and sets it here. + self.input_id: str = input_id if input_id is not None else task_id + self._result_future = result_future + self._metadata = metadata or TaskMetadata() + self._cancel_event = cancel_event or asyncio.Event() + self._execution_task: asyncio.Task[Any] | None = execution_task + #: weak reference to the TaskContext so + # TaskRun.cancel() can set ctx.cancel_requested = True before + # setting ctx.cancel. + self._cancel_ctx_ref: Any = cancel_ctx_ref + # Optional callback installed by the framework when this handle + # represents a queued (not-yet-promoted) steering input. + # ``cancel()`` invokes the callback instead of the in-process + # cancel signal — the callback removes the queued slot from + # ``_steering.pending_inputs`` and resolves the future with + # ``TaskCancelled``. + self._queued_cancel_callback: Any = queued_cancel_callback + + @property + def metadata(self) -> TaskMetadata: + """The task's metadata. + + For in-process handles, this is the live metadata reference. + + :return: The task metadata instance. + :rtype: TaskMetadata + """ + return self._metadata + + @property + def is_queued(self) -> bool: + """Whether this handle represents a *queued* steering input. + + ``True`` when this :class:`TaskRun` is a queued (not-yet-promoted) + steering input on a steerable chain — i.e. the request landed while a + turn was already in flight and is awaiting drain — and ``False`` for a + freshly-started or active run. A queued run's :meth:`cancel` removes the + queued slot and resolves :meth:`result` with ``TaskCancelled`` without + affecting the active turn. + + This is the supported, public way to distinguish a queued steering + handle from a freshly-started one. + + :return: ``True`` if this handle is a queued steering input. + :rtype: bool + """ + return self._queued_cancel_callback is not None + + async def result(self) -> Output: + """Await task completion and return the raw output value. + + : returns ``Output`` directly (not a wrapper). + Failures, cancellation, deferral are raised as exceptions. + + :return: The task's output value. + :rtype: Output + :raises TaskFailed: If the function raised an exception (one-shot). + :raises TaskCancelled: If the task was cancelled. + :raises TaskDeferred: If the task called ``ctx.exit_for_recovery()``. + """ + return _unwrap_result(await self._result_future) + + async def cancel(self) -> None: + """Signal cancellation to the running task. + + : sets ``ctx.cancel_requested = True`` + BEFORE setting ``ctx.cancel``, so a handler observing + ``ctx.cancel.is_set() == True`` is guaranteed to see at least + one cause boolean already ``True``. + + The handler should check ``ctx.cancel.is_set()`` (and optionally + branch on which cause boolean is set) to wind down cleanly. + + For a queued (not-yet-promoted) steering input, ``cancel()`` + removes the queued slot from the chain's pending-inputs queue + and resolves :meth:`result` with ``TaskCancelled``. The active + turn (if any) is not affected. + """ + if self._queued_cancel_callback is not None: + await self._queued_cancel_callback() + return + ctx = self._cancel_ctx_ref + if ctx is not None: + ctx.cancel_requested = True + self._cancel_event.set() + + def __await__(self) -> Any: + """Awaiting a :class:`TaskRun` returns its raw :meth:`result`. + + : resolves to ``Output`` (not a wrapper). Mirrors + ``await run.result()`` exactly. + + :return: The raw output value. + :rtype: Output + """ + return self.result().__await__() diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tasks/_task_api_logging_policy.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tasks/_task_api_logging_policy.py new file mode 100644 index 000000000000..33dbea9a9b96 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tasks/_task_api_logging_policy.py @@ -0,0 +1,135 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Task-API logging policy for the hosted task-store pipeline. + +, this policy logs request/response metadata for the +``HostedTaskProvider`` ``azure.core.AsyncPipelineClient`` chain. The +policy: + +- Logs an allow-listed set of operational headers (`x-ms-client-request-id`, + `x-ms-request-id`, `etag`, `if-match`, `retry-after`, standard Azure + operational headers like `x-ms-correlation-request-id`). +- NEVER logs the `Authorization` header (or any header whose name matches + a credential-bearing pattern). +- NEVER logs request or response bodies above DEBUG. +- Logs status codes and methods at INFO for successful responses, WARNING + for client errors (4xx), ERROR for server errors (5xx). + +Reference: spec.md, (the classifier funnels errors but does +NOT log them — that's this policy's job). +""" + +from __future__ import annotations + +import logging +from typing import Any + +from azure.core.pipeline import PipelineRequest, PipelineResponse +from azure.core.pipeline.policies import SansIOHTTPPolicy + +logger = logging.getLogger("azure.ai.agentserver.tasks.taskapi") + + +# Allow-listed operational headers. Logging anything else risks leaking +# auth, internal correlation IDs, or large payloads. +_ALLOWED_REQUEST_HEADERS: frozenset[str] = frozenset( + h.lower() + for h in ( + "x-ms-client-request-id", + "x-ms-correlation-request-id", + "if-match", + "if-none-match", + "content-type", + "content-length", + "user-agent", + "api-version", + ) +) +_ALLOWED_RESPONSE_HEADERS: frozenset[str] = frozenset( + h.lower() + for h in ( + "x-ms-client-request-id", + "x-ms-request-id", + "x-ms-correlation-request-id", + "etag", + "retry-after", + "content-type", + "content-length", + "date", + ) +) + + +def _redact_headers(headers: Any, allowed: frozenset[str]) -> dict[str, str]: + """Return a copy of ``headers`` keeping only the allow-listed keys. + + Defensive: ``headers`` may be a real Mapping or a custom HeaderDict. + Anything not in the allow-list is replaced with ``""`` + so the log line still shows the header was present without exposing + the value. + + :param headers: Header collection to copy (any mapping-like object). + :type headers: Any + :param allowed: Lower-cased header names that may be logged in full. + :type allowed: frozenset[str] + :return: A redacted copy of the headers. + :rtype: dict[str, str] + """ + if not headers: + return {} + out: dict[str, str] = {} + try: + items = list(headers.items()) + except Exception: # pylint: disable=broad-exception-caught # noqa: BLE001 + return {} + for name, value in items: + try: + key = str(name).lower() + except Exception: # pylint: disable=broad-exception-caught # noqa: BLE001 + continue + if key in allowed: + out[name] = str(value) + else: + out[name] = "" + return out + + +class TaskApiLoggingPolicy(SansIOHTTPPolicy): + """Sans-I/O logging policy for the task-store pipeline. + + Sits late in the chain (after retries and credential injection) so + each emitted line reflects what actually went over the wire. + """ + + def on_request(self, request: PipelineRequest) -> None: + if not logger.isEnabledFor(logging.INFO): + return + http_request = request.http_request + method = http_request.method + url = str(http_request.url) + headers = _redact_headers(http_request.headers, _ALLOWED_REQUEST_HEADERS) + logger.info("task-store request: %s %s headers=%s", method, url, headers) + + def on_response(self, request: PipelineRequest, response: PipelineResponse) -> None: + http_response = response.http_response + status = getattr(http_response, "status_code", 0) + if status >= 500: + level = logging.ERROR + elif status >= 400: + level = logging.WARNING + else: + level = logging.INFO + if not logger.isEnabledFor(level): + return + method = request.http_request.method + url = str(request.http_request.url) + resp_headers = _redact_headers(getattr(http_response, "headers", {}), _ALLOWED_RESPONSE_HEADERS) + logger.log( + level, + "task-store response: %s %s -> %d headers=%s", + method, + url, + status, + resp_headers, + ) diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tasks/_validation.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tasks/_validation.py new file mode 100644 index 000000000000..b1b4d28c6199 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tasks/_validation.py @@ -0,0 +1,333 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Shared field validation for the resilient-task primitive. + +Both the hosted and local providers MUST enforce the same input +validation rules so a developer running locally observes the same +accept / reject decisions they would observe deployed against the +hosted service. This module is the single source of truth for those +rules — used by ``LocalFileTaskProvider`` to reject pre-write and by +the framework-side construction code in ``_decorator.py`` / +``_run.py`` to fail fast before any provider call. + +Spec source: ``docs/task-and-streaming-spec.md`` §28a + §22.1 + §23.9 ++ §24 + C-VAL / C-LSE / C-ATT / C-LCM conformance items. +""" + +from __future__ import annotations + +import json +import re +from typing import Any + +from ._exceptions_internal import _HostedConflict + + +# ── Regex patterns (per §28a.1 / §23.9) ────────────────────────────── + +_TASK_ID_RE = re.compile(r"^[a-zA-Z0-9_-]{1,128}$") +_TAG_KEY_RE = re.compile(r"^[a-zA-Z0-9_.\-]{1,64}$") +_ATTACHMENT_KEY_RE = re.compile(r"^[a-zA-Z0-9_.\-]{1,64}$") + + +# ── Length / count / size caps (per §28a.1, §28a.2, §23.7) ────────── + +MAX_AGENT_NAME_LEN = 128 +MAX_SESSION_ID_LEN = 128 +MAX_TITLE_LEN = 256 +MAX_DESCRIPTION_LEN = 1024 +MAX_SUSPENSION_REASON_LEN = 256 +MAX_TAG_VALUE_LEN = 256 +MAX_TAG_ENTRIES = 16 +MAX_PAYLOAD_BYTES = 1024 * 1024 # 1 MB +MAX_ERROR_BYTES = 64 * 1024 # 64 KB +MAX_SOURCE_BYTES = 4 * 1024 # 4 KB +MAX_ATTACHMENT_VALUE_BYTES = 2 * 1024 * 1024 # 2 MB (also enforced in _attachments) +MAX_ATTACHMENT_ENTRIES = 20 # (also enforced in _attachments) +MAX_LEASE_IDENTITY_LEN = 256 + +# ── Lease duration bounds (per §22.1 LSE-W-1) ──────────────────────── + +LEASE_DURATION_MIN = 10 +LEASE_DURATION_MAX = 3600 + + +# ── Allowed status values + state-transition matrix (per §24, §24.1) ─ + +_LEGAL_STATUSES = {"pending", "in_progress", "suspended", "completed"} +_LEGACY_STATUS_ALIASES = {"done": "completed"} # §28a.5 + +_ALLOWED_TRANSITIONS: dict[str, set[str]] = { + "pending": {"in_progress", "completed"}, + "in_progress": {"pending", "in_progress", "suspended", "completed"}, + "suspended": {"pending", "in_progress", "suspended", "completed"}, + # 'completed' MUST be terminal except no-op completed→completed + # without other field changes (see §24.2 — checked at the call site, + # not in the matrix alone). + "completed": {"completed"}, +} + +# Fields that MUST NOT appear in a PATCH body (§28a.6 / §24). +IMMUTABLE_PATCH_FIELDS = frozenset({"id", "agent_name", "session_id", "title", "description", "source"}) + + +# ── Helpers ────────────────────────────────────────────────────────── + + +def _reject(code: str, message: str) -> None: + """Raise an ``invalid_request``-coded :class:`_HostedConflict`. + + All validation rejections funnel through here so the wire-status + and code are uniform. The framework's translation layer converts + this to the developer-facing :class:`TaskPreconditionFailed`. + """ + raise _HostedConflict(_code=code, status_code=400, message=message) + + +def _canonical_json_bytes(value: Any) -> int: + """Return the UTF-8 byte length of ``value`` serialized as canonical JSON. + + Canonicalization matches the service's measurement: + ``sort_keys=True`` + compact separators (no whitespace). + """ + return len(json.dumps(value, sort_keys=True, separators=(",", ":")).encode("utf-8")) + + +def normalize_legacy_status(status: str | None) -> str | None: + """Map legacy status aliases to canonical values (§28a.5).""" + if status is None: + return None + return _LEGACY_STATUS_ALIASES.get(status, status) + + +# ── Validators (called by both providers) ──────────────────────────── + + +def validate_task_id(task_id: str) -> None: + """C-VAL-1: task id MUST match ``^[a-zA-Z0-9_-]{1,128}$``.""" + if not task_id or not _TASK_ID_RE.match(task_id): + _reject( + "invalid_request", + "id must match [a-zA-Z0-9_-] and be 128 characters or fewer.", + ) + + +def validate_required_string(value: str | None, field_name: str, max_len: int) -> None: + """C-VAL-2 / §28a.1: a required string field is non-empty after trim + and at or under ``max_len``.""" + if value is None or not value.strip(): + _reject("invalid_request", f"{field_name} must be provided.") + if len(value.strip()) > max_len: + _reject( + "invalid_request", + f"{field_name} exceeds the maximum allowed length of {max_len}.", + ) + + +def validate_optional_string(value: str | None, field_name: str, max_len: int) -> None: + """§28a.1: an optional string field, when present, is at or under ``max_len``.""" + if value is None: + return + if len(value.strip()) > max_len: + _reject( + "invalid_request", + f"{field_name} exceeds the maximum allowed length of {max_len}.", + ) + + +def validate_tags(tags: dict[str, Any] | None) -> None: + """C-VAL-5: tag key regex, value length, total entry count.""" + if tags is None: + return + if len(tags) > MAX_TAG_ENTRIES: + _reject( + "invalid_request", + f"tags must contain {MAX_TAG_ENTRIES} entries or fewer.", + ) + for key, value in tags.items(): + if not _TAG_KEY_RE.match(key or ""): + _reject( + "invalid_request", + "tag keys must match [a-zA-Z0-9_.-] and be 64 characters or fewer.", + ) + # null-as-delete (PATCH) — value None is meaningful, skip length check + if value is None: + continue + if not isinstance(value, str): + _reject("invalid_request", "tag values must be strings or null.") + if len(value) > MAX_TAG_VALUE_LEN: + _reject( + "invalid_request", + f"tag values must be {MAX_TAG_VALUE_LEN} characters or fewer.", + ) + + +def validate_payload_size(payload: Any) -> None: + """C-VAL-6: payload canonical-JSON byte count ≤ 1 MB.""" + if payload is None: + return + if _canonical_json_bytes(payload) > MAX_PAYLOAD_BYTES: + _reject( + "invalid_request", + f"payload exceeds the maximum allowed size of {MAX_PAYLOAD_BYTES} bytes.", + ) + + +def validate_error(error: dict[str, Any] | None) -> None: + """C-VAL-6 / C-VAL-8: error JSON ≤ 64 KB; required message + type.""" + if error is None: + return + if not isinstance(error, dict): + _reject("invalid_request", "error must be an object.") + if _canonical_json_bytes(error) > MAX_ERROR_BYTES: + _reject( + "invalid_request", + f"error exceeds the maximum allowed size of {MAX_ERROR_BYTES} bytes.", + ) + msg = error.get("message") + if not isinstance(msg, str) or not msg.strip(): + _reject("invalid_request", "error.message must be a non-empty string.") + typ = error.get("type") + if not isinstance(typ, str) or not typ.strip(): + _reject("invalid_request", "error.type must be a non-empty string.") + + +def normalize_error(error: dict[str, Any] | None) -> dict[str, Any] | None: + """C-VAL-8: error PATCH defaults ``code`` to ``"error"`` if missing. + Returns the canonicalized dict (a copy with defaults applied). + """ + if error is None: + return None + out = dict(error) + if not out.get("code"): + out["code"] = "error" + return out + + +def validate_source(source: dict[str, Any] | None) -> None: + """C-VAL-6 / C-VAL-7: source ≤ 4 KB and has non-empty ``type``.""" + if source is None: + return + if not isinstance(source, dict): + _reject("invalid_request", "source must be an object.") + if _canonical_json_bytes(source) > MAX_SOURCE_BYTES: + _reject( + "invalid_request", + f"source exceeds the maximum allowed size of {MAX_SOURCE_BYTES} bytes.", + ) + src_type = source.get("type") + if not isinstance(src_type, str) or not src_type.strip(): + _reject("invalid_request", "source.type must be a non-empty string.") + + +def validate_attachment_key(key: str) -> None: + """C-ATT-8: attachment keys MUST match the regex; non-empty after trim.""" + if not key or not key.strip() or not _ATTACHMENT_KEY_RE.match(key.strip()): + _reject( + "invalid_request", + "attachment keys must match [a-zA-Z0-9_.-] and be 64 characters or fewer.", + ) + + +def validate_attachment_keys(attachments: dict[str, Any] | None) -> None: + """Validate every key in an attachments dict.""" + if not attachments: + return + for key in attachments.keys(): + validate_attachment_key(key) + + +def validate_create_status(status: str | None) -> str: + """Normalize + validate the ``status`` field on CREATE. + + Per §24 / C-LCM-1: only ``pending`` or ``in_progress`` are allowed on + create. Empty/None defaults to ``pending``. ``"done"`` normalizes + to ``"completed"`` but is then rejected (create→completed is not + allowed). ``"failed"`` is rejected outright per §28a.5. + """ + status = (status or "pending").strip().lower() + if status == "failed": + _reject( + "invalid_request", + "Unsupported status 'failed'. Represent failures as completed tasks " "with a non-null error.", + ) + normalized = _LEGACY_STATUS_ALIASES.get(status, status) + if normalized not in {"pending", "in_progress"}: + _reject("invalid_request", "status on create must be pending or in_progress.") + return normalized + + +def validate_patch_status(status: str | None) -> str | None: + """Normalize + validate the ``status`` field on PATCH. + + Per §24 / C-VAL-9. ``"failed"`` rejected, ``"done"`` normalized. + Returns the normalized status (or None when not patching status). + """ + if status is None: + return None + status = status.strip().lower() + if status == "failed": + _reject( + "invalid_request", + "Unsupported status 'failed'. Represent failures as completed tasks " "with a non-null error.", + ) + normalized = _LEGACY_STATUS_ALIASES.get(status, status) + if normalized not in _LEGAL_STATUSES: + _reject("invalid_request", f"Unsupported status '{status}'.") + return normalized + + +def validate_transition(current: str, target: str) -> None: + """C-LCM-5: enforce the §24.1 transition matrix.""" + current = normalize_legacy_status(current) or current + target = normalize_legacy_status(target) or target + allowed = _ALLOWED_TRANSITIONS.get(current, set()) + if target not in allowed: + # invalid_state_transition is technically a 409 per §39.1 but + # the framework treats it as a framework bug. Use the proper + # code; translation step handles the rest. + raise _HostedConflict( + _code="invalid_state_transition", + status_code=409, + message=f"Cannot transition task from '{current}' to '{target}'.", + ) + + +def validate_lease_params( + owner: str | None, + instance_id: str | None, + duration_seconds: int | None, +) -> tuple[str, str, int] | None: + """C-LSE-6 / C-LSE-7: all-or-nothing triplet, duration bounds. + + Returns the normalized triplet when all three are supplied, ``None`` + when none are supplied. Raises ``_HostedConflict`` when partial. + """ + any_set = bool(owner) or bool(instance_id) or duration_seconds is not None + all_set = bool(owner) and bool(instance_id) and duration_seconds is not None + if any_set and not all_set: + _reject( + "invalid_request", + "lease_owner, lease_instance_id, and lease_duration_seconds must " "be provided together.", + ) + if not all_set: + return None + assert owner is not None and instance_id is not None # type narrowing + assert duration_seconds is not None + if duration_seconds != 0 and not (LEASE_DURATION_MIN <= duration_seconds <= LEASE_DURATION_MAX): + _reject( + "invalid_request", + f"lease_duration_seconds must be 0 or between {LEASE_DURATION_MIN} " f"and {LEASE_DURATION_MAX}.", + ) + if len(owner) > MAX_LEASE_IDENTITY_LEN: + _reject( + "invalid_request", + f"lease_owner exceeds the maximum allowed length of " f"{MAX_LEASE_IDENTITY_LEN}.", + ) + if len(instance_id) > MAX_LEASE_IDENTITY_LEN: + _reject( + "invalid_request", + f"lease_instance_id exceeds the maximum allowed length of " f"{MAX_LEASE_IDENTITY_LEN}.", + ) + return (owner.strip(), instance_id.strip(), duration_seconds) diff --git a/sdk/agentserver/azure-ai-agentserver-core/docs/streaming-guide.md b/sdk/agentserver/azure-ai-agentserver-core/docs/streaming-guide.md new file mode 100644 index 000000000000..3026a905c03c --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/docs/streaming-guide.md @@ -0,0 +1,520 @@ +# Streaming guide — `azure.ai.agentserver.core.streaming` + +This package gives you one way to **emit events from one coroutine +and receive them from one or more other coroutines** — typically: +your `@task` handler produces events, and your HTTP layer fans them +out to a Server-Sent-Events / WebSocket / long-poll endpoint. + +You pick a backing once at app startup, then everywhere else you +look streams up by id and call `emit` / `subscribe`. + +--- + +## 5-minute getting started + +```python +from azure.ai.agentserver.core.streaming import streams + +# 1. At app startup — pick a backing. +streams.use_in_memory_replay(cursor_fn=lambda ev: ev["n"], ttl_seconds=600) + +# 2. The producer (e.g. your @task handler): +async def produce(stream_id: str) -> None: + stream = await streams.get_or_create(stream_id) + try: + for n in range(5): + await stream.emit({"n": n, "msg": f"hello {n}"}) + finally: + await stream.close() + +# 3. The subscriber (e.g. your HTTP handler) — attach BEFORE the +# producer starts (see §Subscribing for why): +async def consume(stream_id: str) -> None: + stream = await streams.get_or_create(stream_id) + async for event in stream.subscribe(): + print(event) + # Loop terminates cleanly when the producer calls close(). +``` + +`streams.get_or_create(id)` is idempotent: the producer and the +subscriber both call it with the same id and get the **same** +`EventStream` instance back. + +--- + +## Public surface + +Six exports, total: + +```python +from azure.ai.agentserver.core.streaming import ( + streams, # the process-level registry singleton + EventStream, # @runtime_checkable Protocol + EventStreamError, # base exception (catch-all) + EventStreamClosedError, # emit on a closed stream + EventStreamNotFoundError, # any op on an id that isn't currently a live stream +) +``` + +That's it. Obtain stream instances from the registry and program +against the `EventStream` Protocol. + +--- + +## Choosing a backing + +| Backing | Use when | Reconnect / replay? | Survives process restart? | Notes | +|---|---|---|---|---| +| `use_in_memory_live()` (default) | Single subscriber that attaches before the producer; lowest memory; you don't need late subscribers to catch up. | No — late subscribers miss earlier events. | No. | Constant memory: only the subscriber list, no event buffer. | +| `use_in_memory_replay(...)` | Multiple subscribers that may attach at different times; client may reconnect within `ttl_seconds`. | Yes (within the per-event TTL window). | No. | Each event is retained until its TTL elapses (or `delete` runs). | +| `use_file_backed_replay(...)` | Long-running turns where you need to survive a process crash and a fresh worker resuming the same turn. | Yes. | Yes — events are persisted to `storage_dir / f"{id}.jsonl"` and rehydrated on the next `get_or_create(id)`. | Single-writer-per-file enforced. | + +**Call a configurator before you create any streams** (typically +once at app startup). Later calls only affect streams created +after the call — streams already in the registry keep their original +backing. Switching mid-process is supported but discouraged. + +### Configurator signatures + +```python +streams.use_in_memory_live() -> None + +streams.use_in_memory_replay( + *, + cursor_fn: Callable[[Any], int] | None = None, + ttl_seconds: float | None = None, +) -> None + +streams.use_file_backed_replay( + *, + storage_dir: Path, + cursor_fn: Callable[[Any], int] | None = None, + ttl_seconds: float | None = None, + serializer: Callable[[Any], bytes] | None = None, + deserializer: Callable[[bytes], Any] | None = None, +) -> None +``` + +- **`cursor_fn`** — pass this if you want cursored re-subscription + (`subscribe(after=N)`) and a usable `last_cursor()`. It receives + each payload and returns an `int` you choose as its cursor (a + monotonically increasing sequence number is typical). Without it, + `subscribe(after=...)` is silently ignored and `last_cursor()` + always returns `None`. +- **`ttl_seconds`** — per-event retention. Each emitted event becomes + evictable `ttl_seconds` after its emit time, regardless of whether + the stream is still active. Use this to bound memory / disk usage. + Once the stream is closed AND its last retained event has expired + AND at least one event was ever emitted, the stream itself + transitions to "destroyed" (see §Lifecycle). A stream that was + created and closed without ever emitting stays in CLOSED forever + (or until `streams.delete(id)`). +- **`storage_dir`** (file-backed only) — directory that holds one + `.jsonl` file per stream. Created if it doesn't exist. +- **`serializer` / `deserializer`** (file-backed only) — bring your + own codec for non-JSON-serializable payloads. Defaults assume the + payload is JSON-serializable. + +--- + +## The stream id + +A stream id is the identity of a single producer/consumer +conversation. Pick the per-turn identifier from your framework: + +| Context | Use as id | +|---|---| +| Inside `azure-ai-agentserver-invocations` | `request.state.invocation_id` (HTTP layer); `ctx.input["invocation_id"]` (handler) | +| Inside `azure-ai-agentserver-responses` | `response_id` | +| Bare-Python / custom | Any per-turn `str` you control end-to-end | + +**Do NOT use a resilient `task_id` as the stream id.** A resilient task +can span multiple turns (steering, recovery). Reusing the id across +turns means the second turn finds the previous turn's already-closed +stream and `emit` raises `EventStreamClosedError`. Always scope the +id to one logical request/turn/invocation. + +**File-backed backing only:** because the file-backed backing maps +the id directly to `/.jsonl`, the id must be safe +for use as a single filename — no path separators, no characters +your filesystem rejects, ideally short. The framework-provided +`invocation_id` / `response_id` values already satisfy this; if you +mint your own id, sanitize it. + +--- + +## The `EventStream` Protocol + +Every stream — regardless of backing — exposes the same four +methods: + +```python +class EventStream(Protocol): + async def emit(self, payload: Any, *, close: bool = False) -> None: ... + async def close(self) -> None: ... + def subscribe(self, *, after: int | None = None) -> AsyncIterator[Any]: ... + async def last_cursor(self) -> int | None: ... +``` + +### `emit(payload, *, close=False)` + +Publishes one event to every currently-attached subscriber. + +- `payload` is yours — pass any value compatible with your + serializer. For file-backed replay the default expects JSON- + serializable values. +- `close=True` is an **atomic emit-and-close**: the payload is + delivered + the stream is closed in one step, with no opportunity + to emit again in between. For replay backings, the payload is + still retained in history and a late subscriber can see it; for + the live backing, late subscribers see neither the payload nor any + earlier events. +- Raises `EventStreamClosedError` if you call `emit` after `close`. + This means a producer bug (you should not be emitting any more); + HTTP layers should treat this as `5xx`, not a client error. +- Raises `EventStreamNotFoundError` if the stream has been destroyed. + +### `close()` + +Marks the stream done. Idempotent — calling it twice (or on a +destroyed stream) is a no-op, never raises. After `close()`: + +- New `emit` calls raise `EventStreamClosedError`. +- Existing subscriber iterators drain any in-flight events, then + exit cleanly with `StopAsyncIteration`. +- New `subscribe` calls still work as long as the stream hasn't yet + been destroyed (for replay backings, they will see the retained + history). + +### `subscribe(*, after=None)` + +Returns an **async iterator** over emitted payloads. **Not** a +coroutine — call it WITHOUT `await`, use directly in `async for`: + +```python +async for event in stream.subscribe(): + handle(event) +``` + +The iterator terminates cleanly with `StopAsyncIteration` when the +stream is closed (after draining any in-flight events) **or** when +the stream is destroyed while you are iterating (whether by +`streams.delete(id)` or by the auto-transition described in +§Lifecycle). `subscribe()` itself raises `EventStreamNotFoundError` +synchronously only if the stream is already destroyed at the time +you call it. + +`after=N` is the **reconnection primitive** — only yield events +whose cursor is strictly greater than `N`. Requires the active +backing to have a `cursor_fn`; silently ignored otherwise. See +§Recovery & resumption. + +Multiple subscribers are supported; each gets its own independent +queue. + +### `last_cursor()` + +Returns the highest cursor value seen so far, or `None` if no +events were emitted, or `None` if the active backing has no +`cursor_fn`. After the stream is closed, this is the last cursor +the backing saw — even if that event has since expired from +replay. Raises `EventStreamNotFoundError` if the stream is destroyed. + +`last_cursor()` is the producer's recovery primitive: a recovering +handler reads it to learn "what cursor should I assign to my next +emit?". + +--- + +## Lifecycle: ACTIVE → CLOSED → (destroyed) + +Each stream is **ACTIVE** or **CLOSED**. After CLOSED, the id may +be destroyed; once destroyed, every operation against it raises +`EventStreamNotFoundError`. + +| State | What it means | How you reach it | +|---|---|---| +| **ACTIVE** | Open to `emit`. Subscribable. | Construction (first `get_or_create(id)`). | +| **CLOSED** | No new emits (`emit` raises `EventStreamClosedError`). Existing subscribers drain. New subscribers can still attach (replay backings) but no new events arrive. | `close()` from ACTIVE. | + +Three independent paths into destroyed: + +- the id was **never registered** (no `get_or_create(id)` for it ever ran); +- the id was **explicitly `streams.delete(id)`**d; +- the id's stream was **Closed** and its close-clock TTL + (`close_time + ttl_seconds`) **elapsed** — only applies to replay + backings constructed with `ttl_seconds`. + +A few practical implications: + +- The live backing (`use_in_memory_live`) never auto-destroys — it + has no TTL machinery. Call `streams.delete(id)` explicitly if you + need to release the id. +- After `close_time + ttl_seconds`, the id is destroyed — regardless + of whether anyone is still subscribed or any retained events are + still in the buffer. +- `last_cursor()` is safe to call during the close window — a + recovering handler can always read the last cursor it had seen + before close. + +--- + +## The registry + +```python +streams.get(id) -> EventStream # raises NotFound for any id that is not currently live +streams.get_or_create(id) -> EventStream # idempotent +streams.delete(id) -> None # idempotent +``` + +- `get(id)` returns the registered stream, or raises + `EventStreamNotFoundError`. Treat any `NotFound` uniformly: + "this id is not a live stream; subscribe to a new id or treat as + missing". +- `get_or_create(id)` is idempotent — every caller using the same + id gets the same `EventStream` instance, even from concurrent + coroutines. If the id was previously destroyed, a fresh stream is + created. +- `delete(id)` removes the stream and any backing resources (including + the on-disk log for file-backed replay). Idempotent — safe to call + on an unknown or already-deleted id. + +You typically do not need to call `delete(id)` for replay backings +with `ttl_seconds` configured — the close-clock auto-destroy +cleans up for you. Call `delete(id)` explicitly when you want +immediate cleanup (end-of-request hook, test teardown) or for +backings without `ttl_seconds`. + +--- + +## Exceptions → wire mapping + +```text +EventStreamError (base — catch-all) +├── EventStreamClosedError producer bug — wire-map to HTTP 5xx +└── EventStreamNotFoundError id is not currently a live stream — HTTP 404 +``` + +Every "this id is not currently a live stream" condition raises +`EventStreamNotFoundError` (HTTP 404). Treat it uniformly: +subscribe to a new id, or render the id as missing. + +--- + +## Subscribing — the subscribe-before-start rule + +For the **default live backing** (`use_in_memory_live`), subscribers +only see events emitted after they attach. With the live backing +"attach" means **`async for` over the iterator has begun (i.e. +`__aiter__` has run)** — not merely that you've called +`get_or_create` or `subscribe`. So just calling +`asyncio.create_task(_serve_sse(stream))` does not guarantee the SSE +task has actually begun iterating before your producer starts +emitting — there is a race. + +Safe options: + +1. **Use a replay backing** (`use_in_memory_replay` or + `use_file_backed_replay`). Late subscribers catch up via the + retained history, so the race doesn't matter. This is the + recommended default for HTTP layers. +2. **Drive iteration before starting the producer.** Spawn the SSE + task, then `await asyncio.sleep(0)` (or any explicit signal from + the SSE task that it has started its `async for`) before calling + `task.start(...)`. This is harder to get right than option 1; we + recommend option 1 unless you have a strong reason to avoid + buffering. + +Once you've picked your strategy, the canonical pattern is: + +1. HTTP layer reads the per-turn id from the request. +2. HTTP layer calls `await streams.get_or_create(id)` and arranges + for a subscriber to be attached (per the strategy above). +3. HTTP layer starts the producer (e.g. `await task.start(...)`) + with the id propagated via input. +4. Producer also calls `await streams.get_or_create(id)` and gets + the same instance. + +```python +# At startup (option 1 — recommended): +streams.use_in_memory_replay(cursor_fn=lambda ev: ev["n"], ttl_seconds=600) + +# HTTP layer +async def handle_request(request): + inv_id = request.state.invocation_id + + stream = await streams.get_or_create(inv_id) # 1 + 2 + sse = asyncio.create_task(_serve_sse(stream)) # safe: replay backing + + await my_task.start( + task_id=..., + input={"invocation_id": inv_id, ...}, # 3 + ) + return StreamingResponse(...) + +# Handler +@task +async def my_task(ctx): + inv_id = ctx.input["invocation_id"] + stream = await streams.get_or_create(inv_id) # 4 — same instance + await stream.emit({"event": "hello"}) +``` + +--- + +## Recovery & resumption + +### Cursored reconnect (client side) + +If your subscriber drops (network blip, client refresh) and your +backing has a `cursor_fn`, the client reconnects with the last +cursor it saw and the SDK only re-delivers later events: + +```python +# Client reconnects with Last-Event-ID: 42 +stream = await streams.get_or_create(stream_id) +async for event in stream.subscribe(after=42): + push_to_client(event) +``` + +Events with cursor ≤ 42 are skipped from the retained history; +delivery resumes at 43. + +### Crash-recoverable producer (file-backed) + +With `use_file_backed_replay`, a fresh process resuming the same +turn rehydrates the stream automatically: + +```python +from azure.ai.agentserver.core.streaming import ( + streams, EventStreamNotFoundError, +) + +streams.use_file_backed_replay( + storage_dir=Path("/var/lib/myapp/streams"), + cursor_fn=lambda ev: ev["n"], + ttl_seconds=3600, +) + +@task +async def producer(ctx): + inv_id = ctx.input["invocation_id"] + stream = await streams.get_or_create(inv_id) + try: + # On crash recovery this is the highest n that made it to disk. + last = await stream.last_cursor() + except EventStreamNotFoundError: + # The previous run closed the stream AND every persisted event + # has since expired. The on-disk log is stale; drop it and start + # fresh. delete() removes the file and records the deletion; + # the next get_or_create() then mints a brand-new stream. + await streams.delete(inv_id) + stream = await streams.get_or_create(inv_id) + last = None + + next_n = (last + 1) if last is not None else 0 + for n in range(next_n, total): + await stream.emit({"n": n, "msg": ...}) + await stream.close() +``` + +The typical recovery scenario — process crashed mid-stream, no +terminal marker on disk — is handled by the first branch: +rehydration loads the persisted events, `last_cursor()` returns the +highest cursor, and the handler resumes emitting from the next +cursor. + +The `EventStreamNotFoundError` branch handles the edge case where the +previous run completed cleanly (wrote a close marker to disk) AND +every persisted event has since expired AND your application policy +is "start over with a fresh stream". Without the explicit +`delete(id)`, the next `get_or_create(id)` would re-hand-back the +same expired stream. `delete(id)` lets you mint a fresh one. + +### Don't double-track in `@task` metadata + +Anti-pattern: + +```python +# Don't do this. +await stream.emit({"n": n, ...}) +ctx.metadata.set("last_event_n", n) +await ctx.metadata.flush() +``` + +The stream already persisted the event; `last_cursor()` will return +`n` for you. `ctx.metadata` is for **workflow** watermarks — which +units of side-effecting work (LLM calls, tool invocations) you've +already completed — not for mirroring stream state. + +--- + +## HTTP / SSE bridging pattern + +Typical helper for serving a stream over Server-Sent-Events: + +```python +import json + +from azure.ai.agentserver.core.streaming import EventStreamNotFoundError + +async def _serve_sse(stream): + """Bridge an EventStream to an SSE wire format.""" + last_seen: int | None = None + try: + async for event in stream.subscribe(): + cursor = event.get("n") + yield f"id: {cursor}\ndata: {json.dumps(event)}\n\n".encode() + last_seen = cursor + except EventStreamNotFoundError: + # Server-side cleanup ran while we were attached; tell the + # client we're done. + yield b"event: gone\ndata: {}\n\n" +``` + +If your client sends `Last-Event-ID`, pass it through to +`stream.subscribe(after=int(last_event_id))` to skip already-delivered +events. + +--- + +## Bringing your own `EventStream` implementation + +You can write your own `EventStream` Protocol impl (e.g. a Redis- +backed stream). It will be accepted anywhere the Protocol is — the +`@runtime_checkable` decorator on the Protocol means +`isinstance(s, EventStream)` works. + +**But** don't register your custom impl with the SDK `streams` +registry — its cleanup is wired to the bundled backings only. Ship +your own peer registry instead, and let consumers pick which one +to call: + +```python +class _MyRedisStreams: + """Peer namespace to the SDK ``streams`` registry.""" + def __init__(self, *, redis_url, **opts): ... + async def get(self, id: str) -> EventStream: ... + async def get_or_create(self, id: str) -> EventStream: ... + async def delete(self, id: str) -> None: ... + +my_redis_streams = _MyRedisStreams(redis_url="...") +``` + +Consumers explicitly choose which registry they want: +`await my_redis_streams.get_or_create(id)` vs +`await streams.get_or_create(id)`. The shared interface is the +`EventStream` Protocol; lifecycle is each registry's own concern. + +--- + +## See also + +- [`tasks-guide.md`](./tasks-guide.md) — `@task` developer + guide; Pattern E shows the streaming integration end-to-end. +- `samples/resilient_streaming/resilient_streaming.py` (in this package) + — minimal standalone sample. +- `azure-ai-agentserver-invocations/samples/resilient_research/`, + `resilient_langgraph/`, `resilient_copilot/` — HTTP-server samples + exercising the registry + per-turn `invocation_id` + + subscribe-before-start pattern end-to-end. diff --git a/sdk/agentserver/azure-ai-agentserver-core/docs/task-and-streaming-spec.md b/sdk/agentserver/azure-ai-agentserver-core/docs/task-and-streaming-spec.md new file mode 100644 index 000000000000..6b7a0ae8d843 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/docs/task-and-streaming-spec.md @@ -0,0 +1,4356 @@ +# Resilient Task & Streaming Primitives — Design Specification + +**Status:** Authoritative, source-of-truth specification. +**Scope:** The **`@task` resilient-task primitive** and the **`streams` +streaming primitive** in `azure-ai-agentserver-core` — i.e. +everything that ships under `azure.ai.agentserver.core.tasks.*` +and `azure.ai.agentserver.core.streaming.*`. NOT a spec for the +rest of the core package (the hosting foundation, middleware, +logging, tracing, server-side ASGI plumbing, etc. are outside +this document's scope). +**Audience:** Implementers building or maintaining these two +primitives in any language (Python, .NET, …), and contributors +modifying the canonical Python implementation. Treat this document +as the only doc a re-implementer needs. +**Out of scope:** Everything else in `azure-ai-agentserver-core` +beyond the two named primitives. The `azure-ai-agentserver-responses` +and `azure-ai-agentserver-invocations` packages. Response-event-stream +wire shapes. HTTP route plumbing for response APIs. The platform +itself. + +This document is the authoritative single source of truth for the +two primitives in scope. + +It **references** the *Foundry Task Storage Protocol Specification* +as the authoritative description of the hosted task store's HTTP +contract (routes, request/response envelopes, server-side merge +rules, authentication, activation, ETag/CAS, error codes). Where +this spec talks about wire shape, the framework MUST conform to +that protocol spec; this spec only describes **how the framework +uses** the store, plus the framework-reserved keys / conventions +it layers on top. + +--- + +## Table of contents + +### Part I — Orientation +- §1. Purpose and design goals +- §2. Non-goals +- §3. Architecture overview +- §4. Glossary (forward-reference) + +### Part II — Programming model (developer-facing concepts) +- §5. The resilient task primitive +- §6. Lifecycle and entry mode +- §7. Identity (`task_id`, `agent_name`, `session_id`, lease owner) +- §8. Inputs, outputs, and per-input size limit +- §9. Persistence ownership (framework vs developer) +- §10. Crash recovery +- §11. Suspend, resume, and multi-turn +- §12. Steering primitive +- §13. Cancellation and cause booleans +- §14. Timeout (per-turn, cooperative) +- §15. Retry +- §16. Shutdown and `exit_for_recovery` +- §17. Metadata namespaces + +### Part III — Storage contract (wire-level) +- §18. Reference to the Foundry Task Storage Protocol +- §19. The framework's view of the task record +- §20. Framework-reserved payload keys +- §21. Framework-reserved tag and source values +- §22. Lease structure and ownership semantics (+ §22.1 lease write rules) +- §23. Attachments and input promotion (+ §23.9 key validation, §23.10 clear-all) +- §24. Status state machine (+ §24.1 transition matrix, §24.2 terminal immutability, §24.3 delete force semantics) +- §25. ETag (optimistic concurrency) usage +- §26. Recovery — internal lifecycle (no public HTTP endpoint) + +### Part IV — Provider abstraction (storage backends) +- §27. `TaskProvider` interface +- §28. Hosted provider (HTTP) +- §28a. Field validation (shared between providers) +- §29. Local provider (file-backed) +- §30. Provider auto-selection +- §31. Background loops +- §31a. List filter parity (internal `list()`) + +### Part V — Public API surface (language-agnostic) +- §32. `task` and `multi_turn_task` decorators +- §33. `Task` (one-shot) and `MultiTurnTask` (multi-turn) handles +- §34. `TaskContext` +- §35. `TaskRun` +- §35a. Read-only inspection (internal — via the task manager's provider) +- §36. `TaskRun.result()` returns `Output` directly +- §37. `TaskMetadata` +- §38. `RetryPolicy` +- §39. Error taxonomy + +### Part VI — Streaming primitive (peer subpackage) +- §40. Why streaming is decoupled from `@task` +- §41. `EventStream` protocol +- §42. The `streams` registry +- §43. Stream lifecycle states (Active ↔ Closed; registry tombstones) +- §44. Concrete backings (live, replay, file-backed) +- §45. Cursor and `subscribe(after=...)` +- §46. TTL eviction and the close-clock (replay backings) +- §47. Streaming error taxonomy +- §48. Third-party stream-impl pattern + +### Part VII — Implementation guidance (algorithms) +- §49. Cold-start sequence +- §50. `.start()` lifecycle resolution +- §51. Steering append (atomic) +- §52. Steering drain (two-phase) +- §53. Suspend write +- §54. Recovery + reclaim +- §55. Periodic recovery loop +- §56. Lease renewal loop +- §57. Per-turn watchdog +- §58. Orphan attachment cleanup + +### Part VIII — Conformance items +- §59. Conformance items (C-1 … C-N) + +### Part IX — References +- §60. References + +### Part X — Appendices (informative) +- §A. Language-mapping cheat sheet +- §B. Representative full task record +- §C. Steering sequence (append → cancel → drain → result) +- §D. Cold-start recovery sequence + +--- +## Part I — Orientation + +### §1. Purpose and design goals + +The resilient-task primitive turns a single async agent function into a +**crash-resilient, steerable, long-running** unit of work backed by a +resilient task store. It exists to close the gap between: + +- **What the platform sees.** A unit of work it can place, restart, + liveness-check, and reclaim. +- **What the application owns.** A plain function the developer writes + once, that survives container crashes, OOM kills, redeployments, and + cooperative cancellation without hand-rolling lease, heartbeat, + checkpoint, recovery, or steering plumbing. + +The streaming primitive (`azure.ai.agentserver.core.streaming`) is a +**peer** to the resilient primitive — it does *not* nest under +`@task`. It exists to give every async producer/consumer pair in the +agentserver family a single Protocol to program against (in-memory live +fan-out, in-memory replay with cursor, file-backed crash-recoverable +replay), independent of whether the producer happens to be a `@task`. + +Five design goals constrain every decision in this document: + +1. **Single invariant for the resilient primitive.** For any given + `task_id`, at most one handler runs at a time. Every other behavior + falls out of this invariant. +2. **Crash-recovery is first-class, not a feature.** Every API + decision is evaluated against the question "what does this look + like after a crash?" A primitive that disappears at the crash + boundary (a per-call kwarg, an in-memory listener, a closure-only + state) is not acceptable; it must be reified into the resilient + record or it must be on the developer. +3. **Cooperative everywhere.** The framework signals; it does not + preempt. Cancellation, timeout, and steering all reduce to "set + `ctx.cancel`; let the handler decide the terminal shape." Forced + teardown belongs to the platform layer, not the primitive. +4. **Storage shape is the public contract.** The framework writes a + structured task record. The shape of that record (which + payload keys are reserved, what attachments look like, what tags + are stamped) is part of the spec — implementers in other languages + MUST produce byte-compatible records so a recovery scan from one + process can pick up a task created by another. +5. **Pay only for what you use.** Streaming is decoupled because + handlers that do not stream pay nothing. Attachments are + thresholded because small inputs pay only the inline cost. + Steering is opt-in because non-steerable tasks pay no queue + overhead. + +### §2. Non-goals + +The primitive is intentionally narrow. The following are explicit +non-goals — they will NOT be added to the spec without explicit +re-scoping: + +1. **Not deterministic replay.** No record-and-replay of effects. + After a crash the handler is re-invoked from the top; only + resilient state (`ctx.input`, `ctx.metadata`, framework counters) + survives. Determinism inside the handler is the developer's + responsibility — the standard at-most-once side-effect pattern in + §10 covers the common case. +2. **Not a workflow engine.** No fan-out/fan-in, no child workflows, + no signals or timers as first-class primitives. Use Temporal / + Orleans for that — `@task` can live inside + such an engine but does not replace it. +3. **Not a bulk-data store.** `ctx.metadata` is small (tens of KB + per namespace; the whole task payload caps at 1 MB). It is a + watermark / dedup-token store, not a chat-log store. Per-input + payloads up to 2 MB are accepted via the attachments mechanism + (§23) but anything larger MUST be externalized by the caller. +4. **Not a competing-consumer queue.** A `task_id` identifies one + logical unit of work owned by one current lifetime. N workers + pulling jobs off a shared queue is the wrong fit; use a queue. +5. **Not multi-process streaming.** The streaming primitive's bundled + backings are single-process. A future remote-backed implementation + could plug into the same protocol but is out of scope here. +6. **No exactly-once side-effect guarantee.** The framework provides + at-most-once via a developer-issued dedup token (the at-most-once + pattern). Anything stronger requires external transactionality. +7. **Single wire shape.** The framework reads and writes exactly + the shapes documented in this spec. The primitive is in private + preview; there is no version-skew compatibility to maintain. + +### §3. Architecture overview + +The framework's runtime decomposes into the following components. +Boxes are types/objects; arrows show the dominant call direction. + +``` + ┌──────────────────────────────┐ + │ application code │ + │ (user-written @task funcs) │ + └──────────────┬───────────────┘ + │ decorator registration + ▼ + ┌─────────────┐ .start / ┌─────────────────┐ create / get / + │ caller │ ─ .run ────▶ │ Task (handle) │ ─ update / list ──▶ ┌──────────────┐ + │ (HTTP,etc.) │ ◀─ TaskRun ─ │ │ │ TaskProvider │ + └─────────────┘ Output └─────────┬───────┘ └──────┬───────┘ + │ │ + invokes user fn ┌──────┴──────┐ + │ │ Hosted via │ + ▼ │ HTTP + │ + ┌─────────────────┐ │ classifier │ + │ TaskContext │ └──────┬──────┘ + │ (ctx.input, │ │ + │ ctx.metadata, │ │ + │ ctx.cancel,…) │ ▼ + └────────┬────────┘ ┌──────────────────┐ + │ flush / suspend / │ Foundry Task │ + │ exit_for_recovery │ Storage (HTTP) │ + ▼ └──────────────────┘ + ┌─────────────────┐ ▲ + │ TaskManager │ ──── lease_renewal_loop ──────┤ + │ (singleton) │ ──── periodic_recovery_loop ─┤ + │ │ ──── timeout_watchdog ───────┤ + └─────────────────┘ │ + │ + ┌────────────────────────────────────────┐ │ + │ Local file provider (dev/test only) │ ◀──────┘ + │ (~/.agentserver-tasks///…) │ + └────────────────────────────────────────┘ + + ┌──────────────────────────────────────────────────────────────────┐ + │ Streaming subpackage (PEER — not nested under @task) │ + │ │ + │ ┌───────────────────┐ get_or_create(id) ┌──────────────┐ │ + │ │ streams registry │ ──────────────────────▶│ EventStream │ │ + │ │ (process-level) │ ◀───────────────────── │ (3 backings)│ │ + │ └───────────────────┘ delete(id) └──────┬───────┘ │ + │ │ │ │ + │ │ emit / subscribe │ + │ ▼ ▼ │ + │ use_in_memory_live() / producers / │ + │ use_in_memory_replay() / consumers │ + │ use_file_backed_replay() │ + └──────────────────────────────────────────────────────────────────┘ +``` + +**Key relationships:** + +- The `Task` handle is the developer-facing object created by the + `@task` decorator; the singleton `TaskManager` is the *runtime* + that owns the active-task table, the periodic recovery loop, and + the provider. +- The `TaskProvider` is an abstraction over the task store. Two + concrete providers ship: `HostedTaskProvider` (HTTP-backed, used + when the platform is detected) and `LocalFileTaskProvider` + (JSON-on-disk under `~/.agentserver-tasks///.json` + by default; used otherwise). The framework auto-selects. +- The `TaskContext` is what the handler receives; it is wired by the + manager and exposes both inputs (`input`, `metadata`, `entry_mode`) + and signals (`cancel`, `shutdown`, cause booleans). +- Three background loops run while the manager is up: the periodic + recovery scan (default 300s), one lease-renewal loop per active + task (half the lease duration), and one timeout watchdog per + active execution (when the task declares a timeout). +- The streaming subpackage is independent. Handlers that want to + stream do `await streams.get_or_create(id)` and `emit` / `close` + on the returned object; the HTTP layer attaches `subscribe(after=…)` + consumers. The framework never touches a stream from the resilient + path. + +### §4. Glossary (forward-referenced) + +| Term | Meaning | +|---|---| +| **Task** | A unit of resilient work, identified by `task_id`, persisted in the task store. | +| **Lifetime** | One contiguous in-memory execution of a task by a particular process. A task can have multiple lifetimes over its life (each crash starts a new lifetime). | +| **Turn** | One handler invocation. A fresh task with no resume/recover is one turn. A suspend/resume cycle is two turns. A steering-driven re-entry is the next turn. | +| **Generation / sequence number** | Monotonic counter inside the steering queue used to derive attachment keys; never reused (see §23). | +| **Lease** | The fenced ownership record on the task. While a process holds the lease, no other lifetime is allowed to run the task. | +| **Entry mode** | The framework's signal to the handler about WHY this turn started: `fresh` (first), `resumed` (after suspend or steering drain), `recovered` (previous lifetime crashed). | +| **Steering** | A new caller `.start()` against an already-running steerable task: the new input is queued, the current turn is cancelled cooperatively, and on the next turn the queued input is consumed. | +| **Attachment** | Per-task secondary storage slot for values larger than a payload-friendly inline threshold (§23). | +| **Ref / attachment ref** | The sentinel value the framework writes into `payload` to indicate "this slot has been promoted to `attachments[]`" (§23.3). | +| **Cause boolean** | A read-only field on `TaskContext` (`timeout_exceeded`, `cancel_requested`) or counter (`pending_input_count`) that explains why `ctx.cancel` was set. | +| **Promotion** | The framework's act of moving an oversized input from inline `payload` into `attachments`, replacing the inline value with a ref (§23). | +| **Drain** | Popping a single steering input off the queue and re-entering the handler with it (§52). | +| **Reclaim** | A different lifetime taking over a task whose lease has expired (§54). | + +--- + + +## Part II — Programming model + +This part is the developer-facing mental model. It is normative for +behavior visible to handler code, but the *wire-level realization* of +each concept lives in Part III. + +### §5. The resilient task primitive + +A resilient task is created by decorating a single async function: + +``` +@task(name="my_task") # decorator +async def my_task(ctx) -> Out: # exactly one parameter: TaskContext[Input] + return ... +``` + +The decoration registers the function with the process-wide +descriptor table (consulted at recovery time). The returned object — +the *task handle* — is what callers invoke (`.run()` / `.start()`). + +The framework guarantees one invariant: **for a given `task_id`, at +most one handler runs at a time in any process owning the active +lease.** Every higher-level behavior in this spec is derived from +that invariant. + +### §6. Lifecycle and entry mode + +The task store records each task in one of four statuses: + +| Status | Meaning | +|---|---| +| `pending` | Created, not yet picked up by a handler. (Rarely observed by handler code — the framework moves through it atomically.) | +| `in_progress` | A handler is currently executing this task (or claims to be — a stale lease may need to be reclaimed). | +| `suspended` | (Multi-turn only.) Handler's turn ended with `return X`; the chain is parked between turns awaiting the next `.run()` / `.start()` to drive the next turn. | +| `completed` | Terminal. The handler is finished (success, raise, cancel) and will not run again. The *outcome* (success / failure / cancelled) is communicated via the typed exceptions (§39) — **NOT encoded in the status field**. | + +Every time the framework invokes the handler, it computes an entry +mode from the persisted state and exposes it as `ctx.entry_mode`: + +| Persisted state at entry | `entry_mode` | What it means | +|---|---|---| +| No task / status `pending` | `"fresh"` | First invocation. No prior state. | +| `suspended` | `"resumed"` | Caller provided new input; resume from where we suspended. | +| `in_progress` (previous lifetime died) | `"recovered"` | We are the new lifetime; check your watermark. | +| `in_progress` (steerable, mid-flight, steering drain) | `"resumed"` (with `ctx.is_steered_turn = True`) | Another input was queued; we are the next-turn re-entry. | + +The handler is REQUIRED to be safe to enter in any of these modes. +Branching on `ctx.entry_mode` at the top is the canonical pattern. + +`entry_mode` and `is_steered_turn` are orthogonal. The combination +`(entry_mode="recovered", is_steered_turn=True)` is legal: a previous +process crashed mid-drain and the recovered handler is taking over. + +### §7. Identity + +A task is identified by three independent strings: + +| Field | Source | Lifetime | Purpose | +|---|---|---|---| +| `task_id` | Caller-supplied at `.start()` / `.run()`. | Identical across resume / recovery / steering. | The conversation / unit-of-work key. | +| `agent_name` | Platform-supplied (env `FOUNDRY_AGENT_NAME`); fallback `"unknown-agent"`. | Fixed per process. | Scoping; multiple agents share a store. | +| `session_id` | Platform-supplied (env `FOUNDRY_AGENT_SESSION_ID`). | Fixed per process. | Scoping; multiple sessions share an agent. | + +The framework derives the **lease owner** string from both +`agent_name` AND `session_id`: + +``` +lease_owner = "|session:" +``` + +Deriving the owner from BOTH components (not session alone) prevents +silent cross-agent ownership collisions in topologies where two +different agents happen to share a session identifier. + +Each *process* generates a fresh **instance id** at startup: + +``` +lease_instance_id = "worker---" +``` + +The `(owner, instance_id)` pair lets recovery distinguish: + +- **Same-owner same-instance** = my own running task (renew, do not reclaim). +- **Same-owner different-instance** = a previous lifetime of mine that + is gone (reclaim immediately on cold start; no expiry wait). +- **Different-owner** = someone else's task; do not touch. + +#### `task_id` validation + +Implementers MUST reject `task_id` values that: + +- Are empty. +- Exceed 256 characters. +- Contain characters outside `[a-zA-Z0-9\-_.:]`. + +Rejection is at the call site (`.start()` / `.run()` raise) before +any network is touched. + +### §8. Inputs, outputs, and the per-input size limit + +A task carries exactly one **input** value at any time — the value +passed to `.start(input=...)` or `.run(input=...)`. The input is JSON- +serialized for persistence and is re-hydrated into `ctx.input` on +every handler entry (fresh, resumed, recovered). + +The handler's return value (or the value passed to +(the handler's `return X`) is the **output**, also JSON-serialized. + +| Bound | Limit | Raised as | +|---|---|---| +| Per-input maximum size | **2 MB** after JSON serialization, for the function input AND each individual queued steering input. | `InputTooLarge` from `.start()` / `.run()` — pre-network, at the call site. | +| Concurrent queued steering inputs | **9** | `SteeringQueueFull` from `.start()` against a steerable task whose queue is full. | + +Inputs and outputs that fit easily in the inline payload budget stay +inline. Inputs whose JSON size exceeds a per-channel threshold are +**promoted** into the task's `attachments` slot transparently — +developers do not configure or opt in. See §23 for the wire +mechanism; the per-input ceiling above is the only developer-visible +limit. + +The framework uses JSON canonicalization rules (`sort_keys=True`, +separators `(",", ":")`) when computing serialized sizes and content +hashes (§23.6). Implementers MUST use the same canonicalization for +both, or hashes will not match across implementations. + +If the handler's input or output cannot be JSON-serialized (e.g. it +contains non-JSON-native types), the framework raises before the +HTTP call. Implementations using a richer model (Pydantic-style) +SHOULD attempt model-aware serialization (`model_dump`) first. + +### §9. Persistence ownership + +The framework persists: + +- The current `ctx.input` value (inline or as an attachment ref). +- A snapshot of every touched `ctx.metadata` namespace at every + terminal-of-turn boundary (suspend, complete, cancel, raise, + steering drain, `exit_for_recovery`) and at every explicit + `metadata.flush()` call. +- Lifecycle counters: `retry_attempt`, `recovery_count` (the + `expiry_count` of the lease record), `_last_input_id` (the + optional caller-provided chain head — see §11). +- A per-turn `_turn_started_at` ISO-8601 UTC timestamp used by the + watchdog (§14) to compute remaining budget across crashes. +- Steering state (`pending_inputs` queue, `cancel_requested`, + `drain_in_progress`, `active_input`, `next_input_seq`) for + steerable tasks (§12). +- The handler's terminal outcome: a structured `error` dict on + failure (when persisted by the layer above the primitive), + `suspension_reason` on suspend. The handler's `return X` value + is NOT persisted in the record — it resolves the in-process + caller's `TaskRun.result()` future and is then no longer + reachable from the persisted record. + +The framework does NOT persist: + +- Handler-local variables. +- In-memory closures over the handler's body. +- Caller-provided callbacks or futures (those are bound to a single + lifetime; a crash discards them). +- Streaming events (those live in the streaming subpackage, which has + its own backings; see Part VI). +- Any bulk data the developer chooses to compute. The developer is + responsible for that — typically through a sibling framework + (LangGraph checkpoint, custom DB, blob storage) with only a small + reference token in `ctx.metadata`. + +The dividing line is "what does the framework need to decide +`entry_mode` and reproduce `ctx`?" — that is what it persists; nothing +more. + +### §10. Crash recovery + +Recovery is **framework-managed**. There is no developer-tunable +threshold and no opt-in. + +**When recovery happens:** + +1. **Cold start** of a new process. The manager's `startup()` scans + the task store for tasks owned by `(agent_name, session_id)` + whose lease has expired OR whose lease is owned by a different + instance of the same owner (a previous dead lifetime). Each is + reclaimed inline. +2. **Periodic scan.** While the manager is up, a background loop + re-runs the same scan every 300 seconds (default; see §31). This + catches tasks that became reclaimable AFTER cold start — typically + leases that expired during this process's lifetime because a sibling + process died. +3. **Inline reclaim.** When a caller `.start()`s a `task_id` whose + current record shows an `in_progress` status with an expired or + foreign-instance lease, the lifecycle resolver reclaims it inline + (no waiting for the periodic scan). + +**What recovery does:** + +The reclaiming process: + +1. Issues a PATCH that re-takes the lease atomically: new + `lease_owner` (always self), new `lease_instance_id` (always + self), new `lease_expires_at`, bumps the lease's `expiry_count` IF the + previous lease had actually expired (not bumped for same-owner + dead-instance handoff). This PATCH MUST be guarded by the read + `etag` for CAS safety. +2. Reads the (now self-owned) record, looks up the registered + resume callback by `source.name` (§21), invokes the handler + with `ctx.entry_mode="recovered"` and the persisted `ctx.input` + re-hydrated. +3. From the handler's perspective, the recovery looks identical to + a fresh entry except that `entry_mode == "recovered"` and any + `ctx.metadata` writes from the previous lifetime are already + present. + +**Crash-recovery does NOT consume the retry budget** (§15). A +lifetime that died before the handler raised does not advance +`retry_attempt`. + +**Pattern — at-most-once side effect across recovery:** + +```python +if ctx.metadata.get("dedup_token") is None: + token = uuid4().hex + ctx.metadata["dedup_token"] = token + await ctx.metadata.flush() # fence + await do_side_effect(idempotency_key=token) +# crash-recovered lifetimes re-issue the call with the SAME token, +# letting the downstream system de-dupe. +``` + +This pattern is the standard answer to "I crashed mid-effect; how +do I avoid duplicate effects?" The framework does NOT provide +exactly-once semantics — the developer issues the dedup token and +fences it before the effect. + +### §11. Suspend, resume, and multi-turn + +Multi-turn chains end every turn with a bare `return X` from the +handler. The framework treats this **return-is-implicit-suspend**: + +1. Transitions the stored status from `in_progress` to `suspended` + with `suspension_reason="run_completion"`. +2. Persists a snapshot of every touched metadata namespace. +3. Does NOT persist `X` anywhere in the task record. `X` resolves + the caller's `await run.result()` in-process and is then gone. +4. Clears `payload["input"]` (and the corresponding attachment if + the input was promoted) — the consumed input is no longer needed + and would inflate the next payload write. +5. Clears `_steering["active_input"]` (mechanism state lives, but + the consumed input value goes). +6. Clears `payload["_retry_attempt"]` so the next turn starts with + a fresh retry budget. +7. Preserves `payload["_last_input_id"]` so the next + `if_last_input_id` precondition can be evaluated. + +The caller's `await run.result()` resolves to `X` directly (typed +as the handler's `Output`). No wrapper class. + +The next `.run(task_id=same, input=new)` or +`.start(task_id=same, input=new)` transitions the status back to +`in_progress` and re-invokes the handler with +`ctx.entry_mode="resumed"`, `ctx.input=new`, and `ctx.metadata` +re-hydrated. + +The same machinery is what multi-turn conversations and +human-in-the-loop approval flows ride. + +One-shot tasks do NOT use this mechanism. A one-shot `@task` +handler's `return X` is a terminal completion: the framework +resolves the caller's `.result()` with `X` and then deletes the +record (one-shot is always ephemeral). + +#### Multi-turn raise semantics + +If a multi-turn handler RAISES (an unhandled exception other than +`asyncio.CancelledError`), the chain still transitions to +`suspended` (NOT `completed` / `failed`) so subsequent turns can +continue: + +1. Transitions to `suspended` with + `suspension_reason="run_completion"`. +2. NO `payload["error"]` is written — the chain record does not + carry the per-turn failure diagnostic. +3. The framework emits a structured ERROR log named + `resilient_task_handler_failure` with `task_id`, `input_id`, + `error_type`, `error_message`. +4. The caller's `await run.result()` raises + `TaskFailed(error=TaskErrorDict(...))`. +5. Queued steerers (multi-turn `steerable=True`) promote per §12: + the next queued input becomes the next turn's input, and the + handler re-invokes with `ctx.entry_mode="resumed"`, + `ctx.is_steered_turn=True`. + +#### Chain identity: `input_id` and `if_last_input_id` + +Both `.run()` and `.start()` accept two optional keyword arguments +that thread caller-supplied chain identity through the persisted +record: + +- **`input_id`** — record-only. The framework writes + `payload["_last_input_id"] = input_id` after accepting the input; + no precondition is checked. +- **`if_last_input_id`** — precondition. The framework requires the + stored `_last_input_id` to equal `if_last_input_id` (the + predecessor the caller claims to be extending). Mismatch raises + `LastInputIdPreconditionFailed(actual_last_input_id=)`. + +For multi-turn, `input_id` is the per-turn identity. For one-shot, +`input_id` defaults to `task_id` (the 1:1 invariant `task_id == +input_id`). + +Implementations MUST reject `if_last_input_id` provided without +`input_id` (`TypeError` at the call site). The pair is orthogonal: +`input_id` alone is idempotency / chain-head tracking; +`(input_id, if_last_input_id)` together is HTTP-`If-Match`-style +chain extension. + +### §12. Steering primitive + +`@multi_turn_task(steerable=True)` upgrades a multi-turn chain from +"one turn at a time" to "callers can queue a new input while a turn +is mid-flight." + +Steering is exclusive to multi-turn chains. One-shot `@task` does +not support steering (the one-shot lifecycle is one input one run); +`@multi_turn_task` without `steerable=True` accepts concurrent +`.start` calls only as `TaskConflictError`. + +#### What `.start()` does on an in-flight steerable chain + +`.start(task_id=, input=NEW)` against an in-flight +steerable chain: + +1. The new input is **queued** at the tail of an internal + pending-inputs FIFO. +2. The cancel signal is raised on the currently-executing turn — + `ctx.cancel.is_set()` becomes True for the handler that is + running right now. `ctx.pending_input_count` flips from 0 to + the live backlog size. +3. A new `TaskRun` handle is returned to the caller. Its + `.result()` resolves with **whatever the next turn emits** — + the caller is the *steerer* of the next turn. + +If the steering queue is at its cap (9), `.start()` raises +`SteeringQueueFull`. + +#### What the first turn's caller sees + +The first turn's caller observes the natural multi-turn outcome of +the in-flight turn: + +| Handler ends turn 1 with... | First caller's `await run.result()` | +|---|---| +| `return X` (clean return) | Resolves with `X` (typed as `Output`). The chain transitions to `suspended` (return-is-implicit-suspend). The framework then promotes the queued steering input as the next turn. | +| `raise SomeError` (non-CancelledError) | Raises `TaskFailed(error=...)`. The chain stays alive in `suspended` with no `payload["error"]` written; the queued steerer is promoted as the next turn. | +| `raise asyncio.CancelledError()` | Raises `TaskCancelled()`. The chain stays alive in `suspended`; the queued steerer is promoted as the next turn. | +| Handler calls `ctx.exit_for_recovery()` (shutdown only) | Raises `TaskDeferred()`. The chain stays `in_progress`; the recovery scanner re-invokes the handler in a future lifetime. The queued steerer remains queued. | + +The handler's `return X` value is delivered **unconditionally** to +the first caller; it is never replaced by what a later turn +produces. + +#### Cooperative cancellation in steering + +`ctx.cancel` is advisory. The framework sets it when a steering +input arrives (alongside the cause counter +`ctx.pending_input_count`), but does not preempt the handler. The +handler decides: + +- **A — Yield immediately.** Check `ctx.cancel.is_set()` (or + `ctx.pending_input_count > 0`) at the next boundary and `return` + with whatever you have. +- **B — Wind down to a safe checkpoint.** Finish the current tool + call / token batch, persist a clean checkpoint, then `return` + with the final value. +- **C — Ignore cancel and finish.** Do not read `ctx.cancel`; let + the handler complete. The chain still transitions to + `suspended` and the queued steerer is promoted as the next + turn. + +#### Steering observability fields + +On a steering-driven re-entry, `TaskContext` exposes: + +- `ctx.is_steered_turn: bool` — `True` iff this turn was + constructed by the steering-drain code path. False for every + other entry path. Orthogonal to `entry_mode`: + `(entry_mode="recovered", is_steered_turn=True)` is legal. +- `ctx.pending_input_count: int` — live count of currently queued + steering inputs. Reads as 0 for non-steerable chains. Useful for + "I am three turns behind, I should short-circuit even harder" + decisions. It is derived from the **in-process observed** steering + state (the property is synchronous — it does NOT issue a store read + per access), and is **failure-tolerant** (any compute failure reads + as 0). It is recorded *before* `ctx.cancel` is set (see §13 ordering + invariant) by both the same-process enqueue and the cross-process + steering poll, and is decremented as the drain consumes inputs, so a + handler that observes `ctx.cancel.is_set()` for a steering cause + already sees `pending_input_count >= 1`. It must be backed by a + settable runtime field (historically it was read from an attribute + that was never storable, so it was stuck at 0). + +#### Force delete + +`MultiTurnTask.delete(task_id)` is the only API that force-removes +a chain. It cancels the in-flight turn (active caller's +`.result()` resolves with `TaskCancelled`), resolves all queued +steerer callers' `.result()` futures with `TaskCancelled`, and +force-deletes the record. Idempotent (no-op on a missing chain). + +### §13. Cancellation and cause booleans + +`ctx.cancel` is a bare event (e.g. `asyncio.Event` in Python). The +framework sets it from multiple causes; a handler observing the bare +event does NOT know *why* it was set. Three independent **cause +booleans** answer the why: + +| Cause | Set when | Reset? | +|---|---|---| +| `ctx.timeout_exceeded: bool` | Per-turn timeout watchdog has fired for this turn. | Never within a turn. | +| `ctx.cancel_requested: bool` | `TaskRun.cancel()` was invoked against this run from external caller code. | Never within a turn. | +| `ctx.pending_input_count: int` (read as a count, not boolean) | Live count of queued steering inputs >= 1. | Decrements as drains consume inputs. | + +**Causes accumulate.** Multiple cause booleans can be `True` +simultaneously (e.g., timeout AND external cancel AND steering). + +**Ordering invariant.** Each cause is set BEFORE the framework sets +`ctx.cancel`. A handler observing `ctx.cancel.is_set() == True` is +guaranteed to see at least one cause already set (cause booleans +or pending_input_count > 0). + +Canonical reaction pattern: + +```python +while not ctx.cancel.is_set(): + await do_a_unit_of_work() +# Branch on cause: +if ctx.timeout_exceeded: + return "(timed out — partial result)" +if ctx.cancel_requested: + raise asyncio.CancelledError() # caller observes TaskCancelled +if ctx.pending_input_count > 0: + return "(pre-empted by queued steering input)" +raise RuntimeError("ctx.cancel set with no recognised cause") +``` + +The handler's choice of terminal shape (`return X` / `raise`) +controls what the caller observes. The framework does NOT pick +the terminal shape on the handler's behalf. For multi-turn, +`return X` is the implicit-suspend boundary (chain stays alive, +caller's `.result()` resolves to `X`); for one-shot, `return X` +ends the run (record is deleted). + +### §14. Timeout (per-turn, cooperative) + +`@task(timeout=...)` is **cooperative-only**. When the budget elapses, +the framework: + +1. Sets `ctx.timeout_exceeded = True`. +2. Sets `ctx.cancel`. +3. Exits the watchdog. + +It does **NOT** force-stop the handler, end the task, or cancel +the lease renewal. An ignoring handler runs until process exit or +external `TaskRun.cancel()`. + +The budget is **per-turn** and **wall-clock**: + +- Each handler turn (fresh entry, suspended-to-resume) gets a + fresh budget. +- A process crash mid-turn does NOT reset the budget. When the + recovered handler enters, the watchdog computes + `remaining = max(0, timeout - (now - turn_started_at))` from the + persisted `_turn_started_at` and fires immediately if elapsed. +- Clock skew is clamped to `[0, timeout]` in both directions. +- **Known gap on steering drain re-entry:** the canonical Python + implementation spawns the watchdog ONCE per `_execute_task` + invocation; steering drain re-enters in-place inside + `_execute_task_loop` without spawning a fresh watchdog. The + steered turn inherits whatever budget remained on the original + watchdog. The persisted `_turn_started_at` IS stamped per drain + (§52 Phase 1), so a CRASH-then-recover from a drained turn + correctly honors the new turn's budget; the in-process drain + path itself does not. Other-language implementers SHOULD spawn + a fresh watchdog per drain to honor the design intent. + +The framework MUST persist `payload["_turn_started_at"]` (ISO-8601 +UTC) at every turn-start boundary: fresh entry, suspended -> in_progress +resume, steering drain re-entry. It is NOT re-stamped on crash +recovery — that is precisely what allows the watchdog to honor the +original budget across crashes. + +### §15. Retry + +`@task(retry=RetryPolicy(...))` and +`@multi_turn_task(retry=RetryPolicy(...))` configure the framework's +retry behavior for handler-raised exceptions. + +`RetryPolicy` parameters: + +| Field | Default | Meaning | +|---|---|---| +| `max_attempts` | `3` | Total failure-retry budget across all lifetimes. Counts the original try. | +| `initial_delay` | `1 second` | Delay before the first retry. | +| `backoff_coefficient` | `2.0` | Multiplier for exponential backoff. | +| `max_delay` | `60 seconds` | Cap on per-retry delay. | +| `jitter` | `True` | Add randomized jitter to delays. | +| `retry_on` | `None` (all exceptions) | Tuple of exception types to retry; others propagate. A bare exception class is accepted as a single-element tuple. | + +Presets: `exponential_backoff()`, `fixed_delay(delay)`, +`linear_backoff()`, `no_retry()`. + +Semantics: + +- **`retry_attempt` is the cross-lifetime counter.** Persisted as + `payload["_retry_attempt"]`. Re-hydrated on every handler entry + via `ctx.retry_attempt`. Increments only when the handler raises + (not on crash). Cleared on every turn-start boundary so each new + turn (multi-turn) or each new run (one-shot) gets a fresh budget. +- **Crash recovery does NOT consume the budget.** A lifetime that + is gone before the handler raised does not advance + `retry_attempt`. The recovered handler sees the same + `ctx.retry_attempt` value the crashed lifetime saw. +- **`return X` bypasses retry.** A handler that returns + (multi-turn = implicit suspend; one-shot = terminal completion) + is not a failure; the retry counter is unaffected. +- When `retry_attempt >= max_attempts`, the framework gives up: + it stops re-invoking, and the awaiting caller observes + `TaskFailed(error=TaskExhaustedRetriesErrorDict(...))` carrying + `attempts`, `last_error`, `last_error_type`, `traceback`. + +#### Interim retry persistence + +Between every failed attempt and the next retry the framework +PATCHes only `payload["_retry_attempt"] = `. NO +`payload["error"]` is written between attempts — the per-turn +failure diagnostic is not projected onto the record. The status +stays `in_progress` throughout. + +When the budget is exhausted (or the exception is non-retryable), +the failure handler runs: + +- **One-shot (`@task`)**: the record is DELETED entirely; nothing + survives on disk. The caller observes `TaskFailed` raised from + `.result()`. +- **Multi-turn (`@multi_turn_task`)**: the chain transitions to + `suspended` with `suspension_reason="run_completion"`; NO + `payload["error"]` is written; queued steerers promote per §12. + The caller of the failing turn observes `TaskFailed` raised + from `.result()`. The chain stays alive — a future + `.run()`/`.start()` against the same `task_id` resumes the + chain with a fresh retry budget. + +The framework emits a structured ERROR log named +`resilient_task_handler_failure` on every handler raise (including +non-final attempts). Observers learn "what just failed, which +attempt am I on" from logs, NOT from a persisted `error` field on +the record. + +`TaskFailed.error` is one of two `TypedDict` shapes: + +```python +class TaskErrorDict(TypedDict): + type: str # exception class name, e.g. "ValueError" + message: str # str(exc) + traceback: str # traceback.format_exc() + +class TaskExhaustedRetriesErrorDict(TypedDict): + type: Literal["exhausted_retries"] + attempts: int + last_error: str + last_error_type: str + traceback: str +``` + +Type-checkers can discriminate on the `type` literal. + +### §16. Shutdown and `exit_for_recovery` + +The container can be shut down at any time (deployment, rolling +restart, eviction). The framework sets `ctx.shutdown` when it +receives the shutdown signal. The handler has three legitimate +responses: + +| Shape | When to use | Stored outcome | Caller observes | +|---|---|---|---| +| `await ctx.exit_for_recovery()` | Container shutting down AND you want this turn re-entered later. | `in_progress` (preserved across shutdown). | `TaskDeferred`. | +| `return X` (multi-turn) | Handler reached a clean checkpoint AND wants to expose `X` to the caller. | `suspended` (caller can `.run()` again to drive the next turn). | `X` (typed as `Output`). | +| `raise asyncio.CancelledError()` | Handler decided to abort. | One-shot: record deleted. Multi-turn: chain transitions to `suspended` (stays alive). | `TaskCancelled()`. | + +`ctx.exit_for_recovery()` is the resilient-deferral primitive. The +method: + +1. Flushes all touched metadata namespaces. +2. **Releases ownership** of the persisted record so the next + process can take over (force-expires the lease). +3. Leaves status as `in_progress` (NOT `suspended`). +4. Raises `TaskDeferred()` upward — the caller of `.result()` + sees this. Semantically distinct from `TaskCancelled`: the + task is not cancelled; this lifetime is just deferring to the + next. +5. Preserves any queued steering inputs — they are NOT drained + during shutdown; on recovery they remain queued. + +When the recovery scanner re-acquires the deferred task, the +handler re-enters with `ctx.entry_mode="recovered"` and the +persisted `payload["input"]` — exactly as if the lifetime had +crashed. + +Misuse: calling `ctx.exit_for_recovery()` when +`ctx.shutdown.is_set() == False` MUST raise `RuntimeError` at the +call site. This makes misuse loudly visible to operators (the task +ends in error, not silently `in_progress`). + +### §17. Metadata namespaces + +`ctx.metadata` is a **callable namespace facade** for the small, +resilient, per-task state the handler owns: + +- `ctx.metadata["key"] = value` — read/write the **default** + namespace, persisted at `payload["metadata"]`. +- `ctx.metadata("session")["upstream_id"] = sid` — read/write a + **named** sibling namespace, persisted at + `payload["metadata:session"]`. + +Each namespace is independent: a write to one does not dirty the +other; `flush()` on one persists only that namespace's data. + +`metadata.flush()` is the fence the developer uses to make +at-most-once side-effect patterns work across a crash. The framework +**auto-flushes** all touched namespaces at every terminal-of-turn +boundary, so writes the developer forgets to flush are still resilient +across a graceful boundary. Explicit `flush()` is for mid-handler +fence semantics. + +**Naming convention:** namespaces and top-level metadata keys +starting with `_` are RESERVED for the framework. The primitive +treats this as a convention at the API surface; layers built on top +(e.g. the responses framework's `_responses` namespace) MAY enforce +it more strictly. + +`TaskMetadata` MUST expose dict-like semantics +(`__getitem__`/`__setitem__`/`__contains__`/`__iter__`/`.get()`/`.to_dict()`) +plus: + +- `flush()` — persist this namespace only. +- `increment(key)` — in-memory atomic numeric increment **on the + metadata namespace object** (read/modify/write under an in- + memory lock). The change is NOT pushed to the store until the + next `flush()` / auto-flush. This is NOT a store-level + compare-and-swap; concurrent processes incrementing the same + key would race at the store level. Use for handler-local + counters that get flushed at clean boundaries; for cross- + process atomic counters, use the store's CAS protocol directly + via the provider. +- `append(key, value)` — append to a list-valued key. Same + in-memory semantics as `increment`: atomic within the namespace + object, NOT atomic against the resilient record. + +Flush failures are logged, not raised — a failed flush should not +crash a handler. The framework retries on the next flush call or +auto-flush boundary. + +--- + + +## Part III — Storage contract (wire-level) + +This part documents how the framework projects the programming model +onto the resilient task record. The HTTP routes, request/response +envelopes, and server-side merge rules themselves are defined by the +*Foundry Task Storage Protocol* specification; this section names which +fields the framework reads/writes and what the framework-reserved +keys mean. + +### §18. Reference to the Foundry Task Storage Protocol + +The hosted task store's transport-level contract — routes +(`POST /tasks`, `GET /tasks`, `GET /tasks/{id}`, `PATCH /tasks/{id}`, +`DELETE /tasks/{id}`), authentication, activation, payload PATCH merge +semantics, attachment PATCH merge semantics, ETag/CAS rules, +classification of 409/412 responses — is specified by +`foundrysdk_specs/specs/hosted-agents/container-spec/docs/foundry-task-storage-protocol-spec.md`. + +This document does **not** restate that contract. Implementers MUST +conform to the protocol spec for any hosted-provider implementation. +The conformance items in §59 reference both this document and the +protocol spec. + +Where this spec uses terms like "PATCH" or "etag", it does so under +the protocol spec's definitions. + +### §19. The framework's view of the task record + +The framework writes/reads the following fields on every task record. +Field meanings beyond this table are defined in the protocol spec. + +| Field | Type | Owned by | Set on | +|---|---|---|---| +| `id` | string | caller | `create`. | +| `agent_name` | string | framework | `create`. | +| `session_id` | string | framework | `create`. | +| `status` | `pending` / `in_progress` / `suspended` / `completed` | framework | `create`, status transitions (§24). | +| `title` | string \| null | caller | `create` (optional). | +| `description` | string \| null | caller | `create` (optional). | +| `lease` | LeaseInfo (§22) | framework | `create`, every renewal, every reclaim. | +| `payload` | object | framework + developer | almost every transition (§20). | +| `tags` | map of string -> string | framework + caller | `create` (framework stamps `_task_name`); caller-set tags allowed. | +| `error` | object \| null | framework | on handler raise. | +| `suspension_reason` | string \| null | framework | on suspend. | +| `source` | object | framework | `create` (§21). | +| `attachments` | object \| null | framework + developer | on input promotion / drain / suspend / orphan cleanup (§23). | +| `etag` | string | server | every server-issued response. | +| `created_at` | ISO-8601 string | server | `create`. | +| `updated_at` | ISO-8601 string | server | every PATCH. | +| `started_at` | ISO-8601 string \| null | server | **set once on first `in_progress` transition; never updated thereafter** (lease re-acquisition, recovery scanner takeover, and suspend/resume cycles do NOT reset). | +| `completed_at` | ISO-8601 string \| null | server | terminal transition. | + +Caller-controlled fields (`tags` keys NOT starting with `_task_`, +`title`, `description`) are passed through verbatim. Framework-owned +fields MUST NOT be set by caller code. + +### §20. Framework-reserved payload keys + +`payload` is the JSON object that holds both the framework's +runtime state and the developer's metadata. The framework reserves +the following top-level keys, all starting with `_` or named +`input`/`metadata`/`output`: + +| Key | Type | Lifetime | Meaning | +|---|---|---|---| +| `input` | any JSON value, or a ref dict (§23) | Set on every `in_progress` transition; cleared at suspend; cleared by drain after consumption. | The current input value (or a ref to its attachment). | +| `metadata` | object | Persisted at boundaries; auto-flushed. | The DEFAULT user metadata namespace. | +| `metadata:` | object | Same as above. | NAMED user metadata namespace ``. | +| `_last_input_id` | string \| null | Set when caller supplies `input_id`. | Chain-head tracking (§11). | +| `_turn_started_at` | ISO-8601 UTC string | Set at every turn-start boundary; NEVER re-stamped on recovery. | Source of truth for the per-turn watchdog (§14). | +| `_retry_attempt` | integer | Incremented on handler raise; reset to 0 on steering drain. (Not also reset on success in the canonical Python implementation.) | Resilient retry counter (§15). | +| `_steering` | object (see below) | Only present on steerable tasks. | Steering mechanism state (§12). | + +The framework does NOT persist the handler's return value in the +task record. There is no `payload["output"]` key and no `_output` +attachment. The handler's return value resolves the in-process +caller's `TaskRun.result()` future and is then no longer reachable +from the persisted record. Per-turn outputs that need to survive +crashes are the handler's responsibility — write them through +your own storage (e.g., LangGraph checkpoint, your own DB) before +returning. + +Likewise, `error` from a handler raise is NOT persisted. The +framework emits a structured ERROR log (named +`resilient_task_handler_failure`) on every handler raise, but the +chain record itself does not carry the per-turn diagnostic. + +`_steering` object shape: + +| Sub-key | Type | Meaning | +|---|---|---| +| `pending_inputs` | array of input values OR refs (§23) | FIFO of queued steering inputs. | +| `next_input_seq` | integer | Monotonic counter for promoted-attachment key allocation (NEVER reused). | +| `cancel_requested` | boolean | Resilient cancel signal; set on steering append; cleared after drain when pending is empty. | +| `drain_in_progress` | boolean | True between the start of a drain PATCH and the next turn-start; protects against partial drain on crash. | +| `active_input` | any JSON value OR ref | The single input being drained (mirror copy used by the race-recovery contract). Cleared at suspend / terminal. | + +Implementers in other languages MUST use these exact key names. A +process built in language X must be able to recover a task created +by language Y. + +Keys NOT in this table are caller-controlled (e.g. user metadata +namespaces); the framework leaves them alone. + +### §21. Framework-reserved tag keys and `source` shape + +#### Reserved tag keys + +The framework stamps the following `tags` entries on `create`: + +| Tag key | Value | Purpose | +|---|---|---| +| `_task_name` | The decorator's `name` (or `fn.__qualname__` fallback). | Server-side `LIST` filtering by task name. | + +Tag keys starting with `_task_` are RESERVED. Caller-supplied tags +using this prefix are stripped at the call site with a warning; +the framework does not pass them to the server. + +#### `source` shape + +The framework stamps `source` on `create`: + +``` +{ + "type": "agentserver.task", + "name": "", + "server_version": "/ (/)" +} +``` + +`source.name` is the **canonical identity anchor** for recovery +routing — the framework looks up the registered handler callback +by matching `source.name` against the decorator-supplied names. +`source.type` is currently a single fixed string but is reserved +for future namespacing. + +### §22. Lease structure and ownership semantics + +`lease` is a sub-object with the following fields: + +| Field | Type | Meaning | +|---|---|---| +| `owner` | string | `\|session:` (§7). Stable across process lifetimes. | +| `instance_id` | string | `worker---`. Fresh per process. | +| `generation` | integer | Increments each time the lease is re-acquired with a different `instance_id`. Mirrored to `ctx.recovery_count`. The local provider AND the hosted task store both bump this. | +| `expires_at` | ISO-8601 UTC string | When the lease expires (and another process may reclaim). | +| `expiry_count` | integer | Number of times ownership has changed via **actual expiry** (i.e. lease was reclaimed because the prior lease's `expires_at` passed, NOT because the same owner restarted). **Server- / provider-only counter** — the framework never writes this field (it is not on `TaskPatchRequest`). The hosted task store bumps it; the local file provider also bumps it on actual-expiry reclaim for parity (so local-mode tests can assert expiry-counter behavior). Surfaced on the framework's internal `TaskInfo`; NOT projected onto the public `TaskRun` handle (lease bookkeeping is framework-internal). | +| `heartbeat_at` | ISO-8601 UTC string | Wall time of the most recent lease write (acquisition, renewal, or force-expire). Stamped by the provider on every lease-touching PATCH. **Provider-only field** — the framework never writes this; consumers and observability tooling read it to distinguish "fresh lease" from "lease that hasn't expired yet". NOT projected onto the public `TaskRun` handle — it's a framework / operator concern, not a developer one. | + +The framework's interaction with the lease: + +- On `create`, the framework sets `lease_owner = self.owner`, + `lease_instance_id = self.instance_id`, and + `lease_duration_seconds = 60` (the framework default). +- The lease renewal loop (§56) renews at half the lease duration + (every 30s by default), but its next tick is computed + DYNAMICALLY from the per-task last-refresh time, NOT a fixed + cadence. So a PATCH within the last `interval` seconds fully + shadows the next heartbeat. +- **Every PATCH the framework issues** (renewal, metadata, + steering, suspend, drain, complete, fail, reclaim) MUST + piggyback (`lease_owner`, `lease_instance_id`, + `lease_duration_seconds`) to refresh the lease as a side effect. + See §25.4. +- On reclaim (§54), the framework PATCHes the lease to itself with + `if_match: ` for CAS. BOTH the inline reclaim + AND the cold-start/periodic scan reclaim use `if_match` (closes + the prior known gap). +- On `ctx.exit_for_recovery()` (§16), the framework force-expires + the lease so the next process can reclaim immediately. + +The framework recognizes three lease states for a foreign-instance +or expired record: + +1. **Live and same-instance** — my own running task; do nothing. +2. **Live and different-instance, same-owner** — a previous lifetime + of mine. RECLAIM immediately (no expiry wait). `expiry_count` is + NOT bumped (the server only bumps on actual-expiry handoff, and + this isn't one). +3. **Expired (any owner)** — RECLAIM. `expiry_count` IS bumped + (server-side, in the hosted store; AND in the local provider + for parity — see the table above). + +**Important: the framework never writes `expiry_count`.** It is not +a field on `TaskPatchRequest` (only `lease_owner`, +`lease_instance_id`, `lease_duration_seconds` are writable). The +hosted task store and the local file provider both increment it +server-side / provider-side on actual-expiry ownership change; the +framework only reads it. + +#### 22.1 Lease write rules (provider-enforced, identical for hosted and local) + +These rules MUST be enforced by **both** providers identically. +Violations raise the internal `_HostedConflict` (§39) which the +framework translates to public exceptions per the translation table +(also §39). Local file provider raises the same logical conditions +directly, with the same internal classification, so the framework +behaves identically against either backing. + +| # | Rule | When violated | +|---|---|---| +| LSE-W-1 | `lease_duration_seconds` MUST be `0` (force-expire) OR in the range `10..3600` (renewal). | Reject as `invalid_request` (400). | +| LSE-W-2 | The triplet `(lease_owner, lease_instance_id, lease_duration_seconds)` is all-or-nothing. Supplying any one without all three is rejected. | Reject as `invalid_request` (400). | +| LSE-W-3 | Lease acquisition / renewal against a record whose lease is currently held by a **different** owner AND not yet expired is rejected. | Raise `_HostedConflict(_code="lease_held_by_another")` → `TaskConflictError(current_status="in_progress")`. | +| LSE-W-4 | When transitioning a task from `in_progress` → `pending`, the supplied `(lease_owner, lease_instance_id)` MUST match the record's current lease. | Raise `_HostedConflict(_code="lease_held_by_another")`. | +| LSE-W-5 | Lease renewal (no status change, `lease_duration_seconds > 0`) is only valid when the current status is `in_progress`. Renewing on `pending` / `suspended` / `completed` is rejected. | Reject as `invalid_request` (400). | +| LSE-W-6 | `lease_duration_seconds = 0` (force-expire) cannot be combined with a status transition in the same PATCH. | Reject as `invalid_request` (400). | +| LSE-W-7 | Force-expire (`lease_duration_seconds = 0`) requires the caller's `(lease_owner, lease_instance_id)` to match the current lease UNLESS the lease is already expired (in which case any caller may force-expire). | Raise `_HostedConflict(_code="lease_held_by_another")` if mismatched and lease is still live. | +| LSE-W-8 | `started_at` is **immutable** after the first `in_progress` transition. Lease re-acquisition (including expired-lease takeover by a different owner OR same-owner restart) MUST NOT update `started_at`. The original wall-clock time of the first turn-start is preserved across recovery, restarts, and suspend/resume cycles. | (Behavioral — observable via the task manager's provider; not on the public `TaskRun` handle.) | +| LSE-W-9 | On lease handoff to a different owner where the prior lease was **expired**, `expiry_count` MUST be incremented. Same-owner different-instance handoff before expiry does NOT bump. | (Behavioral — observable via the task manager's provider; not on the public `TaskRun` handle.) | +| LSE-W-10 | On every successful lease write (acquisition, renewal, force-expire), the provider MUST stamp the lease's `heartbeat_at` field to "now". This field exists on `LeaseInfo` so consumers and observability tooling can distinguish a fresh lease from one that simply hasn't expired yet. | (Behavioral — observable through `LeaseInfo.heartbeat_at` in the internal `TaskInfo`. Not on the public surface.) | + +### §23. Attachments and input promotion + +The hosted task store provides a second per-task storage slot, +`attachments`, alongside `payload`. The two stores have different +budgets: + +| Slot | Per-task cap | Per-value cap | Entry count cap | +|---|---|---|---| +| `payload` | 1 MB | n/a (shared) | unlimited keys | +| `attachments` | n/a (per-entry only) | 2 MB per attachment | 20 attachments max | + +`attachments` lets the framework lift the per-input ceiling from +"however much fits in payload alongside everything else" to +**2 MB per input** without evicting metadata budget. + +#### 23.1 PATCH merge semantics + +The hosted store's merge semantics for `attachments` mirror `tags`: + +- Key present with non-null value -> **upsert** (new) or **replace** (existing). +- Key present with `null` -> **delete** that entry. +- Key absent -> **unchanged**. +- `attachments` field absent entirely -> no attachment changes. + +PATCHes that include BOTH `payload` and `attachments` are atomic +across both stores. This is load-bearing: every promote, drain, +suspend, and orphan-cleanup write co-PATCHes payload + attachments +in a single round trip. + +#### 23.2 Thresholds + always-attachment for output (framework-owned) + +The framework treats different channels differently. Inputs use a +size threshold; output ALWAYS uses an attachment (no threshold, +no inline shape). + +| Channel | Promotion rule | Attachment key | +|---|---|---| +| Function input (`payload["input"]`) | > 200 KiB serialized → ref; otherwise inline. | `_input` | +| Each steering input (entry in `_steering["pending_inputs"]`) | > 20 KiB serialized → ref; otherwise inline. | `_steering_input_` | + +Different rules because: + +- The function input is set once per turn-start. A 200 KiB inline + budget keeps small inputs cheap and only spills clearly-large ones. +- Steering inputs may accumulate (up to 9 queued). A 20 KiB + threshold caps the worst-case inline payload contribution from + steering at ~180 KiB even when the queue is full. + +There is no `_output` channel and no output promotion. The +framework does not persist handler return values; outputs resolve +the in-process caller's `TaskRun.result()` future directly and are +never projected onto the task record. + +Sizes are measured in bytes of canonical JSON +(`sort_keys=True`, separators `(",", ":")`). + +Worst-case framework attachment usage: +`_input` (1) + `_steering_input_*` (up to 9) = +**10 of 20** per-task attachment slots. Leaves 10 slots free for +future use. + +#### 23.3 Wire shapes — two only + +A slot that would hold an input (`payload["input"]`, an entry in +`_steering["pending_inputs"]`) is represented in exactly one of two +shapes: + +**Inline** (size <= threshold): the raw JSON value, verbatim. + +**Ref** (size > threshold): a single-magic-key dict pointing at the +attachment: + +```json +{ + "__attachment_ref__": { + "key": "", + "hash": "sha256:<64 lowercase hex chars>" + } +} +``` + +**Detection rule** (used everywhere the framework reads a slot): +the slot is a ref iff (1) it is a JSON object, (2) it has exactly +one key, (3) that key is `__attachment_ref__`, (4) the value is an +object with both `key` and `hash`. Everything else is inline. + +The inline + ref shapes are **disjoint**: a developer-supplied +inline value cannot accidentally be misread as a ref because the +detection rule's 4-step structure is too specific to occur +incidentally. + +#### 23.4 Single wire shape + +The framework reads and writes exactly the inline + ref shapes +documented in §23.3. The primitive is in private preview; there is +no version-skew compatibility to maintain. + +#### 23.5 Sequence number invariants (steering) + +`payload["_steering"]["next_input_seq"]` is the monotonic counter +the framework uses to derive `_steering_input_` keys. Critical +invariants: + +- **Advances ONLY on promotion.** Inline steering appends do not + bump `next_input_seq`. +- **Never reused.** A drained-and-deleted key is never re-allocated; + the next promoted append always uses the current + `next_input_seq`, then `next_input_seq += 1`. +- **Stable for surviving entries.** A drain pops the head of + `pending_inputs` and (if it was a ref) deletes the corresponding + `_steering_input_` attachment. It does NOT renumber any + other entry. A queue of `[ref_3, ref_4]` becomes `[ref_4]` after + one drain; `ref_4` keeps its key. + +This invariant is what allows the framework to drain without +re-uploading attachments — a property that would be impossible if +keys encoded queue position. + +#### 23.6 Content hash + +Every ref carries `hash: "sha256:"` where the hex is the +SHA-256 of the canonical JSON bytes +(`sort_keys=True`, separators `(",", ":")`) of the attachment +value. The framework writes the hash on promotion. + +**Hash validation (known gap).** The canonical Python +implementation today writes the hash on every promotion but does +NOT validate it on read — `_read_input_value()` resolves the ref +key against `attachments` and returns the value without +recomputing the hash. Other-language implementers SHOULD validate +on read (recompute hash from the attachment value, compare against +the ref's hash, raise on mismatch) to detect store-side +corruption. Cross-implementation byte-compatibility requires using +the SAME canonicalization rules so a write from one language can +be validated by another. + +The hash is sufficient for ref validity once validated (no separate +write-timestamp is needed): SHA-256 birthday-bound collision +probability at fleet trillion/sec × 100 years is < 1 in 10^33. + +#### 23.7 Caps and pre-network enforcement + +Caps: + +- Per-attachment value: **2 MB** serialized. +- Per-task attachment count: **20**. + +The framework enforces (pre-network) and surfaces developer-facing +exceptions based on which channel the violation occurs on: + +| Cap | Where enforced | Developer-facing exception | +|---|---|---| +| Per-value (2 MB) on `_input` | Create + PATCH, both providers | `InputTooLarge` (the framework remaps an internal `_AttachmentTooLarge` based on attachment-key prefix) | +| Per-value (2 MB) on `_steering_input_` | Steering append site (always reads state first to count) | `InputTooLarge` | + +| Per-task count (20) on `create` | Create path | `_AttachmentLimitExceeded` (internal) — reachable only via direct provider use, which is unsupported | +| Per-task count (20) on `patch` | Local provider (cheap count); hosted PATCH relies on server-side check | `_AttachmentLimitExceeded` (internal) | + +Internal exceptions `_AttachmentTooLarge` and +`_AttachmentLimitExceeded` are **provider-internal** — they are +NOT exported from `tasks/__init__.py`. The framework catches +`_AttachmentTooLarge` and re-raises the appropriate developer- +facing exception based on the attachment key prefix (`_input` / +`_steering_input_*` → `InputTooLarge`). +`_AttachmentLimitExceeded` is unreachable in normal framework +operation (worst case is 11 of 20 slots; see §23.2) and if it ever +propagates indicates a framework bug — caught at the boundary and +converted to `RuntimeError`. + +#### 23.8 Atomic co-writes + +These transitions MUST be single PATCHes carrying BOTH `payload` and +`attachments`: + +1. **Promote on `.start()` (fresh)**: `attachments["_input"] = ` + + `payload["input"] = {ref}` (CREATE on the hosted store). +2. **Promote on resume**: same fields, but PATCH. +3. **Suspend (multi-turn turn-end via `return X`)**: + - `payload["input"] = null` + - `payload["_steering"]["active_input"] = null` + - `payload["_retry_attempt"] = null` (fresh budget for the next turn) + - `attachments["_input"] = null` (delete) IF the input was a ref +4. **Steering append (promoted)**: `payload["_steering"]["pending_inputs"] + += [{ref}]`, `attachments["_steering_input_"] = `, + `payload["_steering"]["next_input_seq"] += 1`, + `payload["_steering"]["cancel_requested"] = true`. +5. **Steering drain (promoted entry, Phase 1)**: + `payload["_steering"]["pending_inputs"]` without the popped + head, `attachments["_steering_input_"] = null`, + plus the new turn's `_turn_started_at`. +6. **One-shot completion**: the record is deleted (one-shot is + always ephemeral). +7. **Failure**: one-shot → record deleted; multi-turn → status="suspended" + with `suspension_reason="run_completion"`. No `payload["error"]` + is written; the per-turn failure surfaces to the caller via + `TaskFailed(error=...)` and via the structured log + `resilient_task_handler_failure`. +8. **Resume (suspended → in_progress)**: status="in_progress", + `_turn_started_at` re-stamped, `_retry_attempt` reset to 0. + New input written (inline or as ref + attachment per §23.2). + +Splitting any of these into multiple PATCHes opens a crash window +where the attachment exists without its ref (or vice versa). The +framework treats this as a single-PATCH invariant. + +#### 23.9 Attachment key validation + +Attachment keys MUST match the regex `^[a-zA-Z0-9_.\-]{1,64}$` and +MUST NOT be empty after trimming whitespace. Both providers enforce +this on every CREATE / PATCH write. The framework's reserved keys (`_input`, `_steering_input_`) all conform. +Developer-supplied attachment keys (none exist today — attachments +are framework-owned per §23.7) would also be validated against this +regex if the surface is ever expanded. + +#### 23.10 Clear-all gesture + +In addition to per-key null-as-delete (§23.1), the provider accepts a +top-level "clear all attachments" gesture: + +- Wire form: `PATCH ... { "attachments": null }`. +- Effect: deletes every attachment on the task, regardless of which + keys currently exist. Per-key entries supplied in the same PATCH + are NOT applied (the clear takes precedence). +- Typed-API form: `TaskPatchRequest.clear_attachments = true`. When + set, the hosted provider serializes `attachments: null`; the local + provider clears the attachments dict directly. Mutually exclusive + with `attachments={...}` (per-key patch) in the same request — the + combination is rejected as `invalid_request`. +- The framework today never emits this gesture; per-key delete + covers all current needs. It is documented for parity with the + service and for future internal callers (e.g. orphan-attachment + cleanup post-recovery). + +DELETE on a task removes all attachments along with the task. The +local provider achieves this trivially (attachments live in the +same JSON file as the task record; unlinking the file removes +both). The hosted provider relies on the service's blob-cleanup +hook. + +### §24. Status state machine + +The framework drives the following transitions: + +``` + create() handler returns + │ or raises + ▼ ┌──────────────┐ + ┌──────────┐ auto-start ┌──────────────│ completed │ + │ pending │ ──────────────▶│ in_progress │ (terminal) │ + └──────────┘ │ │ │ + │ └──────────────┘ + │ return X (multi-turn) + ▼ ▲ + ┌──────────┐ │ + │suspended │ ────────┘ + └──────────┘ .run/.start with new input + ▲ + │ + │ reclaim (same status, + │ new lease) + │ + └─── in_progress (foreign lease) +``` + +Notes: + +- The framework usually creates with `status = in_progress` directly + (the `pending` state is rarely externally observed). +- `in_progress -> in_progress` is the most-traversed transition + (every lease renewal, every reclaim, every steering drain, every + successful retry). +- `completed` is terminal; the *outcome* (success / failure / + cancel) is communicated through the typed exceptions, not via a + separate status value. +- `ctx.exit_for_recovery()` preserves `in_progress` and force-expires + the lease — it is the only way to release ownership without moving + to a different status (§16). + +#### 24.1 Allowed transition matrix (provider-enforced) + +The provider rejects PATCHes whose declared `status` transition is +not in this table. Internal classification `_HostedConflict(_code="invalid_state_transition")`, +translated to a generic framework error at the boundary (this +condition should never escape to developer code — the framework +chooses transitions, not the developer; if it ever does escape it's +a framework bug per Workstream C). + +| From → To | `pending` | `in_progress` | `suspended` | `completed` | +|---|---|---|---|---| +| `pending` | n/a | ✅ | ❌ | ✅ | +| `in_progress` | ✅ (with matching lease) | ✅ (lease renewal) | ✅ | ✅ | +| `suspended` | ✅ | ✅ | ✅ | ✅ | +| `completed` | ❌ (terminal) | ❌ | ❌ | ✅ (no-op only — see §24.2) | + +#### 24.2 Terminal immutability + +A PATCH against a task whose current status is `completed` is +rejected UNLESS the PATCH is a no-op `completed → completed` AND +carries no other field changes (no `payload`, no `tags`, no +`error`, no `suspension_reason`, no lease). The no-op pass-through +returns the existing record without modification — this lets +idempotent retry-loops behave predictably. + +Any other PATCH against a completed task raises +`_HostedConflict(_code="task_immutable")` → translated to +`TaskConflictError(current_status="completed")`. + +#### 24.3 Delete force semantics + +DELETE on a task in any **non-terminal** status (`pending`, +`in_progress`, `suspended`) requires `force=true`. Without it the +provider rejects the delete as `invalid_request` (400) — note this +is **NOT** a conflict (409); the service's PR 2135250 explicitly +moved this from 409 → 400 with code `invalid_request`. + +DELETE on a **terminal** (`completed`) task always succeeds (no +force required). + +DELETE additionally honors `If-Match`: when supplied, the +provider rejects the delete with `_HostedConflict(_code="etag_mismatch")` +→ `EtagConflict` if the supplied etag does not match the current +record. + +### §25. ETag (optimistic concurrency) + in-process write serialization + +The framework uses the hosted store's ETag/CAS protocol per the +Foundry Task Storage Protocol spec. + +#### 25.1 Etag tracking — always-on after the first read/create + +After the first successful read/create on a `task_id`, **every +subsequent PATCH MUST carry `If-Match` with the latest known etag** +for that task. The framework tracks the latest etag in the +in-memory active-task entry, updating it from every PATCH/GET +response. `delete()` is the only operation that MUST NOT carry +`if_match` — deletion is intentionally unconditional and tolerates +a concurrent winner. + +**No blind writes.** This applies to *every* PATCH-issuing site, +including those that hold the per-task write lock and call the +provider directly to avoid re-entrant lock acquisition (e.g. the +queued-steering-cancel path): such sites MUST go through the +lock-held update helper that selects `If-Match` from the tracked +etag, never a bare `provider.update` with no `if_match`. + +The service-returned `etag` value is passed verbatim as `If-Match` +on the next PATCH. The framework does NOT strip surrounding quotes, +normalize whitespace, or otherwise rewrite it. + +#### 25.2 Per-task in-process write queue + +Without coordination, the framework has multiple concurrent +PATCH-issuing code paths against the same task: lease renewal +heartbeats, metadata flushes (handler-issued AND auto-flush at +turn boundaries), steering append, steering drain Phase-1/3, +suspend, complete, fail, output writes, and reclaim. All of these +race in-process for the same etag and can produce avoidable 412 +conflicts in steady state. + +The framework MUST serialize these writes through a **per-task +asyncio lock** held for the read-state + compute-PATCH + apply +cycle. Reads (e.g., `Task.get(task_id)`) do NOT take this lock — +they're snapshot operations that don't move the etag. + +The read MUST happen **inside** the lock for any read-modify-write +sequence (steering drain, queued-steering-cancel, etc.), so the +record read and the PATCH are atomic with respect to other +in-process writers (notably the lease-renewal heartbeat). A site +that reads the record (or pins an etag) *before* acquiring the lock +can have its etag invalidated by the heartbeat between the read and +the write, which under contention starves the retry budget. Because +the per-task lock is a **non-reentrant** `asyncio.Lock`, the +framework provides two helpers: a lock-acquiring update (for callers +that do not hold the lock) and a lock-held update (for callers that +already hold it, e.g. the drain); both select `If-Match` from the +tracked etag and refresh it on success. + +Lock lifecycle: + +- Per-`task_id` `asyncio.Lock` allocated lazily on first write. +- Released after the PATCH response is recorded (etag updated). +- Removed from the in-memory lock table when the local active-task + entry is torn down (no leaked locks). + +In-process contention now serializes; cross-process contention +(another worker reclaimed the lease) still surfaces as 412 because +the queue is in-process only. + +#### 25.3 412 (etag conflict) resolution — per-operation policy + +When a PATCH inside the queue gets a 412, the appropriate response +depends on the operation's INTENT. There is no single retry rule: + +| Operation | On 412, do what | +|---|---| +| Metadata flush | re-read state, overwrite the addressed namespace with local value (last-write-wins), retry (up to 5 attempts). | +| Steering append | re-read `_steering`, append to the NEW state's `pending_inputs`, bump `next_input_seq` from the NEW state, retry (up to 5 attempts). Idempotent when `input_id` is supplied. | +| Steering drain (Phase 1) | re-read `_steering`, drain the NEW head, retry (up to 5 attempts). | +| Steering drain (Phase 3) | re-read, retry (up to 5 attempts). | +| Lease renewal heartbeat | re-read lease; if still ours, retry; otherwise signal eviction. | +| Suspend / complete / fail terminal writes | **RE-READ + decide.** A 412 here means our etag is stale — that's all we know on its own. Re-read the record, then choose: (a) if the lease is **no longer ours** (`lease.owner` differs OR `lease.instance_id` differs OR `lease.expiry_count` bumped past our cached value) → ABANDON and signal awaiters via the eviction path (C-LSE-4 / C-ERR-2); the new owner is authoritative and our terminal would clobber their in-flight recovery. (b) If `status` is already terminal (`completed`) → ABANDON; another actor already wrote the terminal. (c) Otherwise (lease still ours, status still `in_progress`) → retry the terminal PATCH against the new etag, up to 5 attempts. Steering inputs that another process appended between our read and our retry are silently superseded by the terminal write — that is correct behavior because the steerer's `.result()` MUST then raise `TaskConflictError(current_status="completed")` per C-STR-6, which is how cross-process steering-after-terminate is supposed to surface. | +| Output write (part of suspend/complete) | inherits the parent operation's policy. | +| Resume-clear-output (part of resume) | re-read, retry (up to 5 attempts). | +| Recovery reclaim (inline) | ABANDON. The 412 IS the race-detection — another process beat us to the reclaim. Let the next caller / scan re-evaluate. | +| Recovery reclaim (cold-start / periodic) | ABANDON. Same reasoning. | + +Default retry budget is 5 attempts unless noted. Each retry +re-acquires the per-task lock before the re-read + re-merge + re-write +cycle. `LastInputIdPreconditionFailed` (for `if_last_input_id`) and +`EtagConflict` (for low-level callers) propagate as today. + +#### 25.4 Auto-extension piggyback on every PATCH + +Every PATCH the framework issues — renewal, metadata, steering, +suspend, etc. — MUST include the lease-extension trio +(`lease_owner`, `lease_instance_id`, `lease_duration_seconds`) so +the lease is refreshed as a side effect. The renewal loop's next +tick is computed dynamically from the per-task last-refresh time +(NOT a fixed cadence), so a PATCH within the last `interval` +seconds fully shadows the next heartbeat. See §56. + +**Lease renewal requires `in_progress`.** The task store accepts the +lease-extension trio as a *renewal* only when the record is already +`in_progress`, and as a *claim* only when the same PATCH transitions +the record INTO `in_progress` (e.g. reclaim, or the steering-drain +Phase-1 PATCH per §52). A PATCH that carries the lease trio against a +`suspended`/`pending`/terminal record WITHOUT a status flip to +`in_progress` is rejected ("lease renewal is only supported for +in_progress tasks"). Therefore any framework path that writes to a +record left `suspended` by a prior turn (notably the steering drain) +MUST set `status='in_progress'` in the same PATCH. The local provider +enforces this same rule so the conflict is reproducible without a +hosted deployment. + +### §26. Recovery — internal lifecycle, no public HTTP endpoint + +There is no HTTP route for resume. Resume is initiated from +caller code via the normal `Task.start` / `Task.run` (one-shot) +or `MultiTurnTask.start` / `MultiTurnTask.run` (multi-turn) entry +points. The framework's lifecycle state machine transitions a +`suspended` task back to `in_progress` and re-enters the handler +without exposing a server-side endpoint. + +Crash recovery for tasks that died mid-`in_progress` is handled +internally by the periodic recovery scanner described in §55: +the scanner detects abandoned leases and re-invokes the handler +with the persisted `payload["input"]` and +`entry_mode="recovered"`. + +--- + +## Part IV — Provider abstraction (storage backends) + +> **Visibility:** Everything in this part is **framework-internal**. +> The `TaskProvider` interface and the two concrete providers +> (`HostedTaskProvider`, `LocalFileTaskProvider`) are NOT part of +> the public surface defined in Part V — in the canonical Python +> implementation, all of these live in `_`-prefixed modules +> (`_provider.py`, `_client.py`, `_local_provider.py`) and are +> NOT re-exported from `tasks/__init__.py`'s `__all__`. The +> abstraction exists to keep the manager testable and to let the +> framework swap hosted vs. local backends — but framework +> consumers are not expected (and not supported) to construct or +> consume providers directly. This part documents the contract a +> re-implementer (in another language) MUST satisfy when writing +> the provider layer. + +### §27. `TaskProvider` interface + +The framework abstracts over the storage backend via a single +async interface. Two providers ship: hosted (HTTP-backed) and local +(file-backed); a third (in-memory) is conceptually possible. + +``` +class TaskProvider: + async def create(request: TaskCreateRequest) -> TaskInfo: ... + async def get(task_id: str) -> TaskInfo | None: ... + async def update(task_id: str, patch: TaskPatchRequest) -> TaskInfo: ... + async def delete(task_id: str, *, force: bool = False, cascade: bool = False) -> None: ... + async def list(*, agent_name: str | None = None, + session_id: str | None = None, + status: TaskStatus | None = None, + tag: dict[str, str] | None = None, + source_type: str | None = None) -> list[TaskInfo]: ... +``` + +Semantic requirements: + +- `get(task_id)` MUST return `None` for missing tasks (not raise). +- `update()` MUST honor the `if_match` field on the patch for CAS. +- `update()` payload MUST shallow-merge. +- `update()` tags MUST null-as-delete merge. +- `update()` attachments MUST null-as-delete merge (§23.1). +- `delete()` MUST be idempotent at the SCHEDULING level (multiple + `.delete()` calls do not error). The provider's lower-level + `provider.delete(task_id)` MAY raise `TaskNotFound` for already- + deleted records; callers of the provider directly MUST handle + this. The canonical Python implementation's hosted provider + raises on 404 and the local provider raises on missing files; + `MultiTurnTask.delete(task_id)` shields user code from these by catching + "not found" substring matches and re-raising as `TaskNotFound` + the first time, and being a no-op only at the user-facing + `Task` surface. +- `list(...)` MUST filter server-side; framework relies on it. + +`TaskCreateRequest` and `TaskPatchRequest` are simple structs +mirroring the writable subset of `TaskInfo` (plus `if_match`, +`lease_owner`, `lease_instance_id`, `lease_duration_seconds`). + +### §28. Hosted provider (HTTP) + +The hosted provider implements `TaskProvider` over HTTP against the +Foundry Task Storage service. Selected when the platform-supplied +environment variable `FOUNDRY_HOSTING_ENVIRONMENT` is set. + +Key implementation notes: + +- **API version:** Pinned at framework build time. The framework + carries one `_API_VERSION` constant (current canonical value: + `"v1"`) and passes it as the `api-version` query parameter on + every request. +- **Authentication:** Bearer token from a `TokenCredential` + resolved at request time. Scope is `https://ai.azure.com/.default`. +- **User-Agent:** Identifies the framework + version + runtime + (`ai-agentserver-core/`). +- **Custom error classification:** The provider classifies every + non-success response into one of four labels and raises a typed + `TransportClassifiedError(classification=