From 88f7c275ba50a2e0255be0eacc42c7bb58d2c573 Mon Sep 17 00:00:00 2001 From: Aaron Farntrog Date: Mon, 13 Apr 2026 15:25:32 -0400 Subject: [PATCH 1/7] feat: add agent_factory parameter to simplify agent evaluation setup Introduce adapter that wraps a no-arg agent factory into a task callable, automatically handling telemetry setup, span collection, and session mapping. Add as an alternative to in and , with mutual exclusivity validation, so users no longer need to manually wire up in-memory exporters and mappers. --- src/strands_evals/__init__.py | 2 + src/strands_evals/agent_task_adapter.py | 41 +++++ src/strands_evals/experiment.py | 29 +++- .../strands_evals/test_agent_task_adapter.py | 152 ++++++++++++++++++ .../test_experiment_agent_factory.py | 139 ++++++++++++++++ 5 files changed, 361 insertions(+), 2 deletions(-) create mode 100644 src/strands_evals/agent_task_adapter.py create mode 100644 tests/strands_evals/test_agent_task_adapter.py create mode 100644 tests/strands_evals/test_experiment_agent_factory.py diff --git a/src/strands_evals/__init__.py b/src/strands_evals/__init__.py index 7de7520e..5f2eddb0 100644 --- a/src/strands_evals/__init__.py +++ b/src/strands_evals/__init__.py @@ -1,4 +1,5 @@ from . import evaluators, extractors, generators, providers, simulation, telemetry, types +from .agent_task_adapter import create_agent_task from .case import Case from .evaluation_data_store import EvaluationDataStore from .experiment import Experiment @@ -11,6 +12,7 @@ "Case", "LocalFileTaskResultStore", "EvaluationDataStore", + "create_agent_task", "evaluators", "extractors", "providers", diff --git a/src/strands_evals/agent_task_adapter.py b/src/strands_evals/agent_task_adapter.py new file mode 100644 index 00000000..f5633087 --- /dev/null +++ b/src/strands_evals/agent_task_adapter.py @@ -0,0 +1,41 @@ +"""Adapts agent factories into task callables compatible with Experiment.run_evaluations.""" + +from collections.abc import Callable +from typing import Any + +from .case import Case +from .mappers.strands_in_memory_session_mapper import StrandsInMemorySessionMapper +from .telemetry import StrandsEvalsTelemetry + + +def create_agent_task(agent_factory: Callable[[], Any]) -> Callable[[Case], dict[str, Any]]: + """Wrap an agent factory into a task function for use with Experiment.run_evaluations. + + Per invocation, this: + 1. Clears the shared in-memory span exporter + 2. Creates a fresh Agent from the factory + 3. Calls agent(case.input) + 4. Collects finished spans and maps them to a Session + 5. Returns {"output": ..., "trajectory": session} + + Args: + agent_factory: A no-arg callable that returns a strands Agent instance. + + Returns: + A task callable that takes a Case and returns a structured dict. + """ + telemetry = StrandsEvalsTelemetry().setup_in_memory_exporter() + mapper = StrandsInMemorySessionMapper() + + def task(case: Case) -> dict[str, Any]: + telemetry.in_memory_exporter.clear() + + agent = agent_factory() + result = agent(case.input) + + spans = list(telemetry.in_memory_exporter.get_finished_spans()) + session = mapper.map_to_session(spans, case.session_id) + + return {"output": str(result), "trajectory": session} + + return task diff --git a/src/strands_evals/experiment.py b/src/strands_evals/experiment.py index e11e0dd4..7dee0bf6 100644 --- a/src/strands_evals/experiment.py +++ b/src/strands_evals/experiment.py @@ -16,6 +16,7 @@ ) from typing_extensions import Any, Generic +from .agent_task_adapter import create_agent_task from .case import Case from .evaluation_data_store import EvaluationDataStore from .evaluators.deterministic import Contains, Equals, StartsWith, StateEquals, ToolCalled @@ -500,8 +501,9 @@ async def _worker( def run_evaluations( self, - task: Callable[[Case[InputT, OutputT]], OutputT | dict[str, Any]], + task: Callable[[Case[InputT, OutputT]], OutputT | dict[str, Any]] | None = None, evaluation_data_store: EvaluationDataStore | None = None, + agent_factory: Callable[[], Any] | None = None, ) -> list[EvaluationReport]: """ Run the evaluations for all of the test cases with all evaluators. @@ -513,11 +515,22 @@ def run_evaluations( OutputT or {"output": OutputT, "trajectory": ...}. evaluation_data_store: Optional store for loading/saving evaluation data. When provided, cached results are loaded instead of running the task, and new results are saved after task execution. + agent_factory: A no-arg callable that returns a strands Agent. When provided, the agent + is automatically invoked with case.input and telemetry/span collection is handled + internally. Mutually exclusive with task. Return: A list of EvaluationReport objects, one for each evaluator, containing the overall score, individual case results, and basic feedback for each test case. """ + if task is not None and agent_factory is not None: + raise ValueError("Cannot specify both 'task' and 'agent_factory'. Use one or the other.") + if task is None and agent_factory is None: + raise ValueError("Must specify either 'task' or 'agent_factory'.") + + if agent_factory is not None: + task = create_agent_task(agent_factory) + if asyncio.iscoroutinefunction(task): raise ValueError("Async task is not supported. Please use run_evaluations_async instead.") @@ -525,9 +538,10 @@ def run_evaluations( async def run_evaluations_async( self, - task: Callable, + task: Callable | None = None, max_workers: int = 10, evaluation_data_store: EvaluationDataStore | None = None, + agent_factory: Callable[[], Any] | None = None, ) -> list[EvaluationReport]: """ Run evaluations asynchronously using a queue for parallel processing. @@ -539,10 +553,21 @@ async def run_evaluations_async( max_workers: Maximum number of parallel workers (default: 10) evaluation_data_store: Optional store for loading/saving evaluation data. When provided, cached results are loaded instead of running the task, and new results are saved after task execution. + agent_factory: A no-arg callable that returns a strands Agent. When provided, the agent + is automatically invoked with case.input and telemetry/span collection is handled + internally. Mutually exclusive with task. Returns: List of EvaluationReport objects, one for each evaluator, containing evaluation results """ + if task is not None and agent_factory is not None: + raise ValueError("Cannot specify both 'task' and 'agent_factory'. Use one or the other.") + if task is None and agent_factory is None: + raise ValueError("Must specify either 'task' or 'agent_factory'.") + + if agent_factory is not None: + task = create_agent_task(agent_factory) + if evaluation_data_store is not None: self._validate_case_names() diff --git a/tests/strands_evals/test_agent_task_adapter.py b/tests/strands_evals/test_agent_task_adapter.py new file mode 100644 index 00000000..718ef4f8 --- /dev/null +++ b/tests/strands_evals/test_agent_task_adapter.py @@ -0,0 +1,152 @@ +"""Tests for the agent task adapter that wraps agent factories into task callables.""" + +from unittest.mock import MagicMock, patch + +import pytest + +from strands_evals.agent_task_adapter import create_agent_task +from strands_evals.case import Case + + +class TestCreateAgentTask: + """Tests for create_agent_task function.""" + + def test_returns_callable(self): + """create_agent_task returns a callable.""" + factory = MagicMock() + task = create_agent_task(factory) + assert callable(task) + + @patch("strands_evals.agent_task_adapter.StrandsEvalsTelemetry") + @patch("strands_evals.agent_task_adapter.StrandsInMemorySessionMapper") + def test_calls_factory_to_create_agent(self, mock_mapper_cls, mock_telemetry_cls): + """The task calls the factory to get a fresh agent per case.""" + mock_agent = MagicMock() + mock_agent.return_value = MagicMock(message="result") + mock_agent.return_value.__str__ = MagicMock(return_value="result") + + factory = MagicMock(return_value=mock_agent) + + mock_telemetry = MagicMock() + mock_telemetry.setup_in_memory_exporter.return_value = mock_telemetry + mock_telemetry.in_memory_exporter.get_finished_spans.return_value = [] + mock_telemetry_cls.return_value = mock_telemetry + + mock_mapper = MagicMock() + mock_mapper.map_to_session.return_value = MagicMock() + mock_mapper_cls.return_value = mock_mapper + + case = Case(name="test", input="hello") + task = create_agent_task(factory) + task(case) + + factory.assert_called_once() + + @patch("strands_evals.agent_task_adapter.StrandsEvalsTelemetry") + @patch("strands_evals.agent_task_adapter.StrandsInMemorySessionMapper") + def test_calls_agent_with_case_input(self, mock_mapper_cls, mock_telemetry_cls): + """The task invokes the agent with case.input.""" + mock_agent = MagicMock() + mock_result = MagicMock() + mock_result.__str__ = MagicMock(return_value="the answer") + mock_agent.return_value = mock_result + + factory = MagicMock(return_value=mock_agent) + + mock_telemetry = MagicMock() + mock_telemetry.setup_in_memory_exporter.return_value = mock_telemetry + mock_telemetry.in_memory_exporter.get_finished_spans.return_value = [] + mock_telemetry_cls.return_value = mock_telemetry + + mock_mapper = MagicMock() + mock_mapper.map_to_session.return_value = MagicMock() + mock_mapper_cls.return_value = mock_mapper + + case = Case(name="test", input="What is 2+2?") + task = create_agent_task(factory) + task(case) + + mock_agent.assert_called_once_with("What is 2+2?") + + @patch("strands_evals.agent_task_adapter.StrandsEvalsTelemetry") + @patch("strands_evals.agent_task_adapter.StrandsInMemorySessionMapper") + def test_returns_dict_with_output_and_trajectory(self, mock_mapper_cls, mock_telemetry_cls): + """The task returns a dict with 'output' and 'trajectory' keys.""" + mock_agent = MagicMock() + mock_result = MagicMock() + mock_result.__str__ = MagicMock(return_value="42") + mock_agent.return_value = mock_result + + factory = MagicMock(return_value=mock_agent) + + mock_telemetry = MagicMock() + mock_telemetry.setup_in_memory_exporter.return_value = mock_telemetry + mock_telemetry.in_memory_exporter.get_finished_spans.return_value = [] + mock_telemetry_cls.return_value = mock_telemetry + + mock_session = MagicMock() + mock_mapper = MagicMock() + mock_mapper.map_to_session.return_value = mock_session + mock_mapper_cls.return_value = mock_mapper + + case = Case(name="test", input="What is 2+2?") + task = create_agent_task(factory) + result = task(case) + + assert isinstance(result, dict) + assert "output" in result + assert "trajectory" in result + assert result["output"] == "42" + assert result["trajectory"] is mock_session + + @patch("strands_evals.agent_task_adapter.StrandsEvalsTelemetry") + @patch("strands_evals.agent_task_adapter.StrandsInMemorySessionMapper") + def test_collects_spans_and_maps_to_session(self, mock_mapper_cls, mock_telemetry_cls): + """The task collects spans from telemetry and maps them to a Session.""" + mock_agent = MagicMock() + mock_agent.return_value = MagicMock(__str__=MagicMock(return_value="res")) + + factory = MagicMock(return_value=mock_agent) + + mock_spans = [MagicMock(), MagicMock()] + mock_telemetry = MagicMock() + mock_telemetry.setup_in_memory_exporter.return_value = mock_telemetry + mock_telemetry.in_memory_exporter.get_finished_spans.return_value = mock_spans + mock_telemetry_cls.return_value = mock_telemetry + + mock_session = MagicMock() + mock_mapper = MagicMock() + mock_mapper.map_to_session.return_value = mock_session + mock_mapper_cls.return_value = mock_mapper + + case = Case(name="test", input="hi", session_id="sess-123") + task = create_agent_task(factory) + result = task(case) + + mock_mapper.map_to_session.assert_called_once_with(list(mock_spans), "sess-123") + assert result["trajectory"] is mock_session + + @patch("strands_evals.agent_task_adapter.StrandsEvalsTelemetry") + @patch("strands_evals.agent_task_adapter.StrandsInMemorySessionMapper") + def test_clears_spans_between_calls(self, mock_mapper_cls, mock_telemetry_cls): + """Each call clears the in-memory exporter before running the agent.""" + mock_agent = MagicMock() + mock_agent.return_value = MagicMock(__str__=MagicMock(return_value="res")) + + factory = MagicMock(return_value=mock_agent) + + mock_telemetry = MagicMock() + mock_telemetry.setup_in_memory_exporter.return_value = mock_telemetry + mock_telemetry.in_memory_exporter.get_finished_spans.return_value = [] + mock_telemetry_cls.return_value = mock_telemetry + + mock_mapper_cls.return_value = MagicMock() + + task = create_agent_task(factory) + + case1 = Case(name="test1", input="a") + case2 = Case(name="test2", input="b") + task(case1) + task(case2) + + assert mock_telemetry.in_memory_exporter.clear.call_count == 2 diff --git a/tests/strands_evals/test_experiment_agent_factory.py b/tests/strands_evals/test_experiment_agent_factory.py new file mode 100644 index 00000000..4274595e --- /dev/null +++ b/tests/strands_evals/test_experiment_agent_factory.py @@ -0,0 +1,139 @@ +"""Tests for Experiment.run_evaluations with agent_factory parameter.""" + +from unittest.mock import MagicMock, patch + +import pytest + +from strands_evals.case import Case +from strands_evals.evaluators.evaluator import Evaluator +from strands_evals.experiment import Experiment +from strands_evals.types.evaluation import EvaluationOutput + + +class PassingEvaluator(Evaluator): + """Evaluator that always passes for testing.""" + + def evaluate(self, evaluation_case): + return [EvaluationOutput(score=1.0, test_pass=True, reason="pass")] + + +class TestRunEvaluationsWithAgentFactory: + """Tests for the agent_factory parameter on run_evaluations.""" + + def test_rejects_both_task_and_agent_factory(self): + """Passing both task and agent_factory raises ValueError.""" + experiment = Experiment( + cases=[Case(name="test", input="hi")], + evaluators=[PassingEvaluator()], + ) + + def my_task(case): + return "output" + + def my_factory(): + return MagicMock() + + with pytest.raises(ValueError, match="Cannot specify both"): + experiment.run_evaluations(task=my_task, agent_factory=my_factory) + + def test_rejects_neither_task_nor_agent_factory(self): + """Passing neither task nor agent_factory raises ValueError.""" + experiment = Experiment( + cases=[Case(name="test", input="hi")], + evaluators=[PassingEvaluator()], + ) + + with pytest.raises(ValueError, match="Must specify either"): + experiment.run_evaluations() + + @patch("strands_evals.experiment.create_agent_task") + def test_agent_factory_creates_task_via_adapter(self, mock_create_agent_task): + """When agent_factory is passed, it's wrapped via create_agent_task.""" + mock_task = MagicMock(return_value="output") + mock_create_agent_task.return_value = mock_task + + factory = MagicMock() + experiment = Experiment( + cases=[Case(name="test", input="hi")], + evaluators=[PassingEvaluator()], + ) + experiment.run_evaluations(agent_factory=factory) + + mock_create_agent_task.assert_called_once_with(factory) + + @patch("strands_evals.experiment.create_agent_task") + def test_agent_factory_results_flow_to_evaluators(self, mock_create_agent_task): + """Results from the adapted agent task are evaluated normally.""" + mock_task = MagicMock(return_value={"output": "42", "trajectory": None}) + mock_create_agent_task.return_value = mock_task + + factory = MagicMock() + experiment = Experiment( + cases=[Case(name="test", input="hi")], + evaluators=[PassingEvaluator()], + ) + reports = experiment.run_evaluations(agent_factory=factory) + + assert len(reports) == 1 + assert reports[0].scores[0] == 1.0 + assert reports[0].test_passes[0] is True + + def test_existing_task_parameter_still_works(self): + """Existing task= parameter continues to work unchanged.""" + def my_task(case): + return "output" + + experiment = Experiment( + cases=[Case(name="test", input="hi")], + evaluators=[PassingEvaluator()], + ) + reports = experiment.run_evaluations(task=my_task) + + assert len(reports) == 1 + assert reports[0].scores[0] == 1.0 + + +class TestRunEvaluationsAsyncWithAgentFactory: + """Tests for the agent_factory parameter on run_evaluations_async.""" + + def test_rejects_both_task_and_agent_factory(self): + """Passing both task and agent_factory raises ValueError.""" + import asyncio + + experiment = Experiment( + cases=[Case(name="test", input="hi")], + evaluators=[PassingEvaluator()], + ) + + with pytest.raises(ValueError, match="Cannot specify both"): + asyncio.run(experiment.run_evaluations_async(task=MagicMock(), agent_factory=MagicMock())) + + def test_rejects_neither_task_nor_agent_factory(self): + """Passing neither task nor agent_factory raises ValueError.""" + import asyncio + + experiment = Experiment( + cases=[Case(name="test", input="hi")], + evaluators=[PassingEvaluator()], + ) + + with pytest.raises(ValueError, match="Must specify either"): + asyncio.run(experiment.run_evaluations_async()) + + @patch("strands_evals.experiment.create_agent_task") + def test_agent_factory_works_async(self, mock_create_agent_task): + """agent_factory works with run_evaluations_async.""" + import asyncio + + mock_task = MagicMock(return_value="output") + mock_create_agent_task.return_value = mock_task + + factory = MagicMock() + experiment = Experiment( + cases=[Case(name="test", input="hi")], + evaluators=[PassingEvaluator()], + ) + reports = asyncio.run(experiment.run_evaluations_async(agent_factory=factory)) + + mock_create_agent_task.assert_called_once_with(factory) + assert len(reports) == 1 From 2b3f6ced16814f5a24cdb37ce85962b10b07d5a6 Mon Sep 17 00:00:00 2001 From: Aaron Farntrog Date: Mon, 13 Apr 2026 15:59:19 -0400 Subject: [PATCH 2/7] refactor: replace create_agent_task with AgentTask and TracedAgentTask classes Replace the `create_agent_task` factory function with two class-based task adapters: `AgentTask` for simple agent invocations and `TracedAgentTask` for invocations that also collect OpenTelemetry spans for trajectory evaluation. Additionally, simplify `Experiment.run_evaluations` by removing the `agent_factory` parameter and making `task` required, eliminating ambiguous dual-path configuration. Users now pass an `AgentTask` or `TracedAgentTask` instance directly as the task argument. --- src/strands_evals/__init__.py | 5 +- src/strands_evals/agent_task_adapter.py | 55 ++++--- src/strands_evals/experiment.py | 29 +--- .../strands_evals/test_agent_task_adapter.py | 145 ++++++++++-------- .../test_experiment_agent_factory.py | 139 ----------------- 5 files changed, 119 insertions(+), 254 deletions(-) delete mode 100644 tests/strands_evals/test_experiment_agent_factory.py diff --git a/src/strands_evals/__init__.py b/src/strands_evals/__init__.py index 5f2eddb0..2339e4f7 100644 --- a/src/strands_evals/__init__.py +++ b/src/strands_evals/__init__.py @@ -1,5 +1,5 @@ from . import evaluators, extractors, generators, providers, simulation, telemetry, types -from .agent_task_adapter import create_agent_task +from .agent_task_adapter import AgentTask, TracedAgentTask from .case import Case from .evaluation_data_store import EvaluationDataStore from .experiment import Experiment @@ -12,7 +12,8 @@ "Case", "LocalFileTaskResultStore", "EvaluationDataStore", - "create_agent_task", + "AgentTask", + "TracedAgentTask", "evaluators", "extractors", "providers", diff --git a/src/strands_evals/agent_task_adapter.py b/src/strands_evals/agent_task_adapter.py index f5633087..d747a266 100644 --- a/src/strands_evals/agent_task_adapter.py +++ b/src/strands_evals/agent_task_adapter.py @@ -8,34 +8,51 @@ from .telemetry import StrandsEvalsTelemetry -def create_agent_task(agent_factory: Callable[[], Any]) -> Callable[[Case], dict[str, Any]]: - """Wrap an agent factory into a task function for use with Experiment.run_evaluations. +class AgentTask: + """A task that creates a fresh Agent per case and invokes it with case.input. - Per invocation, this: - 1. Clears the shared in-memory span exporter - 2. Creates a fresh Agent from the factory - 3. Calls agent(case.input) - 4. Collects finished spans and maps them to a Session - 5. Returns {"output": ..., "trajectory": session} + Args: + agent_factory: A no-arg callable that returns a strands Agent instance. + + Example: + task = AgentTask(lambda: Agent(model="...", tools=[calculator])) + experiment.run_evaluations(task=task) + """ + + def __init__(self, agent_factory: Callable[[], Any]): + self._agent_factory = agent_factory + + def __call__(self, case: Case) -> dict[str, Any]: + agent = self._agent_factory() + result = agent(case.input) + return {"output": str(result)} + + +class TracedAgentTask(AgentTask): + """An AgentTask that also collects OpenTelemetry spans and maps them to a Session. + + Use this when your evaluators need trajectory data (e.g., TrajectoryEvaluator). Args: agent_factory: A no-arg callable that returns a strands Agent instance. - Returns: - A task callable that takes a Case and returns a structured dict. + Example: + task = TracedAgentTask(lambda: Agent(model="...", tools=[calculator])) + experiment.run_evaluations(task=task) """ - telemetry = StrandsEvalsTelemetry().setup_in_memory_exporter() - mapper = StrandsInMemorySessionMapper() - def task(case: Case) -> dict[str, Any]: - telemetry.in_memory_exporter.clear() + def __init__(self, agent_factory: Callable[[], Any]): + super().__init__(agent_factory) + self._telemetry = StrandsEvalsTelemetry().setup_in_memory_exporter() + self._mapper = StrandsInMemorySessionMapper() + + def __call__(self, case: Case) -> dict[str, Any]: + self._telemetry.in_memory_exporter.clear() - agent = agent_factory() + agent = self._agent_factory() result = agent(case.input) - spans = list(telemetry.in_memory_exporter.get_finished_spans()) - session = mapper.map_to_session(spans, case.session_id) + spans = list(self._telemetry.in_memory_exporter.get_finished_spans()) + session = self._mapper.map_to_session(spans, case.session_id) return {"output": str(result), "trajectory": session} - - return task diff --git a/src/strands_evals/experiment.py b/src/strands_evals/experiment.py index 7dee0bf6..e11e0dd4 100644 --- a/src/strands_evals/experiment.py +++ b/src/strands_evals/experiment.py @@ -16,7 +16,6 @@ ) from typing_extensions import Any, Generic -from .agent_task_adapter import create_agent_task from .case import Case from .evaluation_data_store import EvaluationDataStore from .evaluators.deterministic import Contains, Equals, StartsWith, StateEquals, ToolCalled @@ -501,9 +500,8 @@ async def _worker( def run_evaluations( self, - task: Callable[[Case[InputT, OutputT]], OutputT | dict[str, Any]] | None = None, + task: Callable[[Case[InputT, OutputT]], OutputT | dict[str, Any]], evaluation_data_store: EvaluationDataStore | None = None, - agent_factory: Callable[[], Any] | None = None, ) -> list[EvaluationReport]: """ Run the evaluations for all of the test cases with all evaluators. @@ -515,22 +513,11 @@ def run_evaluations( OutputT or {"output": OutputT, "trajectory": ...}. evaluation_data_store: Optional store for loading/saving evaluation data. When provided, cached results are loaded instead of running the task, and new results are saved after task execution. - agent_factory: A no-arg callable that returns a strands Agent. When provided, the agent - is automatically invoked with case.input and telemetry/span collection is handled - internally. Mutually exclusive with task. Return: A list of EvaluationReport objects, one for each evaluator, containing the overall score, individual case results, and basic feedback for each test case. """ - if task is not None and agent_factory is not None: - raise ValueError("Cannot specify both 'task' and 'agent_factory'. Use one or the other.") - if task is None and agent_factory is None: - raise ValueError("Must specify either 'task' or 'agent_factory'.") - - if agent_factory is not None: - task = create_agent_task(agent_factory) - if asyncio.iscoroutinefunction(task): raise ValueError("Async task is not supported. Please use run_evaluations_async instead.") @@ -538,10 +525,9 @@ def run_evaluations( async def run_evaluations_async( self, - task: Callable | None = None, + task: Callable, max_workers: int = 10, evaluation_data_store: EvaluationDataStore | None = None, - agent_factory: Callable[[], Any] | None = None, ) -> list[EvaluationReport]: """ Run evaluations asynchronously using a queue for parallel processing. @@ -553,21 +539,10 @@ async def run_evaluations_async( max_workers: Maximum number of parallel workers (default: 10) evaluation_data_store: Optional store for loading/saving evaluation data. When provided, cached results are loaded instead of running the task, and new results are saved after task execution. - agent_factory: A no-arg callable that returns a strands Agent. When provided, the agent - is automatically invoked with case.input and telemetry/span collection is handled - internally. Mutually exclusive with task. Returns: List of EvaluationReport objects, one for each evaluator, containing evaluation results """ - if task is not None and agent_factory is not None: - raise ValueError("Cannot specify both 'task' and 'agent_factory'. Use one or the other.") - if task is None and agent_factory is None: - raise ValueError("Must specify either 'task' or 'agent_factory'.") - - if agent_factory is not None: - task = create_agent_task(agent_factory) - if evaluation_data_store is not None: self._validate_case_names() diff --git a/tests/strands_evals/test_agent_task_adapter.py b/tests/strands_evals/test_agent_task_adapter.py index 718ef4f8..6dfec4ea 100644 --- a/tests/strands_evals/test_agent_task_adapter.py +++ b/tests/strands_evals/test_agent_task_adapter.py @@ -1,82 +1,106 @@ -"""Tests for the agent task adapter that wraps agent factories into task callables.""" +"""Tests for AgentTask and TracedAgentTask.""" from unittest.mock import MagicMock, patch -import pytest - -from strands_evals.agent_task_adapter import create_agent_task +from strands_evals.agent_task_adapter import AgentTask, TracedAgentTask from strands_evals.case import Case -class TestCreateAgentTask: - """Tests for create_agent_task function.""" +class TestAgentTask: + """Tests for the base AgentTask class.""" - def test_returns_callable(self): - """create_agent_task returns a callable.""" - factory = MagicMock() - task = create_agent_task(factory) + def test_is_callable(self): + """AgentTask instances are callable.""" + task = AgentTask(MagicMock()) assert callable(task) - @patch("strands_evals.agent_task_adapter.StrandsEvalsTelemetry") - @patch("strands_evals.agent_task_adapter.StrandsInMemorySessionMapper") - def test_calls_factory_to_create_agent(self, mock_mapper_cls, mock_telemetry_cls): - """The task calls the factory to get a fresh agent per case.""" + def test_calls_factory_to_create_agent(self): + """Calls the factory to get a fresh agent per case.""" mock_agent = MagicMock() - mock_agent.return_value = MagicMock(message="result") mock_agent.return_value.__str__ = MagicMock(return_value="result") + factory = MagicMock(return_value=mock_agent) + + task = AgentTask(factory) + task(Case(name="test", input="hello")) + factory.assert_called_once() + + def test_calls_agent_with_case_input(self): + """Invokes the agent with case.input.""" + mock_agent = MagicMock() + mock_agent.return_value.__str__ = MagicMock(return_value="answer") factory = MagicMock(return_value=mock_agent) - mock_telemetry = MagicMock() - mock_telemetry.setup_in_memory_exporter.return_value = mock_telemetry - mock_telemetry.in_memory_exporter.get_finished_spans.return_value = [] - mock_telemetry_cls.return_value = mock_telemetry + task = AgentTask(factory) + task(Case(name="test", input="What is 2+2?")) - mock_mapper = MagicMock() - mock_mapper.map_to_session.return_value = MagicMock() - mock_mapper_cls.return_value = mock_mapper + mock_agent.assert_called_once_with("What is 2+2?") - case = Case(name="test", input="hello") - task = create_agent_task(factory) - task(case) + def test_returns_dict_with_output(self): + """Returns a dict with 'output' key.""" + mock_agent = MagicMock() + mock_agent.return_value.__str__ = MagicMock(return_value="42") + factory = MagicMock(return_value=mock_agent) - factory.assert_called_once() + task = AgentTask(factory) + result = task(Case(name="test", input="What is 2+2?")) + + assert isinstance(result, dict) + assert result["output"] == "42" + assert "trajectory" not in result + + def test_works_with_experiment_run_evaluations(self): + """AgentTask works when passed as the task parameter to Experiment.""" + from strands_evals.evaluators.evaluator import Evaluator + from strands_evals.experiment import Experiment + from strands_evals.types.evaluation import EvaluationOutput + + class PassingEvaluator(Evaluator): + def evaluate(self, evaluation_case): + return [EvaluationOutput(score=1.0, test_pass=True, reason="pass")] + + mock_agent = MagicMock() + mock_agent.return_value.__str__ = MagicMock(return_value="output") + factory = MagicMock(return_value=mock_agent) + + experiment = Experiment( + cases=[Case(name="test", input="hi")], + evaluators=[PassingEvaluator()], + ) + reports = experiment.run_evaluations(task=AgentTask(factory)) + + assert len(reports) == 1 + assert reports[0].scores[0] == 1.0 + + +class TestTracedAgentTask: + """Tests for TracedAgentTask that adds telemetry collection.""" @patch("strands_evals.agent_task_adapter.StrandsEvalsTelemetry") @patch("strands_evals.agent_task_adapter.StrandsInMemorySessionMapper") def test_calls_agent_with_case_input(self, mock_mapper_cls, mock_telemetry_cls): - """The task invokes the agent with case.input.""" + """Invokes the agent with case.input.""" mock_agent = MagicMock() - mock_result = MagicMock() - mock_result.__str__ = MagicMock(return_value="the answer") - mock_agent.return_value = mock_result - + mock_agent.return_value.__str__ = MagicMock(return_value="answer") factory = MagicMock(return_value=mock_agent) mock_telemetry = MagicMock() mock_telemetry.setup_in_memory_exporter.return_value = mock_telemetry mock_telemetry.in_memory_exporter.get_finished_spans.return_value = [] mock_telemetry_cls.return_value = mock_telemetry + mock_mapper_cls.return_value.map_to_session.return_value = MagicMock() - mock_mapper = MagicMock() - mock_mapper.map_to_session.return_value = MagicMock() - mock_mapper_cls.return_value = mock_mapper - - case = Case(name="test", input="What is 2+2?") - task = create_agent_task(factory) - task(case) + task = TracedAgentTask(factory) + task(Case(name="test", input="What is 2+2?")) mock_agent.assert_called_once_with("What is 2+2?") @patch("strands_evals.agent_task_adapter.StrandsEvalsTelemetry") @patch("strands_evals.agent_task_adapter.StrandsInMemorySessionMapper") def test_returns_dict_with_output_and_trajectory(self, mock_mapper_cls, mock_telemetry_cls): - """The task returns a dict with 'output' and 'trajectory' keys.""" + """Returns a dict with both 'output' and 'trajectory' keys.""" mock_agent = MagicMock() - mock_result = MagicMock() - mock_result.__str__ = MagicMock(return_value="42") - mock_agent.return_value = mock_result - + mock_agent.return_value.__str__ = MagicMock(return_value="42") factory = MagicMock(return_value=mock_agent) mock_telemetry = MagicMock() @@ -85,27 +109,20 @@ def test_returns_dict_with_output_and_trajectory(self, mock_mapper_cls, mock_tel mock_telemetry_cls.return_value = mock_telemetry mock_session = MagicMock() - mock_mapper = MagicMock() - mock_mapper.map_to_session.return_value = mock_session - mock_mapper_cls.return_value = mock_mapper + mock_mapper_cls.return_value.map_to_session.return_value = mock_session - case = Case(name="test", input="What is 2+2?") - task = create_agent_task(factory) - result = task(case) + task = TracedAgentTask(factory) + result = task(Case(name="test", input="What is 2+2?")) - assert isinstance(result, dict) - assert "output" in result - assert "trajectory" in result assert result["output"] == "42" assert result["trajectory"] is mock_session @patch("strands_evals.agent_task_adapter.StrandsEvalsTelemetry") @patch("strands_evals.agent_task_adapter.StrandsInMemorySessionMapper") def test_collects_spans_and_maps_to_session(self, mock_mapper_cls, mock_telemetry_cls): - """The task collects spans from telemetry and maps them to a Session.""" + """Collects spans from telemetry and maps them to a Session.""" mock_agent = MagicMock() - mock_agent.return_value = MagicMock(__str__=MagicMock(return_value="res")) - + mock_agent.return_value.__str__ = MagicMock(return_value="res") factory = MagicMock(return_value=mock_agent) mock_spans = [MagicMock(), MagicMock()] @@ -119,9 +136,8 @@ def test_collects_spans_and_maps_to_session(self, mock_mapper_cls, mock_telemetr mock_mapper.map_to_session.return_value = mock_session mock_mapper_cls.return_value = mock_mapper - case = Case(name="test", input="hi", session_id="sess-123") - task = create_agent_task(factory) - result = task(case) + task = TracedAgentTask(factory) + result = task(Case(name="test", input="hi", session_id="sess-123")) mock_mapper.map_to_session.assert_called_once_with(list(mock_spans), "sess-123") assert result["trajectory"] is mock_session @@ -131,22 +147,17 @@ def test_collects_spans_and_maps_to_session(self, mock_mapper_cls, mock_telemetr def test_clears_spans_between_calls(self, mock_mapper_cls, mock_telemetry_cls): """Each call clears the in-memory exporter before running the agent.""" mock_agent = MagicMock() - mock_agent.return_value = MagicMock(__str__=MagicMock(return_value="res")) - + mock_agent.return_value.__str__ = MagicMock(return_value="res") factory = MagicMock(return_value=mock_agent) mock_telemetry = MagicMock() mock_telemetry.setup_in_memory_exporter.return_value = mock_telemetry mock_telemetry.in_memory_exporter.get_finished_spans.return_value = [] mock_telemetry_cls.return_value = mock_telemetry - mock_mapper_cls.return_value = MagicMock() - task = create_agent_task(factory) - - case1 = Case(name="test1", input="a") - case2 = Case(name="test2", input="b") - task(case1) - task(case2) + task = TracedAgentTask(factory) + task(Case(name="test1", input="a")) + task(Case(name="test2", input="b")) assert mock_telemetry.in_memory_exporter.clear.call_count == 2 diff --git a/tests/strands_evals/test_experiment_agent_factory.py b/tests/strands_evals/test_experiment_agent_factory.py deleted file mode 100644 index 4274595e..00000000 --- a/tests/strands_evals/test_experiment_agent_factory.py +++ /dev/null @@ -1,139 +0,0 @@ -"""Tests for Experiment.run_evaluations with agent_factory parameter.""" - -from unittest.mock import MagicMock, patch - -import pytest - -from strands_evals.case import Case -from strands_evals.evaluators.evaluator import Evaluator -from strands_evals.experiment import Experiment -from strands_evals.types.evaluation import EvaluationOutput - - -class PassingEvaluator(Evaluator): - """Evaluator that always passes for testing.""" - - def evaluate(self, evaluation_case): - return [EvaluationOutput(score=1.0, test_pass=True, reason="pass")] - - -class TestRunEvaluationsWithAgentFactory: - """Tests for the agent_factory parameter on run_evaluations.""" - - def test_rejects_both_task_and_agent_factory(self): - """Passing both task and agent_factory raises ValueError.""" - experiment = Experiment( - cases=[Case(name="test", input="hi")], - evaluators=[PassingEvaluator()], - ) - - def my_task(case): - return "output" - - def my_factory(): - return MagicMock() - - with pytest.raises(ValueError, match="Cannot specify both"): - experiment.run_evaluations(task=my_task, agent_factory=my_factory) - - def test_rejects_neither_task_nor_agent_factory(self): - """Passing neither task nor agent_factory raises ValueError.""" - experiment = Experiment( - cases=[Case(name="test", input="hi")], - evaluators=[PassingEvaluator()], - ) - - with pytest.raises(ValueError, match="Must specify either"): - experiment.run_evaluations() - - @patch("strands_evals.experiment.create_agent_task") - def test_agent_factory_creates_task_via_adapter(self, mock_create_agent_task): - """When agent_factory is passed, it's wrapped via create_agent_task.""" - mock_task = MagicMock(return_value="output") - mock_create_agent_task.return_value = mock_task - - factory = MagicMock() - experiment = Experiment( - cases=[Case(name="test", input="hi")], - evaluators=[PassingEvaluator()], - ) - experiment.run_evaluations(agent_factory=factory) - - mock_create_agent_task.assert_called_once_with(factory) - - @patch("strands_evals.experiment.create_agent_task") - def test_agent_factory_results_flow_to_evaluators(self, mock_create_agent_task): - """Results from the adapted agent task are evaluated normally.""" - mock_task = MagicMock(return_value={"output": "42", "trajectory": None}) - mock_create_agent_task.return_value = mock_task - - factory = MagicMock() - experiment = Experiment( - cases=[Case(name="test", input="hi")], - evaluators=[PassingEvaluator()], - ) - reports = experiment.run_evaluations(agent_factory=factory) - - assert len(reports) == 1 - assert reports[0].scores[0] == 1.0 - assert reports[0].test_passes[0] is True - - def test_existing_task_parameter_still_works(self): - """Existing task= parameter continues to work unchanged.""" - def my_task(case): - return "output" - - experiment = Experiment( - cases=[Case(name="test", input="hi")], - evaluators=[PassingEvaluator()], - ) - reports = experiment.run_evaluations(task=my_task) - - assert len(reports) == 1 - assert reports[0].scores[0] == 1.0 - - -class TestRunEvaluationsAsyncWithAgentFactory: - """Tests for the agent_factory parameter on run_evaluations_async.""" - - def test_rejects_both_task_and_agent_factory(self): - """Passing both task and agent_factory raises ValueError.""" - import asyncio - - experiment = Experiment( - cases=[Case(name="test", input="hi")], - evaluators=[PassingEvaluator()], - ) - - with pytest.raises(ValueError, match="Cannot specify both"): - asyncio.run(experiment.run_evaluations_async(task=MagicMock(), agent_factory=MagicMock())) - - def test_rejects_neither_task_nor_agent_factory(self): - """Passing neither task nor agent_factory raises ValueError.""" - import asyncio - - experiment = Experiment( - cases=[Case(name="test", input="hi")], - evaluators=[PassingEvaluator()], - ) - - with pytest.raises(ValueError, match="Must specify either"): - asyncio.run(experiment.run_evaluations_async()) - - @patch("strands_evals.experiment.create_agent_task") - def test_agent_factory_works_async(self, mock_create_agent_task): - """agent_factory works with run_evaluations_async.""" - import asyncio - - mock_task = MagicMock(return_value="output") - mock_create_agent_task.return_value = mock_task - - factory = MagicMock() - experiment = Experiment( - cases=[Case(name="test", input="hi")], - evaluators=[PassingEvaluator()], - ) - reports = asyncio.run(experiment.run_evaluations_async(agent_factory=factory)) - - mock_create_agent_task.assert_called_once_with(factory) - assert len(reports) == 1 From c334712880614f588b37f9bf6012d84720d58926 Mon Sep 17 00:00:00 2001 From: Aaron Farntrog Date: Mon, 13 Apr 2026 16:06:32 -0400 Subject: [PATCH 3/7] refactor: replace agent_factory callable with direct Agent kwargs in AgentTask Simplify AgentTask and TracedAgentTask API by accepting **agent_kwargs instead of requiring a lambda/callable factory. The classes now directly construct strands.Agent instances internally via a _create_agent method, reducing boilerplate for users (e.g., `AgentTask(model="...")` instead of `AgentTask(lambda: Agent(model="..."))`). Updated tests to verify kwargs forwarding and fresh agent creation per call. --- src/strands_evals/agent_task_adapter.py | 30 +++---- .../strands_evals/test_agent_task_adapter.py | 81 ++++++++++++------- 2 files changed, 66 insertions(+), 45 deletions(-) diff --git a/src/strands_evals/agent_task_adapter.py b/src/strands_evals/agent_task_adapter.py index d747a266..ac76eee2 100644 --- a/src/strands_evals/agent_task_adapter.py +++ b/src/strands_evals/agent_task_adapter.py @@ -1,8 +1,9 @@ -"""Adapts agent factories into task callables compatible with Experiment.run_evaluations.""" +"""Adapts agent configuration into task callables compatible with Experiment.run_evaluations.""" -from collections.abc import Callable from typing import Any +from strands import Agent + from .case import Case from .mappers.strands_in_memory_session_mapper import StrandsInMemorySessionMapper from .telemetry import StrandsEvalsTelemetry @@ -11,19 +12,21 @@ class AgentTask: """A task that creates a fresh Agent per case and invokes it with case.input. - Args: - agent_factory: A no-arg callable that returns a strands Agent instance. + Accepts the same keyword arguments as strands.Agent. Example: - task = AgentTask(lambda: Agent(model="...", tools=[calculator])) + task = AgentTask(model="us.anthropic.claude-sonnet-4-20250514-v1:0", tools=[calculator]) experiment.run_evaluations(task=task) """ - def __init__(self, agent_factory: Callable[[], Any]): - self._agent_factory = agent_factory + def __init__(self, **agent_kwargs: Any): + self._agent_kwargs = agent_kwargs + + def _create_agent(self) -> Agent: + return Agent(**self._agent_kwargs) def __call__(self, case: Case) -> dict[str, Any]: - agent = self._agent_factory() + agent = self._create_agent() result = agent(case.input) return {"output": str(result)} @@ -33,23 +36,20 @@ class TracedAgentTask(AgentTask): Use this when your evaluators need trajectory data (e.g., TrajectoryEvaluator). - Args: - agent_factory: A no-arg callable that returns a strands Agent instance. - Example: - task = TracedAgentTask(lambda: Agent(model="...", tools=[calculator])) + task = TracedAgentTask(model="us.anthropic.claude-sonnet-4-20250514-v1:0", tools=[calculator]) experiment.run_evaluations(task=task) """ - def __init__(self, agent_factory: Callable[[], Any]): - super().__init__(agent_factory) + def __init__(self, **agent_kwargs: Any): + super().__init__(**agent_kwargs) self._telemetry = StrandsEvalsTelemetry().setup_in_memory_exporter() self._mapper = StrandsInMemorySessionMapper() def __call__(self, case: Case) -> dict[str, Any]: self._telemetry.in_memory_exporter.clear() - agent = self._agent_factory() + agent = self._create_agent() result = agent(case.input) spans = list(self._telemetry.in_memory_exporter.get_finished_spans()) diff --git a/tests/strands_evals/test_agent_task_adapter.py b/tests/strands_evals/test_agent_task_adapter.py index 6dfec4ea..03b94821 100644 --- a/tests/strands_evals/test_agent_task_adapter.py +++ b/tests/strands_evals/test_agent_task_adapter.py @@ -1,6 +1,6 @@ """Tests for AgentTask and TracedAgentTask.""" -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, call, patch from strands_evals.agent_task_adapter import AgentTask, TracedAgentTask from strands_evals.case import Case @@ -11,45 +11,62 @@ class TestAgentTask: def test_is_callable(self): """AgentTask instances are callable.""" - task = AgentTask(MagicMock()) + task = AgentTask(model="test-model") assert callable(task) - def test_calls_factory_to_create_agent(self): - """Calls the factory to get a fresh agent per case.""" + @patch("strands_evals.agent_task_adapter.Agent") + def test_creates_fresh_agent_per_call(self, mock_agent_cls): + """Creates a new Agent for each case.""" mock_agent = MagicMock() mock_agent.return_value.__str__ = MagicMock(return_value="result") - factory = MagicMock(return_value=mock_agent) + mock_agent_cls.return_value = mock_agent - task = AgentTask(factory) - task(Case(name="test", input="hello")) + task = AgentTask(model="test-model") + task(Case(name="test1", input="a")) + task(Case(name="test2", input="b")) + + assert mock_agent_cls.call_count == 2 + + @patch("strands_evals.agent_task_adapter.Agent") + def test_passes_kwargs_to_agent(self, mock_agent_cls): + """Passes stored kwargs to Agent constructor.""" + mock_agent = MagicMock() + mock_agent.return_value.__str__ = MagicMock(return_value="result") + mock_agent_cls.return_value = mock_agent + + task = AgentTask(model="test-model", system_prompt="Be helpful") + task(Case(name="test", input="hi")) - factory.assert_called_once() + mock_agent_cls.assert_called_once_with(model="test-model", system_prompt="Be helpful") - def test_calls_agent_with_case_input(self): + @patch("strands_evals.agent_task_adapter.Agent") + def test_calls_agent_with_case_input(self, mock_agent_cls): """Invokes the agent with case.input.""" mock_agent = MagicMock() mock_agent.return_value.__str__ = MagicMock(return_value="answer") - factory = MagicMock(return_value=mock_agent) + mock_agent_cls.return_value = mock_agent - task = AgentTask(factory) + task = AgentTask(model="test-model") task(Case(name="test", input="What is 2+2?")) mock_agent.assert_called_once_with("What is 2+2?") - def test_returns_dict_with_output(self): - """Returns a dict with 'output' key.""" + @patch("strands_evals.agent_task_adapter.Agent") + def test_returns_dict_with_output(self, mock_agent_cls): + """Returns a dict with 'output' key and no trajectory.""" mock_agent = MagicMock() mock_agent.return_value.__str__ = MagicMock(return_value="42") - factory = MagicMock(return_value=mock_agent) + mock_agent_cls.return_value = mock_agent - task = AgentTask(factory) + task = AgentTask(model="test-model") result = task(Case(name="test", input="What is 2+2?")) assert isinstance(result, dict) assert result["output"] == "42" assert "trajectory" not in result - def test_works_with_experiment_run_evaluations(self): + @patch("strands_evals.agent_task_adapter.Agent") + def test_works_with_experiment_run_evaluations(self, mock_agent_cls): """AgentTask works when passed as the task parameter to Experiment.""" from strands_evals.evaluators.evaluator import Evaluator from strands_evals.experiment import Experiment @@ -61,13 +78,13 @@ def evaluate(self, evaluation_case): mock_agent = MagicMock() mock_agent.return_value.__str__ = MagicMock(return_value="output") - factory = MagicMock(return_value=mock_agent) + mock_agent_cls.return_value = mock_agent experiment = Experiment( cases=[Case(name="test", input="hi")], evaluators=[PassingEvaluator()], ) - reports = experiment.run_evaluations(task=AgentTask(factory)) + reports = experiment.run_evaluations(task=AgentTask(model="test-model")) assert len(reports) == 1 assert reports[0].scores[0] == 1.0 @@ -76,13 +93,14 @@ def evaluate(self, evaluation_case): class TestTracedAgentTask: """Tests for TracedAgentTask that adds telemetry collection.""" + @patch("strands_evals.agent_task_adapter.Agent") @patch("strands_evals.agent_task_adapter.StrandsEvalsTelemetry") @patch("strands_evals.agent_task_adapter.StrandsInMemorySessionMapper") - def test_calls_agent_with_case_input(self, mock_mapper_cls, mock_telemetry_cls): + def test_calls_agent_with_case_input(self, mock_mapper_cls, mock_telemetry_cls, mock_agent_cls): """Invokes the agent with case.input.""" mock_agent = MagicMock() mock_agent.return_value.__str__ = MagicMock(return_value="answer") - factory = MagicMock(return_value=mock_agent) + mock_agent_cls.return_value = mock_agent mock_telemetry = MagicMock() mock_telemetry.setup_in_memory_exporter.return_value = mock_telemetry @@ -90,18 +108,19 @@ def test_calls_agent_with_case_input(self, mock_mapper_cls, mock_telemetry_cls): mock_telemetry_cls.return_value = mock_telemetry mock_mapper_cls.return_value.map_to_session.return_value = MagicMock() - task = TracedAgentTask(factory) + task = TracedAgentTask(model="test-model") task(Case(name="test", input="What is 2+2?")) mock_agent.assert_called_once_with("What is 2+2?") + @patch("strands_evals.agent_task_adapter.Agent") @patch("strands_evals.agent_task_adapter.StrandsEvalsTelemetry") @patch("strands_evals.agent_task_adapter.StrandsInMemorySessionMapper") - def test_returns_dict_with_output_and_trajectory(self, mock_mapper_cls, mock_telemetry_cls): + def test_returns_dict_with_output_and_trajectory(self, mock_mapper_cls, mock_telemetry_cls, mock_agent_cls): """Returns a dict with both 'output' and 'trajectory' keys.""" mock_agent = MagicMock() mock_agent.return_value.__str__ = MagicMock(return_value="42") - factory = MagicMock(return_value=mock_agent) + mock_agent_cls.return_value = mock_agent mock_telemetry = MagicMock() mock_telemetry.setup_in_memory_exporter.return_value = mock_telemetry @@ -111,19 +130,20 @@ def test_returns_dict_with_output_and_trajectory(self, mock_mapper_cls, mock_tel mock_session = MagicMock() mock_mapper_cls.return_value.map_to_session.return_value = mock_session - task = TracedAgentTask(factory) + task = TracedAgentTask(model="test-model") result = task(Case(name="test", input="What is 2+2?")) assert result["output"] == "42" assert result["trajectory"] is mock_session + @patch("strands_evals.agent_task_adapter.Agent") @patch("strands_evals.agent_task_adapter.StrandsEvalsTelemetry") @patch("strands_evals.agent_task_adapter.StrandsInMemorySessionMapper") - def test_collects_spans_and_maps_to_session(self, mock_mapper_cls, mock_telemetry_cls): + def test_collects_spans_and_maps_to_session(self, mock_mapper_cls, mock_telemetry_cls, mock_agent_cls): """Collects spans from telemetry and maps them to a Session.""" mock_agent = MagicMock() mock_agent.return_value.__str__ = MagicMock(return_value="res") - factory = MagicMock(return_value=mock_agent) + mock_agent_cls.return_value = mock_agent mock_spans = [MagicMock(), MagicMock()] mock_telemetry = MagicMock() @@ -136,19 +156,20 @@ def test_collects_spans_and_maps_to_session(self, mock_mapper_cls, mock_telemetr mock_mapper.map_to_session.return_value = mock_session mock_mapper_cls.return_value = mock_mapper - task = TracedAgentTask(factory) + task = TracedAgentTask(model="test-model") result = task(Case(name="test", input="hi", session_id="sess-123")) mock_mapper.map_to_session.assert_called_once_with(list(mock_spans), "sess-123") assert result["trajectory"] is mock_session + @patch("strands_evals.agent_task_adapter.Agent") @patch("strands_evals.agent_task_adapter.StrandsEvalsTelemetry") @patch("strands_evals.agent_task_adapter.StrandsInMemorySessionMapper") - def test_clears_spans_between_calls(self, mock_mapper_cls, mock_telemetry_cls): + def test_clears_spans_between_calls(self, mock_mapper_cls, mock_telemetry_cls, mock_agent_cls): """Each call clears the in-memory exporter before running the agent.""" mock_agent = MagicMock() mock_agent.return_value.__str__ = MagicMock(return_value="res") - factory = MagicMock(return_value=mock_agent) + mock_agent_cls.return_value = mock_agent mock_telemetry = MagicMock() mock_telemetry.setup_in_memory_exporter.return_value = mock_telemetry @@ -156,7 +177,7 @@ def test_clears_spans_between_calls(self, mock_mapper_cls, mock_telemetry_cls): mock_telemetry_cls.return_value = mock_telemetry mock_mapper_cls.return_value = MagicMock() - task = TracedAgentTask(factory) + task = TracedAgentTask(model="test-model") task(Case(name="test1", input="a")) task(Case(name="test2", input="b")) From d39c59d6bbad38de407e402e5fefa6b38f606184 Mon Sep 17 00:00:00 2001 From: Aaron Farntrog Date: Tue, 14 Apr 2026 15:37:58 -0400 Subject: [PATCH 4/7] feat: add @eval_task decorator and handlers for wrapping task functions Introduce `EvalTaskHandler` and `TracedHandler` classes along with an `@eval_task` decorator to standardize how task functions are wrapped with evaluation behavior. `EvalTaskHandler` normalizes return values to dicts, while `TracedHandler` adds OpenTelemetry span collection and session mapping for trajectory-based evaluators. Includes public exports and comprehensive tests. --- src/strands_evals/__init__.py | 6 +- src/strands_evals/agent_task_adapter.py | 58 ------ src/strands_evals/eval_task.py | 102 ++++++++++ .../strands_evals/test_agent_task_adapter.py | 184 ----------------- tests/strands_evals/test_eval_task.py | 185 ++++++++++++++++++ 5 files changed, 290 insertions(+), 245 deletions(-) delete mode 100644 src/strands_evals/agent_task_adapter.py create mode 100644 src/strands_evals/eval_task.py delete mode 100644 tests/strands_evals/test_agent_task_adapter.py create mode 100644 tests/strands_evals/test_eval_task.py diff --git a/src/strands_evals/__init__.py b/src/strands_evals/__init__.py index 2339e4f7..2b89af0d 100644 --- a/src/strands_evals/__init__.py +++ b/src/strands_evals/__init__.py @@ -1,6 +1,6 @@ from . import evaluators, extractors, generators, providers, simulation, telemetry, types -from .agent_task_adapter import AgentTask, TracedAgentTask from .case import Case +from .eval_task import EvalTaskHandler, TracedHandler from .evaluation_data_store import EvaluationDataStore from .experiment import Experiment from .local_file_task_result_store import LocalFileTaskResultStore @@ -12,8 +12,8 @@ "Case", "LocalFileTaskResultStore", "EvaluationDataStore", - "AgentTask", - "TracedAgentTask", + "EvalTaskHandler", + "TracedHandler", "evaluators", "extractors", "providers", diff --git a/src/strands_evals/agent_task_adapter.py b/src/strands_evals/agent_task_adapter.py deleted file mode 100644 index ac76eee2..00000000 --- a/src/strands_evals/agent_task_adapter.py +++ /dev/null @@ -1,58 +0,0 @@ -"""Adapts agent configuration into task callables compatible with Experiment.run_evaluations.""" - -from typing import Any - -from strands import Agent - -from .case import Case -from .mappers.strands_in_memory_session_mapper import StrandsInMemorySessionMapper -from .telemetry import StrandsEvalsTelemetry - - -class AgentTask: - """A task that creates a fresh Agent per case and invokes it with case.input. - - Accepts the same keyword arguments as strands.Agent. - - Example: - task = AgentTask(model="us.anthropic.claude-sonnet-4-20250514-v1:0", tools=[calculator]) - experiment.run_evaluations(task=task) - """ - - def __init__(self, **agent_kwargs: Any): - self._agent_kwargs = agent_kwargs - - def _create_agent(self) -> Agent: - return Agent(**self._agent_kwargs) - - def __call__(self, case: Case) -> dict[str, Any]: - agent = self._create_agent() - result = agent(case.input) - return {"output": str(result)} - - -class TracedAgentTask(AgentTask): - """An AgentTask that also collects OpenTelemetry spans and maps them to a Session. - - Use this when your evaluators need trajectory data (e.g., TrajectoryEvaluator). - - Example: - task = TracedAgentTask(model="us.anthropic.claude-sonnet-4-20250514-v1:0", tools=[calculator]) - experiment.run_evaluations(task=task) - """ - - def __init__(self, **agent_kwargs: Any): - super().__init__(**agent_kwargs) - self._telemetry = StrandsEvalsTelemetry().setup_in_memory_exporter() - self._mapper = StrandsInMemorySessionMapper() - - def __call__(self, case: Case) -> dict[str, Any]: - self._telemetry.in_memory_exporter.clear() - - agent = self._create_agent() - result = agent(case.input) - - spans = list(self._telemetry.in_memory_exporter.get_finished_spans()) - session = self._mapper.map_to_session(spans, case.session_id) - - return {"output": str(result), "trajectory": session} diff --git a/src/strands_evals/eval_task.py b/src/strands_evals/eval_task.py new file mode 100644 index 00000000..5918c40a --- /dev/null +++ b/src/strands_evals/eval_task.py @@ -0,0 +1,102 @@ +"""Decorator and handlers for wrapping task functions with evaluation behavior.""" + +import functools +from collections.abc import Callable +from typing import Any + +from .case import Case +from .mappers.strands_in_memory_session_mapper import StrandsInMemorySessionMapper +from .telemetry import StrandsEvalsTelemetry + + +class EvalTaskHandler: + """Base handler that normalizes task function return values. + + Subclass to add behavior before/after task execution (e.g., telemetry collection). + """ + + def before(self, case: Case) -> None: + """Called before the task function runs. Override to add setup logic.""" + pass + + def after(self, case: Case, result: Any) -> dict[str, Any]: + """Called after the task function runs. Normalizes the result to a dict. + + Args: + case: The test case that was executed. + result: The raw return value from the task function. + + Returns: + A dict compatible with Experiment (must have at least "output" key). + """ + if isinstance(result, dict): + return result + return {"output": str(result)} + + +class TracedHandler(EvalTaskHandler): + """Handler that collects OpenTelemetry spans and maps them to a Session. + + Use with @eval_task when your evaluators need trajectory data. + + Args: + mapper: Session mapper to use. Defaults to StrandsInMemorySessionMapper. + + Example: + @eval_task(TracedHandler()) + def my_task(case): + return str(my_agent(case.input)) + + @eval_task(TracedHandler(mapper=MyCustomMapper())) + def my_task(case): + return str(my_agent(case.input)) + """ + + def __init__(self, mapper=None): + self._telemetry = StrandsEvalsTelemetry().setup_in_memory_exporter() + self._mapper = mapper or StrandsInMemorySessionMapper() + + def before(self, case: Case) -> None: + self._telemetry.in_memory_exporter.clear() + + def after(self, case: Case, result: Any) -> dict[str, Any]: + processed = super().after(case, result) + + spans = list(self._telemetry.in_memory_exporter.get_finished_spans()) + session = self._mapper.map_to_session(spans, case.session_id) + processed.setdefault("trajectory", session) + + return processed + + +def eval_task(handler: EvalTaskHandler | None = None) -> Callable: + """Decorator that wraps a task function with evaluation behavior. + + Args: + handler: Handler that runs before/after the task function. + Defaults to EvalTaskHandler (normalizes return values only). + + Example: + # Simple — just normalize output + @eval_task() + def my_task(case): + return str(my_agent(case.input)) + + # With telemetry collection + @eval_task(TracedHandler()) + def my_task(case): + return str(my_agent(case.input)) + """ + if handler is None: + handler = EvalTaskHandler() + + def decorator(fn: Callable[[Case], Any]) -> Callable[[Case], dict[str, Any]]: + @functools.wraps(fn) + def wrapper(case: Case) -> dict[str, Any]: + handler.before(case) + result = fn(case) + return handler.after(case, result) + + return wrapper + + return decorator diff --git a/tests/strands_evals/test_agent_task_adapter.py b/tests/strands_evals/test_agent_task_adapter.py deleted file mode 100644 index 03b94821..00000000 --- a/tests/strands_evals/test_agent_task_adapter.py +++ /dev/null @@ -1,184 +0,0 @@ -"""Tests for AgentTask and TracedAgentTask.""" - -from unittest.mock import MagicMock, call, patch - -from strands_evals.agent_task_adapter import AgentTask, TracedAgentTask -from strands_evals.case import Case - - -class TestAgentTask: - """Tests for the base AgentTask class.""" - - def test_is_callable(self): - """AgentTask instances are callable.""" - task = AgentTask(model="test-model") - assert callable(task) - - @patch("strands_evals.agent_task_adapter.Agent") - def test_creates_fresh_agent_per_call(self, mock_agent_cls): - """Creates a new Agent for each case.""" - mock_agent = MagicMock() - mock_agent.return_value.__str__ = MagicMock(return_value="result") - mock_agent_cls.return_value = mock_agent - - task = AgentTask(model="test-model") - task(Case(name="test1", input="a")) - task(Case(name="test2", input="b")) - - assert mock_agent_cls.call_count == 2 - - @patch("strands_evals.agent_task_adapter.Agent") - def test_passes_kwargs_to_agent(self, mock_agent_cls): - """Passes stored kwargs to Agent constructor.""" - mock_agent = MagicMock() - mock_agent.return_value.__str__ = MagicMock(return_value="result") - mock_agent_cls.return_value = mock_agent - - task = AgentTask(model="test-model", system_prompt="Be helpful") - task(Case(name="test", input="hi")) - - mock_agent_cls.assert_called_once_with(model="test-model", system_prompt="Be helpful") - - @patch("strands_evals.agent_task_adapter.Agent") - def test_calls_agent_with_case_input(self, mock_agent_cls): - """Invokes the agent with case.input.""" - mock_agent = MagicMock() - mock_agent.return_value.__str__ = MagicMock(return_value="answer") - mock_agent_cls.return_value = mock_agent - - task = AgentTask(model="test-model") - task(Case(name="test", input="What is 2+2?")) - - mock_agent.assert_called_once_with("What is 2+2?") - - @patch("strands_evals.agent_task_adapter.Agent") - def test_returns_dict_with_output(self, mock_agent_cls): - """Returns a dict with 'output' key and no trajectory.""" - mock_agent = MagicMock() - mock_agent.return_value.__str__ = MagicMock(return_value="42") - mock_agent_cls.return_value = mock_agent - - task = AgentTask(model="test-model") - result = task(Case(name="test", input="What is 2+2?")) - - assert isinstance(result, dict) - assert result["output"] == "42" - assert "trajectory" not in result - - @patch("strands_evals.agent_task_adapter.Agent") - def test_works_with_experiment_run_evaluations(self, mock_agent_cls): - """AgentTask works when passed as the task parameter to Experiment.""" - from strands_evals.evaluators.evaluator import Evaluator - from strands_evals.experiment import Experiment - from strands_evals.types.evaluation import EvaluationOutput - - class PassingEvaluator(Evaluator): - def evaluate(self, evaluation_case): - return [EvaluationOutput(score=1.0, test_pass=True, reason="pass")] - - mock_agent = MagicMock() - mock_agent.return_value.__str__ = MagicMock(return_value="output") - mock_agent_cls.return_value = mock_agent - - experiment = Experiment( - cases=[Case(name="test", input="hi")], - evaluators=[PassingEvaluator()], - ) - reports = experiment.run_evaluations(task=AgentTask(model="test-model")) - - assert len(reports) == 1 - assert reports[0].scores[0] == 1.0 - - -class TestTracedAgentTask: - """Tests for TracedAgentTask that adds telemetry collection.""" - - @patch("strands_evals.agent_task_adapter.Agent") - @patch("strands_evals.agent_task_adapter.StrandsEvalsTelemetry") - @patch("strands_evals.agent_task_adapter.StrandsInMemorySessionMapper") - def test_calls_agent_with_case_input(self, mock_mapper_cls, mock_telemetry_cls, mock_agent_cls): - """Invokes the agent with case.input.""" - mock_agent = MagicMock() - mock_agent.return_value.__str__ = MagicMock(return_value="answer") - mock_agent_cls.return_value = mock_agent - - mock_telemetry = MagicMock() - mock_telemetry.setup_in_memory_exporter.return_value = mock_telemetry - mock_telemetry.in_memory_exporter.get_finished_spans.return_value = [] - mock_telemetry_cls.return_value = mock_telemetry - mock_mapper_cls.return_value.map_to_session.return_value = MagicMock() - - task = TracedAgentTask(model="test-model") - task(Case(name="test", input="What is 2+2?")) - - mock_agent.assert_called_once_with("What is 2+2?") - - @patch("strands_evals.agent_task_adapter.Agent") - @patch("strands_evals.agent_task_adapter.StrandsEvalsTelemetry") - @patch("strands_evals.agent_task_adapter.StrandsInMemorySessionMapper") - def test_returns_dict_with_output_and_trajectory(self, mock_mapper_cls, mock_telemetry_cls, mock_agent_cls): - """Returns a dict with both 'output' and 'trajectory' keys.""" - mock_agent = MagicMock() - mock_agent.return_value.__str__ = MagicMock(return_value="42") - mock_agent_cls.return_value = mock_agent - - mock_telemetry = MagicMock() - mock_telemetry.setup_in_memory_exporter.return_value = mock_telemetry - mock_telemetry.in_memory_exporter.get_finished_spans.return_value = [] - mock_telemetry_cls.return_value = mock_telemetry - - mock_session = MagicMock() - mock_mapper_cls.return_value.map_to_session.return_value = mock_session - - task = TracedAgentTask(model="test-model") - result = task(Case(name="test", input="What is 2+2?")) - - assert result["output"] == "42" - assert result["trajectory"] is mock_session - - @patch("strands_evals.agent_task_adapter.Agent") - @patch("strands_evals.agent_task_adapter.StrandsEvalsTelemetry") - @patch("strands_evals.agent_task_adapter.StrandsInMemorySessionMapper") - def test_collects_spans_and_maps_to_session(self, mock_mapper_cls, mock_telemetry_cls, mock_agent_cls): - """Collects spans from telemetry and maps them to a Session.""" - mock_agent = MagicMock() - mock_agent.return_value.__str__ = MagicMock(return_value="res") - mock_agent_cls.return_value = mock_agent - - mock_spans = [MagicMock(), MagicMock()] - mock_telemetry = MagicMock() - mock_telemetry.setup_in_memory_exporter.return_value = mock_telemetry - mock_telemetry.in_memory_exporter.get_finished_spans.return_value = mock_spans - mock_telemetry_cls.return_value = mock_telemetry - - mock_session = MagicMock() - mock_mapper = MagicMock() - mock_mapper.map_to_session.return_value = mock_session - mock_mapper_cls.return_value = mock_mapper - - task = TracedAgentTask(model="test-model") - result = task(Case(name="test", input="hi", session_id="sess-123")) - - mock_mapper.map_to_session.assert_called_once_with(list(mock_spans), "sess-123") - assert result["trajectory"] is mock_session - - @patch("strands_evals.agent_task_adapter.Agent") - @patch("strands_evals.agent_task_adapter.StrandsEvalsTelemetry") - @patch("strands_evals.agent_task_adapter.StrandsInMemorySessionMapper") - def test_clears_spans_between_calls(self, mock_mapper_cls, mock_telemetry_cls, mock_agent_cls): - """Each call clears the in-memory exporter before running the agent.""" - mock_agent = MagicMock() - mock_agent.return_value.__str__ = MagicMock(return_value="res") - mock_agent_cls.return_value = mock_agent - - mock_telemetry = MagicMock() - mock_telemetry.setup_in_memory_exporter.return_value = mock_telemetry - mock_telemetry.in_memory_exporter.get_finished_spans.return_value = [] - mock_telemetry_cls.return_value = mock_telemetry - mock_mapper_cls.return_value = MagicMock() - - task = TracedAgentTask(model="test-model") - task(Case(name="test1", input="a")) - task(Case(name="test2", input="b")) - - assert mock_telemetry.in_memory_exporter.clear.call_count == 2 diff --git a/tests/strands_evals/test_eval_task.py b/tests/strands_evals/test_eval_task.py new file mode 100644 index 00000000..1d5cacee --- /dev/null +++ b/tests/strands_evals/test_eval_task.py @@ -0,0 +1,185 @@ +"""Tests for @eval_task decorator and handlers.""" + +from unittest.mock import MagicMock, patch + +from strands_evals.case import Case +from strands_evals.eval_task import EvalTaskHandler, TracedHandler, eval_task + + +class TestEvalTaskDecorator: + """Tests for the @eval_task decorator.""" + + def test_decorated_function_is_callable(self): + @eval_task() + def my_task(case): + return "output" + + assert callable(my_task) + + def test_passes_case_to_function(self): + received_case = None + + @eval_task() + def my_task(case): + nonlocal received_case + received_case = case + return "output" + + case = Case(name="test", input="hello") + my_task(case) + assert received_case is case + + def test_string_return_wrapped_as_dict(self): + @eval_task() + def my_task(case): + return "the answer" + + result = my_task(Case(name="test", input="hi")) + assert result == {"output": "the answer"} + + def test_dict_return_passed_through(self): + @eval_task() + def my_task(case): + return {"output": "answer", "custom_key": "value"} + + result = my_task(Case(name="test", input="hi")) + assert result == {"output": "answer", "custom_key": "value"} + + def test_works_without_handler(self): + @eval_task() + def my_task(case): + return "output" + + result = my_task(Case(name="test", input="hi")) + assert result == {"output": "output"} + + def test_works_with_experiment(self): + from strands_evals.evaluators.evaluator import Evaluator + from strands_evals.experiment import Experiment + from strands_evals.types.evaluation import EvaluationOutput + + class PassingEvaluator(Evaluator): + def evaluate(self, evaluation_case): + return [EvaluationOutput(score=1.0, test_pass=True, reason="pass")] + + @eval_task() + def my_task(case): + return "output" + + experiment = Experiment( + cases=[Case(name="test", input="hi")], + evaluators=[PassingEvaluator()], + ) + reports = experiment.run_evaluations(my_task) + assert reports[0].scores[0] == 1.0 + + +class TestEvalTaskHandler: + """Tests for the base EvalTaskHandler.""" + + def test_before_is_noop(self): + handler = EvalTaskHandler() + handler.before(Case(name="test", input="hi")) # should not raise + + def test_after_wraps_string(self): + handler = EvalTaskHandler() + result = handler.after(Case(name="test", input="hi"), "output") + assert result == {"output": "output"} + + def test_after_passes_dict_through(self): + handler = EvalTaskHandler() + result = handler.after(Case(name="test", input="hi"), {"output": "x", "extra": "y"}) + assert result == {"output": "x", "extra": "y"} + + +class TestTracedHandler: + """Tests for TracedHandler that collects telemetry spans.""" + + @patch("strands_evals.eval_task.StrandsEvalsTelemetry") + @patch("strands_evals.eval_task.StrandsInMemorySessionMapper") + def test_clears_spans_on_before(self, mock_mapper_cls, mock_telemetry_cls): + mock_telemetry = MagicMock() + mock_telemetry.setup_in_memory_exporter.return_value = mock_telemetry + mock_telemetry_cls.return_value = mock_telemetry + + handler = TracedHandler() + handler.before(Case(name="test", input="hi")) + + mock_telemetry.in_memory_exporter.clear.assert_called_once() + + @patch("strands_evals.eval_task.StrandsEvalsTelemetry") + @patch("strands_evals.eval_task.StrandsInMemorySessionMapper") + def test_after_adds_trajectory(self, mock_mapper_cls, mock_telemetry_cls): + mock_spans = [MagicMock()] + mock_telemetry = MagicMock() + mock_telemetry.setup_in_memory_exporter.return_value = mock_telemetry + mock_telemetry.in_memory_exporter.get_finished_spans.return_value = mock_spans + mock_telemetry_cls.return_value = mock_telemetry + + mock_session = MagicMock() + mock_mapper_cls.return_value.map_to_session.return_value = mock_session + + handler = TracedHandler() + case = Case(name="test", input="hi", session_id="sess-1") + result = handler.after(case, "output") + + assert result["output"] == "output" + assert result["trajectory"] is mock_session + mock_mapper_cls.return_value.map_to_session.assert_called_once_with( + list(mock_spans), "sess-1" + ) + + @patch("strands_evals.eval_task.StrandsEvalsTelemetry") + @patch("strands_evals.eval_task.StrandsInMemorySessionMapper") + def test_does_not_override_user_trajectory(self, mock_mapper_cls, mock_telemetry_cls): + mock_telemetry = MagicMock() + mock_telemetry.setup_in_memory_exporter.return_value = mock_telemetry + mock_telemetry.in_memory_exporter.get_finished_spans.return_value = [] + mock_telemetry_cls.return_value = mock_telemetry + mock_mapper_cls.return_value.map_to_session.return_value = MagicMock() + + handler = TracedHandler() + user_trajectory = ["tool1", "tool2"] + result = handler.after( + Case(name="test", input="hi"), + {"output": "x", "trajectory": user_trajectory}, + ) + + assert result["trajectory"] is user_trajectory + + @patch("strands_evals.eval_task.StrandsEvalsTelemetry") + @patch("strands_evals.eval_task.StrandsInMemorySessionMapper") + def test_accepts_custom_mapper(self, mock_mapper_cls, mock_telemetry_cls): + mock_telemetry = MagicMock() + mock_telemetry.setup_in_memory_exporter.return_value = mock_telemetry + mock_telemetry.in_memory_exporter.get_finished_spans.return_value = [] + mock_telemetry_cls.return_value = mock_telemetry + + custom_mapper = MagicMock() + custom_session = MagicMock() + custom_mapper.map_to_session.return_value = custom_session + + handler = TracedHandler(mapper=custom_mapper) + result = handler.after(Case(name="test", input="hi", session_id="s1"), "out") + + custom_mapper.map_to_session.assert_called_once() + assert result["trajectory"] is custom_session + + @patch("strands_evals.eval_task.StrandsEvalsTelemetry") + @patch("strands_evals.eval_task.StrandsInMemorySessionMapper") + def test_full_decorator_flow(self, mock_mapper_cls, mock_telemetry_cls): + mock_telemetry = MagicMock() + mock_telemetry.setup_in_memory_exporter.return_value = mock_telemetry + mock_telemetry.in_memory_exporter.get_finished_spans.return_value = [] + mock_telemetry_cls.return_value = mock_telemetry + + mock_session = MagicMock() + mock_mapper_cls.return_value.map_to_session.return_value = mock_session + + @eval_task(TracedHandler()) + def my_task(case): + return f"answer to {case.input}" + + result = my_task(Case(name="test", input="question")) + assert result["output"] == "answer to question" + assert result["trajectory"] is mock_session From 272ca4a9011badbdff1495cb15cea1c06aaff2b5 Mon Sep 17 00:00:00 2001 From: Aaron Farntrog Date: Tue, 14 Apr 2026 16:38:28 -0400 Subject: [PATCH 5/7] feat(eval_task): Support Agent return and optional case parameter Allow eval_task-decorated functions to either accept a Case argument or take no arguments. When the function returns an Agent instance, the decorator auto-invokes it with case.input and stringifies the result. This simplifies the common pattern where tasks just need to construct and run an agent. --- src/strands_evals/eval_task.py | 42 +++++++++++++------ tests/strands_evals/test_eval_task.py | 58 ++++++++++++++++++++++----- 2 files changed, 76 insertions(+), 24 deletions(-) diff --git a/src/strands_evals/eval_task.py b/src/strands_evals/eval_task.py index 5918c40a..5d4861b0 100644 --- a/src/strands_evals/eval_task.py +++ b/src/strands_evals/eval_task.py @@ -1,9 +1,12 @@ """Decorator and handlers for wrapping task functions with evaluation behavior.""" import functools +import inspect from collections.abc import Callable from typing import Any +from strands import Agent + from .case import Case from .mappers.strands_in_memory_session_mapper import StrandsInMemorySessionMapper from .telemetry import StrandsEvalsTelemetry @@ -44,12 +47,8 @@ class TracedHandler(EvalTaskHandler): Example: @eval_task(TracedHandler()) - def my_task(case): - return str(my_agent(case.input)) - - @eval_task(TracedHandler(mapper=MyCustomMapper())) - def my_task(case): - return str(my_agent(case.input)) + def my_task(): + return Agent(model="...", tools=[calculator]) """ def __init__(self, mapper=None): @@ -69,32 +68,49 @@ def after(self, case: Case, result: Any) -> dict[str, Any]: return processed +def _accepts_case(fn: Callable) -> bool: + """Check if a function accepts a positional argument.""" + sig = inspect.signature(fn) + params = [p for p in sig.parameters.values() if p.kind in (p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD)] + return len(params) >= 1 + + def eval_task(handler: EvalTaskHandler | None = None) -> Callable: """Decorator that wraps a task function with evaluation behavior. + The decorated function can: + - Take no arguments (simple) or a Case argument (for per-case customization) + - Return an Agent (auto-invoked with case.input), a string, or a dict + Args: handler: Handler that runs before/after the task function. Defaults to EvalTaskHandler (normalizes return values only). Example: - # Simple — just normalize output @eval_task() - def my_task(case): - return str(my_agent(case.input)) + def my_task(): + return Agent(model="...", tools=[calculator]) - # With telemetry collection @eval_task(TracedHandler()) def my_task(case): - return str(my_agent(case.input)) + tools = [calculator] if case.metadata.get("use_calc") else [] + return Agent(model="...", tools=tools) """ if handler is None: handler = EvalTaskHandler() - def decorator(fn: Callable[[Case], Any]) -> Callable[[Case], dict[str, Any]]: + def decorator(fn: Callable) -> Callable[[Case], dict[str, Any]]: + takes_case = _accepts_case(fn) + @functools.wraps(fn) def wrapper(case: Case) -> dict[str, Any]: handler.before(case) - result = fn(case) + + result = fn(case) if takes_case else fn() + + if isinstance(result, Agent): + result = str(result(case.input)) + return handler.after(case, result) return wrapper diff --git a/tests/strands_evals/test_eval_task.py b/tests/strands_evals/test_eval_task.py index 1d5cacee..2b053fb1 100644 --- a/tests/strands_evals/test_eval_task.py +++ b/tests/strands_evals/test_eval_task.py @@ -16,7 +16,7 @@ def my_task(case): assert callable(my_task) - def test_passes_case_to_function(self): + def test_passes_case_to_function_with_case_param(self): received_case = None @eval_task() @@ -29,6 +29,45 @@ def my_task(case): my_task(case) assert received_case is case + def test_no_case_param_function_works(self): + """Functions with no parameters are called without case.""" + @eval_task() + def my_task(): + return "output" + + result = my_task(Case(name="test", input="hi")) + assert result == {"output": "output"} + + def test_agent_return_auto_invoked(self): + """When function returns an Agent, decorator invokes it with case.input.""" + from strands import Agent + + mock_agent = MagicMock(spec=Agent) + mock_agent.return_value.__str__ = MagicMock(return_value="42") + + @eval_task() + def my_task(): + return mock_agent + + result = my_task(Case(name="test", input="What is 2+2?")) + mock_agent.assert_called_once_with("What is 2+2?") + assert result["output"] == "42" + + def test_agent_return_with_case_param(self): + """Agent returned from function with case param is also auto-invoked.""" + from strands import Agent + + mock_agent = MagicMock(spec=Agent) + mock_agent.return_value.__str__ = MagicMock(return_value="answer") + + @eval_task() + def my_task(case): + return mock_agent + + result = my_task(Case(name="test", input="question")) + mock_agent.assert_called_once_with("question") + assert result["output"] == "answer" + def test_string_return_wrapped_as_dict(self): @eval_task() def my_task(case): @@ -45,14 +84,6 @@ def my_task(case): result = my_task(Case(name="test", input="hi")) assert result == {"output": "answer", "custom_key": "value"} - def test_works_without_handler(self): - @eval_task() - def my_task(case): - return "output" - - result = my_task(Case(name="test", input="hi")) - assert result == {"output": "output"} - def test_works_with_experiment(self): from strands_evals.evaluators.evaluator import Evaluator from strands_evals.experiment import Experiment @@ -62,9 +93,14 @@ class PassingEvaluator(Evaluator): def evaluate(self, evaluation_case): return [EvaluationOutput(score=1.0, test_pass=True, reason="pass")] + from strands import Agent + + mock_agent = MagicMock(spec=Agent) + mock_agent.return_value.__str__ = MagicMock(return_value="output") + @eval_task() - def my_task(case): - return "output" + def my_task(): + return mock_agent experiment = Experiment( cases=[Case(name="test", input="hi")], From 384f83026cb808247c1adbf4a4f7fe4838796320 Mon Sep 17 00:00:00 2001 From: Aaron Farntrog Date: Wed, 15 Apr 2026 11:20:22 -0400 Subject: [PATCH 6/7] feat: export eval_task decorator from package public API - Add `eval_task` to `__init__.py` imports and `__all__` for direct access - Document TracedHandler concurrency limitations (sequential only or max_workers=1 due to shared span exporter) - Minor test formatting cleanup --- src/strands_evals/__init__.py | 3 ++- src/strands_evals/eval_task.py | 5 +++++ tests/strands_evals/test_eval_task.py | 5 ++--- 3 files changed, 9 insertions(+), 4 deletions(-) diff --git a/src/strands_evals/__init__.py b/src/strands_evals/__init__.py index 2b89af0d..f130646e 100644 --- a/src/strands_evals/__init__.py +++ b/src/strands_evals/__init__.py @@ -1,6 +1,6 @@ from . import evaluators, extractors, generators, providers, simulation, telemetry, types from .case import Case -from .eval_task import EvalTaskHandler, TracedHandler +from .eval_task import EvalTaskHandler, TracedHandler, eval_task from .evaluation_data_store import EvaluationDataStore from .experiment import Experiment from .local_file_task_result_store import LocalFileTaskResultStore @@ -14,6 +14,7 @@ "EvaluationDataStore", "EvalTaskHandler", "TracedHandler", + "eval_task", "evaluators", "extractors", "providers", diff --git a/src/strands_evals/eval_task.py b/src/strands_evals/eval_task.py index 5d4861b0..5a6c6bc0 100644 --- a/src/strands_evals/eval_task.py +++ b/src/strands_evals/eval_task.py @@ -42,6 +42,11 @@ class TracedHandler(EvalTaskHandler): Use with @eval_task when your evaluators need trajectory data. + This handler shares a single span exporter across calls. Use only with + sequential execution (run_evaluations) or run_evaluations_async with + max_workers=1. For concurrent execution, each worker needs its own + TracedHandler instance. + Args: mapper: Session mapper to use. Defaults to StrandsInMemorySessionMapper. diff --git a/tests/strands_evals/test_eval_task.py b/tests/strands_evals/test_eval_task.py index 2b053fb1..5f739fdc 100644 --- a/tests/strands_evals/test_eval_task.py +++ b/tests/strands_evals/test_eval_task.py @@ -31,6 +31,7 @@ def my_task(case): def test_no_case_param_function_works(self): """Functions with no parameters are called without case.""" + @eval_task() def my_task(): return "output" @@ -161,9 +162,7 @@ def test_after_adds_trajectory(self, mock_mapper_cls, mock_telemetry_cls): assert result["output"] == "output" assert result["trajectory"] is mock_session - mock_mapper_cls.return_value.map_to_session.assert_called_once_with( - list(mock_spans), "sess-1" - ) + mock_mapper_cls.return_value.map_to_session.assert_called_once_with(list(mock_spans), "sess-1") @patch("strands_evals.eval_task.StrandsEvalsTelemetry") @patch("strands_evals.eval_task.StrandsInMemorySessionMapper") From d56574e0364a34533b101f138f6018cb6518d3b2 Mon Sep 17 00:00:00 2001 From: Aaron Farntrog Date: Tue, 21 Apr 2026 13:18:47 -0400 Subject: [PATCH 7/7] fix: rename eval_task module to avoid shadowing by function import Rename eval_task.py to eval_task_handler.py so the module name no longer collides with the eval_task function re-exported in __init__.py. This was causing @patch targets in TestTracedHandler to resolve to the function instead of the module, breaking all 5 TracedHandler tests with AttributeError. --- src/strands_evals/__init__.py | 2 +- .../{eval_task.py => eval_task_handler.py} | 0 tests/strands_evals/test_eval_task.py | 22 +++++++++---------- 3 files changed, 12 insertions(+), 12 deletions(-) rename src/strands_evals/{eval_task.py => eval_task_handler.py} (100%) diff --git a/src/strands_evals/__init__.py b/src/strands_evals/__init__.py index f130646e..3301d2c4 100644 --- a/src/strands_evals/__init__.py +++ b/src/strands_evals/__init__.py @@ -1,6 +1,6 @@ from . import evaluators, extractors, generators, providers, simulation, telemetry, types from .case import Case -from .eval_task import EvalTaskHandler, TracedHandler, eval_task +from .eval_task_handler import EvalTaskHandler, TracedHandler, eval_task from .evaluation_data_store import EvaluationDataStore from .experiment import Experiment from .local_file_task_result_store import LocalFileTaskResultStore diff --git a/src/strands_evals/eval_task.py b/src/strands_evals/eval_task_handler.py similarity index 100% rename from src/strands_evals/eval_task.py rename to src/strands_evals/eval_task_handler.py diff --git a/tests/strands_evals/test_eval_task.py b/tests/strands_evals/test_eval_task.py index 5f739fdc..64871879 100644 --- a/tests/strands_evals/test_eval_task.py +++ b/tests/strands_evals/test_eval_task.py @@ -3,7 +3,7 @@ from unittest.mock import MagicMock, patch from strands_evals.case import Case -from strands_evals.eval_task import EvalTaskHandler, TracedHandler, eval_task +from strands_evals.eval_task_handler import EvalTaskHandler, TracedHandler, eval_task class TestEvalTaskDecorator: @@ -132,8 +132,8 @@ def test_after_passes_dict_through(self): class TestTracedHandler: """Tests for TracedHandler that collects telemetry spans.""" - @patch("strands_evals.eval_task.StrandsEvalsTelemetry") - @patch("strands_evals.eval_task.StrandsInMemorySessionMapper") + @patch("strands_evals.eval_task_handler.StrandsEvalsTelemetry") + @patch("strands_evals.eval_task_handler.StrandsInMemorySessionMapper") def test_clears_spans_on_before(self, mock_mapper_cls, mock_telemetry_cls): mock_telemetry = MagicMock() mock_telemetry.setup_in_memory_exporter.return_value = mock_telemetry @@ -144,8 +144,8 @@ def test_clears_spans_on_before(self, mock_mapper_cls, mock_telemetry_cls): mock_telemetry.in_memory_exporter.clear.assert_called_once() - @patch("strands_evals.eval_task.StrandsEvalsTelemetry") - @patch("strands_evals.eval_task.StrandsInMemorySessionMapper") + @patch("strands_evals.eval_task_handler.StrandsEvalsTelemetry") + @patch("strands_evals.eval_task_handler.StrandsInMemorySessionMapper") def test_after_adds_trajectory(self, mock_mapper_cls, mock_telemetry_cls): mock_spans = [MagicMock()] mock_telemetry = MagicMock() @@ -164,8 +164,8 @@ def test_after_adds_trajectory(self, mock_mapper_cls, mock_telemetry_cls): assert result["trajectory"] is mock_session mock_mapper_cls.return_value.map_to_session.assert_called_once_with(list(mock_spans), "sess-1") - @patch("strands_evals.eval_task.StrandsEvalsTelemetry") - @patch("strands_evals.eval_task.StrandsInMemorySessionMapper") + @patch("strands_evals.eval_task_handler.StrandsEvalsTelemetry") + @patch("strands_evals.eval_task_handler.StrandsInMemorySessionMapper") def test_does_not_override_user_trajectory(self, mock_mapper_cls, mock_telemetry_cls): mock_telemetry = MagicMock() mock_telemetry.setup_in_memory_exporter.return_value = mock_telemetry @@ -182,8 +182,8 @@ def test_does_not_override_user_trajectory(self, mock_mapper_cls, mock_telemetry assert result["trajectory"] is user_trajectory - @patch("strands_evals.eval_task.StrandsEvalsTelemetry") - @patch("strands_evals.eval_task.StrandsInMemorySessionMapper") + @patch("strands_evals.eval_task_handler.StrandsEvalsTelemetry") + @patch("strands_evals.eval_task_handler.StrandsInMemorySessionMapper") def test_accepts_custom_mapper(self, mock_mapper_cls, mock_telemetry_cls): mock_telemetry = MagicMock() mock_telemetry.setup_in_memory_exporter.return_value = mock_telemetry @@ -200,8 +200,8 @@ def test_accepts_custom_mapper(self, mock_mapper_cls, mock_telemetry_cls): custom_mapper.map_to_session.assert_called_once() assert result["trajectory"] is custom_session - @patch("strands_evals.eval_task.StrandsEvalsTelemetry") - @patch("strands_evals.eval_task.StrandsInMemorySessionMapper") + @patch("strands_evals.eval_task_handler.StrandsEvalsTelemetry") + @patch("strands_evals.eval_task_handler.StrandsInMemorySessionMapper") def test_full_decorator_flow(self, mock_mapper_cls, mock_telemetry_cls): mock_telemetry = MagicMock() mock_telemetry.setup_in_memory_exporter.return_value = mock_telemetry