Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion api/cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -843,7 +843,7 @@ func corsMiddleware() gin.HandlerFunc {
c.Writer.Header().Set("Access-Control-Allow-Credentials", "true")
}

c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With")
c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With, Upgrade, Connection, Sec-WebSocket-Key, Sec-WebSocket-Version, Sec-WebSocket-Extensions, Sec-WebSocket-Protocol")
c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT, PATCH, DELETE")

if c.Request.Method == "OPTIONS" {
Expand Down
16 changes: 7 additions & 9 deletions api/internal/auth/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,15 +132,13 @@ func NewAuthHandler(userDB *db.UserDB, jwtManager *JWTManager, samlAuth *SAMLAut

// RegisterRoutes registers authentication routes
func (h *AuthHandler) RegisterRoutes(router *gin.RouterGroup) {
auth := router.Group("/auth")
{
auth.POST("/login", h.Login)
auth.POST("/refresh", h.RefreshToken)
auth.POST("/logout", h.Logout)
auth.GET("/saml/login", h.SAMLLogin)
auth.POST("/saml/acs", h.SAMLCallback)
auth.GET("/saml/metadata", h.SAMLMetadata)
}
// Note: router is already /api/v1/auth from main.go
router.POST("/login", h.Login)
router.POST("/refresh", h.RefreshToken)
router.POST("/logout", h.Logout)
router.GET("/saml/login", h.SAMLLogin)
router.POST("/saml/acs", h.SAMLCallback)
router.GET("/saml/metadata", h.SAMLMetadata)
}

// LoginRequest represents a login request
Expand Down
24 changes: 24 additions & 0 deletions api/internal/auth/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,17 @@ import (
// Middleware creates an authentication middleware
func Middleware(jwtManager *JWTManager, userDB *db.UserDB) gin.HandlerFunc {
return func(c *gin.Context) {
// Check if this is a WebSocket upgrade request
isWebSocket := c.GetHeader("Upgrade") == "websocket" && c.GetHeader("Connection") == "Upgrade"

// Extract token from Authorization header
authHeader := c.GetHeader("Authorization")
if authHeader == "" {
// For WebSocket, abort without writing response (let upgrader handle it)
if isWebSocket {
c.AbortWithStatus(http.StatusUnauthorized)
return
}
c.JSON(http.StatusUnauthorized, gin.H{
"error": "Authorization header required",
})
Expand All @@ -156,6 +164,10 @@ func Middleware(jwtManager *JWTManager, userDB *db.UserDB) gin.HandlerFunc {
// Check Bearer prefix
parts := strings.SplitN(authHeader, " ", 2)
if len(parts) != 2 || parts[0] != "Bearer" {
if isWebSocket {
c.AbortWithStatus(http.StatusUnauthorized)
return
}
c.JSON(http.StatusUnauthorized, gin.H{
"error": "Invalid authorization header format. Use: Bearer <token>",
})
Expand All @@ -168,6 +180,10 @@ func Middleware(jwtManager *JWTManager, userDB *db.UserDB) gin.HandlerFunc {
// Validate token
claims, err := jwtManager.ValidateToken(tokenString)
if err != nil {
if isWebSocket {
c.AbortWithStatus(http.StatusUnauthorized)
return
}
c.JSON(http.StatusUnauthorized, gin.H{
"error": "Invalid or expired token",
"message": err.Error(),
Expand All @@ -179,6 +195,10 @@ func Middleware(jwtManager *JWTManager, userDB *db.UserDB) gin.HandlerFunc {
// Verify user still exists and is active
user, err := userDB.GetUser(c.Request.Context(), claims.UserID)
if err != nil {
if isWebSocket {
c.AbortWithStatus(http.StatusUnauthorized)
return
}
c.JSON(http.StatusUnauthorized, gin.H{
"error": "User not found",
})
Expand All @@ -187,6 +207,10 @@ func Middleware(jwtManager *JWTManager, userDB *db.UserDB) gin.HandlerFunc {
}

if !user.Active {
if isWebSocket {
c.AbortWithStatus(http.StatusForbidden)
return
}
c.JSON(http.StatusForbidden, gin.H{
"error": "User account is disabled",
})
Expand Down
29 changes: 0 additions & 29 deletions api/internal/handlers/setup.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,8 @@
package handlers

import (
"bytes"
"database/sql"
"fmt"
"io"
"net/http"
"regexp"

Expand Down Expand Up @@ -81,10 +79,6 @@ type SetupStatusResponse struct {
func (h *SetupHandler) GetSetupStatus(c *gin.Context) {
setupRequired, adminExists, hasPassword := h.isSetupRequired()

// Debug logging
fmt.Printf("DEBUG GetSetupStatus: setupRequired=%v, adminExists=%v, hasPassword=%v\n",
setupRequired, adminExists, hasPassword)

var message string
if setupRequired {
message = "Setup wizard is available - admin account needs password configuration"
Expand All @@ -101,9 +95,6 @@ func (h *SetupHandler) GetSetupStatus(c *gin.Context) {
Message: message,
}

// Debug logging
fmt.Printf("DEBUG GetSetupStatus response: %+v\n", response)

c.JSON(http.StatusOK, response)
}

Expand All @@ -116,22 +107,17 @@ func (h *SetupHandler) isSetupRequired() (bool, bool, bool) {
if err != nil {
if err == sql.ErrNoRows {
// Admin user doesn't exist yet
fmt.Printf("DEBUG isSetupRequired: Admin user not found (sql.ErrNoRows)\n")
return false, false, false
}
// Database error - don't allow setup
fmt.Printf("DEBUG isSetupRequired: Database error: %v\n", err)
return false, true, false
}

// Admin exists, check if password is set
hasPassword := passwordHash.Valid && passwordHash.String != ""
fmt.Printf("DEBUG isSetupRequired: Admin found - passwordHash.Valid=%v, passwordHash.String=%q, hasPassword=%v\n",
passwordHash.Valid, passwordHash.String, hasPassword)

// Setup required if admin exists but has no password
setupRequired := !hasPassword
fmt.Printf("DEBUG isSetupRequired: setupRequired=%v\n", setupRequired)
return setupRequired, true, hasPassword
}

Expand Down Expand Up @@ -175,10 +161,6 @@ func (h *SetupHandler) SetupAdmin(c *gin.Context) {
// Check if setup is allowed
setupRequired, adminExists, hasPassword := h.isSetupRequired()

// Debug logging
fmt.Printf("DEBUG SetupAdmin: setupRequired=%v, adminExists=%v, hasPassword=%v\n",
setupRequired, adminExists, hasPassword)

if !setupRequired {
if !adminExists {
c.JSON(http.StatusForbidden, gin.H{
Expand All @@ -196,27 +178,16 @@ func (h *SetupHandler) SetupAdmin(c *gin.Context) {
}
}

// Debug: Log request body
bodyBytes, _ := c.GetRawData()
fmt.Printf("DEBUG SetupAdmin request body: %s\n", string(bodyBytes))
fmt.Printf("DEBUG SetupAdmin Content-Type: %s\n", c.ContentType())
// Restore body for binding
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))

// Parse and validate request
var req SetupAdminRequest
if err := c.ShouldBindJSON(&req); err != nil {
fmt.Printf("DEBUG SetupAdmin bind error: %v\n", err)
c.JSON(http.StatusBadRequest, gin.H{
"error": "Invalid request format",
"details": err.Error(),
})
return
}

fmt.Printf("DEBUG SetupAdmin parsed: password=%d chars, email=%s\n",
len(req.Password), req.Email)

// Validate password confirmation
if req.Password != req.PasswordConfirm {
c.JSON(http.StatusBadRequest, gin.H{
Expand Down
32 changes: 19 additions & 13 deletions api/internal/handlers/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -324,25 +324,31 @@ func checkWebSocketOrigin(r *http.Request) bool {
return true
}

// Get allowed origins from environment variable
// Format: ALLOWED_ORIGINS=https://app.streamspace.io,https://streamspace.io
allowedOriginsEnv := os.Getenv("ALLOWED_ORIGINS")
// Get allowed origins from environment variable (same as CORS middleware)
// Format: CORS_ALLOWED_ORIGINS=https://app.streamspace.io,https://streamspace.io
allowedOriginsEnv := os.Getenv("CORS_ALLOWED_ORIGINS")

var allowedOrigins []string
if allowedOriginsEnv != "" {
allowedOrigins := strings.Split(allowedOriginsEnv, ",")
for _, allowed := range allowedOrigins {
if strings.TrimSpace(allowed) == origin {
return true
}
// Parse comma-separated list of origins
for _, allowedOrigin := range strings.Split(allowedOriginsEnv, ",") {
allowedOrigins = append(allowedOrigins, strings.TrimSpace(allowedOrigin))
}
}

// Check if origin matches the request host (same-origin)
requestHost := r.Host
if strings.HasPrefix(origin, "http://"+requestHost) || strings.HasPrefix(origin, "https://"+requestHost) {
return true
// If no origins specified, use localhost only for development (same as CORS middleware)
if len(allowedOrigins) == 0 {
allowedOrigins = []string{"http://localhost:3000", "http://localhost:8000"}
}

// Check if origin is in allowed list
for _, allowed := range allowedOrigins {
if origin == allowed {
return true
}
}

// Allow localhost and 127.0.0.1 for development
// Also allow any localhost or 127.0.0.1 origin for development
if strings.Contains(origin, "localhost") || strings.Contains(origin, "127.0.0.1") {
return true
}
Expand Down
16 changes: 15 additions & 1 deletion api/internal/middleware/inputvalidation.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,10 @@
package middleware

import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"regexp"
"strings"
Expand Down Expand Up @@ -114,8 +117,19 @@ func (v *InputValidator) SanitizeJSONMiddleware() gin.HandlerFunc {
return
}

// Read and preserve the request body
bodyBytes, err := io.ReadAll(c.Request.Body)
if err != nil {
c.Next()
return
}

// Restore the body for handlers to read
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))

// Try to parse as JSON map
var data map[string]interface{}
if err := c.ShouldBindJSON(&data); err != nil {
if err := json.Unmarshal(bodyBytes, &data); err != nil {
// If it's not a map, let it pass to the handler which will validate properly
c.Next()
return
Expand Down
Loading