Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions api/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"github.com/braintrustdata/braintrust-sdk-go/api/datasets"
"github.com/braintrustdata/braintrust-sdk-go/api/experiments"
"github.com/braintrustdata/braintrust-sdk-go/api/functions"
"github.com/braintrustdata/braintrust-sdk-go/api/objects"
"github.com/braintrustdata/braintrust-sdk-go/api/projects"
"github.com/braintrustdata/braintrust-sdk-go/internal/https"
"github.com/braintrustdata/braintrust-sdk-go/logger"
Expand Down Expand Up @@ -86,3 +87,8 @@ func (a *API) Datasets() *datasets.API {
func (a *API) Functions() *functions.API {
return functions.New(a.client)
}

// Objects is used to access generic object APIs (e.g. /v1/{object_type}/{id}/fetch).
func (a *API) Objects() *objects.API {
return objects.New(a.client)
}
23 changes: 17 additions & 6 deletions api/functions/functions.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,35 +102,46 @@ func (a *API) Invoke(ctx context.Context, functionID string, input any) (any, er
}

path := fmt.Sprintf("/v1/function/%s/invoke", functionID)
return a.invokePath(ctx, path, req)
}

// InvokeGlobal calls a global function by slug/type and returns the output.
func (a *API) InvokeGlobal(ctx context.Context, req InvokeGlobalParams) (any, error) {
if req.GlobalFunction == "" {
return nil, fmt.Errorf("global function is required")
}

return a.invokePath(ctx, "/function/invoke", req)
}

func (a *API) invokePath(ctx context.Context, path string, req any) (any, error) {
resp, err := a.client.POST(ctx, path, req)
if err != nil {
return nil, err
}
defer func() { _ = resp.Body.Close() }()

// Read the entire response body so we can parse it multiple ways
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response body: %w", err)
}

// Parse response - try as object first, then as raw value
return decodeInvokeResponse(body)
}

func decodeInvokeResponse(body []byte) (any, error) {
var response map[string]any
if err := json.Unmarshal(body, &response); err == nil {
// Response is an object, extract output field if present
if output, ok := response["output"]; ok {
return output, nil
}
// If no output field, return the whole object
return response, nil
}

// Response is not an object, try parsing as raw JSON value (string, number, etc.)
var output any
if err := json.Unmarshal(body, &output); err != nil {
return nil, fmt.Errorf("failed to decode response: %w", err)
}

return output, nil
}

Expand Down
12 changes: 12 additions & 0 deletions api/functions/functions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"github.com/braintrustdata/braintrust-sdk-go/api/projects"
"github.com/braintrustdata/braintrust-sdk-go/internal/https"
"github.com/braintrustdata/braintrust-sdk-go/internal/vcr"
"github.com/braintrustdata/braintrust-sdk-go/logger"
)

const integrationTestProject = "go-sdk-tests"
Expand Down Expand Up @@ -318,3 +319,14 @@ func TestFunctions_Invoke_Validation(t *testing.T) {
require.Error(t, err)
assert.Contains(t, err.Error(), "required")
}

func TestFunctions_InvokeGlobal_Validation(t *testing.T) {
t.Parallel()

ctx := context.Background()
client := New(https.NewClient("test-key", "https://example.com", logger.Discard()))

_, err := client.InvokeGlobal(ctx, InvokeGlobalParams{})
require.Error(t, err)
assert.Contains(t, err.Error(), "required")
}
8 changes: 8 additions & 0 deletions api/functions/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,14 @@ type InvokeParams struct {
Input any `json:"input"`
}

// InvokeGlobalParams represents the request payload for invoking a global function.
type InvokeGlobalParams struct {
GlobalFunction string `json:"global_function"`
FunctionType string `json:"function_type,omitempty"`
Mode string `json:"mode,omitempty"`
Input any `json:"input,omitempty"`
}

// QueryResponse represents the response from querying functions.
type QueryResponse struct {
Objects []Function `json:"objects"`
Expand Down
42 changes: 42 additions & 0 deletions api/objects/objects.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package objects

import (
"context"
"encoding/json"
"fmt"

"github.com/braintrustdata/braintrust-sdk-go/internal/https"
)

// API provides methods for generic object operations.
type API struct {
client *https.Client
}

// New creates a new objects API client.
func New(client *https.Client) *API {
return &API{client: client}
}

// Fetch retrieves rows from a given object type and ID.
func (a *API) Fetch(ctx context.Context, objectType, objectID string, params FetchParams) (*FetchResponse, error) {
if objectType == "" {
return nil, fmt.Errorf("object type is required")
}
if objectID == "" {
return nil, fmt.Errorf("object ID is required")
}

path := fmt.Sprintf("/v1/%s/%s/fetch", objectType, objectID)
resp, err := a.client.POST(ctx, path, params)
if err != nil {
return nil, err
}
defer func() { _ = resp.Body.Close() }()

var out FetchResponse
if err := json.NewDecoder(resp.Body).Decode(&out); err != nil {
return nil, fmt.Errorf("error decoding response: %w", err)
}
return &out, nil
}
77 changes: 77 additions & 0 deletions api/objects/objects_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
package objects

import (
"context"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/braintrustdata/braintrust-sdk-go/api/datasets"
"github.com/braintrustdata/braintrust-sdk-go/api/projects"
"github.com/braintrustdata/braintrust-sdk-go/internal/https"
"github.com/braintrustdata/braintrust-sdk-go/internal/vcr"
"github.com/braintrustdata/braintrust-sdk-go/logger"
)

const integrationTestProject = "go-sdk-tests"

func TestObjects_Fetch_Integration(t *testing.T) {
t.Parallel()

ctx := context.Background()
client := vcr.GetHTTPSClient(t)
api := New(client)

// Create a project and dataset with events
projectsAPI := projects.New(client)
project, err := projectsAPI.Create(ctx, projects.CreateParams{Name: integrationTestProject})
require.NoError(t, err)

datasetsAPI := datasets.New(client)
dataset, err := datasetsAPI.Create(ctx, datasets.CreateParams{
ProjectID: project.ID,
Name: "test-objects-fetch",
})
require.NoError(t, err)
defer func() { _ = datasetsAPI.Delete(ctx, dataset.ID) }()

err = datasetsAPI.InsertEvents(ctx, dataset.ID, []datasets.Event{
{Input: map[string]any{"q": "1"}, Expected: map[string]any{"a": "1"}},
{Input: map[string]any{"q": "2"}, Expected: map[string]any{"a": "2"}},
})
require.NoError(t, err)

// Fetch via the generic objects API (retry for eventual consistency)
var rows []map[string]any
for i := 0; i < 3; i++ {
resp, err := api.Fetch(ctx, "dataset", dataset.ID, FetchParams{Limit: 10})
require.NoError(t, err)
require.NotNil(t, resp)

rows = resp.Events
if len(rows) == 0 {
rows = resp.Rows
}
if len(rows) >= 2 {
break
}
time.Sleep(500 * time.Millisecond)
}
assert.GreaterOrEqual(t, len(rows), 2)
}

func TestObjects_Fetch_Validation(t *testing.T) {
t.Parallel()

api := New(https.NewClient("test-key", "https://example.com", logger.Discard()))

_, err := api.Fetch(context.Background(), "", "obj-1", FetchParams{})
require.Error(t, err)
assert.Contains(t, err.Error(), "object type is required")

_, err = api.Fetch(context.Background(), "experiment", "", FetchParams{})
require.Error(t, err)
assert.Contains(t, err.Error(), "object ID is required")
}
Loading
Loading