Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion api/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,5 @@ require (
sigs.k8s.io/json v0.0.0-20241014173422-cfa47c3a1cc8 // indirect
sigs.k8s.io/randfill v1.0.0 // indirect
sigs.k8s.io/structured-merge-diff/v4 v4.4.1 // indirect
sigs.k8s.io/structured-merge-diff/v6 v6.3.0 // indirect
sigs.k8s.io/yaml v1.6.0 // indirect
)
20 changes: 20 additions & 0 deletions api/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -255,16 +255,34 @@ gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gotest.tools v2.2.0+incompatible h1:VsBPFP1AI068pPrMxtb/S8Zkgf9xEmTLJjfM+P5UIEo=
gotest.tools v2.2.0+incompatible/go.mod h1:DsYFclhRJ6vuDpmuTbkuFWG+y2sxOXAzmJt81HFBacw=
k8s.io/api v0.29.11 h1:6FwDo33f1WX5Yu0RQTX9YAd3wth8Ik0B4SXQKsoQfbk=
k8s.io/api v0.29.11/go.mod h1:3TDAW1OpFbz/Yx5r0W06b6eiAfHEwtH61VYDzpTU4Ng=
k8s.io/api v0.30.8 h1:Y+yZRF3c1WC0MTkLe0qBkiLCquRNa4I21/iDioGMCbo=
k8s.io/api v0.30.8/go.mod h1:89IE5MzirZ5HHxU/Hq1/KWGqXkhXClu/FHGesFhQ0A4=
k8s.io/api v0.31.0 h1:b9LiSjR2ym/SzTOlfMHm1tr7/21aD7fSkqgD/CVJBCo=
k8s.io/api v0.31.0/go.mod h1:0YiFF+JfFxMM6+1hQei8FY8M7s1Mth+z/q7eF1aJkTE=
k8s.io/apimachinery v0.29.11 h1:55+6ue9advpA7T0sX2ZJDHCLKuiFfrAAR/39VQN9KEQ=
k8s.io/apimachinery v0.29.11/go.mod h1:i3FJVwhvSp/6n8Fl4K97PJEP8C+MM+aoDq4+ZJBf70Y=
k8s.io/apimachinery v0.30.8 h1:9jyTItYzmJc00cBDxZC5ArFNxUeKCwbw0m760iFUMKY=
k8s.io/apimachinery v0.30.8/go.mod h1:iexa2somDaxdnj7bha06bhb43Zpa6eWH8N8dbqVjTUc=
k8s.io/apimachinery v0.31.0 h1:m9jOiSr3FoSSL5WO9bjm1n6B9KROYYgNZOb4tyZ1lBc=
k8s.io/apimachinery v0.31.0/go.mod h1:rsPdaZJfTfLsNJSQzNHQvYoTmxhoOEofxtOsF3rtsMo=
k8s.io/client-go v0.29.11 h1:mBX7Ub0uqpLMwWz3J/AGS/xKOZsjr349qZ1vxVoL1l8=
k8s.io/client-go v0.29.11/go.mod h1:WOEoi/eLg2YEg3/yEd7YK3CNScYkM8AEScQadxUnaTE=
k8s.io/client-go v0.30.8 h1:fC1SQMZm7bSWiVv9ydN+nv+sqGVAxMdf/5eKUVffNJE=
k8s.io/client-go v0.30.8/go.mod h1:daF3UcGVqGPHvH5mn/ESkp/VoR8i9tg9IBfKr+AeFYo=
k8s.io/client-go v0.31.0 h1:QqEJzNjbN2Yv1H79SsS+SWnXkBgVu4Pj3CJQgbx0gI8=
k8s.io/client-go v0.31.0/go.mod h1:Y9wvC76g4fLjmU0BA+rV+h2cncoadjvjjkkIGoTLcGU=
k8s.io/klog/v2 v2.130.1 h1:n9Xl7H1Xvksem4KFG4PYbdQCQxqc/tTUyrgXaOhHSzk=
k8s.io/klog/v2 v2.130.1/go.mod h1:3Jpz1GvMt720eyJH1ckRHK1EDfpxISzJ7I9OYgaDtPE=
k8s.io/kube-openapi v0.0.0-20240228011516-70dd3763d340 h1:BZqlfIlq5YbRMFko6/PM7FjZpUb45WallggurYhKGag=
k8s.io/kube-openapi v0.0.0-20240228011516-70dd3763d340/go.mod h1:yD4MZYeKMBwQKVht279WycxKyM84kkAx2DPrTXaeb98=
k8s.io/kube-openapi v0.0.0-20250710124328-f3f2b991d03b h1:MloQ9/bdJyIu9lb1PzujOPolHyvO06MXG5TUIj2mNAA=
k8s.io/kube-openapi v0.0.0-20250710124328-f3f2b991d03b/go.mod h1:UZ2yyWbFTpuhSbFhv24aGNOdoRdJZgsIObGBUaYVsts=
k8s.io/metrics v0.29.11 h1:GtIYoM6uKUQIygnUpVEYE4X12xggwWmnYe443TZumcg=
k8s.io/metrics v0.29.11/go.mod h1:aroO9WuHjPwp4Z6ACFG7aCZf3m8Un+Vn3ogauOlJHrE=
k8s.io/metrics v0.30.8 h1:oM7kdmGjnahww89CLSF95XHC5dWtb05+cgplUeu/fqE=
k8s.io/metrics v0.30.8/go.mod h1:fbEC4z4Q5uwGXtfJIGQyoq3ndBCORYopDA0oXReDk2I=
k8s.io/metrics v0.31.0 h1:s7Vu7W0oEZPTN8jgcoiWIXIZBmVxt7YP9MRVyIgMdOc=
k8s.io/metrics v0.31.0/go.mod h1:UNsz6swyX8FWkDoKN9ixPF75TBREMbHZIKjD7fydaOY=
k8s.io/utils v0.0.0-20250604170112-4c0f3b243397 h1:hwvWFiBzdWw1FhfY1FooPn3kzWuJ8tmbZBHi4zVsl1Y=
Expand All @@ -276,6 +294,8 @@ sigs.k8s.io/randfill v1.0.0 h1:JfjMILfT8A6RbawdsK2JXGBR5AQVfd+9TbzrlneTyrU=
sigs.k8s.io/randfill v1.0.0/go.mod h1:XeLlZ/jmk4i1HRopwe7/aU3H5n1zNUcX6TM94b3QxOY=
sigs.k8s.io/structured-merge-diff/v4 v4.4.1 h1:150L+0vs/8DA78h1u02ooW1/fFq/Lwr+sGiqlzvrtq4=
sigs.k8s.io/structured-merge-diff/v4 v4.4.1/go.mod h1:N8hJocpFajUSSeSJ9bOZ77VzejKZaXsTtZo4/u7Io08=
sigs.k8s.io/structured-merge-diff/v6 v6.2.0 h1:msyqjP8Nyd5sF3QSmJouFSzcBIdwq4ct8d1/7VSBHIQ=
sigs.k8s.io/structured-merge-diff/v6 v6.2.0/go.mod h1:M3W8sfWvn2HhQDIbGWj3S099YozAsymCo/wrT5ohRUE=
sigs.k8s.io/structured-merge-diff/v6 v6.3.0 h1:jTijUJbW353oVOd9oTlifJqOGEkUw2jB/fXCbTiQEco=
sigs.k8s.io/structured-merge-diff/v6 v6.3.0/go.mod h1:M3W8sfWvn2HhQDIbGWj3S099YozAsymCo/wrT5ohRUE=
sigs.k8s.io/yaml v1.6.0 h1:G8fkbMSAFqgEFgh4b1wmtzDnioxFCUgTZhlbj5P9QYs=
Expand Down
13 changes: 12 additions & 1 deletion api/internal/auth/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,8 @@ package auth

import (
"context"
"encoding/xml"
"fmt"
"log"
"net/http"
"time"
Expand Down Expand Up @@ -473,9 +475,18 @@ func (h *AuthHandler) SAMLMetadata(c *gin.Context) {
// Generate metadata XML
metadata := sp.Metadata()

// Marshal to XML bytes
metadataBytes, err := xml.Marshal(metadata)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": fmt.Sprintf("Failed to marshal metadata: %v", err),
})
return
}

// Return XML with proper content type
c.Header("Content-Type", "application/samlmetadata+xml")
c.String(http.StatusOK, string(metadata))
c.String(http.StatusOK, string(metadataBytes))
}

// PasswordChangeRequest represents a password change request
Expand Down
1 change: 0 additions & 1 deletion api/internal/auth/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,6 @@
package auth

import (
"context"
"net/http"
"strings"

Expand Down
1 change: 1 addition & 0 deletions api/internal/auth/oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ package auth

import (
"context"
"crypto/tls"
"encoding/json"
"fmt"
"log"
Expand Down
101 changes: 78 additions & 23 deletions api/internal/auth/saml.go
Original file line number Diff line number Diff line change
Expand Up @@ -554,6 +554,11 @@ func (sa *SAMLAuthenticator) GetMiddleware() *samlsp.Middleware {
return sa.middleware
}

// GetServiceProvider returns the SAML Service Provider instance.
func (sa *SAMLAuthenticator) GetServiceProvider() *saml.ServiceProvider {
return sa.serviceProvider
}

// ExtractUserFromAssertion extracts user information from a SAML assertion.
//
// After the IdP authenticates a user, it sends a SAML assertion containing:
Expand Down Expand Up @@ -815,6 +820,70 @@ func (sa *SAMLAuthenticator) ExtractUserFromAssertion(assertion *saml.Assertion)
return user, nil
}

// ExtractUserFromAttributes extracts user information from SAML session attributes.
//
// This is a simpler version of ExtractUserFromAssertion that works with the
// attributes map returned by SessionWithAttributes.GetAttributes().
//
// The attributes map contains key-value pairs from the SAML assertion's
// AttributeStatements, already parsed and ready to use.
func (sa *SAMLAuthenticator) ExtractUserFromAttributes(attributes samlsp.Attributes) (*UserInfo, error) {
if attributes == nil {
return nil, fmt.Errorf("attributes map is nil")
}

// Initialize user object
user := &UserInfo{
Attributes: make(map[string]interface{}),
}

// Helper function to get first attribute value
getAttribute := func(key string) string {
return attributes.Get(key)
}

// Helper function to get all attribute values
getAttributes := func(key string) []string {
if vals, ok := attributes[key]; ok {
return vals
}
return nil
}

// Extract mapped attributes
if sa.config.AttributeMapping.Email != "" {
user.Email = getAttribute(sa.config.AttributeMapping.Email)
}
if sa.config.AttributeMapping.Username != "" {
user.Username = getAttribute(sa.config.AttributeMapping.Username)
}
if sa.config.AttributeMapping.FirstName != "" {
user.FirstName = getAttribute(sa.config.AttributeMapping.FirstName)
}
if sa.config.AttributeMapping.LastName != "" {
user.LastName = getAttribute(sa.config.AttributeMapping.LastName)
}
if sa.config.AttributeMapping.Groups != "" {
user.Groups = getAttributes(sa.config.AttributeMapping.Groups)
}

// Store all attributes in Attributes map for custom use cases
for key, values := range attributes {
if len(values) == 1 {
user.Attributes[key] = values[0]
} else {
user.Attributes[key] = values
}
}

// Validate required fields
if user.Username == "" {
return nil, fmt.Errorf("username not found in SAML attributes")
}

return user, nil
}

// GinMiddleware returns a Gin middleware function that enforces SAML authentication.
//
// This middleware protects routes by requiring valid SAML authentication. It:
Expand Down Expand Up @@ -930,40 +999,26 @@ func (sa *SAMLAuthenticator) GinMiddleware() gin.HandlerFunc {
return
}

// STEP 2: Extract assertion from session
// STEP 2: Extract attributes from session
//
// The session contains the SAML assertion that was validated during login.
// We need to retrieve it to extract user attributes.
// The session contains the SAML attributes that were extracted during login.
// We need to retrieve them to get user information.
//
// WHY TYPE ASSERTION: Session interface doesn't expose attributes directly,
// must cast to SessionWithAttributes to access GetAttributes()
assertion := session.(samlsp.SessionWithAttributes).GetAttributes()
if assertion == nil {
// Session exists but has no assertion - corrupted session
attributes := session.(samlsp.SessionWithAttributes).GetAttributes()
if attributes == nil {
// Session exists but has no attributes - corrupted session
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid SAML session"})
c.Abort()
return
}

// STEP 3: Convert to SAML Assertion type
//
// GetAttributes() returns interface{}, we need to type-assert to *saml.Assertion
// to access assertion fields like AttributeStatements.
//
// WHY TYPE CHECK: Ensures we have the correct type before using assertion methods
samlAssertion, ok := assertion.(*saml.Assertion)
if !ok {
// Assertion is not the expected type - should never happen in normal flow
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid assertion type"})
c.Abort()
return
}

// STEP 4: Extract user information from assertion
// STEP 3: Extract user information from attributes
//
// Parse SAML assertion attributes and map them to StreamSpace user fields.
// Parse SAML attributes and map them to StreamSpace user fields.
// This uses the configured AttributeMapping to translate IdP attributes.
user, err := sa.ExtractUserFromAssertion(samlAssertion)
user, err := sa.ExtractUserFromAttributes(attributes)
if err != nil {
// Failed to extract required user fields (usually missing username)
// This indicates IdP misconfiguration or incorrect attribute mapping
Expand Down
3 changes: 2 additions & 1 deletion api/internal/middleware/compression.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ package middleware
import (
"compress/gzip"
"io"
"net/http"
"strings"
"sync"

Expand Down Expand Up @@ -133,7 +134,7 @@ func Gzip(level int) gin.HandlerFunc {
}

// shouldCompress determines if the response should be compressed
func shouldCompress(r *gin.Context.Request) bool {
func shouldCompress(r *http.Request) bool {
// Check if client accepts gzip
if !strings.Contains(r.Header.Get("Accept-Encoding"), "gzip") {
return false
Expand Down
16 changes: 8 additions & 8 deletions api/internal/middleware/inputvalidation.go
Original file line number Diff line number Diff line change
Expand Up @@ -273,16 +273,16 @@ func (v *InputValidator) sanitizeMap(data map[string]interface{}) map[string]int
result := make(map[string]interface{})

for key, value := range data {
switch v := value.(type) {
switch val := value.(type) {
case string:
// Sanitize string values using bluemonday
result[key] = v.sanitizer.Sanitize(v)
result[key] = v.sanitizer.Sanitize(val)
case map[string]interface{}:
// Recursively sanitize nested maps
result[key] = v.sanitizeMap(v)
result[key] = v.sanitizeMap(val)
case []interface{}:
// Sanitize arrays
result[key] = v.sanitizeArray(v)
result[key] = v.sanitizeArray(val)
default:
// Keep other types as-is (numbers, booleans, etc.)
result[key] = value
Expand All @@ -297,13 +297,13 @@ func (v *InputValidator) sanitizeArray(data []interface{}) []interface{} {
result := make([]interface{}, len(data))

for i, value := range data {
switch v := value.(type) {
switch val := value.(type) {
case string:
result[i] = v.sanitizer.Sanitize(v)
result[i] = v.sanitizer.Sanitize(val)
case map[string]interface{}:
result[i] = v.sanitizeMap(v)
result[i] = v.sanitizeMap(val)
case []interface{}:
result[i] = v.sanitizeArray(v)
result[i] = v.sanitizeArray(val)
default:
result[i] = value
}
Expand Down
71 changes: 41 additions & 30 deletions api/internal/quota/enforcer.go
Original file line number Diff line number Diff line change
Expand Up @@ -210,23 +210,28 @@ func (e *Enforcer) GetUserLimits(ctx context.Context, username string) (*Limits,
if user.Quota.MaxSessions > 0 {
limits.MaxSessions = user.Quota.MaxSessions
}
if user.Quota.MaxCPUPerSession > 0 {
limits.MaxCPUPerSession = user.Quota.MaxCPUPerSession
}
if user.Quota.MaxMemoryPerSession > 0 {
limits.MaxMemoryPerSession = user.Quota.MaxMemoryPerSession
}
if user.Quota.MaxTotalCPU > 0 {
limits.MaxTotalCPU = user.Quota.MaxTotalCPU
}
if user.Quota.MaxTotalMemory > 0 {
limits.MaxTotalMemory = user.Quota.MaxTotalMemory
if user.Quota.MaxCPU != "" {
// Parse MaxCPU as total CPU (both per-session and total)
cpu, err := ParseResourceQuantity(user.Quota.MaxCPU, "cpu")
if err == nil && cpu > 0 {
limits.MaxCPUPerSession = cpu
limits.MaxTotalCPU = cpu
}
}
if user.Quota.MaxStorage > 0 {
limits.MaxStorage = user.Quota.MaxStorage
if user.Quota.MaxMemory != "" {
// Parse MaxMemory as total memory (both per-session and total)
memory, err := ParseResourceQuantity(user.Quota.MaxMemory, "memory")
if err == nil && memory > 0 {
limits.MaxMemoryPerSession = memory
limits.MaxTotalMemory = memory
}
}
if user.Quota.MaxGPUPerSession >= 0 {
limits.MaxGPUPerSession = user.Quota.MaxGPUPerSession
if user.Quota.MaxStorage != "" {
// Parse MaxStorage
storage, err := ParseResourceQuantity(user.Quota.MaxStorage, "memory")
if err == nil && storage > 0 {
limits.MaxStorage = storage
}
}
}

Expand All @@ -243,23 +248,29 @@ func (e *Enforcer) GetUserLimits(ctx context.Context, username string) (*Limits,
if group.Quota.MaxSessions > 0 && group.Quota.MaxSessions < limits.MaxSessions {
limits.MaxSessions = group.Quota.MaxSessions
}
if group.Quota.MaxCPUPerSession > 0 && group.Quota.MaxCPUPerSession < limits.MaxCPUPerSession {
limits.MaxCPUPerSession = group.Quota.MaxCPUPerSession
}
if group.Quota.MaxMemoryPerSession > 0 && group.Quota.MaxMemoryPerSession < limits.MaxMemoryPerSession {
limits.MaxMemoryPerSession = group.Quota.MaxMemoryPerSession
}
if group.Quota.MaxTotalCPU > 0 && group.Quota.MaxTotalCPU < limits.MaxTotalCPU {
limits.MaxTotalCPU = group.Quota.MaxTotalCPU
}
if group.Quota.MaxTotalMemory > 0 && group.Quota.MaxTotalMemory < limits.MaxTotalMemory {
limits.MaxTotalMemory = group.Quota.MaxTotalMemory
if group.Quota.MaxCPU != "" {
cpu, err := ParseResourceQuantity(group.Quota.MaxCPU, "cpu")
if err == nil && cpu > 0 && cpu < limits.MaxCPUPerSession {
limits.MaxCPUPerSession = cpu
}
if err == nil && cpu > 0 && cpu < limits.MaxTotalCPU {
limits.MaxTotalCPU = cpu
}
}
if group.Quota.MaxStorage > 0 && group.Quota.MaxStorage < limits.MaxStorage {
limits.MaxStorage = group.Quota.MaxStorage
if group.Quota.MaxMemory != "" {
memory, err := ParseResourceQuantity(group.Quota.MaxMemory, "memory")
if err == nil && memory > 0 && memory < limits.MaxMemoryPerSession {
limits.MaxMemoryPerSession = memory
}
if err == nil && memory > 0 && memory < limits.MaxTotalMemory {
limits.MaxTotalMemory = memory
}
}
if group.Quota.MaxGPUPerSession >= 0 && group.Quota.MaxGPUPerSession < limits.MaxGPUPerSession {
limits.MaxGPUPerSession = group.Quota.MaxGPUPerSession
if group.Quota.MaxStorage != "" {
storage, err := ParseResourceQuantity(group.Quota.MaxStorage, "memory")
if err == nil && storage > 0 && storage < limits.MaxStorage {
limits.MaxStorage = storage
}
}
}
}
Expand Down
Loading