Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 58 additions & 13 deletions py/src/braintrust/integrations/dspy/test_dspy.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,28 +38,73 @@ def test_dspy_callback(memory_logger):

# Check logged spans
spans = memory_logger.pop()
assert len(spans) >= 2 # Should have module span and LM span
assert len(spans) >= 4 # Should have module, adapter format, LM, and adapter parse spans

# Find LM span by checking span_attributes
lm_spans = [s for s in spans if s.get("span_attributes", {}).get("name") == "dspy.lm"]
assert len(lm_spans) >= 1
spans_by_name = {span["span_attributes"]["name"]: span for span in spans}

lm_span = lm_spans[0]
# Verify metadata
lm_span = spans_by_name["dspy.lm"]
assert "metadata" in lm_span
assert "model" in lm_span["metadata"]
assert MODEL in lm_span["metadata"]["model"]

# Verify input/output
assert "input" in lm_span
assert "output" in lm_span

# Find module span
module_spans = [s for s in spans if "module" in s.get("span_attributes", {}).get("name", "")]
assert len(module_spans) >= 1
format_span = spans_by_name["dspy.adapter.format"]
parse_span = spans_by_name["dspy.adapter.parse"]

assert format_span["metadata"]["adapter_class"].endswith("ChatAdapter")
assert "signature" in format_span["input"]
assert "demos" in format_span["input"]
assert "inputs" in format_span["input"]
assert isinstance(format_span["output"], list)

assert parse_span["metadata"]["adapter_class"].endswith("ChatAdapter")
assert "signature" in parse_span["input"]
assert "completion" in parse_span["input"]
assert isinstance(parse_span["output"], dict)

# Verify spans are nested under the broader DSPy execution
span_ids = {span["span_id"] for span in spans}
assert lm_span.get("span_parents")
assert format_span.get("span_parents")
assert parse_span.get("span_parents")
assert format_span["span_parents"][0] in span_ids
assert lm_span["span_parents"][0] in span_ids
assert parse_span["span_parents"][0] in span_ids


def test_dspy_adapter_callbacks(memory_logger):
"""Adapter format/parse callbacks should log spans without an LM call."""
assert not memory_logger.pop()

dspy.configure(callbacks=[BraintrustDSpyCallback()])

signature = dspy.make_signature("question -> answer")
adapter = dspy.ChatAdapter()
formatted = adapter.format(
signature,
demos=[{"question": "1+1", "answer": "2"}],
inputs={"question": "2+2"},
)
parsed = adapter.parse(signature, "[[ ## answer ## ]]\n4")

assert formatted
assert parsed == {"answer": "4"}

spans = memory_logger.pop()
assert len(spans) == 2

spans_by_name = {span["span_attributes"]["name"]: span for span in spans}
format_span = spans_by_name["dspy.adapter.format"]
parse_span = spans_by_name["dspy.adapter.parse"]

assert format_span["metadata"]["adapter_class"].endswith("ChatAdapter")
assert format_span["input"]["inputs"] == {"question": "2+2"}
assert format_span["output"] == formatted

# Verify span parenting (LM span should have parent)
assert lm_span.get("span_parents") # LM span should have parent
assert parse_span["metadata"]["adapter_class"].endswith("ChatAdapter")
assert parse_span["input"]["completion"] == "[[ ## answer ## ]]\n4"
assert parse_span["output"] == parsed


class TestPatchDSPy:
Expand Down
152 changes: 107 additions & 45 deletions py/src/braintrust/integrations/dspy/tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ class BraintrustDSpyCallback(BaseCallback):

The callback creates Braintrust spans for:
- DSPy module executions (Predict, ChainOfThought, ReAct, etc.)
- Adapter formatting and parsing steps
- LLM calls with latency metrics
- Tool calls
- Evaluation runs
Expand Down Expand Up @@ -121,19 +122,13 @@ def on_lm_start(
span.set_current()
self._spans[call_id] = span

def on_lm_end(
def _end_span(
self,
call_id: str,
outputs: dict[str, Any] | None,
outputs: Any | None,
exception: Exception | None = None,
):
"""Log the end of a language model call.

Args:
call_id: Unique identifier for this call
outputs: Output from the LM, or None if there was an exception
exception: Exception raised during execution, if any
"""
"""Pop span by call_id, log outputs/exception, and end it."""
span = self._spans.pop(call_id, None)
if not span:
return
Expand All @@ -151,6 +146,21 @@ def on_lm_end(
span.unset_current()
span.end()

def on_lm_end(
self,
call_id: str,
outputs: dict[str, Any] | None,
exception: Exception | None = None,
):
"""Log the end of a language model call.

Args:
call_id: Unique identifier for this call
outputs: Output from the LM, or None if there was an exception
exception: Exception raised during execution, if any
"""
self._end_span(call_id, outputs, exception)

def on_module_start(
self,
call_id: str,
Expand Down Expand Up @@ -193,28 +203,95 @@ def on_module_end(
outputs: Output from the module, or None if there was an exception
exception: Exception raised during execution, if any
"""
span = self._spans.pop(call_id, None)
if not span:
return
if outputs is not None:
if hasattr(outputs, "toDict"):
outputs = outputs.toDict()
elif hasattr(outputs, "__dict__"):
outputs = outputs.__dict__
self._end_span(call_id, outputs, exception)

def _start_adapter_span(
self,
call_id: str,
instance: Any,
inputs: dict[str, Any],
span_name: str,
):
"""Create and store a span for an adapter format/parse call."""
cls = instance.__class__
metadata = {"adapter_class": f"{cls.__module__}.{cls.__name__}"}

try:
log_data = {}
if exception:
log_data["error"] = exception
if outputs is not None:
if hasattr(outputs, "toDict"):
output_dict = outputs.toDict()
elif hasattr(outputs, "__dict__"):
output_dict = outputs.__dict__
else:
output_dict = outputs
log_data["output"] = output_dict
parent = current_span()
parent_export = parent.export() if parent else None

if log_data:
span.log(**log_data)
finally:
span.unset_current()
span.end()
span = start_span(
name=span_name,
input=inputs,
metadata=metadata,
parent=parent_export,
)
span.set_current()
self._spans[call_id] = span

def on_adapter_format_start(
self,
call_id: str,
instance: Any,
inputs: dict[str, Any],
):
"""Log the start of an adapter format call.

Args:
call_id: Unique identifier for this call
instance: The Adapter instance being called
inputs: Input parameters to the adapter's format() method
"""
self._start_adapter_span(call_id, instance, inputs, "dspy.adapter.format")

def on_adapter_format_end(
self,
call_id: str,
outputs: list[dict[str, Any]] | None,
exception: Exception | None = None,
):
"""Log the end of an adapter format call.

Args:
call_id: Unique identifier for this call
outputs: Output from the adapter's format() method, or None if there was an exception
exception: Exception raised during execution, if any
"""
self._end_span(call_id, outputs, exception)

def on_adapter_parse_start(
self,
call_id: str,
instance: Any,
inputs: dict[str, Any],
):
"""Log the start of an adapter parse call.

Args:
call_id: Unique identifier for this call
instance: The Adapter instance being called
inputs: Input parameters to the adapter's parse() method
"""
self._start_adapter_span(call_id, instance, inputs, "dspy.adapter.parse")

def on_adapter_parse_end(
self,
call_id: str,
outputs: dict[str, Any] | None,
exception: Exception | None = None,
):
"""Log the end of an adapter parse call.

Args:
call_id: Unique identifier for this call
outputs: Output from the adapter's parse() method, or None if there was an exception
exception: Exception raised during execution, if any
"""
self._end_span(call_id, outputs, exception)

def on_tool_start(
self,
Expand Down Expand Up @@ -262,22 +339,7 @@ def on_tool_end(
outputs: Output from the tool, or None if there was an exception
exception: Exception raised during execution, if any
"""
span = self._spans.pop(call_id, None)
if not span:
return

try:
log_data = {}
if exception:
log_data["error"] = exception
if outputs is not None:
log_data["output"] = outputs

if log_data:
span.log(**log_data)
finally:
span.unset_current()
span.end()
self._end_span(call_id, outputs, exception)

def on_evaluate_start(
self,
Expand Down
Loading