diff --git a/AGENTS.md b/AGENTS.md index 9b4c958..23b84c2 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -40,16 +40,17 @@ chainweaver/ ├── compiler_llm.py Offline build-time LLM flow compiler: LLMProposal + llm_propose_flows() + write_proposals() (#28); banned from executor.py ├── optimizer.py Offline build-time tool-description optimizer: OptimizationStrategy + ToolDescriptionProposal + optimize_tool_descriptions()/optimize_new_tool_description() (#100); banned from executor.py ├── _offline_llm.py Private shared internals for the offline LLM proposers: LLMFn type + parse_llm_yaml() + render_tool_catalogue() (#28, #100) -├── contracts.py ToolSafetyContract + SideEffectLevel/StabilityLevel/DeterminismLevel enums + merge_safety() + evaluate_predicate() — determinism + operational safety vocabulary (#19, #125, #293, #9, #8) +├── contracts.py ToolSafetyContract + SideEffectLevel/StabilityLevel/DeterminismLevel enums + merge_safety() + side_effect_exceeds() (#356) + evaluate_predicate() — determinism + operational safety vocabulary (#19, #125, #293, #9, #8) +├── approvals.py ApprovalCallback Protocol + ApprovalContext/ApprovalDecision/ApprovalRecord + coerce_approval_callback — execution-time ToolSafetyContract enforcement seam (#356); mirrors decisions.py ├── decorators.py @tool decorator for zero-boilerplate tool definition -├── tools.py Tool class: named callable with Pydantic I/O schemas + schema_hash + safety contract (#19); Tool.from_flow() wraps a Flow as a Tool (#24) with derived safety (#125) +├── tools.py Tool class: named callable with Pydantic I/O schemas + schema_hash + safety contract (#19) + metadata provenance (#358/#359/#371) + dry_run_fn/run_dry (#357); Tool.from_flow() wraps a Flow as a Tool (#24) with derived safety (#125) ├── flow.py FlowStep + Flow + DAGFlow + FlowStatus + FlowLifecycle + FlowGovernance + DriftInfo + ConditionalEdge (#9) + determinism_level property (#8) + ContextCollisionPolicy / on_context_collision (#337) ├── registry.py FlowRegistry: multi-version catalogue with status filtering (store-backed) + copy-on-write update_flow_state (#335) ├── storage.py RegistryStore protocol + InMemoryStore + FileStore (#16) ├── analyzer.py ChainAnalyzer: offline schema-compatibility analysis (#77) ├── attest.py attest_flow() + AttestationReport: observed-determinism evidence (#154) ├── decisions.py DecisionCallback Protocol + DecisionContext + coerce_decision_callback (#102) -├── executor.py FlowExecutor: sequential/DAG runner + drift detection + stream_flow + opt-in async DAG-level concurrency (max_step_concurrency, #344) (main entry point) +├── executor.py FlowExecutor: sequential/DAG runner + drift detection + stream_flow + opt-in async DAG-level concurrency (max_step_concurrency, #344) + opt-in execution-time safety enforcement (approval_callback/strict_safety/max_side_effect_level, #356) + dry-run mode (execute_flow(dry_run=...), #357) (main entry point) ├── _execution/ Internal, no-I/O execution collaborators shared by both lanes (#330, #331); banned from importing LLM/network/random — see invariants │ ├── __init__.py Re-exports merge_step_outputs │ └── context.py merge_step_outputs: single context-merge honouring on_context_collision (#337) @@ -68,7 +69,7 @@ chainweaver/ ├── mcp/ MCP integration (issues #70, #72, #150); requires chainweaver[mcp] │ ├── __init__.py Public surface: MCPToolAdapter, FlowServer, jsonschema_to_pydantic │ ├── _schema.py JSON Schema ↔ Pydantic bridge -│ ├── adapter.py MCPToolAdapter: wrap MCP server tools as ChainWeaver Tools (#70, #150) +│ ├── adapter.py MCPToolAdapter: wrap MCP server tools as ChainWeaver Tools (#70, #150) + untrusted-metadata trust controls — annotation_trust→ToolSafetyContract (#371), MetadataPolicy name/description sanitisation (#359), schema-hash pinning + on_drift (#358), build_pin_file/load_pins │ └── server.py FlowServer: safely expose governed flows as MCP tools via FastMCP (#72, #259, #294) ├── contrib/ Curated deterministic stdlib tools (#145); pip install 'chainweaver[contrib]' │ ├── __init__.py Re-exports the public tool set @@ -320,6 +321,7 @@ integration. | `started_at` | `datetime` | UTC timestamp when execution began. | | `ended_at` | `datetime` | UTC timestamp when execution finished. | | `total_duration_ms` | `float` | Wall-clock duration in ms (via `time.perf_counter`). | +| `dry_run` | `bool` | `True` when produced by `execute_flow(dry_run=True)` (#357); a rehearsal trace, never a real run. | ### `StepRecord` (Pydantic `BaseModel`) @@ -335,6 +337,7 @@ integration. | `started_at` | `datetime` | UTC timestamp when the step began. | | `ended_at` | `datetime` | UTC timestamp when the step finished. | | `duration_ms` | `float` | Wall-clock duration in ms (via `time.perf_counter`). | +| `approval` | `ApprovalRecord \| None` | The decision for a step gated by an execution-time approval callback (#356); `None` when no approval was required. | > **Serialization:** `ExecutionResult` and `StepRecord` are Pydantic models; > `result.model_dump_json()` and `ExecutionResult.model_validate_json(...)` diff --git a/README.md b/README.md index 8a3bcac..a48c1c8 100644 --- a/README.md +++ b/README.md @@ -800,6 +800,8 @@ All errors are typed and traceable: | `SchemaValidationError` | Input or output fails Pydantic validation | | `InputMappingError` | A mapping key is not present in the context | | `FlowExecutionError` | The tool callable raises an unexpected exception | +| `ApprovalDeniedError` | An execution-time approval callback denied a step, raised, or returned an invalid value — or `strict_safety=True` and a required-approval step has no callback | +| `SafetyCeilingError` | A step's `ToolSafetyContract.side_effects` exceeds the executor's configured `max_side_effect_level` | | `ToolDefinitionError` | The `@tool` decorator cannot build a tool from a function | | `DAGDefinitionError` | A `DAGFlow` has a cycle, duplicate `step_id`, or unknown dependency | | `FlowCompositionError` | A composed flow has a sub-flow cycle, exceeds `max_composition_depth`, or references an unregistered sub-flow | @@ -812,6 +814,8 @@ All errors are typed and traceable: | `FixtureStaleError` | A `record_then_replay` replay invocation cannot be matched to a recording (missing/stale fixture) | | `FuzzConfigError` | A property-based fuzzing run is misconfigured (no properties, `runs < 1`, a flow with no `input_schema` and no base input, or an unsupported input-field type) | | `CostProfileError` | A cost estimate is requested for a `(provider, model)` pair absent from the maintained `PROVIDER_PRICES` table | +| `MCPMetadataError` | A server-provided MCP tool name fails the adapter's `MetadataPolicy` (and `on_invalid_name="error"`) | +| `MCPSchemaDriftError` | A pinned MCP tool's raw schema changed under `MCPToolAdapter(on_drift="error")` | All exceptions inherit from `ChainWeaverError`. diff --git a/chainweaver/__init__.py b/chainweaver/__init__.py index 6c9ac46..052dbef 100644 --- a/chainweaver/__init__.py +++ b/chainweaver/__init__.py @@ -41,6 +41,15 @@ from chainweaver import cli from chainweaver.analyzer import ChainAnalyzer, Suggestion, ToolChain, suggest_optimizations +from chainweaver.approvals import ( + ApprovalCallable, + ApprovalCallback, + ApprovalContext, + ApprovalDecision, + ApprovalRecord, + BaseApprovalCallback, + coerce_approval_callback, +) from chainweaver.attest import AttestationInputError, AttestationReport, attest_flow from chainweaver.builder import FlowBuilder, FlowBuilderError from chainweaver.cache import FileStepCache, InMemoryStepCache, StepCache, StepCacheKey @@ -66,6 +75,7 @@ ToolSafetyContract, evaluate_predicate, merge_safety, + side_effect_exceeds, ) from chainweaver.cost import ( PROVIDER_PRICES, @@ -85,6 +95,7 @@ from chainweaver.events import FlowEvent from chainweaver.exceptions import ( AgentTraceImportError, + ApprovalDeniedError, AsyncLaneUnsupportedError, ChainWeaverError, CheckpointDriftError, @@ -106,11 +117,14 @@ InvalidFlowVersionError, KernelInvocationError, MCPError, + MCPMetadataError, MCPSchemaConversionError, + MCPSchemaDriftError, MCPToolInvocationError, OfflineLLMError, PluginDiscoveryError, PredicateSyntaxError, + SafetyCeilingError, SchemaValidationError, ToolDefinitionError, ToolNotFoundError, @@ -238,11 +252,18 @@ "PROVIDER_PRICES", "AgentTraceEvent", "AgentTraceImportError", + "ApprovalCallable", + "ApprovalCallback", + "ApprovalContext", + "ApprovalDecision", + "ApprovalDeniedError", + "ApprovalRecord", "AsyncLaneUnsupportedError", "AttestationInputError", "AttestationReport", "BacktestMismatch", "BacktestReport", + "BaseApprovalCallback", "BaseDecisionCallback", "BaseMiddleware", "CancellationToken", @@ -321,7 +342,9 @@ "LessonEvidenceStep", "LessonReview", "MCPError", + "MCPMetadataError", "MCPSchemaConversionError", + "MCPSchemaDriftError", "MCPToolInvocationError", "ObservedStep", "ObservedTrace", @@ -337,6 +360,7 @@ "ReplayMode", "ReplayResult", "RetryPolicy", + "SafetyCeilingError", "SafetyLevel", "SchemaValidationError", "ServiceConfig", @@ -369,6 +393,7 @@ "check_flow_compatibility", "classify_safety", "cli", + "coerce_approval_callback", "coerce_decision_callback", "compile_flow", "discover_flows", @@ -397,6 +422,7 @@ "result_to_mermaid", "schema_fingerprint", "score_candidate", + "side_effect_exceeds", "suggest_optimizations", "tool", "trace_to_lesson_candidate", diff --git a/chainweaver/approvals.py b/chainweaver/approvals.py new file mode 100644 index 0000000..b39c434 --- /dev/null +++ b/chainweaver/approvals.py @@ -0,0 +1,188 @@ +"""Execution-time approval seam for ToolSafetyContract enforcement (issue #356). + +ChainWeaver already ships a rich, composable safety vocabulary +(:mod:`chainweaver.contracts`): side-effect levels, approval flags, dry-run +support, and ``merge_safety()``. In v1 the contract was purely *advisory* — +:class:`~chainweaver.executor.FlowExecutor` never acted on it. An +:class:`ApprovalCallback` is the opt-in seam that makes the contract +*actionable*: when a step's effective contract has ``requires_approval=True`` +and a callback is registered, the executor asks the callback to approve the +step **before** the tool function runs. + +The seam deliberately mirrors :class:`~chainweaver.decisions.DecisionCallback` +(issue #102): the executor only ever *calls* a user-supplied callback, so the +three hard executor invariants (no LLM, no network I/O, no randomness in +:mod:`chainweaver.executor`) are preserved — the callback is where a host can +inject a human prompt, a policy service, or an RPC, none of which the executor +performs itself. + +Two equivalent callback shapes are accepted, exactly like the decision seam:: + + # Class-based + class CliApprover: + def approve(self, ctx: ApprovalContext) -> ApprovalDecision: + return ApprovalDecision.APPROVE + + # Plain callable + def approve_all(ctx: ApprovalContext) -> ApprovalDecision: + return ApprovalDecision.APPROVE + + FlowExecutor(registry, approval_callback=approve_all) + +Failure semantics: a ``DENY`` decision, a callback that raises, or a callback +that returns a non-:class:`ApprovalDecision` aborts the step with +:class:`~chainweaver.exceptions.ApprovalDeniedError` and a failed +``StepRecord`` — the same abort-the-step path tool failures take. +""" + +from __future__ import annotations + +from collections.abc import Callable +from enum import Enum +from typing import Any, Protocol, runtime_checkable + +from pydantic import BaseModel, ConfigDict + +from chainweaver.contracts import ToolSafetyContract + + +class ApprovalDecision(str, Enum): + """Explicit outcome of an :class:`ApprovalCallback` — no boolean ambiguity. + + Attributes: + APPROVE: Allow the step's tool to run. + DENY: Refuse the step; the executor aborts it with + :class:`~chainweaver.exceptions.ApprovalDeniedError`. + """ + + APPROVE = "approve" + DENY = "deny" + + +class ApprovalContext(BaseModel): + """Snapshot of execution state passed to an :class:`ApprovalCallback`. + + Attributes: + trace_id: UUID4 hex string for the running execution. + flow_name: Name of the flow being executed. + step_index: Zero-based position of the step inside the flow. + step_id: ``DAGFlowStep.step_id`` when running a ``DAGFlow``; ``None`` + for linear ``Flow`` execution. + tool_name: Name of the tool the step is about to run. + inputs: The step's resolved (already redacted, when a redaction policy + is configured) inputs. Read-only — mutating has no effect. + safety: The effective :class:`ToolSafetyContract` that triggered the + approval gate. + """ + + model_config = ConfigDict(frozen=True, arbitrary_types_allowed=True) + + trace_id: str + flow_name: str + step_index: int + step_id: str | None + tool_name: str + inputs: dict[str, Any] + safety: ToolSafetyContract + + +class ApprovalRecord(BaseModel): + """Audit record of an approval decision, attached to ``StepRecord.approval``. + + Persisted on the trace so a completed :class:`~chainweaver.executor.ExecutionResult` + is a full record of which side-effecting steps were gated and how they were + resolved. + + Attributes: + decision: The :class:`ApprovalDecision` the callback returned (or + ``DENY`` when the callback raised / no callback was registered under + ``strict_safety``). + reason: Optional human-readable explanation carried alongside the + decision. + """ + + model_config = ConfigDict(frozen=True) + + decision: ApprovalDecision + reason: str | None = None + + +@runtime_checkable +class ApprovalCallback(Protocol): + """Structural protocol for execution-time approval callbacks (issue #356).""" + + def approve(self, ctx: ApprovalContext) -> ApprovalDecision: + """Return :attr:`ApprovalDecision.APPROVE` or :attr:`ApprovalDecision.DENY`. + + Args: + ctx: Snapshot of the flow execution state at the approval point. + + Returns: + An :class:`ApprovalDecision`. Returning anything else aborts the + step with :class:`~chainweaver.exceptions.ApprovalDeniedError`. + + Raises: + Exception: Any exception is caught by the executor, converted to an + :class:`~chainweaver.exceptions.ApprovalDeniedError`, and aborts + the step like any other tool failure. + """ + ... + + +class BaseApprovalCallback: + """Convenience base class for class-based :class:`ApprovalCallback`. + + Subclass and override :meth:`approve`. Stateful approvers (batching + prompts, caching policy decisions) typically inherit from this; pure + stateless approvers can use a plain function and skip the class entirely. + """ + + def approve(self, ctx: ApprovalContext) -> ApprovalDecision: # pragma: no cover — abstract + raise NotImplementedError("BaseApprovalCallback subclasses must override 'approve'.") + + +# Type alias for accepted callback shapes; bare callables are wrapped so the +# executor's call site stays uniform (``cb.approve(ctx)``). +ApprovalCallable = Callable[[ApprovalContext], ApprovalDecision] + + +class _CallableApprovalCallback: + """Adapter that wraps a bare callable into an :class:`ApprovalCallback`.""" + + __slots__ = ("_fn",) + + def __init__(self, fn: ApprovalCallable) -> None: + self._fn = fn + + def approve(self, ctx: ApprovalContext) -> ApprovalDecision: + return self._fn(ctx) + + +def coerce_approval_callback( + cb: ApprovalCallback | ApprovalCallable | None, +) -> ApprovalCallback | None: + """Normalize *cb* into an :class:`ApprovalCallback`, or ``None``. + + Accepts either an object implementing ``approve(ctx)`` or a bare callable + with the equivalent signature. Bare callables are wrapped so the executor + can call ``cb.approve(ctx)`` uniformly. + + Args: + cb: An :class:`ApprovalCallback`, a bare callable, or ``None``. + + Returns: + An :class:`ApprovalCallback` instance, or ``None`` if *cb* was ``None``. + + Raises: + TypeError: If *cb* is neither an :class:`ApprovalCallback` nor callable. + """ + if cb is None: + return None + if isinstance(cb, ApprovalCallback): + return cb + if callable(cb): + return _CallableApprovalCallback(cb) + raise TypeError( + f"approval_callback must implement ApprovalCallback or be callable; " + f"got {type(cb).__name__}." + ) diff --git a/chainweaver/compat.py b/chainweaver/compat.py index 4be636e..0cce34e 100644 --- a/chainweaver/compat.py +++ b/chainweaver/compat.py @@ -35,6 +35,30 @@ def schema_fingerprint(model: type[BaseModel]) -> str: return hashlib.sha256(canonical.encode()).hexdigest()[:16] +def schema_dict_fingerprint(raw_schema: dict[str, object]) -> str: + """Compute a deterministic fingerprint of a *raw* JSON Schema dict (issue #358). + + Counterpart to :func:`schema_fingerprint`, which fingerprints a Pydantic + model. This variant fingerprints a JSON Schema *mapping* directly — used by + :class:`~chainweaver.mcp.adapter.MCPToolAdapter` to pin the schemas a remote + MCP server advertises *before* they are projected to Pydantic, so a server + silently changing a tool's ``inputSchema`` / ``outputSchema`` between sessions + is detectable. + + The canonicalisation (sorted keys, compact separators) makes the fingerprint + insensitive to JSON key ordering, so a server reordering schema keys without + changing their meaning does not register as drift. + + Args: + raw_schema: A JSON-Schema mapping (e.g. an MCP tool's ``inputSchema``). + + Returns: + A 16-character hex digest string, matching :func:`schema_fingerprint`. + """ + canonical = json.dumps(raw_schema, sort_keys=True, separators=(",", ":")) + return hashlib.sha256(canonical.encode()).hexdigest()[:16] + + @dataclass class CompatibilityIssue: """A single compatibility problem detected between a flow and its tools. diff --git a/chainweaver/contracts.py b/chainweaver/contracts.py index 0ae635f..734719e 100644 --- a/chainweaver/contracts.py +++ b/chainweaver/contracts.py @@ -222,6 +222,18 @@ def requires_review(self) -> bool: return self.requires_approval +def side_effect_exceeds(level: SideEffectLevel, ceiling: SideEffectLevel) -> bool: + """Return ``True`` when *level* is strictly more restrictive than *ceiling*. + + Used by :class:`~chainweaver.executor.FlowExecutor`'s execution-time + side-effect ceiling (issue #356): a step whose contract declares a + :class:`SideEffectLevel` above the configured ``max_side_effect_level`` is + refused. Ordering follows :class:`SideEffectLevel`'s documented + most-permissive → most-restrictive scale. + """ + return _SIDE_EFFECT_ORDER[level] > _SIDE_EFFECT_ORDER[ceiling] + + def merge_safety( contracts: Iterable[ToolSafetyContract], *, @@ -479,4 +491,5 @@ def _eval(node: ast.AST) -> Any: "ToolSafetyContract", "evaluate_predicate", "merge_safety", + "side_effect_exceeds", ] diff --git a/chainweaver/exceptions.py b/chainweaver/exceptions.py index 9f2da2e..6123458 100644 --- a/chainweaver/exceptions.py +++ b/chainweaver/exceptions.py @@ -456,6 +456,109 @@ def __init__(self, tool_name: str, detail: str) -> None: super().__init__(f"MCP tool '{tool_name}' invocation failed: {detail}.") +class ApprovalDeniedError(ChainWeaverError): + """Raised when an :class:`~chainweaver.approvals.ApprovalCallback` denies a step (issue #356). + + Execution-time enforcement of :class:`~chainweaver.contracts.ToolSafetyContract` + is opt-in: when a step's effective contract has ``requires_approval=True`` and + a callback is registered on the executor, the callback is asked to approve the + step *before* the tool function runs. A ``DENY`` decision (or a callback that + raises, or a missing callback under ``strict_safety=True``) aborts the step + with this typed error rather than running the side-effecting tool unattended. + + Attributes: + tool_name: Name of the tool whose invocation was denied. + step_index: Zero-based position of the step inside the flow. + detail: Human-readable description of why approval was denied. + """ + + def __init__(self, tool_name: str, step_index: int, detail: str) -> None: + self.tool_name = tool_name + self.step_index = step_index + self.detail = detail + # Normalise so the message ends with exactly one period (repo convention, + # AGENTS.md §6) regardless of whether *detail* already carried one. + normalised = detail.rstrip(".") + super().__init__( + f"Approval denied for tool '{tool_name}' at step {step_index}: {normalised}." + ) + + +class SafetyCeilingError(ChainWeaverError): + """Raised when a step's side-effect level exceeds the executor ceiling (issue #356). + + When :class:`~chainweaver.executor.FlowExecutor` is configured with + ``max_side_effect_level=...``, a step whose effective + :class:`~chainweaver.contracts.ToolSafetyContract` declares a + :class:`~chainweaver.contracts.SideEffectLevel` above that ceiling is refused + before it runs, rather than silently executing a higher-risk operation than + the host opted into. + + Attributes: + tool_name: Name of the tool that exceeded the ceiling. + step_index: Zero-based position of the step inside the flow. + level: The step's declared side-effect level (value string). + ceiling: The configured maximum side-effect level (value string). + """ + + def __init__(self, tool_name: str, step_index: int, level: str, ceiling: str) -> None: + self.tool_name = tool_name + self.step_index = step_index + self.level = level + self.ceiling = ceiling + super().__init__( + f"Tool '{tool_name}' at step {step_index} has side-effect level " + f"'{level}' which exceeds the configured ceiling '{ceiling}'." + ) + + +class MCPMetadataError(MCPError): + """Raised when server-provided MCP tool metadata violates the metadata policy (issue #359). + + Tool names and descriptions wrapped from an MCP server are untrusted input: + they become ChainWeaver :attr:`Tool.description` / :attr:`Tool.name` values and + can be re-exported to LLM clients or rendered into proposer prompts. When a + server advertises a tool name that fails the configured validation pattern (and + the policy is not in sanitising mode), :class:`MCPToolAdapter` refuses it with + this error instead of adopting a look-alike or control-character-laden name. + + Attributes: + tool_name: The offending server-provided tool name (server-prefixed when a + prefix was supplied). + detail: Human-readable explanation of which rule was violated. + """ + + def __init__(self, tool_name: str, detail: str) -> None: + self.tool_name = tool_name + self.detail = detail + super().__init__(f"MCP tool metadata for '{tool_name}' rejected: {detail}.") + + +class MCPSchemaDriftError(MCPError): + """Raised when a discovered MCP tool schema no longer matches its pin (issue #358). + + Tools wrapped from remote MCP servers get the same schema-drift discipline as + locally registered tools: :class:`MCPToolAdapter` fingerprints each tool's raw + JSON Schema at discovery and, when a pin is supplied, verifies it. Under the + ``on_drift="error"`` policy a mismatch raises this exception naming the tool and + both fingerprints, rather than transparently rebuilding models around a silently + changed remote schema. + + Attributes: + tool_name: Name of the MCP tool whose schema drifted (server-side name). + expected: The pinned fingerprint. + actual: The fingerprint computed from the freshly discovered schema. + """ + + def __init__(self, tool_name: str, expected: str, actual: str) -> None: + self.tool_name = tool_name + self.expected = expected + self.actual = actual + super().__init__( + f"MCP tool '{tool_name}' schema drifted: pinned '{expected}', discovered '{actual}'." + ) + + class DecisionCallbackError(ChainWeaverError): """Raised when a :class:`~chainweaver.decisions.DecisionCallback` fails (issue #102). diff --git a/chainweaver/executor.py b/chainweaver/executor.py index ad05c33..cb496ee 100644 --- a/chainweaver/executor.py +++ b/chainweaver/executor.py @@ -29,10 +29,22 @@ from tenacity import RetryError, Retrying, retry_if_exception_type, stop_after_attempt, wait_fixed from chainweaver._execution import merge_step_outputs +from chainweaver.approvals import ( + ApprovalCallable, + ApprovalCallback, + ApprovalContext, + ApprovalDecision, + ApprovalRecord, + coerce_approval_callback, +) from chainweaver.cache import StepCache, StepCacheKey, compute_input_value_hash from chainweaver.cancellation import CancellationToken from chainweaver.checkpoint import Checkpointer, ExecutionSnapshot -from chainweaver.contracts import evaluate_predicate +from chainweaver.contracts import ( + SideEffectLevel, + evaluate_predicate, + side_effect_exceeds, +) from chainweaver.cost import CostProfile, CostReport, compute_cost_report from chainweaver.decisions import ( DecisionCallable, @@ -42,6 +54,7 @@ ) from chainweaver.events import FlowEvent from chainweaver.exceptions import ( + ApprovalDeniedError, AsyncLaneUnsupportedError, CheckpointDriftError, CheckpointerNotConfiguredError, @@ -54,6 +67,7 @@ FlowStatusError, InputMappingError, PredicateSyntaxError, + SafetyCeilingError, SchemaValidationError, ToolNotFoundError, ToolOutputSizeError, @@ -138,6 +152,13 @@ def __init__(self) -> None: self.active_flow_version: str = "" self.in_replay: bool = False self.resume_snapshot: ExecutionSnapshot | None = None + # Dry-run mode (issue #357): set for the duration of an + # ``execute_flow(dry_run=True)`` call. ``dry_run_unsupported`` is the + # policy applied to side-effecting steps that declare no ``dry_run_fn`` + # ("skip" stubs them, "abort" fails the step). Both are save/restored + # around sub-flow recursion alongside ``active_flow_version``. + self.dry_run: bool = False + self.dry_run_unsupported: str = "skip" class _StreamCollectorMiddleware(BaseMiddleware): @@ -407,6 +428,10 @@ class StepRecord(BaseModel): sub_result: For a composed sub-flow step (issue #75), the nested :class:`ExecutionResult` of the sub-flow run, so the parent trace retains the full sub-flow execution log. ``None`` for tool steps. + approval: For a step gated by an execution-time approval callback + (issue #356), the :class:`~chainweaver.approvals.ApprovalRecord` + describing the decision. ``None`` for steps whose effective + contract did not require approval (the common case). """ model_config = ConfigDict(arbitrary_types_allowed=True) @@ -428,6 +453,7 @@ class StepRecord(BaseModel): fallback_used: bool = False flow_name: str | None = None sub_result: ExecutionResult | None = None + approval: ApprovalRecord | None = None class ExecutionResult(BaseModel): @@ -479,6 +505,12 @@ class ExecutionResult(BaseModel): initial_input: The initial context dictionary that was passed to ``execute_flow``. Stored on the result so the trace can be replayed deterministically by :meth:`FlowExecutor.replay_flow`. + dry_run: ``True`` when the result was produced by + ``execute_flow(dry_run=True)`` (issue #357). Dry-run traces run + ``dry_run_fn`` for tools that declare it, run read-only steps + normally, and skip/abort other side-effecting steps — so a dry-run + trace must never be mistaken for a real run. ``False`` for normal + executions (the default). """ model_config = ConfigDict(arbitrary_types_allowed=True) @@ -494,6 +526,7 @@ class ExecutionResult(BaseModel): total_duration_ms: float cost_report: CostReport | None = None initial_input: dict[str, Any] = Field(default_factory=dict) + dry_run: bool = False def to_mermaid(self, *, direction: str = "LR") -> str: """Return a Mermaid graph overlaying this result on the flow (#79).""" @@ -604,6 +637,9 @@ def __init__( checkpointer: Checkpointer | None = None, delete_on_success: bool = True, decision_callback: DecisionCallback | DecisionCallable | None = None, + approval_callback: ApprovalCallback | ApprovalCallable | None = None, + strict_safety: bool = False, + max_side_effect_level: SideEffectLevel | None = None, discover_plugins: bool = False, max_composition_depth: int = 10, max_step_concurrency: int = 1, @@ -640,6 +676,19 @@ def __init__( self._decision_callback: DecisionCallback | None = coerce_decision_callback( decision_callback ) + # Execution-time safety enforcement (issue #356). All opt-in: + # ``approval_callback`` is the seam invoked before a step whose effective + # ``ToolSafetyContract`` has ``requires_approval=True``; ``strict_safety`` + # refuses such steps when no callback is registered (instead of the + # advisory default of running them); ``max_side_effect_level`` is a + # ceiling above which any step is refused. Like ``decision_callback``, + # the callback is a user-supplied seam — the executor never performs I/O + # itself, so the no-LLM/no-network/no-randomness invariants hold. + self._approval_callback: ApprovalCallback | None = coerce_approval_callback( + approval_callback + ) + self._strict_safety = strict_safety + self._max_side_effect_level = max_side_effect_level # Step-result cache (issue #127). ``None`` (the default) # disables caching entirely — every tool runs every call. # When set, eligible step outputs are read from / written to @@ -1229,6 +1278,66 @@ def execute_flow( force: bool = False, deadline: float | None = None, cancel_token: CancellationToken | None = None, + dry_run: bool = False, + dry_run_unsupported: str = "skip", + ) -> ExecutionResult: + """Execute a registered flow from *initial_input*. + + See :meth:`_execute_flow_impl` for the full parameter reference; the two + extra keyword arguments here add the opt-in dry-run mode (issue #357): + + Args: + dry_run: When ``True``, run the flow as a side-effect-free rehearsal. + Read-only steps (``side_effects`` in ``NONE``/``READ``) run + normally; tools that declare a ``dry_run_fn`` run it; other + side-effecting steps are handled per *dry_run_unsupported*. The + step cache and checkpointer are bypassed and + :attr:`ExecutionResult.dry_run` is set, so a rehearsal trace can + never be confused with a real run. Composed sub-flows (#75) + inherit the dry-run mode. + dry_run_unsupported: Policy for side-effecting steps with no + ``dry_run_fn`` under ``dry_run=True``: ``"skip"`` (the default) + records a ``skipped`` stub and continues; ``"abort"`` fails the + step for a high-fidelity rehearsal. + + Returns: + An :class:`ExecutionResult`; ``dry_run`` reflects the mode it ran in. + """ + if dry_run_unsupported not in ("skip", "abort"): + raise ValueError( + f"dry_run_unsupported must be 'skip' or 'abort', got {dry_run_unsupported!r}." + ) + # Save/restore the per-thread dry-run markers so a composed sub-flow + # (#75) inherits the parent's mode and a later normal call on the same + # thread is unaffected. ``dry_run or previous`` lets a nested call with + # the default ``dry_run=False`` stay in the parent's dry run. + previous_dry_run = self._local.dry_run + previous_unsupported = self._local.dry_run_unsupported + self._local.dry_run = dry_run or previous_dry_run + if dry_run: + self._local.dry_run_unsupported = dry_run_unsupported + try: + return self._execute_flow_impl( + flow_name, + initial_input, + version=version, + force=force, + deadline=deadline, + cancel_token=cancel_token, + ) + finally: + self._local.dry_run = previous_dry_run + self._local.dry_run_unsupported = previous_unsupported + + def _execute_flow_impl( + self, + flow_name: str, + initial_input: dict[str, Any], + *, + version: str | None = None, + force: bool = False, + deadline: float | None = None, + cancel_token: CancellationToken | None = None, ) -> ExecutionResult: """Execute a registered flow from *initial_input*. @@ -2012,6 +2121,8 @@ async def _execute_step_async( started_at = _now_utc() t0 = time.perf_counter() tool_attempts = [0] + # Approval audit record (issue #356); set by the safety gate below. + approval_record: ApprovalRecord | None = None def _record( *, @@ -2041,6 +2152,7 @@ def _record( skipped=skipped, cached=False, fallback_used=fallback_used, + approval=approval_record, ) def _finish(record: StepRecord) -> StepRecord: @@ -2101,6 +2213,30 @@ def _finish(record: StepRecord) -> StepRecord: redaction=self._redaction_policy, ) + # Execution-time safety enforcement (issue #356) — mirrors the sync lane. + # ``DAGFlowStep`` carries a ``step_id``; a linear ``FlowStep`` does not. + gate_error, approval_record = self._evaluate_safety_gate( + step=step, + tool=tool, + step_index=step_index, + inputs=inputs, + flow_name=flow_name, + trace_id=trace_id, + step_id=getattr(step, "step_id", None), + ) + if gate_error is not None: + log_step_error(_logger, step_index, step.display_name, gate_error) + return _finish( + _record( + inputs=inputs, + outputs=None, + error=gate_error, + success=False, + skipped=False, + retry_errors=[], + ) + ) + retry_errors: list[str] = [] try: outputs = await self._invoke_tool_async( @@ -2591,6 +2727,7 @@ def _make_result( total_duration_ms=total_ms, cost_report=cost_report, initial_input=dict(initial_input), + dry_run=self._local.dry_run, ) if self._trace_recorder is not None: self._record_observed_trace(result) @@ -2625,7 +2762,9 @@ def _save_linear_snapshot( relevant tool's current ``schema_hash`` so resume can detect drift since the snapshot was written. """ - if self._checkpointer is None: + # Dry runs (#357) never persist snapshots — a rehearsal must not leave + # resumable state behind. + if self._checkpointer is None or self._local.dry_run: return tool_hashes: dict[str, str] = {} for step in flow.steps: @@ -2662,7 +2801,7 @@ def _save_dag_snapshot( run sequentially, but on resume the level is replayed from scratch. No-op when no checkpointer is configured. """ - if self._checkpointer is None: + if self._checkpointer is None or self._local.dry_run: return tool_hashes: dict[str, str] = {} for step in flow.steps: @@ -3340,6 +3479,10 @@ def _execute_step( # accurate when ``on_error="skip"`` or ``on_error="fallback:…"`` # appends extra entries to ``retry_errors`` for context. tool_attempts = [0] + # Approval audit record (issue #356); ``None`` until the safety gate + # below evaluates an approval decision. ``_record`` reads it at call + # time, so every record produced after the gate carries the decision. + approval_record: ApprovalRecord | None = None def _record( *, @@ -3374,6 +3517,7 @@ def _record( skipped=skipped, cached=cached, fallback_used=fallback_used, + approval=approval_record, ) def _finish(record: StepRecord) -> StepRecord: @@ -3462,15 +3606,60 @@ def _finish(record: StepRecord) -> StepRecord: redaction=self._redaction_policy, ) + # Execution-time safety enforcement (issue #356): side-effect ceiling + + # approval gate. Runs after inputs are resolved (so the approval + # callback sees real inputs) and before any tool work. + gate_error, approval_record = self._evaluate_safety_gate( + step=step, + tool=tool, + step_index=step_index, + inputs=inputs, + flow_name=flow_name, + trace_id=trace_id, + step_id=step_id, + ) + if gate_error is not None: + log_step_error(_logger, step_index, step.display_name, gate_error) + return _finish( + _record( + inputs=inputs, + outputs=None, + error=gate_error, + success=False, + skipped=False, + retry_errors=[], + ) + ) + + # Dry-run dispatch (issue #357): side-effecting steps run ``dry_run_fn`` + # or are skipped/aborted; read-only steps fall through and run normally + # (with the cache bypassed below). + if self._local.dry_run: + dry_record = self._dry_run_step( + step=step, + tool=tool, + step_index=step_index, + inputs=inputs, + record_fn=_record, + ) + if dry_record is not None: + return _finish(dry_record) + # Cache lookup (issue #127). Skip caching during replay_flow - # (replay always re-executes) and for tools that opt out via + # (replay always re-executes), during dry runs (#357, rehearsals must + # not read or write real cache state), and for tools that opt out via # ``cacheable=False``. Input validation runs inside the cache # path so we can hash the *validated* form — equivalent inputs # that differ only in field ordering or coercion collapse onto # the same key. If validation fails, fall through to the # normal execution path, which surfaces the same error. cache_key: StepCacheKey | None = None - if self._step_cache is not None and tool.cacheable and not self._local.in_replay: + if ( + self._step_cache is not None + and tool.cacheable + and not self._local.in_replay + and not self._local.dry_run + ): try: validated = tool.input_schema.model_validate(inputs) except ValidationError: @@ -3607,6 +3796,164 @@ def _finish(record: StepRecord) -> StepRecord: ) ) + def _evaluate_safety_gate( + self, + *, + step: FlowStep, + tool: Tool, + step_index: int, + inputs: dict[str, Any], + flow_name: str, + trace_id: str, + step_id: str | None, + ) -> tuple[Exception | None, ApprovalRecord | None]: + """Enforce the execution-time safety contract for a step (issue #356). + + Returns ``(error, approval)``: ``error`` is non-``None`` when the step + must be aborted (ceiling exceeded, approval denied, callback misbehaved, + or ``strict_safety`` with no callback); ``approval`` is the audit record + to attach to the step when an approval decision was made. Both are + ``None`` for the common case of a step that needs no gating. This helper + performs no I/O itself — the only outward call is the user-supplied + approval callback, mirroring the ``decision_callback`` seam — so the + executor's determinism invariants are preserved. + """ + contract = tool.safety + ceiling = self._max_side_effect_level + if ceiling is not None and side_effect_exceeds(contract.side_effects, ceiling): + return ( + SafetyCeilingError( + tool.name, step_index, contract.side_effects.value, ceiling.value + ), + None, + ) + + if not contract.requires_approval: + return None, None + + if self._approval_callback is None: + if self._strict_safety: + reason = "no approval_callback registered (strict_safety=True)" + return ( + ApprovalDeniedError( + tool.name, + step_index, + f"step requires approval but {reason}", + ), + ApprovalRecord(decision=ApprovalDecision.DENY, reason=reason), + ) + # Advisory default (pre-#356 behaviour): run the step. + return None, None + + redacted = ( + self._redaction_policy.redact(inputs) + if self._redaction_policy is not None + else dict(inputs) + ) + ctx = ApprovalContext( + trace_id=trace_id, + flow_name=flow_name, + step_index=step_index, + step_id=step_id, + tool_name=tool.name, + inputs=redacted, + safety=contract, + ) + try: + decision = self._approval_callback.approve(ctx) + except Exception as exc: + reason = f"approval callback raised {type(exc).__name__}: {exc}" + err = ApprovalDeniedError(tool.name, step_index, reason) + err.__cause__ = exc + return err, ApprovalRecord(decision=ApprovalDecision.DENY, reason=reason) + if not isinstance(decision, ApprovalDecision): + reason = f"approval callback returned {decision!r}, not an ApprovalDecision" + return ( + ApprovalDeniedError(tool.name, step_index, reason), + ApprovalRecord(decision=ApprovalDecision.DENY, reason=reason), + ) + if decision is ApprovalDecision.DENY: + reason = "approval callback returned DENY" + return ( + ApprovalDeniedError(tool.name, step_index, reason), + ApprovalRecord(decision=ApprovalDecision.DENY, reason=reason), + ) + return None, ApprovalRecord(decision=ApprovalDecision.APPROVE) + + def _dry_run_step( + self, + *, + step: FlowStep, + tool: Tool, + step_index: int, + inputs: dict[str, Any], + record_fn: Callable[..., StepRecord], + ) -> StepRecord | None: + """Dispatch one step under ``dry_run=True`` (issue #357). + + Returns ``None`` when the step is read-only and should run normally + (the caller falls through to the real execution path, with the cache + bypassed). Otherwise returns a finished :class:`StepRecord`: a + ``dry_run_fn`` preview, a stubbed ``skipped`` record, or a failed record + under the ``"abort"`` policy. + """ + contract = tool.safety + if contract.side_effects in (SideEffectLevel.NONE, SideEffectLevel.READ): + return None # safe to actually run a read-only step + + if tool.supports_dry_run and tool.dry_run_fn is not None: + try: + outputs = tool.run_dry(inputs) + except Exception as exc: + wrapped = self._wrap_tool_exception(step, step_index, exc) + log_step_error(_logger, step_index, step.display_name, wrapped) + return record_fn( + inputs=inputs, + outputs=None, + error=wrapped, + success=False, + skipped=False, + retry_errors=[], + ) + log_step_end( + _logger, step_index, step.display_name, outputs, redaction=self._redaction_policy + ) + return record_fn( + inputs=inputs, + outputs=outputs, + error=None, + success=True, + skipped=False, + retry_errors=[], + ) + + if self._local.dry_run_unsupported == "abort": + err = FlowExecutionError( + step.display_name, + step_index, + "dry-run abort: side-effecting tool declares no dry_run_fn", + ) + log_step_error(_logger, step_index, step.display_name, err) + return record_fn( + inputs=inputs, + outputs=None, + error=err, + success=False, + skipped=False, + retry_errors=[], + ) + + # "skip": stub the side-effecting step so the rehearsal continues + # without merging any output into the context. + return record_fn( + inputs=inputs, + outputs={}, + error=None, + success=True, + skipped=True, + retry_errors=[], + ) + def _invoke_tool( self, tool: Tool, diff --git a/chainweaver/mcp/__init__.py b/chainweaver/mcp/__init__.py index 64f28f0..eaefeb7 100644 --- a/chainweaver/mcp/__init__.py +++ b/chainweaver/mcp/__init__.py @@ -27,13 +27,25 @@ from __future__ import annotations from chainweaver.mcp._schema import jsonschema_to_pydantic, pydantic_to_jsonschema -from chainweaver.mcp.adapter import MCPToolAdapter +from chainweaver.mcp.adapter import ( + AnnotationTrust, + DriftPolicy, + MCPToolAdapter, + MetadataPolicy, + build_pin_file, + load_pins, +) from chainweaver.mcp.server import FlowServer, TransportName __all__ = [ + "AnnotationTrust", + "DriftPolicy", "FlowServer", "MCPToolAdapter", + "MetadataPolicy", "TransportName", + "build_pin_file", "jsonschema_to_pydantic", + "load_pins", "pydantic_to_jsonschema", ] diff --git a/chainweaver/mcp/adapter.py b/chainweaver/mcp/adapter.py index b1ae03c..8b73c0d 100644 --- a/chainweaver/mcp/adapter.py +++ b/chainweaver/mcp/adapter.py @@ -36,17 +36,33 @@ from __future__ import annotations import json -from typing import TYPE_CHECKING, Any - -from pydantic import BaseModel - -from chainweaver.exceptions import MCPToolInvocationError +import logging +import re +import unicodedata +from datetime import datetime, timezone +from pathlib import Path +from typing import TYPE_CHECKING, Any, Literal + +from pydantic import BaseModel, ConfigDict + +from chainweaver.compat import schema_dict_fingerprint +from chainweaver.contracts import ( + DeterminismLevel, + SideEffectLevel, + StabilityLevel, + ToolSafetyContract, +) +from chainweaver.exceptions import ( + MCPMetadataError, + MCPSchemaDriftError, + MCPToolInvocationError, +) from chainweaver.mcp._schema import jsonschema_to_pydantic from chainweaver.tools import Tool try: # Optional dependency. from mcp import ClientSession - from mcp.types import CallToolResult, TextContent + from mcp.types import CallToolResult, TextContent, ToolAnnotations from mcp.types import Tool as MCPRemoteTool except ImportError as exc: # pragma: no cover — depends on install layout raise ImportError( @@ -58,6 +74,23 @@ from collections.abc import Iterable, Mapping +_logger = logging.getLogger("chainweaver.mcp.adapter") + +AnnotationTrust = Literal["trust", "ignore", "cap"] +"""Trust policy for server-declared :class:`ToolAnnotations` (issue #371). + +* ``"trust"`` — map declared annotations onto a :class:`ToolSafetyContract`; + tools with **no** annotations are left ``safety=None`` (nothing to trust). +* ``"ignore"`` — never derive safety from annotations (``safety=None`` always). +* ``"cap"`` (default) — like ``"trust"`` but conservative: a tool with no + annotations still gets an ``EXTERNAL`` contract, and a declared read-only tool + gets ``READ`` (never ``NONE``) since a remote call still observes the world. +""" + +DriftPolicy = Literal["error", "warn", "accept"] +"""How :class:`MCPToolAdapter` reacts to a pinned schema changing (issue #358).""" + + DEFAULT_SERVER_PREFIX_SEP = "__" """Default separator between server prefix and the MCP tool's own name. @@ -68,6 +101,199 @@ """ +_DEFAULT_NAME_PATTERN = r"^[A-Za-z0-9._-]+$" + + +class MetadataPolicy(BaseModel): + """Trust policy for server-provided MCP tool names and descriptions (issue #359). + + Tool descriptions and names wrapped from an MCP server travel further than they + first appear — they become ChainWeaver :attr:`Tool.description` / :attr:`Tool.name` + values, can be re-exported to LLM clients via :class:`~chainweaver.mcp.FlowServer` + or the ``export`` adapters, and may be rendered into the offline proposer prompts. + This policy treats that metadata as untrusted input: conservative defaults strip + control characters, normalise whitespace, cap description length, and validate the + tool name, while leaving the verbatim-adoption escape hatch explicit. + + Attributes: + max_description_length: Cap on the adopted description length; longer + descriptions are truncated with a visible marker. ``None`` disables the + cap. Defaults to ``2000``. + strip_control_chars: Remove C0/C1 control characters (other than ``\\n`` / + ``\\t``) from descriptions. Defaults to ``True``. + normalize_whitespace: Collapse runs of whitespace to single spaces and strip + the ends. Defaults to ``True``. + description_mode: ``"server"`` adopts the (sanitised) server description; + ``"placeholder"`` replaces it with a neutral generated description, making + verbatim adoption of remote text an explicit opt-in. Defaults to + ``"server"``. + name_pattern: Regex a (prefixed) tool name must fully match. Defaults to + ``^[A-Za-z0-9._-]+$``. + on_invalid_name: ``"error"`` rejects a non-matching name with + :class:`~chainweaver.exceptions.MCPMetadataError`; ``"sanitize"`` replaces + each disallowed character with ``_``. Defaults to ``"error"``. + """ + + model_config = ConfigDict(frozen=True) + + max_description_length: int | None = 2000 + strip_control_chars: bool = True + normalize_whitespace: bool = True + description_mode: Literal["server", "placeholder"] = "server" + name_pattern: str = _DEFAULT_NAME_PATTERN + on_invalid_name: Literal["error", "sanitize"] = "error" + + @classmethod + def permissive(cls) -> MetadataPolicy: + """Return a policy that restores the pre-#359 verbatim behaviour. + + No length cap, no control-char stripping, no whitespace normalisation, the + server description adopted as-is, and any tool name accepted unchanged. Use + this to opt out of hardening for a fully trusted server. + """ + return cls( + max_description_length=None, + strip_control_chars=False, + normalize_whitespace=False, + description_mode="server", + name_pattern=r"(?s).*", + ) + + def apply_name(self, name: str) -> str: + """Validate (or sanitise) a (prefixed) tool *name* against the policy. + + Sanitisation maps any character outside the default safe set + (``[A-Za-z0-9._-]``) to ``_``; the result always satisfies the default + :attr:`name_pattern`. When a *custom* ``name_pattern`` is configured the + sanitised name is re-validated against it and rejected with + :class:`~chainweaver.exceptions.MCPMetadataError` if it still does not + match, rather than returning a name that silently violates the policy. + """ + if re.fullmatch(self.name_pattern, name): + return name + if self.on_invalid_name == "sanitize": + sanitised = re.sub(r"[^A-Za-z0-9._-]", "_", name) + # Guard the degenerate all-invalid case so we never return an empty name. + sanitised = sanitised or "mcp_tool" + # The sanitisation charset matches the *default* pattern; a custom, + # stricter pattern may still reject it, so re-validate and fail loudly + # instead of adopting a name that violates the configured policy. + if not re.fullmatch(self.name_pattern, sanitised): + raise MCPMetadataError( + name, + f"sanitised name {sanitised!r} still does not match the configured " + f"name_pattern {self.name_pattern!r}", + ) + return sanitised + raise MCPMetadataError( + name, + f"name does not match required pattern {self.name_pattern!r}; " + "set MetadataPolicy(on_invalid_name='sanitize') to coerce it", + ) + + def apply_description(self, raw: str | None, *, cw_name: str, server: str | None) -> str: + """Return the description to adopt for a tool, per the policy.""" + if self.description_mode == "placeholder": + origin = f" from server '{server}'" if server else "" + return f"MCP tool '{cw_name}'{origin}." + text = raw if raw is not None else f"MCP tool '{cw_name}'." + if self.strip_control_chars: + text = "".join(ch for ch in text if ch in "\n\t" or unicodedata.category(ch)[0] != "C") + if self.normalize_whitespace: + text = " ".join(text.split()) + if self.max_description_length is not None and len(text) > self.max_description_length: + text = text[: self.max_description_length] + "…(truncated)" + # An all-control-character description can normalise to empty; keep a stable, + # non-blank value so downstream catalogues never render a nameless entry. + return text or f"MCP tool '{cw_name}'." + + +def _remote_contract(side_effects: SideEffectLevel, *, idempotent: bool) -> ToolSafetyContract: + """Build a conservative :class:`ToolSafetyContract` for a remote MCP tool. + + Remote tools are never cached by default and their determinism cannot be + attested from self-declared hints, so ``determinism_level`` is pinned to + ``NONE`` and ``stability`` to ``BEST_EFFORT`` regardless of the annotation. + ``read_only`` is derived from *side_effects* to satisfy the contract validator. + """ + return ToolSafetyContract( + side_effects=side_effects, + stability=StabilityLevel.BEST_EFFORT, + determinism_level=DeterminismLevel.NONE, + idempotent=idempotent, + cacheable=False, + safe_to_retry=idempotent, + supports_dry_run=False, + ) + + +def _safety_from_annotations( + annotations: ToolAnnotations | None, + trust: AnnotationTrust, +) -> ToolSafetyContract | None: + """Map MCP :class:`ToolAnnotations` onto a :class:`ToolSafetyContract` (issue #371). + + Conservative by construction: ``readOnlyHint`` maps to ``READ`` (a remote call + still observes the world, so never ``NONE``), ``destructiveHint`` to + ``DESTRUCTIVE``, and an absent/ambiguous annotation to ``EXTERNAL``. Returns + ``None`` (unknown) when *trust* is ``"ignore"``, or when *trust* is ``"trust"`` + and the tool declares no annotations at all. + """ + if trust == "ignore": + return None + if annotations is None: + if trust == "cap": + return _remote_contract(SideEffectLevel.EXTERNAL, idempotent=False) + return None + if annotations.destructiveHint: + side = SideEffectLevel.DESTRUCTIVE + elif annotations.readOnlyHint: + side = SideEffectLevel.READ + else: + side = SideEffectLevel.EXTERNAL + if annotations.idempotentHint is not None: + idempotent = bool(annotations.idempotentHint) + else: + idempotent = side in {SideEffectLevel.NONE, SideEffectLevel.READ} + return _remote_contract(side, idempotent=idempotent) + + +def build_pin_file(tools: Iterable[Tool], *, server: str) -> dict[str, Any]: + """Build a pin-file mapping from discovered *tools* (issue #358). + + The returned structure records the server identity, a UTC timestamp, and each + tool's pinned raw-schema fingerprint (read from ``tool.metadata['mcp_schema_hash']``, + populated by :meth:`MCPToolAdapter.discover_tools`). Serialise it with + :func:`json.dump` to produce a ``.chainweaver/mcp-pins.json`` lockfile, then pass + it back via ``discover_tools(pins=...)`` on later runs to detect drift. + + Args: + tools: Tools previously returned by :meth:`MCPToolAdapter.discover_tools`. + server: Identifier recorded for the MCP server the tools came from. + + Returns: + A JSON-serialisable pin mapping keyed by the tools' server-side names. + """ + pinned: dict[str, str] = {} + for tool in tools: + remote_name = tool.metadata.get("mcp_remote_name", tool.name) + schema_hash = tool.metadata.get("mcp_schema_hash") + if schema_hash is not None: + pinned[remote_name] = schema_hash + return { + "server": server, + "pinned_at": datetime.now(timezone.utc).isoformat(), + "tools": pinned, + } + + +def load_pins(pins_path: str | Path) -> dict[str, str]: + """Load the ``tools`` fingerprint mapping from a pin file (issue #358).""" + data = json.loads(Path(pins_path).read_text(encoding="utf-8")) + tools = data.get("tools", {}) if isinstance(data, dict) else {} + return {str(name): str(value) for name, value in tools.items()} + + class _MCPToolOutput(BaseModel): """Permissive output schema used when the MCP tool has no ``outputSchema``. @@ -102,6 +328,22 @@ class MCPToolAdapter: applied to every discovered tool. Per-tool overrides are available by mutating ``tool.timeout_seconds`` after discovery. + annotation_trust: How to map server-declared :class:`ToolAnnotations` + onto each wrapped tool's :class:`ToolSafetyContract` (issue #371). + ``"cap"`` (the default) derives a conservative contract for every + tool; ``"trust"`` only for annotated tools; ``"ignore"`` never. + metadata_policy: Trust policy for server-provided tool names and + descriptions (issue #359). ``None`` (the default) applies the + conservative :class:`MetadataPolicy` defaults; pass + ``MetadataPolicy.permissive()`` to opt out. + on_drift: How to react when a discovered tool's raw schema no longer + matches a supplied pin (issue #358): ``"error"`` (the default) + raises :class:`~chainweaver.exceptions.MCPSchemaDriftError`, + ``"warn"`` logs and continues, ``"accept"`` silently adopts the new + schema. Only consulted when ``discover_tools`` is given pins. + server_name: Optional identifier for the MCP server, recorded on each + tool's metadata and used by ``description_mode="placeholder"`` and + :func:`build_pin_file`. Example:: @@ -123,9 +365,26 @@ def __init__( session: ClientSession, *, timeout_seconds: float | None = None, + annotation_trust: AnnotationTrust = "cap", + metadata_policy: MetadataPolicy | None = None, + on_drift: DriftPolicy = "error", + server_name: str | None = None, ) -> None: + # Validate the policy literals at construction so a typo (e.g. + # ``on_drift="erorr"``) fails loudly instead of silently falling through + # to "accept" and disabling drift protection on a security surface. + if annotation_trust not in ("trust", "ignore", "cap"): + raise ValueError( + f"annotation_trust must be 'trust', 'ignore', or 'cap', got {annotation_trust!r}." + ) + if on_drift not in ("error", "warn", "accept"): + raise ValueError(f"on_drift must be 'error', 'warn', or 'accept', got {on_drift!r}.") self.session = session self.timeout_seconds = timeout_seconds + self.annotation_trust: AnnotationTrust = annotation_trust + self.metadata_policy = metadata_policy if metadata_policy is not None else MetadataPolicy() + self.on_drift: DriftPolicy = on_drift + self.server_name = server_name async def discover_tools( self, @@ -135,6 +394,8 @@ async def discover_tools( include: Iterable[str] | None = None, exclude: Iterable[str] | None = None, schema_overrides: Mapping[str, type[BaseModel]] | None = None, + pins: Mapping[str, str] | None = None, + pins_path: str | Path | None = None, ) -> list[Tool]: """List the MCP server's tools and project each into a ChainWeaver Tool. @@ -159,6 +420,14 @@ async def discover_tools( is insufficient (e.g. the server advertises a loose schema you want to tighten). Keyed by the MCP tool's own name, not the (optionally prefixed) ChainWeaver name. + pins: Optional mapping of MCP-side tool name to a pinned raw-schema + fingerprint (issue #358). When supplied, a discovered tool whose + schema fingerprint differs from its pin is handled per the + adapter's ``on_drift`` policy. Tools absent from the mapping are + not drift-checked. + pins_path: Optional path to a JSON pin file (as written by + :func:`build_pin_file`); its ``tools`` mapping is merged under any + explicit *pins* (explicit entries win on conflict). Returns: A list of :class:`Tool` instances ready for @@ -167,12 +436,22 @@ async def discover_tools( Raises: MCPSchemaConversionError: When a tool's ``inputSchema`` is structurally invalid. + MCPMetadataError: When a tool name fails the metadata policy and + ``on_invalid_name="error"`` (issue #359). + MCPSchemaDriftError: When a pinned tool's schema changed and + ``on_drift="error"`` (issue #358). """ result = await self.session.list_tools() wanted: set[str] | None = set(include) if include is not None else None unwanted: set[str] = set(exclude) if exclude is not None else set() overrides: Mapping[str, type[BaseModel]] = schema_overrides or {} + resolved_pins: dict[str, str] = {} + if pins_path is not None: + resolved_pins.update(load_pins(pins_path)) + if pins is not None: + resolved_pins.update(pins) + tools: list[Tool] = [] for mcp_tool in result.tools: if wanted is not None and mcp_tool.name not in wanted: @@ -185,6 +464,7 @@ async def discover_tools( server_prefix=server_prefix, prefix_separator=prefix_separator, input_override=overrides.get(mcp_tool.name), + pin=resolved_pins.get(mcp_tool.name), ) ) return tools @@ -196,12 +476,16 @@ def _build_tool( server_prefix: str, prefix_separator: str, input_override: type[BaseModel] | None = None, + pin: str | None = None, ) -> Tool: """Project a single MCP tool descriptor into a ChainWeaver ``Tool``.""" if server_prefix: - cw_name = f"{server_prefix}{prefix_separator}{mcp_tool.name}" + raw_name = f"{server_prefix}{prefix_separator}{mcp_tool.name}" else: - cw_name = mcp_tool.name + raw_name = mcp_tool.name + # Validate / sanitise the server-provided name before it becomes a + # ChainWeaver tool identifier (issue #359). + cw_name = self.metadata_policy.apply_name(raw_name) if input_override is not None: input_schema: type[BaseModel] = input_override @@ -223,6 +507,26 @@ def _build_tool( output_schema = _MCPToolOutput project_result = _project_unstructured_output + # Fingerprint the *raw* JSON Schema(s) the server advertised, before the + # Pydantic projection, and verify it against any supplied pin (issue #358). + schema_hash = schema_dict_fingerprint( + {"input": mcp_tool.inputSchema, "output": mcp_tool.outputSchema} + ) + if pin is not None and pin != schema_hash: + if self.on_drift == "error": + raise MCPSchemaDriftError(mcp_tool.name, pin, schema_hash) + if self.on_drift == "warn": + _logger.warning( + "MCP tool '%s' schema drifted: pinned '%s', discovered '%s'.", + mcp_tool.name, + pin, + schema_hash, + ) + # "accept" (and the warn path) fall through and adopt the new schema. + + # Derive a conservative safety contract from server annotations (issue #371). + safety = _safety_from_annotations(mcp_tool.annotations, self.annotation_trust) + session = self.session remote_name = mcp_tool.name @@ -237,20 +541,39 @@ async def fn(validated_input: BaseModel) -> dict[str, Any]: raise MCPToolInvocationError(cw_name, str(exc)) from exc return project_result(call_result, cw_name) - return Tool( - name=cw_name, - description=(mcp_tool.description or f"MCP tool '{remote_name}'."), - input_schema=input_schema, - output_schema=output_schema, - fn=fn, - timeout_seconds=self.timeout_seconds, - # MCP tools may have side effects on the remote server; - # opt out of the in-process step cache by default so each - # invocation actually hits the server. Callers can flip - # this on a per-tool basis after discovery for tools they - # know to be pure. - cacheable=False, + description = self.metadata_policy.apply_description( + mcp_tool.description, cw_name=cw_name, server=self.server_name ) + metadata: dict[str, Any] = { + "mcp_remote_name": remote_name, + "mcp_schema_hash": schema_hash, + "mcp_annotation_source": "server" if mcp_tool.annotations is not None else "absent", + "mcp_annotation_trust": self.annotation_trust, + } + if self.server_name is not None: + metadata["mcp_server"] = self.server_name + # Preserve the raw server description for audit even when it was replaced + # or sanitised, so nothing is lost (issue #359). + if mcp_tool.description is not None: + metadata["mcp_raw_description"] = mcp_tool.description + + # MCP tools may have side effects on the remote server; opt out of the + # in-process step cache by default so each invocation actually hits the + # server. When a safety contract is derived it already declares + # ``cacheable=False``; only pass the explicit flag when no contract is + # derived (``safety=None``), to avoid the Tool's conflict guard. + common_kwargs: dict[str, Any] = { + "name": cw_name, + "description": description, + "input_schema": input_schema, + "output_schema": output_schema, + "fn": fn, + "timeout_seconds": self.timeout_seconds, + "metadata": metadata, + } + if safety is not None: + return Tool(safety=safety, **common_kwargs) + return Tool(cacheable=False, **common_kwargs) def _join_text_content(content: list[Any]) -> str: diff --git a/chainweaver/tools.py b/chainweaver/tools.py index c65a9ae..3bb7b29 100644 --- a/chainweaver/tools.py +++ b/chainweaver/tools.py @@ -98,6 +98,17 @@ class Tool: as the UTF-8 byte length of its JSON serialization. When set and exceeded, :class:`~chainweaver.exceptions.ToolOutputSizeError` is raised. ``None`` (the default) disables the size check. + metadata: Optional free-form provenance/annotation metadata (issues + #358, #359, #371). Audit information without a first-class field — + e.g. an MCP tool's raw server-provided description, the source of a + derived safety contract, or a pinned remote schema fingerprint. + Never consumed by the executor. Stored as a (shallow-copied) dict; + defaults to ``{}``. + dry_run_fn: Optional **synchronous** effect-free preview callable (issue + #357) with the same ``(validated_input) -> dict`` shape as a sync + ``fn``. Required when ``safety.supports_dry_run=True`` (validated at + construction). Invoked by ``execute_flow(dry_run=True)`` in place of + ``fn`` for side-effecting tools. Example:: @@ -137,6 +148,8 @@ def __init__( schema_version: str = "0.0.0", cacheable: bool | None = None, safety: ToolSafetyContract | None = None, + metadata: dict[str, Any] | None = None, + dry_run_fn: Callable[[Any], dict[str, Any]] | None = None, ) -> None: self.name = name self.description = description @@ -176,6 +189,27 @@ def __init__( ) self.cacheable = safety.cacheable self.safety = safety + # Free-form provenance / annotation metadata (issues #358, #359, #371). + # Carries audit information that has no first-class field — e.g. the raw + # server-provided description an MCP tool was sanitised from, the source + # of a derived safety contract ("server" vs "author"), or the pinned + # remote schema fingerprint. Never consumed by the executor; downstream + # reviewers and the drift workflow read it. Defaults to an empty dict so + # callers can always ``tool.metadata.get(...)`` without a None guard. + self.metadata: dict[str, Any] = dict(metadata) if metadata else {} + # Effect-free preview callable (issue #357). Takes the validated input + # model like ``fn`` and returns a ``dict``, but — unlike ``fn`` — must be + # a *synchronous* callable (dry-run is a synchronous ``execute_flow`` + # feature) and must perform no side effects. + # ``FlowExecutor.execute_flow(dry_run=True)`` calls this instead of ``fn`` + # for side-effecting tools that declare it. A tool whose contract sets + # ``supports_dry_run=True`` MUST supply one. + self.dry_run_fn = dry_run_fn + if self.safety.supports_dry_run and dry_run_fn is None: + raise ToolDefinitionError( + name, + "safety.supports_dry_run=True requires a dry_run_fn to be supplied.", + ) # Whether ``fn`` is a coroutine function — pre-computed once # because ``inspect.iscoroutinefunction`` doesn't recognise # callables whose ``__call__`` is async, so we also inspect the @@ -253,6 +287,37 @@ async def run_async(self, raw_inputs: dict[str, Any]) -> dict[str, Any]: raw_output = await self._call_fn_async(validated_input) return self._validate_output(raw_output) + @property + def supports_dry_run(self) -> bool: + """Whether this tool can be previewed effect-free (issue #357). + + Mirrors :attr:`ToolSafetyContract.supports_dry_run`; a ``dry_run_fn`` is + required whenever this is ``True`` (enforced at construction). + """ + return self.safety.supports_dry_run + + def run_dry(self, raw_inputs: dict[str, Any]) -> dict[str, Any]: + """Validate *raw_inputs*, run the effect-free ``dry_run_fn``, validate output. + + The dry-run counterpart to :meth:`run` (issue #357): applies the same + input/output schema validation and size cap, but dispatches to + ``dry_run_fn`` so no side effects occur. Used by + :meth:`FlowExecutor.execute_flow` under ``dry_run=True``. + + Raises: + ToolDefinitionError: When the tool has no ``dry_run_fn``. + pydantic.ValidationError: When inputs or the preview output do not + match the declared schemas. + ToolOutputSizeError: When ``max_output_size`` is exceeded. + """ + if self.dry_run_fn is None: + raise ToolDefinitionError( + self.name, "run_dry() called but the tool has no dry_run_fn." + ) + validated_input = self.input_schema.model_validate(raw_inputs) + raw_output = self.dry_run_fn(validated_input) + return self._validate_output(raw_output) + def _validate_output(self, raw_output: dict[str, Any]) -> dict[str, Any]: """Apply size cap + schema validation; shared by ``run`` / ``run_async``.""" if self.max_output_size is not None: diff --git a/docs/security.md b/docs/security.md index 20cc65c..2f029ec 100644 --- a/docs/security.md +++ b/docs/security.md @@ -75,6 +75,104 @@ applied recursively to nested dicts and lists. --- +## Execution-time safety enforcement (#356) + +By default a `ToolSafetyContract` is **advisory** — the executor records it but +does not act on it. Three opt-in `FlowExecutor` controls make it actionable for +hosts that expose flows to LLM clients, all behaviour-preserving when unset: + +```python +from chainweaver import FlowExecutor, ApprovalContext, ApprovalDecision, SideEffectLevel + +def approver(ctx: ApprovalContext) -> ApprovalDecision: + # ctx carries trace_id, flow_name, step_index, tool_name, redacted inputs, + # and the effective ToolSafetyContract. + return ApprovalDecision.APPROVE if ctx.tool_name in TRUSTED else ApprovalDecision.DENY + +executor = FlowExecutor( + registry, + approval_callback=approver, # gate requires_approval=True steps + strict_safety=True, # refuse such steps if no callback + max_side_effect_level=SideEffectLevel.WRITE, # refuse DESTRUCTIVE/EXTERNAL-over-ceiling +) +``` + +* A step whose **effective contract** has `requires_approval=True` invokes the + callback *before* the tool runs. `DENY`, a callback exception, or an invalid + return aborts the step with `ApprovalDeniedError` and a failed `StepRecord`; + the decision is recorded on `StepRecord.approval`. +* With **no** callback registered, the default is unchanged (the step runs); + `strict_safety=True` instead refuses approval-requiring steps. +* `max_side_effect_level` refuses any step whose `side_effects` exceeds the + ceiling with `SafetyCeilingError`. +* The callback is a **user-supplied seam** — the executor never performs I/O + itself, so the no-LLM / no-network / no-randomness invariants are preserved + (the same model as `decision_callback`). Enforcement applies on both the sync + and async lanes. + +## Dry-run rehearsals (#357) + +`execute_flow(dry_run=True)` runs a side-effect-free rehearsal that validates +wiring and data shapes against real systems without committing side effects: + +```python +deploy = Tool(name="deploy", fn=do_deploy, dry_run_fn=plan_deploy, + safety=ToolSafetyContract(side_effects=SideEffectLevel.EXTERNAL, + supports_dry_run=True)) +result = executor.execute_flow("release", inputs, dry_run=True) +assert result.dry_run is True +``` + +* Read-only steps (`side_effects` in `NONE`/`READ`) run normally; tools that + declare a `dry_run_fn` (and `supports_dry_run=True`) run it; other + side-effecting steps are **skipped** (stubbed) by default, or fail the step + under `dry_run_unsupported="abort"` for a high-fidelity rehearsal. +* The step cache and checkpointer are **bypassed** so a rehearsal never reads or + writes real state, and `ExecutionResult.dry_run` is set so a dry-run trace can + never be confused with a real run. Composed sub-flows inherit the mode. + +## Trusting MCP-imported tool metadata (#358, #359, #371) + +Tools wrapped from a remote MCP server arrive as **untrusted input**: their +names, descriptions, schemas, and annotations are server-declared and travel on +into `Tool` objects, re-exports, and proposer prompts. `MCPToolAdapter` applies +conservative defaults: + +```python +from chainweaver.mcp import MCPToolAdapter, MetadataPolicy + +adapter = MCPToolAdapter( + session, + annotation_trust="cap", # derive a conservative ToolSafetyContract (#371) + metadata_policy=MetadataPolicy(),# sanitise names/descriptions (#359) + on_drift="error", # reject changed pinned schemas (#358) + server_name="search-tools", +) +tools = await adapter.discover_tools(pins_path=".chainweaver/mcp-pins.json") +``` + +* **Annotations → contract (#371):** `readOnlyHint → READ` (never `NONE` — a + remote call still observes the world), `destructiveHint → DESTRUCTIVE`, + unannotated → `EXTERNAL`; remote `determinism_level` is always `NONE`. + `annotation_trust` is `"cap"` (conservative, default), `"trust"` (declared + only), or `"ignore"`. The contract source is recorded on `tool.metadata`. +* **Metadata policy (#359):** control characters stripped, whitespace + normalised, descriptions length-capped, names validated against + `^[A-Za-z0-9._-]+$`; `description_mode="placeholder"` drops remote text + entirely. The raw server description is preserved on + `tool.metadata["mcp_raw_description"]` for audit. `MetadataPolicy.permissive()` + restores the pre-hardening verbatim behaviour for a fully trusted server. +* **Schema pinning (#358):** each tool's raw JSON Schema is fingerprinted at + discovery (`tool.metadata["mcp_schema_hash"]`); supply `pins` / `pins_path` + (write one with `build_pin_file`) and a changed schema is handled per + `on_drift` (`"error"` / `"warn"` / `"accept"`). + +> These controls verify **declared** metadata and schemas, not remote +> *behaviour*. Keep human review in your promotion workflow; a server can +> still change what a tool *does* without changing its schema. + +--- + ## Recommendations for production 1. **Always configure a `RedactionPolicy`** for flows whose tools handle diff --git a/tests/fixtures/public_api.json b/tests/fixtures/public_api.json index 24f3344..aab934c 100644 --- a/tests/fixtures/public_api.json +++ b/tests/fixtures/public_api.json @@ -2,12 +2,19 @@ "__all__": [ "AgentTraceEvent", "AgentTraceImportError", + "ApprovalCallable", + "ApprovalCallback", + "ApprovalContext", + "ApprovalDecision", + "ApprovalDeniedError", + "ApprovalRecord", "AsyncLaneUnsupportedError", "AttestationInputError", "AttestationReport", "BUILTIN_PROPERTIES", "BacktestMismatch", "BacktestReport", + "BaseApprovalCallback", "BaseDecisionCallback", "BaseMiddleware", "CancellationToken", @@ -86,7 +93,9 @@ "LessonEvidenceStep", "LessonReview", "MCPError", + "MCPMetadataError", "MCPSchemaConversionError", + "MCPSchemaDriftError", "MCPToolInvocationError", "ObservedStep", "ObservedTrace", @@ -103,6 +112,7 @@ "ReplayMode", "ReplayResult", "RetryPolicy", + "SafetyCeilingError", "SafetyLevel", "SchemaValidationError", "ServiceConfig", @@ -135,6 +145,7 @@ "check_flow_compatibility", "classify_safety", "cli", + "coerce_approval_callback", "coerce_decision_callback", "compile_flow", "discover_flows", @@ -163,6 +174,7 @@ "result_to_mermaid", "schema_fingerprint", "score_candidate", + "side_effect_exceeds", "suggest_optimizations", "tool", "trace_to_lesson_candidate", @@ -200,6 +212,51 @@ "qualname": "AgentTraceImportError", "signature": "(detail: str, *, source: str | None = None, line: int | None = None) -> None" }, + "ApprovalCallable": { + "kind": "_CallableGenericAlias", + "module": "collections.abc", + "qualname": "Callable" + }, + "ApprovalCallback": { + "kind": "class", + "module": "chainweaver.approvals", + "qualname": "ApprovalCallback", + "signature": "(*args, **kwargs)" + }, + "ApprovalContext": { + "kind": "pydantic-model", + "model_fields": { + "flow_name": "str", + "inputs": "dict[str, Any]", + "safety": "chainweaver.contracts.ToolSafetyContract", + "step_id": "str | NoneType", + "step_index": "int", + "tool_name": "str", + "trace_id": "str" + }, + "module": "chainweaver.approvals", + "qualname": "ApprovalContext" + }, + "ApprovalDecision": { + "kind": "enum", + "module": "chainweaver.approvals", + "qualname": "ApprovalDecision" + }, + "ApprovalDeniedError": { + "kind": "class", + "module": "chainweaver.exceptions", + "qualname": "ApprovalDeniedError", + "signature": "(tool_name: str, step_index: int, detail: str) -> None" + }, + "ApprovalRecord": { + "kind": "pydantic-model", + "model_fields": { + "decision": "chainweaver.approvals.ApprovalDecision", + "reason": "str | NoneType" + }, + "module": "chainweaver.approvals", + "qualname": "ApprovalRecord" + }, "AsyncLaneUnsupportedError": { "kind": "class", "module": "chainweaver.exceptions", @@ -262,6 +319,12 @@ "module": "chainweaver.traces", "qualname": "BacktestReport" }, + "BaseApprovalCallback": { + "kind": "class", + "module": "chainweaver.approvals", + "qualname": "BaseApprovalCallback", + "signature": "()" + }, "BaseDecisionCallback": { "kind": "class", "module": "chainweaver.decisions", @@ -543,6 +606,7 @@ "kind": "pydantic-model", "model_fields": { "cost_report": "chainweaver.cost.CostReport | NoneType", + "dry_run": "bool", "ended_at": "datetime.datetime", "execution_log": "list[chainweaver.executor.StepRecord]", "final_output": "dict[str, Any] | NoneType", @@ -695,7 +759,7 @@ "kind": "class", "module": "chainweaver.executor", "qualname": "FlowExecutor", - "signature": "(registry: FlowRegistry, *, cost_profile: CostProfile | None = None, redaction_policy: RedactionPolicy | None = None, trace_recorder: TraceRecorder | None = None, middleware: list[FlowExecutorMiddleware] | None = None, step_cache: StepCache | None = None, checkpointer: Checkpointer | None = None, delete_on_success: bool = True, decision_callback: DecisionCallback | DecisionCallable | None = None, discover_plugins: bool = False, max_composition_depth: int = 10, max_step_concurrency: int = 1) -> None" + "signature": "(registry: FlowRegistry, *, cost_profile: CostProfile | None = None, redaction_policy: RedactionPolicy | None = None, trace_recorder: TraceRecorder | None = None, middleware: list[FlowExecutorMiddleware] | None = None, step_cache: StepCache | None = None, checkpointer: Checkpointer | None = None, delete_on_success: bool = True, decision_callback: DecisionCallback | DecisionCallable | None = None, approval_callback: ApprovalCallback | ApprovalCallable | None = None, strict_safety: bool = False, max_side_effect_level: SideEffectLevel | None = None, discover_plugins: bool = False, max_composition_depth: int = 10, max_step_concurrency: int = 1) -> None" }, "FlowExecutorMiddleware": { "kind": "class", @@ -927,12 +991,24 @@ "qualname": "MCPError", "signature": null }, + "MCPMetadataError": { + "kind": "class", + "module": "chainweaver.exceptions", + "qualname": "MCPMetadataError", + "signature": "(tool_name: str, detail: str) -> None" + }, "MCPSchemaConversionError": { "kind": "class", "module": "chainweaver.exceptions", "qualname": "MCPSchemaConversionError", "signature": "(tool_name: str, detail: str) -> None" }, + "MCPSchemaDriftError": { + "kind": "class", + "module": "chainweaver.exceptions", + "qualname": "MCPSchemaDriftError", + "signature": "(tool_name: str, expected: str, actual: str) -> None" + }, "MCPToolInvocationError": { "kind": "class", "module": "chainweaver.exceptions", @@ -1057,6 +1133,12 @@ "module": "chainweaver.flow", "qualname": "RetryPolicy" }, + "SafetyCeilingError": { + "kind": "class", + "module": "chainweaver.exceptions", + "qualname": "SafetyCeilingError", + "signature": "(tool_name: str, step_index: int, level: str, ceiling: str) -> None" + }, "SafetyLevel": { "kind": "enum", "module": "chainweaver.traces", @@ -1182,6 +1264,7 @@ "StepRecord": { "kind": "pydantic-model", "model_fields": { + "approval": "chainweaver.approvals.ApprovalRecord | NoneType", "cached": "bool", "duration_ms": "float", "ended_at": "datetime.datetime", @@ -1232,7 +1315,7 @@ "kind": "class", "module": "chainweaver.tools", "qualname": "Tool", - "signature": "(*, name: str, description: str, input_schema: type[BaseModel], output_schema: type[BaseModel], fn: Callable[[Any], dict[str, Any] | Awaitable[dict[str, Any]]], timeout_seconds: float | None = None, max_output_size: int | None = None, schema_version: str = '0.0.0', cacheable: bool | None = None, safety: ToolSafetyContract | None = None) -> None" + "signature": "(*, name: str, description: str, input_schema: type[BaseModel], output_schema: type[BaseModel], fn: Callable[[Any], dict[str, Any] | Awaitable[dict[str, Any]]], timeout_seconds: float | None = None, max_output_size: int | None = None, schema_version: str = '0.0.0', cacheable: bool | None = None, safety: ToolSafetyContract | None = None, metadata: dict[str, Any] | None = None, dry_run_fn: Callable[[Any], dict[str, Any]] | None = None) -> None" }, "ToolChain": { "kind": "GenericAlias", @@ -1332,6 +1415,12 @@ "module": null, "qualname": "chainweaver.cli" }, + "coerce_approval_callback": { + "kind": "function", + "module": "chainweaver.approvals", + "qualname": "coerce_approval_callback", + "signature": "(cb: ApprovalCallback | ApprovalCallable | None) -> ApprovalCallback | None" + }, "coerce_decision_callback": { "kind": "function", "module": "chainweaver.decisions", @@ -1500,6 +1589,12 @@ "qualname": "score_candidate", "signature": "(events: Sequence[AgentTraceEvent], sequence: Sequence[str]) -> CandidateScore" }, + "side_effect_exceeds": { + "kind": "function", + "module": "chainweaver.contracts", + "qualname": "side_effect_exceeds", + "signature": "(level: SideEffectLevel, ceiling: SideEffectLevel) -> bool" + }, "suggest_optimizations": { "kind": "function", "module": "chainweaver.analyzer", diff --git a/tests/test_execution_safety.py b/tests/test_execution_safety.py new file mode 100644 index 0000000..be92139 --- /dev/null +++ b/tests/test_execution_safety.py @@ -0,0 +1,297 @@ +"""Tests for execution-time safety enforcement and dry-run (issues #356, #357). + +* **#356** — the approval callback seam, ``strict_safety``, and the + ``max_side_effect_level`` ceiling enforced by :class:`FlowExecutor` at + execution time. +* **#357** — ``execute_flow(dry_run=True)``: read-only steps run, ``dry_run_fn`` + previews run, other side-effecting steps skip/abort, cache/checkpoint bypassed. +""" + +from __future__ import annotations + +import asyncio +from typing import Any + +import pytest +from pydantic import BaseModel + +from chainweaver import ( + ApprovalContext, + ApprovalDecision, + Flow, + FlowExecutor, + FlowRegistry, + FlowStep, + InMemoryStepCache, + SideEffectLevel, + Tool, + ToolSafetyContract, +) +from chainweaver.exceptions import ToolDefinitionError + + +class _In(BaseModel): + x: int + + +class _Out(BaseModel): + y: int + + +def _make_tool( + name: str, + side_effects: SideEffectLevel, + *, + requires_approval: bool = False, + supports_dry_run: bool = False, + dry_run_fn: Any = None, + counter: list[int] | None = None, +) -> Tool: + def fn(inp: _In) -> dict[str, Any]: + if counter is not None: + counter[0] += 1 + return {"y": inp.x + 1} + + return Tool( + name=name, + description=f"{name} tool.", + input_schema=_In, + output_schema=_Out, + fn=fn, + safety=ToolSafetyContract( + side_effects=side_effects, + requires_approval=requires_approval, + supports_dry_run=supports_dry_run, + ), + dry_run_fn=dry_run_fn, + ) + + +def _single_step_executor(tool: Tool, **executor_kwargs: Any) -> FlowExecutor: + registry = FlowRegistry() + registry.register_flow( + Flow(name="f", description="d", steps=[FlowStep(tool_name=tool.name, input_mapping={})]) + ) + executor = FlowExecutor(registry, **executor_kwargs) + executor.register_tool(tool) + return executor + + +# --------------------------------------------------------------------------- +# #356 — approval enforcement +# --------------------------------------------------------------------------- + + +class TestApprovalEnforcement: + def test_approve_runs_step_and_records_decision(self) -> None: + tool = _make_tool("writer", SideEffectLevel.WRITE, requires_approval=True) + executor = _single_step_executor( + tool, approval_callback=lambda ctx: ApprovalDecision.APPROVE + ) + result = executor.execute_flow("f", {"x": 1}) + assert result.success is True + assert result.final_output == {"x": 1, "y": 2} + record = result.execution_log[0] + assert record.approval is not None + assert record.approval.decision is ApprovalDecision.APPROVE + + def test_deny_aborts_step(self) -> None: + tool = _make_tool("writer", SideEffectLevel.WRITE, requires_approval=True) + executor = _single_step_executor(tool, approval_callback=lambda ctx: ApprovalDecision.DENY) + result = executor.execute_flow("f", {"x": 1}) + assert result.success is False + record = result.execution_log[0] + assert record.error_type == "ApprovalDeniedError" + assert record.approval is not None + assert record.approval.decision is ApprovalDecision.DENY + + def test_callback_receives_context(self) -> None: + seen: list[ApprovalContext] = [] + + def approver(ctx: ApprovalContext) -> ApprovalDecision: + seen.append(ctx) + return ApprovalDecision.APPROVE + + tool = _make_tool("writer", SideEffectLevel.WRITE, requires_approval=True) + executor = _single_step_executor(tool, approval_callback=approver) + executor.execute_flow("f", {"x": 5}) + assert len(seen) == 1 + assert seen[0].tool_name == "writer" + assert seen[0].inputs == {"x": 5} + assert seen[0].safety.requires_approval is True + + def test_callback_raises_is_denied(self) -> None: + def boom(ctx: ApprovalContext) -> ApprovalDecision: + raise RuntimeError("approver exploded") + + tool = _make_tool("writer", SideEffectLevel.WRITE, requires_approval=True) + executor = _single_step_executor(tool, approval_callback=boom) + result = executor.execute_flow("f", {"x": 1}) + assert result.success is False + record = result.execution_log[0] + assert record.error_type == "ApprovalDeniedError" + # A misbehaving callback is still an approval outcome: recorded as DENY. + assert record.approval is not None + assert record.approval.decision is ApprovalDecision.DENY + assert record.approval.reason is not None + + def test_callback_returns_invalid_is_denied(self) -> None: + tool = _make_tool("writer", SideEffectLevel.WRITE, requires_approval=True) + executor = _single_step_executor(tool, approval_callback=lambda ctx: "yes") + result = executor.execute_flow("f", {"x": 1}) + assert result.success is False + record = result.execution_log[0] + assert record.error_type == "ApprovalDeniedError" + assert record.approval is not None + assert record.approval.decision is ApprovalDecision.DENY + + def test_no_callback_advisory_by_default(self) -> None: + # requires_approval with no callback and no strict_safety: runs (advisory). + tool = _make_tool("writer", SideEffectLevel.WRITE, requires_approval=True) + executor = _single_step_executor(tool) + result = executor.execute_flow("f", {"x": 1}) + assert result.success is True + assert result.execution_log[0].approval is None + + def test_strict_safety_refuses_without_callback(self) -> None: + tool = _make_tool("writer", SideEffectLevel.WRITE, requires_approval=True) + executor = _single_step_executor(tool, strict_safety=True) + result = executor.execute_flow("f", {"x": 1}) + assert result.success is False + record = result.execution_log[0] + assert record.error_type == "ApprovalDeniedError" + # Denial under strict_safety is recorded for audit completeness. + assert record.approval is not None + assert record.approval.decision is ApprovalDecision.DENY + + def test_no_approval_required_ignores_callback(self) -> None: + called: list[int] = [] + + def approver(ctx: ApprovalContext) -> ApprovalDecision: + called.append(1) + return ApprovalDecision.APPROVE + + tool = _make_tool("reader", SideEffectLevel.READ) # requires_approval=False + executor = _single_step_executor(tool, approval_callback=approver) + result = executor.execute_flow("f", {"x": 1}) + assert result.success is True + assert called == [] # callback never consulted + + def test_approval_enforced_on_async_lane(self) -> None: + tool = _make_tool("writer", SideEffectLevel.WRITE, requires_approval=True) + executor = _single_step_executor(tool, approval_callback=lambda ctx: ApprovalDecision.DENY) + result = asyncio.run(executor.execute_flow_async("f", {"x": 1})) + assert result.success is False + assert result.execution_log[0].error_type == "ApprovalDeniedError" + + def test_approval_record_roundtrips(self) -> None: + tool = _make_tool("writer", SideEffectLevel.WRITE, requires_approval=True) + executor = _single_step_executor( + tool, approval_callback=lambda ctx: ApprovalDecision.APPROVE + ) + result = executor.execute_flow("f", {"x": 1}) + from chainweaver import ExecutionResult + + restored = ExecutionResult.model_validate_json(result.model_dump_json()) + assert restored.execution_log[0].approval is not None + assert restored.execution_log[0].approval.decision is ApprovalDecision.APPROVE + + +# --------------------------------------------------------------------------- +# #356 — side-effect ceiling +# --------------------------------------------------------------------------- + + +class TestSideEffectCeiling: + def test_ceiling_refuses_higher_level(self) -> None: + tool = _make_tool("destroyer", SideEffectLevel.DESTRUCTIVE) + executor = _single_step_executor(tool, max_side_effect_level=SideEffectLevel.READ) + result = executor.execute_flow("f", {"x": 1}) + assert result.success is False + assert result.execution_log[0].error_type == "SafetyCeilingError" + + def test_ceiling_allows_level_at_or_below(self) -> None: + tool = _make_tool("reader", SideEffectLevel.READ) + executor = _single_step_executor(tool, max_side_effect_level=SideEffectLevel.WRITE) + result = executor.execute_flow("f", {"x": 1}) + assert result.success is True + + +# --------------------------------------------------------------------------- +# #357 — dry-run +# --------------------------------------------------------------------------- + + +class TestDryRun: + def test_construction_requires_dry_run_fn(self) -> None: + with pytest.raises(ToolDefinitionError): + Tool( + name="t", + description="d", + input_schema=_In, + output_schema=_Out, + fn=lambda i: {"y": 1}, + safety=ToolSafetyContract( + side_effects=SideEffectLevel.EXTERNAL, supports_dry_run=True + ), + ) + + def test_read_only_step_runs_in_dry_run(self) -> None: + tool = _make_tool("reader", SideEffectLevel.READ) + executor = _single_step_executor(tool) + result = executor.execute_flow("f", {"x": 1}, dry_run=True) + assert result.dry_run is True + assert result.success is True + assert result.final_output == {"x": 1, "y": 2} + + def test_dry_run_fn_used_for_side_effecting_step(self) -> None: + tool = _make_tool( + "deploy", + SideEffectLevel.EXTERNAL, + supports_dry_run=True, + dry_run_fn=lambda i: {"y": 999}, + ) + executor = _single_step_executor(tool) + result = executor.execute_flow("f", {"x": 1}, dry_run=True) + assert result.dry_run is True + assert result.success is True + assert result.final_output == {"x": 1, "y": 999} + assert result.execution_log[0].skipped is False + + def test_skip_policy_stubs_side_effecting_step(self) -> None: + tool = _make_tool("writer", SideEffectLevel.WRITE) + executor = _single_step_executor(tool) + result = executor.execute_flow("f", {"x": 1}, dry_run=True) + assert result.success is True + record = result.execution_log[0] + assert record.skipped is True + # Skipped step merges nothing — only the initial input remains. + assert result.final_output == {"x": 1} + + def test_abort_policy_fails_side_effecting_step(self) -> None: + tool = _make_tool("writer", SideEffectLevel.WRITE) + executor = _single_step_executor(tool) + result = executor.execute_flow("f", {"x": 1}, dry_run=True, dry_run_unsupported="abort") + assert result.success is False + assert result.execution_log[0].error_type == "FlowExecutionError" + + def test_invalid_unsupported_policy_rejected(self) -> None: + tool = _make_tool("writer", SideEffectLevel.WRITE) + executor = _single_step_executor(tool) + with pytest.raises(ValueError): + executor.execute_flow("f", {"x": 1}, dry_run=True, dry_run_unsupported="nope") + + def test_dry_run_bypasses_cache(self) -> None: + counter = [0] + tool = _make_tool("reader", SideEffectLevel.READ, counter=counter) + executor = _single_step_executor(tool, step_cache=InMemoryStepCache()) + executor.execute_flow("f", {"x": 1}, dry_run=True) + executor.execute_flow("f", {"x": 1}, dry_run=True) + # Each dry run actually invokes the tool; nothing is served from cache. + assert counter[0] == 2 + + def test_normal_run_not_marked_dry(self) -> None: + tool = _make_tool("reader", SideEffectLevel.READ) + executor = _single_step_executor(tool) + result = executor.execute_flow("f", {"x": 1}) + assert result.dry_run is False diff --git a/tests/test_mcp_adapter_trust.py b/tests/test_mcp_adapter_trust.py new file mode 100644 index 0000000..43b2d1a --- /dev/null +++ b/tests/test_mcp_adapter_trust.py @@ -0,0 +1,294 @@ +"""Tests for the MCP adapter trust-boundary hardening (issues #358, #359, #371). + +Covers three adjacent concerns that all live on the MCP import boundary +(``chainweaver/mcp/adapter.py``): + +* **#371** — mapping server-declared ``ToolAnnotations`` onto a conservative + :class:`~chainweaver.contracts.ToolSafetyContract`. +* **#359** — the :class:`~chainweaver.mcp.MetadataPolicy` trust controls for + server-provided tool names and descriptions. +* **#358** — raw-schema fingerprint pinning and drift detection. + +The integration tests drive an in-memory FastMCP server using the same +``create_connected_server_and_client_session`` helper as ``test_mcp_adapter.py``. +""" + +from __future__ import annotations + +import asyncio +import json +import logging +from typing import Any + +import pytest +from mcp.server.fastmcp import FastMCP +from mcp.shared.memory import create_connected_server_and_client_session +from mcp.types import ToolAnnotations + +from chainweaver.contracts import DeterminismLevel, SideEffectLevel +from chainweaver.exceptions import MCPMetadataError, MCPSchemaDriftError +from chainweaver.mcp import MCPToolAdapter, MetadataPolicy, build_pin_file, load_pins +from chainweaver.mcp.adapter import _safety_from_annotations + + +def _run(coro: Any) -> Any: + """Run *coro* on a fresh asyncio event loop and return its result.""" + return asyncio.run(coro) + + +def _build_demo_server() -> FastMCP: + server = FastMCP(name="demo") + + @server.tool(name="echo", description="Echoes the supplied text.") + def echo(text: str) -> dict: # type: ignore[type-arg] + return {"echoed": text} + + @server.tool(name="add", description="Adds two integers.") + def add(a: int, b: int) -> int: + return a + b + + return server + + +# --------------------------------------------------------------------------- +# #371 — ToolAnnotations → ToolSafetyContract mapping +# --------------------------------------------------------------------------- + + +class TestAdapterValidation: + def test_invalid_annotation_trust_rejected(self) -> None: + with pytest.raises(ValueError): + MCPToolAdapter(None, annotation_trust="bogus") # type: ignore[arg-type] + + def test_invalid_on_drift_rejected(self) -> None: + # A typo must fail loudly, not silently fall through to "accept". + with pytest.raises(ValueError): + MCPToolAdapter(None, on_drift="erorr") # type: ignore[arg-type] + + +class TestAnnotationMapping: + def test_ignore_always_none(self) -> None: + ann = ToolAnnotations(readOnlyHint=True) + assert _safety_from_annotations(ann, "ignore") is None + assert _safety_from_annotations(None, "ignore") is None + + def test_trust_unannotated_is_none(self) -> None: + assert _safety_from_annotations(None, "trust") is None + + def test_cap_unannotated_is_external(self) -> None: + contract = _safety_from_annotations(None, "cap") + assert contract is not None + assert contract.side_effects is SideEffectLevel.EXTERNAL + assert contract.read_only is False + assert contract.determinism_level is DeterminismLevel.NONE + + def test_read_only_maps_to_read_not_none(self) -> None: + # A declared read-only remote tool still observed the world: READ, not NONE. + for trust in ("trust", "cap"): + contract = _safety_from_annotations(ToolAnnotations(readOnlyHint=True), trust) + assert contract is not None + assert contract.side_effects is SideEffectLevel.READ + assert contract.read_only is True + + def test_destructive_maps_to_destructive(self) -> None: + contract = _safety_from_annotations(ToolAnnotations(destructiveHint=True), "cap") + assert contract is not None + assert contract.side_effects is SideEffectLevel.DESTRUCTIVE + assert contract.read_only is False + + def test_destructive_wins_over_read_only(self) -> None: + contract = _safety_from_annotations( + ToolAnnotations(readOnlyHint=True, destructiveHint=True), "cap" + ) + assert contract is not None + assert contract.side_effects is SideEffectLevel.DESTRUCTIVE + + def test_idempotent_hint_propagates(self) -> None: + contract = _safety_from_annotations( + ToolAnnotations(destructiveHint=True, idempotentHint=True), "cap" + ) + assert contract is not None + assert contract.idempotent is True + # Destructive but idempotent → retry is at least not unsafe by idempotency. + assert contract.safe_to_retry is True + + def test_remote_determinism_always_none(self) -> None: + contract = _safety_from_annotations(ToolAnnotations(readOnlyHint=True), "cap") + assert contract is not None + assert contract.determinism_level is DeterminismLevel.NONE + assert contract.cacheable is False + + def test_integration_unannotated_tool_capped_external(self) -> None: + async def go() -> None: + server = _build_demo_server() + async with create_connected_server_and_client_session(server._mcp_server) as session: + await session.initialize() + adapter = MCPToolAdapter(session, annotation_trust="cap") + tools = {t.name: t for t in await adapter.discover_tools()} + # FastMCP does not emit annotations for these tools → capped EXTERNAL. + assert tools["echo"].safety.side_effects is SideEffectLevel.EXTERNAL + assert tools["echo"].metadata["mcp_annotation_source"] == "absent" + + _run(go()) + + def test_integration_ignore_leaves_permissive_default(self) -> None: + async def go() -> None: + server = _build_demo_server() + async with create_connected_server_and_client_session(server._mcp_server) as session: + await session.initialize() + adapter = MCPToolAdapter(session, annotation_trust="ignore") + tools = {t.name: t for t in await adapter.discover_tools()} + # safety=None → Tool falls back to its permissive default contract, + # but remains uncached (the historical adapter behaviour). + assert tools["echo"].cacheable is False + assert tools["echo"].safety_declared is False + + _run(go()) + + +# --------------------------------------------------------------------------- +# #359 — MetadataPolicy +# --------------------------------------------------------------------------- + + +class TestMetadataPolicy: + def test_strips_control_chars(self) -> None: + policy = MetadataPolicy() + out = policy.apply_description("hi\x00\x07there", cw_name="t", server=None) + assert out == "hithere" + + def test_normalizes_whitespace(self) -> None: + policy = MetadataPolicy() + out = policy.apply_description("a b\n\nc", cw_name="t", server=None) + assert out == "a b c" + + def test_truncates_to_cap(self) -> None: + policy = MetadataPolicy(max_description_length=5, normalize_whitespace=False) + out = policy.apply_description("abcdefghij", cw_name="t", server=None) + assert out == "abcde…(truncated)" + + def test_placeholder_mode_ignores_remote_text(self) -> None: + policy = MetadataPolicy(description_mode="placeholder") + out = policy.apply_description("malicious instructions", cw_name="t", server="srv") + assert "malicious" not in out + assert "srv" in out + + def test_empty_after_sanitize_falls_back(self) -> None: + policy = MetadataPolicy() + out = policy.apply_description("\x00\x01\x02", cw_name="tool_x", server=None) + assert out == "MCP tool 'tool_x'." + + def test_invalid_name_errors_by_default(self) -> None: + policy = MetadataPolicy() + with pytest.raises(MCPMetadataError): + policy.apply_name("bad name!") + + def test_invalid_name_sanitized_when_configured(self) -> None: + policy = MetadataPolicy(on_invalid_name="sanitize") + assert policy.apply_name("bad name!") == "bad_name_" + + def test_valid_name_unchanged(self) -> None: + assert MetadataPolicy().apply_name("search__query") == "search__query" + + def test_permissive_restores_verbatim(self) -> None: + policy = MetadataPolicy.permissive() + raw = "x" * 5000 + "\x07" + out = policy.apply_description(raw, cw_name="t", server=None) + assert out == raw # no cap, no stripping + assert policy.apply_name("weird name!!") == "weird name!!" + + def test_integration_raw_description_preserved(self) -> None: + async def go() -> None: + server = _build_demo_server() + async with create_connected_server_and_client_session(server._mcp_server) as session: + await session.initialize() + adapter = MCPToolAdapter( + session, metadata_policy=MetadataPolicy(description_mode="placeholder") + ) + tools = {t.name: t for t in await adapter.discover_tools()} + # Description replaced, but the raw server text is retained for audit. + assert "Echoes" not in tools["echo"].description + assert tools["echo"].metadata["mcp_raw_description"] == "Echoes the supplied text." + + _run(go()) + + +# --------------------------------------------------------------------------- +# #358 — schema-hash pinning + drift +# --------------------------------------------------------------------------- + + +class TestSchemaPinning: + def test_metadata_carries_schema_hash(self) -> None: + async def go() -> None: + server = _build_demo_server() + async with create_connected_server_and_client_session(server._mcp_server) as session: + await session.initialize() + adapter = MCPToolAdapter(session) + tools = await adapter.discover_tools() + for tool in tools: + assert isinstance(tool.metadata["mcp_schema_hash"], str) + assert len(tool.metadata["mcp_schema_hash"]) == 16 + + _run(go()) + + def test_drift_error_raises(self) -> None: + async def go() -> None: + server = _build_demo_server() + async with create_connected_server_and_client_session(server._mcp_server) as session: + await session.initialize() + adapter = MCPToolAdapter(session, on_drift="error") + with pytest.raises(MCPSchemaDriftError) as excinfo: + await adapter.discover_tools(pins={"echo": "0000000000000000"}) + assert excinfo.value.tool_name == "echo" + + _run(go()) + + def test_drift_warn_continues(self, caplog: pytest.LogCaptureFixture) -> None: + async def go() -> list[Any]: + server = _build_demo_server() + async with create_connected_server_and_client_session(server._mcp_server) as session: + await session.initialize() + adapter = MCPToolAdapter(session, on_drift="warn") + with caplog.at_level(logging.WARNING, logger="chainweaver.mcp.adapter"): + return await adapter.discover_tools(pins={"echo": "0000000000000000"}) + + tools = _run(go()) + assert {t.name for t in tools} == {"echo", "add"} + assert any("drifted" in r.message for r in caplog.records) + + def test_matching_pin_no_drift(self) -> None: + async def go() -> None: + server = _build_demo_server() + async with create_connected_server_and_client_session(server._mcp_server) as session: + await session.initialize() + adapter = MCPToolAdapter(session, on_drift="error") + first = await adapter.discover_tools() + pins = { + t.metadata["mcp_remote_name"]: t.metadata["mcp_schema_hash"] for t in first + } + # Re-discovering with the captured pins must not raise. + again = await adapter.discover_tools(pins=pins) + assert len(again) == len(first) + + _run(go()) + + def test_pin_file_roundtrip(self, tmp_path: Any) -> None: + async def go() -> None: + server = _build_demo_server() + async with create_connected_server_and_client_session(server._mcp_server) as session: + await session.initialize() + adapter = MCPToolAdapter(session, on_drift="error") + tools = await adapter.discover_tools() + pin_file = tmp_path / "mcp-pins.json" + pin_file.write_text(json.dumps(build_pin_file(tools, server="demo"))) + + loaded = load_pins(pin_file) + assert loaded == { + t.metadata["mcp_remote_name"]: t.metadata["mcp_schema_hash"] for t in tools + } + # pins_path consumed end-to-end with no drift. + again = await adapter.discover_tools(pins_path=pin_file) + assert len(again) == len(tools) + + _run(go())