diff --git a/api/go.mod b/api/go.mod index aa2dc495..6b3c33fa 100644 --- a/api/go.mod +++ b/api/go.mod @@ -94,6 +94,5 @@ require ( sigs.k8s.io/json v0.0.0-20241014173422-cfa47c3a1cc8 // indirect sigs.k8s.io/randfill v1.0.0 // indirect sigs.k8s.io/structured-merge-diff/v4 v4.4.1 // indirect - sigs.k8s.io/structured-merge-diff/v6 v6.3.0 // indirect sigs.k8s.io/yaml v1.6.0 // indirect ) diff --git a/api/go.sum b/api/go.sum index 4af514b6..e37b1483 100644 --- a/api/go.sum +++ b/api/go.sum @@ -255,16 +255,34 @@ gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gotest.tools v2.2.0+incompatible h1:VsBPFP1AI068pPrMxtb/S8Zkgf9xEmTLJjfM+P5UIEo= gotest.tools v2.2.0+incompatible/go.mod h1:DsYFclhRJ6vuDpmuTbkuFWG+y2sxOXAzmJt81HFBacw= +k8s.io/api v0.29.11 h1:6FwDo33f1WX5Yu0RQTX9YAd3wth8Ik0B4SXQKsoQfbk= +k8s.io/api v0.29.11/go.mod h1:3TDAW1OpFbz/Yx5r0W06b6eiAfHEwtH61VYDzpTU4Ng= +k8s.io/api v0.30.8 h1:Y+yZRF3c1WC0MTkLe0qBkiLCquRNa4I21/iDioGMCbo= +k8s.io/api v0.30.8/go.mod h1:89IE5MzirZ5HHxU/Hq1/KWGqXkhXClu/FHGesFhQ0A4= k8s.io/api v0.31.0 h1:b9LiSjR2ym/SzTOlfMHm1tr7/21aD7fSkqgD/CVJBCo= k8s.io/api v0.31.0/go.mod h1:0YiFF+JfFxMM6+1hQei8FY8M7s1Mth+z/q7eF1aJkTE= +k8s.io/apimachinery v0.29.11 h1:55+6ue9advpA7T0sX2ZJDHCLKuiFfrAAR/39VQN9KEQ= +k8s.io/apimachinery v0.29.11/go.mod h1:i3FJVwhvSp/6n8Fl4K97PJEP8C+MM+aoDq4+ZJBf70Y= +k8s.io/apimachinery v0.30.8 h1:9jyTItYzmJc00cBDxZC5ArFNxUeKCwbw0m760iFUMKY= +k8s.io/apimachinery v0.30.8/go.mod h1:iexa2somDaxdnj7bha06bhb43Zpa6eWH8N8dbqVjTUc= k8s.io/apimachinery v0.31.0 h1:m9jOiSr3FoSSL5WO9bjm1n6B9KROYYgNZOb4tyZ1lBc= k8s.io/apimachinery v0.31.0/go.mod h1:rsPdaZJfTfLsNJSQzNHQvYoTmxhoOEofxtOsF3rtsMo= +k8s.io/client-go v0.29.11 h1:mBX7Ub0uqpLMwWz3J/AGS/xKOZsjr349qZ1vxVoL1l8= +k8s.io/client-go v0.29.11/go.mod h1:WOEoi/eLg2YEg3/yEd7YK3CNScYkM8AEScQadxUnaTE= +k8s.io/client-go v0.30.8 h1:fC1SQMZm7bSWiVv9ydN+nv+sqGVAxMdf/5eKUVffNJE= +k8s.io/client-go v0.30.8/go.mod h1:daF3UcGVqGPHvH5mn/ESkp/VoR8i9tg9IBfKr+AeFYo= k8s.io/client-go v0.31.0 h1:QqEJzNjbN2Yv1H79SsS+SWnXkBgVu4Pj3CJQgbx0gI8= k8s.io/client-go v0.31.0/go.mod h1:Y9wvC76g4fLjmU0BA+rV+h2cncoadjvjjkkIGoTLcGU= k8s.io/klog/v2 v2.130.1 h1:n9Xl7H1Xvksem4KFG4PYbdQCQxqc/tTUyrgXaOhHSzk= k8s.io/klog/v2 v2.130.1/go.mod h1:3Jpz1GvMt720eyJH1ckRHK1EDfpxISzJ7I9OYgaDtPE= +k8s.io/kube-openapi v0.0.0-20240228011516-70dd3763d340 h1:BZqlfIlq5YbRMFko6/PM7FjZpUb45WallggurYhKGag= +k8s.io/kube-openapi v0.0.0-20240228011516-70dd3763d340/go.mod h1:yD4MZYeKMBwQKVht279WycxKyM84kkAx2DPrTXaeb98= k8s.io/kube-openapi v0.0.0-20250710124328-f3f2b991d03b h1:MloQ9/bdJyIu9lb1PzujOPolHyvO06MXG5TUIj2mNAA= k8s.io/kube-openapi v0.0.0-20250710124328-f3f2b991d03b/go.mod h1:UZ2yyWbFTpuhSbFhv24aGNOdoRdJZgsIObGBUaYVsts= +k8s.io/metrics v0.29.11 h1:GtIYoM6uKUQIygnUpVEYE4X12xggwWmnYe443TZumcg= +k8s.io/metrics v0.29.11/go.mod h1:aroO9WuHjPwp4Z6ACFG7aCZf3m8Un+Vn3ogauOlJHrE= +k8s.io/metrics v0.30.8 h1:oM7kdmGjnahww89CLSF95XHC5dWtb05+cgplUeu/fqE= +k8s.io/metrics v0.30.8/go.mod h1:fbEC4z4Q5uwGXtfJIGQyoq3ndBCORYopDA0oXReDk2I= k8s.io/metrics v0.31.0 h1:s7Vu7W0oEZPTN8jgcoiWIXIZBmVxt7YP9MRVyIgMdOc= k8s.io/metrics v0.31.0/go.mod h1:UNsz6swyX8FWkDoKN9ixPF75TBREMbHZIKjD7fydaOY= k8s.io/utils v0.0.0-20250604170112-4c0f3b243397 h1:hwvWFiBzdWw1FhfY1FooPn3kzWuJ8tmbZBHi4zVsl1Y= @@ -276,6 +294,8 @@ sigs.k8s.io/randfill v1.0.0 h1:JfjMILfT8A6RbawdsK2JXGBR5AQVfd+9TbzrlneTyrU= sigs.k8s.io/randfill v1.0.0/go.mod h1:XeLlZ/jmk4i1HRopwe7/aU3H5n1zNUcX6TM94b3QxOY= sigs.k8s.io/structured-merge-diff/v4 v4.4.1 h1:150L+0vs/8DA78h1u02ooW1/fFq/Lwr+sGiqlzvrtq4= sigs.k8s.io/structured-merge-diff/v4 v4.4.1/go.mod h1:N8hJocpFajUSSeSJ9bOZ77VzejKZaXsTtZo4/u7Io08= +sigs.k8s.io/structured-merge-diff/v6 v6.2.0 h1:msyqjP8Nyd5sF3QSmJouFSzcBIdwq4ct8d1/7VSBHIQ= +sigs.k8s.io/structured-merge-diff/v6 v6.2.0/go.mod h1:M3W8sfWvn2HhQDIbGWj3S099YozAsymCo/wrT5ohRUE= sigs.k8s.io/structured-merge-diff/v6 v6.3.0 h1:jTijUJbW353oVOd9oTlifJqOGEkUw2jB/fXCbTiQEco= sigs.k8s.io/structured-merge-diff/v6 v6.3.0/go.mod h1:M3W8sfWvn2HhQDIbGWj3S099YozAsymCo/wrT5ohRUE= sigs.k8s.io/yaml v1.6.0 h1:G8fkbMSAFqgEFgh4b1wmtzDnioxFCUgTZhlbj5P9QYs= diff --git a/api/internal/auth/handlers.go b/api/internal/auth/handlers.go index 40ec6d6c..70bfcfa5 100644 --- a/api/internal/auth/handlers.go +++ b/api/internal/auth/handlers.go @@ -102,6 +102,8 @@ package auth import ( "context" + "encoding/xml" + "fmt" "log" "net/http" "time" @@ -473,9 +475,18 @@ func (h *AuthHandler) SAMLMetadata(c *gin.Context) { // Generate metadata XML metadata := sp.Metadata() + // Marshal to XML bytes + metadataBytes, err := xml.Marshal(metadata) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{ + "error": fmt.Sprintf("Failed to marshal metadata: %v", err), + }) + return + } + // Return XML with proper content type c.Header("Content-Type", "application/samlmetadata+xml") - c.String(http.StatusOK, string(metadata)) + c.String(http.StatusOK, string(metadataBytes)) } // PasswordChangeRequest represents a password change request diff --git a/api/internal/auth/middleware.go b/api/internal/auth/middleware.go index f311ffeb..f31ce7f9 100644 --- a/api/internal/auth/middleware.go +++ b/api/internal/auth/middleware.go @@ -133,7 +133,6 @@ package auth import ( - "context" "net/http" "strings" diff --git a/api/internal/auth/oidc.go b/api/internal/auth/oidc.go index 2f919df4..7441e69e 100644 --- a/api/internal/auth/oidc.go +++ b/api/internal/auth/oidc.go @@ -185,6 +185,7 @@ package auth import ( "context" + "crypto/tls" "encoding/json" "fmt" "log" diff --git a/api/internal/auth/saml.go b/api/internal/auth/saml.go index a0153b7f..87b2a6cf 100644 --- a/api/internal/auth/saml.go +++ b/api/internal/auth/saml.go @@ -554,6 +554,11 @@ func (sa *SAMLAuthenticator) GetMiddleware() *samlsp.Middleware { return sa.middleware } +// GetServiceProvider returns the SAML Service Provider instance. +func (sa *SAMLAuthenticator) GetServiceProvider() *saml.ServiceProvider { + return sa.serviceProvider +} + // ExtractUserFromAssertion extracts user information from a SAML assertion. // // After the IdP authenticates a user, it sends a SAML assertion containing: @@ -815,6 +820,70 @@ func (sa *SAMLAuthenticator) ExtractUserFromAssertion(assertion *saml.Assertion) return user, nil } +// ExtractUserFromAttributes extracts user information from SAML session attributes. +// +// This is a simpler version of ExtractUserFromAssertion that works with the +// attributes map returned by SessionWithAttributes.GetAttributes(). +// +// The attributes map contains key-value pairs from the SAML assertion's +// AttributeStatements, already parsed and ready to use. +func (sa *SAMLAuthenticator) ExtractUserFromAttributes(attributes samlsp.Attributes) (*UserInfo, error) { + if attributes == nil { + return nil, fmt.Errorf("attributes map is nil") + } + + // Initialize user object + user := &UserInfo{ + Attributes: make(map[string]interface{}), + } + + // Helper function to get first attribute value + getAttribute := func(key string) string { + return attributes.Get(key) + } + + // Helper function to get all attribute values + getAttributes := func(key string) []string { + if vals, ok := attributes[key]; ok { + return vals + } + return nil + } + + // Extract mapped attributes + if sa.config.AttributeMapping.Email != "" { + user.Email = getAttribute(sa.config.AttributeMapping.Email) + } + if sa.config.AttributeMapping.Username != "" { + user.Username = getAttribute(sa.config.AttributeMapping.Username) + } + if sa.config.AttributeMapping.FirstName != "" { + user.FirstName = getAttribute(sa.config.AttributeMapping.FirstName) + } + if sa.config.AttributeMapping.LastName != "" { + user.LastName = getAttribute(sa.config.AttributeMapping.LastName) + } + if sa.config.AttributeMapping.Groups != "" { + user.Groups = getAttributes(sa.config.AttributeMapping.Groups) + } + + // Store all attributes in Attributes map for custom use cases + for key, values := range attributes { + if len(values) == 1 { + user.Attributes[key] = values[0] + } else { + user.Attributes[key] = values + } + } + + // Validate required fields + if user.Username == "" { + return nil, fmt.Errorf("username not found in SAML attributes") + } + + return user, nil +} + // GinMiddleware returns a Gin middleware function that enforces SAML authentication. // // This middleware protects routes by requiring valid SAML authentication. It: @@ -930,40 +999,26 @@ func (sa *SAMLAuthenticator) GinMiddleware() gin.HandlerFunc { return } - // STEP 2: Extract assertion from session + // STEP 2: Extract attributes from session // - // The session contains the SAML assertion that was validated during login. - // We need to retrieve it to extract user attributes. + // The session contains the SAML attributes that were extracted during login. + // We need to retrieve them to get user information. // // WHY TYPE ASSERTION: Session interface doesn't expose attributes directly, // must cast to SessionWithAttributes to access GetAttributes() - assertion := session.(samlsp.SessionWithAttributes).GetAttributes() - if assertion == nil { - // Session exists but has no assertion - corrupted session + attributes := session.(samlsp.SessionWithAttributes).GetAttributes() + if attributes == nil { + // Session exists but has no attributes - corrupted session c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid SAML session"}) c.Abort() return } - // STEP 3: Convert to SAML Assertion type - // - // GetAttributes() returns interface{}, we need to type-assert to *saml.Assertion - // to access assertion fields like AttributeStatements. - // - // WHY TYPE CHECK: Ensures we have the correct type before using assertion methods - samlAssertion, ok := assertion.(*saml.Assertion) - if !ok { - // Assertion is not the expected type - should never happen in normal flow - c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid assertion type"}) - c.Abort() - return - } - - // STEP 4: Extract user information from assertion + // STEP 3: Extract user information from attributes // - // Parse SAML assertion attributes and map them to StreamSpace user fields. + // Parse SAML attributes and map them to StreamSpace user fields. // This uses the configured AttributeMapping to translate IdP attributes. - user, err := sa.ExtractUserFromAssertion(samlAssertion) + user, err := sa.ExtractUserFromAttributes(attributes) if err != nil { // Failed to extract required user fields (usually missing username) // This indicates IdP misconfiguration or incorrect attribute mapping diff --git a/api/internal/middleware/compression.go b/api/internal/middleware/compression.go index beb76333..03088e77 100644 --- a/api/internal/middleware/compression.go +++ b/api/internal/middleware/compression.go @@ -47,6 +47,7 @@ package middleware import ( "compress/gzip" "io" + "net/http" "strings" "sync" @@ -133,7 +134,7 @@ func Gzip(level int) gin.HandlerFunc { } // shouldCompress determines if the response should be compressed -func shouldCompress(r *gin.Context.Request) bool { +func shouldCompress(r *http.Request) bool { // Check if client accepts gzip if !strings.Contains(r.Header.Get("Accept-Encoding"), "gzip") { return false diff --git a/api/internal/middleware/inputvalidation.go b/api/internal/middleware/inputvalidation.go index 8fa5317c..c2da5027 100644 --- a/api/internal/middleware/inputvalidation.go +++ b/api/internal/middleware/inputvalidation.go @@ -273,16 +273,16 @@ func (v *InputValidator) sanitizeMap(data map[string]interface{}) map[string]int result := make(map[string]interface{}) for key, value := range data { - switch v := value.(type) { + switch val := value.(type) { case string: // Sanitize string values using bluemonday - result[key] = v.sanitizer.Sanitize(v) + result[key] = v.sanitizer.Sanitize(val) case map[string]interface{}: // Recursively sanitize nested maps - result[key] = v.sanitizeMap(v) + result[key] = v.sanitizeMap(val) case []interface{}: // Sanitize arrays - result[key] = v.sanitizeArray(v) + result[key] = v.sanitizeArray(val) default: // Keep other types as-is (numbers, booleans, etc.) result[key] = value @@ -297,13 +297,13 @@ func (v *InputValidator) sanitizeArray(data []interface{}) []interface{} { result := make([]interface{}, len(data)) for i, value := range data { - switch v := value.(type) { + switch val := value.(type) { case string: - result[i] = v.sanitizer.Sanitize(v) + result[i] = v.sanitizer.Sanitize(val) case map[string]interface{}: - result[i] = v.sanitizeMap(v) + result[i] = v.sanitizeMap(val) case []interface{}: - result[i] = v.sanitizeArray(v) + result[i] = v.sanitizeArray(val) default: result[i] = value } diff --git a/api/internal/quota/enforcer.go b/api/internal/quota/enforcer.go index f9b1897c..c6b172d0 100644 --- a/api/internal/quota/enforcer.go +++ b/api/internal/quota/enforcer.go @@ -210,23 +210,28 @@ func (e *Enforcer) GetUserLimits(ctx context.Context, username string) (*Limits, if user.Quota.MaxSessions > 0 { limits.MaxSessions = user.Quota.MaxSessions } - if user.Quota.MaxCPUPerSession > 0 { - limits.MaxCPUPerSession = user.Quota.MaxCPUPerSession - } - if user.Quota.MaxMemoryPerSession > 0 { - limits.MaxMemoryPerSession = user.Quota.MaxMemoryPerSession - } - if user.Quota.MaxTotalCPU > 0 { - limits.MaxTotalCPU = user.Quota.MaxTotalCPU - } - if user.Quota.MaxTotalMemory > 0 { - limits.MaxTotalMemory = user.Quota.MaxTotalMemory + if user.Quota.MaxCPU != "" { + // Parse MaxCPU as total CPU (both per-session and total) + cpu, err := ParseResourceQuantity(user.Quota.MaxCPU, "cpu") + if err == nil && cpu > 0 { + limits.MaxCPUPerSession = cpu + limits.MaxTotalCPU = cpu + } } - if user.Quota.MaxStorage > 0 { - limits.MaxStorage = user.Quota.MaxStorage + if user.Quota.MaxMemory != "" { + // Parse MaxMemory as total memory (both per-session and total) + memory, err := ParseResourceQuantity(user.Quota.MaxMemory, "memory") + if err == nil && memory > 0 { + limits.MaxMemoryPerSession = memory + limits.MaxTotalMemory = memory + } } - if user.Quota.MaxGPUPerSession >= 0 { - limits.MaxGPUPerSession = user.Quota.MaxGPUPerSession + if user.Quota.MaxStorage != "" { + // Parse MaxStorage + storage, err := ParseResourceQuantity(user.Quota.MaxStorage, "memory") + if err == nil && storage > 0 { + limits.MaxStorage = storage + } } } @@ -243,23 +248,29 @@ func (e *Enforcer) GetUserLimits(ctx context.Context, username string) (*Limits, if group.Quota.MaxSessions > 0 && group.Quota.MaxSessions < limits.MaxSessions { limits.MaxSessions = group.Quota.MaxSessions } - if group.Quota.MaxCPUPerSession > 0 && group.Quota.MaxCPUPerSession < limits.MaxCPUPerSession { - limits.MaxCPUPerSession = group.Quota.MaxCPUPerSession - } - if group.Quota.MaxMemoryPerSession > 0 && group.Quota.MaxMemoryPerSession < limits.MaxMemoryPerSession { - limits.MaxMemoryPerSession = group.Quota.MaxMemoryPerSession - } - if group.Quota.MaxTotalCPU > 0 && group.Quota.MaxTotalCPU < limits.MaxTotalCPU { - limits.MaxTotalCPU = group.Quota.MaxTotalCPU - } - if group.Quota.MaxTotalMemory > 0 && group.Quota.MaxTotalMemory < limits.MaxTotalMemory { - limits.MaxTotalMemory = group.Quota.MaxTotalMemory + if group.Quota.MaxCPU != "" { + cpu, err := ParseResourceQuantity(group.Quota.MaxCPU, "cpu") + if err == nil && cpu > 0 && cpu < limits.MaxCPUPerSession { + limits.MaxCPUPerSession = cpu + } + if err == nil && cpu > 0 && cpu < limits.MaxTotalCPU { + limits.MaxTotalCPU = cpu + } } - if group.Quota.MaxStorage > 0 && group.Quota.MaxStorage < limits.MaxStorage { - limits.MaxStorage = group.Quota.MaxStorage + if group.Quota.MaxMemory != "" { + memory, err := ParseResourceQuantity(group.Quota.MaxMemory, "memory") + if err == nil && memory > 0 && memory < limits.MaxMemoryPerSession { + limits.MaxMemoryPerSession = memory + } + if err == nil && memory > 0 && memory < limits.MaxTotalMemory { + limits.MaxTotalMemory = memory + } } - if group.Quota.MaxGPUPerSession >= 0 && group.Quota.MaxGPUPerSession < limits.MaxGPUPerSession { - limits.MaxGPUPerSession = group.Quota.MaxGPUPerSession + if group.Quota.MaxStorage != "" { + storage, err := ParseResourceQuantity(group.Quota.MaxStorage, "memory") + if err == nil && storage > 0 && storage < limits.MaxStorage { + limits.MaxStorage = storage + } } } }