diff --git a/cmd/gocached/gocached.go b/cmd/gocached/gocached.go index 01551d4..8e0f8bc 100644 --- a/cmd/gocached/gocached.go +++ b/cmd/gocached/gocached.go @@ -9,7 +9,6 @@ import ( "flag" "fmt" "log" - "maps" "net" "net/http" "os" @@ -77,15 +76,27 @@ func main() { log.Fatal("must specify --jwt-claim at least once when --jwt-issuer is set") } - globalClaims := map[string]string{} - maps.Copy(globalClaims, jwtClaims) - maps.Copy(globalClaims, globalJWTClaims) - - opts = append(opts, gocached.WithJWTAuth(gocached.JWTIssuerConfig{ - Issuer: *jwtIssuer, - RequiredClaims: jwtClaims, - GlobalWriteClaims: globalClaims, - })) + opts = append(opts, + gocached.WithJWTAuth(*jwtIssuer), + gocached.WithNamespaceMapping(func(claims map[string]any) (gocached.Namespace, error) { + var ns gocached.Namespace + for k, want := range jwtClaims { + if got := claims[k]; got != want { + return "", fmt.Errorf("claim %q = %v, want %v", k, got, want) + } + if ns != "" { + ns += "," + } + ns += gocached.Namespace(fmt.Sprintf("%s=%s", k, want)) + } + for k, want := range globalJWTClaims { + if got := claims[k]; got != want { + return ns, nil + } + } + return gocached.GlobalNamespace, nil + }), + ) } srv, err := gocached.NewServer(opts...) diff --git a/gocached/gocached.go b/gocached/gocached.go index e7b394f..3f7af5b 100644 --- a/gocached/gocached.go +++ b/gocached/gocached.go @@ -191,7 +191,7 @@ CREATE UNIQUE INDEX IF NOT EXISTS idx_blobs_sha256 ON Blobs(SHA256); CREATE TABLE IF NOT EXISTS Namespaces ( NamespaceID INTEGER PRIMARY KEY AUTOINCREMENT, - Namespace TEXT NOT NULL UNIQUE CHECK (Namespace = lower(Namespace)) + Namespace TEXT NOT NULL UNIQUE ) STRICT; -- BlobShardStats persists per-shard usage histograms across restarts so @@ -230,6 +230,15 @@ func openDB(dbDir string) (*sql.DB, error) { db.SetMaxOpenConns(numConns) db.SetMaxIdleConns(numConns) db.SetConnMaxLifetime(0) // no limit + var ddl string + err = db.QueryRow(`SELECT sql FROM sqlite_master WHERE type='table' AND name='Namespaces'`).Scan(&ddl) + if err == nil && strings.Contains(ddl, "lower(Namespace)") { + // Drop the Namespaces table _only_ if it has the lowercase constraint so it + // can be recreated without it. + if _, err := db.Exec(`DROP TABLE Namespaces`); err != nil { + return nil, fmt.Errorf("dropping Namespaces lowercase constraint: %w", err) + } + } if _, err := db.Exec(schema); err != nil { return nil, err } @@ -365,6 +374,24 @@ func (srv *Server) start() error { Buckets: prometheus.DefBuckets, }) + // Fill the namespace ID cache. + rows, err := srv.db.Query("SELECT NamespaceID, Namespace FROM Namespaces") + if err != nil { + return err + } + defer rows.Close() + for rows.Next() { + var id int64 + var ns string + if err := rows.Scan(&id, &ns); err != nil { + return err + } + srv.namespaces[Namespace(ns)] = id + } + if err := rows.Err(); err != nil { + return err + } + reg := prometheus.NewRegistry() reg.MustRegister( collectors.NewGoCollector(), @@ -444,17 +471,13 @@ func (srv *Server) start() error { len(srv.shardStats), srv.numShards(), srv.lastUsage.Load().All(), bytesFmt(srv.maxSize)) if len(srv.jwtIssuers) > 0 { - issuerURLs := make([]string, 0, len(srv.jwtIssuers)) - for iss := range srv.jwtIssuers { - issuerURLs = append(issuerURLs, iss) - } - srv.jwtValidator = ijwt.NewJWTValidator(srv.logf, gocachedAudience, issuerURLs) + srv.jwtValidator = ijwt.NewJWTValidator(srv.logf, gocachedAudience, srv.jwtIssuers) if err := srv.jwtValidator.RunUpdateJWKSLoop(srv.shutdownCtx); err != nil { return fmt.Errorf("failed to fetch JWKS for JWT validator: %w", err) } - for iss, entry := range srv.jwtIssuers { - srv.logf("gocached: using JWT issuer %q with required claims %v, global write claims %v", iss, entry.requiredClaims, entry.globalWriteClaims) + for _, iss := range srv.jwtIssuers { + srv.logf("gocached: using JWT issuer %q", iss) } go srv.runCleanSessionsLoop() @@ -568,40 +591,44 @@ func WithShardPrefixLen(n int) ServerOption { } } -// JWTIssuerConfig configures a single OIDC issuer for JWT-based authentication. -type JWTIssuerConfig struct { - // Issuer is the OIDC issuer URL. It must be a reachable HTTP(S) server - // that serves its JWKS via a URL discoverable at - // /.well-known/openid-configuration. - Issuer string +// Namespace identifies a logical partition of the cache where each peer is +// equally trusted. Every session is associated with exactly one Namespace to +// which it can read and write; sessions for non-global namespaces also read +// from [GlobalNamespace]. See [WithNamespaceMapping]. It may only contain +// characters from the set [a-zA-Z0-9._~:/@+|=-]. +type Namespace string - // RequiredClaims are claims that any JWT from this issuer must have to - // start a session. All key-value pairs must match exactly. - RequiredClaims map[string]string +// GlobalNamespace is a trusted namespace that all sessions can read from. Only +// sessions explicitly mapped to GlobalNamespace can write to it. +const GlobalNamespace Namespace = "" - // GlobalWriteClaims are claims that a JWT from this issuer must have to - // write to the cache's global namespace. It should be a superset of - // RequiredClaims. - GlobalWriteClaims map[string]string +// WithJWTAuth enables JWT-based authentication for the server. Each issuer +// must be a reachable HTTP(S) server that serves its JWKS via a URL +// discoverable at /.well-known/openid-configuration. JWTs presented for token +// exchange must pass the standard signature/issuer/audience/expiry checks +// against one of these issuers. If [WithNamespaceMapping] is provided, then +// it may still be rejected if the mapping function returns an error for its +// claims. No requests other than token exchange are allowed without +// authentication. It may be called multiple times; issuers accumulate. +func WithJWTAuth(issuers ...string) ServerOption { + return func(srv *Server) { + srv.jwtIssuers = append(srv.jwtIssuers, issuers...) + } } -// WithJWTAuth enables JWT-based authentication for the server. Each issuer must -// be a reachable HTTP(S) server that serves its JWKS via a URL discoverable at -// /.well-known/openid-configuration, and any JWT presented to the server must -// exactly match the issuer's required claims to start a session. No requests are -// allowed without authentication if JWT auth is enabled. It can be called multiple -// times; configs accumulate. -func WithJWTAuth(issuers ...JWTIssuerConfig) ServerOption { +// WithNamespaceMapping sets the function that makes policy decisions based on +// a JWT's claims. It is called once per token exchange after the JWT's +// signature and standard claims have been validated. It should return an error +// if the claims are not authorized, and otherwise return which [Namespace] the +// session is allowed to read and write in. See [Namespace] for character set +// constraints. All authorized sessions are allowed to read from the +// [GlobalNamespace] regardless of the namespace returned. Check claims["iss"] +// to switch on per-issuer rules. If JWT auth is enabled but no mapping +// function is provided, all sessions will read and write in the +// [GlobalNamespace]. +func WithNamespaceMapping(fn func(claims map[string]any) (Namespace, error)) ServerOption { return func(srv *Server) { - if srv.jwtIssuers == nil { - srv.jwtIssuers = make(map[string]*jwtIssuerConfig) - } - for _, ic := range issuers { - srv.jwtIssuers[ic.Issuer] = &jwtIssuerConfig{ - requiredClaims: ic.RequiredClaims, - globalWriteClaims: ic.GlobalWriteClaims, - } - } + srv.namespaceMapping = fn } } @@ -613,12 +640,21 @@ func NewServer(opts ...ServerOption) (*Server, error) { shutdownCtx: context.Background(), logf: log.Printf, sessions: make(map[string]*sessionData), + namespaces: make(map[Namespace]int64), clock: time.Now, } for _, opt := range opts { opt(srv) } + if len(srv.jwtIssuers) > 0 && srv.namespaceMapping == nil { + // If JWT auth is enabled, but not namespace mapping, every session is in + // the global namespace. + srv.namespaceMapping = func(claims map[string]any) (Namespace, error) { + return GlobalNamespace, nil + } + } + err := srv.start() if err != nil { return nil, err @@ -668,13 +704,15 @@ type Server struct { shutdownCtx context.Context shutdownCancel context.CancelFunc - jwtValidator *ijwt.Validator // nil unless jwtIssuers is non-empty - jwtIssuers map[string]*jwtIssuerConfig // keyed by issuer URL + jwtValidator *ijwt.Validator // nil unless jwtIssuers is non-empty + jwtIssuers []string // accepted issuer URLs + namespaceMapping func(claims map[string]any) (Namespace, error) // required when jwtIssuers is non-empty mu sync.RWMutex // guards following fields in this block sessions map[string]*sessionData // maps access token -> session data. accessDirty map[actionKey]int64 // action -> accessTime accessFlushTimer *time.Timer // nil if no flush is scheduled + namespaces map[Namespace]int64 // cached namespace string -> NamespaceID // sqliteWriteMu serializes access to SQLite. In theory the SQLite driver // should serialize access with our 5000ms busy timeout, but empirically we @@ -763,18 +801,13 @@ type Server struct { } } -// jwtIssuerConfig holds per-issuer claim requirements for JWT auth. -type jwtIssuerConfig struct { - requiredClaims map[string]string - globalWriteClaims map[string]string -} - // sessionData corresponds to a specific access token, and is only used if JWT // auth is enabled. type sessionData struct { - expiry time.Time // Session valid until. - globalNSWrite bool // Whether this session can write to the cache's global namespace. - claims map[string]any // Claims from the JWT used to create this session, stored for debug. + expiry time.Time // Session valid until. + namespaceID int64 // The namespace this session writes to. 0 means GlobalNamespace; non-zero sessions also read from 0. + namespace Namespace // Namespace this session writes to, stored for debug. + claims map[string]any // Claims from the JWT used to create this session, stored for debug. mu sync.Mutex // Guards stats. stats stats @@ -884,12 +917,11 @@ func (srv *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { } if r.Method == "PUT" { - if sessionData != nil && !sessionData.globalNSWrite { - // TODO(tomhjp): support per-namespace writes. - http.Error(w, "forbidden", http.StatusForbidden) - return + var writeNS int64 + if sessionData != nil { + writeNS = sessionData.namespaceID } - srv.handlePut(w, r, reqStats) + srv.handlePut(w, r, reqStats, writeNS) return } if r.Method != "GET" && r.Method != "HEAD" { @@ -897,7 +929,7 @@ func (srv *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } if strings.HasPrefix(r.URL.Path, "/action/") { - srv.handleGetAction(w, r, reqStats) + srv.handleGetAction(w, r, reqStats, sessionData) return } if sessionData != nil && r.URL.Path == "/session/stats" { @@ -966,7 +998,7 @@ func getHexSuffix(r *http.Request, prefix string) (hexSuffix string, ok bool) { // actionKey is the comparable value type for the (NamespaceID, ActionID) // primary key tuple used in the SQLite Actions table. type actionKey struct { - NamespaceID int // 0 for global + NamespaceID int64 // 0 for global ActionID string } @@ -988,7 +1020,27 @@ func validHex(x string) bool { // we do a DB write to update it. const relAtimeSeconds = 60 * 60 * 24 // 1 day -func (srv *Server) handleGetAction(w http.ResponseWriter, r *http.Request, stats *stats) { +const getFromGlobalNamespace = ` +SELECT b.SHA256, b.StoredSize, b.UncompressedSize, b.SmallData, a.AltOutputID, a.AccessTime, a.NamespaceID +FROM Actions a, Blobs b +WHERE a.NamespaceID = 0 + AND a.ActionID = ? + AND a.BlobID = b.BlobID +` + +// Get hits from the global namespace first so shared cache mtime is bumped +// with higher priority than namespaced cache. +const getFromSessionNamespace = ` +SELECT b.SHA256, b.StoredSize, b.UncompressedSize, b.SmallData, a.AltOutputID, a.AccessTime, a.NamespaceID +FROM Actions a, Blobs b +WHERE a.NamespaceID IN (0, ?) + AND a.ActionID = ? + AND a.BlobID = b.BlobID +ORDER BY CASE a.NamespaceID WHEN 0 THEN 0 ELSE 1 END +LIMIT 1 +` + +func (srv *Server) handleGetAction(w http.ResponseWriter, r *http.Request, stats *stats, sessionData *sessionData) { srv.m.ActiveGets.Add(1) defer srv.m.ActiveGets.Add(-1) @@ -1014,19 +1066,24 @@ func (srv *Server) handleGetAction(w http.ResponseWriter, r *http.Request, stats return } - var sha256hex string - var storedSize, uncompressedSize int64 - var smallData sql.NullString - var altObjectID string - var accessTime int64 - var actionKey = actionKey{ - NamespaceID: 0, // global for now; TODO(bradfitz): support namespac - ActionID: actionID, + var ( + sha256hex string + storedSize, uncompressedSize int64 + smallData sql.NullString + altObjectID string + accessTime int64 + err error + actionKey = actionKey{ + ActionID: actionID, + } + ) + if sessionData != nil && sessionData.namespaceID != 0 { + err = srv.db.QueryRow(getFromSessionNamespace, sessionData.namespaceID, actionID).Scan( + &sha256hex, &storedSize, &uncompressedSize, &smallData, &altObjectID, &accessTime, &actionKey.NamespaceID) + } else { + err = srv.db.QueryRow(getFromGlobalNamespace, actionID).Scan( + &sha256hex, &storedSize, &uncompressedSize, &smallData, &altObjectID, &accessTime, &actionKey.NamespaceID) } - err := srv.db.QueryRow( - "SELECT b.SHA256, b.StoredSize, b.UncompressedSize, b.SmallData, a.AltOutputID, a.AccessTime FROM Actions a, Blobs b WHERE a.NameSpaceID = ? AND a.ActionID = ? AND a.BlobID = b.BlobID", - actionKey.NamespaceID, actionKey.ActionID).Scan( - &sha256hex, &storedSize, &uncompressedSize, &smallData, &altObjectID, &accessTime) if err != nil { if errors.Is(err, sql.ErrNoRows) { http.Error(w, "not found", http.StatusNotFound) @@ -1240,7 +1297,7 @@ func (srv *Server) getObjectFromDiskOrPeer(_ context.Context, sha256hex string, return f, nil } -func (s *Server) handlePut(w http.ResponseWriter, r *http.Request, stats *stats) { +func (s *Server) handlePut(w http.ResponseWriter, r *http.Request, stats *stats, namespaceID int64) { s.m.ActivePuts.Add(1) defer s.m.ActivePuts.Add(-1) @@ -1317,13 +1374,12 @@ func (s *Server) handlePut(w http.ResponseWriter, r *http.Request, stats *stats) // Insert or update the action in the database. nowUnix := s.now().Unix() altObjectID := "" - namespace := 0 // global for now; TODO(bradfitz): support namespaces if sha256hex != outputID { altObjectID = outputID } res, err := s.db.Exec(`INSERT OR IGNORE INTO Actions (NamespaceID, ActionID, BlobID, AltOutputID, CreateTime, AccessTime) VALUES (?, ?, ?, ?, ?, ?)`, - namespace, + namespaceID, actionID, blobID, altObjectID, @@ -1384,23 +1440,44 @@ func (srv *Server) handleTokenExchange(w http.ResponseWriter, r *http.Request) { return } - globalNSWrite, err := srv.evaluateClaims(jwtClaims) + ns, err := srv.namespaceMapping(jwtClaims) if err != nil { srv.m.AuthErrs.Add(1) if srv.verbose { - srv.logf("token exchange: %v", err) + srv.logf("token exchange: namespace func error: %v", err) } http.Error(w, "unauthorized", http.StatusUnauthorized) return } + if err := validateNamespace(ns); err != nil { + srv.m.AuthErrs.Add(1) + if srv.verbose { + srv.logf("token exchange: invalid namespace from claims: %v", err) + } + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + + var namespaceID int64 + if ns != GlobalNamespace { + namespaceID, err = srv.resolveNamespaceID(ns) + if err != nil { + srv.m.AuthErrs.Add(1) + srv.logf("token exchange: %v", err) + http.Error(w, "internal error", http.StatusInternalServerError) + return + } + } + const ttl = time.Hour // 52 base32 characters, 256 bits of entropy. accessToken := tokenPrefix + strings.ToLower(rand.Text()+rand.Text()) srv.addSessionData(accessToken, &sessionData{ - expiry: srv.now().UTC().Add(ttl), - globalNSWrite: globalNSWrite, - claims: jwtClaims, + expiry: srv.now().UTC().Add(ttl), + namespaceID: namespaceID, + namespace: ns, + claims: jwtClaims, }) resp := map[string]any{ @@ -1418,38 +1495,61 @@ func (srv *Server) handleTokenExchange(w http.ResponseWriter, r *http.Request) { srv.m.Auths.Add(1) } -func (srv *Server) evaluateClaims(claims map[string]any) (globalNSWrite bool, _ error) { - iss, _ := claims["iss"].(string) - cfg, ok := srv.jwtIssuers[iss] - if !ok { - return false, fmt.Errorf("got claims %v; unknown issuer %q", claims, iss) - } +// namespaceAllowedBytes is the set of non-alphanumeric bytes permitted in a +// namespace. It is chosen to cover characters common in JWT identity claims: +// issuer URLs, emails, and provider-structured "sub" values such as +// "repo:org/repo:environment:prod" and "auth0|abc123". It is deliberately +// ASCII-only so SQLite BINARY comparison matches Go string equality byte for +// byte, avoiding Unicode normalization divergence. +const namespaceAllowedBytes = "._~:/@+|=-" - if missing := findMissingClaims(cfg.requiredClaims, claims); len(missing) > 0 { - return false, fmt.Errorf("got claims %v; missing required claims: %v", claims, missing) +func validateNamespace(ns Namespace) error { + if ns == GlobalNamespace { + return nil + } + for i := 0; i < len(ns); i++ { + c := ns[i] + switch { + case c >= 'A' && c <= 'Z', c >= 'a' && c <= 'z', c >= '0' && c <= '9': + case strings.IndexByte(namespaceAllowedBytes, c) >= 0: + default: + return fmt.Errorf("namespace contains disallowed byte %#x at index %d", c, i) + } } + return nil +} - if missing := findMissingClaims(cfg.globalWriteClaims, claims); len(missing) == 0 { - return true, nil - } else if srv.verbose { - srv.logf("token exchange: missing global namespace write claims: %v", missing) +// resolveNamespaceID returns the integer ID for the given Namespace, +// inserting a row in the Namespaces table if one doesn't already exist. +func (srv *Server) resolveNamespaceID(ns Namespace) (int64, error) { + // If it's not a new namespace, we only need to consult our cache of IDs. + srv.mu.Lock() + id, ok := srv.namespaces[ns] + srv.mu.Unlock() + if ok { + return id, nil } - return false, nil -} + srv.sqliteWriteMu.Lock() + defer srv.sqliteWriteMu.Unlock() -func findMissingClaims(wantClaims map[string]string, gotClaims map[string]any) map[string]any { - if wantClaims == nil { - return nil + srv.mu.Lock() + defer srv.mu.Unlock() + + // Check if we lost a race now that we have both locks. + if id, ok = srv.namespaces[ns]; ok { + return id, nil } - missing := make(map[string]any) - for k, want := range wantClaims { - if got, ok := gotClaims[k]; !ok || got != want { - missing[k] = want - } + err := srv.db.QueryRow(`INSERT INTO Namespaces (Namespace) VALUES (?) + RETURNING NamespaceID;`, ns).Scan(&id) + if err != nil { + return 0, fmt.Errorf("resolving namespace %q: %w", ns, err) } - return missing + + srv.namespaces[ns] = id + + return id, nil } func (srv *Server) handleSessionStats(w http.ResponseWriter, sessionData *sessionData) { @@ -2524,10 +2624,11 @@ func (srv *Server) serveSessions(w http.ResponseWriter, r *http.Request) { for _, v := range srv.sessions { v.mu.Lock() sessions = append(sessions, &sessionData{ - expiry: v.expiry, - globalNSWrite: v.globalNSWrite, - claims: v.claims, - stats: v.stats, + expiry: v.expiry, + namespaceID: v.namespaceID, + namespace: v.namespace, + claims: v.claims, + stats: v.stats, }) v.mu.Unlock() } @@ -2535,15 +2636,13 @@ func (srv *Server) serveSessions(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/html; charset=utf-8") fmt.Fprintf(w, "

gocached sessions

\n") - for iss, cfg := range srv.jwtIssuers { + for _, iss := range srv.jwtIssuers { fmt.Fprintf(w, "

JWT issuer: %s

\n", iss) - fmt.Fprintf(w, "

JWT claims required: %v

\n", cfg.requiredClaims) - fmt.Fprintf(w, "

JWT global write claims required: %v

\n", cfg.globalWriteClaims) } fmt.Fprintf(w, "

Number of sessions: %d

\n", len(sessions)) fmt.Fprintf(w, "\n") - fmt.Fprintf(w, "\n") + fmt.Fprintf(w, "\n") slices.SortFunc(sessions, func(a, b *sessionData) int { return a.stats.LastUsed.Compare(b.stats.LastUsed) }) @@ -2552,10 +2651,14 @@ func (srv *Server) serveSessions(w http.ResponseWriter, r *http.Request) { if !d.stats.LastUsed.IsZero() { lastUsed = durFmt(time.Since(d.stats.LastUsed)) + " ago" } + nsLabel := "(global)" + if d.namespaceID != 0 { + nsLabel = fmt.Sprintf("%q (id=%d)", d.namespace, d.namespaceID) + } statsJSON, _ := json.MarshalIndent(d.stats, "", " ") claimsJSON, _ := json.MarshalIndent(d.claims, "", " ") - fmt.Fprintf(w, "\n", - lastUsed, d.expiry.Format(time.RFC3339), d.globalNSWrite, statsJSON, claimsJSON) + fmt.Fprintf(w, "\n", + lastUsed, d.expiry.Format(time.RFC3339), nsLabel, statsJSON, claimsJSON) } fmt.Fprintf(w, "
Last usedExpiry timeGlobal writeStatsClaims
Last usedExpiry timeNamespaceStatsClaims
%s%s%v
%s
%s
%s%s%s
%s
%s
\n") } diff --git a/gocached/gocached_test.go b/gocached/gocached_test.go index 2d08f36..858a3ca 100644 --- a/gocached/gocached_test.go +++ b/gocached/gocached_test.go @@ -38,6 +38,23 @@ import ( // value in SQLite to store bytes, as it's common. const sha256OfEmpty = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855" +// Package-level ECDSA P-256 keys reused across tests that need OIDC signing +// keys. Generating these is the single most expensive step in the JWT tests, +// so we share them rather than regenerating per test. +var ( + testKey1 = mustGenerateTestKey() + testKey2 = mustGenerateTestKey() + testKey3 = mustGenerateTestKey() +) + +func mustGenerateTestKey() *ecdsa.PrivateKey { + k, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + panic(fmt.Sprintf("generating test ECDSA key: %v", err)) + } + return k +} + type tester struct { t testing.TB srv *Server @@ -1150,41 +1167,17 @@ func TestClientConnReuse(t *testing.T) { } func TestExchangeToken(t *testing.T) { - // Generate private keys outside of the loop for speed. - privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - if err != nil { - t.Fatalf("error generating OIDC server private key: %v", err) - } - otherPrivateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - if err != nil { - t.Fatalf("error generating OIDC server private key: %v", err) - } - wantClaims := map[string]string{ - "sub": "user123", - } - wantGlobalClaims := map[string]string{ - "sub": "user123", - "ref": "refs/heads/main", - } + privateKey := testKey1 + otherPrivateKey := testKey2 for name, tc := range map[string]struct { mutateClaims func(jwt.MapClaims) signingKey *ecdsa.PrivateKey wantStatusCode int - wantWrite bool }{ // Base case: no mutation. - "valid_read": { + "valid": { wantStatusCode: http.StatusOK, - wantWrite: false, - }, - // Additional claim needed for write scope. - "valid_write": { - mutateClaims: func(cl jwt.MapClaims) { - cl["ref"] = "refs/heads/main" - }, - wantStatusCode: http.StatusOK, - wantWrite: true, }, // Every other test makes one mutation from the base case that should cause failure. "missing_sub": { @@ -1231,10 +1224,12 @@ func TestExchangeToken(t *testing.T) { t.Run(name, func(t *testing.T) { issuer, createJWT := startOIDCServer(t, privateKey.Public()) st := newServerTester(t, - WithJWTAuth(JWTIssuerConfig{ - Issuer: issuer, - RequiredClaims: wantClaims, - GlobalWriteClaims: wantGlobalClaims, + WithJWTAuth(issuer), + WithNamespaceMapping(func(claims map[string]any) (Namespace, error) { + if claims["sub"] != "user123" { + return "", fmt.Errorf("sub = %v, want user123", claims["sub"]) + } + return GlobalNamespace, nil }), ) @@ -1307,14 +1302,8 @@ func TestExchangeToken(t *testing.T) { cl.AccessToken = d.AccessToken st.wantGetMiss(cl, "abc123") - if tc.wantWrite { - st.wantPut(cl, "abc123", "def456", "data789") - st.wantGet(cl, "abc123", "def456", "data789") - } else { - if _, err := cl.Put(t.Context(), "abc123", "def456", 0, nil); err == nil { - t.Fatalf("Put without write scope succeeded unexpectedly") - } - } + st.wantPut(cl, "abc123", "def456", "data789") + st.wantGet(cl, "abc123", "def456", "data789") // Check session stats. reqStats, err := http.NewRequest("GET", st.hs.URL+"/session/stats", nil) @@ -1342,102 +1331,91 @@ func TestExchangeToken(t *testing.T) { if stats.Gets == 0 { t.Errorf("expected non-zero gets in session stats") } - if stats.Puts == 0 && tc.wantWrite { + if stats.Puts == 0 { t.Errorf("expected non-zero puts in session stats") } }) } } -func TestMultiIssuerAuth(t *testing.T) { - // Generate separate keys for each issuer. - keyA, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - if err != nil { - t.Fatalf("error generating key A: %v", err) - } - keyB, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - if err != nil { - t.Fatalf("error generating key B: %v", err) - } - keyC, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - if err != nil { - t.Fatalf("error generating key C: %v", err) +func TestExchangeTokenNamespaceValidation(t *testing.T) { + for name, tc := range map[string]struct { + namespace string + wantStatusCode int + }{ + "simple": {"user123", http.StatusOK}, + "structured_sub": {"repo:octo-org/octo-repo:environment:prod", http.StatusOK}, + "auth0_sub": {"auth0|507f1f77bcf86cd799439020", http.StatusOK}, + "email": {"alice+ci@example.com", http.StatusOK}, + "space": {"has space", http.StatusUnauthorized}, + "non_ascii": {"héllo", http.StatusUnauthorized}, + "control_char": {"bad\tns", http.StatusUnauthorized}, + "html_meta": {"