Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 109 additions & 1 deletion api/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ package main
import (
"context"
"crypto/sha256"
"crypto/subtle"
"encoding/json"
"encoding/hex"
"errors"
"fmt"
"io"
Expand Down Expand Up @@ -73,10 +75,24 @@ type authConfig struct {
bearerJWKSRefreshInterval time.Duration
bearerJWKSRefreshTimeout time.Duration
bearerJWKSRateLimit time.Duration
bearerStaticPrincipals []staticBearerPrincipal
bearerJWKSInitErr error
bearerJWKSInitLock sync.Mutex
bearerJWKSLastAttempt time.Time
bearerJWKS *keyfunc.JWKS
configErr error
}

type staticBearerPrincipal struct {
ID string
Email string
Teams []string
Type principalType
Subject string
Issuer string
Scopes []string
TokenHash string
parsedHash []byte
}

type principal struct {
Expand Down Expand Up @@ -104,6 +120,7 @@ func newAuthConfig() authConfig {
if mode == authModeAuto {
bearerDefaultType = principalTypeService
}
staticPrincipals, configErr := parseStaticBearerPrincipals(os.Getenv("SPRITZ_AUTH_BEARER_STATIC_PRINCIPALS_JSON"))
return authConfig{
mode: mode,
headerID: envOrDefault("SPRITZ_AUTH_HEADER_ID", "X-Spritz-User-Id"),
Expand Down Expand Up @@ -137,6 +154,8 @@ func newAuthConfig() authConfig {
bearerJWKSRefreshInterval: parseDurationEnv("SPRITZ_AUTH_BEARER_JWKS_REFRESH_INTERVAL", 5*time.Minute),
bearerJWKSRefreshTimeout: parseDurationEnv("SPRITZ_AUTH_BEARER_JWKS_REFRESH_TIMEOUT", 5*time.Second),
bearerJWKSRateLimit: parseDurationEnv("SPRITZ_AUTH_BEARER_JWKS_RATE_LIMIT", 10*time.Second),
bearerStaticPrincipals: staticPrincipals,
configErr: configErr,
}
}

Expand Down Expand Up @@ -267,7 +286,7 @@ func (a *authConfig) principal(r *http.Request) (principal, error) {
a.isAdmin(id, teams),
), nil
}
if a.bearerIntrospectionURL == "" && a.bearerJWKSURL == "" {
if a.bearerIntrospectionURL == "" && a.bearerJWKSURL == "" && len(a.bearerStaticPrincipals) == 0 {
return principal{}, errUnauthenticated
}
return a.principalFromBearer(r)
Expand All @@ -288,6 +307,9 @@ func (a *authConfig) principalFromBearer(r *http.Request) (principal, error) {
if token == "" {
return principal{}, errUnauthenticated
}
if resolved, ok := a.principalFromStaticBearerToken(token); ok {
return resolved, nil
}
if a.bearerJWKSURL != "" {
if resolved, err := a.principalFromJWT(r.Context(), token); err == nil {
return resolved, nil
Expand Down Expand Up @@ -370,6 +392,32 @@ func (a *authConfig) introspectToken(ctx context.Context, token string) (princip
), nil
}

func (a *authConfig) principalFromStaticBearerToken(token string) (principal, bool) {
if len(a.bearerStaticPrincipals) == 0 {
return principal{}, false
}
sum := sha256.Sum256([]byte(strings.TrimSpace(token)))
for _, candidate := range a.bearerStaticPrincipals {
if len(candidate.parsedHash) != len(sum) {
continue
}
if subtle.ConstantTimeCompare(candidate.parsedHash, sum[:]) != 1 {
continue
}
return finalizePrincipal(
candidate.ID,
candidate.Email,
candidate.Teams,
firstNonEmpty(candidate.Subject, candidate.ID),
candidate.Issuer,
candidate.Type,
candidate.Scopes,
a.isAdmin(candidate.ID, candidate.Teams),
), true
}
return principal{}, false
}

func (a *authConfig) jwks() (*keyfunc.JWKS, error) {
if a.bearerJWKS != nil {
return a.bearerJWKS, nil
Expand Down Expand Up @@ -586,6 +634,66 @@ func ownerLabelValue(id string) string {
return fmt.Sprintf("owner-%x", sum[:16])
}

func parseStaticBearerPrincipals(raw string) ([]staticBearerPrincipal, error) {
raw = strings.TrimSpace(raw)
if raw == "" {
return nil, nil
}
var payload []struct {
ID string `json:"id"`
Email string `json:"email"`
Teams []string `json:"teams"`
Type string `json:"type"`
Subject string `json:"subject"`
Issuer string `json:"issuer"`
Scopes []string `json:"scopes"`
TokenHash string `json:"tokenHash"`
}
if err := json.Unmarshal([]byte(raw), &payload); err != nil {
return nil, fmt.Errorf("invalid SPRITZ_AUTH_BEARER_STATIC_PRINCIPALS_JSON: %w", err)
}
principals := make([]staticBearerPrincipal, 0, len(payload))
for index, item := range payload {
id := strings.TrimSpace(item.ID)
if id == "" {
return nil, fmt.Errorf("invalid SPRITZ_AUTH_BEARER_STATIC_PRINCIPALS_JSON: principals[%d].id is required", index)
}
hashHex := strings.ToLower(strings.TrimSpace(item.TokenHash))
if hashHex == "" {
return nil, fmt.Errorf("invalid SPRITZ_AUTH_BEARER_STATIC_PRINCIPALS_JSON: principals[%d].tokenHash is required", index)
}
hashBytes, err := hex.DecodeString(hashHex)
if err != nil || len(hashBytes) != sha256.Size {
return nil, fmt.Errorf("invalid SPRITZ_AUTH_BEARER_STATIC_PRINCIPALS_JSON: principals[%d].tokenHash must be a sha256 hex digest", index)
}
principalTypeValue := normalizePrincipalType(item.Type, principalTypeService)
if len(item.Scopes) == 0 {
return nil, fmt.Errorf("invalid SPRITZ_AUTH_BEARER_STATIC_PRINCIPALS_JSON: principals[%d].scopes is required", index)
}
principals = append(principals, staticBearerPrincipal{
ID: id,
Email: strings.TrimSpace(item.Email),
Teams: dedupeStrings(item.Teams),
Type: principalTypeValue,
Subject: strings.TrimSpace(item.Subject),
Issuer: strings.TrimSpace(item.Issuer),
Scopes: dedupeStrings(item.Scopes),
TokenHash: hashHex,
parsedHash: hashBytes,
})
}
return principals, nil
}

func firstNonEmpty(values ...string) string {
for _, value := range values {
if trimmed := strings.TrimSpace(value); trimmed != "" {
return trimmed
}
}
return ""
}

func envOrDefault(key, fallback string) string {
value := strings.TrimSpace(os.Getenv(key))
if value == "" {
Expand Down
104 changes: 104 additions & 0 deletions api/auth_middleware_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package main

import (
"crypto/sha256"
"encoding/hex"
"encoding/json"
"net/http"
"net/http/httptest"
Expand Down Expand Up @@ -408,3 +410,105 @@ func TestBearerAuthDoesNotGrantAdminFromTypeClaim(t *testing.T) {
t.Fatalf("expected bearer admin claim to remain non-admin")
}
}

func TestBearerAuthAcceptsStaticServicePrincipal(t *testing.T) {
t.Setenv("SPRITZ_AUTH_MODE", "auto")
t.Setenv(
"SPRITZ_AUTH_BEARER_STATIC_PRINCIPALS_JSON",
`[{"id":"zenobot-staging","tokenHash":"`+sha256HexForTest("spritz-static-token")+`","scopes":["spritz.instances.create","spritz.instances.assign_owner"]}]`,
)

s := &server{auth: newAuthConfig()}
e := echo.New()
secured := e.Group("", s.authMiddleware())
secured.GET("/api/spritzes", func(c echo.Context) error {
p, ok := principalFromContext(c)
if !ok {
return c.JSON(http.StatusInternalServerError, map[string]string{"error": "missing principal"})
}
return c.JSON(http.StatusOK, map[string]any{
"id": p.ID,
"type": p.Type,
"scopes": p.Scopes,
})
})

req := httptest.NewRequest(http.MethodGet, "/api/spritzes", nil)
req.Header.Set("Authorization", "Bearer spritz-static-token")
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)

if rec.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", rec.Code, rec.Body.String())
}

payload := map[string]any{}
if err := json.Unmarshal(rec.Body.Bytes(), &payload); err != nil {
t.Fatalf("failed to decode response: %v", err)
}
if payload["id"] != "zenobot-staging" {
t.Fatalf("expected static principal id, got %#v", payload["id"])
}
if payload["type"] != string(principalTypeService) {
t.Fatalf("expected service principal type, got %#v", payload["type"])
}
scopes, _ := payload["scopes"].([]any)
if len(scopes) != 2 {
t.Fatalf("expected two scopes, got %#v", payload["scopes"])
}
}

func TestBearerAuthFallsBackWhenStaticPrincipalMisses(t *testing.T) {
introspection := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_ = json.NewEncoder(w).Encode(map[string]any{
"sub": "user-123",
"email": "user@example.com",
})
}))
defer introspection.Close()

t.Setenv("SPRITZ_AUTH_MODE", "auto")
t.Setenv("SPRITZ_AUTH_BEARER_INTROSPECTION_URL", introspection.URL)
t.Setenv("SPRITZ_AUTH_BEARER_ID_PATHS", "sub")
t.Setenv("SPRITZ_AUTH_BEARER_EMAIL_PATHS", "email")
t.Setenv(
"SPRITZ_AUTH_BEARER_STATIC_PRINCIPALS_JSON",
`[{"id":"zenobot-staging","tokenHash":"`+sha256HexForTest("some-other-token")+`","scopes":["spritz.instances.create"]}]`,
)

s := &server{auth: newAuthConfig()}
e := echo.New()
secured := e.Group("", s.authMiddleware())
secured.GET("/api/spritzes", func(c echo.Context) error {
p, ok := principalFromContext(c)
if !ok {
return c.JSON(http.StatusInternalServerError, map[string]string{"error": "missing principal"})
}
return c.JSON(http.StatusOK, map[string]any{
"id": p.ID,
"email": p.Email,
})
})

req := httptest.NewRequest(http.MethodGet, "/api/spritzes", nil)
req.Header.Set("Authorization", "Bearer test-token")
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)

if rec.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", rec.Code, rec.Body.String())
}

payload := map[string]any{}
if err := json.Unmarshal(rec.Body.Bytes(), &payload); err != nil {
t.Fatalf("failed to decode response: %v", err)
}
if payload["id"] != "user-123" {
t.Fatalf("expected introspection fallback principal id, got %#v", payload["id"])
}
}

func sha256HexForTest(value string) string {
sum := sha256.Sum256([]byte(strings.TrimSpace(value)))
return hex.EncodeToString(sum[:])
}
4 changes: 4 additions & 0 deletions api/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,10 @@ func main() {
}

auth := newAuthConfig()
if auth.configErr != nil {
fmt.Fprintf(os.Stderr, "invalid auth config: %v\n", auth.configErr)
os.Exit(1)
}
ingressDefaults := newIngressDefaults()
terminal := newTerminalConfig()
acp := newACPConfig()
Expand Down
10 changes: 10 additions & 0 deletions helm/spritz/templates/api-deployment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,16 @@ spec:
- name: SPRITZ_AUTH_BEARER_DEFAULT_TYPE
value: {{ .Values.api.auth.bearer.defaultType | quote }}
{{- end }}
{{- if .Values.api.auth.bearer.staticPrincipalsSecret.name }}
- name: SPRITZ_AUTH_BEARER_STATIC_PRINCIPALS_JSON
valueFrom:
secretKeyRef:
name: {{ .Values.api.auth.bearer.staticPrincipalsSecret.name | quote }}
key: {{ .Values.api.auth.bearer.staticPrincipalsSecret.key | quote }}
{{- else if .Values.api.auth.bearer.staticPrincipalsJson }}
- name: SPRITZ_AUTH_BEARER_STATIC_PRINCIPALS_JSON
value: {{ .Values.api.auth.bearer.staticPrincipalsJson | quote }}
{{- end }}
{{- if .Values.api.auth.bearer.jwks.url }}
- name: SPRITZ_AUTH_BEARER_JWKS_URL
value: {{ .Values.api.auth.bearer.jwks.url | quote }}
Expand Down
4 changes: 4 additions & 0 deletions helm/spritz/values.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,10 @@ api:
- scopes
- scp
defaultType: ""
staticPrincipalsJson: ""
staticPrincipalsSecret:
name: ""
key: SPRITZ_AUTH_BEARER_STATIC_PRINCIPALS_JSON
provisioners:
defaultPresetId: ""
allowedPresetIds: []
Expand Down
Loading