diff --git a/py/src/braintrust/integrations/dspy/test_dspy.py b/py/src/braintrust/integrations/dspy/test_dspy.py index f6a0f1da..3e5a44b6 100644 --- a/py/src/braintrust/integrations/dspy/test_dspy.py +++ b/py/src/braintrust/integrations/dspy/test_dspy.py @@ -38,28 +38,73 @@ def test_dspy_callback(memory_logger): # Check logged spans spans = memory_logger.pop() - assert len(spans) >= 2 # Should have module span and LM span + assert len(spans) >= 4 # Should have module, adapter format, LM, and adapter parse spans - # Find LM span by checking span_attributes - lm_spans = [s for s in spans if s.get("span_attributes", {}).get("name") == "dspy.lm"] - assert len(lm_spans) >= 1 + spans_by_name = {span["span_attributes"]["name"]: span for span in spans} - lm_span = lm_spans[0] - # Verify metadata + lm_span = spans_by_name["dspy.lm"] assert "metadata" in lm_span assert "model" in lm_span["metadata"] assert MODEL in lm_span["metadata"]["model"] - - # Verify input/output assert "input" in lm_span assert "output" in lm_span - # Find module span - module_spans = [s for s in spans if "module" in s.get("span_attributes", {}).get("name", "")] - assert len(module_spans) >= 1 + format_span = spans_by_name["dspy.adapter.format"] + parse_span = spans_by_name["dspy.adapter.parse"] + + assert format_span["metadata"]["adapter_class"].endswith("ChatAdapter") + assert "signature" in format_span["input"] + assert "demos" in format_span["input"] + assert "inputs" in format_span["input"] + assert isinstance(format_span["output"], list) + + assert parse_span["metadata"]["adapter_class"].endswith("ChatAdapter") + assert "signature" in parse_span["input"] + assert "completion" in parse_span["input"] + assert isinstance(parse_span["output"], dict) + + # Verify spans are nested under the broader DSPy execution + span_ids = {span["span_id"] for span in spans} + assert lm_span.get("span_parents") + assert format_span.get("span_parents") + assert parse_span.get("span_parents") + assert format_span["span_parents"][0] in span_ids + assert lm_span["span_parents"][0] in span_ids + assert parse_span["span_parents"][0] in span_ids + + +def test_dspy_adapter_callbacks(memory_logger): + """Adapter format/parse callbacks should log spans without an LM call.""" + assert not memory_logger.pop() + + dspy.configure(callbacks=[BraintrustDSpyCallback()]) + + signature = dspy.make_signature("question -> answer") + adapter = dspy.ChatAdapter() + formatted = adapter.format( + signature, + demos=[{"question": "1+1", "answer": "2"}], + inputs={"question": "2+2"}, + ) + parsed = adapter.parse(signature, "[[ ## answer ## ]]\n4") + + assert formatted + assert parsed == {"answer": "4"} + + spans = memory_logger.pop() + assert len(spans) == 2 + + spans_by_name = {span["span_attributes"]["name"]: span for span in spans} + format_span = spans_by_name["dspy.adapter.format"] + parse_span = spans_by_name["dspy.adapter.parse"] + + assert format_span["metadata"]["adapter_class"].endswith("ChatAdapter") + assert format_span["input"]["inputs"] == {"question": "2+2"} + assert format_span["output"] == formatted - # Verify span parenting (LM span should have parent) - assert lm_span.get("span_parents") # LM span should have parent + assert parse_span["metadata"]["adapter_class"].endswith("ChatAdapter") + assert parse_span["input"]["completion"] == "[[ ## answer ## ]]\n4" + assert parse_span["output"] == parsed class TestPatchDSPy: diff --git a/py/src/braintrust/integrations/dspy/tracing.py b/py/src/braintrust/integrations/dspy/tracing.py index e771edb9..dbd376f7 100644 --- a/py/src/braintrust/integrations/dspy/tracing.py +++ b/py/src/braintrust/integrations/dspy/tracing.py @@ -68,6 +68,7 @@ class BraintrustDSpyCallback(BaseCallback): The callback creates Braintrust spans for: - DSPy module executions (Predict, ChainOfThought, ReAct, etc.) + - Adapter formatting and parsing steps - LLM calls with latency metrics - Tool calls - Evaluation runs @@ -121,19 +122,13 @@ def on_lm_start( span.set_current() self._spans[call_id] = span - def on_lm_end( + def _end_span( self, call_id: str, - outputs: dict[str, Any] | None, + outputs: Any | None, exception: Exception | None = None, ): - """Log the end of a language model call. - - Args: - call_id: Unique identifier for this call - outputs: Output from the LM, or None if there was an exception - exception: Exception raised during execution, if any - """ + """Pop span by call_id, log outputs/exception, and end it.""" span = self._spans.pop(call_id, None) if not span: return @@ -151,6 +146,21 @@ def on_lm_end( span.unset_current() span.end() + def on_lm_end( + self, + call_id: str, + outputs: dict[str, Any] | None, + exception: Exception | None = None, + ): + """Log the end of a language model call. + + Args: + call_id: Unique identifier for this call + outputs: Output from the LM, or None if there was an exception + exception: Exception raised during execution, if any + """ + self._end_span(call_id, outputs, exception) + def on_module_start( self, call_id: str, @@ -193,28 +203,95 @@ def on_module_end( outputs: Output from the module, or None if there was an exception exception: Exception raised during execution, if any """ - span = self._spans.pop(call_id, None) - if not span: - return + if outputs is not None: + if hasattr(outputs, "toDict"): + outputs = outputs.toDict() + elif hasattr(outputs, "__dict__"): + outputs = outputs.__dict__ + self._end_span(call_id, outputs, exception) + + def _start_adapter_span( + self, + call_id: str, + instance: Any, + inputs: dict[str, Any], + span_name: str, + ): + """Create and store a span for an adapter format/parse call.""" + cls = instance.__class__ + metadata = {"adapter_class": f"{cls.__module__}.{cls.__name__}"} - try: - log_data = {} - if exception: - log_data["error"] = exception - if outputs is not None: - if hasattr(outputs, "toDict"): - output_dict = outputs.toDict() - elif hasattr(outputs, "__dict__"): - output_dict = outputs.__dict__ - else: - output_dict = outputs - log_data["output"] = output_dict + parent = current_span() + parent_export = parent.export() if parent else None - if log_data: - span.log(**log_data) - finally: - span.unset_current() - span.end() + span = start_span( + name=span_name, + input=inputs, + metadata=metadata, + parent=parent_export, + ) + span.set_current() + self._spans[call_id] = span + + def on_adapter_format_start( + self, + call_id: str, + instance: Any, + inputs: dict[str, Any], + ): + """Log the start of an adapter format call. + + Args: + call_id: Unique identifier for this call + instance: The Adapter instance being called + inputs: Input parameters to the adapter's format() method + """ + self._start_adapter_span(call_id, instance, inputs, "dspy.adapter.format") + + def on_adapter_format_end( + self, + call_id: str, + outputs: list[dict[str, Any]] | None, + exception: Exception | None = None, + ): + """Log the end of an adapter format call. + + Args: + call_id: Unique identifier for this call + outputs: Output from the adapter's format() method, or None if there was an exception + exception: Exception raised during execution, if any + """ + self._end_span(call_id, outputs, exception) + + def on_adapter_parse_start( + self, + call_id: str, + instance: Any, + inputs: dict[str, Any], + ): + """Log the start of an adapter parse call. + + Args: + call_id: Unique identifier for this call + instance: The Adapter instance being called + inputs: Input parameters to the adapter's parse() method + """ + self._start_adapter_span(call_id, instance, inputs, "dspy.adapter.parse") + + def on_adapter_parse_end( + self, + call_id: str, + outputs: dict[str, Any] | None, + exception: Exception | None = None, + ): + """Log the end of an adapter parse call. + + Args: + call_id: Unique identifier for this call + outputs: Output from the adapter's parse() method, or None if there was an exception + exception: Exception raised during execution, if any + """ + self._end_span(call_id, outputs, exception) def on_tool_start( self, @@ -262,22 +339,7 @@ def on_tool_end( outputs: Output from the tool, or None if there was an exception exception: Exception raised during execution, if any """ - span = self._spans.pop(call_id, None) - if not span: - return - - try: - log_data = {} - if exception: - log_data["error"] = exception - if outputs is not None: - log_data["output"] = outputs - - if log_data: - span.log(**log_data) - finally: - span.unset_current() - span.end() + self._end_span(call_id, outputs, exception) def on_evaluate_start( self,