Skip to content

Commit ec903da

Browse files
committed
fix(dspy): add adapter callback tracing
Implement DSPy adapter format and parse callbacks in BraintrustDSpyCallback so adapter-level prompt formatting and output parsing are traced alongside module and LM spans. Add focused tests for the new adapter spans and extend the end-to-end DSPy integration test to assert the new callback coverage. Closes #176
1 parent a656bab commit ec903da

File tree

2 files changed

+165
-58
lines changed

2 files changed

+165
-58
lines changed

py/src/braintrust/integrations/dspy/test_dspy.py

Lines changed: 58 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -38,28 +38,73 @@ def test_dspy_callback(memory_logger):
3838

3939
# Check logged spans
4040
spans = memory_logger.pop()
41-
assert len(spans) >= 2 # Should have module span and LM span
41+
assert len(spans) >= 4 # Should have module, adapter format, LM, and adapter parse spans
4242

43-
# Find LM span by checking span_attributes
44-
lm_spans = [s for s in spans if s.get("span_attributes", {}).get("name") == "dspy.lm"]
45-
assert len(lm_spans) >= 1
43+
spans_by_name = {span["span_attributes"]["name"]: span for span in spans}
4644

47-
lm_span = lm_spans[0]
48-
# Verify metadata
45+
lm_span = spans_by_name["dspy.lm"]
4946
assert "metadata" in lm_span
5047
assert "model" in lm_span["metadata"]
5148
assert MODEL in lm_span["metadata"]["model"]
52-
53-
# Verify input/output
5449
assert "input" in lm_span
5550
assert "output" in lm_span
5651

57-
# Find module span
58-
module_spans = [s for s in spans if "module" in s.get("span_attributes", {}).get("name", "")]
59-
assert len(module_spans) >= 1
52+
format_span = spans_by_name["dspy.adapter.format"]
53+
parse_span = spans_by_name["dspy.adapter.parse"]
54+
55+
assert format_span["metadata"]["adapter_class"].endswith("ChatAdapter")
56+
assert "signature" in format_span["input"]
57+
assert "demos" in format_span["input"]
58+
assert "inputs" in format_span["input"]
59+
assert isinstance(format_span["output"], list)
60+
61+
assert parse_span["metadata"]["adapter_class"].endswith("ChatAdapter")
62+
assert "signature" in parse_span["input"]
63+
assert "completion" in parse_span["input"]
64+
assert isinstance(parse_span["output"], dict)
65+
66+
# Verify spans are nested under the broader DSPy execution
67+
span_ids = {span["span_id"] for span in spans}
68+
assert lm_span.get("span_parents")
69+
assert format_span.get("span_parents")
70+
assert parse_span.get("span_parents")
71+
assert format_span["span_parents"][0] in span_ids
72+
assert lm_span["span_parents"][0] in span_ids
73+
assert parse_span["span_parents"][0] in span_ids
74+
75+
76+
def test_dspy_adapter_callbacks(memory_logger):
77+
"""Adapter format/parse callbacks should log spans without an LM call."""
78+
assert not memory_logger.pop()
79+
80+
dspy.configure(callbacks=[BraintrustDSpyCallback()])
81+
82+
signature = dspy.Signature("question -> answer")
83+
adapter = dspy.ChatAdapter()
84+
formatted = adapter.format(
85+
signature,
86+
demos=[{"question": "1+1", "answer": "2"}],
87+
inputs={"question": "2+2"},
88+
)
89+
parsed = adapter.parse(signature, "[[ ## answer ## ]]\n4")
90+
91+
assert formatted
92+
assert parsed == {"answer": "4"}
93+
94+
spans = memory_logger.pop()
95+
assert len(spans) == 2
96+
97+
spans_by_name = {span["span_attributes"]["name"]: span for span in spans}
98+
format_span = spans_by_name["dspy.adapter.format"]
99+
parse_span = spans_by_name["dspy.adapter.parse"]
100+
101+
assert format_span["metadata"]["adapter_class"].endswith("ChatAdapter")
102+
assert format_span["input"]["inputs"] == {"question": "2+2"}
103+
assert format_span["output"] == formatted
60104

61-
# Verify span parenting (LM span should have parent)
62-
assert lm_span.get("span_parents") # LM span should have parent
105+
assert parse_span["metadata"]["adapter_class"].endswith("ChatAdapter")
106+
assert parse_span["input"]["completion"] == "[[ ## answer ## ]]\n4"
107+
assert parse_span["output"] == parsed
63108

64109

65110
class TestPatchDSPy:

py/src/braintrust/integrations/dspy/tracing.py

Lines changed: 107 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ class BraintrustDSpyCallback(BaseCallback):
6868
6969
The callback creates Braintrust spans for:
7070
- DSPy module executions (Predict, ChainOfThought, ReAct, etc.)
71+
- Adapter formatting and parsing steps
7172
- LLM calls with latency metrics
7273
- Tool calls
7374
- Evaluation runs
@@ -121,19 +122,13 @@ def on_lm_start(
121122
span.set_current()
122123
self._spans[call_id] = span
123124

124-
def on_lm_end(
125+
def _end_span(
125126
self,
126127
call_id: str,
127-
outputs: dict[str, Any] | None,
128+
outputs: Any | None,
128129
exception: Exception | None = None,
129130
):
130-
"""Log the end of a language model call.
131-
132-
Args:
133-
call_id: Unique identifier for this call
134-
outputs: Output from the LM, or None if there was an exception
135-
exception: Exception raised during execution, if any
136-
"""
131+
"""Pop span by call_id, log outputs/exception, and end it."""
137132
span = self._spans.pop(call_id, None)
138133
if not span:
139134
return
@@ -151,6 +146,21 @@ def on_lm_end(
151146
span.unset_current()
152147
span.end()
153148

149+
def on_lm_end(
150+
self,
151+
call_id: str,
152+
outputs: dict[str, Any] | None,
153+
exception: Exception | None = None,
154+
):
155+
"""Log the end of a language model call.
156+
157+
Args:
158+
call_id: Unique identifier for this call
159+
outputs: Output from the LM, or None if there was an exception
160+
exception: Exception raised during execution, if any
161+
"""
162+
self._end_span(call_id, outputs, exception)
163+
154164
def on_module_start(
155165
self,
156166
call_id: str,
@@ -193,28 +203,95 @@ def on_module_end(
193203
outputs: Output from the module, or None if there was an exception
194204
exception: Exception raised during execution, if any
195205
"""
196-
span = self._spans.pop(call_id, None)
197-
if not span:
198-
return
206+
if outputs is not None:
207+
if hasattr(outputs, "toDict"):
208+
outputs = outputs.toDict()
209+
elif hasattr(outputs, "__dict__"):
210+
outputs = outputs.__dict__
211+
self._end_span(call_id, outputs, exception)
212+
213+
def _start_adapter_span(
214+
self,
215+
call_id: str,
216+
instance: Any,
217+
inputs: dict[str, Any],
218+
span_name: str,
219+
):
220+
"""Create and store a span for an adapter format/parse call."""
221+
cls = instance.__class__
222+
metadata = {"adapter_class": f"{cls.__module__}.{cls.__name__}"}
199223

200-
try:
201-
log_data = {}
202-
if exception:
203-
log_data["error"] = exception
204-
if outputs is not None:
205-
if hasattr(outputs, "toDict"):
206-
output_dict = outputs.toDict()
207-
elif hasattr(outputs, "__dict__"):
208-
output_dict = outputs.__dict__
209-
else:
210-
output_dict = outputs
211-
log_data["output"] = output_dict
224+
parent = current_span()
225+
parent_export = parent.export() if parent else None
212226

213-
if log_data:
214-
span.log(**log_data)
215-
finally:
216-
span.unset_current()
217-
span.end()
227+
span = start_span(
228+
name=span_name,
229+
input=inputs,
230+
metadata=metadata,
231+
parent=parent_export,
232+
)
233+
span.set_current()
234+
self._spans[call_id] = span
235+
236+
def on_adapter_format_start(
237+
self,
238+
call_id: str,
239+
instance: Any,
240+
inputs: dict[str, Any],
241+
):
242+
"""Log the start of an adapter format call.
243+
244+
Args:
245+
call_id: Unique identifier for this call
246+
instance: The Adapter instance being called
247+
inputs: Input parameters to the adapter's format() method
248+
"""
249+
self._start_adapter_span(call_id, instance, inputs, "dspy.adapter.format")
250+
251+
def on_adapter_format_end(
252+
self,
253+
call_id: str,
254+
outputs: list[dict[str, Any]] | None,
255+
exception: Exception | None = None,
256+
):
257+
"""Log the end of an adapter format call.
258+
259+
Args:
260+
call_id: Unique identifier for this call
261+
outputs: Output from the adapter's format() method, or None if there was an exception
262+
exception: Exception raised during execution, if any
263+
"""
264+
self._end_span(call_id, outputs, exception)
265+
266+
def on_adapter_parse_start(
267+
self,
268+
call_id: str,
269+
instance: Any,
270+
inputs: dict[str, Any],
271+
):
272+
"""Log the start of an adapter parse call.
273+
274+
Args:
275+
call_id: Unique identifier for this call
276+
instance: The Adapter instance being called
277+
inputs: Input parameters to the adapter's parse() method
278+
"""
279+
self._start_adapter_span(call_id, instance, inputs, "dspy.adapter.parse")
280+
281+
def on_adapter_parse_end(
282+
self,
283+
call_id: str,
284+
outputs: dict[str, Any] | None,
285+
exception: Exception | None = None,
286+
):
287+
"""Log the end of an adapter parse call.
288+
289+
Args:
290+
call_id: Unique identifier for this call
291+
outputs: Output from the adapter's parse() method, or None if there was an exception
292+
exception: Exception raised during execution, if any
293+
"""
294+
self._end_span(call_id, outputs, exception)
218295

219296
def on_tool_start(
220297
self,
@@ -262,22 +339,7 @@ def on_tool_end(
262339
outputs: Output from the tool, or None if there was an exception
263340
exception: Exception raised during execution, if any
264341
"""
265-
span = self._spans.pop(call_id, None)
266-
if not span:
267-
return
268-
269-
try:
270-
log_data = {}
271-
if exception:
272-
log_data["error"] = exception
273-
if outputs is not None:
274-
log_data["output"] = outputs
275-
276-
if log_data:
277-
span.log(**log_data)
278-
finally:
279-
span.unset_current()
280-
span.end()
342+
self._end_span(call_id, outputs, exception)
281343

282344
def on_evaluate_start(
283345
self,

0 commit comments

Comments
 (0)