4545```
4646"""
4747
48+ import asyncio
49+ import inspect
4850import json
4951import os
5052import re
5153from collections import defaultdict
54+ from collections .abc import Callable
5255from dataclasses import dataclass
5356
5457import chevron
5861
5962from .oai import Client , arun_cached_request , get_default_model , run_cached_request
6063from .score import Score
64+ from .thread_utils import (
65+ THREAD_VARIABLE_NAMES ,
66+ compute_thread_template_vars ,
67+ template_uses_thread_variables ,
68+ )
6169
6270# Disable HTML escaping in chevron.
6371chevron .renderer ._html_escape = lambda x : x # type: ignore[attr-defined]
@@ -243,6 +251,9 @@ def _request_args(self, output, expected, **kwargs):
243251
244252 return ret
245253
254+ async def _request_args_async (self , output , expected , ** kwargs ):
255+ return self ._request_args (output , expected , ** kwargs )
256+
246257 def _process_response (self , resp ):
247258 metadata = {}
248259 if "tool_calls" not in resp :
@@ -268,7 +279,9 @@ def _postprocess_response(self, resp):
268279 raise ValueError ("Empty response from OpenAI" )
269280
270281 async def _run_eval_async (self , output , expected , ** kwargs ):
271- return self ._postprocess_response (await arun_cached_request (** self ._request_args (output , expected , ** kwargs )))
282+ return self ._postprocess_response (
283+ await arun_cached_request (** (await self ._request_args_async (output , expected , ** kwargs )))
284+ )
272285
273286 def _run_eval_sync (self , output , expected , ** kwargs ):
274287 return self ._postprocess_response (run_cached_request (** self ._request_args (output , expected , ** kwargs )))
@@ -330,10 +343,15 @@ class LLMClassifier(OpenAILLMClassifier):
330343 api_key: Deprecated. Use client instead.
331344 base_url: Deprecated. Use client instead.
332345 client: OpenAI client. If not provided, uses global client from init().
346+ trace: Optional trace object for multi-turn scoring. When provided at
347+ evaluation time and the template references thread variables
348+ (`{{thread}}`, `{{thread_count}}`, etc.), thread variables are
349+ derived from `trace.get_thread()` / `trace.getThread()`.
333350 **extra_render_args: Additional template variables
334351 """
335352
336353 _SPEC_FILE_CONTENTS : dict [str , str ] = defaultdict (str )
354+ _thread_variable_names = THREAD_VARIABLE_NAMES
337355
338356 def __init__ (
339357 self ,
@@ -353,6 +371,7 @@ def __init__(
353371 client : Client | None = None ,
354372 ** extra_render_args ,
355373 ):
374+ self ._template_uses_thread_variables = template_uses_thread_variables (prompt_template )
356375 choice_strings = list (choice_scores .keys ())
357376 # Use configured default model if not specified
358377 if model is None :
@@ -384,6 +403,67 @@ def __init__(
384403 client = client ,
385404 )
386405
406+ @staticmethod
407+ def _get_trace_thread_method (trace ) -> Callable [..., object ] | None :
408+ if hasattr (trace , "get_thread" ) and callable (trace .get_thread ):
409+ return trace .get_thread
410+ return None
411+
412+ def _compute_thread_vars_sync (self , trace ) -> dict [str , object ]:
413+ method = self ._get_trace_thread_method (trace )
414+ if method is None :
415+ raise TypeError ("trace must implement async get_thread(options=None)" )
416+
417+ thread_awaitable = method ()
418+ if not inspect .isawaitable (thread_awaitable ):
419+ raise TypeError ("trace.get_thread() must return an awaitable" )
420+ try :
421+ asyncio .get_running_loop ()
422+ except RuntimeError :
423+ thread = asyncio .run (thread_awaitable )
424+ else :
425+ raise RuntimeError ("trace.get_thread() is async; use eval_async() when already inside an event loop" )
426+
427+ if not isinstance (thread , list ):
428+ thread = list (thread )
429+
430+ computed = compute_thread_template_vars (thread )
431+ return {name : computed [name ] for name in self ._thread_variable_names }
432+
433+ async def _compute_thread_vars_async (self , trace ) -> dict [str , object ]:
434+ method = self ._get_trace_thread_method (trace )
435+ if method is None :
436+ raise TypeError ("trace must implement async get_thread(options=None)" )
437+
438+ thread_awaitable = method ()
439+ if not inspect .isawaitable (thread_awaitable ):
440+ raise TypeError ("trace.get_thread() must return an awaitable" )
441+ thread = await thread_awaitable
442+
443+ if not isinstance (thread , list ):
444+ thread = list (thread )
445+
446+ computed = compute_thread_template_vars (thread )
447+ return {name : computed [name ] for name in self ._thread_variable_names }
448+
449+ def _request_args (self , output , expected , ** kwargs ):
450+ trace = kwargs .get ("trace" )
451+ thread_vars : dict [str , object ] = {}
452+ if trace is not None and self ._template_uses_thread_variables :
453+ thread_vars = self ._compute_thread_vars_sync (trace )
454+
455+ # Thread vars come first so explicit render args can override.
456+ return super ()._request_args (output , expected , ** thread_vars , ** kwargs )
457+
458+ async def _request_args_async (self , output , expected , ** kwargs ):
459+ trace = kwargs .get ("trace" )
460+ thread_vars : dict [str , object ] = {}
461+ if trace is not None and self ._template_uses_thread_variables :
462+ thread_vars = await self ._compute_thread_vars_async (trace )
463+
464+ # Thread vars come first so explicit render args can override.
465+ return super ()._request_args (output , expected , ** thread_vars , ** kwargs )
466+
387467 @classmethod
388468 def from_spec (cls , name : str , spec : ModelGradedSpec , client : Client | None = None , ** kwargs ):
389469 spec_kwargs = {}
0 commit comments