From 232975622382c30a1f9eddf5a01543670c11a550 Mon Sep 17 00:00:00 2001 From: Tom Proctor Date: Wed, 6 May 2026 16:20:06 +0000 Subject: [PATCH] gocached: support namespaces Remove the special-case concept of "global writes" and instead allow callers to provide a namespace mapping function that makes policy decisions about which clients are globally trusted and which clients should be isolated from each other. As a result, all clients are now able to write, but not necessarily into the shared global namespace. All clients can still read from the global namespace, as well as their own. The Namespaces table already existed in the schema, but we drop the lowercase constraint. However, there are some breaking changes in the package API, WithJWTAuth now takes issuer URLs only and policy moves into the new WithNamespaceMapping option. cmd/gocached implements the spirit of the old API in terms of a namespace mapping function, with the main difference that it now allows writes if you don't have the global claims, but just into your own isolated namespace. Updates tailscale/corp#38092 Signed-off-by: Tom Proctor --- cmd/gocached/gocached.go | 31 ++-- gocached/gocached.go | 321 +++++++++++++++++++++++------------- gocached/gocached_test.go | 331 +++++++++++++++++++++++--------------- 3 files changed, 434 insertions(+), 249 deletions(-) diff --git a/cmd/gocached/gocached.go b/cmd/gocached/gocached.go index 01551d4..8e0f8bc 100644 --- a/cmd/gocached/gocached.go +++ b/cmd/gocached/gocached.go @@ -9,7 +9,6 @@ import ( "flag" "fmt" "log" - "maps" "net" "net/http" "os" @@ -77,15 +76,27 @@ func main() { log.Fatal("must specify --jwt-claim at least once when --jwt-issuer is set") } - globalClaims := map[string]string{} - maps.Copy(globalClaims, jwtClaims) - maps.Copy(globalClaims, globalJWTClaims) - - opts = append(opts, gocached.WithJWTAuth(gocached.JWTIssuerConfig{ - Issuer: *jwtIssuer, - RequiredClaims: jwtClaims, - GlobalWriteClaims: globalClaims, - })) + opts = append(opts, + gocached.WithJWTAuth(*jwtIssuer), + gocached.WithNamespaceMapping(func(claims map[string]any) (gocached.Namespace, error) { + var ns gocached.Namespace + for k, want := range jwtClaims { + if got := claims[k]; got != want { + return "", fmt.Errorf("claim %q = %v, want %v", k, got, want) + } + if ns != "" { + ns += "," + } + ns += gocached.Namespace(fmt.Sprintf("%s=%s", k, want)) + } + for k, want := range globalJWTClaims { + if got := claims[k]; got != want { + return ns, nil + } + } + return gocached.GlobalNamespace, nil + }), + ) } srv, err := gocached.NewServer(opts...) diff --git a/gocached/gocached.go b/gocached/gocached.go index e7b394f..3f7af5b 100644 --- a/gocached/gocached.go +++ b/gocached/gocached.go @@ -191,7 +191,7 @@ CREATE UNIQUE INDEX IF NOT EXISTS idx_blobs_sha256 ON Blobs(SHA256); CREATE TABLE IF NOT EXISTS Namespaces ( NamespaceID INTEGER PRIMARY KEY AUTOINCREMENT, - Namespace TEXT NOT NULL UNIQUE CHECK (Namespace = lower(Namespace)) + Namespace TEXT NOT NULL UNIQUE ) STRICT; -- BlobShardStats persists per-shard usage histograms across restarts so @@ -230,6 +230,15 @@ func openDB(dbDir string) (*sql.DB, error) { db.SetMaxOpenConns(numConns) db.SetMaxIdleConns(numConns) db.SetConnMaxLifetime(0) // no limit + var ddl string + err = db.QueryRow(`SELECT sql FROM sqlite_master WHERE type='table' AND name='Namespaces'`).Scan(&ddl) + if err == nil && strings.Contains(ddl, "lower(Namespace)") { + // Drop the Namespaces table _only_ if it has the lowercase constraint so it + // can be recreated without it. + if _, err := db.Exec(`DROP TABLE Namespaces`); err != nil { + return nil, fmt.Errorf("dropping Namespaces lowercase constraint: %w", err) + } + } if _, err := db.Exec(schema); err != nil { return nil, err } @@ -365,6 +374,24 @@ func (srv *Server) start() error { Buckets: prometheus.DefBuckets, }) + // Fill the namespace ID cache. + rows, err := srv.db.Query("SELECT NamespaceID, Namespace FROM Namespaces") + if err != nil { + return err + } + defer rows.Close() + for rows.Next() { + var id int64 + var ns string + if err := rows.Scan(&id, &ns); err != nil { + return err + } + srv.namespaces[Namespace(ns)] = id + } + if err := rows.Err(); err != nil { + return err + } + reg := prometheus.NewRegistry() reg.MustRegister( collectors.NewGoCollector(), @@ -444,17 +471,13 @@ func (srv *Server) start() error { len(srv.shardStats), srv.numShards(), srv.lastUsage.Load().All(), bytesFmt(srv.maxSize)) if len(srv.jwtIssuers) > 0 { - issuerURLs := make([]string, 0, len(srv.jwtIssuers)) - for iss := range srv.jwtIssuers { - issuerURLs = append(issuerURLs, iss) - } - srv.jwtValidator = ijwt.NewJWTValidator(srv.logf, gocachedAudience, issuerURLs) + srv.jwtValidator = ijwt.NewJWTValidator(srv.logf, gocachedAudience, srv.jwtIssuers) if err := srv.jwtValidator.RunUpdateJWKSLoop(srv.shutdownCtx); err != nil { return fmt.Errorf("failed to fetch JWKS for JWT validator: %w", err) } - for iss, entry := range srv.jwtIssuers { - srv.logf("gocached: using JWT issuer %q with required claims %v, global write claims %v", iss, entry.requiredClaims, entry.globalWriteClaims) + for _, iss := range srv.jwtIssuers { + srv.logf("gocached: using JWT issuer %q", iss) } go srv.runCleanSessionsLoop() @@ -568,40 +591,44 @@ func WithShardPrefixLen(n int) ServerOption { } } -// JWTIssuerConfig configures a single OIDC issuer for JWT-based authentication. -type JWTIssuerConfig struct { - // Issuer is the OIDC issuer URL. It must be a reachable HTTP(S) server - // that serves its JWKS via a URL discoverable at - // /.well-known/openid-configuration. - Issuer string +// Namespace identifies a logical partition of the cache where each peer is +// equally trusted. Every session is associated with exactly one Namespace to +// which it can read and write; sessions for non-global namespaces also read +// from [GlobalNamespace]. See [WithNamespaceMapping]. It may only contain +// characters from the set [a-zA-Z0-9._~:/@+|=-]. +type Namespace string - // RequiredClaims are claims that any JWT from this issuer must have to - // start a session. All key-value pairs must match exactly. - RequiredClaims map[string]string +// GlobalNamespace is a trusted namespace that all sessions can read from. Only +// sessions explicitly mapped to GlobalNamespace can write to it. +const GlobalNamespace Namespace = "" - // GlobalWriteClaims are claims that a JWT from this issuer must have to - // write to the cache's global namespace. It should be a superset of - // RequiredClaims. - GlobalWriteClaims map[string]string +// WithJWTAuth enables JWT-based authentication for the server. Each issuer +// must be a reachable HTTP(S) server that serves its JWKS via a URL +// discoverable at /.well-known/openid-configuration. JWTs presented for token +// exchange must pass the standard signature/issuer/audience/expiry checks +// against one of these issuers. If [WithNamespaceMapping] is provided, then +// it may still be rejected if the mapping function returns an error for its +// claims. No requests other than token exchange are allowed without +// authentication. It may be called multiple times; issuers accumulate. +func WithJWTAuth(issuers ...string) ServerOption { + return func(srv *Server) { + srv.jwtIssuers = append(srv.jwtIssuers, issuers...) + } } -// WithJWTAuth enables JWT-based authentication for the server. Each issuer must -// be a reachable HTTP(S) server that serves its JWKS via a URL discoverable at -// /.well-known/openid-configuration, and any JWT presented to the server must -// exactly match the issuer's required claims to start a session. No requests are -// allowed without authentication if JWT auth is enabled. It can be called multiple -// times; configs accumulate. -func WithJWTAuth(issuers ...JWTIssuerConfig) ServerOption { +// WithNamespaceMapping sets the function that makes policy decisions based on +// a JWT's claims. It is called once per token exchange after the JWT's +// signature and standard claims have been validated. It should return an error +// if the claims are not authorized, and otherwise return which [Namespace] the +// session is allowed to read and write in. See [Namespace] for character set +// constraints. All authorized sessions are allowed to read from the +// [GlobalNamespace] regardless of the namespace returned. Check claims["iss"] +// to switch on per-issuer rules. If JWT auth is enabled but no mapping +// function is provided, all sessions will read and write in the +// [GlobalNamespace]. +func WithNamespaceMapping(fn func(claims map[string]any) (Namespace, error)) ServerOption { return func(srv *Server) { - if srv.jwtIssuers == nil { - srv.jwtIssuers = make(map[string]*jwtIssuerConfig) - } - for _, ic := range issuers { - srv.jwtIssuers[ic.Issuer] = &jwtIssuerConfig{ - requiredClaims: ic.RequiredClaims, - globalWriteClaims: ic.GlobalWriteClaims, - } - } + srv.namespaceMapping = fn } } @@ -613,12 +640,21 @@ func NewServer(opts ...ServerOption) (*Server, error) { shutdownCtx: context.Background(), logf: log.Printf, sessions: make(map[string]*sessionData), + namespaces: make(map[Namespace]int64), clock: time.Now, } for _, opt := range opts { opt(srv) } + if len(srv.jwtIssuers) > 0 && srv.namespaceMapping == nil { + // If JWT auth is enabled, but not namespace mapping, every session is in + // the global namespace. + srv.namespaceMapping = func(claims map[string]any) (Namespace, error) { + return GlobalNamespace, nil + } + } + err := srv.start() if err != nil { return nil, err @@ -668,13 +704,15 @@ type Server struct { shutdownCtx context.Context shutdownCancel context.CancelFunc - jwtValidator *ijwt.Validator // nil unless jwtIssuers is non-empty - jwtIssuers map[string]*jwtIssuerConfig // keyed by issuer URL + jwtValidator *ijwt.Validator // nil unless jwtIssuers is non-empty + jwtIssuers []string // accepted issuer URLs + namespaceMapping func(claims map[string]any) (Namespace, error) // required when jwtIssuers is non-empty mu sync.RWMutex // guards following fields in this block sessions map[string]*sessionData // maps access token -> session data. accessDirty map[actionKey]int64 // action -> accessTime accessFlushTimer *time.Timer // nil if no flush is scheduled + namespaces map[Namespace]int64 // cached namespace string -> NamespaceID // sqliteWriteMu serializes access to SQLite. In theory the SQLite driver // should serialize access with our 5000ms busy timeout, but empirically we @@ -763,18 +801,13 @@ type Server struct { } } -// jwtIssuerConfig holds per-issuer claim requirements for JWT auth. -type jwtIssuerConfig struct { - requiredClaims map[string]string - globalWriteClaims map[string]string -} - // sessionData corresponds to a specific access token, and is only used if JWT // auth is enabled. type sessionData struct { - expiry time.Time // Session valid until. - globalNSWrite bool // Whether this session can write to the cache's global namespace. - claims map[string]any // Claims from the JWT used to create this session, stored for debug. + expiry time.Time // Session valid until. + namespaceID int64 // The namespace this session writes to. 0 means GlobalNamespace; non-zero sessions also read from 0. + namespace Namespace // Namespace this session writes to, stored for debug. + claims map[string]any // Claims from the JWT used to create this session, stored for debug. mu sync.Mutex // Guards stats. stats stats @@ -884,12 +917,11 @@ func (srv *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { } if r.Method == "PUT" { - if sessionData != nil && !sessionData.globalNSWrite { - // TODO(tomhjp): support per-namespace writes. - http.Error(w, "forbidden", http.StatusForbidden) - return + var writeNS int64 + if sessionData != nil { + writeNS = sessionData.namespaceID } - srv.handlePut(w, r, reqStats) + srv.handlePut(w, r, reqStats, writeNS) return } if r.Method != "GET" && r.Method != "HEAD" { @@ -897,7 +929,7 @@ func (srv *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } if strings.HasPrefix(r.URL.Path, "/action/") { - srv.handleGetAction(w, r, reqStats) + srv.handleGetAction(w, r, reqStats, sessionData) return } if sessionData != nil && r.URL.Path == "/session/stats" { @@ -966,7 +998,7 @@ func getHexSuffix(r *http.Request, prefix string) (hexSuffix string, ok bool) { // actionKey is the comparable value type for the (NamespaceID, ActionID) // primary key tuple used in the SQLite Actions table. type actionKey struct { - NamespaceID int // 0 for global + NamespaceID int64 // 0 for global ActionID string } @@ -988,7 +1020,27 @@ func validHex(x string) bool { // we do a DB write to update it. const relAtimeSeconds = 60 * 60 * 24 // 1 day -func (srv *Server) handleGetAction(w http.ResponseWriter, r *http.Request, stats *stats) { +const getFromGlobalNamespace = ` +SELECT b.SHA256, b.StoredSize, b.UncompressedSize, b.SmallData, a.AltOutputID, a.AccessTime, a.NamespaceID +FROM Actions a, Blobs b +WHERE a.NamespaceID = 0 + AND a.ActionID = ? + AND a.BlobID = b.BlobID +` + +// Get hits from the global namespace first so shared cache mtime is bumped +// with higher priority than namespaced cache. +const getFromSessionNamespace = ` +SELECT b.SHA256, b.StoredSize, b.UncompressedSize, b.SmallData, a.AltOutputID, a.AccessTime, a.NamespaceID +FROM Actions a, Blobs b +WHERE a.NamespaceID IN (0, ?) + AND a.ActionID = ? + AND a.BlobID = b.BlobID +ORDER BY CASE a.NamespaceID WHEN 0 THEN 0 ELSE 1 END +LIMIT 1 +` + +func (srv *Server) handleGetAction(w http.ResponseWriter, r *http.Request, stats *stats, sessionData *sessionData) { srv.m.ActiveGets.Add(1) defer srv.m.ActiveGets.Add(-1) @@ -1014,19 +1066,24 @@ func (srv *Server) handleGetAction(w http.ResponseWriter, r *http.Request, stats return } - var sha256hex string - var storedSize, uncompressedSize int64 - var smallData sql.NullString - var altObjectID string - var accessTime int64 - var actionKey = actionKey{ - NamespaceID: 0, // global for now; TODO(bradfitz): support namespac - ActionID: actionID, + var ( + sha256hex string + storedSize, uncompressedSize int64 + smallData sql.NullString + altObjectID string + accessTime int64 + err error + actionKey = actionKey{ + ActionID: actionID, + } + ) + if sessionData != nil && sessionData.namespaceID != 0 { + err = srv.db.QueryRow(getFromSessionNamespace, sessionData.namespaceID, actionID).Scan( + &sha256hex, &storedSize, &uncompressedSize, &smallData, &altObjectID, &accessTime, &actionKey.NamespaceID) + } else { + err = srv.db.QueryRow(getFromGlobalNamespace, actionID).Scan( + &sha256hex, &storedSize, &uncompressedSize, &smallData, &altObjectID, &accessTime, &actionKey.NamespaceID) } - err := srv.db.QueryRow( - "SELECT b.SHA256, b.StoredSize, b.UncompressedSize, b.SmallData, a.AltOutputID, a.AccessTime FROM Actions a, Blobs b WHERE a.NameSpaceID = ? AND a.ActionID = ? AND a.BlobID = b.BlobID", - actionKey.NamespaceID, actionKey.ActionID).Scan( - &sha256hex, &storedSize, &uncompressedSize, &smallData, &altObjectID, &accessTime) if err != nil { if errors.Is(err, sql.ErrNoRows) { http.Error(w, "not found", http.StatusNotFound) @@ -1240,7 +1297,7 @@ func (srv *Server) getObjectFromDiskOrPeer(_ context.Context, sha256hex string, return f, nil } -func (s *Server) handlePut(w http.ResponseWriter, r *http.Request, stats *stats) { +func (s *Server) handlePut(w http.ResponseWriter, r *http.Request, stats *stats, namespaceID int64) { s.m.ActivePuts.Add(1) defer s.m.ActivePuts.Add(-1) @@ -1317,13 +1374,12 @@ func (s *Server) handlePut(w http.ResponseWriter, r *http.Request, stats *stats) // Insert or update the action in the database. nowUnix := s.now().Unix() altObjectID := "" - namespace := 0 // global for now; TODO(bradfitz): support namespaces if sha256hex != outputID { altObjectID = outputID } res, err := s.db.Exec(`INSERT OR IGNORE INTO Actions (NamespaceID, ActionID, BlobID, AltOutputID, CreateTime, AccessTime) VALUES (?, ?, ?, ?, ?, ?)`, - namespace, + namespaceID, actionID, blobID, altObjectID, @@ -1384,23 +1440,44 @@ func (srv *Server) handleTokenExchange(w http.ResponseWriter, r *http.Request) { return } - globalNSWrite, err := srv.evaluateClaims(jwtClaims) + ns, err := srv.namespaceMapping(jwtClaims) if err != nil { srv.m.AuthErrs.Add(1) if srv.verbose { - srv.logf("token exchange: %v", err) + srv.logf("token exchange: namespace func error: %v", err) } http.Error(w, "unauthorized", http.StatusUnauthorized) return } + if err := validateNamespace(ns); err != nil { + srv.m.AuthErrs.Add(1) + if srv.verbose { + srv.logf("token exchange: invalid namespace from claims: %v", err) + } + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + + var namespaceID int64 + if ns != GlobalNamespace { + namespaceID, err = srv.resolveNamespaceID(ns) + if err != nil { + srv.m.AuthErrs.Add(1) + srv.logf("token exchange: %v", err) + http.Error(w, "internal error", http.StatusInternalServerError) + return + } + } + const ttl = time.Hour // 52 base32 characters, 256 bits of entropy. accessToken := tokenPrefix + strings.ToLower(rand.Text()+rand.Text()) srv.addSessionData(accessToken, &sessionData{ - expiry: srv.now().UTC().Add(ttl), - globalNSWrite: globalNSWrite, - claims: jwtClaims, + expiry: srv.now().UTC().Add(ttl), + namespaceID: namespaceID, + namespace: ns, + claims: jwtClaims, }) resp := map[string]any{ @@ -1418,38 +1495,61 @@ func (srv *Server) handleTokenExchange(w http.ResponseWriter, r *http.Request) { srv.m.Auths.Add(1) } -func (srv *Server) evaluateClaims(claims map[string]any) (globalNSWrite bool, _ error) { - iss, _ := claims["iss"].(string) - cfg, ok := srv.jwtIssuers[iss] - if !ok { - return false, fmt.Errorf("got claims %v; unknown issuer %q", claims, iss) - } +// namespaceAllowedBytes is the set of non-alphanumeric bytes permitted in a +// namespace. It is chosen to cover characters common in JWT identity claims: +// issuer URLs, emails, and provider-structured "sub" values such as +// "repo:org/repo:environment:prod" and "auth0|abc123". It is deliberately +// ASCII-only so SQLite BINARY comparison matches Go string equality byte for +// byte, avoiding Unicode normalization divergence. +const namespaceAllowedBytes = "._~:/@+|=-" - if missing := findMissingClaims(cfg.requiredClaims, claims); len(missing) > 0 { - return false, fmt.Errorf("got claims %v; missing required claims: %v", claims, missing) +func validateNamespace(ns Namespace) error { + if ns == GlobalNamespace { + return nil + } + for i := 0; i < len(ns); i++ { + c := ns[i] + switch { + case c >= 'A' && c <= 'Z', c >= 'a' && c <= 'z', c >= '0' && c <= '9': + case strings.IndexByte(namespaceAllowedBytes, c) >= 0: + default: + return fmt.Errorf("namespace contains disallowed byte %#x at index %d", c, i) + } } + return nil +} - if missing := findMissingClaims(cfg.globalWriteClaims, claims); len(missing) == 0 { - return true, nil - } else if srv.verbose { - srv.logf("token exchange: missing global namespace write claims: %v", missing) +// resolveNamespaceID returns the integer ID for the given Namespace, +// inserting a row in the Namespaces table if one doesn't already exist. +func (srv *Server) resolveNamespaceID(ns Namespace) (int64, error) { + // If it's not a new namespace, we only need to consult our cache of IDs. + srv.mu.Lock() + id, ok := srv.namespaces[ns] + srv.mu.Unlock() + if ok { + return id, nil } - return false, nil -} + srv.sqliteWriteMu.Lock() + defer srv.sqliteWriteMu.Unlock() -func findMissingClaims(wantClaims map[string]string, gotClaims map[string]any) map[string]any { - if wantClaims == nil { - return nil + srv.mu.Lock() + defer srv.mu.Unlock() + + // Check if we lost a race now that we have both locks. + if id, ok = srv.namespaces[ns]; ok { + return id, nil } - missing := make(map[string]any) - for k, want := range wantClaims { - if got, ok := gotClaims[k]; !ok || got != want { - missing[k] = want - } + err := srv.db.QueryRow(`INSERT INTO Namespaces (Namespace) VALUES (?) + RETURNING NamespaceID;`, ns).Scan(&id) + if err != nil { + return 0, fmt.Errorf("resolving namespace %q: %w", ns, err) } - return missing + + srv.namespaces[ns] = id + + return id, nil } func (srv *Server) handleSessionStats(w http.ResponseWriter, sessionData *sessionData) { @@ -2524,10 +2624,11 @@ func (srv *Server) serveSessions(w http.ResponseWriter, r *http.Request) { for _, v := range srv.sessions { v.mu.Lock() sessions = append(sessions, &sessionData{ - expiry: v.expiry, - globalNSWrite: v.globalNSWrite, - claims: v.claims, - stats: v.stats, + expiry: v.expiry, + namespaceID: v.namespaceID, + namespace: v.namespace, + claims: v.claims, + stats: v.stats, }) v.mu.Unlock() } @@ -2535,15 +2636,13 @@ func (srv *Server) serveSessions(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/html; charset=utf-8") fmt.Fprintf(w, "

gocached sessions

\n") - for iss, cfg := range srv.jwtIssuers { + for _, iss := range srv.jwtIssuers { fmt.Fprintf(w, "

JWT issuer: %s

\n", iss) - fmt.Fprintf(w, "

JWT claims required: %v

\n", cfg.requiredClaims) - fmt.Fprintf(w, "

JWT global write claims required: %v

\n", cfg.globalWriteClaims) } fmt.Fprintf(w, "

Number of sessions: %d

\n", len(sessions)) fmt.Fprintf(w, "\n") - fmt.Fprintf(w, "\n") + fmt.Fprintf(w, "\n") slices.SortFunc(sessions, func(a, b *sessionData) int { return a.stats.LastUsed.Compare(b.stats.LastUsed) }) @@ -2552,10 +2651,14 @@ func (srv *Server) serveSessions(w http.ResponseWriter, r *http.Request) { if !d.stats.LastUsed.IsZero() { lastUsed = durFmt(time.Since(d.stats.LastUsed)) + " ago" } + nsLabel := "(global)" + if d.namespaceID != 0 { + nsLabel = fmt.Sprintf("%q (id=%d)", d.namespace, d.namespaceID) + } statsJSON, _ := json.MarshalIndent(d.stats, "", " ") claimsJSON, _ := json.MarshalIndent(d.claims, "", " ") - fmt.Fprintf(w, "\n", - lastUsed, d.expiry.Format(time.RFC3339), d.globalNSWrite, statsJSON, claimsJSON) + fmt.Fprintf(w, "\n", + lastUsed, d.expiry.Format(time.RFC3339), nsLabel, statsJSON, claimsJSON) } fmt.Fprintf(w, "
Last usedExpiry timeGlobal writeStatsClaims
Last usedExpiry timeNamespaceStatsClaims
%s%s%v
%s
%s
%s%s%s
%s
%s
\n") } diff --git a/gocached/gocached_test.go b/gocached/gocached_test.go index 2d08f36..858a3ca 100644 --- a/gocached/gocached_test.go +++ b/gocached/gocached_test.go @@ -38,6 +38,23 @@ import ( // value in SQLite to store bytes, as it's common. const sha256OfEmpty = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855" +// Package-level ECDSA P-256 keys reused across tests that need OIDC signing +// keys. Generating these is the single most expensive step in the JWT tests, +// so we share them rather than regenerating per test. +var ( + testKey1 = mustGenerateTestKey() + testKey2 = mustGenerateTestKey() + testKey3 = mustGenerateTestKey() +) + +func mustGenerateTestKey() *ecdsa.PrivateKey { + k, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + panic(fmt.Sprintf("generating test ECDSA key: %v", err)) + } + return k +} + type tester struct { t testing.TB srv *Server @@ -1150,41 +1167,17 @@ func TestClientConnReuse(t *testing.T) { } func TestExchangeToken(t *testing.T) { - // Generate private keys outside of the loop for speed. - privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - if err != nil { - t.Fatalf("error generating OIDC server private key: %v", err) - } - otherPrivateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - if err != nil { - t.Fatalf("error generating OIDC server private key: %v", err) - } - wantClaims := map[string]string{ - "sub": "user123", - } - wantGlobalClaims := map[string]string{ - "sub": "user123", - "ref": "refs/heads/main", - } + privateKey := testKey1 + otherPrivateKey := testKey2 for name, tc := range map[string]struct { mutateClaims func(jwt.MapClaims) signingKey *ecdsa.PrivateKey wantStatusCode int - wantWrite bool }{ // Base case: no mutation. - "valid_read": { + "valid": { wantStatusCode: http.StatusOK, - wantWrite: false, - }, - // Additional claim needed for write scope. - "valid_write": { - mutateClaims: func(cl jwt.MapClaims) { - cl["ref"] = "refs/heads/main" - }, - wantStatusCode: http.StatusOK, - wantWrite: true, }, // Every other test makes one mutation from the base case that should cause failure. "missing_sub": { @@ -1231,10 +1224,12 @@ func TestExchangeToken(t *testing.T) { t.Run(name, func(t *testing.T) { issuer, createJWT := startOIDCServer(t, privateKey.Public()) st := newServerTester(t, - WithJWTAuth(JWTIssuerConfig{ - Issuer: issuer, - RequiredClaims: wantClaims, - GlobalWriteClaims: wantGlobalClaims, + WithJWTAuth(issuer), + WithNamespaceMapping(func(claims map[string]any) (Namespace, error) { + if claims["sub"] != "user123" { + return "", fmt.Errorf("sub = %v, want user123", claims["sub"]) + } + return GlobalNamespace, nil }), ) @@ -1307,14 +1302,8 @@ func TestExchangeToken(t *testing.T) { cl.AccessToken = d.AccessToken st.wantGetMiss(cl, "abc123") - if tc.wantWrite { - st.wantPut(cl, "abc123", "def456", "data789") - st.wantGet(cl, "abc123", "def456", "data789") - } else { - if _, err := cl.Put(t.Context(), "abc123", "def456", 0, nil); err == nil { - t.Fatalf("Put without write scope succeeded unexpectedly") - } - } + st.wantPut(cl, "abc123", "def456", "data789") + st.wantGet(cl, "abc123", "def456", "data789") // Check session stats. reqStats, err := http.NewRequest("GET", st.hs.URL+"/session/stats", nil) @@ -1342,102 +1331,91 @@ func TestExchangeToken(t *testing.T) { if stats.Gets == 0 { t.Errorf("expected non-zero gets in session stats") } - if stats.Puts == 0 && tc.wantWrite { + if stats.Puts == 0 { t.Errorf("expected non-zero puts in session stats") } }) } } -func TestMultiIssuerAuth(t *testing.T) { - // Generate separate keys for each issuer. - keyA, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - if err != nil { - t.Fatalf("error generating key A: %v", err) - } - keyB, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - if err != nil { - t.Fatalf("error generating key B: %v", err) - } - keyC, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - if err != nil { - t.Fatalf("error generating key C: %v", err) +func TestExchangeTokenNamespaceValidation(t *testing.T) { + for name, tc := range map[string]struct { + namespace string + wantStatusCode int + }{ + "simple": {"user123", http.StatusOK}, + "structured_sub": {"repo:octo-org/octo-repo:environment:prod", http.StatusOK}, + "auth0_sub": {"auth0|507f1f77bcf86cd799439020", http.StatusOK}, + "email": {"alice+ci@example.com", http.StatusOK}, + "space": {"has space", http.StatusUnauthorized}, + "non_ascii": {"héllo", http.StatusUnauthorized}, + "control_char": {"bad\tns", http.StatusUnauthorized}, + "html_meta": {"