From 69a8b0c3d067815a88a25fd6781b8426cb12b34a Mon Sep 17 00:00:00 2001 From: wuyangfan Date: Fri, 12 Jun 2026 23:59:09 +0800 Subject: [PATCH] refactor: name flow validation step indexes --- AGENTS.md | 2 +- chainweaver/__init__.py | 4 +++ chainweaver/compiler.py | 5 ++-- chainweaver/executor.py | 47 ++++++++++++++++++---------------- chainweaver/export/callable.py | 5 ++-- chainweaver/mcp/server.py | 3 ++- chainweaver/step_index.py | 19 ++++++++++++++ chainweaver/tools.py | 11 ++++---- tests/fixtures/public_api.json | 13 ++++++++++ tests/test_data_integrity.py | 14 +++++++--- tests/test_flow_execution.py | 7 ++--- tests/test_step_contracts.py | 3 ++- 12 files changed, 92 insertions(+), 41 deletions(-) create mode 100644 chainweaver/step_index.py diff --git a/AGENTS.md b/AGENTS.md index 517446d..601dbb4 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -280,7 +280,7 @@ integration. | Field | Type | Meaning | |-------|------|---------| -| `step_index` | `int` | Zero-based position (`-1` = flow-input validation, `len(steps)` = flow-output validation). | +| `step_index` | `int` | Zero-based position (`FLOW_INPUT_STEP_INDEX` = flow-input validation, `flow_output_step_index(flow)` = flow-output/context validation). | | `tool_name` | `str` | Tool invoked (or flow name for validation records). | | `inputs` | `dict` | Validated inputs passed to the tool. | | `outputs` | `dict \| None` | Validated outputs, or `None` on failure. | diff --git a/chainweaver/__init__.py b/chainweaver/__init__.py index 9f46935..4ec93da 100644 --- a/chainweaver/__init__.py +++ b/chainweaver/__init__.py @@ -10,6 +10,7 @@ FlowBuilder, FlowRegistry, FlowExecutor, RetryPolicy, ExecutionPlan, ExecutionResult, ReplayMode, ReplayResult, StepDiff, StepPlan, StepRecord, + FLOW_INPUT_STEP_INDEX, flow_output_step_index, RedactionPolicy, TraceRecorder, ObservedStep, ObservedTrace, CostProfile, CostReport, PriceSnap, PROVIDER_PRICES, lookup_price, validate_dag_topology, @@ -191,6 +192,7 @@ ServiceMetrics, ServiceProposal, ) +from chainweaver.step_index import FLOW_INPUT_STEP_INDEX, flow_output_step_index from chainweaver.storage import FileStore, InMemoryStore, RegistryStore from chainweaver.testing.replay import FixtureStaleError from chainweaver.tools import Tool @@ -233,6 +235,7 @@ __all__ = [ "BUILTIN_PROPERTIES", + "FLOW_INPUT_STEP_INDEX", "PROVIDER_PRICES", "AgentTraceEvent", "AgentTraceImportError", @@ -374,6 +377,7 @@ "flow_from_dict", "flow_from_json", "flow_from_yaml", + "flow_output_step_index", "flow_schema_json", "flow_to_ascii", "flow_to_dict", diff --git a/chainweaver/compiler.py b/chainweaver/compiler.py index df3e09e..7083987 100644 --- a/chainweaver/compiler.py +++ b/chainweaver/compiler.py @@ -15,6 +15,7 @@ from pydantic.fields import FieldInfo from chainweaver.flow import Flow +from chainweaver.step_index import flow_output_step_index from chainweaver.tools import Tool # Types considered numeric for widening compatibility. @@ -293,7 +294,7 @@ def compile_flow(flow: Flow, tools: dict[str, Tool]) -> CompilationResult: if name not in context_fields: errors.append( CompilationError( - step_index=len(flow.steps), + step_index=flow_output_step_index(flow), tool_name=flow.name, field_name=name, issue_type="output_schema_gap", @@ -309,7 +310,7 @@ def compile_flow(flow: Flow, tools: dict[str, Tool]) -> CompilationResult: if not _types_compatible(actual_type, expected_type): errors.append( CompilationError( - step_index=len(flow.steps), + step_index=flow_output_step_index(flow), tool_name=flow.name, field_name=name, issue_type="output_type_mismatch", diff --git a/chainweaver/executor.py b/chainweaver/executor.py index 88034d2..4701736 100644 --- a/chainweaver/executor.py +++ b/chainweaver/executor.py @@ -83,6 +83,7 @@ ) from chainweaver.observation import TraceRecorder from chainweaver.registry import AnyFlow, FlowRegistry +from chainweaver.step_index import FLOW_INPUT_STEP_INDEX, flow_output_step_index from chainweaver.tools import Tool _logger = get_logger("chainweaver.executor") @@ -336,8 +337,9 @@ class StepRecord(BaseModel): step_index: Position of this record in the flow. For normal steps this is the zero-based step index. Two sentinel values are used for flow-level schema validation: - ``-1`` — input validation (before any step runs), - ``len(steps)`` — output validation (after all steps complete). + ``FLOW_INPUT_STEP_INDEX`` — input validation (before any step + runs), and ``flow_output_step_index(flow)`` — output validation + (after all steps complete). tool_name: Name of the tool that was invoked (or the flow name for flow-level validation records). inputs: The validated inputs that were passed to the tool. @@ -423,10 +425,10 @@ class ExecutionResult(BaseModel): one entry per executed tool step. When ``input_schema`` or ``output_schema`` is set on the flow and the corresponding validation **fails**, a synthetic record is appended carrying - the validation error (``step_index == -1`` for input failures, - ``step_index == len(steps)`` for output failures); successful - validations do not produce records, so the log is unchanged - on the happy path. + the validation error (``FLOW_INPUT_STEP_INDEX`` for input + failures, ``flow_output_step_index(flow)`` for output failures); + successful validations do not produce records, so the log is + unchanged on the happy path. trace_id: UUID4 hex string assigned at the start of the execution. Use this to correlate the result with logs or external systems. started_at: UTC timestamp when the execution began. @@ -1234,7 +1236,7 @@ def execute_flow( flow_name=flow_name, payload=initial_input, schema=flow.input_schema, - step_index=-1, + step_index=FLOW_INPUT_STEP_INDEX, context_label="flow_input", ) if validation_record is not None: @@ -1349,7 +1351,7 @@ def execute_flow( flow_name=flow_name, payload=context, schema=flow.output_schema, - step_index=len(flow.steps), + step_index=flow_output_step_index(flow), context_label="flow_output", ) if validation_record is not None: @@ -1376,7 +1378,7 @@ def execute_flow( flow_name=flow_name, payload=context, schema=flow.context_schema, - step_index=len(flow.steps), + step_index=flow_output_step_index(flow), context_label="flow_context", ) if context_record is not None: @@ -1570,7 +1572,7 @@ async def _execute_linear_flow_async( flow_name=flow_name, payload=initial_input, schema=flow.input_schema, - step_index=-1, + step_index=FLOW_INPUT_STEP_INDEX, context_label="flow_input", ) if validation_record is not None: @@ -1625,7 +1627,7 @@ async def _execute_linear_flow_async( flow_name=flow_name, payload=context, schema=flow.output_schema, - step_index=len(flow.steps), + step_index=flow_output_step_index(flow), context_label="flow_output", ) if validation_record is not None: @@ -1689,7 +1691,7 @@ async def _execute_dag_flow_async( flow_name=flow.name, payload=initial_input, schema=flow.input_schema, - step_index=-1, + step_index=FLOW_INPUT_STEP_INDEX, context_label="flow_input", ) if validation_record is not None: @@ -1817,7 +1819,7 @@ async def _execute_dag_flow_async( flow_name=flow.name, payload=context, schema=flow.output_schema, - step_index=len(flow.steps), + step_index=flow_output_step_index(flow), context_label="flow_output", ) if validation_record is not None: @@ -2402,10 +2404,11 @@ def _make_result( Args: tool_step_count: Number of *tool* step records in ``execution_log`` (excluding the synthetic flow-level - schema-validation records that may carry ``step_index == - -1`` or ``step_index == len(steps)``). Used to compute - ``cost_report.llm_calls_avoided`` so validation records - don't inflate the estimate. When ``None`` (the default), + schema-validation records that may carry + ``FLOW_INPUT_STEP_INDEX`` or ``flow_output_step_index(flow)``). + Used to compute ``cost_report.llm_calls_avoided`` so + validation records don't inflate the estimate. When ``None`` + (the default), falls back to ``len(execution_log)`` for callers that do not append validation records. Composed sub-flow steps (issue #75) are expanded to their nested tool invocations @@ -2676,7 +2679,7 @@ def _resume_linear_flow( flow_name=flow_name, payload=context, schema=flow.output_schema, - step_index=len(flow.steps), + step_index=flow_output_step_index(flow), context_label="flow_output", ) if validation_record is not None: @@ -2698,7 +2701,7 @@ def _resume_linear_flow( flow_name=flow_name, payload=context, schema=flow.context_schema, - step_index=len(flow.steps), + step_index=flow_output_step_index(flow), context_label="flow_context", ) if context_record is not None: @@ -3765,7 +3768,7 @@ def _execute_dag_flow( flow_name=flow.name, payload=initial_input, schema=flow.input_schema, - step_index=-1, + step_index=FLOW_INPUT_STEP_INDEX, context_label="flow_input", ) if validation_record is not None: @@ -4147,7 +4150,7 @@ def _execute_dag_flow( flow_name=flow.name, payload=context, schema=flow.output_schema, - step_index=len(flow.steps), + step_index=flow_output_step_index(flow), context_label="flow_output", ) if validation_record is not None: @@ -4174,7 +4177,7 @@ def _execute_dag_flow( flow_name=flow.name, payload=context, schema=flow.context_schema, - step_index=len(flow.steps), + step_index=flow_output_step_index(flow), context_label="flow_context", ) if context_record is not None: diff --git a/chainweaver/export/callable.py b/chainweaver/export/callable.py index afa79b9..ea2f4c8 100644 --- a/chainweaver/export/callable.py +++ b/chainweaver/export/callable.py @@ -22,6 +22,7 @@ from chainweaver.exceptions import FlowExecutionError from chainweaver.export._schema import derive_flow_input_schema +from chainweaver.step_index import FLOW_INPUT_STEP_INDEX, flow_output_step_index if TYPE_CHECKING: # pragma: no cover — type-only references from chainweaver.executor import FlowExecutor @@ -78,7 +79,7 @@ def _call(raw_inputs: dict[str, Any]) -> dict[str, Any]: failed = next((r for r in result.execution_log if not r.success), None) if failed is None: detail = "Flow execution failed without recording a failing step." - step_index = -1 + step_index = FLOW_INPUT_STEP_INDEX tool_name = flow_name else: detail = failed.error_message or failed.error_type or "Unknown error." @@ -88,7 +89,7 @@ def _call(raw_inputs: dict[str, Any]) -> dict[str, Any]: if result.final_output is None: raise FlowExecutionError( tool_name=flow_name, - step_index=len(flow.steps), + step_index=flow_output_step_index(flow), detail="Flow reported success but produced no final output.", ) return result.final_output diff --git a/chainweaver/mcp/server.py b/chainweaver/mcp/server.py index f807784..a158b76 100644 --- a/chainweaver/mcp/server.py +++ b/chainweaver/mcp/server.py @@ -47,6 +47,7 @@ from chainweaver.contracts import SideEffectLevel, ToolSafetyContract, merge_safety from chainweaver.exceptions import FlowExecutionError from chainweaver.flow import DAGFlow, Flow, FlowLifecycle +from chainweaver.step_index import FLOW_INPUT_STEP_INDEX try: # Optional dependency. from fastmcp import FastMCP @@ -328,7 +329,7 @@ async def _dispatcher(**kwargs: Any) -> dict[str, Any]: if last is not None and last.error_type is not None else "flow execution failed without recorded step error" ) - raise FlowExecutionError(flow_name, -1, detail) + raise FlowExecutionError(flow_name, FLOW_INPUT_STEP_INDEX, detail) if output_schema is not None and result.final_output is not None: validated_out = output_schema.model_validate(result.final_output) return validated_out.model_dump() diff --git a/chainweaver/step_index.py b/chainweaver/step_index.py new file mode 100644 index 0000000..9c91ec7 --- /dev/null +++ b/chainweaver/step_index.py @@ -0,0 +1,19 @@ +"""Named step-index sentinels for flow-level validation records.""" + +from __future__ import annotations + +from collections.abc import Sized +from typing import Protocol + + +class _FlowLike(Protocol): + steps: Sized + + +# Synthetic record emitted when flow input validation fails before step 0. +FLOW_INPUT_STEP_INDEX = -1 + + +def flow_output_step_index(flow: _FlowLike) -> int: + """Return the synthetic index used after the final flow step.""" + return len(flow.steps) diff --git a/chainweaver/tools.py b/chainweaver/tools.py index c65a9ae..11ea28e 100644 --- a/chainweaver/tools.py +++ b/chainweaver/tools.py @@ -48,6 +48,7 @@ ToolTimeoutError, ) from chainweaver.flow import DAGFlow, DAGFlowStep, Flow, FlowStep +from chainweaver.step_index import FLOW_INPUT_STEP_INDEX, flow_output_step_index if TYPE_CHECKING: from chainweaver.executor import FlowExecutor @@ -507,7 +508,7 @@ def _flow_fn(validated_input: BaseModel) -> dict[str, Any]: failed = next((r for r in result.execution_log if not r.success), None) if failed is None: detail = "Flow execution failed without recording a failing step." - step_index = -1 + step_index = FLOW_INPUT_STEP_INDEX else: detail = failed.error_message or failed.error_type or "Unknown error." step_index = failed.step_index @@ -516,13 +517,11 @@ def _flow_fn(validated_input: BaseModel) -> dict[str, Any]: # Defensive: a successful run should always have a final_output, # but the executor's contract allows None on failure paths and # this is the only place the closure can guarantee non-None. - # Use ``len(flow.steps)`` (the flow-output validation sentinel - # per AGENTS.md §5 StepRecord) — this anomaly is a flow-output - # contract violation, not a flow-input validation failure - # (which is what ``step_index=-1`` would denote). + # Use the flow-output validation sentinel per AGENTS.md + # StepRecord: this anomaly is a flow-output contract violation. raise FlowExecutionError( tool_name=tool_name, - step_index=len(flow.steps), + step_index=flow_output_step_index(flow), detail="Flow reported success but produced no final output.", ) return result.final_output diff --git a/tests/fixtures/public_api.json b/tests/fixtures/public_api.json index 21bfa18..0148e47 100644 --- a/tests/fixtures/public_api.json +++ b/tests/fixtures/public_api.json @@ -41,6 +41,7 @@ "ExecutionPlan", "ExecutionResult", "ExecutionSnapshot", + "FLOW_INPUT_STEP_INDEX", "FaultConfig", "FileCheckpointer", "FileStepCache", @@ -142,6 +143,7 @@ "flow_from_dict", "flow_from_json", "flow_from_yaml", + "flow_output_step_index", "flow_schema_json", "flow_to_ascii", "flow_to_dict", @@ -559,6 +561,11 @@ "module": "chainweaver.checkpoint", "qualname": "ExecutionSnapshot" }, + "FLOW_INPUT_STEP_INDEX": { + "kind": "int", + "module": null, + "qualname": null + }, "FaultConfig": { "kind": "class", "module": "chainweaver.fuzz", @@ -1370,6 +1377,12 @@ "qualname": "flow_from_yaml", "signature": "(data: str) -> AnyFlow" }, + "flow_output_step_index": { + "kind": "function", + "module": "chainweaver.step_index", + "qualname": "flow_output_step_index", + "signature": "(flow: _FlowLike) -> int" + }, "flow_schema_json": { "kind": "function", "module": "chainweaver.schemas", diff --git a/tests/test_data_integrity.py b/tests/test_data_integrity.py index 5332b90..02704f6 100644 --- a/tests/test_data_integrity.py +++ b/tests/test_data_integrity.py @@ -7,7 +7,15 @@ from pydantic import BaseModel -from chainweaver import Flow, FlowExecutor, FlowRegistry, FlowStep, Tool +from chainweaver import ( + FLOW_INPUT_STEP_INDEX, + Flow, + FlowExecutor, + FlowRegistry, + FlowStep, + Tool, + flow_output_step_index, +) class NumberInput(BaseModel): @@ -186,7 +194,7 @@ def test_schema_validated_execution_context() -> None: input_result = input_executor.execute_flow("input_validated", {"number": "bad"}) assert not input_result.success - assert input_result.execution_log[0].step_index == -1 + assert input_result.execution_log[0].step_index == FLOW_INPUT_STEP_INDEX assert input_result.execution_log[0].error_type == "SchemaValidationError" output_validated = Flow( @@ -203,5 +211,5 @@ def test_schema_validated_execution_context() -> None: output_result = output_executor.execute_flow("output_validated", {"number": 7}) assert not output_result.success - assert output_result.execution_log[-1].step_index == len(output_validated.steps) + assert output_result.execution_log[-1].step_index == flow_output_step_index(output_validated) assert output_result.execution_log[-1].error_type == "SchemaValidationError" diff --git a/tests/test_flow_execution.py b/tests/test_flow_execution.py index d442d87..cad1406 100644 --- a/tests/test_flow_execution.py +++ b/tests/test_flow_execution.py @@ -17,6 +17,7 @@ from chainweaver.executor import FlowExecutor from chainweaver.flow import DAGFlow, DAGFlowStep, Flow, FlowStep from chainweaver.registry import FlowRegistry +from chainweaver.step_index import FLOW_INPUT_STEP_INDEX, flow_output_step_index from chainweaver.tools import Tool # --------------------------------------------------------------------------- @@ -513,7 +514,7 @@ def test_invalid_input_caught_before_execution( assert result.final_output is None # The only record should be the flow-level input validation failure. assert len(result.execution_log) == 1 - assert result.execution_log[0].step_index == -1 + assert result.execution_log[0].step_index == FLOW_INPUT_STEP_INDEX assert result.execution_log[0].error_type == "SchemaValidationError" def test_invalid_output_caught_after_execution( @@ -541,7 +542,7 @@ def test_invalid_output_caught_after_execution( # Normal step succeeded + one output-validation record. assert len(result.execution_log) == 2 output_record = result.execution_log[-1] - assert output_record.step_index == len(flow.steps) + assert output_record.step_index == flow_output_step_index(flow) assert output_record.error_type == "SchemaValidationError" def test_none_schemas_behave_unchanged( @@ -1259,7 +1260,7 @@ def test_invalid_input_schema_caught_before_execution(self) -> None: assert result.success is False assert len(result.execution_log) == 1 - assert result.execution_log[0].step_index == -1 + assert result.execution_log[0].step_index == FLOW_INPUT_STEP_INDEX assert result.execution_log[0].error_type == "SchemaValidationError" def test_invalid_output_schema_caught_after_execution(self) -> None: diff --git a/tests/test_step_contracts.py b/tests/test_step_contracts.py index 1afcab7..d0d16aa 100644 --- a/tests/test_step_contracts.py +++ b/tests/test_step_contracts.py @@ -31,6 +31,7 @@ ) from chainweaver.registry import FlowRegistry from chainweaver.serialization import flow_from_json +from chainweaver.step_index import flow_output_step_index from chainweaver.tools import Tool # --------------------------------------------------------------------------- @@ -354,7 +355,7 @@ def test_context_schema_failure_aborts_flow_after_steps(self, double_tool: Tool) # context-schema gate. assert len(result.execution_log) == 2 gate = result.execution_log[-1] - assert gate.step_index == len(flow.steps) + assert gate.step_index == flow_output_step_index(flow) assert gate.error_type == "SchemaValidationError" assert "flow_context" in (gate.error_message or "")