diff --git a/go.mod b/go.mod index e01846e..39c4cfd 100644 --- a/go.mod +++ b/go.mod @@ -9,6 +9,7 @@ require ( github.com/go-pdf/fpdf v0.9.0 github.com/jackc/pgx/v5 v5.10.0 github.com/nats-io/nats.go v1.52.0 + github.com/writer/cerebro/sdk/go/cerebroapi v0.0.0-20260617165858-b79a38afb59c golang.org/x/crypto v0.53.0 google.golang.org/protobuf v1.36.11 ) diff --git a/go.sum b/go.sum index 355af9f..e1d4b36 100644 --- a/go.sum +++ b/go.sum @@ -30,6 +30,8 @@ github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UV github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +github.com/writer/cerebro/sdk/go/cerebroapi v0.0.0-20260617165858-b79a38afb59c h1:Q4lzmWx5TNGG6wVyXFkc50QYk55wfzr4CZeC10m6XsE= +github.com/writer/cerebro/sdk/go/cerebroapi v0.0.0-20260617165858-b79a38afb59c/go.mod h1:Pj6RePoAtQMzYIZuYdS02A0RZInNozjLDKmEig8JlZM= golang.org/x/crypto v0.53.0 h1:QZ4Muo8THX6CizN2vPPd5fBGHyogrdK9fG4wLPFUsto= golang.org/x/crypto v0.53.0/go.mod h1:DNLU434OwVakk9PzuwV8w62mAJpRJL3vsgcfp4Qnsio= golang.org/x/sync v0.21.0 h1:HLII4xRRTtCRkxYp4HNFF0Js/Og6q2i++KXbg0gHCwM= diff --git a/internal/cerebroclaims/builder.go b/internal/cerebroclaims/builder.go index 15389a4..c1b6002 100644 --- a/internal/cerebroclaims/builder.go +++ b/internal/cerebroclaims/builder.go @@ -11,6 +11,7 @@ import ( cerebrov1 "github.com/writer/aperio/gen/cerebro/v1" "github.com/writer/aperio/internal/cerebroclient" + sdkclaims "github.com/writer/cerebro/sdk/go/cerebroapi/claims" ) type Payload struct { @@ -75,21 +76,21 @@ func Build(input BuildInput) ([]cerebroclient.Claim, error) { } } claims := []cerebroclient.Claim{ - existsClaim(finding, input.Payload, attributes), - existsClaim(target, input.Payload, map[string]string{"provider": provider}), - existsClaim(integration, input.Payload, map[string]string{"provider": provider}), - relationClaim(finding, "affects", target, input.Payload), - relationClaim(finding, "observed_by", integration, input.Payload), - attributeClaim(finding, "title", title, input.Payload), - attributeClaim(finding, "provider", provider, input.Payload), + sdkclaims.Exists(finding, claimSource(input.Payload, attributes)), + sdkclaims.Exists(target, claimSource(input.Payload, map[string]string{"provider": provider})), + sdkclaims.Exists(integration, claimSource(input.Payload, map[string]string{"provider": provider})), + sdkclaims.Relation(finding, "affects", target, claimSource(input.Payload, nil)), + sdkclaims.Relation(finding, "observed_by", integration, claimSource(input.Payload, nil)), + sdkclaims.Attribute(finding, "title", title, claimSource(input.Payload, nil)), + sdkclaims.Attribute(finding, "provider", provider, claimSource(input.Payload, nil)), } for _, key := range []string{"severity", "riskScore", "status", "ruleId"} { if value := firstString(input.Payload.Record[key]); value != "" { - claims = append(claims, attributeClaim(finding, key, value, input.Payload)) + claims = append(claims, sdkclaims.Attribute(finding, key, value, claimSource(input.Payload, nil))) } } if description := firstString(input.Payload.Record["description"]); description != "" { - claims = append(claims, attributeClaim(finding, "description", description, input.Payload)) + claims = append(claims, sdkclaims.Attribute(finding, "description", description, claimSource(input.Payload, nil))) } return claims, nil } @@ -142,45 +143,14 @@ func EncodeExternalID(value string) string { return builder.String() } -func claimBase(payload Payload, attributes map[string]string) cerebroclient.Claim { - return cerebroclient.Claim{ - Status: "asserted", +func claimSource(payload Payload, attributes map[string]string) sdkclaims.Source { + return sdkclaims.Source{ SourceEventID: firstString(payload.Record["sourceEventId"]), ObservedAt: payload.OccurredAt, Attributes: attributes, } } -func existsClaim(subject cerebroclient.EntityRef, payload Payload, attributes map[string]string) cerebroclient.Claim { - claim := claimBase(payload, attributes) - claim.SubjectURN = subject.URN - claim.SubjectRef = subject - claim.Predicate = "exists" - claim.ClaimType = "existence" - return claim -} - -func attributeClaim(subject cerebroclient.EntityRef, predicate string, value string, payload Payload) cerebroclient.Claim { - claim := claimBase(payload, nil) - claim.SubjectURN = subject.URN - claim.SubjectRef = subject - claim.Predicate = predicate - claim.ObjectValue = value - claim.ClaimType = "attribute" - return claim -} - -func relationClaim(subject cerebroclient.EntityRef, predicate string, object cerebroclient.EntityRef, payload Payload) cerebroclient.Claim { - claim := claimBase(payload, nil) - claim.SubjectURN = subject.URN - claim.SubjectRef = subject - claim.Predicate = predicate - claim.ObjectURN = object.URN - claim.ObjectRef = &object - claim.ClaimType = "relation" - return claim -} - func firstString(values ...any) string { for _, value := range values { switch typed := value.(type) { diff --git a/internal/cerebroclient/client.go b/internal/cerebroclient/client.go index c01a0c8..7ec7b2b 100644 --- a/internal/cerebroclient/client.go +++ b/internal/cerebroclient/client.go @@ -1,426 +1,57 @@ package cerebroclient import ( - "bytes" - "context" - "encoding/json" - "errors" - "fmt" - "io" "net/http" - "net/url" - "os" - "strconv" "strings" - "time" + + "github.com/writer/cerebro/sdk/go/cerebroapi" ) const ( DefaultRuntimeID = "writer-aperio-saas-dr" DefaultSourceID = "aperio_saas_dr" - defaultTimeout = 15 * time.Second userAgent = "aperio-cerebro-client" ) -// Config contains the minimal hosted Cerebro connection settings Aperio needs. -// The browser-facing Aperio session remains cookie-based; this credential is -// the server-side integration principal used for Cerebro HTTP and MCP calls. -type Config struct { - BaseURL string - MCPURL string - APIKey string - TenantID string - RuntimeID string - SourceID string - Timeout time.Duration -} - -func ConfigFromEnv() Config { - return Config{ - BaseURL: trimEnv("CEREBRO_BASE_URL"), - MCPURL: trimEnv("CEREBRO_MCP_URL"), - APIKey: firstEnv("CEREBRO_API_KEY", "CEREBRO_TOKEN"), - TenantID: trimEnv("CEREBRO_TENANT_ID"), - RuntimeID: envDefault("CEREBRO_SOURCE_RUNTIME_ID", DefaultRuntimeID), - SourceID: envDefault("CEREBRO_SOURCE_ID", DefaultSourceID), - Timeout: durationEnv("CEREBRO_HTTP_TIMEOUT_SECONDS", defaultTimeout), - } -} - -func (c Config) MCPServerURL() string { - if strings.TrimSpace(c.MCPURL) != "" { - mcpURL, err := parseBaseURL(c.MCPURL) - if err != nil { - return "" - } - return mcpURL.String() - } - baseURL, err := parseBaseURL(c.BaseURL) - if err != nil { - return "" - } - value := *baseURL - value.Path, value.RawPath = mcpPathFromBase(value.Path, value.EscapedPath()) - value.RawQuery = "" - value.Fragment = "" - return value.String() -} +type Config = cerebroapi.Config +type Option = cerebroapi.Option +type HTTPError = cerebroapi.HTTPError -func mcpPathFromBase(path string, rawPath string) (string, string) { - path = strings.TrimRight(path, "/") - rawPath = strings.TrimRight(rawPath, "/") - suffix := "/api/v1/mcp" - if strings.HasSuffix(path, "/api") { - suffix = "/v1/mcp" - } else if hasVersionedAPISuffix(path) { - suffix = "/mcp" - } - return path + suffix, rawPath + suffix +type Client struct { + *cerebroapi.Client } -func hasVersionedAPISuffix(path string) bool { - versionStart := strings.LastIndex(path, "/") - if versionStart < 0 || !strings.HasSuffix(path[:versionStart], "/api") { - return false +func ConfigFromEnv() Config { + config := cerebroapi.ConfigFromEnv() + if strings.TrimSpace(config.RuntimeID) == "" { + config.RuntimeID = DefaultRuntimeID } - version := path[versionStart+1:] - if len(version) < 2 || version[0] != 'v' { - return false + if strings.TrimSpace(config.SourceID) == "" { + config.SourceID = DefaultSourceID } - for _, ch := range version[1:] { - if ch < '0' || ch > '9' { - return false - } + if strings.TrimSpace(config.UserAgent) == "" { + config.UserAgent = userAgent } - return true -} - -func (c Config) Enabled() bool { - return strings.TrimSpace(c.BaseURL) != "" && - strings.TrimSpace(c.APIKey) != "" && - strings.TrimSpace(c.TenantID) != "" -} - -func (c Config) DefaultRuntime() SourceRuntime { - return SourceRuntime{ - ID: envOrValue(c.RuntimeID, DefaultRuntimeID), - SourceID: envOrValue(c.SourceID, DefaultSourceID), - TenantID: strings.TrimSpace(c.TenantID), - Config: map[string]string{ + if len(config.RuntimeConfig) == 0 { + config.RuntimeConfig = map[string]string{ "surface": "aperio_saas_dr", "owner": "aperio", - }, + } } + return config } -type Client struct { - baseURL *url.URL - apiKey string - tenantID string - httpClient *http.Client -} - -type Option func(*Client) - func WithHTTPClient(httpClient *http.Client) Option { - return func(c *Client) { - if httpClient != nil { - c.httpClient = httpClient - } - } + return cerebroapi.WithHTTPClient(httpClient) } func New(config Config, options ...Option) (*Client, error) { - baseURL, err := parseBaseURL(config.BaseURL) - if err != nil { - return nil, err - } - apiKey := strings.TrimSpace(config.APIKey) - if apiKey == "" { - return nil, errors.New("cerebro API key is required") - } - tenantID := strings.TrimSpace(config.TenantID) - if tenantID == "" { - return nil, errors.New("cerebro tenant id is required") - } - timeout := config.Timeout - if timeout <= 0 { - timeout = defaultTimeout - } - client := &Client{ - baseURL: baseURL, - apiKey: apiKey, - tenantID: tenantID, - httpClient: &http.Client{ - Timeout: timeout, - CheckRedirect: blockRedirects, - }, - } - for _, option := range options { - option(client) - } - if client.httpClient.CheckRedirect == nil { - transportClient := *client.httpClient - transportClient.CheckRedirect = blockRedirects - client.httpClient = &transportClient - } - return client, nil -} - -func (c *Client) GetSourceRuntime(ctx context.Context, runtimeID string) (*SourceRuntime, error) { - runtimeID = strings.TrimSpace(runtimeID) - if runtimeID == "" { - return nil, errors.New("cerebro runtime id is required") - } - var response sourceRuntimeResponse - if err := c.doJSON(ctx, http.MethodGet, c.runtimeURL(runtimeID), nil, &response); err != nil { - return nil, err - } - return runtimeFromResponse(response) -} - -func (c *Client) PutSourceRuntime(ctx context.Context, runtime SourceRuntime) (*SourceRuntime, error) { - runtime.ID = strings.TrimSpace(runtime.ID) - if runtime.ID == "" { - return nil, errors.New("cerebro runtime id is required") - } - if strings.TrimSpace(runtime.TenantID) == "" { - runtime.TenantID = c.tenantID - } - if strings.TrimSpace(runtime.SourceID) == "" { - return nil, errors.New("cerebro source id is required") - } - var response sourceRuntimeResponse - body := map[string]SourceRuntime{"runtime": runtime} - if err := c.doJSON(ctx, http.MethodPut, c.runtimeURL(runtime.ID), body, &response); err != nil { - return nil, err - } - return runtimeFromResponse(response) -} - -func (c *Client) EnsureDefaultRuntime(ctx context.Context, config Config) (*SourceRuntime, error) { - return c.PutSourceRuntime(ctx, config.DefaultRuntime()) -} - -func (c *Client) ListClaims(ctx context.Context, request ListClaimsRequest) (*ListClaimsResponse, error) { - request.RuntimeID = strings.TrimSpace(request.RuntimeID) - if request.RuntimeID == "" { - return nil, errors.New("cerebro runtime id is required") - } - query := url.Values{} - addQuery(query, "claim_id", request.ClaimID) - addQuery(query, "subject_urn", request.SubjectURN) - addQuery(query, "predicate", request.Predicate) - addQuery(query, "object_urn", request.ObjectURN) - addQuery(query, "object_value", request.ObjectValue) - addQuery(query, "claim_type", request.ClaimType) - addQuery(query, "status", request.Status) - addQuery(query, "source_event_id", request.SourceEventID) - if request.Limit > 0 { - query.Set("limit", strconv.FormatUint(uint64(request.Limit), 10)) - } - var response ListClaimsResponse - if err := c.doJSON(ctx, http.MethodGet, withQuery(c.runtimeURL(request.RuntimeID, "claims"), query), nil, &response); err != nil { - return nil, err - } - return &response, nil -} - -func (c *Client) WriteClaims(ctx context.Context, request WriteClaimsRequest) (*WriteClaimsResponse, error) { - request.RuntimeID = strings.TrimSpace(request.RuntimeID) - if request.RuntimeID == "" { - return nil, errors.New("cerebro runtime id is required") - } - if len(request.Claims) == 0 { - return nil, errors.New("at least one cerebro claim is required") - } - var response WriteClaimsResponse - if err := c.doJSON(ctx, http.MethodPost, c.runtimeURL(request.RuntimeID, "claims"), request, &response); err != nil { - return nil, err - } - return &response, nil -} - -func (c *Client) GetEntityNeighborhood(ctx context.Context, rootURN string, limit uint32) (*EntityNeighborhood, error) { - rootURN = strings.TrimSpace(rootURN) - if rootURN == "" { - return nil, errors.New("cerebro graph root urn is required") + if strings.TrimSpace(config.UserAgent) == "" { + config.UserAgent = userAgent } - query := url.Values{"root_urn": []string{rootURN}} - if limit > 0 { - query.Set("limit", strconv.FormatUint(uint64(limit), 10)) - } - var response EntityNeighborhood - if err := c.doJSON(ctx, http.MethodGet, withQuery(c.urlFor("platform", "graph", "neighborhood"), query), nil, &response); err != nil { - return nil, err - } - return &response, nil -} - -func (c *Client) doJSON(ctx context.Context, method string, endpoint string, body any, target any) error { - var bodyReader io.Reader - if body != nil { - encoded, err := json.Marshal(body) - if err != nil { - return err - } - bodyReader = bytes.NewReader(encoded) - } - req, err := http.NewRequestWithContext(ctx, method, endpoint, bodyReader) - if err != nil { - return err - } - req.Header.Set("Accept", "application/json") - req.Header.Set("Authorization", "Bearer "+c.apiKey) - req.Header.Set("User-Agent", userAgent) - req.Header.Set("X-Cerebro-Tenant", c.tenantID) - if body != nil { - req.Header.Set("Content-Type", "application/json") - } - resp, err := c.httpClient.Do(req) - if err != nil { - return err - } - defer resp.Body.Close() - if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices { - return newHTTPError(resp) - } - if target == nil { - _, _ = io.Copy(io.Discard, resp.Body) - return nil - } - if err := json.NewDecoder(resp.Body).Decode(target); err != nil { - return fmt.Errorf("decode cerebro response: %w", err) - } - return nil -} - -func (c *Client) runtimeURL(runtimeID string, suffix ...string) string { - segments := []string{"source-runtimes", strings.TrimSpace(runtimeID)} - segments = append(segments, suffix...) - return c.urlFor(segments...) -} - -func (c *Client) urlFor(segments ...string) string { - value := *c.baseURL - pathPrefix := strings.TrimRight(value.Path, "/") - rawPathPrefix := strings.TrimRight(value.EscapedPath(), "/") - escaped := make([]string, 0, len(segments)) - for _, segment := range segments { - escaped = append(escaped, url.PathEscape(segment)) - } - value.Path = pathPrefix + "/" + strings.Join(segments, "/") - value.RawPath = rawPathPrefix + "/" + strings.Join(escaped, "/") - value.RawQuery = "" - value.Fragment = "" - return value.String() -} - -func addQuery(query url.Values, key string, value string) { - if trimmed := strings.TrimSpace(value); trimmed != "" { - query.Set(key, trimmed) - } -} - -func withQuery(endpoint string, query url.Values) string { - if len(query) == 0 { - return endpoint - } - parsed, err := url.Parse(endpoint) - if err != nil { - return endpoint - } - parsed.RawQuery = query.Encode() - return parsed.String() -} - -func parseBaseURL(raw string) (*url.URL, error) { - trimmed := strings.TrimSpace(raw) - if trimmed == "" { - return nil, errors.New("cerebro base url is required") - } - parsed, err := url.Parse(trimmed) + client, err := cerebroapi.New(config, options...) if err != nil { - return nil, fmt.Errorf("parse cerebro base url: %w", err) - } - if parsed.Scheme != "http" && parsed.Scheme != "https" { - return nil, errors.New("cerebro base url must use http or https") - } - if parsed.Host == "" { - return nil, errors.New("cerebro base url must include a host") - } - if parsed.User != nil || parsed.RawQuery != "" || parsed.Fragment != "" { - return nil, errors.New("cerebro base url must not include credentials, query, or fragment") - } - return parsed, nil -} - -func blockRedirects(*http.Request, []*http.Request) error { - return http.ErrUseLastResponse -} - -func runtimeFromResponse(response sourceRuntimeResponse) (*SourceRuntime, error) { - if response.Runtime == nil { - return nil, errors.New("cerebro response missing runtime") - } - return response.Runtime, nil -} - -type HTTPError struct { - StatusCode int - Status string - Body string -} - -func (e HTTPError) Error() string { - if e.Body == "" { - return fmt.Sprintf("cerebro http error: %d %s", e.StatusCode, e.Status) - } - return fmt.Sprintf("cerebro http error: %d %s: %s", e.StatusCode, e.Status, e.Body) -} - -func newHTTPError(resp *http.Response) error { - body, _ := io.ReadAll(io.LimitReader(resp.Body, 512)) - return HTTPError{ - StatusCode: resp.StatusCode, - Status: strings.TrimSpace(resp.Status), - Body: strings.TrimSpace(string(body)), - } -} - -func trimEnv(key string) string { - return strings.TrimSpace(os.Getenv(key)) -} - -func firstEnv(keys ...string) string { - for _, key := range keys { - if value := trimEnv(key); value != "" { - return value - } - } - return "" -} - -func envDefault(key string, fallback string) string { - return envOrValue(trimEnv(key), fallback) -} - -func envOrValue(value string, fallback string) string { - trimmed := strings.TrimSpace(value) - if trimmed == "" { - return fallback - } - return trimmed -} - -func durationEnv(key string, fallback time.Duration) time.Duration { - value := trimEnv(key) - if value == "" { - return fallback - } - seconds, err := strconv.Atoi(value) - if err != nil || seconds <= 0 { - return fallback + return nil, err } - return time.Duration(seconds) * time.Second + return &Client{Client: client}, nil } diff --git a/internal/cerebroclient/types.go b/internal/cerebroclient/types.go index 259fb21..1ae921f 100644 --- a/internal/cerebroclient/types.go +++ b/internal/cerebroclient/types.go @@ -1,83 +1,14 @@ package cerebroclient -type SourceRuntime struct { - ID string `json:"id,omitempty"` - SourceID string `json:"source_id"` - TenantID string `json:"tenant_id"` - Config map[string]string `json:"config,omitempty"` -} - -type EntityRef struct { - URN string `json:"urn"` - EntityType string `json:"entity_type"` - Label string `json:"label"` -} - -type Claim struct { - ID string `json:"id,omitempty"` - SubjectURN string `json:"subject_urn"` - SubjectRef EntityRef `json:"subject_ref"` - Predicate string `json:"predicate"` - ObjectURN string `json:"object_urn,omitempty"` - ObjectRef *EntityRef `json:"object_ref,omitempty"` - ObjectValue string `json:"object_value,omitempty"` - ClaimType string `json:"claim_type"` - Status string `json:"status"` - SourceEventID string `json:"source_event_id,omitempty"` - ObservedAt string `json:"observed_at,omitempty"` - ValidFrom string `json:"valid_from,omitempty"` - ValidTo string `json:"valid_to,omitempty"` - Attributes map[string]string `json:"attributes,omitempty"` -} - -type ListClaimsRequest struct { - RuntimeID string - ClaimID string - SubjectURN string - Predicate string - ObjectURN string - ObjectValue string - ClaimType string - Status string - SourceEventID string - Limit uint32 -} - -type ListClaimsResponse struct { - Claims []Claim `json:"claims"` -} - -type WriteClaimsRequest struct { - RuntimeID string `json:"runtime_id"` - Claims []Claim `json:"claims"` - ReplaceExisting bool `json:"replace_existing,omitempty"` -} - -type WriteClaimsResponse struct { - ClaimsWritten uint32 `json:"claims_written"` - EntitiesUpserted uint32 `json:"entities_upserted"` - RelationLinksProjected uint32 `json:"relation_links_projected"` - ClaimsRetracted uint32 `json:"claims_retracted"` -} - -type sourceRuntimeResponse struct { - Runtime *SourceRuntime `json:"runtime"` -} - -type GraphEntity struct { - URN string `json:"urn"` - EntityType string `json:"entity_type"` - Label string `json:"label"` -} - -type GraphRelation struct { - FromURN string `json:"from_urn"` - Relation string `json:"relation"` - ToURN string `json:"to_urn"` -} - -type EntityNeighborhood struct { - Root *GraphEntity `json:"root,omitempty"` - Neighbors []GraphEntity `json:"neighbors,omitempty"` - Relations []GraphRelation `json:"relations,omitempty"` -} +import "github.com/writer/cerebro/sdk/go/cerebroapi" + +type SourceRuntime = cerebroapi.SourceRuntime +type EntityRef = cerebroapi.EntityRef +type Claim = cerebroapi.Claim +type ListClaimsRequest = cerebroapi.ListClaimsRequest +type ListClaimsResponse = cerebroapi.ListClaimsResponse +type WriteClaimsRequest = cerebroapi.WriteClaimsRequest +type WriteClaimsResponse = cerebroapi.WriteClaimsResponse +type GraphEntity = cerebroapi.GraphEntity +type GraphRelation = cerebroapi.GraphRelation +type EntityNeighborhood = cerebroapi.EntityNeighborhood