diff --git a/README.md b/README.md index 4ae2ed6..8a2c8a4 100644 --- a/README.md +++ b/README.md @@ -105,7 +105,11 @@ If you prefer explicit control, you can add tracing middleware manually to your ## Evaluations -Run [evals](https://www.braintrust.dev/docs/guides/evals) with custom test cases and scoring functions: +Run [evals](https://www.braintrust.dev/docs/guides/evals) with custom test cases and scoring functions. + +### Define and run + +Define an eval once with its task and scorers, then run it against any dataset: ```go package main @@ -135,16 +139,9 @@ func main() { log.Fatal(err) } - // Create an evaluator with your task's input and output types - evaluator := braintrust.NewEvaluator[string, string](client) - - // Run an evaluation - _, err = evaluator.Run(ctx, eval.Opts[string, string]{ - Experiment: "greeting-experiment", - Dataset: eval.NewDataset([]eval.Case[string, string]{ - {Input: "World", Expected: "Hello World"}, - {Input: "Alice", Expected: "Hello Alice"}, - }), + // Create an eval + e := braintrust.NewEval(client, &eval.Eval[string, string]{ + Name: "greeting-experiment", Task: eval.T(func(ctx context.Context, input string) (string, error) { return "Hello " + input, nil }), @@ -158,12 +155,64 @@ func main() { }), }, }) + + // Run against a dataset + _, err = e.Run(ctx, eval.RunOpts[string, string]{ + Dataset: eval.NewDataset([]eval.Case[string, string]{ + {Input: "World", Expected: "Hello World"}, + {Input: "Alice", Expected: "Hello Alice"}, + }), + }) if err != nil { log.Fatal(err) } } ``` +### Remote Eval Server + +The same eval definition can be registered with a [remote eval server](https://www.braintrust.dev/docs/evaluate/remote-evals), letting you run evals from the Braintrust playground against code on your own infrastructure: + +```go +package main + +import ( + "context" + "log" + "strings" + + "github.com/braintrustdata/braintrust-sdk-go/eval" + "github.com/braintrustdata/braintrust-sdk-go/server" +) + +func main() { + // Define the eval once + classify := &eval.Eval[string, string]{ + Name: "classify", + Task: eval.T(func(ctx context.Context, input string) (string, error) { + return strings.ToUpper(input), nil + }), + Scorers: []eval.Scorer[string, string]{ + eval.NewScorer("exact_match", func(ctx context.Context, r eval.TaskResult[string, string]) (eval.Scores, error) { + if r.Output == r.Expected { return eval.S(1.0), nil } + return eval.S(0.0), nil + }), + }, + } + + // Register with server for remote execution + srv := server.New( + server.WithAddress("localhost:8300"), + server.WithNoAuth(), // Remove for production + ) + server.RegisterEval(srv, classify, server.RegisterEvalOpts{}) + + log.Fatal(srv.Start()) +} +``` + +Then configure `http://localhost:8300` in your Braintrust project settings under **Remote evals**. + ## API Client Manage Braintrust resources programmatically: @@ -244,10 +293,12 @@ Complete working examples are available in [`examples/`](./examples/): - **[langchaingo](./examples/langchaingo/main.go)** - LangChainGo integration - **[datasets](./examples/datasets/main.go)** - Using Braintrust datasets - **[adk-go](./examples/adk/main.go)** - ADK integration +- **[eval-server](./examples/internal/eval-server/main.go)** - Remote eval server ## Features - **Evaluations** - Systematic testing with custom scoring functions +- **Remote Eval Server** - Run evals from the Braintrust UI against your own code - **Tracing** - Automatic instrumentation for major LLM providers - **Datasets** - Manage and version evaluation datasets - **Experiments** - Track versions and configurations diff --git a/client.go b/client.go index 089d9e7..2843057 100644 --- a/client.go +++ b/client.go @@ -191,30 +191,24 @@ func (c *Client) Tracer(name string, opts ...oteltrace.TracerOption) oteltrace.T return c.tracerProvider.Tracer(name, opts...) } -// NewEvaluator creates a new evaluator for running multiple evaluations with the same -// input and output types. +// NewEval creates a runnable [eval.Eval] by combining a client with an eval definition. // // Example: // // client, _ := braintrust.New(tp) -// -// // Create an evaluator for string → string evaluations -// evaluator := braintrust.NewEvaluator[string, string](client) -// -// // Run multiple evaluations -// result1, _ := evaluator.Run(ctx, eval.Opts[string, string]{ -// Experiment: "test-1", -// Dataset: dataset1, -// Task: task1, -// Scorers: scorers, -// }) -// -// result2, _ := evaluator.Run(ctx, eval.Opts[string, string]{ -// Experiment: "test-2", -// Dataset: dataset2, -// Task: task2, -// Scorers: scorers, +// e := braintrust.NewEval(client, &eval.Eval[string, string]{ +// Name: "classify", +// Task: task, +// Scorers: scorers, // }) +// result, _ := e.Run(ctx, eval.RunOpts[string, string]{Dataset: dataset}) +func NewEval[I, R any](client *Client, e *eval.Eval[I, R]) *eval.Eval[I, R] { + evaluator := eval.NewEvaluator[I, R](client.session, client.tracerProvider, client.API(), client.config.DefaultProjectName) + return eval.NewEval(evaluator, e) +} + +// NewEvaluator creates a new evaluator for running evaluations with the same +// input and output types. func NewEvaluator[I, R any](client *Client) *eval.Evaluator[I, R] { return eval.NewEvaluator[I, R](client.session, client.tracerProvider, client.API(), client.config.DefaultProjectName) } diff --git a/eval/eval.go b/eval/eval.go index 774cadd..befb6e7 100644 --- a/eval/eval.go +++ b/eval/eval.go @@ -52,12 +52,6 @@ var ( errCaseIterator = errors.New("case iterator error") ) -var ( - // braintrust "span_attributes" for each type of eval span. - evalSpanAttrs = map[string]any{"type": "eval"} - taskSpanAttrs = map[string]any{"type": "task"} -) - // Opts defines the options for running an evaluation. // I is the input type and R is the result/output type. // @@ -83,6 +77,141 @@ type Opts[I, R any] struct { Update bool // If true, append to existing experiment (default: false) Parallelism int // Number of goroutines (default: 1) Quiet bool // Suppress result output (default: false) + + // OnCaseComplete is called after each case completes (task + scorers). + // It is called from worker goroutines and must be safe for concurrent use. + // Optional — nil means no callback. + OnCaseComplete func(CaseProgress) + + // SpanParent overrides the parent attribute set on eval spans. + // When zero, the default "experiment_id:" parent is used. + // The remote eval server sets this to link spans to a playground context. + SpanParent bttrace.Parent + + // Generation is propagated from the parent context (e.g. a Braintrust playground + // invocation) and injected into braintrust.span_attributes on every span. + // The Braintrust backend uses it to link eval spans back to the triggering context. + Generation any +} + +// CaseProgress contains the result of a single completed evaluation case. +// It is passed to the [Opts.OnCaseComplete] callback. +type CaseProgress struct { + Output any + Scores map[string]float64 + Error error + // ID is the eval span ID, used to correlate SSE progress events with OTLP span data. + ID string + // Origin contains dataset provenance when the case came from a dataset. + Origin map[string]any +} + +// Eval defines an evaluation: the task to run and the scorers to apply. +// Create one with [braintrust.NewEval], then call [Eval.Run] to execute it +// or pass it to a remote eval server. +type Eval[I, R any] struct { + // Name is the eval name. Used as the default experiment name and as + // the registration key when registered with a remote eval server. + Name string + + // Task is the function under evaluation. + Task TaskFunc[I, R] + + // Scorers are the scoring functions applied to each task result. + Scorers []Scorer[I, R] + + // ProjectName is the Braintrust project for this eval. + // Optional; falls back to the default project from the client. + ProjectName string + + // evaluator holds the infrastructure (session, tracer, API client) + // needed to run the eval. Set by NewEval / braintrust.NewEval. + evaluator *Evaluator[I, R] +} + +// NewEval creates a runnable Eval by attaching an [Evaluator] as the default +// runner. Users should call braintrust.NewEval rather than this directly. +func NewEval[I, R any](evaluator *Evaluator[I, R], e *Eval[I, R]) *Eval[I, R] { + e.evaluator = evaluator + return e +} + +// Run executes the evaluation using the default [Evaluator]. +func (e *Eval[I, R]) Run(ctx context.Context, opts RunOpts[I, R]) (*Result, error) { + return e.evaluator.Run(ctx, mergeOpts(e, opts)) +} + +// RunOpts configures a single evaluation run. These vary per invocation; +// the [Eval] definition stays the same. +type RunOpts[I, R any] struct { + // Experiment overrides the experiment name. Defaults to [Eval.Name]. + Experiment string + + // ProjectName overrides the project name. Defaults to [Eval.ProjectName]. + ProjectName string + + // Dataset is the test cases to evaluate against. Required. + Dataset Dataset[I, R] + + // Tags to apply to the experiment. + Tags []string + + // Metadata to attach to the experiment. + Metadata Metadata + + // Update appends to an existing experiment when true (default: false). + Update bool + + // Parallelism is the number of goroutines (default: 1). + Parallelism int + + // Quiet suppresses result output (default: false). + Quiet bool + + // OnCaseComplete is called after each case completes (task + scorers). + // It is called from worker goroutines and must be safe for concurrent use. + // Optional — nil means no callback. + OnCaseComplete func(CaseProgress) + + // SpanParent overrides the parent attribute set on eval spans. + // When zero, the default "experiment_id:" parent is used. + // The remote eval server sets this to link spans to a playground context. + SpanParent bttrace.Parent + + // Generation is propagated from the parent context (e.g. a Braintrust playground + // invocation) and injected into braintrust.span_attributes on every span. + // The Braintrust backend uses it to link eval spans back to the triggering context. + Generation any +} + +// mergeOpts combines an Eval definition with RunOpts into an Opts for +// backward-compatible delegation to the existing run() function. +func mergeOpts[I, R any](ev *Eval[I, R], ro RunOpts[I, R]) Opts[I, R] { + experiment := ro.Experiment + if experiment == "" { + experiment = ev.Name + } + + projectName := ro.ProjectName + if projectName == "" { + projectName = ev.ProjectName + } + + return Opts[I, R]{ + Experiment: experiment, + Dataset: ro.Dataset, + Task: ev.Task, + Scorers: ev.Scorers, + ProjectName: projectName, + Tags: ro.Tags, + Metadata: ro.Metadata, + Update: ro.Update, + Parallelism: ro.Parallelism, + Quiet: ro.Quiet, + OnCaseComplete: ro.OnCaseComplete, + SpanParent: ro.SpanParent, + Generation: ro.Generation, + } } // Case represents a single test case in an evaluation. @@ -174,6 +303,16 @@ func (r *Result) ID() string { return r.key.experimentID } +// ProjectID returns the project ID. +func (r *Result) ProjectID() string { + return r.key.projectID +} + +// ProjectName returns the project name. +func (r *Result) ProjectName() string { + return r.key.projectName +} + // String returns a string representaton of the result for printing on the console. // // The format it prints will change and shouldn't be relied on for programmatic use. @@ -233,6 +372,8 @@ type eval[I, R any] struct { startSpanOpt oteltrace.SpanStartOption goroutines int quiet bool + onCaseComplete func(CaseProgress) + generation any } // nextCase is a wrapper for sending cases through a channel. @@ -255,9 +396,17 @@ func newEval[I, R any]( scorers []Scorer[I, R], parallelism int, quiet bool, + onCaseComplete func(CaseProgress), + spanParent bttrace.Parent, + generation any, ) *eval[I, R] { - // Build parent span option + // Build parent span option. Use explicit override if provided (e.g. from + // the remote eval server linking spans to a playground), otherwise default + // to the experiment ID. parent := bttrace.NewParent(bttrace.ParentTypeExperimentID, experimentID) + if !spanParent.IsZero() { + parent = spanParent + } startSpanOpt := oteltrace.WithAttributes(parent.Attr()) // Extract dataset ID from dataset @@ -269,6 +418,11 @@ func newEval[I, R any]( goroutines = 1 } + // Default to noop so callers don't need nil checks + if onCaseComplete == nil { + onCaseComplete = func(CaseProgress) {} + } + return &eval[I, R]{ session: s, parent: parent, @@ -284,6 +438,8 @@ func newEval[I, R any]( startSpanOpt: startSpanOpt, goroutines: goroutines, quiet: quiet, + onCaseComplete: onCaseComplete, + generation: generation, } } @@ -321,9 +477,22 @@ func newEvalOpts[I, R any](ctx context.Context, s *auth.Session, tp *trace.Trace opts.Scorers, opts.Parallelism, opts.Quiet, + opts.OnCaseComplete, + opts.SpanParent, + opts.Generation, ), nil } +// spanAttrs builds span_attributes for the given span type, injecting +// generation when set (used by the remote eval server). +func (e *eval[I, R]) spanAttrs(spanType string) map[string]any { + attrs := map[string]any{"type": spanType} + if e.generation != nil { + attrs["generation"] = e.generation + } + return attrs +} + func (e *eval[I, R]) run(ctx context.Context) (*Result, error) { start := time.Now() if e.experimentID == "" { @@ -415,7 +584,7 @@ func (e *eval[I, R]) runNextCase(ctx context.Context, nextCase nextCase[I, R]) e func (e *eval[I, R]) runCase(ctx context.Context, span oteltrace.Span, c Case[I, R]) error { // Set all non-output attributes upfront so they're captured even if the task fails. attrs := map[string]any{ - "braintrust.span_attributes": evalSpanAttrs, + "braintrust.span_attributes": e.spanAttrs("eval"), "braintrust.input_json": c.Input, "braintrust.expected": c.Expected, } @@ -438,8 +607,25 @@ func (e *eval[I, R]) runCase(ctx context.Context, span oteltrace.Span, c Case[I, span.SetAttributes(attribute.StringSlice("braintrust.tags", c.Tags)) } + // Use the eval span ID as the progress event ID, matching Ruby's protocol. + // The UI correlates SSE progress events with OTLP span data using this ID. + spanID := span.SpanContext().SpanID().String() + + // Build origin for progress tracking when case came from a dataset. + var origin map[string]any + if c.ID != "" && c.XactID != "" { + origin = map[string]any{ + "object_type": "dataset", + "object_id": e.datasetID, + "id": c.ID, + "created": c.Created, + "_xact_id": c.XactID, + } + } + taskResult, err := e.runTask(ctx, span, c) if err != nil { + e.onCaseComplete(CaseProgress{Error: err, ID: spanID, Origin: origin}) span.SetStatus(codes.Error, err.Error()) return err } @@ -448,7 +634,18 @@ func (e *eval[I, R]) runCase(ctx context.Context, span oteltrace.Span, c Case[I, return err } - _, err = e.runScorers(ctx, taskResult) + scores, err := e.runScorers(ctx, taskResult) + scoreMap := make(map[string]float64, len(scores)) + for _, s := range scores { + scoreMap[s.Name] = s.Score + } + e.onCaseComplete(CaseProgress{ + Output: taskResult.Output, + Scores: scoreMap, + Error: err, + ID: spanID, + Origin: origin, + }) if err != nil { span.SetStatus(codes.Error, err.Error()) return err @@ -466,7 +663,7 @@ func (e *eval[I, R]) runTask(ctx context.Context, evalSpan oteltrace.Span, c Cas attrs := map[string]any{ "braintrust.input_json": c.Input, "braintrust.expected": c.Expected, - "braintrust.span_attributes": taskSpanAttrs, + "braintrust.span_attributes": e.spanAttrs("task"), } var encodeErrs []error @@ -534,12 +731,10 @@ func (e *eval[I, R]) runScorer(ctx context.Context, scorer Scorer[I, R], taskRes ctx, span := e.tracer.Start(ctx, scorer.Name(), e.startSpanOpt) defer span.End() - spanAttrs := map[string]any{ - "type": "score", - "name": scorer.Name(), - "purpose": "scorer", - } - if err := setJSONAttr(span, "braintrust.span_attributes", spanAttrs); err != nil { + scorerAttrs := e.spanAttrs("score") + scorerAttrs["name"] = scorer.Name() + scorerAttrs["purpose"] = "scorer" + if err := setJSONAttr(span, "braintrust.span_attributes", scorerAttrs); err != nil { return nil, err } @@ -749,6 +944,9 @@ func testNewEval[I, R any]( task, scorers, parallelism, - true, // quiet=true for tests + true, // quiet=true for tests + nil, // no callback for tests + bttrace.Parent{}, // no parent override + nil, // no generation ) } diff --git a/eval/eval_integration_test.go b/eval/eval_integration_test.go index 9637eda..1a2318a 100644 --- a/eval/eval_integration_test.go +++ b/eval/eval_integration_test.go @@ -727,3 +727,48 @@ func TestEval_NoProjectName(t *testing.T) { assert.Nil(t, result) assert.Contains(t, err.Error(), "project name is required") } + +// TestEvalRun_Integration tests Eval.Run with a reusable Eval definition. +func TestEvalRun_Integration(t *testing.T) { + session, apiClient := setupIntegrationTest(t) + t.Parallel() + + ctx := context.Background() + cfg := &config.Config{ + DefaultProjectName: integrationTestProject, + } + + // Define a reusable eval + classify := &Eval[string, string]{ + Name: "classify", + Task: T(func(ctx context.Context, input string) (string, error) { + return "category-" + input, nil + }), + Scorers: []Scorer[string, string]{ + NewScorer("exact_match", func(ctx context.Context, r TaskResult[string, string]) (Scores, error) { + if r.Output == r.Expected { + return S(1.0), nil + } + return S(0.0), nil + }), + }, + } + + tp := trace.NewTracerProvider() + defer func() { _ = tp.Shutdown(ctx) }() + + evaluator := NewEvaluator[string, string](session, tp, apiClient, cfg.DefaultProjectName) + e := NewEval(evaluator, classify) + result, err := e.Run(ctx, RunOpts[string, string]{ + Dataset: NewDataset([]Case[string, string]{ + {Input: "apple", Expected: "category-apple"}, + {Input: "banana", Expected: "category-banana"}, + }), + Quiet: true, + }) + + require.NoError(t, err) + require.NotNil(t, result) + assert.Equal(t, "classify", result.Name()) + assert.NotEmpty(t, result.ID()) +} diff --git a/eval/eval_runopts_test.go b/eval/eval_runopts_test.go new file mode 100644 index 0000000..bef562b --- /dev/null +++ b/eval/eval_runopts_test.go @@ -0,0 +1,212 @@ +package eval + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + bttrace "github.com/braintrustdata/braintrust-sdk-go/trace" +) + +func TestMergeOpts_ExperimentDefaultsToEvalName(t *testing.T) { + t.Parallel() + + ev := &Eval[testInput, testOutput]{ + Name: "my-eval", + Task: T(func(_ context.Context, in testInput) (testOutput, error) { + return testOutput{}, nil + }), + } + + opts := mergeOpts(ev, RunOpts[testInput, testOutput]{}) + assert.Equal(t, "my-eval", opts.Experiment) +} + +func TestMergeOpts_ExperimentOverride(t *testing.T) { + t.Parallel() + + ev := &Eval[testInput, testOutput]{ + Name: "my-eval", + Task: T(func(_ context.Context, in testInput) (testOutput, error) { + return testOutput{}, nil + }), + } + + opts := mergeOpts(ev, RunOpts[testInput, testOutput]{ + Experiment: "custom-experiment", + }) + assert.Equal(t, "custom-experiment", opts.Experiment) +} + +func TestMergeOpts_FieldResolution(t *testing.T) { + t.Parallel() + + task := T(func(_ context.Context, in testInput) (testOutput, error) { + return testOutput{Result: in.Value}, nil + }) + scorer := NewScorer("s", func(_ context.Context, _ TaskResult[testInput, testOutput]) (Scores, error) { + return S(1.0), nil + }) + + ev := &Eval[testInput, testOutput]{ + Name: "eval-name", + Task: task, + Scorers: []Scorer[testInput, testOutput]{scorer}, + ProjectName: "eval-project", + } + + dataset := NewDataset([]Case[testInput, testOutput]{ + {Input: testInput{Value: "x"}}, + }) + callback := func(CaseProgress) {} + + ro := RunOpts[testInput, testOutput]{ + Dataset: dataset, + Tags: []string{"tag1"}, + Metadata: map[string]any{"k": "v"}, + Update: true, + Parallelism: 4, + Quiet: true, + OnCaseComplete: callback, + SpanParent: bttrace.NewParent(bttrace.ParentTypePlaygroundID, "pg-1"), + Generation: 42, + } + + opts := mergeOpts(ev, ro) + + // Definition fields come from Eval + assert.Equal(t, "eval-name", opts.Experiment) + assert.NotNil(t, opts.Task) + assert.Len(t, opts.Scorers, 1) + assert.Equal(t, "eval-project", opts.ProjectName) + + // Runtime fields come from RunOpts + assert.Equal(t, dataset, opts.Dataset) + assert.Equal(t, []string{"tag1"}, opts.Tags) + assert.Equal(t, Metadata{"k": "v"}, opts.Metadata) + assert.True(t, opts.Update) + assert.Equal(t, 4, opts.Parallelism) + assert.True(t, opts.Quiet) + assert.NotNil(t, opts.OnCaseComplete) + assert.Equal(t, bttrace.NewParent(bttrace.ParentTypePlaygroundID, "pg-1"), opts.SpanParent) + assert.Equal(t, 42, opts.Generation) +} + +func TestMergeOpts_ProjectNameOverride(t *testing.T) { + t.Parallel() + + ev := &Eval[testInput, testOutput]{ + Name: "my-eval", + ProjectName: "default-project", + Task: T(func(_ context.Context, in testInput) (testOutput, error) { + return testOutput{}, nil + }), + } + + // Without override: uses Eval.ProjectName + opts := mergeOpts(ev, RunOpts[testInput, testOutput]{}) + assert.Equal(t, "default-project", opts.ProjectName) + + // With override: uses RunOpts.ProjectName + opts = mergeOpts(ev, RunOpts[testInput, testOutput]{ + ProjectName: "override-project", + }) + assert.Equal(t, "override-project", opts.ProjectName) +} + +func TestMergeOpts_EvalReuse(t *testing.T) { + t.Parallel() + + ev := &Eval[testInput, testOutput]{ + Name: "reusable-eval", + Task: T(func(_ context.Context, in testInput) (testOutput, error) { + return testOutput{Result: in.Value}, nil + }), + Scorers: []Scorer[testInput, testOutput]{ + NewScorer("s", func(_ context.Context, _ TaskResult[testInput, testOutput]) (Scores, error) { + return S(1.0), nil + }), + }, + ProjectName: "base-project", + } + + // Same Eval, different RunOpts + opts1 := mergeOpts(ev, RunOpts[testInput, testOutput]{ + Experiment: "run-1", + Tags: []string{"nightly"}, + }) + opts2 := mergeOpts(ev, RunOpts[testInput, testOutput]{ + Experiment: "run-2", + ProjectName: "other-project", + Parallelism: 8, + }) + + // Each merge produces independent Opts + assert.Equal(t, "run-1", opts1.Experiment) + assert.Equal(t, "base-project", opts1.ProjectName) + assert.Equal(t, []string{"nightly"}, opts1.Tags) + assert.Equal(t, 0, opts1.Parallelism) + + assert.Equal(t, "run-2", opts2.Experiment) + assert.Equal(t, "other-project", opts2.ProjectName) + assert.Nil(t, opts2.Tags) + assert.Equal(t, 8, opts2.Parallelism) + + // Definition is unchanged + assert.Equal(t, "reusable-eval", ev.Name) + assert.Equal(t, "base-project", ev.ProjectName) +} + +// TestEvalRun_Success verifies that Eval.Run produces the same span structure as +// the equivalent Evaluator.Run call. It uses the same testNewEval path as other +// unit tests to avoid needing a real API client for experiment registration. +func TestEvalRun_Success(t *testing.T) { + t.Parallel() + + task := T(func(_ context.Context, in testInput) (testOutput, error) { + return testOutput{Result: "output-" + in.Value}, nil + }) + scorer := NewScorer("accuracy", func(_ context.Context, r TaskResult[testInput, testOutput]) (Scores, error) { + return S(0.95), nil + }) + + ev := &Eval[testInput, testOutput]{ + Name: "test-eval", + Task: task, + Scorers: []Scorer[testInput, testOutput]{scorer}, + ProjectName: "eval-project", + } + + dataset := NewDataset([]Case[testInput, testOutput]{ + {Input: testInput{Value: "a"}, Expected: testOutput{Result: "expected-a"}}, + }) + + ro := RunOpts[testInput, testOutput]{ + Dataset: dataset, + Quiet: true, + } + + // Verify mergeOpts produces the right Opts, then run via testNewEval + // (the same path used by all other unit tests). + opts := mergeOpts(ev, ro) + assert.Equal(t, "test-eval", opts.Experiment) + assert.Equal(t, "eval-project", opts.ProjectName) + assert.NotNil(t, opts.Task) + assert.Len(t, opts.Scorers, 1) + assert.True(t, opts.Quiet) + + // Run the eval using the unit test helper to verify end-to-end span output. + ute := newUnitTestEval(t, dataset, opts.Task, opts.Scorers, 1) + result, err := ute.eval.run(context.Background()) + require.NoError(t, err) + require.NotNil(t, result) + + // Verify spans: task + scorer + eval = 3 + spans := ute.exporter.Flush() + require.Len(t, spans, 3) + spans[0].AssertNameIs("task") + spans[1].AssertNameIs("accuracy") + spans[2].AssertNameIs("eval") +} diff --git a/eval/eval_test.go b/eval/eval_test.go index e77f94d..e76270e 100644 --- a/eval/eval_test.go +++ b/eval/eval_test.go @@ -3,7 +3,9 @@ package eval import ( "context" "errors" + "fmt" "io" + "sync" "testing" "github.com/stretchr/testify/assert" @@ -1157,6 +1159,327 @@ func TestEval_ParentPropagation(t *testing.T) { assert.Equal(scorerParent, trace.Parent{Type: trace.ParentTypeExperimentID, ID: result.ID()}) } +func TestOnCaseComplete_Callback(t *testing.T) { + t.Parallel() + + cases := NewDataset([]Case[testInput, testOutput]{ + {Input: testInput{Value: "test1"}, Expected: testOutput{Result: "expected1"}}, + {Input: testInput{Value: "test2"}, Expected: testOutput{Result: "expected2"}}, + }) + + task := T(func(ctx context.Context, input testInput) (testOutput, error) { + return testOutput{Result: "output-" + input.Value}, nil + }) + + scorer := NewScorer("accuracy", func(ctx context.Context, result TaskResult[testInput, testOutput]) (Scores, error) { + return S(0.75), nil + }) + + // Track callback invocations + var mu sync.Mutex + var progresses []CaseProgress + callback := func(cp CaseProgress) { + mu.Lock() + progresses = append(progresses, cp) + mu.Unlock() + } + + // Create eval manually with the callback + tp, _ := oteltest.Setup(t) + tracer := tp.Tracer(t.Name()) + session := tests.NewSession(t) + + e := newEval( + session, tracer, + "exp-callback", "callback-experiment", + "proj-callback", "callback-project", + cases, task, + []Scorer[testInput, testOutput]{scorer}, + 1, true, callback, trace.Parent{}, nil, + ) + + result, err := e.run(context.Background()) + require.NoError(t, err) + assert.NotNil(t, result) + + mu.Lock() + defer mu.Unlock() + require.Len(t, progresses, 2) + + // Both should have scores and no errors + for _, p := range progresses { + assert.NoError(t, p.Error) + assert.NotNil(t, p.Scores) + assert.Equal(t, 0.75, p.Scores["accuracy"]) + assert.NotNil(t, p.Output) + } +} + +func TestOnCaseComplete_CallbackOnError(t *testing.T) { + t.Parallel() + + cases := NewDataset([]Case[testInput, testOutput]{ + {Input: testInput{Value: "will-fail"}}, + }) + + task := T(func(ctx context.Context, input testInput) (testOutput, error) { + return testOutput{}, errors.New("task failed") + }) + + var called bool + var capturedProgress CaseProgress + callback := func(cp CaseProgress) { + called = true + capturedProgress = cp + } + + tp, _ := oteltest.Setup(t) + tracer := tp.Tracer(t.Name()) + session := tests.NewSession(t) + + e := newEval( + session, tracer, + "exp-err", "err-experiment", + "proj-err", "err-project", + cases, task, + nil, 1, true, callback, trace.Parent{}, nil, + ) + + _, _ = e.run(context.Background()) + + assert.True(t, called) + assert.Error(t, capturedProgress.Error) + assert.Contains(t, capturedProgress.Error.Error(), "task failed") +} + +func TestOnCaseComplete_NilCallback(t *testing.T) { + t.Parallel() + + // Ensure nil callback doesn't panic + cases := NewDataset([]Case[testInput, testOutput]{ + {Input: testInput{Value: "test"}}, + }) + + task := T(func(ctx context.Context, input testInput) (testOutput, error) { + return testOutput{Result: "ok"}, nil + }) + + tp, _ := oteltest.Setup(t) + tracer := tp.Tracer(t.Name()) + session := tests.NewSession(t) + + e := newEval( + session, tracer, + "exp-nil", "nil-experiment", + "proj-nil", "nil-project", + cases, task, + nil, 1, true, nil, trace.Parent{}, nil, + ) + + result, err := e.run(context.Background()) + require.NoError(t, err) + assert.NotNil(t, result) +} + +func TestOnCaseComplete_Parallel(t *testing.T) { + t.Parallel() + + // 20 cases with parallelism=4 to exercise concurrent callback invocation + var inputCases []Case[testInput, testOutput] + for i := 0; i < 20; i++ { + inputCases = append(inputCases, Case[testInput, testOutput]{ + Input: testInput{Value: fmt.Sprintf("test%d", i)}, + Expected: testOutput{Result: fmt.Sprintf("output-test%d", i)}, + }) + } + cases := NewDataset(inputCases) + + task := T(func(ctx context.Context, input testInput) (testOutput, error) { + return testOutput{Result: "output-" + input.Value}, nil + }) + + scorer := NewScorer("score", func(ctx context.Context, result TaskResult[testInput, testOutput]) (Scores, error) { + return S(1.0), nil + }) + + var mu sync.Mutex + var progresses []CaseProgress + callback := func(cp CaseProgress) { + mu.Lock() + progresses = append(progresses, cp) + mu.Unlock() + } + + tp, _ := oteltest.Setup(t) + tracer := tp.Tracer(t.Name()) + session := tests.NewSession(t) + + e := newEval( + session, tracer, + "exp-parallel", "parallel-experiment", + "proj-parallel", "parallel-project", + cases, task, + []Scorer[testInput, testOutput]{scorer}, + 4, true, callback, trace.Parent{}, nil, + ) + + result, err := e.run(context.Background()) + require.NoError(t, err) + assert.NotNil(t, result) + + mu.Lock() + defer mu.Unlock() + require.Len(t, progresses, 20, "callback should fire for all 20 cases") + + for _, p := range progresses { + assert.NoError(t, p.Error) + assert.Equal(t, 1.0, p.Scores["score"]) + assert.NotNil(t, p.Output) + } +} + +func TestCaseProgress_IDIsSpanID(t *testing.T) { + t.Parallel() + + cases := NewDataset([]Case[testInput, testOutput]{ + {Input: testInput{Value: "test"}}, + }) + + task := T(func(ctx context.Context, input testInput) (testOutput, error) { + return testOutput{Result: "ok"}, nil + }) + + var capturedProgress CaseProgress + callback := func(cp CaseProgress) { + capturedProgress = cp + } + + tp, _ := oteltest.Setup(t) + tracer := tp.Tracer(t.Name()) + session := tests.NewSession(t) + + e := newEval( + session, tracer, + "exp-id", "id-experiment", + "proj-id", "id-project", + cases, task, + nil, 1, true, callback, trace.Parent{}, nil, + ) + + _, err := e.run(context.Background()) + require.NoError(t, err) + + // ID should be a 16-character hex span ID + assert.NotEmpty(t, capturedProgress.ID) + assert.Regexp(t, `^[0-9a-f]{16}$`, capturedProgress.ID) +} + +func TestCaseProgress_OriginFromDataset(t *testing.T) { + t.Parallel() + + cases := NewDataset([]Case[testInput, testOutput]{ + { + Input: testInput{Value: "test"}, + ID: "case-123", + XactID: "xact-456", + Created: "2024-01-15T10:30:00Z", + }, + }) + + task := T(func(ctx context.Context, input testInput) (testOutput, error) { + return testOutput{Result: "ok"}, nil + }) + + var capturedProgress CaseProgress + callback := func(cp CaseProgress) { + capturedProgress = cp + } + + tp, _ := oteltest.Setup(t) + tracer := tp.Tracer(t.Name()) + session := tests.NewSession(t) + + e := newEval( + session, tracer, + "exp-origin", "origin-experiment", + "proj-origin", "origin-project", + cases, task, + nil, 1, true, callback, trace.Parent{}, nil, + ) + + _, err := e.run(context.Background()) + require.NoError(t, err) + + require.NotNil(t, capturedProgress.Origin) + assert.Equal(t, "dataset", capturedProgress.Origin["object_type"]) + assert.Equal(t, "case-123", capturedProgress.Origin["id"]) + assert.Equal(t, "xact-456", capturedProgress.Origin["_xact_id"]) +} + +func TestCaseProgress_OriginNilWithoutDatasetID(t *testing.T) { + t.Parallel() + + cases := NewDataset([]Case[testInput, testOutput]{ + {Input: testInput{Value: "test"}}, + }) + + task := T(func(ctx context.Context, input testInput) (testOutput, error) { + return testOutput{Result: "ok"}, nil + }) + + var capturedProgress CaseProgress + callback := func(cp CaseProgress) { + capturedProgress = cp + } + + tp, _ := oteltest.Setup(t) + tracer := tp.Tracer(t.Name()) + session := tests.NewSession(t) + + e := newEval( + session, tracer, + "exp-no-origin", "no-origin-experiment", + "proj-no-origin", "no-origin-project", + cases, task, + nil, 1, true, callback, trace.Parent{}, nil, + ) + + _, err := e.run(context.Background()) + require.NoError(t, err) + + assert.Nil(t, capturedProgress.Origin) +} + +func TestResult_ProjectIDAndProjectName(t *testing.T) { + t.Parallel() + + cases := NewDataset([]Case[testInput, testOutput]{ + {Input: testInput{Value: "test"}}, + }) + + task := T(func(ctx context.Context, input testInput) (testOutput, error) { + return testOutput{Result: "ok"}, nil + }) + + tp, _ := oteltest.Setup(t) + tracer := tp.Tracer(t.Name()) + session := tests.NewSession(t) + + e := newEval( + session, tracer, + "exp-proj", "project-experiment", + "proj-abc123", "test-project-name", + cases, task, + nil, 1, true, nil, trace.Parent{}, nil, + ) + + result, err := e.run(context.Background()) + require.NoError(t, err) + + assert.Equal(t, "proj-abc123", result.ProjectID()) + assert.Equal(t, "test-project-name", result.ProjectName()) +} + func TestTaskOutput_UserData(t *testing.T) { t.Parallel() @@ -1308,5 +1631,97 @@ func TestEval_EvalSpanAttrsOnTaskFailure(t *testing.T) { } } assert.Equal(t, []string{"tag1"}, foundTags) +} + +func TestSpanParentOverride(t *testing.T) { + t.Parallel() + + cases := NewDataset([]Case[testInput, testOutput]{ + {Input: testInput{Value: "x"}, Expected: testOutput{Result: "y"}}, + }) + task := T(func(ctx context.Context, input testInput) (testOutput, error) { + return testOutput{Result: input.Value}, nil + }) + scorer := NewScorer[testInput, testOutput]("s", func(_ context.Context, _ TaskResult[testInput, testOutput]) (Scores, error) { + return S(1.0), nil + }) + + tp, exporter := oteltest.Setup(t) + tracer := tp.Tracer(t.Name()) + session := tests.NewSession(t) + + e := newEval( + session, tracer, + "exp-id", "exp-name", + "proj-id", "proj-name", + cases, task, + []Scorer[testInput, testOutput]{scorer}, + 1, true, nil, + trace.NewParent(trace.ParentTypePlaygroundID, "pg-999"), // SpanParent override + 42, // Generation + ) + + result, err := e.run(context.Background()) + require.NoError(t, err) + require.NotNil(t, result) + + spans := exporter.Flush() + require.NotEmpty(t, spans) + + // Every span should have the overridden parent attribute + for _, span := range spans { + parentVal := span.Attr("braintrust.parent").Value.AsString() + assert.Equal(t, "playground_id:pg-999", parentVal, + "span %q should have overridden parent", span.Name()) + } + + // Eval span should have generation in span_attributes + evalSpan := spans[len(spans)-1] + evalSpan.AssertNameIs("eval") + spanAttrsJSON := evalSpan.Attr("braintrust.span_attributes").Value.AsString() + assert.Contains(t, spanAttrsJSON, `"generation"`) + assert.Contains(t, spanAttrsJSON, `42`) +} + +func TestSpanParentDefault(t *testing.T) { + t.Parallel() + + cases := NewDataset([]Case[testInput, testOutput]{ + {Input: testInput{Value: "x"}, Expected: testOutput{Result: "y"}}, + }) + task := T(func(ctx context.Context, input testInput) (testOutput, error) { + return testOutput{Result: input.Value}, nil + }) + tp, exporter := oteltest.Setup(t) + tracer := tp.Tracer(t.Name()) + session := tests.NewSession(t) + + e := newEval( + session, tracer, + "exp-id", "exp-name", + "proj-id", "proj-name", + cases, task, nil, + 1, true, nil, + trace.Parent{}, nil, // no override, no generation + ) + + result, err := e.run(context.Background()) + require.NoError(t, err) + require.NotNil(t, result) + + spans := exporter.Flush() + require.NotEmpty(t, spans) + + // Every span should have the default experiment_id parent + for _, span := range spans { + parentVal := span.Attr("braintrust.parent").Value.AsString() + assert.Equal(t, "experiment_id:exp-id", parentVal, + "span %q should have default experiment parent", span.Name()) + } + + // Eval span should NOT have generation in span_attributes + evalSpan := spans[len(spans)-1] + spanAttrsJSON := evalSpan.Attr("braintrust.span_attributes").Value.AsString() + assert.NotContains(t, spanAttrsJSON, `"generation"`) } diff --git a/eval/example_test.go b/eval/example_test.go index b99979f..15538b8 100644 --- a/eval/example_test.go +++ b/eval/example_test.go @@ -61,3 +61,49 @@ func Example() { fmt.Printf("Evaluation complete: %s\n", result.Name()) } + +// Example_evalDefinition demonstrates how to define a reusable evaluation +// that can be run locally or registered with a remote eval server. +func Example_evalDefinition() { + ctx := context.Background() + + // Create tracer provider + tp := trace.NewTracerProvider() + defer func() { _ = tp.Shutdown(ctx) }() + + // Create Braintrust client + client, err := braintrust.New(tp, braintrust.WithProject("test-project")) + if err != nil { + log.Fatal(err) + } + + // Create a runnable eval + e := braintrust.NewEval(client, &eval.Eval[string, string]{ + Name: "classify", + Task: eval.T(func(ctx context.Context, input string) (string, error) { + return input + "!", nil + }), + Scorers: []eval.Scorer[string, string]{ + eval.NewScorer("exact-match", func(ctx context.Context, result eval.TaskResult[string, string]) (eval.Scores, error) { + if result.Output == result.Expected { + return eval.S(1.0), nil + } + return eval.S(0.0), nil + }), + }, + ProjectName: "test-project", + }) + + // Run it + result, err := e.Run(ctx, eval.RunOpts[string, string]{ + Dataset: eval.NewDataset([]eval.Case[string, string]{ + {Input: "hello", Expected: "hello!"}, + }), + Quiet: true, + }) + if err != nil { + log.Fatal(err) + } + + fmt.Printf("Evaluation complete: %s\n", result.Name()) +} diff --git a/eval/testdata/cassettes/TestEvalRun_Integration.yaml b/eval/testdata/cassettes/TestEvalRun_Integration.yaml new file mode 100644 index 0000000..c890002 --- /dev/null +++ b/eval/testdata/cassettes/TestEvalRun_Integration.yaml @@ -0,0 +1,183 @@ +--- +version: 2 +interactions: + - id: 0 + request: + proto: HTTP/1.1 + proto_major: 1 + proto_minor: 1 + content_length: 23 + transfer_encoding: [] + trailer: {} + host: api.braintrust.dev + remote_addr: "" + request_uri: "" + body: '{"name":"go-sdk-tests"}' + form: {} + headers: + Content-Type: + - application/json + url: https://api.braintrust.dev/v1/project + method: POST + response: + proto: HTTP/2.0 + proto_major: 2 + proto_minor: 0 + transfer_encoding: [] + trailer: {} + content_length: -1 + uncompressed: true + body: '{"id":"c1357431-5a26-4ef4-baac-c40fc13553ef","org_id":"5d7c97d7-fef1-4cb7-bda6-7e3756a0ca8e","name":"go-sdk-tests","description":null,"created":"2025-11-04T16:07:47.679Z","deleted_at":null,"user_id":"f2ddc4e6-a51a-4a60-9734-9af4ea05c6ef","settings":null}' + headers: + Access-Control-Allow-Credentials: + - "true" + Access-Control-Expose-Headers: + - x-bt-cursor,x-bt-found-existing,x-bt-query-plan,x-bt-api-duration-ms,x-bt-brainstore-duration-ms + Content-Type: + - application/json; charset=utf-8 + Date: + - Tue, 31 Mar 2026 16:02:55 GMT + Etag: + - W/"fe-zwAhFWuHmpM5HSpE8rG2KFNGNzI" + Vary: + - Origin, Accept-Encoding + Via: + - 1.1 7d7f7790ad8ab9e81e905351df020944.cloudfront.net (CloudFront), 1.1 abcdd9ead509c6f31d96ed9f797fd698.cloudfront.net (CloudFront) + X-Amz-Apigw-Id: + - bGJ3ZEqTIAMEh9w= + X-Amz-Cf-Id: + - A63FFfxFgWRApSDdXlSrIzhRdy8u9evnbFC4W1blhVTaUKmRMsfbAQ== + X-Amz-Cf-Pop: + - CMH68-P1 + - CMH68-P1 + X-Amzn-Requestid: + - aec7efd6-bd6a-471f-942e-12eeb24fe9b9 + X-Amzn-Trace-Id: + - Root=1-69cbf02e-00c49d6b3fcc61fb71aef1ac;Parent=3f4e12eaf309b983;Sampled=0;Lineage=1:fc3b4ff1:0 + X-Bt-Found-Existing: + - "true" + X-Bt-Internal-Trace-Id: + - 69cbf02f00000000224cee26d6fb2154 + X-Cache: + - Miss from cloudfront + status: 200 OK + code: 200 + duration: 264.397677ms + - id: 1 + request: + proto: HTTP/1.1 + proto_major: 1 + proto_minor: 1 + content_length: 0 + transfer_encoding: [] + trailer: {} + host: api.braintrust.dev + remote_addr: "" + request_uri: "" + body: "" + form: {} + headers: {} + url: https://api.braintrust.dev/api/apikey/login + method: POST + response: + proto: HTTP/2.0 + proto_major: 2 + proto_minor: 0 + transfer_encoding: [] + trailer: {} + content_length: -1 + uncompressed: true + body: '{"org_info":[{"id":"5d7c97d7-fef1-4cb7-bda6-7e3756a0ca8e","name":"braintrustdata.com","api_url":"https://api.braintrust.dev","git_metadata":{"fields":["commit","branch","tag","author_name","author_email","commit_message","commit_time","dirty"],"collect":"some"},"is_universal_api":true,"proxy_url":"https://api.braintrust.dev","realtime_url":"wss://realtime.braintrustapi.com"}]}' + headers: + Access-Control-Allow-Credentials: + - "true" + Access-Control-Expose-Headers: + - x-bt-cursor,x-bt-found-existing,x-bt-query-plan,x-bt-api-duration-ms,x-bt-brainstore-duration-ms + Content-Type: + - application/json; charset=utf-8 + Date: + - Tue, 31 Mar 2026 16:02:55 GMT + Etag: + - W/"18b-OPCBBHzVVuCPaglXVbFjmsFzOoE" + Vary: + - Origin, Accept-Encoding + Via: + - 1.1 ade0cadf195b634f1ce60fe31eb474a2.cloudfront.net (CloudFront), 1.1 abcdd9ead509c6f31d96ed9f797fd698.cloudfront.net (CloudFront) + X-Amz-Apigw-Id: + - bGJ3ZHNXoAMECzA= + X-Amz-Cf-Id: + - qJ3xnwAA2-83oCl844FN8KycLApEVsunPLp27TgaXWxh03Tssm5OpA== + X-Amz-Cf-Pop: + - CMH68-P1 + - CMH68-P1 + X-Amzn-Requestid: + - c46b9c94-043a-4cce-a700-4d930e73e55b + X-Amzn-Trace-Id: + - Root=1-69cbf02e-202957480080e3655c74eb33;Parent=25bc53d80538962a;Sampled=0;Lineage=1:fc3b4ff1:0 + X-Bt-Internal-Trace-Id: + - 69cbf02f00000000247f8dd332bca120 + X-Cache: + - Miss from cloudfront + status: 200 OK + code: 200 + duration: 277.481017ms + - id: 2 + request: + proto: HTTP/1.1 + proto_major: 1 + proto_minor: 1 + content_length: 89 + transfer_encoding: [] + trailer: {} + host: api.braintrust.dev + remote_addr: "" + request_uri: "" + body: '{"project_id":"c1357431-5a26-4ef4-baac-c40fc13553ef","name":"classify","ensure_new":true}' + form: {} + headers: + Content-Type: + - application/json + url: https://api.braintrust.dev/v1/experiment + method: POST + response: + proto: HTTP/2.0 + proto_major: 2 + proto_minor: 0 + transfer_encoding: [] + trailer: {} + content_length: -1 + uncompressed: true + body: '{"id":"332cd801-ab6a-4713-9a46-e71729db8de4","project_id":"c1357431-5a26-4ef4-baac-c40fc13553ef","name":"classify","description":null,"created":"2026-03-31T16:02:55.305Z","repo_info":null,"commit":null,"base_exp_id":null,"deleted_at":null,"dataset_id":null,"dataset_version":null,"parameters_id":null,"parameters_version":null,"public":false,"user_id":"c755328d-f64a-4737-a984-e83c088cd9f7","metadata":null,"tags":null}' + headers: + Access-Control-Allow-Credentials: + - "true" + Access-Control-Expose-Headers: + - x-bt-cursor,x-bt-found-existing,x-bt-query-plan,x-bt-api-duration-ms,x-bt-brainstore-duration-ms + Content-Type: + - application/json; charset=utf-8 + Date: + - Tue, 31 Mar 2026 16:02:55 GMT + Etag: + - W/"1a3-O6UJpZ1ut7UppcvvRlH5NWXjJxM" + Vary: + - Origin, Accept-Encoding + Via: + - 1.1 a239c31f56936d8dde678cf491dbaa28.cloudfront.net (CloudFront), 1.1 abcdd9ead509c6f31d96ed9f797fd698.cloudfront.net (CloudFront) + X-Amz-Apigw-Id: + - bGJ3bGxWIAMEOUw= + X-Amz-Cf-Id: + - zwutbRqzzUoav-R4NaNbF9tRALrtg9g4EWjTgj8l7PwR9uYw0sJV5g== + X-Amz-Cf-Pop: + - CMH68-P1 + - CMH68-P1 + X-Amzn-Requestid: + - 335e5011-dad5-4730-ba33-c0c00a54838f + X-Amzn-Trace-Id: + - Root=1-69cbf02f-4f0723641ae77a48592172b2;Parent=1db626a8bcc7a791;Sampled=0;Lineage=1:fc3b4ff1:0 + X-Bt-Internal-Trace-Id: + - 69cbf02f000000006bc6baafdb05fbb6 + X-Cache: + - Miss from cloudfront + status: 200 OK + code: 200 + duration: 162.639032ms diff --git a/examples/README.md b/examples/README.md index 14197c9..23919f8 100644 --- a/examples/README.md +++ b/examples/README.md @@ -29,6 +29,12 @@ Examples for other AI providers and client libraries: - **[langchaingo/](langchaingo/)** - Trace LangChainGo multi-turn conversations - **[adk-go/](adk/)** - Trace agents built with the ADK framework +## Remote Evals + +Run evaluations triggered from the Braintrust UI against your own infrastructure: + +- **[eval-server](internal/eval-server/)** - Remote eval server exposing evaluators over HTTP + ## Advanced Features (30 minutes) More specialized use cases and integrations: diff --git a/examples/eval-server/eval-server.go b/examples/eval-server/eval-server.go new file mode 100644 index 0000000..1d6e6ca --- /dev/null +++ b/examples/eval-server/eval-server.go @@ -0,0 +1,80 @@ +// This example demonstrates running a remote eval server that exposes +// evaluators to the Braintrust UI. The Braintrust playground can then +// trigger evaluations against your locally running code. +// +// Start the server: +// +// go run examples/eval-server/eval-server.go +// +// Then configure the endpoint (http://localhost:8300) in your Braintrust +// project settings under Remote evals. +package main + +import ( + "context" + "log" + "strings" + + "github.com/braintrustdata/braintrust-sdk-go/eval" + "github.com/braintrustdata/braintrust-sdk-go/server" +) + +func main() { + // Create the eval server. + // Use WithNoAuth() for local development; remove for production. + // Use WithTracerProvider(tp) to share an OpenTelemetry TracerProvider so + // user-instrumented spans (LLM clients, custom spans) appear in the same + // trace as eval spans. + srv := server.New( + server.WithAddress("localhost:8300"), + server.WithNoAuth(), + ) + + // Define a simple task: classify food items. + classifyTask := eval.T(func(ctx context.Context, input string) (string, error) { + input = strings.ToLower(input) + switch { + case strings.Contains(input, "apple") || strings.Contains(input, "banana"): + return "fruit", nil + case strings.Contains(input, "carrot") || strings.Contains(input, "lettuce"): + return "vegetable", nil + default: + return "unknown", nil + } + }) + + // Define scorers. Scorers run server-side as part of the evaluation and + // their results are exported to Braintrust via OTLP spans. + exactMatch := eval.NewScorer("exact_match", + func(ctx context.Context, r eval.TaskResult[string, string]) (eval.Scores, error) { + if r.Output == r.Expected { + return eval.S(1.0), nil + } + return eval.S(0.0), nil + }, + ) + + // Register the evaluator with the server. + server.RegisterEval(srv, &eval.Eval[string, string]{ + Name: "food-classifier", + Task: classifyTask, + Scorers: []eval.Scorer[string, string]{exactMatch}, + ProjectName: "go-sdk-examples", + }, server.RegisterEvalOpts{ + Parameters: &server.Parameters{ + Schema: map[string]server.ParameterDef{ + "model": { + Type: "string", + Default: "rule-based", + Description: "Classification model to use", + }, + }, + }, + }) + + log.Printf("Eval server starting on localhost:8300") + log.Printf("Registered evaluators: food-classifier") + log.Printf("Health check: http://localhost:8300/") + log.Printf("List evals: http://localhost:8300/list") + log.Fatal(srv.Start()) +} diff --git a/examples/internal/README.md b/examples/internal/README.md index 39946a8..deb19eb 100644 --- a/examples/internal/README.md +++ b/examples/internal/README.md @@ -30,6 +30,7 @@ Comprehensive examples testing all features for each AI provider: - **[rewrite/](rewrite/)** - Manual tracing and evaluator API testing - **[email-evals/](email-evals/)** - Realistic eval example with complex scoring - **[eval-updates/](eval-updates/)** - Testing Update option for appending to experiments +- **[eval-server/](eval-server/)** - Remote eval server exposing evaluators over HTTP - **[temporal/](temporal/)** - Temporal workflow distributed tracing (worker + client) ## For Learning diff --git a/examples/internal/eval-server/main.go b/examples/internal/eval-server/main.go new file mode 100644 index 0000000..f9eb0b9 --- /dev/null +++ b/examples/internal/eval-server/main.go @@ -0,0 +1,94 @@ +// This example demonstrates running a remote eval server that exposes +// evaluators to the Braintrust UI. The Braintrust playground can then +// trigger evaluations against your locally running code. +// +// Start the server: +// +// go run examples/internal/eval-server/main.go +// +// Then configure the endpoint (http://localhost:8300) in your Braintrust +// project settings under Remote evals. +package main + +import ( + "context" + "log" + "strings" + + "github.com/braintrustdata/braintrust-sdk-go/eval" + "github.com/braintrustdata/braintrust-sdk-go/server" +) + +func main() { + // Create the eval server. + // Use WithNoAuth() for local development; remove for production. + // Use WithTracerProvider(tp) to include user-instrumented spans (LLM clients, + // custom spans) in eval traces. + srv := server.New( + server.WithAddress("localhost:8300"), + server.WithNoAuth(), + ) + + // Define a simple task: classify food items. + classifyTask := eval.T(func(ctx context.Context, input string) (string, error) { + input = strings.ToLower(input) + switch { + case strings.Contains(input, "apple") || strings.Contains(input, "banana"): + return "fruit", nil + case strings.Contains(input, "carrot") || strings.Contains(input, "lettuce"): + return "vegetable", nil + default: + return "unknown", nil + } + }) + + // Define a scorer: exact match. + exactMatch := eval.NewScorer("exact_match", + func(ctx context.Context, r eval.TaskResult[string, string]) (eval.Scores, error) { + if r.Output == r.Expected { + return eval.S(1.0), nil + } + return eval.S(0.0), nil + }, + ) + + // Define a scorer: checks output is a valid food category. + validCategory := eval.NewScorer("valid_category", + func(ctx context.Context, r eval.TaskResult[string, string]) (eval.Scores, error) { + switch r.Output { + case "fruit", "vegetable", "grain", "protein", "dairy", "unknown": + return eval.S(1.0), nil + default: + return eval.S(0.0), nil + } + }, + ) + + // Define the eval. + foodClassifier := &eval.Eval[string, string]{ + Name: "food-classifier", + Task: classifyTask, + Scorers: []eval.Scorer[string, string]{exactMatch, validCategory}, + ProjectName: "go-sdk-examples", + } + + // Register with the server. + server.RegisterEval(srv, foodClassifier, server.RegisterEvalOpts{ + Parameters: &server.Parameters{ + Schema: map[string]server.ParameterDef{ + "model": { + Type: "string", + Default: "rule-based", + Description: "Classification model to use", + }, + }, + }, + }, + ) + + log.Printf("Eval server starting on localhost:8300") + log.Printf("Registered evaluators: food-classifier") + log.Printf("Health check: http://localhost:8300/") + log.Printf("List evals: http://localhost:8300/list") + log.Fatal(srv.Start()) +} diff --git a/server/auth.go b/server/auth.go new file mode 100644 index 0000000..ee3b3fd --- /dev/null +++ b/server/auth.go @@ -0,0 +1,222 @@ +package server + +import ( + "context" + "errors" + "fmt" + "net/http" + "strings" + "sync" + + "github.com/braintrustdata/braintrust-sdk-go/api" + "github.com/braintrustdata/braintrust-sdk-go/internal/auth" + "github.com/braintrustdata/braintrust-sdk-go/internal/https" + "github.com/braintrustdata/braintrust-sdk-go/logger" +) + +type contextKey string + +const authContextKey contextKey = "braintrust.auth" + +// authResult holds the validated auth context for a request. +type authResult struct { + session *auth.Session + api *api.API +} + +// newAuthResult creates a session, logs in, and builds an API client. +// This is the shared auth flow used by both per-request auth and no-auth mode. +func newAuthResult(ctx context.Context, apiKey, appURL, apiURL, orgName string, log logger.Logger) (*authResult, error) { + httpClient := https.NewClient(apiKey, appURL, log) + session, err := auth.NewSession(ctx, auth.Options{ + APIKey: apiKey, + AppURL: appURL, + AppPublicURL: appURL, + APIURL: apiURL, + OrgName: orgName, + Logger: log, + Client: httpClient, + }) + if err != nil { + return nil, fmt.Errorf("failed to create session: %w", err) + } + + if err := session.Login(ctx); err != nil { + session.Close() + return nil, fmt.Errorf("authentication failed: %w", err) + } + + apiInfo := session.APIInfo() + apiClient := api.NewClient(apiInfo.APIKey, api.WithAPIURL(apiInfo.APIURL), api.WithLogger(log)) + + return &authResult{session: session, api: apiClient}, nil +} + +// authCache is an LRU cache of authenticated sessions. +type authCache struct { + mu sync.Mutex + entries map[string]*authResult + order []string + maxSize int + appURL string + log logger.Logger +} + +func newAuthCache(appURL string, maxSize int, log logger.Logger) *authCache { + return &authCache{ + entries: make(map[string]*authResult), + maxSize: maxSize, + appURL: appURL, + log: log, + } +} + +// cacheKey builds a collision-free cache key from request auth headers. +// Uses length-prefixing to prevent "a:b"+"c" == "a"+"b:c" collisions. +func cacheKey(token, orgName string) string { + return fmt.Sprintf("%d:%s:%s", len(token), token, orgName) +} + +// getOrCreate returns a cached auth result or creates a new one. +func (c *authCache) getOrCreate(ctx context.Context, token, orgName string) (*authResult, error) { + key := cacheKey(token, orgName) + + c.mu.Lock() + if result, ok := c.entries[key]; ok { + c.moveToEnd(key) + c.mu.Unlock() + return result, nil + } + c.mu.Unlock() + + // Create new session outside the lock to avoid holding it during network calls + result, err := c.createSession(ctx, token, orgName) + if err != nil { + return nil, err + } + + c.mu.Lock() + defer c.mu.Unlock() + + // Re-check: another goroutine may have inserted this key while we were unlocked + if existing, ok := c.entries[key]; ok { + // Discard the session we just created; use the existing one + result.session.Close() + return existing, nil + } + + // Evict oldest if at capacity + if len(c.entries) >= c.maxSize && len(c.order) > 0 { + oldest := c.order[0] + if evicted, ok := c.entries[oldest]; ok { + evicted.session.Close() + } + delete(c.entries, oldest) + c.order = c.order[1:] + } + + c.entries[key] = result + c.order = append(c.order, key) + return result, nil +} + +// createSession creates and validates a new auth session. +func (c *authCache) createSession(ctx context.Context, token, orgName string) (*authResult, error) { + return newAuthResult(ctx, token, c.appURL, "", orgName, c.log) +} + +// evict removes a cache entry by token and org name, closing its session. +// Called when a cached session produces auth errors during eval execution. +func (c *authCache) evict(token, orgName string) { + key := cacheKey(token, orgName) + c.mu.Lock() + defer c.mu.Unlock() + + if entry, ok := c.entries[key]; ok { + entry.session.Close() + delete(c.entries, key) + for i, k := range c.order { + if k == key { + c.order = append(c.order[:i], c.order[i+1:]...) + break + } + } + } +} + +// moveToEnd moves a key to the end of the LRU order. +// O(n) scan over the order slice; fine for defaultAuthCacheMax (64). +// Must be called with c.mu held. +func (c *authCache) moveToEnd(key string) { + for i, k := range c.order { + if k == key { + c.order = append(c.order[:i], c.order[i+1:]...) + c.order = append(c.order, key) + return + } + } +} + +// isAuthError returns true if the error chain contains an HTTP 401 or 403. +func isAuthError(err error) bool { + var httpErr *https.HTTPError + if errors.As(err, &httpErr) { + return httpErr.StatusCode == 401 || httpErr.StatusCode == 403 + } + return false +} + +// extractToken extracts the auth token from request headers. +func extractToken(r *http.Request) string { + // Prefer x-bt-auth-token + if token := strings.TrimSpace(r.Header.Get("X-Bt-Auth-Token")); token != "" { + return token + } + // Fall back to Authorization: Bearer + authHeader := r.Header.Get("Authorization") + if authHeader == "" { + return "" + } + if strings.HasPrefix(authHeader, "Bearer ") { + return strings.TrimSpace(strings.TrimPrefix(authHeader, "Bearer ")) + } + return strings.TrimSpace(authHeader) +} + +// extractOrgName extracts the organization name from request headers. +func extractOrgName(r *http.Request) string { + return r.Header.Get("X-Bt-Org-Name") +} + +// authFromContext retrieves the auth result from request context. +func authFromContext(ctx context.Context) *authResult { + result, _ := ctx.Value(authContextKey).(*authResult) + return result +} + +// authMiddleware validates request auth and injects the auth result into context. +func (s *Server) authMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if s.noAuth { + next.ServeHTTP(w, r) + return + } + + token := extractToken(r) + if token == "" { + http.Error(w, `{"error":"missing authentication token"}`, http.StatusUnauthorized) + return + } + + orgName := extractOrgName(r) + result, err := s.authCache.getOrCreate(r.Context(), token, orgName) + if err != nil { + s.logger.Warn("authentication failed", "error", err) + http.Error(w, `{"error":"authentication failed"}`, http.StatusUnauthorized) + return + } + + ctx := context.WithValue(r.Context(), authContextKey, result) + next.ServeHTTP(w, r.WithContext(ctx)) + }) +} diff --git a/server/auth_test.go b/server/auth_test.go new file mode 100644 index 0000000..454afdc --- /dev/null +++ b/server/auth_test.go @@ -0,0 +1,162 @@ +package server + +import ( + "context" + "fmt" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/braintrustdata/braintrust-sdk-go/internal/auth" + "github.com/braintrustdata/braintrust-sdk-go/internal/https" +) + +func TestExtractToken(t *testing.T) { + tests := []struct { + name string + headers map[string]string + expected string + }{ + { + name: "x-bt-auth-token header", + headers: map[string]string{"X-Bt-Auth-Token": "token123"}, + expected: "token123", + }, + { + name: "bearer authorization", + headers: map[string]string{"Authorization": "Bearer token456"}, + expected: "token456", + }, + { + name: "plain authorization", + headers: map[string]string{"Authorization": "token789"}, + expected: "token789", + }, + { + name: "x-bt-auth-token takes priority", + headers: map[string]string{"X-Bt-Auth-Token": "preferred", "Authorization": "Bearer fallback"}, + expected: "preferred", + }, + { + name: "no token", + headers: map[string]string{}, + expected: "", + }, + { + name: "bearer with extra whitespace", + headers: map[string]string{"Authorization": "Bearer token-ws "}, + expected: "token-ws", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := httptest.NewRequest(http.MethodGet, "/", nil) + for k, v := range tt.headers { + r.Header.Set(k, v) + } + assert.Equal(t, tt.expected, extractToken(r)) + }) + } +} + +func TestExtractOrgName(t *testing.T) { + r := httptest.NewRequest(http.MethodGet, "/", nil) + r.Header.Set("X-Bt-Org-Name", "my-org") + assert.Equal(t, "my-org", extractOrgName(r)) +} + +func TestAuthMiddleware_NoAuth(t *testing.T) { + srv := New(WithNoAuth()) + handler := srv.authMiddleware(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + r := httptest.NewRequest(http.MethodGet, "/", nil) + w := httptest.NewRecorder() + handler.ServeHTTP(w, r) + + assert.Equal(t, http.StatusOK, w.Code) +} + +func TestAuthMiddleware_MissingToken(t *testing.T) { + srv := New() // auth enabled by default + handler := srv.authMiddleware(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + r := httptest.NewRequest(http.MethodGet, "/", nil) + w := httptest.NewRecorder() + handler.ServeHTTP(w, r) + + assert.Equal(t, http.StatusUnauthorized, w.Code) + assert.Contains(t, w.Body.String(), "missing authentication token") +} + +func TestAuthFromContext_Nil(t *testing.T) { + r := httptest.NewRequest(http.MethodGet, "/", nil) + result := authFromContext(r.Context()) + assert.Nil(t, result) +} + +func TestAuthFromContext_WrongType(t *testing.T) { + ctx := context.WithValue(context.Background(), authContextKey, "not-an-authResult") + result := authFromContext(ctx) + assert.Nil(t, result) +} + +func TestAuthCacheKey(t *testing.T) { + key := cacheKey("token", "org") + assert.Equal(t, "5:token:org", key) +} + +func TestAuthCacheKey_NoCollision(t *testing.T) { + // "a:b" + "c" should differ from "a" + "b:c" + key1 := cacheKey("a:b", "c") + key2 := cacheKey("a", "b:c") + assert.NotEqual(t, key1, key2) +} + +func TestAuthCache_Evict(t *testing.T) { + cache := newAuthCache("https://app.test", 64, nil) + token := "test-token" + orgName := "test-org" + + // Manually insert a fake entry + key := cacheKey(token, orgName) + session := auth.NewTestSession("key", "org-id", orgName, "https://api.test", "https://app.test", "https://app.test", nil) + cache.entries[key] = &authResult{session: session, api: nil} + cache.order = append(cache.order, key) + + // Verify it's there + assert.Len(t, cache.entries, 1) + assert.Len(t, cache.order, 1) + + // Evict it + cache.evict(token, orgName) + + // Verify it's gone + assert.Len(t, cache.entries, 0) + assert.Len(t, cache.order, 0) +} + +func TestAuthCache_EvictNonExistent(t *testing.T) { + cache := newAuthCache("https://app.test", 64, nil) + + // Evicting a key that doesn't exist should be a no-op + cache.evict("no-such-token", "no-such-org") + + assert.Len(t, cache.entries, 0) + assert.Len(t, cache.order, 0) +} + +func TestIsAuthError(t *testing.T) { + assert.True(t, isAuthError(&https.HTTPError{StatusCode: 401})) + assert.True(t, isAuthError(&https.HTTPError{StatusCode: 403})) + assert.True(t, isAuthError(fmt.Errorf("wrapped: %w", &https.HTTPError{StatusCode: 401}))) + assert.False(t, isAuthError(&https.HTTPError{StatusCode: 500})) + assert.False(t, isAuthError(fmt.Errorf("some other error"))) + assert.False(t, isAuthError(nil)) +} diff --git a/server/cors.go b/server/cors.go new file mode 100644 index 0000000..ed16e35 --- /dev/null +++ b/server/cors.go @@ -0,0 +1,57 @@ +package server + +import ( + "net/http" + "regexp" + "strings" +) + +// allowedOriginPattern matches braintrust.dev and braintrustdata.dev origins +// including subdomains and preview deployments. HTTP is allowed alongside HTTPS +// to support local development proxies. +var allowedOriginPattern = regexp.MustCompile(`^https?://([\w-]+\.)*(braintrust|braintrustdata)\.dev$`) + +var corsAllowHeaders = strings.Join([]string{ + "Content-Type", + "Authorization", + "X-Api-Key", + "X-Bt-Auth-Token", + "X-Bt-Parent", + "X-Bt-Org-Name", + "X-Bt-Project-Id", + "X-Bt-Cursor", + "X-Bt-Found-Existing-Experiment", + "X-Bt-Span-Id", + "X-Bt-Span-Export", + "X-Bt-Use-Gateway", +}, ", ") + +const ( + corsAllowMethods = "GET, POST, OPTIONS" + corsMaxAge = "86400" +) + +// corsMiddleware wraps an http.Handler with CORS support for Braintrust origins. +func corsMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + origin := r.Header.Get("Origin") + if origin != "" && allowedOriginPattern.MatchString(origin) { + w.Header().Set("Access-Control-Allow-Origin", origin) + w.Header().Set("Access-Control-Allow-Methods", corsAllowMethods) + w.Header().Set("Access-Control-Allow-Headers", corsAllowHeaders) + w.Header().Set("Access-Control-Expose-Headers", corsAllowHeaders) + w.Header().Set("Access-Control-Allow-Credentials", "true") + w.Header().Set("Access-Control-Max-Age", corsMaxAge) + w.Header().Set("Vary", "Origin") + // Support Chrome Private Network Access + w.Header().Set("Access-Control-Allow-Private-Network", "true") + } + + if r.Method == http.MethodOptions { + w.WriteHeader(http.StatusNoContent) + return + } + + next.ServeHTTP(w, r) + }) +} diff --git a/server/cors_test.go b/server/cors_test.go new file mode 100644 index 0000000..ea2ce51 --- /dev/null +++ b/server/cors_test.go @@ -0,0 +1,82 @@ +package server + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestCORS_AllowedOrigins(t *testing.T) { + tests := []struct { + origin string + allowed bool + }{ + {"https://www.braintrust.dev", true}, + {"https://braintrust.dev", true}, + {"https://app.braintrust.dev", true}, + {"https://preview-123.braintrust.dev", true}, + {"https://braintrustdata.dev", true}, + {"https://app.braintrustdata.dev", true}, + {"http://braintrust.dev", true}, + {"https://evil.com", false}, + {"https://braintrust.dev.evil.com", false}, + {"", false}, + } + + handler := corsMiddleware(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + for _, tt := range tests { + t.Run(tt.origin, func(t *testing.T) { + r := httptest.NewRequest(http.MethodGet, "/", nil) + if tt.origin != "" { + r.Header.Set("Origin", tt.origin) + } + w := httptest.NewRecorder() + handler.ServeHTTP(w, r) + + if tt.allowed { + assert.Equal(t, tt.origin, w.Header().Get("Access-Control-Allow-Origin")) + assert.Equal(t, "true", w.Header().Get("Access-Control-Allow-Private-Network")) + } else { + assert.Empty(t, w.Header().Get("Access-Control-Allow-Origin")) + } + }) + } +} + +func TestCORS_Preflight(t *testing.T) { + handler := corsMiddleware(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + r := httptest.NewRequest(http.MethodOptions, "/eval", nil) + r.Header.Set("Origin", "https://www.braintrust.dev") + w := httptest.NewRecorder() + handler.ServeHTTP(w, r) + + assert.Equal(t, http.StatusNoContent, w.Code) + assert.Equal(t, "https://www.braintrust.dev", w.Header().Get("Access-Control-Allow-Origin")) + assert.Contains(t, w.Header().Get("Access-Control-Allow-Headers"), "X-Bt-Auth-Token") + assert.Contains(t, w.Header().Get("Access-Control-Allow-Headers"), "X-Bt-Use-Gateway") + assert.Equal(t, "86400", w.Header().Get("Access-Control-Max-Age")) +} + +func TestCORS_NonPreflightPassesThrough(t *testing.T) { + called := false + handler := corsMiddleware(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + called = true + w.WriteHeader(http.StatusOK) + })) + + r := httptest.NewRequest(http.MethodGet, "/", nil) + r.Header.Set("Origin", "https://www.braintrust.dev") + w := httptest.NewRecorder() + handler.ServeHTTP(w, r) + + assert.True(t, called) + assert.Equal(t, http.StatusOK, w.Code) +} diff --git a/server/register.go b/server/register.go new file mode 100644 index 0000000..8383cb4 --- /dev/null +++ b/server/register.go @@ -0,0 +1,319 @@ +package server + +import ( + "context" + "encoding/json" + "fmt" + "sync" + "time" + + sdktrace "go.opentelemetry.io/otel/sdk/trace" + + "github.com/braintrustdata/braintrust-sdk-go/eval" + bttrace "github.com/braintrustdata/braintrust-sdk-go/trace" +) + +// registeredEval is the non-generic interface stored in the server's evaluator map. +// It hides the type parameters behind JSON-based I/O. +type registeredEval interface { + scorerNames() []string + parameters() *Parameters + projectName() string + run(ctx context.Context, cfg *evalRunConfig) error +} + +// evalRunConfig holds the per-request configuration for running an evaluation. +type evalRunConfig struct { + req *EvalRequest + auth *authResult + sse *sseWriter + noAuth bool + tracerProvider *sdktrace.TracerProvider // nil means create per-request +} + +// RegisterEvalOpts configures a registered evaluator. +type RegisterEvalOpts struct { + // Parameters defines the parameter schema shown in the Braintrust UI. + Parameters *Parameters + + // ProjectName is the default project for this evaluator. + ProjectName string +} + +// RegisterEval adds an eval definition to the server. The type parameters I and R +// are the input and result types of the evaluation. Go does not allow generic +// methods on non-generic types, so this is a package-level function. +// +// Example: +// +// classify := &eval.Eval[string, string]{ +// Name: "classify", +// Task: eval.T(classifyTask), +// Scorers: []eval.Scorer[string, string]{scorer}, +// } +// server.RegisterEval(srv, classify, server.RegisterEvalOpts{}) +func RegisterEval[I, R any](s *Server, ev *eval.Eval[I, R], opts RegisterEvalOpts) { + impl := ®isteredEvalImpl[I, R]{ + def: ev, + opts: opts, + } + + s.evalsMu.Lock() + defer s.evalsMu.Unlock() + s.evaluators[ev.Name] = impl +} + +// registeredEvalImpl implements registeredEval by wrapping an [eval.Eval] definition. +type registeredEvalImpl[I, R any] struct { + def *eval.Eval[I, R] + opts RegisterEvalOpts +} + +func (r *registeredEvalImpl[I, R]) scorerNames() []string { + names := make([]string, len(r.def.Scorers)) + for i, s := range r.def.Scorers { + names[i] = s.Name() + } + return names +} + +func (r *registeredEvalImpl[I, R]) parameters() *Parameters { + return r.opts.Parameters +} + +func (r *registeredEvalImpl[I, R]) projectName() string { + if r.opts.ProjectName != "" { + return r.opts.ProjectName + } + return r.def.ProjectName +} + +func (r *registeredEvalImpl[I, R]) run(ctx context.Context, cfg *evalRunConfig) error { + req := cfg.req + + // Resolve inline data into typed cases + dataset, err := r.resolveDataset(ctx, cfg) + if err != nil { + return fmt.Errorf("failed to resolve dataset: %w", err) + } + + // Determine experiment name + experimentName := req.ExperimentName + if experimentName == "" { + experimentName = r.def.Name + } + + // Use the shared TracerProvider if one was provided, otherwise create a + // per-request provider. A shared provider allows user-instrumented code + // (LLM clients, custom spans) to appear in the same trace as eval spans. + tp := cfg.tracerProvider + if tp == nil { + tp = sdktrace.NewTracerProvider() + defer func() { + shutdownCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + _ = tp.Shutdown(shutdownCtx) + }() + + // Per-request provider needs its own Braintrust span processor + traceCfg := bttrace.Config{ + DefaultProjectName: r.projectName(), + } + if err := bttrace.AddSpanProcessor(tp, cfg.auth.session, traceCfg); err != nil { + return fmt.Errorf("failed to setup tracing: %w", err) + } + } + + apiClient := cfg.auth.api + + // Cancel eval if the client disconnects (SSE write fails) + evalCtx, cancelEval := context.WithCancel(ctx) + defer cancelEval() + + // Track scores across cases for the summary + var scoresMu sync.Mutex + scoreSums := make(map[string]float64) + scoreCounts := make(map[string]int) + + // Build OnCaseComplete callback to stream progress via SSE + onComplete := func(cp eval.CaseProgress) { + // Skip score accumulation and progress on error + if cp.Error != nil { + return + } + + // Accumulate scores for summary + scoresMu.Lock() + for name, val := range cp.Scores { + scoreSums[name] += val + scoreCounts[name]++ + } + scoresMu.Unlock() + + // JSON-encode just the output, matching Ruby's protocol. + // Scores are delivered via OTLP spans, not SSE progress events. + outputJSON, _ := json.Marshal(cp.Output) + + // Stream progress event; cancel eval if write fails (client disconnected) + if err := cfg.sse.writeProgress(progressEvent{ + ID: cp.ID, + ObjectType: "task", + Name: r.def.Name, + Format: "code", + OutputType: "completion", + Event: "json_delta", + Data: string(outputJSON), + Origin: cp.Origin, + }); err != nil { + cancelEval() + return + } + + // Signal per-cell completion so the UI marks the task as done. + if err := cfg.sse.writeProgress(progressEvent{ + ID: cp.ID, + ObjectType: "task", + Name: r.def.Name, + Format: "code", + OutputType: "completion", + Event: "done", + Data: "", + Origin: cp.Origin, + }); err != nil { + cancelEval() + return + } + + } + + // Resolve parent span context from the request (links traces to the playground) + var spanParent bttrace.Parent + var generation any + if req.Parent != nil && req.Parent.ObjectID != "" { + // Always use "playground_id" as the parent type, matching Ruby/Java behavior. + // The request sends object_type "playground_logs" but the span parent must + // be "playground_id" for the UI to find the spans. + spanParent = bttrace.NewParent("playground_id", req.Parent.ObjectID) + // Extract generation from propagated_event.span_attributes.generation + if len(req.Parent.PropagatedEvent) > 0 { + var pe struct { + SpanAttributes struct { + Generation any `json:"generation"` + } `json:"span_attributes"` + } + if json.Unmarshal(req.Parent.PropagatedEvent, &pe) == nil { + generation = pe.SpanAttributes.Generation + } + } + } + + // Create a per-request evaluator with the caller's session, not the + // default evaluator on the Eval, so traces are attributed to the user + // who triggered the request. + evaluator := eval.NewEvaluator[I, R](cfg.auth.session, tp, apiClient, r.projectName()) + e := eval.NewEval(evaluator, r.def) + result, evalErr := e.Run(evalCtx, eval.RunOpts[I, R]{ + Experiment: experimentName, + Dataset: dataset, + ProjectName: r.projectName(), + Update: true, + Quiet: true, + OnCaseComplete: onComplete, + SpanParent: spanParent, + Generation: generation, + }) + + // Flush traces before sending summary so the UI can poll for scores immediately. + flushCtx, flushCancel := context.WithTimeout(context.Background(), 30*time.Second) + defer flushCancel() + _ = tp.ForceFlush(flushCtx) + + // Build average scores for summary + avgScores := make(map[string]float64, len(scoreSums)) + scoresMu.Lock() + for name, sum := range scoreSums { + if count := scoreCounts[name]; count > 0 { + avgScores[name] = sum / float64(count) + } + } + scoresMu.Unlock() + + // Send summary event + summary := summaryEvent{ + Scores: avgScores, + } + if result != nil { + summary.ExperimentID = result.ID() + summary.ExperimentName = result.Name() + summary.ProjectName = result.ProjectName() + summary.ProjectID = result.ProjectID() + if permalink, err := result.Permalink(); err == nil && permalink != "" { + summary.ExperimentURL = permalink + } + } + if err := cfg.sse.writeSummary(summary); err != nil { + return fmt.Errorf("failed to write summary: %w", err) + } + + return evalErr +} + +// resolveDataset resolves the request data into a typed Dataset. +func (r *registeredEvalImpl[I, R]) resolveDataset(ctx context.Context, cfg *evalRunConfig) (eval.Dataset[I, R], error) { + data := cfg.req.Data + + sourceCount := 0 + if len(data.Data) > 0 { + sourceCount++ + } + if data.DatasetID != "" { + sourceCount++ + } + if data.DatasetName != "" { + sourceCount++ + } + if sourceCount != 1 { + return nil, fmt.Errorf("exactly one of data, dataset_id, or dataset_name must be specified") + } + + // Inline data + if len(data.Data) > 0 { + return r.parseInlineData(data.Data) + } + + // Dataset by ID or name requires an API client + if cfg.auth == nil { + return nil, fmt.Errorf("dataset resolution requires authentication") + } + + // Use the authenticated API client via evaluator's Datasets() + evaluator := eval.NewEvaluator[I, R](cfg.auth.session, nil, cfg.auth.api, r.projectName()) + dsAPI := evaluator.Datasets() + + if data.DatasetID != "" { + return dsAPI.Get(ctx, data.DatasetID) + } + + return dsAPI.Query(ctx, eval.DatasetQueryOpts{ + Name: data.DatasetName, + }) +} + +// parseInlineData unmarshals raw JSON into typed Cases. +func (r *registeredEvalImpl[I, R]) parseInlineData(raw json.RawMessage) (eval.Dataset[I, R], error) { + var rawItems []json.RawMessage + if err := json.Unmarshal(raw, &rawItems); err != nil { + return nil, fmt.Errorf("failed to parse inline data array: %w", err) + } + + cases := make([]eval.Case[I, R], 0, len(rawItems)) + for i, item := range rawItems { + var c eval.Case[I, R] + if err := json.Unmarshal(item, &c); err != nil { + return nil, fmt.Errorf("failed to parse case %d: %w", i, err) + } + cases = append(cases, c) + } + + return eval.NewDataset(cases), nil +} diff --git a/server/server.go b/server/server.go new file mode 100644 index 0000000..dc55228 --- /dev/null +++ b/server/server.go @@ -0,0 +1,313 @@ +// Package server provides a remote eval HTTP server for the Braintrust SDK. +// +// The server exposes locally-registered evaluators over HTTP, allowing the +// Braintrust UI to trigger evaluations against code running on your infrastructure. +// Results are streamed back via Server-Sent Events (SSE). +// +// Example: +// +// classify := &eval.Eval[string, string]{ +// Name: "classify", +// Task: eval.T(classifyTask), +// Scorers: []eval.Scorer[string, string]{scorer}, +// } +// +// srv := server.New(server.WithAddress(":8300")) +// server.RegisterEval(srv, classify, server.RegisterEvalOpts{}) +// log.Fatal(srv.Start()) +package server + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "sync" + "time" + + sdktrace "go.opentelemetry.io/otel/sdk/trace" + + "github.com/braintrustdata/braintrust-sdk-go/config" + "github.com/braintrustdata/braintrust-sdk-go/logger" +) + +const ( + defaultAddr = "localhost:8300" + defaultAppURL = "https://www.braintrust.dev" + defaultAuthCacheMax = 64 + + // defaultReadHeaderTimeout limits how long the server waits for request + // headers after accepting a connection. Protects against slowloris attacks. + // Does not affect eval duration — only the initial request setup. + defaultReadHeaderTimeout = 10 * time.Second + + // defaultIdleTimeout limits how long idle keep-alive connections stay open + // between requests. Does not affect active SSE streams. + defaultIdleTimeout = 120 * time.Second +) + +// Server is an HTTP server that exposes registered evaluators to the Braintrust UI. +type Server struct { + evalsMu sync.RWMutex + evaluators map[string]registeredEval + + serverMu sync.Mutex + httpServer *http.Server + + logger logger.Logger + tracerProvider *sdktrace.TracerProvider // optional, user-provided + addr string + appURL string + noAuth bool + authCache *authCache + defaultAuth *authResult // used in no-auth mode, built from env config +} + +// Option configures the server. +type Option func(*Server) + +// WithAddress sets the listen address (default "localhost:8300"). +func WithAddress(addr string) Option { + return func(s *Server) { + s.addr = addr + } +} + +// WithLogger sets a custom logger. +func WithLogger(l logger.Logger) Option { + return func(s *Server) { + s.logger = l + } +} + +// WithAppURL sets the Braintrust app URL for auth validation (default "https://www.braintrust.dev"). +func WithAppURL(url string) Option { + return func(s *Server) { + s.appURL = url + } +} + +// WithTracerProvider sets a custom OpenTelemetry TracerProvider for the server. +// When provided, all eval spans flow through this provider, so user-instrumented +// code (LLM clients, custom spans, etc.) appears in the same trace as eval spans. +// When nil (the default), a per-request TracerProvider is created internally. +func WithTracerProvider(tp *sdktrace.TracerProvider) Option { + return func(s *Server) { + s.tracerProvider = tp + } +} + +// WithNoAuth disables authentication. Use only for local development. +func WithNoAuth() Option { + return func(s *Server) { + s.noAuth = true + } +} + +// New creates a new eval server. +func New(opts ...Option) *Server { + s := &Server{ + evaluators: make(map[string]registeredEval), + addr: defaultAddr, + appURL: defaultAppURL, + logger: logger.NewDefaultLogger(), + } + + for _, opt := range opts { + opt(s) + } + + s.authCache = newAuthCache(s.appURL, defaultAuthCacheMax, s.logger) + + // In no-auth mode, build a default auth result from environment config + if s.noAuth { + if err := s.initDefaultAuth(); err != nil { + s.logger.Warn("no-auth mode: could not create default auth from env", "error", err) + } + } + + return s +} + +// initDefaultAuth creates a default auth result from environment variables. +// This is used in no-auth mode so the server can run evaluations locally. +func (s *Server) initDefaultAuth() error { + cfg := config.FromEnv() + if cfg.APIKey == "" { + return fmt.Errorf("BRAINTRUST_API_KEY is required for no-auth mode") + } + + appURL := cfg.AppURL + if appURL == "" { + appURL = s.appURL + } + + result, err := newAuthResult(context.Background(), cfg.APIKey, appURL, cfg.APIURL, cfg.OrgName, s.logger) + if err != nil { + return err + } + + s.defaultAuth = result + return nil +} + +// Handler returns the server's http.Handler for embedding in a custom server. +func (s *Server) Handler() http.Handler { + mux := http.NewServeMux() + mux.HandleFunc("GET /{$}", s.handleHealth) + mux.HandleFunc("GET /list", s.handleList) + mux.HandleFunc("POST /list", s.handleList) + mux.HandleFunc("POST /eval", s.handleEval) + + // Apply middleware: CORS → Auth → Router + handler := s.authMiddleware(mux) + handler = corsMiddleware(handler) + return handler +} + +// Start starts the HTTP server and blocks until it is shut down. +func (s *Server) Start() error { + srv := &http.Server{ + Addr: s.addr, + Handler: s.Handler(), + ReadHeaderTimeout: defaultReadHeaderTimeout, + IdleTimeout: defaultIdleTimeout, + } + + s.serverMu.Lock() + s.httpServer = srv + s.serverMu.Unlock() + + s.logger.Info("eval server listening", "addr", s.addr, "no_auth", s.noAuth) + return srv.ListenAndServe() +} + +// Shutdown gracefully shuts down the server. +func (s *Server) Shutdown(ctx context.Context) error { + s.serverMu.Lock() + srv := s.httpServer + s.serverMu.Unlock() + + // Drain active requests before closing the default session, + // so in-flight evals using it can complete. + var err error + if srv != nil { + err = srv.Shutdown(ctx) + } + + if s.defaultAuth != nil { + s.defaultAuth.session.Close() + } + + return err +} + +// handleHealth responds to GET / with a health check. +func (s *Server) handleHealth(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"status":"ok"}`)) +} + +// handleList responds to GET/POST /list with registered evaluators. +func (s *Server) handleList(w http.ResponseWriter, _ *http.Request) { + s.evalsMu.RLock() + defer s.evalsMu.RUnlock() + + resp := make(listResponse, len(s.evaluators)) + for name, e := range s.evaluators { + info := evalInfo{ + Scores: make([]scoreInfo, 0), + } + for _, sn := range e.scorerNames() { + info.Scores = append(info.Scores, scoreInfo{Name: sn}) + } + if params := e.parameters(); params != nil { + wireSchema := make(map[string]wireParameterDef, len(params.Schema)) + for k, v := range params.Schema { + wireSchema[k] = wireParameterDef{ + Type: "data", + Schema: schemaField{Type: v.Type}, + Default: v.Default, + Description: v.Description, + } + } + info.Parameters = ¶metersMeta{ + Type: "braintrust.staticParameters", + Schema: wireSchema, + Source: nil, + } + } + resp[name] = info + } + + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(resp); err != nil { + s.logger.Error("failed to encode list response", "error", err) + } +} + +// maxRequestBodyBytes limits the size of eval request bodies (10MB). +const maxRequestBodyBytes = 10 * 1024 * 1024 + +// handleEval handles POST /eval by running an evaluation and streaming SSE. +func (s *Server) handleEval(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestBodyBytes) + var req EvalRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, `{"error":"invalid request body"}`, http.StatusBadRequest) + return + } + + if req.Name == "" { + http.Error(w, `{"error":"name is required"}`, http.StatusBadRequest) + return + } + + s.evalsMu.RLock() + evaluator, ok := s.evaluators[req.Name] + s.evalsMu.RUnlock() + + if !ok { + http.Error(w, fmt.Sprintf(`{"error":"evaluator %q not found"}`, req.Name), http.StatusNotFound) + return + } + + // Get auth from context, or use default auth in no-auth mode + ar := authFromContext(r.Context()) + if ar == nil && s.noAuth { + ar = s.defaultAuth + } + if ar == nil { + http.Error(w, `{"error":"authentication required"}`, http.StatusUnauthorized) + return + } + + // Create SSE writer + sse, err := newSSEWriter(w) + if err != nil { + http.Error(w, fmt.Sprintf(`{"error":"%s"}`, err), http.StatusInternalServerError) + return + } + + cfg := &evalRunConfig{ + req: &req, + auth: ar, + sse: sse, + noAuth: s.noAuth, + tracerProvider: s.tracerProvider, + } + + // Run the evaluation + if err := evaluator.run(r.Context(), cfg); err != nil { + s.logger.Error("eval failed", "evaluator", req.Name, "error", err) + // Evict cached session on auth errors so the next request gets a fresh login + if isAuthError(err) && !s.noAuth { + token := extractToken(r) + orgName := extractOrgName(r) + s.authCache.evict(token, orgName) + } + _ = sse.writeError(err) + } + + _ = sse.writeDone() +} diff --git a/server/server_test.go b/server/server_test.go new file mode 100644 index 0000000..4103c9b --- /dev/null +++ b/server/server_test.go @@ -0,0 +1,212 @@ +package server + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/braintrustdata/braintrust-sdk-go/eval" +) + +func TestHealthEndpoint(t *testing.T) { + srv := New(WithNoAuth()) + ts := httptest.NewServer(srv.Handler()) + defer ts.Close() + + resp, err := http.Get(ts.URL + "/") + require.NoError(t, err) + defer resp.Body.Close() //nolint:errcheck + + assert.Equal(t, http.StatusOK, resp.StatusCode) + + var body map[string]string + require.NoError(t, json.NewDecoder(resp.Body).Decode(&body)) + assert.Equal(t, "ok", body["status"]) +} + +func TestListEndpoint_Empty(t *testing.T) { + srv := New(WithNoAuth()) + ts := httptest.NewServer(srv.Handler()) + defer ts.Close() + + resp, err := http.Get(ts.URL + "/list") + require.NoError(t, err) + defer resp.Body.Close() //nolint:errcheck + + assert.Equal(t, http.StatusOK, resp.StatusCode) + + var body map[string]any + require.NoError(t, json.NewDecoder(resp.Body).Decode(&body)) + assert.Empty(t, body) +} + +func TestListEndpoint_WithEvaluators(t *testing.T) { + srv := New(WithNoAuth()) + + task := eval.T(func(ctx context.Context, input string) (string, error) { + return input, nil + }) + scorer := eval.NewScorer("exact_match", func(ctx context.Context, r eval.TaskResult[string, string]) (eval.Scores, error) { + if r.Output == r.Expected { + return eval.S(1.0), nil + } + return eval.S(0.0), nil + }) + + RegisterEval(srv, &eval.Eval[string, string]{ + Name: "my-eval", + Task: task, + Scorers: []eval.Scorer[string, string]{scorer}, + }, RegisterEvalOpts{}) + + ts := httptest.NewServer(srv.Handler()) + defer ts.Close() + + resp, err := http.Get(ts.URL + "/list") + require.NoError(t, err) + defer resp.Body.Close() //nolint:errcheck + + assert.Equal(t, http.StatusOK, resp.StatusCode) + + var body map[string]json.RawMessage + require.NoError(t, json.NewDecoder(resp.Body).Decode(&body)) + assert.Contains(t, body, "my-eval") + + var info evalInfo + require.NoError(t, json.Unmarshal(body["my-eval"], &info)) + require.Len(t, info.Scores, 1) + assert.Equal(t, "exact_match", info.Scores[0].Name) +} + +func TestListEndpoint_WithParameters(t *testing.T) { + srv := New(WithNoAuth()) + + task := eval.T(func(ctx context.Context, input string) (string, error) { + return input, nil + }) + scorer := eval.NewScorer("score", func(ctx context.Context, r eval.TaskResult[string, string]) (eval.Scores, error) { + return eval.S(1.0), nil + }) + + RegisterEval(srv, &eval.Eval[string, string]{ + Name: "param-eval", + Task: task, + Scorers: []eval.Scorer[string, string]{scorer}, + }, RegisterEvalOpts{ + Parameters: &Parameters{ + Schema: map[string]ParameterDef{ + "model": { + Type: "string", + Default: "gpt-4", + Description: "Model to use", + }, + }, + }, + }) + + ts := httptest.NewServer(srv.Handler()) + defer ts.Close() + + resp, err := http.Get(ts.URL + "/list") + require.NoError(t, err) + defer resp.Body.Close() //nolint:errcheck + + var body map[string]json.RawMessage + require.NoError(t, json.NewDecoder(resp.Body).Decode(&body)) + + var info struct { + Parameters *parametersMeta `json:"parameters"` + } + require.NoError(t, json.Unmarshal(body["param-eval"], &info)) + require.NotNil(t, info.Parameters) + assert.Equal(t, "braintrust.staticParameters", info.Parameters.Type) + assert.Contains(t, info.Parameters.Schema, "model") + + // Verify wire format: each parameter must have type "data" and nested schema. + modelParam := info.Parameters.Schema["model"] + assert.Equal(t, "data", modelParam.Type) + assert.Equal(t, "string", modelParam.Schema.Type) + assert.Equal(t, "gpt-4", modelParam.Default) + assert.Equal(t, "Model to use", modelParam.Description) +} + +func TestEvalEndpoint_MissingName(t *testing.T) { + srv := New(WithNoAuth()) + ts := httptest.NewServer(srv.Handler()) + defer ts.Close() + + resp, err := http.Post(ts.URL+"/eval", "application/json", strings.NewReader(`{}`)) + require.NoError(t, err) + defer resp.Body.Close() //nolint:errcheck + + assert.Equal(t, http.StatusBadRequest, resp.StatusCode) +} + +func TestEvalEndpoint_NotFound(t *testing.T) { + srv := New(WithNoAuth()) + ts := httptest.NewServer(srv.Handler()) + defer ts.Close() + + body := `{"name":"nonexistent","data":{"data":[]}}` + resp, err := http.Post(ts.URL+"/eval", "application/json", strings.NewReader(body)) + require.NoError(t, err) + defer resp.Body.Close() //nolint:errcheck + + assert.Equal(t, http.StatusNotFound, resp.StatusCode) +} + +func TestEvalEndpoint_InvalidJSON(t *testing.T) { + srv := New(WithNoAuth()) + ts := httptest.NewServer(srv.Handler()) + defer ts.Close() + + resp, err := http.Post(ts.URL+"/eval", "application/json", strings.NewReader(`not json`)) + require.NoError(t, err) + defer resp.Body.Close() //nolint:errcheck + + assert.Equal(t, http.StatusBadRequest, resp.StatusCode) +} + +func TestServerNew_Defaults(t *testing.T) { + srv := New() + assert.Equal(t, "localhost:8300", srv.addr) + assert.Equal(t, "https://www.braintrust.dev", srv.appURL) + assert.False(t, srv.noAuth) + assert.NotNil(t, srv.evaluators) + assert.NotNil(t, srv.authCache) +} + +func TestServerNew_Options(t *testing.T) { + srv := New( + WithAddress(":9000"), + WithAppURL("https://custom.example.com"), + WithNoAuth(), + ) + assert.Equal(t, ":9000", srv.addr) + assert.Equal(t, "https://custom.example.com", srv.appURL) + assert.True(t, srv.noAuth) +} + +func TestServerShutdown_NilServer(t *testing.T) { + srv := New() + err := srv.Shutdown(context.Background()) + assert.NoError(t, err) +} + +func TestListEndpoint_POST(t *testing.T) { + srv := New(WithNoAuth()) + ts := httptest.NewServer(srv.Handler()) + defer ts.Close() + + resp, err := http.Post(ts.URL+"/list", "application/json", nil) + require.NoError(t, err) + defer resp.Body.Close() //nolint:errcheck + + assert.Equal(t, http.StatusOK, resp.StatusCode) +} diff --git a/server/sse.go b/server/sse.go new file mode 100644 index 0000000..6afe7b2 --- /dev/null +++ b/server/sse.go @@ -0,0 +1,85 @@ +package server + +import ( + "encoding/json" + "fmt" + "net/http" + "strings" + "sync" +) + +// sseWriter writes Server-Sent Events to an http.ResponseWriter. +// It is safe for concurrent use from multiple goroutines. +type sseWriter struct { + w http.ResponseWriter + flusher http.Flusher + mu sync.Mutex +} + +// newSSEWriter creates an SSE writer and sets the required response headers. +// Returns an error if the ResponseWriter does not support flushing. +func newSSEWriter(w http.ResponseWriter) (*sseWriter, error) { + flusher, ok := w.(http.Flusher) + if !ok { + return nil, fmt.Errorf("response writer does not support flushing (required for SSE)") + } + + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") + w.WriteHeader(http.StatusOK) + flusher.Flush() + + return &sseWriter{w: w, flusher: flusher}, nil +} + +// writeEvent writes a single SSE event with the given type and data. +func (s *sseWriter) writeEvent(event string, data any) error { + // Marshal data outside the mutex to reduce contention + var dataStr string + switch v := data.(type) { + case string: + dataStr = v + case []byte: + dataStr = string(v) + default: + dataBytes, err := json.Marshal(data) + if err != nil { + return fmt.Errorf("failed to marshal SSE data: %w", err) + } + dataStr = string(dataBytes) + } + + // SSE data field cannot contain bare newlines; each line needs its own "data:" prefix. + // Replace newlines with SSE multi-line format. + dataStr = strings.ReplaceAll(dataStr, "\n", "\ndata: ") + + s.mu.Lock() + defer s.mu.Unlock() + + if _, err := fmt.Fprintf(s.w, "event: %s\ndata: %s\n\n", event, dataStr); err != nil { + return fmt.Errorf("failed to write SSE event: %w", err) + } + s.flusher.Flush() + return nil +} + +// writeProgress writes a progress event for a completed evaluation case. +func (s *sseWriter) writeProgress(data progressEvent) error { + return s.writeEvent("progress", data) +} + +// writeSummary writes the final summary event with aggregated scores. +func (s *sseWriter) writeSummary(data summaryEvent) error { + return s.writeEvent("summary", data) +} + +// writeDone writes the terminal done event. +func (s *sseWriter) writeDone() error { + return s.writeEvent("done", "") +} + +// writeError writes an error event. +func (s *sseWriter) writeError(err error) error { + return s.writeEvent("error", map[string]string{"error": err.Error()}) +} diff --git a/server/sse_test.go b/server/sse_test.go new file mode 100644 index 0000000..c9b1a80 --- /dev/null +++ b/server/sse_test.go @@ -0,0 +1,123 @@ +package server + +import ( + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestSSEWriter_Headers(t *testing.T) { + w := httptest.NewRecorder() + _, err := newSSEWriter(w) + require.NoError(t, err) + + assert.Equal(t, "text/event-stream", w.Header().Get("Content-Type")) + assert.Equal(t, "no-cache", w.Header().Get("Cache-Control")) + assert.Equal(t, "keep-alive", w.Header().Get("Connection")) + assert.Equal(t, http.StatusOK, w.Code) +} + +func TestSSEWriter_WriteEvent(t *testing.T) { + w := httptest.NewRecorder() + sse, err := newSSEWriter(w) + require.NoError(t, err) + + err = sse.writeEvent("progress", map[string]string{"key": "value"}) + require.NoError(t, err) + + body := w.Body.String() + assert.Contains(t, body, "event: progress\n") + assert.Contains(t, body, `data: {"key":"value"}`) + assert.True(t, strings.HasSuffix(body, "\n\n")) +} + +func TestSSEWriter_WriteProgress(t *testing.T) { + w := httptest.NewRecorder() + sse, err := newSSEWriter(w) + require.NoError(t, err) + + err = sse.writeProgress(progressEvent{ + ID: "abc123def456", + ObjectType: "task", + Name: "test-eval", + Event: "json_delta", + Data: "hello", + Origin: map[string]any{ + "object_type": "dataset", + "object_id": "ds-123", + "id": "row-456", + }, + }) + require.NoError(t, err) + + body := w.Body.String() + assert.Contains(t, body, "event: progress\n") + assert.Contains(t, body, `"id":"abc123def456"`) + assert.Contains(t, body, `"object_type":"task"`) + assert.Contains(t, body, `"name":"test-eval"`) + assert.Contains(t, body, `"object_id":"ds-123"`) +} + +func TestSSEWriter_WriteSummary(t *testing.T) { + w := httptest.NewRecorder() + sse, err := newSSEWriter(w) + require.NoError(t, err) + + err = sse.writeSummary(summaryEvent{ + ExperimentID: "exp-id-1", + ExperimentName: "exp-1", + ProjectName: "my-project", + ProjectID: "proj-456", + Scores: map[string]float64{"accuracy": 0.95}, + }) + require.NoError(t, err) + + body := w.Body.String() + assert.Contains(t, body, "event: summary\n") + assert.Contains(t, body, `"experiment_name":"exp-1"`) + assert.Contains(t, body, `"project_name":"my-project"`) + assert.Contains(t, body, `"project_id":"proj-456"`) + assert.Contains(t, body, `"accuracy":0.95`) +} + +func TestSSEWriter_WriteDone(t *testing.T) { + w := httptest.NewRecorder() + sse, err := newSSEWriter(w) + require.NoError(t, err) + + err = sse.writeDone() + require.NoError(t, err) + + body := w.Body.String() + assert.Contains(t, body, "event: done\n") + assert.Contains(t, body, "data: ") +} + +func TestSSEWriter_WriteError(t *testing.T) { + w := httptest.NewRecorder() + sse, err := newSSEWriter(w) + require.NoError(t, err) + + err = sse.writeError(assert.AnError) + require.NoError(t, err) + + body := w.Body.String() + assert.Contains(t, body, "event: error\n") + assert.Contains(t, body, "assert.AnError") +} + +func TestSSEWriter_StringData(t *testing.T) { + w := httptest.NewRecorder() + sse, err := newSSEWriter(w) + require.NoError(t, err) + + err = sse.writeEvent("test", "plain string") + require.NoError(t, err) + + body := w.Body.String() + assert.Contains(t, body, "data: plain string\n") +} diff --git a/server/types.go b/server/types.go new file mode 100644 index 0000000..141b892 --- /dev/null +++ b/server/types.go @@ -0,0 +1,114 @@ +package server + +import "encoding/json" + +// EvalRequest is the request body for POST /eval. +type EvalRequest struct { + // Name is the registered evaluator name (required). + Name string `json:"name"` + + // Data specifies the evaluation dataset (required). + Data EvalData `json:"data"` + + // ExperimentName overrides the experiment name (optional). + ExperimentName string `json:"experiment_name,omitempty"` + + // ProjectID overrides the project ID (optional). + ProjectID string `json:"project_id,omitempty"` + + // Parent specifies the parent span for tracing (optional). + Parent *ParentInfo `json:"parent,omitempty"` +} + +// EvalData specifies where evaluation data comes from. +// Exactly one of Data, DatasetID, or DatasetName must be set. +type EvalData struct { + // Data is an inline array of test cases. + Data json.RawMessage `json:"data,omitempty"` + + // DatasetID loads a dataset by ID. + DatasetID string `json:"dataset_id,omitempty"` + + // DatasetName loads a dataset by name (optionally scoped by ProjectName). + DatasetName string `json:"dataset_name,omitempty"` + + // ProjectName scopes DatasetName lookups (optional). + ProjectName string `json:"project_name,omitempty"` +} + +// ParentInfo specifies parent span context for tracing. +type ParentInfo struct { + ObjectType string `json:"object_type,omitempty"` + ObjectID string `json:"object_id,omitempty"` + PropagatedEvent json.RawMessage `json:"propagated_event,omitempty"` +} + +// Parameters defines the parameter schema for an evaluator, displayed in the Braintrust UI. +type Parameters struct { + Schema map[string]ParameterDef `json:"schema"` +} + +// ParameterDef defines a single parameter for the Braintrust UI. +type ParameterDef struct { + Type string `json:"type"` + Default any `json:"default,omitempty"` + Description string `json:"description,omitempty"` +} + +// listResponse is the response for GET/POST /list. +type listResponse map[string]evalInfo + +// evalInfo describes a registered evaluator in the list response. +type evalInfo struct { + Scores []scoreInfo `json:"scores"` + Parameters *parametersMeta `json:"parameters,omitempty"` +} + +// scoreInfo describes a scorer in the list response. +type scoreInfo struct { + Name string `json:"name"` +} + +// parametersMeta wraps parameters with the protocol-required metadata. +type parametersMeta struct { + Type string `json:"type"` + Schema map[string]wireParameterDef `json:"schema"` + Source *string `json:"source"` +} + +// wireParameterDef is the wire format for a parameter in the dev server protocol. +// Each parameter is wrapped with type "data" and a nested schema object. +type wireParameterDef struct { + Type string `json:"type"` + Schema schemaField `json:"schema"` + Default any `json:"default,omitempty"` + Description string `json:"description,omitempty"` +} + +// schemaField is the inner schema for a wire parameter definition. +type schemaField struct { + Type string `json:"type"` +} + +// progressEvent is an SSE progress event sent per evaluation case. +// The id field is required by the Braintrust UI (SSEProgressEventData schema). +type progressEvent struct { + ID string `json:"id"` + ObjectType string `json:"object_type"` + Name string `json:"name"` + Format string `json:"format"` + OutputType string `json:"output_type"` + Event string `json:"event"` + Data any `json:"data,omitempty"` + Origin map[string]any `json:"origin,omitempty"` +} + +// summaryEvent is the final SSE event with aggregated results. +type summaryEvent struct { + ExperimentID string `json:"experiment_id,omitempty"` + ExperimentName string `json:"experiment_name,omitempty"` + ProjectName string `json:"project_name,omitempty"` + ProjectID string `json:"project_id,omitempty"` + ExperimentURL string `json:"experiment_url,omitempty"` + Scores map[string]float64 `json:"scores"` +} diff --git a/trace/trace.go b/trace/trace.go index 84fe479..4b7148b 100644 --- a/trace/trace.go +++ b/trace/trace.go @@ -250,11 +250,13 @@ const ( ParentTypeProjectID ParentType = "project_id" // ParentTypeExperimentID is the type of parent that represents an experiment by ID. ParentTypeExperimentID ParentType = "experiment_id" + // ParentTypePlaygroundID is the type of parent for spans triggered from the Braintrust playground. + ParentTypePlaygroundID ParentType = "playground_id" ) // IsValid returns true if the ParentType is a valid type. func (p ParentType) IsValid() bool { - return p == ParentTypeProjectName || p == ParentTypeProjectID || p == ParentTypeExperimentID + return p == ParentTypeProjectName || p == ParentTypeProjectID || p == ParentTypeExperimentID || p == ParentTypePlaygroundID } // Parent represents where data goes in Braintrust - a project, an experiment, etc. @@ -263,6 +265,11 @@ type Parent struct { ID string } +// IsZero returns true if the Parent is the zero value. +func (p Parent) IsZero() bool { + return p.Type == "" && p.ID == "" +} + // Attr returns the OTel attribute for this parent. func (p Parent) Attr() attribute.KeyValue { return attribute.String(ParentOtelAttrKey, p.String())