Skip to content
Open
4 changes: 4 additions & 0 deletions src/strands_evals/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from . import evaluators, extractors, generators, providers, simulation, telemetry, types
from .case import Case
from .eval_task_handler import EvalTaskHandler, TracedHandler, eval_task
from .evaluation_data_store import EvaluationDataStore
from .experiment import Experiment
from .local_file_task_result_store import LocalFileTaskResultStore
Expand All @@ -11,6 +12,9 @@
"Case",
"LocalFileTaskResultStore",
"EvaluationDataStore",
"EvalTaskHandler",
"TracedHandler",
"eval_task",
"evaluators",
"extractors",
"providers",
Expand Down
123 changes: 123 additions & 0 deletions src/strands_evals/eval_task_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
"""Decorator and handlers for wrapping task functions with evaluation behavior."""

import functools
import inspect
from collections.abc import Callable
from typing import Any

from strands import Agent

from .case import Case
from .mappers.strands_in_memory_session_mapper import StrandsInMemorySessionMapper
from .telemetry import StrandsEvalsTelemetry


class EvalTaskHandler:
"""Base handler that normalizes task function return values.

Subclass to add behavior before/after task execution (e.g., telemetry collection).
"""

def before(self, case: Case) -> None:
"""Called before the task function runs. Override to add setup logic."""
pass

def after(self, case: Case, result: Any) -> dict[str, Any]:
"""Called after the task function runs. Normalizes the result to a dict.

Args:
case: The test case that was executed.
result: The raw return value from the task function.

Returns:
A dict compatible with Experiment (must have at least "output" key).
"""
if isinstance(result, dict):
return result
return {"output": str(result)}


class TracedHandler(EvalTaskHandler):
"""Handler that collects OpenTelemetry spans and maps them to a Session.

Use with @eval_task when your evaluators need trajectory data.

This handler shares a single span exporter across calls. Use only with
sequential execution (run_evaluations) or run_evaluations_async with
max_workers=1. For concurrent execution, each worker needs its own
TracedHandler instance.

Args:
mapper: Session mapper to use. Defaults to StrandsInMemorySessionMapper.

Example:
@eval_task(TracedHandler())
def my_task():
return Agent(model="...", tools=[calculator])
"""

def __init__(self, mapper=None):
self._telemetry = StrandsEvalsTelemetry().setup_in_memory_exporter()
self._mapper = mapper or StrandsInMemorySessionMapper()

def before(self, case: Case) -> None:
self._telemetry.in_memory_exporter.clear()

def after(self, case: Case, result: Any) -> dict[str, Any]:
processed = super().after(case, result)

spans = list(self._telemetry.in_memory_exporter.get_finished_spans())
session = self._mapper.map_to_session(spans, case.session_id)
processed.setdefault("trajectory", session)

return processed


def _accepts_case(fn: Callable) -> bool:
"""Check if a function accepts a positional argument."""
sig = inspect.signature(fn)
params = [p for p in sig.parameters.values() if p.kind in (p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD)]
return len(params) >= 1


def eval_task(handler: EvalTaskHandler | None = None) -> Callable:
"""Decorator that wraps a task function with evaluation behavior.

The decorated function can:
- Take no arguments (simple) or a Case argument (for per-case customization)
- Return an Agent (auto-invoked with case.input), a string, or a dict

Args:
handler: Handler that runs before/after the task function.
Defaults to EvalTaskHandler (normalizes return values only).

Example:
@eval_task()
def my_task():
return Agent(model="...", tools=[calculator])

@eval_task(TracedHandler())
def my_task(case):
tools = [calculator] if case.metadata.get("use_calc") else []
return Agent(model="...", tools=tools)
"""
if handler is None:
handler = EvalTaskHandler()

def decorator(fn: Callable) -> Callable[[Case], dict[str, Any]]:
takes_case = _accepts_case(fn)

@functools.wraps(fn)
def wrapper(case: Case) -> dict[str, Any]:
handler.before(case)

result = fn(case) if takes_case else fn()

if isinstance(result, Agent):
result = str(result(case.input))

return handler.after(case, result)

return wrapper

return decorator
220 changes: 220 additions & 0 deletions tests/strands_evals/test_eval_task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
"""Tests for @eval_task decorator and handlers."""

from unittest.mock import MagicMock, patch

from strands_evals.case import Case
from strands_evals.eval_task_handler import EvalTaskHandler, TracedHandler, eval_task


class TestEvalTaskDecorator:
"""Tests for the @eval_task decorator."""

def test_decorated_function_is_callable(self):
@eval_task()
def my_task(case):
return "output"

assert callable(my_task)

def test_passes_case_to_function_with_case_param(self):
received_case = None

@eval_task()
def my_task(case):
nonlocal received_case
received_case = case
return "output"

case = Case(name="test", input="hello")
my_task(case)
assert received_case is case

def test_no_case_param_function_works(self):
"""Functions with no parameters are called without case."""

@eval_task()
def my_task():
return "output"

result = my_task(Case(name="test", input="hi"))
assert result == {"output": "output"}

def test_agent_return_auto_invoked(self):
"""When function returns an Agent, decorator invokes it with case.input."""
from strands import Agent

mock_agent = MagicMock(spec=Agent)
mock_agent.return_value.__str__ = MagicMock(return_value="42")

@eval_task()
def my_task():
return mock_agent

result = my_task(Case(name="test", input="What is 2+2?"))
mock_agent.assert_called_once_with("What is 2+2?")
assert result["output"] == "42"

def test_agent_return_with_case_param(self):
"""Agent returned from function with case param is also auto-invoked."""
from strands import Agent

mock_agent = MagicMock(spec=Agent)
mock_agent.return_value.__str__ = MagicMock(return_value="answer")

@eval_task()
def my_task(case):
return mock_agent

result = my_task(Case(name="test", input="question"))
mock_agent.assert_called_once_with("question")
assert result["output"] == "answer"

def test_string_return_wrapped_as_dict(self):
@eval_task()
def my_task(case):
return "the answer"

result = my_task(Case(name="test", input="hi"))
assert result == {"output": "the answer"}

def test_dict_return_passed_through(self):
@eval_task()
def my_task(case):
return {"output": "answer", "custom_key": "value"}

result = my_task(Case(name="test", input="hi"))
assert result == {"output": "answer", "custom_key": "value"}

def test_works_with_experiment(self):
from strands_evals.evaluators.evaluator import Evaluator
from strands_evals.experiment import Experiment
from strands_evals.types.evaluation import EvaluationOutput

class PassingEvaluator(Evaluator):
def evaluate(self, evaluation_case):
return [EvaluationOutput(score=1.0, test_pass=True, reason="pass")]

from strands import Agent

mock_agent = MagicMock(spec=Agent)
mock_agent.return_value.__str__ = MagicMock(return_value="output")

@eval_task()
def my_task():
return mock_agent

experiment = Experiment(
cases=[Case(name="test", input="hi")],
evaluators=[PassingEvaluator()],
)
reports = experiment.run_evaluations(my_task)
assert reports[0].scores[0] == 1.0


class TestEvalTaskHandler:
"""Tests for the base EvalTaskHandler."""

def test_before_is_noop(self):
handler = EvalTaskHandler()
handler.before(Case(name="test", input="hi")) # should not raise

def test_after_wraps_string(self):
handler = EvalTaskHandler()
result = handler.after(Case(name="test", input="hi"), "output")
assert result == {"output": "output"}

def test_after_passes_dict_through(self):
handler = EvalTaskHandler()
result = handler.after(Case(name="test", input="hi"), {"output": "x", "extra": "y"})
assert result == {"output": "x", "extra": "y"}


class TestTracedHandler:
"""Tests for TracedHandler that collects telemetry spans."""

@patch("strands_evals.eval_task_handler.StrandsEvalsTelemetry")
@patch("strands_evals.eval_task_handler.StrandsInMemorySessionMapper")
def test_clears_spans_on_before(self, mock_mapper_cls, mock_telemetry_cls):
mock_telemetry = MagicMock()
mock_telemetry.setup_in_memory_exporter.return_value = mock_telemetry
mock_telemetry_cls.return_value = mock_telemetry

handler = TracedHandler()
handler.before(Case(name="test", input="hi"))

mock_telemetry.in_memory_exporter.clear.assert_called_once()

@patch("strands_evals.eval_task_handler.StrandsEvalsTelemetry")
@patch("strands_evals.eval_task_handler.StrandsInMemorySessionMapper")
def test_after_adds_trajectory(self, mock_mapper_cls, mock_telemetry_cls):
mock_spans = [MagicMock()]
mock_telemetry = MagicMock()
mock_telemetry.setup_in_memory_exporter.return_value = mock_telemetry
mock_telemetry.in_memory_exporter.get_finished_spans.return_value = mock_spans
mock_telemetry_cls.return_value = mock_telemetry

mock_session = MagicMock()
mock_mapper_cls.return_value.map_to_session.return_value = mock_session

handler = TracedHandler()
case = Case(name="test", input="hi", session_id="sess-1")
result = handler.after(case, "output")

assert result["output"] == "output"
assert result["trajectory"] is mock_session
mock_mapper_cls.return_value.map_to_session.assert_called_once_with(list(mock_spans), "sess-1")

@patch("strands_evals.eval_task_handler.StrandsEvalsTelemetry")
@patch("strands_evals.eval_task_handler.StrandsInMemorySessionMapper")
def test_does_not_override_user_trajectory(self, mock_mapper_cls, mock_telemetry_cls):
mock_telemetry = MagicMock()
mock_telemetry.setup_in_memory_exporter.return_value = mock_telemetry
mock_telemetry.in_memory_exporter.get_finished_spans.return_value = []
mock_telemetry_cls.return_value = mock_telemetry
mock_mapper_cls.return_value.map_to_session.return_value = MagicMock()

handler = TracedHandler()
user_trajectory = ["tool1", "tool2"]
result = handler.after(
Case(name="test", input="hi"),
{"output": "x", "trajectory": user_trajectory},
)

assert result["trajectory"] is user_trajectory

@patch("strands_evals.eval_task_handler.StrandsEvalsTelemetry")
@patch("strands_evals.eval_task_handler.StrandsInMemorySessionMapper")
def test_accepts_custom_mapper(self, mock_mapper_cls, mock_telemetry_cls):
mock_telemetry = MagicMock()
mock_telemetry.setup_in_memory_exporter.return_value = mock_telemetry
mock_telemetry.in_memory_exporter.get_finished_spans.return_value = []
mock_telemetry_cls.return_value = mock_telemetry

custom_mapper = MagicMock()
custom_session = MagicMock()
custom_mapper.map_to_session.return_value = custom_session

handler = TracedHandler(mapper=custom_mapper)
result = handler.after(Case(name="test", input="hi", session_id="s1"), "out")

custom_mapper.map_to_session.assert_called_once()
assert result["trajectory"] is custom_session

@patch("strands_evals.eval_task_handler.StrandsEvalsTelemetry")
@patch("strands_evals.eval_task_handler.StrandsInMemorySessionMapper")
def test_full_decorator_flow(self, mock_mapper_cls, mock_telemetry_cls):
mock_telemetry = MagicMock()
mock_telemetry.setup_in_memory_exporter.return_value = mock_telemetry
mock_telemetry.in_memory_exporter.get_finished_spans.return_value = []
mock_telemetry_cls.return_value = mock_telemetry

mock_session = MagicMock()
mock_mapper_cls.return_value.map_to_session.return_value = mock_session

@eval_task(TracedHandler())
def my_task(case):
return f"answer to {case.input}"

result = my_task(Case(name="test", input="question"))
assert result["output"] == "answer to question"
assert result["trajectory"] is mock_session
Loading