Skip to content

Commit 0d428fb

Browse files
CLowbrowAlex Z
andauthored
Trace injection in python to mirror the JS implementation (#175)
This is a direct port of: #168 --------- Co-authored-by: Alex Z <alex.zelenskiy@braintrustdata.com>
1 parent d99a37c commit 0d428fb

4 files changed

Lines changed: 520 additions & 1 deletion

File tree

py/autoevals/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,4 +134,5 @@ async def evaluate_qa():
134134
from .ragas import *
135135
from .score import Score, Scorer, SerializableDataClass
136136
from .string import *
137+
from .thread_utils import *
137138
from .value import ExactMatch

py/autoevals/llm.py

Lines changed: 81 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,13 @@
4545
```
4646
"""
4747

48+
import asyncio
49+
import inspect
4850
import json
4951
import os
5052
import re
5153
from collections import defaultdict
54+
from collections.abc import Callable
5255
from dataclasses import dataclass
5356

5457
import chevron
@@ -58,6 +61,11 @@
5861

5962
from .oai import Client, arun_cached_request, get_default_model, run_cached_request
6063
from .score import Score
64+
from .thread_utils import (
65+
THREAD_VARIABLE_NAMES,
66+
compute_thread_template_vars,
67+
template_uses_thread_variables,
68+
)
6169

6270
# Disable HTML escaping in chevron.
6371
chevron.renderer._html_escape = lambda x: x # type: ignore[attr-defined]
@@ -243,6 +251,9 @@ def _request_args(self, output, expected, **kwargs):
243251

244252
return ret
245253

254+
async def _request_args_async(self, output, expected, **kwargs):
255+
return self._request_args(output, expected, **kwargs)
256+
246257
def _process_response(self, resp):
247258
metadata = {}
248259
if "tool_calls" not in resp:
@@ -268,7 +279,9 @@ def _postprocess_response(self, resp):
268279
raise ValueError("Empty response from OpenAI")
269280

270281
async def _run_eval_async(self, output, expected, **kwargs):
271-
return self._postprocess_response(await arun_cached_request(**self._request_args(output, expected, **kwargs)))
282+
return self._postprocess_response(
283+
await arun_cached_request(**(await self._request_args_async(output, expected, **kwargs)))
284+
)
272285

273286
def _run_eval_sync(self, output, expected, **kwargs):
274287
return self._postprocess_response(run_cached_request(**self._request_args(output, expected, **kwargs)))
@@ -330,10 +343,15 @@ class LLMClassifier(OpenAILLMClassifier):
330343
api_key: Deprecated. Use client instead.
331344
base_url: Deprecated. Use client instead.
332345
client: OpenAI client. If not provided, uses global client from init().
346+
trace: Optional trace object for multi-turn scoring. When provided at
347+
evaluation time and the template references thread variables
348+
(`{{thread}}`, `{{thread_count}}`, etc.), thread variables are
349+
derived from `trace.get_thread()` / `trace.getThread()`.
333350
**extra_render_args: Additional template variables
334351
"""
335352

336353
_SPEC_FILE_CONTENTS: dict[str, str] = defaultdict(str)
354+
_thread_variable_names = THREAD_VARIABLE_NAMES
337355

338356
def __init__(
339357
self,
@@ -353,6 +371,7 @@ def __init__(
353371
client: Client | None = None,
354372
**extra_render_args,
355373
):
374+
self._template_uses_thread_variables = template_uses_thread_variables(prompt_template)
356375
choice_strings = list(choice_scores.keys())
357376
# Use configured default model if not specified
358377
if model is None:
@@ -384,6 +403,67 @@ def __init__(
384403
client=client,
385404
)
386405

406+
@staticmethod
407+
def _get_trace_thread_method(trace) -> Callable[..., object] | None:
408+
if hasattr(trace, "get_thread") and callable(trace.get_thread):
409+
return trace.get_thread
410+
return None
411+
412+
def _compute_thread_vars_sync(self, trace) -> dict[str, object]:
413+
method = self._get_trace_thread_method(trace)
414+
if method is None:
415+
raise TypeError("trace must implement async get_thread(options=None)")
416+
417+
thread_awaitable = method()
418+
if not inspect.isawaitable(thread_awaitable):
419+
raise TypeError("trace.get_thread() must return an awaitable")
420+
try:
421+
asyncio.get_running_loop()
422+
except RuntimeError:
423+
thread = asyncio.run(thread_awaitable)
424+
else:
425+
raise RuntimeError("trace.get_thread() is async; use eval_async() when already inside an event loop")
426+
427+
if not isinstance(thread, list):
428+
thread = list(thread)
429+
430+
computed = compute_thread_template_vars(thread)
431+
return {name: computed[name] for name in self._thread_variable_names}
432+
433+
async def _compute_thread_vars_async(self, trace) -> dict[str, object]:
434+
method = self._get_trace_thread_method(trace)
435+
if method is None:
436+
raise TypeError("trace must implement async get_thread(options=None)")
437+
438+
thread_awaitable = method()
439+
if not inspect.isawaitable(thread_awaitable):
440+
raise TypeError("trace.get_thread() must return an awaitable")
441+
thread = await thread_awaitable
442+
443+
if not isinstance(thread, list):
444+
thread = list(thread)
445+
446+
computed = compute_thread_template_vars(thread)
447+
return {name: computed[name] for name in self._thread_variable_names}
448+
449+
def _request_args(self, output, expected, **kwargs):
450+
trace = kwargs.get("trace")
451+
thread_vars: dict[str, object] = {}
452+
if trace is not None and self._template_uses_thread_variables:
453+
thread_vars = self._compute_thread_vars_sync(trace)
454+
455+
# Thread vars come first so explicit render args can override.
456+
return super()._request_args(output, expected, **thread_vars, **kwargs)
457+
458+
async def _request_args_async(self, output, expected, **kwargs):
459+
trace = kwargs.get("trace")
460+
thread_vars: dict[str, object] = {}
461+
if trace is not None and self._template_uses_thread_variables:
462+
thread_vars = await self._compute_thread_vars_async(trace)
463+
464+
# Thread vars come first so explicit render args can override.
465+
return super()._request_args(output, expected, **thread_vars, **kwargs)
466+
387467
@classmethod
388468
def from_spec(cls, name: str, spec: ModelGradedSpec, client: Client | None = None, **kwargs):
389469
spec_kwargs = {}

py/autoevals/test_llm.py

Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from autoevals import init
1212
from autoevals.llm import Battle, Factuality, LLMClassifier, OpenAILLMClassifier, build_classification_tools
1313
from autoevals.oai import OpenAIV1Module, get_default_model
14+
from autoevals.thread_utils import compute_thread_template_vars
1415

1516

1617
class TestModel(BaseModel):
@@ -54,6 +55,47 @@ def test_render_messages():
5455
assert rendered[5]["content"] == ""
5556

5657

58+
def test_render_messages_with_thread_variables():
59+
classifier = OpenAILLMClassifier(
60+
"test",
61+
messages=[
62+
{"role": "user", "content": "{{thread}}"},
63+
{"role": "user", "content": "First message: {{thread.0}}"},
64+
{"role": "user", "content": "Count: {{thread_count}}"},
65+
{"role": "user", "content": "First: {{first_message}}"},
66+
{"role": "user", "content": "Users: {{user_messages}}"},
67+
{"role": "user", "content": "Pairs: {{human_ai_pairs}}"},
68+
{
69+
"role": "user",
70+
"content": "Messages:{{#thread}}\n- {{role}}: {{content}}{{/thread}}",
71+
},
72+
],
73+
model="gpt-4",
74+
choice_scores={"A": 1},
75+
classification_tools=[],
76+
)
77+
78+
sample_thread = [
79+
{"role": "user", "content": "Hello, how are you?"},
80+
{"role": "assistant", "content": "I am doing well, thank you!"},
81+
{"role": "user", "content": "What is the weather like?"},
82+
{"role": "assistant", "content": "It is sunny and warm today."},
83+
]
84+
thread_vars = compute_thread_template_vars(sample_thread)
85+
rendered = classifier._render_messages(**thread_vars)
86+
87+
assert "User:" in rendered[0]["content"]
88+
assert "Hello, how are you?" in rendered[0]["content"]
89+
assert "Assistant:" in rendered[0]["content"]
90+
assert rendered[1]["content"] == "First message: user: Hello, how are you?"
91+
assert rendered[2]["content"] == "Count: 4"
92+
assert rendered[3]["content"] == "First: user: Hello, how are you?"
93+
assert "Users: User:" in rendered[4]["content"]
94+
assert "Pairs:" in rendered[5]["content"]
95+
assert "human" in rendered[5]["content"]
96+
assert rendered[6]["content"].startswith("Messages:\n- user: Hello, how are you?")
97+
98+
5799
def test_openai():
58100
e = OpenAILLMClassifier(
59101
"title",
@@ -547,3 +589,130 @@ def capture_model(request):
547589

548590
# Reset for other tests
549591
init(None)
592+
593+
594+
@respx.mock
595+
def test_llm_classifier_injects_thread_vars_from_trace():
596+
captured_request_body = None
597+
598+
class TraceStub:
599+
def __init__(self, thread):
600+
self.thread = thread
601+
self.calls = 0
602+
603+
async def get_thread(self):
604+
self.calls += 1
605+
return self.thread
606+
607+
thread = [
608+
{"role": "user", "content": "Hello"},
609+
{"role": "assistant", "content": "Hi there"},
610+
{"role": "user", "content": "Can you help me?"},
611+
]
612+
trace = TraceStub(thread)
613+
614+
def capture_request(request):
615+
nonlocal captured_request_body
616+
captured_request_body = json.loads(request.content.decode("utf-8"))
617+
return Response(
618+
200,
619+
json={
620+
"id": "chatcmpl-test",
621+
"object": "chat.completion",
622+
"created": 1234567890,
623+
"model": "gpt-4o",
624+
"choices": [
625+
{
626+
"index": 0,
627+
"message": {
628+
"role": "assistant",
629+
"content": None,
630+
"tool_calls": [
631+
{
632+
"id": "call_test",
633+
"type": "function",
634+
"function": {"name": "select_choice", "arguments": '{"choice": "1"}'},
635+
}
636+
],
637+
},
638+
"finish_reason": "tool_calls",
639+
}
640+
],
641+
"usage": {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30},
642+
},
643+
)
644+
645+
respx.post("https://api.openai.com/v1/chat/completions").mock(side_effect=capture_request)
646+
client = OpenAI(api_key="test-api-key", base_url="https://api.openai.com/v1")
647+
init(client)
648+
649+
classifier = LLMClassifier(
650+
"thread_test",
651+
"Thread:\n{{thread}}\nCount: {{thread_count}}\nFirst: {{first_message}}\nUsers:\n{{user_messages}}",
652+
{"1": 1, "2": 0},
653+
)
654+
classifier.eval(output="irrelevant", expected="irrelevant", trace=trace)
655+
656+
content = captured_request_body["messages"][0]["content"]
657+
assert trace.calls == 1
658+
assert "Thread:" in content
659+
assert "User:" in content
660+
assert "Assistant:" in content
661+
assert "Count: 3" in content
662+
assert "First: user: Hello" in content
663+
assert "Users:" in content
664+
665+
666+
@respx.mock
667+
def test_llm_classifier_does_not_fetch_thread_when_template_does_not_use_it():
668+
class TraceStub:
669+
def __init__(self):
670+
self.calls = 0
671+
672+
async def get_thread(self):
673+
self.calls += 1
674+
return [{"role": "user", "content": "unused"}]
675+
676+
trace = TraceStub()
677+
678+
respx.post("https://api.openai.com/v1/chat/completions").mock(
679+
return_value=Response(
680+
200,
681+
json={
682+
"id": "chatcmpl-test",
683+
"object": "chat.completion",
684+
"created": 1234567890,
685+
"model": "gpt-4o",
686+
"choices": [
687+
{
688+
"index": 0,
689+
"message": {
690+
"role": "assistant",
691+
"content": None,
692+
"tool_calls": [
693+
{
694+
"id": "call_test",
695+
"type": "function",
696+
"function": {"name": "select_choice", "arguments": '{"choice": "1"}'},
697+
}
698+
],
699+
},
700+
"finish_reason": "tool_calls",
701+
}
702+
],
703+
"usage": {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30},
704+
},
705+
)
706+
)
707+
708+
client = OpenAI(api_key="test-api-key", base_url="https://api.openai.com/v1")
709+
init(client)
710+
711+
classifier = LLMClassifier(
712+
"thread_unused_test",
713+
"Output: {{output}}",
714+
{"1": 1, "2": 0},
715+
)
716+
classifier.eval(output="x", expected="y", trace=trace)
717+
718+
assert trace.calls == 0

0 commit comments

Comments
 (0)