diff --git a/src/strands_evals/__init__.py b/src/strands_evals/__init__.py index 7de7520e..3301d2c4 100644 --- a/src/strands_evals/__init__.py +++ b/src/strands_evals/__init__.py @@ -1,5 +1,6 @@ from . import evaluators, extractors, generators, providers, simulation, telemetry, types from .case import Case +from .eval_task_handler import EvalTaskHandler, TracedHandler, eval_task from .evaluation_data_store import EvaluationDataStore from .experiment import Experiment from .local_file_task_result_store import LocalFileTaskResultStore @@ -11,6 +12,9 @@ "Case", "LocalFileTaskResultStore", "EvaluationDataStore", + "EvalTaskHandler", + "TracedHandler", + "eval_task", "evaluators", "extractors", "providers", diff --git a/src/strands_evals/eval_task_handler.py b/src/strands_evals/eval_task_handler.py new file mode 100644 index 00000000..5a6c6bc0 --- /dev/null +++ b/src/strands_evals/eval_task_handler.py @@ -0,0 +1,123 @@ +"""Decorator and handlers for wrapping task functions with evaluation behavior.""" + +import functools +import inspect +from collections.abc import Callable +from typing import Any + +from strands import Agent + +from .case import Case +from .mappers.strands_in_memory_session_mapper import StrandsInMemorySessionMapper +from .telemetry import StrandsEvalsTelemetry + + +class EvalTaskHandler: + """Base handler that normalizes task function return values. + + Subclass to add behavior before/after task execution (e.g., telemetry collection). + """ + + def before(self, case: Case) -> None: + """Called before the task function runs. Override to add setup logic.""" + pass + + def after(self, case: Case, result: Any) -> dict[str, Any]: + """Called after the task function runs. Normalizes the result to a dict. + + Args: + case: The test case that was executed. + result: The raw return value from the task function. + + Returns: + A dict compatible with Experiment (must have at least "output" key). + """ + if isinstance(result, dict): + return result + return {"output": str(result)} + + +class TracedHandler(EvalTaskHandler): + """Handler that collects OpenTelemetry spans and maps them to a Session. + + Use with @eval_task when your evaluators need trajectory data. + + This handler shares a single span exporter across calls. Use only with + sequential execution (run_evaluations) or run_evaluations_async with + max_workers=1. For concurrent execution, each worker needs its own + TracedHandler instance. + + Args: + mapper: Session mapper to use. Defaults to StrandsInMemorySessionMapper. + + Example: + @eval_task(TracedHandler()) + def my_task(): + return Agent(model="...", tools=[calculator]) + """ + + def __init__(self, mapper=None): + self._telemetry = StrandsEvalsTelemetry().setup_in_memory_exporter() + self._mapper = mapper or StrandsInMemorySessionMapper() + + def before(self, case: Case) -> None: + self._telemetry.in_memory_exporter.clear() + + def after(self, case: Case, result: Any) -> dict[str, Any]: + processed = super().after(case, result) + + spans = list(self._telemetry.in_memory_exporter.get_finished_spans()) + session = self._mapper.map_to_session(spans, case.session_id) + processed.setdefault("trajectory", session) + + return processed + + +def _accepts_case(fn: Callable) -> bool: + """Check if a function accepts a positional argument.""" + sig = inspect.signature(fn) + params = [p for p in sig.parameters.values() if p.kind in (p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD)] + return len(params) >= 1 + + +def eval_task(handler: EvalTaskHandler | None = None) -> Callable: + """Decorator that wraps a task function with evaluation behavior. + + The decorated function can: + - Take no arguments (simple) or a Case argument (for per-case customization) + - Return an Agent (auto-invoked with case.input), a string, or a dict + + Args: + handler: Handler that runs before/after the task function. + Defaults to EvalTaskHandler (normalizes return values only). + + Example: + @eval_task() + def my_task(): + return Agent(model="...", tools=[calculator]) + + @eval_task(TracedHandler()) + def my_task(case): + tools = [calculator] if case.metadata.get("use_calc") else [] + return Agent(model="...", tools=tools) + """ + if handler is None: + handler = EvalTaskHandler() + + def decorator(fn: Callable) -> Callable[[Case], dict[str, Any]]: + takes_case = _accepts_case(fn) + + @functools.wraps(fn) + def wrapper(case: Case) -> dict[str, Any]: + handler.before(case) + + result = fn(case) if takes_case else fn() + + if isinstance(result, Agent): + result = str(result(case.input)) + + return handler.after(case, result) + + return wrapper + + return decorator diff --git a/tests/strands_evals/test_eval_task.py b/tests/strands_evals/test_eval_task.py new file mode 100644 index 00000000..64871879 --- /dev/null +++ b/tests/strands_evals/test_eval_task.py @@ -0,0 +1,220 @@ +"""Tests for @eval_task decorator and handlers.""" + +from unittest.mock import MagicMock, patch + +from strands_evals.case import Case +from strands_evals.eval_task_handler import EvalTaskHandler, TracedHandler, eval_task + + +class TestEvalTaskDecorator: + """Tests for the @eval_task decorator.""" + + def test_decorated_function_is_callable(self): + @eval_task() + def my_task(case): + return "output" + + assert callable(my_task) + + def test_passes_case_to_function_with_case_param(self): + received_case = None + + @eval_task() + def my_task(case): + nonlocal received_case + received_case = case + return "output" + + case = Case(name="test", input="hello") + my_task(case) + assert received_case is case + + def test_no_case_param_function_works(self): + """Functions with no parameters are called without case.""" + + @eval_task() + def my_task(): + return "output" + + result = my_task(Case(name="test", input="hi")) + assert result == {"output": "output"} + + def test_agent_return_auto_invoked(self): + """When function returns an Agent, decorator invokes it with case.input.""" + from strands import Agent + + mock_agent = MagicMock(spec=Agent) + mock_agent.return_value.__str__ = MagicMock(return_value="42") + + @eval_task() + def my_task(): + return mock_agent + + result = my_task(Case(name="test", input="What is 2+2?")) + mock_agent.assert_called_once_with("What is 2+2?") + assert result["output"] == "42" + + def test_agent_return_with_case_param(self): + """Agent returned from function with case param is also auto-invoked.""" + from strands import Agent + + mock_agent = MagicMock(spec=Agent) + mock_agent.return_value.__str__ = MagicMock(return_value="answer") + + @eval_task() + def my_task(case): + return mock_agent + + result = my_task(Case(name="test", input="question")) + mock_agent.assert_called_once_with("question") + assert result["output"] == "answer" + + def test_string_return_wrapped_as_dict(self): + @eval_task() + def my_task(case): + return "the answer" + + result = my_task(Case(name="test", input="hi")) + assert result == {"output": "the answer"} + + def test_dict_return_passed_through(self): + @eval_task() + def my_task(case): + return {"output": "answer", "custom_key": "value"} + + result = my_task(Case(name="test", input="hi")) + assert result == {"output": "answer", "custom_key": "value"} + + def test_works_with_experiment(self): + from strands_evals.evaluators.evaluator import Evaluator + from strands_evals.experiment import Experiment + from strands_evals.types.evaluation import EvaluationOutput + + class PassingEvaluator(Evaluator): + def evaluate(self, evaluation_case): + return [EvaluationOutput(score=1.0, test_pass=True, reason="pass")] + + from strands import Agent + + mock_agent = MagicMock(spec=Agent) + mock_agent.return_value.__str__ = MagicMock(return_value="output") + + @eval_task() + def my_task(): + return mock_agent + + experiment = Experiment( + cases=[Case(name="test", input="hi")], + evaluators=[PassingEvaluator()], + ) + reports = experiment.run_evaluations(my_task) + assert reports[0].scores[0] == 1.0 + + +class TestEvalTaskHandler: + """Tests for the base EvalTaskHandler.""" + + def test_before_is_noop(self): + handler = EvalTaskHandler() + handler.before(Case(name="test", input="hi")) # should not raise + + def test_after_wraps_string(self): + handler = EvalTaskHandler() + result = handler.after(Case(name="test", input="hi"), "output") + assert result == {"output": "output"} + + def test_after_passes_dict_through(self): + handler = EvalTaskHandler() + result = handler.after(Case(name="test", input="hi"), {"output": "x", "extra": "y"}) + assert result == {"output": "x", "extra": "y"} + + +class TestTracedHandler: + """Tests for TracedHandler that collects telemetry spans.""" + + @patch("strands_evals.eval_task_handler.StrandsEvalsTelemetry") + @patch("strands_evals.eval_task_handler.StrandsInMemorySessionMapper") + def test_clears_spans_on_before(self, mock_mapper_cls, mock_telemetry_cls): + mock_telemetry = MagicMock() + mock_telemetry.setup_in_memory_exporter.return_value = mock_telemetry + mock_telemetry_cls.return_value = mock_telemetry + + handler = TracedHandler() + handler.before(Case(name="test", input="hi")) + + mock_telemetry.in_memory_exporter.clear.assert_called_once() + + @patch("strands_evals.eval_task_handler.StrandsEvalsTelemetry") + @patch("strands_evals.eval_task_handler.StrandsInMemorySessionMapper") + def test_after_adds_trajectory(self, mock_mapper_cls, mock_telemetry_cls): + mock_spans = [MagicMock()] + mock_telemetry = MagicMock() + mock_telemetry.setup_in_memory_exporter.return_value = mock_telemetry + mock_telemetry.in_memory_exporter.get_finished_spans.return_value = mock_spans + mock_telemetry_cls.return_value = mock_telemetry + + mock_session = MagicMock() + mock_mapper_cls.return_value.map_to_session.return_value = mock_session + + handler = TracedHandler() + case = Case(name="test", input="hi", session_id="sess-1") + result = handler.after(case, "output") + + assert result["output"] == "output" + assert result["trajectory"] is mock_session + mock_mapper_cls.return_value.map_to_session.assert_called_once_with(list(mock_spans), "sess-1") + + @patch("strands_evals.eval_task_handler.StrandsEvalsTelemetry") + @patch("strands_evals.eval_task_handler.StrandsInMemorySessionMapper") + def test_does_not_override_user_trajectory(self, mock_mapper_cls, mock_telemetry_cls): + mock_telemetry = MagicMock() + mock_telemetry.setup_in_memory_exporter.return_value = mock_telemetry + mock_telemetry.in_memory_exporter.get_finished_spans.return_value = [] + mock_telemetry_cls.return_value = mock_telemetry + mock_mapper_cls.return_value.map_to_session.return_value = MagicMock() + + handler = TracedHandler() + user_trajectory = ["tool1", "tool2"] + result = handler.after( + Case(name="test", input="hi"), + {"output": "x", "trajectory": user_trajectory}, + ) + + assert result["trajectory"] is user_trajectory + + @patch("strands_evals.eval_task_handler.StrandsEvalsTelemetry") + @patch("strands_evals.eval_task_handler.StrandsInMemorySessionMapper") + def test_accepts_custom_mapper(self, mock_mapper_cls, mock_telemetry_cls): + mock_telemetry = MagicMock() + mock_telemetry.setup_in_memory_exporter.return_value = mock_telemetry + mock_telemetry.in_memory_exporter.get_finished_spans.return_value = [] + mock_telemetry_cls.return_value = mock_telemetry + + custom_mapper = MagicMock() + custom_session = MagicMock() + custom_mapper.map_to_session.return_value = custom_session + + handler = TracedHandler(mapper=custom_mapper) + result = handler.after(Case(name="test", input="hi", session_id="s1"), "out") + + custom_mapper.map_to_session.assert_called_once() + assert result["trajectory"] is custom_session + + @patch("strands_evals.eval_task_handler.StrandsEvalsTelemetry") + @patch("strands_evals.eval_task_handler.StrandsInMemorySessionMapper") + def test_full_decorator_flow(self, mock_mapper_cls, mock_telemetry_cls): + mock_telemetry = MagicMock() + mock_telemetry.setup_in_memory_exporter.return_value = mock_telemetry + mock_telemetry.in_memory_exporter.get_finished_spans.return_value = [] + mock_telemetry_cls.return_value = mock_telemetry + + mock_session = MagicMock() + mock_mapper_cls.return_value.map_to_session.return_value = mock_session + + @eval_task(TracedHandler()) + def my_task(case): + return f"answer to {case.input}" + + result = my_task(Case(name="test", input="question")) + assert result["output"] == "answer to question" + assert result["trajectory"] is mock_session