diff --git a/Makefile b/Makefile index 5603a9a..b75352f 100644 --- a/Makefile +++ b/Makefile @@ -1,6 +1,6 @@ .PHONY: backend frontend build docker test test-frontend test-all clean dev-backend dev-frontend dev-bundle lint lint-frontend lint-backend security govuln vuln secrets audit hooks check docs e2e-fresh stop -BACKEND_ENV ?= LISTEN_ADDR=:9096 +BACKEND_ENV ?= LISTEN_ADDR=:9096 COOKIE_SECURE=false BIN_DIR ?= $(PWD)/bin BINARY ?= $(BIN_DIR)/warden diff --git a/internal/api/auth.go b/internal/api/auth.go index af3746d..b1c4d1c 100644 --- a/internal/api/auth.go +++ b/internal/api/auth.go @@ -103,7 +103,7 @@ func (h *AuthHandler) Login(w http.ResponseWriter, r *http.Request) { } // Set Cookie - http.SetCookie(w, &http.Cookie{ + http.SetCookie(w, &http.Cookie{ // #nosec G124 -- Secure defaults true; configurable for local HTTP dev Name: "auth_token", Value: token, Expires: expiresAt, @@ -113,11 +113,38 @@ func (h *AuthHandler) Login(w http.ResponseWriter, r *http.Request) { Secure: h.config.CookieSecure, }) + // Fetch full user details for the response (avatar, display name, etc.) + fullUser, _ := h.store.GetUser(user.ID) + avatar := "" + displayName := user.Username + email := "" + timezone := "UTC" + ssoProvider := "" + if fullUser != nil { + displayName = fullUser.DisplayName + if displayName == "" { + displayName = fullUser.Username + } + email = fullUser.Email + timezone = fullUser.Timezone + ssoProvider = fullUser.SSOProvider + avatar = fullUser.AvatarURL + if avatar == "" { + avatar = "https://ui-avatars.com/api/?name=" + url.QueryEscape(displayName) + "&background=random" + } + } + writeJSON(w, http.StatusOK, map[string]any{ "message": "logged in", "user": map[string]any{ - "username": user.Username, - "id": user.ID, + "username": user.Username, + "id": user.ID, + "role": user.Role, + "displayName": displayName, + "email": email, + "timezone": timezone, + "ssoProvider": ssoProvider, + "avatar": avatar, }, }) } @@ -129,7 +156,7 @@ func (h *AuthHandler) Logout(w http.ResponseWriter, r *http.Request) { } // Clear Cookie - http.SetCookie(w, &http.Cookie{ + http.SetCookie(w, &http.Cookie{ // #nosec G124 -- Secure defaults true; configurable for local HTTP dev Name: "auth_token", Value: "", Expires: time.Now().Add(-1 * time.Hour), @@ -183,6 +210,7 @@ func (h *AuthHandler) Me(w http.ResponseWriter, r *http.Request) { "ssoProvider": user.SSOProvider, "avatar": avatar, "displayName": displayName, + "role": user.Role, }, }) } @@ -247,25 +275,45 @@ func (h *AuthHandler) UpdateUser(w http.ResponseWriter, r *http.Request) { writeJSON(w, http.StatusOK, map[string]string{"message": "settings updated"}) } -// IsAuthenticated checks whether a request has a valid session cookie or API key -// without writing a response. Used by handlers that need optional auth checks. -func (h *AuthHandler) IsAuthenticated(r *http.Request) bool { +// AuthInfo holds identity information extracted from a request. +type AuthInfo struct { + Authenticated bool + UserID int64 + Role string +} + +// GetAuthInfo extracts authentication details from a request without writing a response. +// Returns user ID, role, and whether the request is authenticated. +func (h *AuthHandler) GetAuthInfo(r *http.Request) AuthInfo { // Check Bearer token authHeader := r.Header.Get("Authorization") if len(authHeader) > 7 && authHeader[:7] == "Bearer " { token := authHeader[7:] - valid, err := h.store.ValidateAPIKey(token) + valid, role, err := h.store.ValidateAPIKey(token) if err == nil && valid { - return true + return AuthInfo{Authenticated: true, UserID: APIKeyUserID, Role: role} } } // Check session cookie c, err := r.Cookie("auth_token") if err != nil { - return false + return AuthInfo{} } sess, err := h.store.GetSession(c.Value) - return err == nil && sess != nil + if err != nil || sess == nil { + return AuthInfo{} + } + role, err := h.store.GetUserRole(sess.UserID) + if err != nil { + return AuthInfo{} + } + return AuthInfo{Authenticated: true, UserID: sess.UserID, Role: role} +} + +// IsAuthenticated checks whether a request has a valid session cookie or API key +// without writing a response. Used by handlers that need optional auth checks. +func (h *AuthHandler) IsAuthenticated(r *http.Request) bool { + return h.GetAuthInfo(r).Authenticated } // Middleware @@ -276,11 +324,12 @@ func (h *AuthHandler) AuthMiddleware(next http.Handler) http.Handler { authHeader := r.Header.Get("Authorization") if len(authHeader) > 7 && authHeader[:7] == "Bearer " { token := authHeader[7:] - valid, err := h.store.ValidateAPIKey(token) + valid, role, err := h.store.ValidateAPIKey(token) if err == nil && valid { // Valid API Key - use special negative ID to distinguish from real users // SECURITY: APIKeyUserID (-1) prevents confusion with real user IDs ctx := context.WithValue(r.Context(), contextKeyUserID, APIKeyUserID) + ctx = context.WithValue(ctx, contextKeyUserRole, role) next.ServeHTTP(w, r.WithContext(ctx)) return } @@ -301,8 +350,16 @@ func (h *AuthHandler) AuthMiddleware(next http.Handler) http.Handler { return } - // 4. Inject UserID into Context + // 4. Fetch user role + role, err := h.store.GetUserRole(sess.UserID) + if err != nil { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + + // 5. Inject UserID and Role into Context ctx := context.WithValue(r.Context(), contextKeyUserID, sess.UserID) + ctx = context.WithValue(ctx, contextKeyUserRole, role) next.ServeHTTP(w, r.WithContext(ctx)) }) } diff --git a/internal/api/handlers_apikeys.go b/internal/api/handlers_apikeys.go index 4e01b25..6cf30a0 100644 --- a/internal/api/handlers_apikeys.go +++ b/internal/api/handlers_apikeys.go @@ -25,6 +25,9 @@ func NewAPIKeyHandler(store *db.Store) *APIKeyHandler { // @Success 200 {object} object{keys=[]db.APIKey} // @Router /api-keys [get] func (h *APIKeyHandler) ListKeys(w http.ResponseWriter, r *http.Request) { + if !requireRole(w, r, RoleAdmin) { + return + } keys, err := h.store.ListAPIKeys() if err != nil { writeError(w, http.StatusInternalServerError, "failed to list keys") @@ -44,8 +47,13 @@ func (h *APIKeyHandler) ListKeys(w http.ResponseWriter, r *http.Request) { // @Failure 400 {object} object{error=string} "Name is required" // @Router /api-keys [post] func (h *APIKeyHandler) CreateKey(w http.ResponseWriter, r *http.Request) { + if !requireRole(w, r, RoleAdmin) { + return + } + var req struct { Name string `json:"name"` + Role string `json:"role"` } if err := json.NewDecoder(r.Body).Decode(&req); err != nil { writeError(w, http.StatusBadRequest, "invalid request") @@ -55,8 +63,15 @@ func (h *APIKeyHandler) CreateKey(w http.ResponseWriter, r *http.Request) { writeError(w, http.StatusBadRequest, "name is required") return } + if req.Role != "" && !ValidRole(req.Role) { + writeError(w, http.StatusBadRequest, "invalid role") + return + } + if req.Role == "" { + req.Role = RoleEditor + } - rawKey, err := h.store.CreateAPIKey(req.Name) + rawKey, err := h.store.CreateAPIKey(req.Name, req.Role) if err != nil { writeError(w, http.StatusInternalServerError, "failed to create key") return @@ -79,6 +94,9 @@ func (h *APIKeyHandler) CreateKey(w http.ResponseWriter, r *http.Request) { // @Failure 400 {object} object{error=string} "Invalid ID" // @Router /api-keys/{id} [delete] func (h *APIKeyHandler) DeleteKey(w http.ResponseWriter, r *http.Request) { + if !requireRole(w, r, RoleAdmin) { + return + } idStr := chi.URLParam(r, "id") id, err := strconv.ParseInt(idStr, 10, 64) if err != nil { diff --git a/internal/api/handlers_apikeys_test.go b/internal/api/handlers_apikeys_test.go index fa129e2..edfd8bb 100644 --- a/internal/api/handlers_apikeys_test.go +++ b/internal/api/handlers_apikeys_test.go @@ -2,6 +2,7 @@ package api import ( "bytes" + "context" "encoding/json" "net/http" "net/http/httptest" @@ -16,6 +17,7 @@ func TestAPIKeysHandler(t *testing.T) { // List Empty req := httptest.NewRequest("GET", "/api/api-keys", nil) + req = req.WithContext(context.WithValue(req.Context(), contextKeyUserRole, RoleAdmin)) w := httptest.NewRecorder() h.ListKeys(w, req) if w.Code != http.StatusOK { @@ -26,6 +28,7 @@ func TestAPIKeysHandler(t *testing.T) { payload := map[string]string{"name": "TestKey"} body, _ := json.Marshal(payload) req = httptest.NewRequest("POST", "/api/api-keys", bytes.NewBuffer(body)) + req = req.WithContext(context.WithValue(req.Context(), contextKeyUserRole, RoleAdmin)) w = httptest.NewRecorder() h.CreateKey(w, req) diff --git a/internal/api/handlers_auth_test.go b/internal/api/handlers_auth_test.go index 4b13216..cc087a0 100644 --- a/internal/api/handlers_auth_test.go +++ b/internal/api/handlers_auth_test.go @@ -12,7 +12,7 @@ func TestAuthLogin(t *testing.T) { _, _, authH, _, s := setupTest(t) // Setup User - if err := s.CreateUser("admin", "correct-password", "UTC"); err != nil { + if err := s.CreateUser("admin", "correct-password", "UTC", "admin"); err != nil { t.Fatalf("Failed to create user: %v", err) } @@ -79,7 +79,7 @@ func TestAuthMeIntegration(t *testing.T) { _, _, _, router, s := setupTest(t) // Setup User - if err := s.CreateUser("admin", "correct-password", "UTC"); err != nil { + if err := s.CreateUser("admin", "correct-password", "UTC", "admin"); err != nil { t.Fatalf("Failed to create user: %v", err) } diff --git a/internal/api/handlers_crud.go b/internal/api/handlers_crud.go index 793a615..c088b55 100644 --- a/internal/api/handlers_crud.go +++ b/internal/api/handlers_crud.go @@ -66,6 +66,9 @@ const maxNameLength = 255 // @Failure 409 {object} object{error=string} "Group already exists" // @Router /groups [post] func (h *CRUDHandler) CreateGroup(w http.ResponseWriter, r *http.Request) { + if !requireRole(w, r, RoleEditor) { + return + } var req struct { Name string `json:"name"` } @@ -116,6 +119,9 @@ func (h *CRUDHandler) CreateGroup(w http.ResponseWriter, r *http.Request) { // @Failure 400 {string} string "ID required" // @Router /groups/{id} [delete] func (h *CRUDHandler) DeleteGroup(w http.ResponseWriter, r *http.Request) { + if !requireRole(w, r, RoleEditor) { + return + } id := chi.URLParam(r, "id") if id == "" { http.Error(w, "ID required", http.StatusBadRequest) @@ -141,6 +147,9 @@ func (h *CRUDHandler) DeleteGroup(w http.ResponseWriter, r *http.Request) { // @Failure 400 {string} string "Name is required" // @Router /groups/{id} [put] func (h *CRUDHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) { + if !requireRole(w, r, RoleEditor) { + return + } id := chi.URLParam(r, "id") if id == "" { http.Error(w, "ID required", http.StatusBadRequest) @@ -188,6 +197,9 @@ func (h *CRUDHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) { // @Failure 409 {string} string "Monitor name already exists" // @Router /monitors [post] func (h *CRUDHandler) CreateMonitor(w http.ResponseWriter, r *http.Request) { + if !requireRole(w, r, RoleEditor) { + return + } var req struct { Name string `json:"name"` URL string `json:"url"` @@ -349,6 +361,9 @@ func (h *CRUDHandler) GetGroups(w http.ResponseWriter, r *http.Request) { // @Failure 400 {string} string "ID required" // @Router /monitors/{id} [put] func (h *CRUDHandler) UpdateMonitor(w http.ResponseWriter, r *http.Request) { + if !requireRole(w, r, RoleEditor) { + return + } id := chi.URLParam(r, "id") if id == "" { http.Error(w, "ID required", http.StatusBadRequest) @@ -407,6 +422,9 @@ func (h *CRUDHandler) UpdateMonitor(w http.ResponseWriter, r *http.Request) { // @Failure 400 {string} string "ID required" // @Router /monitors/{id} [delete] func (h *CRUDHandler) DeleteMonitor(w http.ResponseWriter, r *http.Request) { + if !requireRole(w, r, RoleEditor) { + return + } id := chi.URLParam(r, "id") if id == "" { http.Error(w, "ID required", http.StatusBadRequest) @@ -432,6 +450,9 @@ func (h *CRUDHandler) DeleteMonitor(w http.ResponseWriter, r *http.Request) { // @Failure 404 {object} object{error=string} "Monitor not found" // @Router /monitors/{id}/pause [post] func (h *CRUDHandler) PauseMonitor(w http.ResponseWriter, r *http.Request) { + if !requireRole(w, r, RoleEditor) { + return + } id := chi.URLParam(r, "id") if id == "" { writeError(w, http.StatusBadRequest, "ID required") @@ -462,6 +483,9 @@ func (h *CRUDHandler) PauseMonitor(w http.ResponseWriter, r *http.Request) { // @Failure 404 {object} object{error=string} "Monitor not found" // @Router /monitors/{id}/resume [post] func (h *CRUDHandler) ResumeMonitor(w http.ResponseWriter, r *http.Request) { + if !requireRole(w, r, RoleEditor) { + return + } id := chi.URLParam(r, "id") if id == "" { writeError(w, http.StatusBadRequest, "ID required") diff --git a/internal/api/handlers_incidents.go b/internal/api/handlers_incidents.go index d6ec8b8..f723f5f 100644 --- a/internal/api/handlers_incidents.go +++ b/internal/api/handlers_incidents.go @@ -87,6 +87,9 @@ func incidentToDTO(i db.Incident, updates []db.IncidentUpdate) IncidentResponseD // @Failure 400 {string} string "Invalid request body" // @Router /incidents [post] func (h *IncidentHandler) CreateIncident(w http.ResponseWriter, r *http.Request) { + if !requireRole(w, r, RoleEditor) { + return + } var req struct { Title string `json:"title"` Description string `json:"description"` @@ -211,6 +214,9 @@ func (h *IncidentHandler) GetIncident(w http.ResponseWriter, r *http.Request) { // @Failure 404 {string} string "Incident not found" // @Router /incidents/{id} [put] func (h *IncidentHandler) UpdateIncident(w http.ResponseWriter, r *http.Request) { + if !requireRole(w, r, RoleEditor) { + return + } id := chi.URLParam(r, "id") existing, err := h.store.GetIncidentByID(id) @@ -293,6 +299,9 @@ func (h *IncidentHandler) UpdateIncident(w http.ResponseWriter, r *http.Request) // @Failure 500 {string} string "Failed to delete incident" // @Router /incidents/{id} [delete] func (h *IncidentHandler) DeleteIncident(w http.ResponseWriter, r *http.Request) { + if !requireRole(w, r, RoleEditor) { + return + } id := chi.URLParam(r, "id") if err := h.store.DeleteIncident(id); err != nil { @@ -317,6 +326,9 @@ func (h *IncidentHandler) DeleteIncident(w http.ResponseWriter, r *http.Request) // @Failure 404 {string} string "Outage not found" // @Router /outages/{id}/promote [post] func (h *IncidentHandler) PromoteOutage(w http.ResponseWriter, r *http.Request) { + if !requireRole(w, r, RoleEditor) { + return + } idStr := chi.URLParam(r, "id") outageID, err := strconv.ParseInt(idStr, 10, 64) if err != nil { @@ -413,6 +425,9 @@ func (h *IncidentHandler) PromoteOutage(w http.ResponseWriter, r *http.Request) // @Failure 404 {string} string "Incident not found" // @Router /incidents/{id}/visibility [patch] func (h *IncidentHandler) SetVisibility(w http.ResponseWriter, r *http.Request) { + if !requireRole(w, r, RoleEditor) { + return + } id := chi.URLParam(r, "id") incident, err := h.store.GetIncidentByID(id) @@ -461,6 +476,9 @@ func (h *IncidentHandler) SetVisibility(w http.ResponseWriter, r *http.Request) // @Failure 404 {string} string "Incident not found" // @Router /incidents/{id}/updates [post] func (h *IncidentHandler) AddUpdate(w http.ResponseWriter, r *http.Request) { + if !requireRole(w, r, RoleEditor) { + return + } id := chi.URLParam(r, "id") incident, err := h.store.GetIncidentByID(id) diff --git a/internal/api/handlers_incidents_test.go b/internal/api/handlers_incidents_test.go index 08b89da..8b5a976 100644 --- a/internal/api/handlers_incidents_test.go +++ b/internal/api/handlers_incidents_test.go @@ -2,6 +2,7 @@ package api import ( "bytes" + "context" "encoding/json" "net/http" "net/http/httptest" @@ -24,6 +25,7 @@ func TestIncidentHandler(t *testing.T) { } body, _ := json.Marshal(payload) req := httptest.NewRequest("POST", "/api/incidents", bytes.NewBuffer(body)) + req = req.WithContext(context.WithValue(req.Context(), contextKeyUserRole, RoleAdmin)) w := httptest.NewRecorder() h.CreateIncident(w, req) diff --git a/internal/api/handlers_maintenance.go b/internal/api/handlers_maintenance.go index d89d78a..f964463 100644 --- a/internal/api/handlers_maintenance.go +++ b/internal/api/handlers_maintenance.go @@ -45,6 +45,9 @@ type MaintenanceResponse struct { // @Failure 500 {string} string "Failed to schedule maintenance" // @Router /maintenance [post] func (h *MaintenanceHandler) CreateMaintenance(w http.ResponseWriter, r *http.Request) { + if !requireRole(w, r, RoleEditor) { + return + } var req struct { Title string `json:"title"` Description string `json:"description"` @@ -175,6 +178,9 @@ func (h *MaintenanceHandler) GetMaintenance(w http.ResponseWriter, r *http.Reque // @Failure 500 {string} string "Failed to update maintenance" // @Router /maintenance/{id} [put] func (h *MaintenanceHandler) UpdateMaintenance(w http.ResponseWriter, r *http.Request) { + if !requireRole(w, r, RoleEditor) { + return + } id := chi.URLParam(r, "id") if id == "" { http.Error(w, "Maintenance ID required", http.StatusBadRequest) @@ -267,6 +273,9 @@ func (h *MaintenanceHandler) UpdateMaintenance(w http.ResponseWriter, r *http.Re // @Failure 500 {string} string "Failed to delete maintenance" // @Router /maintenance/{id} [delete] func (h *MaintenanceHandler) DeleteMaintenance(w http.ResponseWriter, r *http.Request) { + if !requireRole(w, r, RoleEditor) { + return + } id := chi.URLParam(r, "id") if id == "" { http.Error(w, "Maintenance ID required", http.StatusBadRequest) diff --git a/internal/api/handlers_maintenance_test.go b/internal/api/handlers_maintenance_test.go index aae6b4a..fc07bda 100644 --- a/internal/api/handlers_maintenance_test.go +++ b/internal/api/handlers_maintenance_test.go @@ -2,6 +2,7 @@ package api import ( "bytes" + "context" "encoding/json" "net/http" "net/http/httptest" @@ -31,6 +32,7 @@ func TestCreateMaintenance(t *testing.T) { } body, _ := json.Marshal(payload) req := httptest.NewRequest("POST", "/api/maintenance", bytes.NewBuffer(body)) + req = req.WithContext(context.WithValue(req.Context(), contextKeyUserRole, RoleAdmin)) w := httptest.NewRecorder() h.CreateMaintenance(w, req) diff --git a/internal/api/handlers_notifications.go b/internal/api/handlers_notifications.go index 966fd25..a2f058e 100644 --- a/internal/api/handlers_notifications.go +++ b/internal/api/handlers_notifications.go @@ -53,6 +53,9 @@ func (h *NotificationChannelsHandler) GetChannels(w http.ResponseWriter, r *http // @Failure 400 {string} string "Type and Name are required" // @Router /notifications/channels [post] func (h *NotificationChannelsHandler) CreateChannel(w http.ResponseWriter, r *http.Request) { + if !requireRole(w, r, RoleEditor) { + return + } var body struct { Type string `json:"type"` Name string `json:"name"` @@ -122,6 +125,9 @@ func (h *NotificationChannelsHandler) CreateChannel(w http.ResponseWriter, r *ht // @Failure 400 {string} string "Missing ID" // @Router /notifications/channels/{id} [delete] func (h *NotificationChannelsHandler) DeleteChannel(w http.ResponseWriter, r *http.Request) { + if !requireRole(w, r, RoleEditor) { + return + } id := chi.URLParam(r, "id") if id == "" { http.Error(w, "Missing ID", http.StatusBadRequest) @@ -167,6 +173,9 @@ func extractWebhookURL(config map[string]interface{}) string { // UpdateChannel modifies an existing notification channel. func (h *NotificationChannelsHandler) UpdateChannel(w http.ResponseWriter, r *http.Request) { + if !requireRole(w, r, RoleEditor) { + return + } id := chi.URLParam(r, "id") if id == "" { http.Error(w, "Missing ID", http.StatusBadRequest) @@ -226,6 +235,9 @@ func (h *NotificationChannelsHandler) UpdateChannel(w http.ResponseWriter, r *ht // TestChannel sends a test notification through the specified channel type and config. func (h *NotificationChannelsHandler) TestChannel(w http.ResponseWriter, r *http.Request) { + if !requireRole(w, r, RoleEditor) { + return + } var body struct { Type string `json:"type"` Config map[string]interface{} `json:"config"` diff --git a/internal/api/handlers_notifications_test.go b/internal/api/handlers_notifications_test.go index 7a3e354..67892d4 100644 --- a/internal/api/handlers_notifications_test.go +++ b/internal/api/handlers_notifications_test.go @@ -2,6 +2,7 @@ package api import ( "bytes" + "context" "encoding/json" "net/http" "net/http/httptest" @@ -37,6 +38,7 @@ func TestGetChannels(t *testing.T) { } req, _ := http.NewRequest("GET", "/api/notifications/channels", nil) + req = req.WithContext(context.WithValue(req.Context(), contextKeyUserRole, RoleAdmin)) rr := httptest.NewRecorder() handler.GetChannels(rr, req) @@ -55,6 +57,13 @@ func TestGetChannels(t *testing.T) { } } +func testAdminRoleMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := context.WithValue(r.Context(), contextKeyUserRole, RoleAdmin) + next.ServeHTTP(w, r.WithContext(ctx)) + }) +} + func TestCreateChannel(t *testing.T) { store := newTestStore(t) handler := NewNotificationChannelsHandler(store) @@ -68,6 +77,7 @@ func TestCreateChannel(t *testing.T) { body, _ := json.Marshal(payload) req, _ := http.NewRequest("POST", "/api/notifications/channels", bytes.NewBuffer(body)) + req = req.WithContext(context.WithValue(req.Context(), contextKeyUserRole, RoleAdmin)) rr := httptest.NewRecorder() handler.CreateChannel(rr, req) @@ -94,6 +104,7 @@ func TestDeleteChannel(t *testing.T) { // Setup CHI router to handle params r := chi.NewRouter() + r.Use(testAdminRoleMiddleware) r.Delete("/notifications/channels/{id}", handler.DeleteChannel) req, _ := http.NewRequest("DELETE", "/notifications/channels/nc1", nil) @@ -124,6 +135,7 @@ func TestUpdateChannel(t *testing.T) { } r := chi.NewRouter() + r.Use(testAdminRoleMiddleware) r.Put("/notifications/channels/{id}", handler.UpdateChannel) payload := map[string]interface{}{ @@ -166,6 +178,7 @@ func TestUpdateChannel_ValidationErrors(t *testing.T) { } r := chi.NewRouter() + r.Use(testAdminRoleMiddleware) r.Put("/notifications/channels/{id}", handler.UpdateChannel) tests := []struct { @@ -222,6 +235,7 @@ func TestCreateChannel_Webhook(t *testing.T) { body, _ := json.Marshal(payload) req, _ := http.NewRequest("POST", "/api/notifications/channels", bytes.NewBuffer(body)) + req = req.WithContext(context.WithValue(req.Context(), contextKeyUserRole, RoleAdmin)) rr := httptest.NewRecorder() handler.CreateChannel(rr, req) @@ -270,6 +284,7 @@ func TestCreateChannel_WebhookValidation(t *testing.T) { t.Run(tc.name, func(t *testing.T) { body, _ := json.Marshal(tc.payload) req, _ := http.NewRequest("POST", "/api/notifications/channels", bytes.NewBuffer(body)) + req = req.WithContext(context.WithValue(req.Context(), contextKeyUserRole, RoleAdmin)) rr := httptest.NewRecorder() handler.CreateChannel(rr, req) @@ -299,6 +314,7 @@ func TestTestChannel_Success(t *testing.T) { body, _ := json.Marshal(payload) req, _ := http.NewRequest("POST", "/api/notifications/channels/test", bytes.NewBuffer(body)) + req = req.WithContext(context.WithValue(req.Context(), contextKeyUserRole, RoleAdmin)) rr := httptest.NewRecorder() handler.TestChannel(rr, req) @@ -328,6 +344,7 @@ func TestTestChannel_ServerError(t *testing.T) { body, _ := json.Marshal(payload) req, _ := http.NewRequest("POST", "/api/notifications/channels/test", bytes.NewBuffer(body)) + req = req.WithContext(context.WithValue(req.Context(), contextKeyUserRole, RoleAdmin)) rr := httptest.NewRecorder() handler.TestChannel(rr, req) @@ -347,6 +364,7 @@ func TestTestChannel_MissingType(t *testing.T) { body, _ := json.Marshal(payload) req, _ := http.NewRequest("POST", "/api/notifications/channels/test", bytes.NewBuffer(body)) + req = req.WithContext(context.WithValue(req.Context(), contextKeyUserRole, RoleAdmin)) rr := httptest.NewRecorder() handler.TestChannel(rr, req) @@ -367,6 +385,7 @@ func TestTestChannel_InvalidURL(t *testing.T) { body, _ := json.Marshal(payload) req, _ := http.NewRequest("POST", "/api/notifications/channels/test", bytes.NewBuffer(body)) + req = req.WithContext(context.WithValue(req.Context(), contextKeyUserRole, RoleAdmin)) rr := httptest.NewRecorder() handler.TestChannel(rr, req) diff --git a/internal/api/handlers_settings.go b/internal/api/handlers_settings.go index 082e9c3..9696483 100644 --- a/internal/api/handlers_settings.go +++ b/internal/api/handlers_settings.go @@ -158,6 +158,9 @@ func (h *SettingsHandler) GetSettings(w http.ResponseWriter, r *http.Request) { // @Failure 400 {string} string "Invalid body" // @Router /settings [patch] func (h *SettingsHandler) UpdateSettings(w http.ResponseWriter, r *http.Request) { + if !requireRole(w, r, RoleAdmin) { + return + } var body map[string]string if err := json.NewDecoder(r.Body).Decode(&body); err != nil { http.Error(w, "Invalid body", http.StatusBadRequest) diff --git a/internal/api/handlers_settings_test.go b/internal/api/handlers_settings_test.go index 09c6c3d..f0bbff0 100644 --- a/internal/api/handlers_settings_test.go +++ b/internal/api/handlers_settings_test.go @@ -2,6 +2,7 @@ package api import ( "bytes" + "context" "encoding/json" "net/http" "net/http/httptest" @@ -11,6 +12,10 @@ import ( "github.com/projecthelena/warden/internal/uptime" ) +func withAdminCtx(req *http.Request) *http.Request { + return req.WithContext(context.WithValue(req.Context(), contextKeyUserRole, RoleAdmin)) +} + func TestGetSettings(t *testing.T) { s, _ := db.NewStore(db.NewTestConfig()) m := uptime.NewManager(s) @@ -54,7 +59,7 @@ func TestUpdateSettings_MultipleSettings(t *testing.T) { req.Header.Set("Content-Type", "application/json") w := httptest.NewRecorder() - h.UpdateSettings(w, req) + h.UpdateSettings(w, withAdminCtx(req)) if w.Code != http.StatusOK { t.Fatalf("Expected 200, got %d", w.Code) @@ -81,7 +86,7 @@ func TestUpdateSettings_InvalidBody(t *testing.T) { req.Header.Set("Content-Type", "application/json") w := httptest.NewRecorder() - h.UpdateSettings(w, req) + h.UpdateSettings(w, withAdminCtx(req)) if w.Code != http.StatusBadRequest { t.Errorf("Expected 400 for invalid JSON, got %d", w.Code) @@ -102,7 +107,7 @@ func TestUpdateSettings_LatencyThresholdUpdatesManager(t *testing.T) { req.Header.Set("Content-Type", "application/json") w := httptest.NewRecorder() - h.UpdateSettings(w, req) + h.UpdateSettings(w, withAdminCtx(req)) if w.Code != http.StatusOK { t.Fatalf("Expected 200, got %d", w.Code) diff --git a/internal/api/handlers_setup.go b/internal/api/handlers_setup.go index bb8c296..95546c1 100644 --- a/internal/api/handlers_setup.go +++ b/internal/api/handlers_setup.go @@ -133,8 +133,8 @@ func (h *Router) PerformSetup(w http.ResponseWriter, r *http.Request) { req.Timezone = "UTC" } - // Create User - if err := h.store.CreateUser(req.Username, req.Password, req.Timezone); err != nil { + // Create User (first user is always admin) + if err := h.store.CreateUser(req.Username, req.Password, req.Timezone, "admin"); err != nil { log.Printf("AUDIT: [SETUP] Failed to create user from IP %s: %v", sanitizeLog(clientIP), err) // #nosec G706 -- sanitized http.Error(w, "Failed to create user", http.StatusInternalServerError) return @@ -221,7 +221,7 @@ func (h *Router) PerformSetup(w http.ResponseWriter, r *http.Request) { } // Set auth cookie - http.SetCookie(w, &http.Cookie{ + http.SetCookie(w, &http.Cookie{ // #nosec G124 -- Secure defaults true; configurable for local HTTP dev Name: "auth_token", Value: token, Expires: expiresAt, diff --git a/internal/api/handlers_sso.go b/internal/api/handlers_sso.go index b99418d..321f0ba 100644 --- a/internal/api/handlers_sso.go +++ b/internal/api/handlers_sso.go @@ -122,7 +122,7 @@ func (h *SSOHandler) GoogleLogin(w http.ResponseWriter, r *http.Request) { // Store state in a short-lived cookie // SECURITY: Use SameSite=Strict to prevent CSRF attacks on OAuth flow - http.SetCookie(w, &http.Cookie{ + http.SetCookie(w, &http.Cookie{ // #nosec G124 -- Secure defaults true; configurable for local HTTP dev Name: "oauth_state", Value: state, MaxAge: 300, // 5 minutes @@ -156,7 +156,7 @@ func (h *SSOHandler) GoogleLogin(w http.ResponseWriter, r *http.Request) { // clearStateCookie clears the OAuth state cookie func (h *SSOHandler) clearStateCookie(w http.ResponseWriter) { - http.SetCookie(w, &http.Cookie{ + http.SetCookie(w, &http.Cookie{ // #nosec G124 -- Secure defaults true; configurable for local HTTP dev Name: "oauth_state", Value: "", MaxAge: -1, @@ -351,7 +351,7 @@ func (h *SSOHandler) GoogleCallback(w http.ResponseWriter, r *http.Request) { } // Set auth cookie - http.SetCookie(w, &http.Cookie{ + http.SetCookie(w, &http.Cookie{ // #nosec G124 -- Secure defaults true; configurable for local HTTP dev Name: "auth_token", Value: sessionToken, Expires: expiresAt, @@ -361,12 +361,19 @@ func (h *SSOHandler) GoogleCallback(w http.ResponseWriter, r *http.Request) { Secure: h.config.CookieSecure, }) - // Redirect to dashboard - http.Redirect(w, r, "/dashboard", http.StatusTemporaryRedirect) + // Redirect based on user role + if user.Role == RoleStatusViewer { + http.Redirect(w, r, "/my-pages", http.StatusTemporaryRedirect) + } else { + http.Redirect(w, r, "/dashboard", http.StatusTemporaryRedirect) + } } // TestSSOConfig tests if the SSO configuration is valid (admin only) func (h *SSOHandler) TestSSOConfig(w http.ResponseWriter, r *http.Request) { + if !requireRole(w, r, RoleAdmin) { + return + } clientID, _ := h.store.GetSetting("sso.google.client_id") clientSecret, _ := h.store.GetSetting("sso.google.client_secret") diff --git a/internal/api/handlers_status_pages.go b/internal/api/handlers_status_pages.go index e09928e..57d89e1 100644 --- a/internal/api/handlers_status_pages.go +++ b/internal/api/handlers_status_pages.go @@ -50,6 +50,7 @@ func (h *StatusPageHandler) GetAll(w http.ResponseWriter, r *http.Request) { // 3. Construct Unified List type StatusPageDTO struct { + ID int64 `json:"id"` Slug string `json:"slug"` Title string `json:"title"` GroupID *string `json:"groupId"` @@ -101,6 +102,7 @@ func (h *StatusPageHandler) GetAll(w http.ResponseWriter, r *http.Request) { HeaderArrangement: "stacked", } if globalPage != nil { + globalDTO.ID = globalPage.ID globalDTO.Title = globalPage.Title globalDTO.Public = globalPage.Public globalDTO.Enabled = globalPage.Enabled @@ -153,6 +155,7 @@ func (h *StatusPageHandler) GetAll(w http.ResponseWriter, r *http.Request) { } if cfg, ok := configMap[g.ID]; ok { + dto.ID = cfg.ID dto.Slug = cfg.Slug dto.Title = cfg.Title dto.Public = cfg.Public @@ -204,6 +207,9 @@ func (h *StatusPageHandler) GetAll(w http.ResponseWriter, r *http.Request) { // @Failure 400 {object} object{error=string} "Invalid request" // @Router /status-pages/{slug} [patch] func (h *StatusPageHandler) Toggle(w http.ResponseWriter, r *http.Request) { + if !requireRole(w, r, RoleEditor) { + return + } slug := chi.URLParam(r, "slug") var req struct { Public bool `json:"public"` @@ -437,10 +443,30 @@ func (h *StatusPageHandler) GetPublicStatus(w http.ResponseWriter, r *http.Reque return } if !page.Public { - if !h.auth.IsAuthenticated(r) { + auth := h.auth.GetAuthInfo(r) + if !auth.Authenticated { writeError(w, http.StatusUnauthorized, "authentication required") return } + // status_viewer users can only see their assigned status pages + if auth.Role == RoleStatusViewer { + pages, err := h.store.GetUserStatusPages(auth.UserID) + if err != nil { + writeError(w, http.StatusInternalServerError, "failed to check access") + return + } + hasAccess := false + for _, pid := range pages { + if pid == page.ID { + hasAccess = true + break + } + } + if !hasAccess { + writeError(w, http.StatusForbidden, "you do not have access to this status page") + return + } + } } // 2. Fetch Layout from DB (Groups + Monitors Metadata) @@ -867,8 +893,29 @@ func (h *StatusPageHandler) GetRSSFeed(w http.ResponseWriter, r *http.Request) { return } if !page.Public { - writeError(w, http.StatusNotFound, "status page not found") - return + auth := h.auth.GetAuthInfo(r) + if !auth.Authenticated { + writeError(w, http.StatusNotFound, "status page not found") + return + } + if auth.Role == RoleStatusViewer { + pages, err := h.store.GetUserStatusPages(auth.UserID) + if err != nil { + writeError(w, http.StatusNotFound, "status page not found") + return + } + hasAccess := false + for _, pid := range pages { + if pid == page.ID { + hasAccess = true + break + } + } + if !hasAccess { + writeError(w, http.StatusNotFound, "status page not found") + return + } + } } // 2. Build base URL from request @@ -982,6 +1029,64 @@ func (h *StatusPageHandler) GetRSSFeed(w http.ResponseWriter, r *http.Request) { _, _ = w.Write([]byte(rss)) // #nosec G705 - all user content escaped via xmlEscape() } +// GetMyStatusPages returns status pages accessible to the logged-in user. +// For status_viewer: returns only assigned pages. For other roles: returns all enabled pages. +func (h *StatusPageHandler) GetMyStatusPages(w http.ResponseWriter, r *http.Request) { + userID, ok := r.Context().Value(contextKeyUserID).(int64) + if !ok { + writeError(w, http.StatusUnauthorized, "unauthorized") + return + } + role := getUserRole(r) + + type PageDTO struct { + ID int64 `json:"id"` + Slug string `json:"slug"` + Title string `json:"title"` + Public bool `json:"public"` + } + + var result []PageDTO + + if role == RoleStatusViewer { + pageIDs, err := h.store.GetUserStatusPages(userID) + if err != nil { + writeError(w, http.StatusInternalServerError, "failed to fetch assigned pages") + return + } + if len(pageIDs) == 0 { + writeJSON(w, http.StatusOK, map[string]any{"pages": []PageDTO{}}) + return + } + pages, err := h.store.GetStatusPagesByIDs(pageIDs) + if err != nil { + writeError(w, http.StatusInternalServerError, "failed to fetch status pages") + return + } + for _, p := range pages { + if p.Enabled { + result = append(result, PageDTO{ID: p.ID, Slug: p.Slug, Title: p.Title, Public: p.Public}) + } + } + } else { + pages, err := h.store.GetStatusPages() + if err != nil { + writeError(w, http.StatusInternalServerError, "failed to fetch status pages") + return + } + for _, p := range pages { + if p.Enabled { + result = append(result, PageDTO{ID: p.ID, Slug: p.Slug, Title: p.Title, Public: p.Public}) + } + } + } + + if result == nil { + result = []PageDTO{} + } + writeJSON(w, http.StatusOK, map[string]any{"pages": result}) +} + // xmlEscape escapes special XML characters func xmlEscape(s string) string { s = strings.ReplaceAll(s, "&", "&") diff --git a/internal/api/handlers_status_pages_test.go b/internal/api/handlers_status_pages_test.go index 0cae6cc..488369d 100644 --- a/internal/api/handlers_status_pages_test.go +++ b/internal/api/handlers_status_pages_test.go @@ -56,7 +56,7 @@ func seedPage(t *testing.T, store *db.Store, slug, title string, groupID *string // seedAuthUser creates a user and session, returning a cookie value for auth. func seedAuthUser(t *testing.T, store *db.Store, username, token string) { t.Helper() - if err := store.CreateUser(username, "password123", "UTC"); err != nil { + if err := store.CreateUser(username, "password123", "UTC", "admin"); err != nil { t.Fatalf("Failed to create user %s: %v", username, err) } user, err := store.Authenticate(username, "password123") @@ -80,7 +80,9 @@ func makeRequest(method, path, slug string, body interface{}) *http.Request { req := httptest.NewRequest(method, path, reqBody) rctx := chi.NewRouteContext() rctx.URLParams.Add("slug", slug) - req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx)) + ctx := context.WithValue(req.Context(), chi.RouteCtxKey, rctx) + ctx = context.WithValue(ctx, contextKeyUserRole, RoleAdmin) + req = req.WithContext(ctx) return req } @@ -362,7 +364,9 @@ func TestToggle_InvalidBody(t *testing.T) { req := httptest.NewRequest("PATCH", "/api/status-pages/test", bytes.NewBufferString("not json")) rctx := chi.NewRouteContext() rctx.URLParams.Add("slug", "test") - req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx)) + ctx := context.WithValue(req.Context(), chi.RouteCtxKey, rctx) + ctx = context.WithValue(ctx, contextKeyUserRole, RoleAdmin) + req = req.WithContext(ctx) w := httptest.NewRecorder() spH.Toggle(w, req) @@ -1682,3 +1686,267 @@ func TestPhase4_RSSFeedAtomSelfLink(t *testing.T) { t.Error("Expected rel='self' in Atom link") } } + +// ============================================================ +// STATUS VIEWER TESTS +// ============================================================ + +// seedAuthUserWithRole creates a user with a specific role and session. +func seedAuthUserWithRole(t *testing.T, store *db.Store, username, token, role string) int64 { + t.Helper() + if err := store.CreateUser(username, "password123", "UTC", role); err != nil { + t.Fatalf("Failed to create user %s: %v", username, err) + } + user, err := store.Authenticate(username, "password123") + if err != nil { + t.Fatalf("Failed to authenticate user %s: %v", username, err) + } + if err := store.CreateSession(user.ID, token, time.Now().Add(24*time.Hour)); err != nil { + t.Fatalf("Failed to create session for %s: %v", username, err) + } + return user.ID +} + +// makeAuthRequest creates a request with auth cookie, user ID, and role in context. +func makeAuthRequest(method, path string, body interface{}, userID int64, role string) *http.Request { + var reqBody *bytes.Buffer + if body != nil { + b, _ := json.Marshal(body) + reqBody = bytes.NewBuffer(b) + } else { + reqBody = bytes.NewBuffer(nil) + } + req := httptest.NewRequest(method, path, reqBody) + ctx := context.WithValue(req.Context(), contextKeyUserID, userID) + ctx = context.WithValue(ctx, contextKeyUserRole, role) + req = req.WithContext(ctx) + return req +} + +func TestGetAll_IncludesIDField(t *testing.T) { + store, spH := newStatusPageTestEnv(t) + + seedPage(t, store, "all", "Global Status", nil, true, true) + + w := httptest.NewRecorder() + req := httptest.NewRequest("GET", "/api/status-pages", nil) + spH.GetAll(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("Expected 200, got %d", w.Code) + } + + body := decodeJSON(t, w) + pages := body["pages"].([]interface{}) + + for _, p := range pages { + page := p.(map[string]interface{}) + if page["slug"] == "all" { + id, ok := page["id"].(float64) + if !ok { + t.Error("Expected 'id' field in status page DTO") + return + } + if id <= 0 { + t.Errorf("Expected positive id for saved page, got %v", id) + } + return + } + } + t.Error("'all' page not found in GetAll response") +} + +func TestGetMyStatusPages_StatusViewerAssigned(t *testing.T) { + store, spH := newStatusPageTestEnv(t) + + // Create a status page + seedPage(t, store, "viewer-page", "Viewer Page", nil, false, true) + page, _ := store.GetStatusPageBySlug("viewer-page") + + // Create a status_viewer user and assign the page + userID := seedAuthUserWithRole(t, store, "linda", "linda-token", RoleStatusViewer) + if err := store.SetUserStatusPages(userID, []int64{page.ID}); err != nil { + t.Fatalf("Failed to assign pages: %v", err) + } + + // Call GetMyStatusPages + req := makeAuthRequest("GET", "/api/my/status-pages", nil, userID, RoleStatusViewer) + w := httptest.NewRecorder() + spH.GetMyStatusPages(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("Expected 200, got %d (body: %s)", w.Code, w.Body.String()) + } + + body := decodeJSON(t, w) + pages := body["pages"].([]interface{}) + if len(pages) != 1 { + t.Fatalf("Expected 1 page, got %d", len(pages)) + } + + p := pages[0].(map[string]interface{}) + if p["slug"] != "viewer-page" { + t.Errorf("Expected slug 'viewer-page', got '%v'", p["slug"]) + } + if p["title"] != "Viewer Page" { + t.Errorf("Expected title 'Viewer Page', got '%v'", p["title"]) + } +} + +func TestGetMyStatusPages_StatusViewerNoAssignment(t *testing.T) { + _, spH := newStatusPageTestEnv(t) + + // Call GetMyStatusPages without assigned pages + req := makeAuthRequest("GET", "/api/my/status-pages", nil, 999, RoleStatusViewer) + w := httptest.NewRecorder() + spH.GetMyStatusPages(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("Expected 200, got %d", w.Code) + } + + body := decodeJSON(t, w) + pages := body["pages"].([]interface{}) + if len(pages) != 0 { + t.Errorf("Expected 0 pages for unassigned user, got %d", len(pages)) + } +} + +func TestGetMyStatusPages_StatusViewerDisabledPageFiltered(t *testing.T) { + store, spH := newStatusPageTestEnv(t) + + // Create a disabled page + seedPage(t, store, "disabled-sv", "Disabled Page", nil, false, false) + page, _ := store.GetStatusPageBySlug("disabled-sv") + + userID := seedAuthUserWithRole(t, store, "viewer2", "v2-token", RoleStatusViewer) + if err := store.SetUserStatusPages(userID, []int64{page.ID}); err != nil { + t.Fatalf("Failed to assign pages: %v", err) + } + + req := makeAuthRequest("GET", "/api/my/status-pages", nil, userID, RoleStatusViewer) + w := httptest.NewRecorder() + spH.GetMyStatusPages(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("Expected 200, got %d", w.Code) + } + + body := decodeJSON(t, w) + pages := body["pages"].([]interface{}) + if len(pages) != 0 { + t.Errorf("Expected 0 pages (disabled filtered out), got %d", len(pages)) + } +} + +func TestGetMyStatusPages_AdminSeesAllEnabled(t *testing.T) { + store, spH := newStatusPageTestEnv(t) + + seedPage(t, store, "admin-page-1", "Admin Page 1", nil, true, true) + seedPage(t, store, "admin-page-2", "Admin Page 2", nil, false, true) + seedPage(t, store, "admin-disabled", "Disabled", nil, true, false) + + userID := seedAuthUserWithRole(t, store, "adminuser", "admin-token", RoleAdmin) + + req := makeAuthRequest("GET", "/api/my/status-pages", nil, userID, RoleAdmin) + w := httptest.NewRecorder() + spH.GetMyStatusPages(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("Expected 200, got %d", w.Code) + } + + body := decodeJSON(t, w) + pages := body["pages"].([]interface{}) + // Should see 2 enabled pages, not the disabled one + if len(pages) != 2 { + t.Errorf("Expected 2 enabled pages for admin, got %d", len(pages)) + } +} + +func TestGetPublicStatus_StatusViewerAccessDenied(t *testing.T) { + store, spH := newStatusPageTestEnv(t) + + // Private, enabled page + seedPage(t, store, "private-sv", "Private SV", nil, false, true) + + // Create status_viewer user WITHOUT assignment + userID := seedAuthUserWithRole(t, store, "svuser", "sv-token", RoleStatusViewer) + + req := makeRequest("GET", "/api/s/private-sv", "private-sv", nil) + req.AddCookie(&http.Cookie{Name: "auth_token", Value: "sv-token"}) + // Inject role context (simulating middleware) + ctx := context.WithValue(req.Context(), contextKeyUserID, userID) + ctx = context.WithValue(ctx, contextKeyUserRole, RoleStatusViewer) + req = req.WithContext(ctx) + + w := httptest.NewRecorder() + spH.GetPublicStatus(w, req) + + if w.Code != http.StatusForbidden { + t.Errorf("Expected 403 for unassigned status_viewer, got %d", w.Code) + } +} + +func TestGetPublicStatus_StatusViewerWithAccess(t *testing.T) { + store, spH := newStatusPageTestEnv(t) + + seedGroup(t, store, "g-sv-access", "SV Access Group") + seedPage(t, store, "sv-access", "SV Access", nil, false, true) + page, _ := store.GetStatusPageBySlug("sv-access") + + // Create status_viewer and assign page + userID := seedAuthUserWithRole(t, store, "svaccess", "sva-token", RoleStatusViewer) + if err := store.SetUserStatusPages(userID, []int64{page.ID}); err != nil { + t.Fatalf("Failed to assign pages: %v", err) + } + + req := makeRequest("GET", "/api/s/sv-access", "sv-access", nil) + req.AddCookie(&http.Cookie{Name: "auth_token", Value: "sva-token"}) + ctx := context.WithValue(req.Context(), contextKeyUserID, userID) + ctx = context.WithValue(ctx, contextKeyUserRole, RoleStatusViewer) + req = req.WithContext(ctx) + + w := httptest.NewRecorder() + spH.GetPublicStatus(w, req) + + if w.Code != http.StatusOK { + t.Errorf("Expected 200 for assigned status_viewer, got %d (body: %s)", w.Code, w.Body.String()) + } +} + +func TestGetStatusPagesByIDs(t *testing.T) { + store, _ := newStatusPageTestEnv(t) + + seedPage(t, store, "byid-1", "By ID 1", nil, true, true) + seedPage(t, store, "byid-2", "By ID 2", nil, false, true) + + p1, _ := store.GetStatusPageBySlug("byid-1") + p2, _ := store.GetStatusPageBySlug("byid-2") + + pages, err := store.GetStatusPagesByIDs([]int64{p1.ID, p2.ID}) + if err != nil { + t.Fatalf("GetStatusPagesByIDs failed: %v", err) + } + if len(pages) != 2 { + t.Fatalf("Expected 2 pages, got %d", len(pages)) + } + + // Test empty input + pages, err = store.GetStatusPagesByIDs([]int64{}) + if err != nil { + t.Fatalf("GetStatusPagesByIDs empty failed: %v", err) + } + if pages != nil { + t.Errorf("Expected nil for empty IDs, got %v", pages) + } + + // Test non-existent IDs + pages, err = store.GetStatusPagesByIDs([]int64{99999}) + if err != nil { + t.Fatalf("GetStatusPagesByIDs nonexistent failed: %v", err) + } + if len(pages) != 0 { + t.Errorf("Expected 0 pages for nonexistent IDs, got %d", len(pages)) + } +} diff --git a/internal/api/handlers_test.go b/internal/api/handlers_test.go index 58e1ce2..866c7b8 100644 --- a/internal/api/handlers_test.go +++ b/internal/api/handlers_test.go @@ -2,6 +2,7 @@ package api import ( "bytes" + "context" "encoding/json" "fmt" "net/http" @@ -14,6 +15,20 @@ import ( "github.com/go-chi/chi/v5" ) +// withAdminRole adds the admin role to a request's context for testing. +func withAdminRole(req *http.Request) *http.Request { + ctx := context.WithValue(req.Context(), contextKeyUserRole, RoleAdmin) + return req.WithContext(ctx) +} + +// adminRoleMiddleware is a test middleware that injects the admin role into context. +func adminRoleMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := context.WithValue(r.Context(), contextKeyUserRole, RoleAdmin) + next.ServeHTTP(w, r.WithContext(ctx)) + }) +} + func setupTest(t *testing.T) (*CRUDHandler, *SettingsHandler, *AuthHandler, http.Handler, *db.Store) { store, _ := db.NewStore(db.NewTestConfig()) manager := uptime.NewManager(store) @@ -49,6 +64,7 @@ func TestUpdateMonitor(t *testing.T) { w := httptest.NewRecorder() r := chi.NewRouter() + r.Use(adminRoleMiddleware) r.Put("/api/monitors/{id}", crudH.UpdateMonitor) r.ServeHTTP(w, req) @@ -92,7 +108,7 @@ func TestUpdateSettings(t *testing.T) { // Settings handler doesn't use URL params, so we can call directly or via router handler := http.HandlerFunc(settingsH.UpdateSettings) - handler.ServeHTTP(w, req) + handler.ServeHTTP(w, withAdminRole(req)) if w.Code != http.StatusOK { t.Errorf("Expected 200, got %d", w.Code) @@ -121,6 +137,7 @@ func TestPauseMonitor(t *testing.T) { w := httptest.NewRecorder() r := chi.NewRouter() + r.Use(adminRoleMiddleware) r.Post("/api/monitors/{id}/pause", crudH.PauseMonitor) r.ServeHTTP(w, req) @@ -171,6 +188,7 @@ func TestResumeMonitor(t *testing.T) { w := httptest.NewRecorder() r := chi.NewRouter() + r.Use(adminRoleMiddleware) r.Post("/api/monitors/{id}/resume", crudH.ResumeMonitor) r.ServeHTTP(w, req) @@ -216,6 +234,7 @@ func TestPauseMonitor_NotFound(t *testing.T) { w := httptest.NewRecorder() r := chi.NewRouter() + r.Use(adminRoleMiddleware) r.Post("/api/monitors/{id}/pause", crudH.PauseMonitor) r.ServeHTTP(w, req) @@ -241,6 +260,7 @@ func TestResumeMonitor_NotFound(t *testing.T) { w := httptest.NewRecorder() r := chi.NewRouter() + r.Use(adminRoleMiddleware) r.Post("/api/monitors/{id}/resume", crudH.ResumeMonitor) r.ServeHTTP(w, req) @@ -266,6 +286,7 @@ func TestPauseMonitor_EmptyID(t *testing.T) { w := httptest.NewRecorder() r := chi.NewRouter() + r.Use(adminRoleMiddleware) r.Post("/api/monitors/{id}/pause", crudH.PauseMonitor) r.ServeHTTP(w, req) @@ -291,6 +312,7 @@ func TestPauseResumeMonitor_FullCycle(t *testing.T) { } r := chi.NewRouter() + r.Use(adminRoleMiddleware) r.Post("/api/monitors/{id}/pause", crudH.PauseMonitor) r.Post("/api/monitors/{id}/resume", crudH.ResumeMonitor) @@ -356,6 +378,7 @@ func TestPauseMonitor_AlreadyPaused(t *testing.T) { w := httptest.NewRecorder() r := chi.NewRouter() + r.Use(adminRoleMiddleware) r.Post("/api/monitors/{id}/pause", crudH.PauseMonitor) r.ServeHTTP(w, req) @@ -387,6 +410,7 @@ func TestResumeMonitor_AlreadyActive(t *testing.T) { w := httptest.NewRecorder() r := chi.NewRouter() + r.Use(adminRoleMiddleware) r.Post("/api/monitors/{id}/resume", crudH.ResumeMonitor) r.ServeHTTP(w, req) @@ -419,6 +443,7 @@ func TestPauseMonitor_UUIDStyleID(t *testing.T) { w := httptest.NewRecorder() r := chi.NewRouter() + r.Use(adminRoleMiddleware) r.Post("/api/monitors/{id}/pause", crudH.PauseMonitor) r.ServeHTTP(w, req) @@ -444,6 +469,7 @@ func TestPauseResumeMonitor_SequentialToggle(t *testing.T) { } r := chi.NewRouter() + r.Use(adminRoleMiddleware) r.Post("/api/monitors/{id}/pause", crudH.PauseMonitor) r.Post("/api/monitors/{id}/resume", crudH.ResumeMonitor) @@ -552,6 +578,7 @@ func TestCreateMonitor_NotifFatigueValidation(t *testing.T) { w := httptest.NewRecorder() r := chi.NewRouter() + r.Use(adminRoleMiddleware) r.Post("/api/monitors", crudH.CreateMonitor) r.ServeHTTP(w, req) @@ -629,6 +656,7 @@ func TestUpdateMonitor_NotifFatigueValidation(t *testing.T) { w := httptest.NewRecorder() r := chi.NewRouter() + r.Use(adminRoleMiddleware) r.Put("/api/monitors/{id}", crudH.UpdateMonitor) r.ServeHTTP(w, req) @@ -670,6 +698,7 @@ func TestGetUptime_IncludesOverrideFields(t *testing.T) { w := httptest.NewRecorder() r := chi.NewRouter() + r.Use(adminRoleMiddleware) r.Get("/api/uptime", uptimeH.GetHistory) r.ServeHTTP(w, req) @@ -761,6 +790,7 @@ func TestGetUptime_IncludesLatencyThreshold(t *testing.T) { w := httptest.NewRecorder() r := chi.NewRouter() + r.Use(adminRoleMiddleware) r.Get("/api/uptime", uptimeH.GetHistory) r.ServeHTTP(w, req) @@ -841,6 +871,7 @@ func TestCreateMonitor_WithRequestConfig(t *testing.T) { w := httptest.NewRecorder() r := chi.NewRouter() + r.Use(adminRoleMiddleware) r.Post("/api/monitors", crudH.CreateMonitor) r.ServeHTTP(w, req) @@ -909,6 +940,7 @@ func TestUpdateMonitor_WithRequestConfig(t *testing.T) { w := httptest.NewRecorder() r := chi.NewRouter() + r.Use(adminRoleMiddleware) r.Put("/api/monitors/{id}", crudH.UpdateMonitor) r.ServeHTTP(w, req) @@ -1019,6 +1051,7 @@ func TestValidateRequestConfig(t *testing.T) { w := httptest.NewRecorder() r := chi.NewRouter() + r.Use(adminRoleMiddleware) r.Post("/api/monitors", crudH.CreateMonitor) r.ServeHTTP(w, req) @@ -1077,6 +1110,7 @@ func TestGetUptime_IncludesRequestConfig(t *testing.T) { w := httptest.NewRecorder() r := chi.NewRouter() + r.Use(adminRoleMiddleware) r.Get("/api/uptime", uptimeH.GetHistory) r.ServeHTTP(w, req) diff --git a/internal/api/handlers_users.go b/internal/api/handlers_users.go new file mode 100644 index 0000000..2e4a591 --- /dev/null +++ b/internal/api/handlers_users.go @@ -0,0 +1,267 @@ +package api + +import ( + "encoding/json" + "net/http" + "strconv" + "strings" + + "github.com/go-chi/chi/v5" + "github.com/projecthelena/warden/internal/db" +) + +type UserHandler struct { + store *db.Store +} + +func NewUserHandler(store *db.Store) *UserHandler { + return &UserHandler{store: store} +} + +// CreateUser creates a new user (admin only). +func (h *UserHandler) CreateUser(w http.ResponseWriter, r *http.Request) { + if !requireRole(w, r, RoleAdmin) { + return + } + + var req struct { + Username string `json:"username"` + Password string `json:"password"` + Role string `json:"role"` + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeError(w, http.StatusBadRequest, "invalid request") + return + } + + req.Username = strings.TrimSpace(req.Username) + if req.Username == "" || req.Password == "" { + writeError(w, http.StatusBadRequest, "username and password are required") + return + } + if len(req.Password) < 8 { + writeError(w, http.StatusBadRequest, "password must be at least 8 characters") + return + } + if req.Role == "" { + req.Role = RoleViewer + } + if !ValidRole(req.Role) { + writeError(w, http.StatusBadRequest, "invalid role") + return + } + + if err := h.store.CreateUser(req.Username, req.Password, "UTC", req.Role); err != nil { + if strings.Contains(err.Error(), "UNIQUE") || strings.Contains(err.Error(), "unique") || strings.Contains(err.Error(), "duplicate") { + writeError(w, http.StatusConflict, "username already exists") + return + } + writeError(w, http.StatusInternalServerError, "failed to create user") + return + } + + writeJSON(w, http.StatusCreated, map[string]string{"message": "user created"}) +} + +// ListUsers returns all users (admin only). +func (h *UserHandler) ListUsers(w http.ResponseWriter, r *http.Request) { + if !requireRole(w, r, RoleAdmin) { + return + } + + users, err := h.store.ListUsers() + if err != nil { + writeError(w, http.StatusInternalServerError, "failed to list users") + return + } + + type UserDTO struct { + ID int64 `json:"id"` + Username string `json:"username"` + Role string `json:"role"` + Email string `json:"email,omitempty"` + SSOProvider string `json:"ssoProvider,omitempty"` + AvatarURL string `json:"avatar,omitempty"` + DisplayName string `json:"displayName,omitempty"` + CreatedAt string `json:"createdAt"` + } + + dtos := make([]UserDTO, 0, len(users)) + for _, u := range users { + dtos = append(dtos, UserDTO{ + ID: u.ID, + Username: u.Username, + Role: u.Role, + Email: u.Email, + SSOProvider: u.SSOProvider, + AvatarURL: u.AvatarURL, + DisplayName: u.DisplayName, + CreatedAt: u.CreatedAt.Format("2006-01-02T15:04:05Z"), + }) + } + + writeJSON(w, http.StatusOK, map[string]any{"users": dtos}) +} + +// UpdateUserRole changes a user's role (admin only). +func (h *UserHandler) UpdateUserRole(w http.ResponseWriter, r *http.Request) { + if !requireRole(w, r, RoleAdmin) { + return + } + + idStr := chi.URLParam(r, "id") + targetID, err := strconv.ParseInt(idStr, 10, 64) + if err != nil { + writeError(w, http.StatusBadRequest, "invalid user ID") + return + } + + // Cannot change own role + currentUserID, _ := r.Context().Value(contextKeyUserID).(int64) + if targetID == currentUserID { + writeError(w, http.StatusBadRequest, "cannot change your own role") + return + } + + var req struct { + Role string `json:"role"` + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeError(w, http.StatusBadRequest, "invalid request") + return + } + if !ValidRole(req.Role) { + writeError(w, http.StatusBadRequest, "invalid role") + return + } + + // If demoting from admin, ensure at least one admin remains + targetUser, err := h.store.GetUser(targetID) + if err != nil { + writeError(w, http.StatusNotFound, "user not found") + return + } + if targetUser.Role == RoleAdmin && req.Role != RoleAdmin { + count, err := h.store.CountAdmins() + if err != nil { + writeError(w, http.StatusInternalServerError, "failed to check admin count") + return + } + if count <= 1 { + writeError(w, http.StatusBadRequest, "cannot remove the last admin") + return + } + } + + if err := h.store.UpdateUserRole(targetID, req.Role); err != nil { + writeError(w, http.StatusInternalServerError, "failed to update role") + return + } + + // Clean up status page assignments when role changes away from status_viewer + if targetUser.Role == RoleStatusViewer && req.Role != RoleStatusViewer { + _ = h.store.SetUserStatusPages(targetID, []int64{}) + } + + writeJSON(w, http.StatusOK, map[string]string{"message": "role updated"}) +} + +// DeleteUser removes a user (admin only). +func (h *UserHandler) DeleteUser(w http.ResponseWriter, r *http.Request) { + if !requireRole(w, r, RoleAdmin) { + return + } + + idStr := chi.URLParam(r, "id") + targetID, err := strconv.ParseInt(idStr, 10, 64) + if err != nil { + writeError(w, http.StatusBadRequest, "invalid user ID") + return + } + + // Cannot delete self + currentUserID, _ := r.Context().Value(contextKeyUserID).(int64) + if targetID == currentUserID { + writeError(w, http.StatusBadRequest, "cannot delete yourself") + return + } + + // Cannot delete the last admin + targetUser, err := h.store.GetUser(targetID) + if err != nil { + writeError(w, http.StatusNotFound, "user not found") + return + } + if targetUser.Role == RoleAdmin { + count, err := h.store.CountAdmins() + if err != nil { + writeError(w, http.StatusInternalServerError, "failed to check admin count") + return + } + if count <= 1 { + writeError(w, http.StatusBadRequest, "cannot delete the last admin") + return + } + } + + if err := h.store.DeleteUser(targetID); err != nil { + writeError(w, http.StatusInternalServerError, "failed to delete user") + return + } + + writeJSON(w, http.StatusOK, map[string]string{"message": "user deleted"}) +} + +// GetUserStatusPages returns the status page IDs assigned to a user (admin only). +func (h *UserHandler) GetUserStatusPages(w http.ResponseWriter, r *http.Request) { + if !requireRole(w, r, RoleAdmin) { + return + } + + idStr := chi.URLParam(r, "id") + userID, err := strconv.ParseInt(idStr, 10, 64) + if err != nil { + writeError(w, http.StatusBadRequest, "invalid user ID") + return + } + + pageIDs, err := h.store.GetUserStatusPages(userID) + if err != nil { + writeError(w, http.StatusInternalServerError, "failed to get status pages") + return + } + if pageIDs == nil { + pageIDs = []int64{} + } + + writeJSON(w, http.StatusOK, map[string]any{"statusPageIds": pageIDs}) +} + +// SetUserStatusPages replaces the status page assignments for a user (admin only). +func (h *UserHandler) SetUserStatusPages(w http.ResponseWriter, r *http.Request) { + if !requireRole(w, r, RoleAdmin) { + return + } + + idStr := chi.URLParam(r, "id") + userID, err := strconv.ParseInt(idStr, 10, 64) + if err != nil { + writeError(w, http.StatusBadRequest, "invalid user ID") + return + } + + var req struct { + StatusPageIDs []int64 `json:"statusPageIds"` + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeError(w, http.StatusBadRequest, "invalid request") + return + } + + if err := h.store.SetUserStatusPages(userID, req.StatusPageIDs); err != nil { + writeError(w, http.StatusInternalServerError, "failed to set status pages") + return + } + + writeJSON(w, http.StatusOK, map[string]string{"message": "status pages updated"}) +} diff --git a/internal/api/integration_test.go b/internal/api/integration_test.go index 90cebb8..7f221fc 100644 --- a/internal/api/integration_test.go +++ b/internal/api/integration_test.go @@ -30,6 +30,7 @@ func TestAPIKeyIntegrationFlow(t *testing.T) { // Use config with AdminSecret for testing cfg := config.Default() cfg.AdminSecret = integrationTestAdminSecret + cfg.CookieSecure = false // Tests use plain HTTP router := NewRouter(manager, store, &cfg) ts := httptest.NewServer(router) diff --git a/internal/api/rbac.go b/internal/api/rbac.go new file mode 100644 index 0000000..b7cc676 --- /dev/null +++ b/internal/api/rbac.go @@ -0,0 +1,78 @@ +package api + +import "net/http" + +// Role constants +const ( + RoleAdmin = "admin" + RoleEditor = "editor" + RoleViewer = "viewer" + RoleStatusViewer = "status_viewer" +) + +const contextKeyUserRole contextKey = "userRole" + +// roleLevel returns a numeric level for role hierarchy comparison. +// Higher level = more permissions. +func roleLevel(role string) int { + switch role { + case RoleAdmin: + return 4 + case RoleEditor: + return 3 + case RoleViewer: + return 2 + case RoleStatusViewer: + return 1 + default: + return 0 + } +} + +// hasPermission checks if userRole meets or exceeds the minimum required role. +func hasPermission(userRole, minimumRole string) bool { + return roleLevel(userRole) >= roleLevel(minimumRole) +} + +// getUserRole extracts the user's role from the request context. +func getUserRole(r *http.Request) string { + role, ok := r.Context().Value(contextKeyUserRole).(string) + if !ok { + return "" + } + return role +} + +// requireRole checks if the user has the minimum required role. +// Returns true if the user has permission, false if blocked (response already written). +func requireRole(w http.ResponseWriter, r *http.Request, minimumRole string) bool { + role := getUserRole(r) + if !hasPermission(role, minimumRole) { + writeError(w, http.StatusForbidden, "your role ("+role+") does not have permission for this action") + return false + } + return true +} + +// RequireViewerMiddleware blocks status_viewer users from dashboard endpoints. +// status_viewer users can only access /api/auth/me and their assigned status pages. +func RequireViewerMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + role := getUserRole(r) + if role == RoleStatusViewer { + writeError(w, http.StatusForbidden, "your role ("+role+") does not have permission for this action") + return + } + next.ServeHTTP(w, r) + }) +} + +// ValidRole returns true if the given role string is a valid RBAC role. +func ValidRole(role string) bool { + switch role { + case RoleAdmin, RoleEditor, RoleViewer, RoleStatusViewer: + return true + default: + return false + } +} diff --git a/internal/api/rbac_admin_editor_test.go b/internal/api/rbac_admin_editor_test.go new file mode 100644 index 0000000..08c44a3 --- /dev/null +++ b/internal/api/rbac_admin_editor_test.go @@ -0,0 +1,793 @@ +package api + +import ( + "bytes" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/go-chi/chi/v5" + "github.com/projecthelena/warden/internal/db" + "github.com/projecthelena/warden/internal/uptime" +) + +// ==================== Admin vs Editor boundary: admin-only endpoints ==================== +// These tests validate that editors are BLOCKED from admin-only endpoints +// and that admins are ALLOWED. + +func TestEditorCannotAccessAdminEndpoints(t *testing.T) { + store, err := db.NewStore(db.NewTestConfig()) + if err != nil { + t.Fatalf("NewStore failed: %v", err) + } + manager := uptime.NewManager(store) + + // Create users for handler tests that need user IDs + _ = store.CreateUser("admin1", "pass", "UTC", "admin") + _ = store.CreateUser("target1", "pass", "UTC", "viewer") + admin, _ := store.Authenticate("admin1", "pass") + target, _ := store.Authenticate("target1", "pass") + + settingsH := NewSettingsHandler(store, manager) + userH := NewUserHandler(store) + apiKeyH := NewAPIKeyHandler(store) + + tests := []struct { + name string + method string + routePattern string + requestPath string + handler http.HandlerFunc + body any + }{ + // Settings + { + name: "PATCH /api/settings", method: "PATCH", + routePattern: "/api/settings", requestPath: "/api/settings", + handler: settingsH.UpdateSettings, body: map[string]string{"data_retention_days": "90"}, + }, + // Users CRUD + { + name: "POST /api/users (create user)", method: "POST", + routePattern: "/api/users", requestPath: "/api/users", + handler: userH.CreateUser, body: map[string]string{"username": "newuser", "password": "longpassword", "role": "viewer"}, + }, + { + name: "GET /api/users (list users)", method: "GET", + routePattern: "/api/users", requestPath: "/api/users", + handler: userH.ListUsers, + }, + { + name: "PATCH /api/users/{id}/role", method: "PATCH", + routePattern: "/api/users/{id}/role", requestPath: fmt.Sprintf("/api/users/%d/role", target.ID), + handler: userH.UpdateUserRole, body: map[string]string{"role": "editor"}, + }, + { + name: "DELETE /api/users/{id}", method: "DELETE", + routePattern: "/api/users/{id}", requestPath: fmt.Sprintf("/api/users/%d", target.ID), + handler: userH.DeleteUser, + }, + { + name: "GET /api/users/{id}/status-pages", method: "GET", + routePattern: "/api/users/{id}/status-pages", requestPath: fmt.Sprintf("/api/users/%d/status-pages", target.ID), + handler: userH.GetUserStatusPages, + }, + { + name: "PUT /api/users/{id}/status-pages", method: "PUT", + routePattern: "/api/users/{id}/status-pages", requestPath: fmt.Sprintf("/api/users/%d/status-pages", target.ID), + handler: userH.SetUserStatusPages, body: map[string]any{"statusPageIds": []int64{}}, + }, + // API Keys + { + name: "GET /api-keys (list keys)", method: "GET", + routePattern: "/api/api-keys", requestPath: "/api/api-keys", + handler: apiKeyH.ListKeys, + }, + { + name: "POST /api-keys (create key)", method: "POST", + routePattern: "/api/api-keys", requestPath: "/api/api-keys", + handler: apiKeyH.CreateKey, body: map[string]string{"name": "test-key"}, + }, + { + name: "DELETE /api-keys/{id}", method: "DELETE", + routePattern: "/api/api-keys/{id}", requestPath: "/api/api-keys/1", + handler: apiKeyH.DeleteKey, + }, + } + + for _, tc := range tests { + t.Run("editor_blocked_"+tc.name, func(t *testing.T) { + var bodyReader *bytes.Buffer + if tc.body != nil { + b, _ := json.Marshal(tc.body) + bodyReader = bytes.NewBuffer(b) + } else { + bodyReader = bytes.NewBuffer(nil) + } + + r := chi.NewRouter() + r.Use(roleAndUserMiddleware(RoleEditor, admin.ID)) + + switch tc.method { + case "GET": + r.Get(tc.routePattern, tc.handler) + case "POST": + r.Post(tc.routePattern, tc.handler) + case "PATCH": + r.Patch(tc.routePattern, tc.handler) + case "PUT": + r.Put(tc.routePattern, tc.handler) + case "DELETE": + r.Delete(tc.routePattern, tc.handler) + } + + req := httptest.NewRequest(tc.method, tc.requestPath, bodyReader) + if tc.body != nil { + req.Header.Set("Content-Type", "application/json") + } + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusForbidden { + t.Errorf("Editor should be forbidden from %s %s: expected 403, got %d. Body: %s", + tc.method, tc.requestPath, w.Code, w.Body.String()) + } + }) + } +} + +func TestAdminCanAccessAdminEndpoints(t *testing.T) { + store, err := db.NewStore(db.NewTestConfig()) + if err != nil { + t.Fatalf("NewStore failed: %v", err) + } + manager := uptime.NewManager(store) + + // Create users for handler tests + _ = store.CreateUser("admin1", "pass", "UTC", "admin") + _ = store.CreateUser("admin2", "pass", "UTC", "admin") + _ = store.CreateUser("target1", "pass", "UTC", "viewer") + admin1, _ := store.Authenticate("admin1", "pass") + target, _ := store.Authenticate("target1", "pass") + + settingsH := NewSettingsHandler(store, manager) + userH := NewUserHandler(store) + apiKeyH := NewAPIKeyHandler(store) + + tests := []struct { + name string + method string + routePattern string // chi route pattern (with {id} params) + requestPath string // actual request URL + handler http.HandlerFunc + body any + wantStatus int + }{ + { + name: "PATCH /api/settings", method: "PATCH", + routePattern: "/api/settings", requestPath: "/api/settings", + handler: settingsH.UpdateSettings, body: map[string]string{"data_retention_days": "90"}, + wantStatus: http.StatusOK, + }, + { + name: "POST /api/users", method: "POST", + routePattern: "/api/users", requestPath: "/api/users", + handler: userH.CreateUser, body: map[string]string{"username": "newadminuser", "password": "longpassword", "role": "viewer"}, + wantStatus: http.StatusCreated, + }, + { + name: "GET /api/users", method: "GET", + routePattern: "/api/users", requestPath: "/api/users", + handler: userH.ListUsers, wantStatus: http.StatusOK, + }, + { + name: "PATCH /api/users/{id}/role", method: "PATCH", + routePattern: "/api/users/{id}/role", requestPath: fmt.Sprintf("/api/users/%d/role", target.ID), + handler: userH.UpdateUserRole, body: map[string]string{"role": "editor"}, + wantStatus: http.StatusOK, + }, + { + name: "GET /api/users/{id}/status-pages", method: "GET", + routePattern: "/api/users/{id}/status-pages", requestPath: fmt.Sprintf("/api/users/%d/status-pages", target.ID), + handler: userH.GetUserStatusPages, wantStatus: http.StatusOK, + }, + { + name: "PUT /api/users/{id}/status-pages", method: "PUT", + routePattern: "/api/users/{id}/status-pages", requestPath: fmt.Sprintf("/api/users/%d/status-pages", target.ID), + handler: userH.SetUserStatusPages, body: map[string]any{"statusPageIds": []int64{}}, + wantStatus: http.StatusOK, + }, + { + name: "GET /api-keys", method: "GET", + routePattern: "/api/api-keys", requestPath: "/api/api-keys", + handler: apiKeyH.ListKeys, wantStatus: http.StatusOK, + }, + { + name: "POST /api-keys", method: "POST", + routePattern: "/api/api-keys", requestPath: "/api/api-keys", + handler: apiKeyH.CreateKey, body: map[string]string{"name": "admin-test-key"}, + wantStatus: http.StatusOK, + }, + } + + for _, tc := range tests { + t.Run("admin_allowed_"+tc.name, func(t *testing.T) { + var bodyReader *bytes.Buffer + if tc.body != nil { + b, _ := json.Marshal(tc.body) + bodyReader = bytes.NewBuffer(b) + } else { + bodyReader = bytes.NewBuffer(nil) + } + + r := chi.NewRouter() + r.Use(roleAndUserMiddleware(RoleAdmin, admin1.ID)) + + switch tc.method { + case "GET": + r.Get(tc.routePattern, tc.handler) + case "POST": + r.Post(tc.routePattern, tc.handler) + case "PATCH": + r.Patch(tc.routePattern, tc.handler) + case "PUT": + r.Put(tc.routePattern, tc.handler) + case "DELETE": + r.Delete(tc.routePattern, tc.handler) + } + + req := httptest.NewRequest(tc.method, tc.requestPath, bodyReader) + if tc.body != nil { + req.Header.Set("Content-Type", "application/json") + } + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != tc.wantStatus { + t.Errorf("Admin should be allowed %s %s: expected %d, got %d. Body: %s", + tc.method, tc.requestPath, tc.wantStatus, w.Code, w.Body.String()) + } + }) + } +} + +// ==================== Admin vs Editor boundary: editor-level endpoints ==================== +// These tests validate that editors CAN access editor-level endpoints +// and that admins can also access them (hierarchy). + +func TestEditorCanAccessEditorEndpoints(t *testing.T) { + store, err := db.NewStore(db.NewTestConfig()) + if err != nil { + t.Fatalf("NewStore failed: %v", err) + } + manager := uptime.NewManager(store) + + crudH := NewCRUDHandler(store, manager) + incidentH := NewIncidentHandler(store) + maintH := NewMaintenanceHandler(store, manager) + notifH := NewNotificationChannelsHandler(store) + + // ── Groups ── + t.Run("editor can create group", func(t *testing.T) { + body, _ := json.Marshal(map[string]string{"name": "Editor Group"}) + r := chi.NewRouter() + r.Use(roleMiddleware(RoleEditor)) + r.Post("/api/groups", crudH.CreateGroup) + + req := httptest.NewRequest("POST", "/api/groups", bytes.NewBuffer(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusCreated { + t.Errorf("Expected 201, got %d. Body: %s", w.Code, w.Body.String()) + } + }) + + t.Run("editor can create monitor", func(t *testing.T) { + body, _ := json.Marshal(map[string]any{ + "name": "Editor Monitor", + "url": "https://example.com", + "groupId": "g-editor-group", + "interval": 60, + }) + r := chi.NewRouter() + r.Use(roleMiddleware(RoleEditor)) + r.Post("/api/monitors", crudH.CreateMonitor) + + req := httptest.NewRequest("POST", "/api/monitors", bytes.NewBuffer(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusCreated { + t.Errorf("Expected 201, got %d. Body: %s", w.Code, w.Body.String()) + } + }) + + // ── Incidents ── + t.Run("editor can create incident", func(t *testing.T) { + body, _ := json.Marshal(map[string]any{ + "title": "Test Incident", + "description": "Something happened", + "severity": "major", + "status": "investigating", + "startTime": time.Now().UTC().Format(time.RFC3339), + "affectedGroups": []string{}, + "public": true, + }) + r := chi.NewRouter() + r.Use(roleMiddleware(RoleEditor)) + r.Post("/api/incidents", incidentH.CreateIncident) + + req := httptest.NewRequest("POST", "/api/incidents", bytes.NewBuffer(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusCreated { + t.Errorf("Expected 201, got %d. Body: %s", w.Code, w.Body.String()) + } + }) + + // ── Maintenance ── + t.Run("editor can create maintenance", func(t *testing.T) { + start := time.Now().Add(1 * time.Hour).UTC() + end := time.Now().Add(2 * time.Hour).UTC() + body, _ := json.Marshal(map[string]any{ + "title": "Maintenance Window", + "description": "Scheduled maintenance", + "status": "scheduled", + "startTime": start.Format(time.RFC3339), + "endTime": end.Format(time.RFC3339), + "affectedGroups": []string{}, + }) + r := chi.NewRouter() + r.Use(roleMiddleware(RoleEditor)) + r.Post("/api/maintenance", maintH.CreateMaintenance) + + req := httptest.NewRequest("POST", "/api/maintenance", bytes.NewBuffer(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusCreated { + t.Errorf("Expected 201, got %d. Body: %s", w.Code, w.Body.String()) + } + }) + + // ── Notification Channels ── + t.Run("editor can create notification channel", func(t *testing.T) { + body, _ := json.Marshal(map[string]any{ + "type": "webhook", + "name": "Editor Webhook", + "config": map[string]any{"webhook_url": "https://example.com/hook"}, + "enabled": true, + }) + r := chi.NewRouter() + r.Use(roleMiddleware(RoleEditor)) + r.Post("/api/notifications/channels", notifH.CreateChannel) + + req := httptest.NewRequest("POST", "/api/notifications/channels", bytes.NewBuffer(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusCreated { + t.Errorf("Expected 201, got %d. Body: %s", w.Code, w.Body.String()) + } + }) +} + +func TestAdminCanAccessEditorEndpoints(t *testing.T) { + store, err := db.NewStore(db.NewTestConfig()) + if err != nil { + t.Fatalf("NewStore failed: %v", err) + } + manager := uptime.NewManager(store) + + crudH := NewCRUDHandler(store, manager) + incidentH := NewIncidentHandler(store) + maintH := NewMaintenanceHandler(store, manager) + notifH := NewNotificationChannelsHandler(store) + + // ── Groups ── + t.Run("admin can create group", func(t *testing.T) { + body, _ := json.Marshal(map[string]string{"name": "Admin Group"}) + r := chi.NewRouter() + r.Use(roleMiddleware(RoleAdmin)) + r.Post("/api/groups", crudH.CreateGroup) + + req := httptest.NewRequest("POST", "/api/groups", bytes.NewBuffer(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusCreated { + t.Errorf("Expected 201, got %d. Body: %s", w.Code, w.Body.String()) + } + }) + + t.Run("admin can create monitor", func(t *testing.T) { + body, _ := json.Marshal(map[string]any{ + "name": "Admin Monitor", + "url": "https://example.com", + "groupId": "g-admin-group", + "interval": 60, + }) + r := chi.NewRouter() + r.Use(roleMiddleware(RoleAdmin)) + r.Post("/api/monitors", crudH.CreateMonitor) + + req := httptest.NewRequest("POST", "/api/monitors", bytes.NewBuffer(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusCreated { + t.Errorf("Expected 201, got %d. Body: %s", w.Code, w.Body.String()) + } + }) + + // ── Incidents ── + t.Run("admin can create incident", func(t *testing.T) { + body, _ := json.Marshal(map[string]any{ + "title": "Admin Incident", + "description": "Admin created this", + "severity": "minor", + "status": "investigating", + "startTime": time.Now().UTC().Format(time.RFC3339), + "affectedGroups": []string{}, + "public": false, + }) + r := chi.NewRouter() + r.Use(roleMiddleware(RoleAdmin)) + r.Post("/api/incidents", incidentH.CreateIncident) + + req := httptest.NewRequest("POST", "/api/incidents", bytes.NewBuffer(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusCreated { + t.Errorf("Expected 201, got %d. Body: %s", w.Code, w.Body.String()) + } + }) + + // ── Maintenance ── + t.Run("admin can create maintenance", func(t *testing.T) { + start := time.Now().Add(1 * time.Hour).UTC() + end := time.Now().Add(2 * time.Hour).UTC() + body, _ := json.Marshal(map[string]any{ + "title": "Admin Maintenance", + "description": "Admin scheduled", + "status": "scheduled", + "startTime": start.Format(time.RFC3339), + "endTime": end.Format(time.RFC3339), + "affectedGroups": []string{}, + }) + r := chi.NewRouter() + r.Use(roleMiddleware(RoleAdmin)) + r.Post("/api/maintenance", maintH.CreateMaintenance) + + req := httptest.NewRequest("POST", "/api/maintenance", bytes.NewBuffer(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusCreated { + t.Errorf("Expected 201, got %d. Body: %s", w.Code, w.Body.String()) + } + }) + + // ── Notification Channels ── + t.Run("admin can create notification channel", func(t *testing.T) { + body, _ := json.Marshal(map[string]any{ + "type": "webhook", + "name": "Admin Webhook", + "config": map[string]any{"webhook_url": "https://example.com/admin-hook"}, + "enabled": true, + }) + r := chi.NewRouter() + r.Use(roleMiddleware(RoleAdmin)) + r.Post("/api/notifications/channels", notifH.CreateChannel) + + req := httptest.NewRequest("POST", "/api/notifications/channels", bytes.NewBuffer(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusCreated { + t.Errorf("Expected 201, got %d. Body: %s", w.Code, w.Body.String()) + } + }) +} + +// ==================== Comprehensive permission matrix ==================== +// Table-driven test covering every role x endpoint combination at the admin/editor boundary. +// Uses bodyFn to generate unique names per role, avoiding name conflicts. + +func TestAdminEditorPermissionMatrix(t *testing.T) { + store, err := db.NewStore(db.NewTestConfig()) + if err != nil { + t.Fatalf("NewStore failed: %v", err) + } + manager := uptime.NewManager(store) + + _ = store.CreateUser("admin_matrix", "pass", "UTC", "admin") + _ = store.CreateUser("target_matrix", "pass", "UTC", "viewer") + admin, _ := store.Authenticate("admin_matrix", "pass") + + crudH := NewCRUDHandler(store, manager) + settingsH := NewSettingsHandler(store, manager) + userH := NewUserHandler(store) + apiKeyH := NewAPIKeyHandler(store) + incidentH := NewIncidentHandler(store) + maintH := NewMaintenanceHandler(store, manager) + notifH := NewNotificationChannelsHandler(store) + + // Pre-create a group so monitors can reference it + { + body, _ := json.Marshal(map[string]string{"name": "Matrix Group"}) + r := chi.NewRouter() + r.Use(roleMiddleware(RoleAdmin)) + r.Post("/api/groups", crudH.CreateGroup) + req := httptest.NewRequest("POST", "/api/groups", bytes.NewBuffer(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + } + + // bodyFn generates a request body with role-specific names to avoid conflicts. + type bodyFn func(role string) any + + type testCase struct { + name string + method string + path string + handler http.HandlerFunc + bodyFn bodyFn + editorWant int + adminWant int + } + + start := time.Now().Add(1 * time.Hour).UTC() + end := time.Now().Add(2 * time.Hour).UTC() + + tests := []testCase{ + // ── Editor-level endpoints (editor=allow, admin=allow) ── + { + name: "create group", method: "POST", path: "/api/groups", + handler: crudH.CreateGroup, + bodyFn: func(role string) any { return map[string]string{"name": "Mx " + role + " Group"} }, + editorWant: http.StatusCreated, adminWant: http.StatusCreated, + }, + { + name: "create monitor", method: "POST", path: "/api/monitors", + handler: crudH.CreateMonitor, + bodyFn: func(role string) any { + return map[string]any{ + "name": "Mx " + role + " Monitor", "url": "https://example.com", + "groupId": "g-matrix-group", "interval": 60, + } + }, + editorWant: http.StatusCreated, adminWant: http.StatusCreated, + }, + { + name: "create incident", method: "POST", path: "/api/incidents", + handler: incidentH.CreateIncident, + bodyFn: func(_ string) any { + return map[string]any{ + "title": "Mx Incident", "description": "test", + "severity": "major", "status": "investigating", + "startTime": time.Now().UTC().Format(time.RFC3339), "affectedGroups": []string{}, "public": true, + } + }, + editorWant: http.StatusCreated, adminWant: http.StatusCreated, + }, + { + name: "create maintenance", method: "POST", path: "/api/maintenance", + handler: maintH.CreateMaintenance, + bodyFn: func(_ string) any { + return map[string]any{ + "title": "Mx Maintenance", "description": "test", + "status": "scheduled", "startTime": start.Format(time.RFC3339), + "endTime": end.Format(time.RFC3339), "affectedGroups": []string{}, + } + }, + editorWant: http.StatusCreated, adminWant: http.StatusCreated, + }, + { + name: "create notification channel", method: "POST", path: "/api/notifications/channels", + handler: notifH.CreateChannel, + bodyFn: func(role string) any { + return map[string]any{ + "type": "webhook", "name": "Mx " + role + " Webhook", + "config": map[string]any{"webhook_url": "https://example.com/hook"}, "enabled": true, + } + }, + editorWant: http.StatusCreated, adminWant: http.StatusCreated, + }, + + // ── Admin-only endpoints (editor=403, admin=allow) ── + { + name: "update settings", method: "PATCH", path: "/api/settings", + handler: settingsH.UpdateSettings, + bodyFn: func(_ string) any { return map[string]string{"data_retention_days": "90"} }, + editorWant: http.StatusForbidden, adminWant: http.StatusOK, + }, + { + name: "list users", method: "GET", path: "/api/users", + handler: userH.ListUsers, + bodyFn: nil, + editorWant: http.StatusForbidden, adminWant: http.StatusOK, + }, + { + name: "create user", method: "POST", path: "/api/users", + handler: userH.CreateUser, + bodyFn: func(role string) any { + return map[string]string{"username": "mx" + role, "password": "longpassword", "role": "viewer"} + }, + editorWant: http.StatusForbidden, adminWant: http.StatusCreated, + }, + { + name: "list api keys", method: "GET", path: "/api/api-keys", + handler: apiKeyH.ListKeys, + bodyFn: nil, + editorWant: http.StatusForbidden, adminWant: http.StatusOK, + }, + { + name: "create api key", method: "POST", path: "/api/api-keys", + handler: apiKeyH.CreateKey, + bodyFn: func(role string) any { + return map[string]string{"name": "mx-" + role + "-key"} + }, + editorWant: http.StatusForbidden, adminWant: http.StatusOK, + }, + } + + runTest := func(t *testing.T, tc testCase, role string, wantStatus int) { + t.Helper() + var bodyReader *bytes.Buffer + if tc.bodyFn != nil { + b, _ := json.Marshal(tc.bodyFn(role)) + bodyReader = bytes.NewBuffer(b) + } else { + bodyReader = bytes.NewBuffer(nil) + } + + r := chi.NewRouter() + r.Use(roleAndUserMiddleware(role, admin.ID)) + switch tc.method { + case "GET": + r.Get(tc.path, tc.handler) + case "POST": + r.Post(tc.path, tc.handler) + case "PATCH": + r.Patch(tc.path, tc.handler) + case "PUT": + r.Put(tc.path, tc.handler) + case "DELETE": + r.Delete(tc.path, tc.handler) + } + + req := httptest.NewRequest(tc.method, tc.path, bodyReader) + if tc.bodyFn != nil { + req.Header.Set("Content-Type", "application/json") + } + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != wantStatus { + t.Errorf("%s %s %s: expected %d, got %d. Body: %s", + role, tc.method, tc.path, wantStatus, w.Code, w.Body.String()) + } + } + + for _, tc := range tests { + t.Run(fmt.Sprintf("editor/%s", tc.name), func(t *testing.T) { + runTest(t, tc, RoleEditor, tc.editorWant) + }) + t.Run(fmt.Sprintf("admin/%s", tc.name), func(t *testing.T) { + runTest(t, tc, RoleAdmin, tc.adminWant) + }) + } +} + +// ==================== Viewer blocked from editor endpoints ==================== +// Ensures the editor level is a real boundary (viewer cannot do what editor can). + +func TestViewerCannotAccessEditorEndpoints(t *testing.T) { + store, err := db.NewStore(db.NewTestConfig()) + if err != nil { + t.Fatalf("NewStore failed: %v", err) + } + manager := uptime.NewManager(store) + + crudH := NewCRUDHandler(store, manager) + incidentH := NewIncidentHandler(store) + maintH := NewMaintenanceHandler(store, manager) + notifH := NewNotificationChannelsHandler(store) + + start := time.Now().Add(1 * time.Hour).UTC() + end := time.Now().Add(2 * time.Hour).UTC() + + tests := []struct { + name string + method string + path string + handler http.HandlerFunc + body any + }{ + { + name: "create group", method: "POST", path: "/api/groups", + handler: crudH.CreateGroup, + body: map[string]string{"name": "Viewer Group"}, + }, + { + name: "create monitor", method: "POST", path: "/api/monitors", + handler: crudH.CreateMonitor, + body: map[string]any{ + "name": "Viewer Monitor", "url": "https://example.com", + "groupId": "g-default", "interval": 60, + }, + }, + { + name: "create incident", method: "POST", path: "/api/incidents", + handler: incidentH.CreateIncident, + body: map[string]any{ + "title": "Viewer Incident", "description": "test", + "severity": "major", "status": "investigating", + "startTime": time.Now().UTC().Format(time.RFC3339), "affectedGroups": []string{}, "public": true, + }, + }, + { + name: "create maintenance", method: "POST", path: "/api/maintenance", + handler: maintH.CreateMaintenance, + body: map[string]any{ + "title": "Viewer Maintenance", "description": "test", + "status": "scheduled", "startTime": start.Format(time.RFC3339), + "endTime": end.Format(time.RFC3339), "affectedGroups": []string{}, + }, + }, + { + name: "create notification channel", method: "POST", path: "/api/notifications/channels", + handler: notifH.CreateChannel, + body: map[string]any{ + "type": "webhook", "name": "Viewer Webhook", + "config": map[string]any{"webhook_url": "https://example.com/hook"}, "enabled": true, + }, + }, + } + + for _, tc := range tests { + t.Run("viewer_blocked_"+tc.name, func(t *testing.T) { + b, _ := json.Marshal(tc.body) + r := chi.NewRouter() + r.Use(roleMiddleware(RoleViewer)) + switch tc.method { + case "POST": + r.Post(tc.path, tc.handler) + case "PUT": + r.Put(tc.path, tc.handler) + case "PATCH": + r.Patch(tc.path, tc.handler) + case "DELETE": + r.Delete(tc.path, tc.handler) + } + + req := httptest.NewRequest(tc.method, tc.path, bytes.NewBuffer(b)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusForbidden { + t.Errorf("Viewer should be forbidden from %s: expected 403, got %d. Body: %s", + tc.name, w.Code, w.Body.String()) + } + }) + } +} diff --git a/internal/api/rbac_test.go b/internal/api/rbac_test.go new file mode 100644 index 0000000..b3d1bb5 --- /dev/null +++ b/internal/api/rbac_test.go @@ -0,0 +1,882 @@ +package api + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "testing" + + "github.com/go-chi/chi/v5" + "github.com/projecthelena/warden/internal/db" + "github.com/projecthelena/warden/internal/uptime" +) + +// roleMiddleware is a test middleware that injects a specific role into context. +func roleMiddleware(role string) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := context.WithValue(r.Context(), contextKeyUserRole, role) + next.ServeHTTP(w, r.WithContext(ctx)) + }) + } +} + +// roleAndUserMiddleware injects both role and user ID into context. +func roleAndUserMiddleware(role string, userID int64) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := context.WithValue(r.Context(), contextKeyUserRole, role) + ctx = context.WithValue(ctx, contextKeyUserID, userID) + next.ServeHTTP(w, r.WithContext(ctx)) + }) + } +} + +// ==================== requireRole tests ==================== + +func TestRequireRole_AdminCanAccessAdminEndpoint(t *testing.T) { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if !requireRole(w, r, RoleAdmin) { + return + } + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("ok")) + }) + + r := chi.NewRouter() + r.Use(roleMiddleware(RoleAdmin)) + r.Get("/test", handler) + + req := httptest.NewRequest("GET", "/test", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("Expected 200, got %d", w.Code) + } +} + +func TestRequireRole_EditorCanAccessEditorEndpoint(t *testing.T) { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if !requireRole(w, r, RoleEditor) { + return + } + w.WriteHeader(http.StatusOK) + }) + + r := chi.NewRouter() + r.Use(roleMiddleware(RoleEditor)) + r.Get("/test", handler) + + req := httptest.NewRequest("GET", "/test", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("Expected 200, got %d", w.Code) + } +} + +func TestRequireRole_AdminCanAccessEditorEndpoint(t *testing.T) { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if !requireRole(w, r, RoleEditor) { + return + } + w.WriteHeader(http.StatusOK) + }) + + r := chi.NewRouter() + r.Use(roleMiddleware(RoleAdmin)) + r.Get("/test", handler) + + req := httptest.NewRequest("GET", "/test", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("Expected 200, got %d", w.Code) + } +} + +func TestRequireRole_ViewerCannotAccessEditorEndpoint(t *testing.T) { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if !requireRole(w, r, RoleEditor) { + return + } + w.WriteHeader(http.StatusOK) + }) + + r := chi.NewRouter() + r.Use(roleMiddleware(RoleViewer)) + r.Get("/test", handler) + + req := httptest.NewRequest("GET", "/test", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusForbidden { + t.Errorf("Expected 403, got %d", w.Code) + } +} + +func TestRequireRole_StatusViewerCannotAccessViewerEndpoint(t *testing.T) { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if !requireRole(w, r, RoleViewer) { + return + } + w.WriteHeader(http.StatusOK) + }) + + r := chi.NewRouter() + r.Use(roleMiddleware(RoleStatusViewer)) + r.Get("/test", handler) + + req := httptest.NewRequest("GET", "/test", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusForbidden { + t.Errorf("Expected 403, got %d", w.Code) + } +} + +func TestRequireRole_EmptyRoleGetsForbidden(t *testing.T) { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if !requireRole(w, r, RoleViewer) { + return + } + w.WriteHeader(http.StatusOK) + }) + + r := chi.NewRouter() + r.Use(roleMiddleware("")) + r.Get("/test", handler) + + req := httptest.NewRequest("GET", "/test", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusForbidden { + t.Errorf("Expected 403, got %d", w.Code) + } +} + +func TestRequireRole_NoRoleInContext(t *testing.T) { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if !requireRole(w, r, RoleViewer) { + return + } + w.WriteHeader(http.StatusOK) + }) + + // No middleware to inject role + r := chi.NewRouter() + r.Get("/test", handler) + + req := httptest.NewRequest("GET", "/test", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusForbidden { + t.Errorf("Expected 403 with no role in context, got %d", w.Code) + } +} + +// ==================== hasPermission tests ==================== + +func TestHasPermission(t *testing.T) { + tests := []struct { + userRole string + minRole string + expected bool + description string + }{ + {RoleAdmin, RoleAdmin, true, "admin >= admin"}, + {RoleAdmin, RoleEditor, true, "admin >= editor"}, + {RoleAdmin, RoleViewer, true, "admin >= viewer"}, + {RoleAdmin, RoleStatusViewer, true, "admin >= status_viewer"}, + {RoleEditor, RoleEditor, true, "editor >= editor"}, + {RoleEditor, RoleViewer, true, "editor >= viewer"}, + {RoleEditor, RoleStatusViewer, true, "editor >= status_viewer"}, + {RoleEditor, RoleAdmin, false, "editor < admin"}, + {RoleViewer, RoleViewer, true, "viewer >= viewer"}, + {RoleViewer, RoleStatusViewer, true, "viewer >= status_viewer"}, + {RoleViewer, RoleEditor, false, "viewer < editor"}, + {RoleViewer, RoleAdmin, false, "viewer < admin"}, + {RoleStatusViewer, RoleStatusViewer, true, "status_viewer >= status_viewer"}, + {RoleStatusViewer, RoleViewer, false, "status_viewer < viewer"}, + {RoleStatusViewer, RoleEditor, false, "status_viewer < editor"}, + {RoleStatusViewer, RoleAdmin, false, "status_viewer < admin"}, + {"", RoleStatusViewer, false, "empty < status_viewer"}, + {"", RoleViewer, false, "empty < viewer"}, + {"invalid_role", RoleViewer, false, "invalid < viewer"}, + } + + for _, tc := range tests { + t.Run(tc.description, func(t *testing.T) { + result := hasPermission(tc.userRole, tc.minRole) + if result != tc.expected { + t.Errorf("hasPermission(%q, %q) = %v, want %v", tc.userRole, tc.minRole, result, tc.expected) + } + }) + } +} + +// ==================== ValidRole tests ==================== + +func TestValidRole(t *testing.T) { + validRoles := []string{"admin", "editor", "viewer", "status_viewer"} + for _, role := range validRoles { + t.Run(fmt.Sprintf("valid_%s", role), func(t *testing.T) { + if !ValidRole(role) { + t.Errorf("ValidRole(%q) = false, want true", role) + } + }) + } + + invalidRoles := []string{"", "superadmin", "Admin", "ADMIN", "root", "user", "moderator", "status-viewer"} + for _, role := range invalidRoles { + t.Run(fmt.Sprintf("invalid_%s", role), func(t *testing.T) { + if ValidRole(role) { + t.Errorf("ValidRole(%q) = true, want false", role) + } + }) + } +} + +// ==================== RequireViewerMiddleware tests ==================== + +func TestRequireViewerMiddleware_BlocksStatusViewer(t *testing.T) { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + r := chi.NewRouter() + r.Use(roleMiddleware(RoleStatusViewer)) + r.Use(RequireViewerMiddleware) + r.Get("/test", handler) + + req := httptest.NewRequest("GET", "/test", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusForbidden { + t.Errorf("Expected 403 for status_viewer, got %d", w.Code) + } +} + +func TestRequireViewerMiddleware_AllowsViewer(t *testing.T) { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + r := chi.NewRouter() + r.Use(roleMiddleware(RoleViewer)) + r.Use(RequireViewerMiddleware) + r.Get("/test", handler) + + req := httptest.NewRequest("GET", "/test", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("Expected 200 for viewer, got %d", w.Code) + } +} + +func TestRequireViewerMiddleware_AllowsEditor(t *testing.T) { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + r := chi.NewRouter() + r.Use(roleMiddleware(RoleEditor)) + r.Use(RequireViewerMiddleware) + r.Get("/test", handler) + + req := httptest.NewRequest("GET", "/test", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("Expected 200 for editor, got %d", w.Code) + } +} + +func TestRequireViewerMiddleware_AllowsAdmin(t *testing.T) { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + r := chi.NewRouter() + r.Use(roleMiddleware(RoleAdmin)) + r.Use(RequireViewerMiddleware) + r.Get("/test", handler) + + req := httptest.NewRequest("GET", "/test", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("Expected 200 for admin, got %d", w.Code) + } +} + +// ==================== Handler role enforcement (integration-style) ==================== + +func TestViewerCanGetOverview(t *testing.T) { + store, err := db.NewStore(db.NewTestConfig()) + if err != nil { + t.Fatalf("NewStore failed: %v", err) + } + manager := uptime.NewManager(store) + uptimeH := NewUptimeHandler(manager, store) + + r := chi.NewRouter() + r.Use(roleMiddleware(RoleViewer)) + r.Use(RequireViewerMiddleware) + r.Get("/api/overview", uptimeH.GetOverview) + + req := httptest.NewRequest("GET", "/api/overview", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("Expected viewer to GET /api/overview (200), got %d. Body: %s", w.Code, w.Body.String()) + } +} + +func TestViewerCannotPostMonitors(t *testing.T) { + store, err := db.NewStore(db.NewTestConfig()) + if err != nil { + t.Fatalf("NewStore failed: %v", err) + } + manager := uptime.NewManager(store) + crudH := NewCRUDHandler(store, manager) + + payload := map[string]interface{}{ + "name": "Test Monitor", + "url": "http://test.com", + "groupId": "g-default", + "interval": 60, + } + body, _ := json.Marshal(payload) + + r := chi.NewRouter() + r.Use(roleMiddleware(RoleViewer)) + r.Post("/api/monitors", crudH.CreateMonitor) + + req := httptest.NewRequest("POST", "/api/monitors", bytes.NewBuffer(body)) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusForbidden { + t.Errorf("Expected viewer to be forbidden from POST /api/monitors (403), got %d", w.Code) + } +} + +func TestEditorCanPostMonitors(t *testing.T) { + store, err := db.NewStore(db.NewTestConfig()) + if err != nil { + t.Fatalf("NewStore failed: %v", err) + } + manager := uptime.NewManager(store) + crudH := NewCRUDHandler(store, manager) + + payload := map[string]interface{}{ + "name": "Editor Monitor", + "url": "http://editor-test.com", + "groupId": "g-default", + "interval": 60, + } + body, _ := json.Marshal(payload) + + r := chi.NewRouter() + r.Use(roleMiddleware(RoleEditor)) + r.Post("/api/monitors", crudH.CreateMonitor) + + req := httptest.NewRequest("POST", "/api/monitors", bytes.NewBuffer(body)) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusCreated { + t.Errorf("Expected editor to POST /api/monitors (201), got %d. Body: %s", w.Code, w.Body.String()) + } +} + +func TestViewerCannotPatchSettings(t *testing.T) { + store, err := db.NewStore(db.NewTestConfig()) + if err != nil { + t.Fatalf("NewStore failed: %v", err) + } + manager := uptime.NewManager(store) + settingsH := NewSettingsHandler(store, manager) + + payload := map[string]string{"data_retention_days": "90"} + body, _ := json.Marshal(payload) + + r := chi.NewRouter() + r.Use(roleMiddleware(RoleViewer)) + r.Patch("/api/settings", settingsH.UpdateSettings) + + req := httptest.NewRequest("PATCH", "/api/settings", bytes.NewBuffer(body)) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusForbidden { + t.Errorf("Expected viewer to be forbidden from PATCH /api/settings (403), got %d", w.Code) + } +} + +func TestEditorCannotPatchSettings(t *testing.T) { + store, err := db.NewStore(db.NewTestConfig()) + if err != nil { + t.Fatalf("NewStore failed: %v", err) + } + manager := uptime.NewManager(store) + settingsH := NewSettingsHandler(store, manager) + + payload := map[string]string{"data_retention_days": "90"} + body, _ := json.Marshal(payload) + + r := chi.NewRouter() + r.Use(roleMiddleware(RoleEditor)) + r.Patch("/api/settings", settingsH.UpdateSettings) + + req := httptest.NewRequest("PATCH", "/api/settings", bytes.NewBuffer(body)) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusForbidden { + t.Errorf("Expected editor to be forbidden from PATCH /api/settings (403), got %d", w.Code) + } +} + +func TestAdminCanPatchSettings(t *testing.T) { + store, err := db.NewStore(db.NewTestConfig()) + if err != nil { + t.Fatalf("NewStore failed: %v", err) + } + manager := uptime.NewManager(store) + settingsH := NewSettingsHandler(store, manager) + + payload := map[string]string{"data_retention_days": "90"} + body, _ := json.Marshal(payload) + + r := chi.NewRouter() + r.Use(roleMiddleware(RoleAdmin)) + r.Patch("/api/settings", settingsH.UpdateSettings) + + req := httptest.NewRequest("PATCH", "/api/settings", bytes.NewBuffer(body)) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("Expected admin to PATCH /api/settings (200), got %d. Body: %s", w.Code, w.Body.String()) + } +} + +// ==================== User management handler tests ==================== + +func TestListUsers_ReturnsUsersWithRoles(t *testing.T) { + store, err := db.NewStore(db.NewTestConfig()) + if err != nil { + t.Fatalf("NewStore failed: %v", err) + } + + // Create users with different roles + _ = store.CreateUser("admin1", "pass", "UTC", "admin") + _ = store.CreateUser("editor1", "pass", "UTC", "editor") + _ = store.CreateUser("viewer1", "pass", "UTC", "viewer") + + userH := NewUserHandler(store) + + r := chi.NewRouter() + r.Use(roleMiddleware(RoleAdmin)) + r.Get("/api/users", userH.ListUsers) + + req := httptest.NewRequest("GET", "/api/users", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("Expected 200, got %d. Body: %s", w.Code, w.Body.String()) + } + + var resp struct { + Users []struct { + ID int64 `json:"id"` + Username string `json:"username"` + Role string `json:"role"` + } `json:"users"` + } + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("Failed to parse response: %v", err) + } + + if len(resp.Users) != 3 { + t.Fatalf("Expected 3 users, got %d", len(resp.Users)) + } + + roleMap := make(map[string]string) + for _, u := range resp.Users { + roleMap[u.Username] = u.Role + } + + if roleMap["admin1"] != "admin" { + t.Errorf("admin1: expected role 'admin', got %s", roleMap["admin1"]) + } + if roleMap["editor1"] != "editor" { + t.Errorf("editor1: expected role 'editor', got %s", roleMap["editor1"]) + } + if roleMap["viewer1"] != "viewer" { + t.Errorf("viewer1: expected role 'viewer', got %s", roleMap["viewer1"]) + } +} + +func TestListUsers_NonAdminForbidden(t *testing.T) { + store, err := db.NewStore(db.NewTestConfig()) + if err != nil { + t.Fatalf("NewStore failed: %v", err) + } + + userH := NewUserHandler(store) + + roles := []string{RoleEditor, RoleViewer, RoleStatusViewer} + for _, role := range roles { + t.Run(role, func(t *testing.T) { + r := chi.NewRouter() + r.Use(roleMiddleware(role)) + r.Get("/api/users", userH.ListUsers) + + req := httptest.NewRequest("GET", "/api/users", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusForbidden { + t.Errorf("Expected 403 for role %s, got %d", role, w.Code) + } + }) + } +} + +func TestUpdateUserRole_ChangesRole(t *testing.T) { + store, err := db.NewStore(db.NewTestConfig()) + if err != nil { + t.Fatalf("NewStore failed: %v", err) + } + + // Create admin (the requester, ID=1) and a target user + _ = store.CreateUser("admin_requester", "pass", "UTC", "admin") + _ = store.CreateUser("target_user", "pass", "UTC", "viewer") + + admin, _ := store.Authenticate("admin_requester", "pass") + target, _ := store.Authenticate("target_user", "pass") + + userH := NewUserHandler(store) + + payload := map[string]string{"role": "editor"} + body, _ := json.Marshal(payload) + + r := chi.NewRouter() + r.Use(roleAndUserMiddleware(RoleAdmin, admin.ID)) + r.Patch("/api/users/{id}/role", userH.UpdateUserRole) + + req := httptest.NewRequest("PATCH", fmt.Sprintf("/api/users/%d/role", target.ID), bytes.NewBuffer(body)) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("Expected 200, got %d. Body: %s", w.Code, w.Body.String()) + } + + // Verify role changed + role, _ := store.GetUserRole(target.ID) + if role != "editor" { + t.Errorf("Expected role 'editor' after update, got %s", role) + } +} + +func TestUpdateUserRole_RejectsSelfChange(t *testing.T) { + store, err := db.NewStore(db.NewTestConfig()) + if err != nil { + t.Fatalf("NewStore failed: %v", err) + } + + _ = store.CreateUser("self_admin", "pass", "UTC", "admin") + admin, _ := store.Authenticate("self_admin", "pass") + + userH := NewUserHandler(store) + + payload := map[string]string{"role": "viewer"} + body, _ := json.Marshal(payload) + + r := chi.NewRouter() + r.Use(roleAndUserMiddleware(RoleAdmin, admin.ID)) + r.Patch("/api/users/{id}/role", userH.UpdateUserRole) + + // Try to change own role + req := httptest.NewRequest("PATCH", fmt.Sprintf("/api/users/%d/role", admin.ID), bytes.NewBuffer(body)) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("Expected 400 for self-change, got %d. Body: %s", w.Code, w.Body.String()) + } + + var resp map[string]string + _ = json.Unmarshal(w.Body.Bytes(), &resp) + if resp["error"] != "cannot change your own role" { + t.Errorf("Expected error 'cannot change your own role', got %s", resp["error"]) + } +} + +func TestUpdateUserRole_PreventsRemovingLastAdmin(t *testing.T) { + store, err := db.NewStore(db.NewTestConfig()) + if err != nil { + t.Fatalf("NewStore failed: %v", err) + } + + // Create the only admin and a separate requester admin + // We need two admins so the requester admin can try to demote the other one + // But if we only have one admin and try to demote them... let's test with two admins + // and then try to demote both to leave zero. + + _ = store.CreateUser("admin_a", "pass", "UTC", "admin") + _ = store.CreateUser("admin_b", "pass", "UTC", "admin") + adminA, _ := store.Authenticate("admin_a", "pass") + adminB, _ := store.Authenticate("admin_b", "pass") + + userH := NewUserHandler(store) + + // First, demote admin_b to editor (should succeed, admin_a remains) + payload := map[string]string{"role": "editor"} + body, _ := json.Marshal(payload) + + r := chi.NewRouter() + r.Use(roleAndUserMiddleware(RoleAdmin, adminA.ID)) + r.Patch("/api/users/{id}/role", userH.UpdateUserRole) + + req := httptest.NewRequest("PATCH", fmt.Sprintf("/api/users/%d/role", adminB.ID), bytes.NewBuffer(body)) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("Expected 200 for demoting admin_b (admin_a remains), got %d. Body: %s", w.Code, w.Body.String()) + } + + // Now admin_a is the only admin. Create a new viewer to try to demote admin_a through them. + // Actually, we cannot demote admin_a since the self-check prevents it. + // Let's create a new admin_c and try to be the requester as admin_c demoting admin_a (the last admin). + // Wait - admin_a is the last admin, admin_b is now editor. + // We need a different admin to call the endpoint. Let's promote admin_b back temporarily. + + // Simpler approach: test with fresh data + store2, _ := db.NewStore(db.NewTestConfig()) + _ = store2.CreateUser("solo_admin", "pass", "UTC", "admin") + _ = store2.CreateUser("other_user", "pass", "UTC", "editor") + soloAdmin, _ := store2.Authenticate("solo_admin", "pass") + other, _ := store2.Authenticate("other_user", "pass") + + // Promote other to admin first so we have two admins, then demote solo_admin + _ = store2.UpdateUserRole(other.ID, "admin") + + userH2 := NewUserHandler(store2) + + // Now demote other (leaving solo_admin as only admin) - should succeed + payload2 := map[string]string{"role": "viewer"} + body2, _ := json.Marshal(payload2) + + r2 := chi.NewRouter() + r2.Use(roleAndUserMiddleware(RoleAdmin, soloAdmin.ID)) + r2.Patch("/api/users/{id}/role", userH2.UpdateUserRole) + + req2 := httptest.NewRequest("PATCH", fmt.Sprintf("/api/users/%d/role", other.ID), bytes.NewBuffer(body2)) + w2 := httptest.NewRecorder() + r2.ServeHTTP(w2, req2) + + if w2.Code != http.StatusOK { + t.Fatalf("Expected 200, got %d", w2.Code) + } + + // Now solo_admin is the ONLY admin. Try to have someone else demote them. + // We need a requester who is admin - but solo_admin is the only one and can't demote self. + // The practical test: create a scenario where the endpoint is called for the last admin. + // We'll use other_user as a "hacked" admin context to try to demote solo_admin. + r3 := chi.NewRouter() + r3.Use(roleAndUserMiddleware(RoleAdmin, other.ID)) // pretend other is admin in context + r3.Patch("/api/users/{id}/role", userH2.UpdateUserRole) + + payload3 := map[string]string{"role": "editor"} + body3, _ := json.Marshal(payload3) + req3 := httptest.NewRequest("PATCH", fmt.Sprintf("/api/users/%d/role", soloAdmin.ID), bytes.NewBuffer(body3)) + w3 := httptest.NewRecorder() + r3.ServeHTTP(w3, req3) + + if w3.Code != http.StatusBadRequest { + t.Errorf("Expected 400 when removing last admin, got %d. Body: %s", w3.Code, w3.Body.String()) + } + + var resp map[string]string + _ = json.Unmarshal(w3.Body.Bytes(), &resp) + if resp["error"] != "cannot remove the last admin" { + t.Errorf("Expected error 'cannot remove the last admin', got %s", resp["error"]) + } +} + +func TestUpdateUserRole_InvalidRole(t *testing.T) { + store, err := db.NewStore(db.NewTestConfig()) + if err != nil { + t.Fatalf("NewStore failed: %v", err) + } + + _ = store.CreateUser("admin_user", "pass", "UTC", "admin") + _ = store.CreateUser("target_user", "pass", "UTC", "viewer") + admin, _ := store.Authenticate("admin_user", "pass") + target, _ := store.Authenticate("target_user", "pass") + + userH := NewUserHandler(store) + + payload := map[string]string{"role": "superadmin"} + body, _ := json.Marshal(payload) + + r := chi.NewRouter() + r.Use(roleAndUserMiddleware(RoleAdmin, admin.ID)) + r.Patch("/api/users/{id}/role", userH.UpdateUserRole) + + req := httptest.NewRequest("PATCH", fmt.Sprintf("/api/users/%d/role", target.ID), bytes.NewBuffer(body)) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("Expected 400 for invalid role, got %d. Body: %s", w.Code, w.Body.String()) + } +} + +func TestDeleteUser_Works(t *testing.T) { + store, err := db.NewStore(db.NewTestConfig()) + if err != nil { + t.Fatalf("NewStore failed: %v", err) + } + + _ = store.CreateUser("admin_user", "pass", "UTC", "admin") + _ = store.CreateUser("to_delete", "pass", "UTC", "editor") + admin, _ := store.Authenticate("admin_user", "pass") + target, _ := store.Authenticate("to_delete", "pass") + + userH := NewUserHandler(store) + + r := chi.NewRouter() + r.Use(roleAndUserMiddleware(RoleAdmin, admin.ID)) + r.Delete("/api/users/{id}", userH.DeleteUser) + + req := httptest.NewRequest("DELETE", fmt.Sprintf("/api/users/%d", target.ID), nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("Expected 200, got %d. Body: %s", w.Code, w.Body.String()) + } + + // Verify user is gone + _, err = store.GetUser(target.ID) + if err == nil { + t.Error("Expected error getting deleted user") + } +} + +func TestDeleteUser_RejectsSelfDeletion(t *testing.T) { + store, err := db.NewStore(db.NewTestConfig()) + if err != nil { + t.Fatalf("NewStore failed: %v", err) + } + + _ = store.CreateUser("self_delete", "pass", "UTC", "admin") + admin, _ := store.Authenticate("self_delete", "pass") + + userH := NewUserHandler(store) + + r := chi.NewRouter() + r.Use(roleAndUserMiddleware(RoleAdmin, admin.ID)) + r.Delete("/api/users/{id}", userH.DeleteUser) + + req := httptest.NewRequest("DELETE", fmt.Sprintf("/api/users/%d", admin.ID), nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("Expected 400 for self-deletion, got %d. Body: %s", w.Code, w.Body.String()) + } + + var resp map[string]string + _ = json.Unmarshal(w.Body.Bytes(), &resp) + if resp["error"] != "cannot delete yourself" { + t.Errorf("Expected error 'cannot delete yourself', got %s", resp["error"]) + } +} + +func TestDeleteUser_PreventsDeleteLastAdmin(t *testing.T) { + store, err := db.NewStore(db.NewTestConfig()) + if err != nil { + t.Fatalf("NewStore failed: %v", err) + } + + // Create only one admin and an editor + _ = store.CreateUser("only_admin", "pass", "UTC", "admin") + _ = store.CreateUser("requester", "pass", "UTC", "editor") + onlyAdmin, _ := store.Authenticate("only_admin", "pass") + requester, _ := store.Authenticate("requester", "pass") + + userH := NewUserHandler(store) + + r := chi.NewRouter() + // Use requester context with admin role (simulating a scenario where this check matters) + r.Use(roleAndUserMiddleware(RoleAdmin, requester.ID)) + r.Delete("/api/users/{id}", userH.DeleteUser) + + req := httptest.NewRequest("DELETE", fmt.Sprintf("/api/users/%d", onlyAdmin.ID), nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("Expected 400 when deleting last admin, got %d. Body: %s", w.Code, w.Body.String()) + } + + var resp map[string]string + _ = json.Unmarshal(w.Body.Bytes(), &resp) + if resp["error"] != "cannot delete the last admin" { + t.Errorf("Expected error 'cannot delete the last admin', got %s", resp["error"]) + } +} + +func TestDeleteUser_NonAdminForbidden(t *testing.T) { + store, err := db.NewStore(db.NewTestConfig()) + if err != nil { + t.Fatalf("NewStore failed: %v", err) + } + + _ = store.CreateUser("target", "pass", "UTC", "viewer") + target, _ := store.Authenticate("target", "pass") + + userH := NewUserHandler(store) + + roles := []string{RoleEditor, RoleViewer} + for _, role := range roles { + t.Run(role, func(t *testing.T) { + r := chi.NewRouter() + r.Use(roleAndUserMiddleware(role, int64(9999))) + r.Delete("/api/users/{id}", userH.DeleteUser) + + req := httptest.NewRequest("DELETE", fmt.Sprintf("/api/users/%d", target.ID), nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusForbidden { + t.Errorf("Expected 403 for role %s, got %d", role, w.Code) + } + }) + } +} diff --git a/internal/api/router.go b/internal/api/router.go index bd78721..b47c9b0 100644 --- a/internal/api/router.go +++ b/internal/api/router.go @@ -114,6 +114,7 @@ func NewRouter(manager *uptime.Manager, store *db.Store, cfg *config.Config) htt eventH := NewEventHandler(store, manager) statusPageH := NewStatusPageHandler(store, manager, authH) notifH := NewNotificationChannelsHandler(store) + userH := NewUserHandler(store) // Kubernetes health probes (unauthenticated, no rate limiting) r.Get("/healthz", Healthz) @@ -153,98 +154,90 @@ func NewRouter(manager *uptime.Manager, store *db.Store, cfg *config.Config) htt api.Group(func(protected chi.Router) { protected.Use(authH.AuthMiddleware) + + // Auth endpoints - accessible by ALL authenticated users (including status_viewer) protected.Get("/auth/me", authH.Me) protected.Patch("/auth/me", authH.UpdateUser) + protected.Get("/my/status-pages", statusPageH.GetMyStatusPages) + + // All other endpoints require at least viewer role (blocks status_viewer) + protected.Group(func(dashboard chi.Router) { + dashboard.Use(RequireViewerMiddleware) // Dashboard Overview - protected.Get("/overview", uptimeH.GetOverview) + dashboard.Get("/overview", uptimeH.GetOverview) // Groups - protected.Post("/groups", crudH.CreateGroup) - protected.Put("/groups/{id}", crudH.UpdateGroup) - protected.Delete("/groups/{id}", crudH.DeleteGroup) + dashboard.Post("/groups", crudH.CreateGroup) + dashboard.Put("/groups/{id}", crudH.UpdateGroup) + dashboard.Delete("/groups/{id}", crudH.DeleteGroup) // Monitors - // /uptime maps to GetHistory in handlers_uptime.go (returns list of monitors with history) - protected.Get("/uptime", uptimeH.GetHistory) - protected.Post("/monitors", crudH.CreateMonitor) - protected.Put("/monitors/{id}", crudH.UpdateMonitor) - protected.Delete("/monitors/{id}", crudH.DeleteMonitor) - protected.Post("/monitors/{id}/pause", crudH.PauseMonitor) - protected.Post("/monitors/{id}/resume", crudH.ResumeMonitor) - protected.Get("/monitors/{id}/uptime", uptimeH.GetMonitorUptime) - protected.Get("/monitors/{id}/latency", uptimeH.GetMonitorLatency) + dashboard.Get("/uptime", uptimeH.GetHistory) + dashboard.Post("/monitors", crudH.CreateMonitor) + dashboard.Put("/monitors/{id}", crudH.UpdateMonitor) + dashboard.Delete("/monitors/{id}", crudH.DeleteMonitor) + dashboard.Post("/monitors/{id}/pause", crudH.PauseMonitor) + dashboard.Post("/monitors/{id}/resume", crudH.ResumeMonitor) + dashboard.Get("/monitors/{id}/uptime", uptimeH.GetMonitorUptime) + dashboard.Get("/monitors/{id}/latency", uptimeH.GetMonitorLatency) // Incidents - protected.Get("/incidents", incidentH.GetIncidents) - protected.Post("/incidents", incidentH.CreateIncident) - protected.Get("/incidents/{id}", incidentH.GetIncident) - protected.Put("/incidents/{id}", incidentH.UpdateIncident) - protected.Delete("/incidents/{id}", incidentH.DeleteIncident) - protected.Patch("/incidents/{id}/visibility", incidentH.SetVisibility) - protected.Get("/incidents/{id}/updates", incidentH.GetUpdates) - protected.Post("/incidents/{id}/updates", incidentH.AddUpdate) + dashboard.Get("/incidents", incidentH.GetIncidents) + dashboard.Post("/incidents", incidentH.CreateIncident) + dashboard.Get("/incidents/{id}", incidentH.GetIncident) + dashboard.Put("/incidents/{id}", incidentH.UpdateIncident) + dashboard.Delete("/incidents/{id}", incidentH.DeleteIncident) + dashboard.Patch("/incidents/{id}/visibility", incidentH.SetVisibility) + dashboard.Get("/incidents/{id}/updates", incidentH.GetUpdates) + dashboard.Post("/incidents/{id}/updates", incidentH.AddUpdate) // Outages (promote to incident) - protected.Post("/outages/{id}/promote", incidentH.PromoteOutage) + dashboard.Post("/outages/{id}/promote", incidentH.PromoteOutage) // Maintenance - protected.Post("/maintenance", maintH.CreateMaintenance) - protected.Get("/maintenance", maintH.GetMaintenance) - protected.Put("/maintenance/{id}", maintH.UpdateMaintenance) - protected.Delete("/maintenance/{id}", maintH.DeleteMaintenance) + dashboard.Post("/maintenance", maintH.CreateMaintenance) + dashboard.Get("/maintenance", maintH.GetMaintenance) + dashboard.Put("/maintenance/{id}", maintH.UpdateMaintenance) + dashboard.Delete("/maintenance/{id}", maintH.DeleteMaintenance) // Settings - protected.Get("/settings", settingsH.GetSettings) - protected.Patch("/settings", settingsH.UpdateSettings) + dashboard.Get("/settings", settingsH.GetSettings) + dashboard.Patch("/settings", settingsH.UpdateSettings) // SSO Settings (admin only) - protected.Post("/settings/sso/test", ssoH.TestSSOConfig) + dashboard.Post("/settings/sso/test", ssoH.TestSSOConfig) // API Keys - protected.Get("/api-keys", apiKeyH.ListKeys) - protected.Post("/api-keys", apiKeyH.CreateKey) - protected.Delete("/api-keys/{id}", apiKeyH.DeleteKey) + dashboard.Get("/api-keys", apiKeyH.ListKeys) + dashboard.Post("/api-keys", apiKeyH.CreateKey) + dashboard.Delete("/api-keys/{id}", apiKeyH.DeleteKey) // Stats - protected.Get("/stats", statsH.GetStats) + dashboard.Get("/stats", statsH.GetStats) // Notifications - protected.Get("/notifications/channels", notifH.GetChannels) - protected.Post("/notifications/channels", notifH.CreateChannel) - protected.Post("/notifications/channels/test", notifH.TestChannel) - protected.Put("/notifications/channels/{id}", notifH.UpdateChannel) - protected.Delete("/notifications/channels/{id}", notifH.DeleteChannel) + dashboard.Get("/notifications/channels", notifH.GetChannels) + dashboard.Post("/notifications/channels", notifH.CreateChannel) + dashboard.Post("/notifications/channels/test", notifH.TestChannel) + dashboard.Put("/notifications/channels/{id}", notifH.UpdateChannel) + dashboard.Delete("/notifications/channels/{id}", notifH.DeleteChannel) // Events (for history) - protected.Get("/events", eventH.GetSystemEvents) + dashboard.Get("/events", eventH.GetSystemEvents) + + // Users (admin only) + dashboard.Post("/users", userH.CreateUser) + dashboard.Get("/users", userH.ListUsers) + dashboard.Patch("/users/{id}/role", userH.UpdateUserRole) + dashboard.Delete("/users/{id}", userH.DeleteUser) + dashboard.Get("/users/{id}/status-pages", userH.GetUserStatusPages) + dashboard.Put("/users/{id}/status-pages", userH.SetUserStatusPages) // Status Pages Management - protected.Get("/status-pages", statusPageH.GetAll) - // Note: Create/Upd/Del methods need to be verified in handlers_status_pages.go - // Based on GetAll, it likely has Toggle. - // Let's assume standard names or check Step 1189 view. - // Step 1189 shows: GetAll, Toggle, GetPublicStatus. - // It does NOT show CreateStatusPage, UpdateStatusPage, DeleteStatusPage explicitly in the view - // (view truncated? No, showed 1-284 which seemed to be whole file?). - // Wait, Step 1189 showed lines 1-284 for handlers_status_pages.go. - // It handled GetAll and Toggle and GetPublicStatus. - // There is NO Create/Delete? - // The store has UpsertStatusPage used in Toggle. - // Maybe there is no Create? Just Toggle? - // The routes in Step 1146 (original) were: - // protected.Post("/status-pages", apiRouter.CreateStatusPage) - // protected.Patch("/status-pages/{slug}", apiRouter.UpdateStatusPage) - // protected.Delete("/status-pages/{slug}", apiRouter.DeleteStatusPage) - - // If handlers_status_pages.go only has Toggle, then "UpdateStatusPage" mapping to Toggle is correct. - // What about Create/Delete? - // Maybe they were missing or I missed them in search? - // If they are missing, I should ommit or fix. - // Toggle does Upsert. So maybe Post -> Toggle? - protected.Patch("/status-pages/{slug}", statusPageH.Toggle) - - // If Create/Delete are missing, I'll comment them out for now to avoid compilation error. + dashboard.Get("/status-pages", statusPageH.GetAll) + dashboard.Patch("/status-pages/{slug}", statusPageH.Toggle) + }) }) }) diff --git a/internal/config/config.go b/internal/config/config.go index 2b6d7b4..3635f81 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -26,7 +26,7 @@ func Default() Config { ListenAddr: ":9096", DBType: DBTypeSQLite, DBPath: "warden.db", - CookieSecure: false, + CookieSecure: true, } } @@ -58,8 +58,8 @@ func Load() (*Config, error) { } } - if os.Getenv("COOKIE_SECURE") == "true" { - cfg.CookieSecure = true + if cs := os.Getenv("COOKIE_SECURE"); cs != "" { + cfg.CookieSecure = strings.EqualFold(cs, "true") } if secret := os.Getenv("ADMIN_SECRET"); secret != "" { diff --git a/internal/db/migrations/postgres/00018_rbac_roles.sql b/internal/db/migrations/postgres/00018_rbac_roles.sql new file mode 100644 index 0000000..d7a72b4 --- /dev/null +++ b/internal/db/migrations/postgres/00018_rbac_roles.sql @@ -0,0 +1,15 @@ +-- +goose Up +ALTER TABLE users ADD COLUMN role TEXT NOT NULL DEFAULT 'admin'; +ALTER TABLE api_keys ADD COLUMN role TEXT NOT NULL DEFAULT 'editor'; +CREATE TABLE user_status_pages ( + user_id INTEGER NOT NULL, + status_page_id INTEGER NOT NULL, + PRIMARY KEY (user_id, status_page_id), + FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE, + FOREIGN KEY (status_page_id) REFERENCES status_pages(id) ON DELETE CASCADE +); + +-- +goose Down +ALTER TABLE users DROP COLUMN IF EXISTS role; +ALTER TABLE api_keys DROP COLUMN IF EXISTS role; +DROP TABLE IF EXISTS user_status_pages; diff --git a/internal/db/migrations/sqlite/00018_rbac_roles.sql b/internal/db/migrations/sqlite/00018_rbac_roles.sql new file mode 100644 index 0000000..f707040 --- /dev/null +++ b/internal/db/migrations/sqlite/00018_rbac_roles.sql @@ -0,0 +1,13 @@ +-- +goose Up +ALTER TABLE users ADD COLUMN role TEXT NOT NULL DEFAULT 'admin'; +ALTER TABLE api_keys ADD COLUMN role TEXT NOT NULL DEFAULT 'editor'; +CREATE TABLE user_status_pages ( + user_id INTEGER NOT NULL, + status_page_id INTEGER NOT NULL, + PRIMARY KEY (user_id, status_page_id), + FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE, + FOREIGN KEY (status_page_id) REFERENCES status_pages(id) ON DELETE CASCADE +); + +-- +goose Down +DROP TABLE IF EXISTS user_status_pages; diff --git a/internal/db/store.go b/internal/db/store.go index 0eb5c1a..14965ca 100644 --- a/internal/db/store.go +++ b/internal/db/store.go @@ -198,6 +198,7 @@ func (s *Store) seed() error { // allowedResetTables is a whitelist of table names that can be dropped during reset. // SECURITY: This prevents potential SQL injection if table names were ever derived from user input. var allowedResetTables = map[string]bool{ + "user_status_pages": true, "users": true, "sessions": true, "groups": true, @@ -228,6 +229,7 @@ func (s *Store) Reset() error { } tables := []string{ + "user_status_pages", "users", "sessions", "groups", "monitors", "monitor_checks", "monitor_events", "status_pages", "api_keys", "settings", "monitor_outages", "notification_channels", "incidents", diff --git a/internal/db/store_api_keys.go b/internal/db/store_api_keys.go index c18a203..51b2791 100644 --- a/internal/db/store_api_keys.go +++ b/internal/db/store_api_keys.go @@ -13,11 +13,15 @@ type APIKey struct { ID int64 `json:"id"` KeyPrefix string `json:"keyPrefix"` Name string `json:"name"` + Role string `json:"role"` CreatedAt time.Time `json:"createdAt"` LastUsed *time.Time `json:"lastUsed,omitempty"` } -func (s *Store) CreateAPIKey(name string) (string, error) { +func (s *Store) CreateAPIKey(name, role string) (string, error) { + if role == "" { + role = "editor" + } // Generate random key with 256-bit entropy (32 bytes) // SECURITY: 256 bits provides adequate security strength for long-lived credentials keyBytes := make([]byte, 32) @@ -33,8 +37,8 @@ func (s *Store) CreateAPIKey(name string) (string, error) { return "", err } - _, err = s.db.Exec(s.rebind("INSERT INTO api_keys (key_prefix, key_hash, name) VALUES (?, ?, ?)"), - prefix, string(hash), name) + _, err = s.db.Exec(s.rebind("INSERT INTO api_keys (key_prefix, key_hash, name, role) VALUES (?, ?, ?, ?)"), + prefix, string(hash), name, role) if err != nil { return "", err } @@ -43,7 +47,7 @@ func (s *Store) CreateAPIKey(name string) (string, error) { } func (s *Store) ListAPIKeys() ([]APIKey, error) { - rows, err := s.db.Query("SELECT id, key_prefix, name, created_at, last_used_at FROM api_keys ORDER BY created_at DESC") + rows, err := s.db.Query("SELECT id, key_prefix, name, COALESCE(role, 'editor'), created_at, last_used_at FROM api_keys ORDER BY created_at DESC") if err != nil { return nil, err } @@ -53,7 +57,7 @@ func (s *Store) ListAPIKeys() ([]APIKey, error) { for rows.Next() { var k APIKey var lastUsed sql.NullTime - if err := rows.Scan(&k.ID, &k.KeyPrefix, &k.Name, &k.CreatedAt, &lastUsed); err != nil { + if err := rows.Scan(&k.ID, &k.KeyPrefix, &k.Name, &k.Role, &k.CreatedAt, &lastUsed); err != nil { return nil, err } if lastUsed.Valid { @@ -69,36 +73,35 @@ func (s *Store) DeleteAPIKey(id int64) error { return err } -func (s *Store) ValidateAPIKey(key string) (bool, error) { +// ValidateAPIKey validates the key and returns (valid, role, error). +func (s *Store) ValidateAPIKey(key string) (bool, string, error) { if len(key) < 12 { - return false, nil + return false, "", nil } prefix := key[:12] // Find candidates by prefix - rows, err := s.db.Query(s.rebind("SELECT id, key_hash FROM api_keys WHERE key_prefix = ?"), prefix) + rows, err := s.db.Query(s.rebind("SELECT id, key_hash, COALESCE(role, 'editor') FROM api_keys WHERE key_prefix = ?"), prefix) if err != nil { - return false, err + return false, "", err } defer func() { _ = rows.Close() }() for rows.Next() { var id int64 - var hash string - if err := rows.Scan(&id, &hash); err != nil { + var hash, role string + if err := rows.Scan(&id, &hash, &role); err != nil { continue } if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(key)); err == nil { // update last used async go func(keyId int64) { - // Create a new generic db execution context or ignore error - // Since we are inside store method, s.db is safe to use concurrently? sql.DB is threadsafe. _, _ = s.db.Exec(s.rebind("UPDATE api_keys SET last_used_at = CURRENT_TIMESTAMP WHERE id = ?"), keyId) }(id) - return true, nil + return true, role, nil } } - return false, nil + return false, "", nil } diff --git a/internal/db/store_api_keys_test.go b/internal/db/store_api_keys_test.go index 4ca7df3..8be234c 100644 --- a/internal/db/store_api_keys_test.go +++ b/internal/db/store_api_keys_test.go @@ -8,7 +8,7 @@ func TestAPIKeys(t *testing.T) { s := newTestStore(t) // Create - key, err := s.CreateAPIKey("Test Key") + key, err := s.CreateAPIKey("Test Key", "editor") if err != nil { t.Fatalf("CreateAPIKey failed: %v", err) } @@ -17,7 +17,7 @@ func TestAPIKeys(t *testing.T) { } // Validate Access - valid, err := s.ValidateAPIKey(key) + valid, _, err := s.ValidateAPIKey(key) if err != nil { t.Fatalf("ValidateAPIKey failed: %v", err) } @@ -26,7 +26,7 @@ func TestAPIKeys(t *testing.T) { } // Validate Fail - valid, _ = s.ValidateAPIKey("sk_live_WRONG") + valid, _, _ = s.ValidateAPIKey("sk_live_WRONG") if valid { t.Error("Expected invalid key to be rejected") } @@ -46,7 +46,7 @@ func TestAPIKeys(t *testing.T) { } // Verify Gone - valid, _ = s.ValidateAPIKey(key) + valid, _, _ = s.ValidateAPIKey(key) if valid { t.Error("Key should be invalid after deletion") } diff --git a/internal/db/store_multidb_test.go b/internal/db/store_multidb_test.go index d6ced58..4d80063 100644 --- a/internal/db/store_multidb_test.go +++ b/internal/db/store_multidb_test.go @@ -104,7 +104,7 @@ func TestMultiDB_MonitorCRUD(t *testing.T) { func TestMultiDB_UserCRUD(t *testing.T) { RunTestWithBothDBs(t, "UserCRUD", func(t *testing.T, s *Store) { // Create user - if err := s.CreateUser("testuser", "password123", "UTC"); err != nil { + if err := s.CreateUser("testuser", "password123", "UTC", "admin"); err != nil { t.Fatalf("CreateUser failed: %v", err) } @@ -141,7 +141,7 @@ func TestMultiDB_UserCRUD(t *testing.T) { func TestMultiDB_Sessions(t *testing.T) { RunTestWithBothDBs(t, "Sessions", func(t *testing.T, s *Store) { // Create user first - if err := s.CreateUser("sessionuser", "password123", "UTC"); err != nil { + if err := s.CreateUser("sessionuser", "password123", "UTC", "admin"); err != nil { t.Fatalf("CreateUser failed: %v", err) } user, _ := s.Authenticate("sessionuser", "password123") @@ -271,7 +271,7 @@ func TestMultiDB_Incidents(t *testing.T) { func TestMultiDB_APIKeys(t *testing.T) { RunTestWithBothDBs(t, "APIKeys", func(t *testing.T, s *Store) { // Create API key - key, err := s.CreateAPIKey("Test Key") + key, err := s.CreateAPIKey("Test Key", "editor") if err != nil { t.Fatalf("CreateAPIKey failed: %v", err) } @@ -280,7 +280,7 @@ func TestMultiDB_APIKeys(t *testing.T) { } // Validate key - valid, err := s.ValidateAPIKey(key) + valid, _, err := s.ValidateAPIKey(key) if err != nil { t.Fatalf("ValidateAPIKey failed: %v", err) } @@ -289,7 +289,7 @@ func TestMultiDB_APIKeys(t *testing.T) { } // Invalid key should fail - valid, _ = s.ValidateAPIKey("sk_live_INVALID") + valid, _, _ = s.ValidateAPIKey("sk_live_INVALID") if valid { t.Error("Expected invalid key to be rejected") } diff --git a/internal/db/store_status_pages.go b/internal/db/store_status_pages.go index 64ab5fe..89b1935 100644 --- a/internal/db/store_status_pages.go +++ b/internal/db/store_status_pages.go @@ -2,6 +2,7 @@ package db import ( "database/sql" + "strings" "time" ) @@ -170,6 +171,50 @@ func (s *Store) UpsertStatusPageFull(input StatusPageInput) error { return err } +// GetStatusPagesByIDs returns status pages matching the given IDs. +func (s *Store) GetStatusPagesByIDs(ids []int64) ([]StatusPage, error) { + if len(ids) == 0 { + return nil, nil + } + + // Build placeholders + placeholders := make([]string, len(ids)) + args := make([]interface{}, len(ids)) + for i, id := range ids { + placeholders[i] = "?" + args[i] = id + } + query := `SELECT id, slug, title, group_id, public, enabled, created_at, + COALESCE(description, ''), COALESCE(logo_url, ''), COALESCE(favicon_url, ''), COALESCE(accent_color, ''), COALESCE(theme, 'system'), + COALESCE(show_uptime_bars, TRUE), COALESCE(show_uptime_percentage, TRUE), COALESCE(show_incident_history, TRUE), + COALESCE(uptime_days_range, 90), COALESCE(header_content, 'logo-title'), COALESCE(header_alignment, 'center'), COALESCE(header_arrangement, 'inline') + FROM status_pages WHERE id IN (` + strings.Join(placeholders, ",") + `)` + + rows, err := s.db.Query(s.rebind(query), args...) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + + var pages []StatusPage + for rows.Next() { + var p StatusPage + var groupID sql.NullString + if err := rows.Scan(&p.ID, &p.Slug, &p.Title, &groupID, &p.Public, &p.Enabled, &p.CreatedAt, + &p.Description, &p.LogoURL, &p.FaviconURL, &p.AccentColor, &p.Theme, + &p.ShowUptimeBars, &p.ShowUptimePercentage, &p.ShowIncidentHistory, &p.UptimeDaysRange, + &p.HeaderContent, &p.HeaderAlignment, &p.HeaderArrangement); err != nil { + return nil, err + } + if groupID.Valid { + s := groupID.String + p.GroupID = &s + } + pages = append(pages, p) + } + return pages, nil +} + // ToggleStatusPage toggles the public status func (s *Store) ToggleStatusPage(slug string, public bool) error { _, err := s.db.Exec(s.rebind("UPDATE status_pages SET public = ? WHERE slug = ?"), public, slug) diff --git a/internal/db/store_users.go b/internal/db/store_users.go index 0ead2fa..996e317 100644 --- a/internal/db/store_users.go +++ b/internal/db/store_users.go @@ -28,6 +28,7 @@ type User struct { SSOID string AvatarURL string DisplayName string + Role string } type Session struct { @@ -40,8 +41,8 @@ func (s *Store) Authenticate(username, password string) (*User, error) { // username = strings.ToLower(strings.TrimSpace(username)) // REMOVED for Strict Mode username = strings.TrimSpace(username) // Only trim valid white space var u User - row := s.db.QueryRow(s.rebind("SELECT id, username, password_hash, created_at, COALESCE(timezone, 'UTC') FROM users WHERE username = ?"), username) - err := row.Scan(&u.ID, &u.Username, &u.Password, &u.CreatedAt, &u.Timezone) + row := s.db.QueryRow(s.rebind("SELECT id, username, password_hash, created_at, COALESCE(timezone, 'UTC'), COALESCE(role, 'admin') FROM users WHERE username = ?"), username) + err := row.Scan(&u.ID, &u.Username, &u.Password, &u.CreatedAt, &u.Timezone, &u.Role) if err == sql.ErrNoRows { return nil, ErrUserNotFound } @@ -77,8 +78,8 @@ func (s *Store) GetSession(token string) (*Session, error) { func (s *Store) GetUser(id int64) (*User, error) { var u User var email, ssoProvider, ssoID, avatarURL, displayName sql.NullString - row := s.db.QueryRow(s.rebind("SELECT id, username, created_at, COALESCE(timezone, 'UTC'), email, sso_provider, sso_id, avatar_url, display_name FROM users WHERE id = ?"), id) - err := row.Scan(&u.ID, &u.Username, &u.CreatedAt, &u.Timezone, &email, &ssoProvider, &ssoID, &avatarURL, &displayName) + row := s.db.QueryRow(s.rebind("SELECT id, username, created_at, COALESCE(timezone, 'UTC'), email, sso_provider, sso_id, avatar_url, display_name, COALESCE(role, 'admin') FROM users WHERE id = ?"), id) + err := row.Scan(&u.ID, &u.Username, &u.CreatedAt, &u.Timezone, &email, &ssoProvider, &ssoID, &avatarURL, &displayName, &u.Role) if err != nil { return nil, err } @@ -115,15 +116,17 @@ func (s *Store) IsSetupComplete() (bool, error) { return isComplete, err } -// CreateUser creates a new user. -func (s *Store) CreateUser(username, password, timezone string) error { +// CreateUser creates a new user with the specified role. +func (s *Store) CreateUser(username, password, timezone, role string) error { username = strings.ToLower(strings.TrimSpace(username)) + if role == "" { + role = "admin" + } hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) if err != nil { return err } - // Using context if we wanted to enforce timeouts, but standard Exec is fine for now - _, err = s.db.Exec(s.rebind("INSERT INTO users (username, password_hash, timezone) VALUES (?, ?, ?)"), username, string(hash), timezone) + _, err = s.db.Exec(s.rebind("INSERT INTO users (username, password_hash, timezone, role) VALUES (?, ?, ?, ?)"), username, string(hash), timezone, role) return err } @@ -175,8 +178,8 @@ func (s *Store) DeleteUserSessions(userID int64, exceptToken string) error { func (s *Store) GetUserByEmail(email string) (*User, error) { var u User var emailVal, ssoProvider, ssoID, avatarURL, displayName sql.NullString - row := s.db.QueryRow(s.rebind("SELECT id, username, created_at, COALESCE(timezone, 'UTC'), email, sso_provider, sso_id, avatar_url, display_name FROM users WHERE email = ?"), email) - err := row.Scan(&u.ID, &u.Username, &u.CreatedAt, &u.Timezone, &emailVal, &ssoProvider, &ssoID, &avatarURL, &displayName) + row := s.db.QueryRow(s.rebind("SELECT id, username, created_at, COALESCE(timezone, 'UTC'), email, sso_provider, sso_id, avatar_url, display_name, COALESCE(role, 'admin') FROM users WHERE email = ?"), email) + err := row.Scan(&u.ID, &u.Username, &u.CreatedAt, &u.Timezone, &emailVal, &ssoProvider, &ssoID, &avatarURL, &displayName, &u.Role) if err == sql.ErrNoRows { return nil, ErrUserNotFound } @@ -208,8 +211,8 @@ func (s *Store) FindOrCreateSSOUser(provider, ssoID, email, name, avatarURL stri // First, try to find by SSO provider and ID var u User var emailVal, ssoProvider, ssoIDVal, avatarVal, displayNameVal sql.NullString - row := tx.QueryRow(s.rebind("SELECT id, username, created_at, COALESCE(timezone, 'UTC'), email, sso_provider, sso_id, avatar_url, display_name FROM users WHERE sso_provider = ? AND sso_id = ?"), provider, ssoID) - err = row.Scan(&u.ID, &u.Username, &u.CreatedAt, &u.Timezone, &emailVal, &ssoProvider, &ssoIDVal, &avatarVal, &displayNameVal) + row := tx.QueryRow(s.rebind("SELECT id, username, created_at, COALESCE(timezone, 'UTC'), email, sso_provider, sso_id, avatar_url, display_name, COALESCE(role, 'admin') FROM users WHERE sso_provider = ? AND sso_id = ?"), provider, ssoID) + err = row.Scan(&u.ID, &u.Username, &u.CreatedAt, &u.Timezone, &emailVal, &ssoProvider, &ssoIDVal, &avatarVal, &displayNameVal, &u.Role) if err == nil { // Found existing SSO user - update avatar and display_name if changed if avatarURL != "" || name != "" { @@ -234,8 +237,8 @@ func (s *Store) FindOrCreateSSOUser(provider, ssoID, email, name, avatarURL stri // Not found by SSO, try to find by email (within transaction) var existingUser User var existingEmailVal, existingSSOProvider, existingSSOID, existingAvatarURL, existingDisplayName sql.NullString - row = tx.QueryRow(s.rebind("SELECT id, username, created_at, COALESCE(timezone, 'UTC'), email, sso_provider, sso_id, avatar_url, display_name FROM users WHERE email = ?"), email) - err = row.Scan(&existingUser.ID, &existingUser.Username, &existingUser.CreatedAt, &existingUser.Timezone, &existingEmailVal, &existingSSOProvider, &existingSSOID, &existingAvatarURL, &existingDisplayName) + row = tx.QueryRow(s.rebind("SELECT id, username, created_at, COALESCE(timezone, 'UTC'), email, sso_provider, sso_id, avatar_url, display_name, COALESCE(role, 'admin') FROM users WHERE email = ?"), email) + err = row.Scan(&existingUser.ID, &existingUser.Username, &existingUser.CreatedAt, &existingUser.Timezone, &existingEmailVal, &existingSSOProvider, &existingSSOID, &existingAvatarURL, &existingDisplayName, &existingUser.Role) if err == nil { // Found user by email - check if they have a password // SECURITY: Do not automatically link SSO to existing accounts with passwords. @@ -308,13 +311,13 @@ func (s *Store) FindOrCreateSSOUser(provider, ssoID, email, name, avatarURL stri counter++ } - // Insert new user with empty password (SSO-only user) + // Insert new user with empty password (SSO-only user), default role = viewer var newID int64 if s.IsPostgres() { - err = tx.QueryRow("INSERT INTO users (username, password_hash, email, sso_provider, sso_id, avatar_url, display_name) VALUES ($1, '', $2, $3, $4, $5, $6) RETURNING id", + err = tx.QueryRow("INSERT INTO users (username, password_hash, email, sso_provider, sso_id, avatar_url, display_name, role) VALUES ($1, '', $2, $3, $4, $5, $6, 'viewer') RETURNING id", username, email, provider, ssoID, avatarURL, name).Scan(&newID) } else { - result, execErr := tx.Exec("INSERT INTO users (username, password_hash, email, sso_provider, sso_id, avatar_url, display_name) VALUES (?, '', ?, ?, ?, ?, ?)", + result, execErr := tx.Exec("INSERT INTO users (username, password_hash, email, sso_provider, sso_id, avatar_url, display_name, role) VALUES (?, '', ?, ?, ?, ?, ?, 'viewer')", username, email, provider, ssoID, avatarURL, name) if execErr != nil { return nil, execErr @@ -339,8 +342,101 @@ func (s *Store) FindOrCreateSSOUser(provider, ssoID, email, name, avatarURL stri AvatarURL: avatarURL, DisplayName: name, Timezone: "UTC", + Role: "viewer", }, nil } +// GetUserRole returns the role for a given user ID. +func (s *Store) GetUserRole(id int64) (string, error) { + var role string + err := s.db.QueryRow(s.rebind("SELECT COALESCE(role, 'admin') FROM users WHERE id = ?"), id).Scan(&role) + if err == sql.ErrNoRows { + return "", ErrUserNotFound + } + return role, err +} + +// ListUsers returns all users (passwords redacted). +func (s *Store) ListUsers() ([]User, error) { + rows, err := s.db.Query("SELECT id, username, created_at, COALESCE(timezone, 'UTC'), COALESCE(email, ''), COALESCE(sso_provider, ''), COALESCE(avatar_url, ''), COALESCE(display_name, ''), COALESCE(role, 'admin') FROM users ORDER BY id") + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + + var users []User + for rows.Next() { + var u User + if err := rows.Scan(&u.ID, &u.Username, &u.CreatedAt, &u.Timezone, &u.Email, &u.SSOProvider, &u.AvatarURL, &u.DisplayName, &u.Role); err != nil { + return nil, err + } + users = append(users, u) + } + return users, nil +} + +// UpdateUserRole changes a user's role. +func (s *Store) UpdateUserRole(id int64, role string) error { + _, err := s.db.Exec(s.rebind("UPDATE users SET role = ? WHERE id = ?"), role, id) + return err +} + +// DeleteUser removes a user by ID. +func (s *Store) DeleteUser(id int64) error { + // Also delete their sessions + _, _ = s.db.Exec(s.rebind("DELETE FROM sessions WHERE user_id = ?"), id) + _, err := s.db.Exec(s.rebind("DELETE FROM users WHERE id = ?"), id) + return err +} + +// CountAdmins returns the number of users with the admin role. +func (s *Store) CountAdmins() (int, error) { + var count int + err := s.db.QueryRow("SELECT COUNT(*) FROM users WHERE role = 'admin'").Scan(&count) + return count, err +} + +// GetUserStatusPages returns the status page IDs assigned to a user. +func (s *Store) GetUserStatusPages(userID int64) ([]int64, error) { + rows, err := s.db.Query(s.rebind("SELECT status_page_id FROM user_status_pages WHERE user_id = ?"), userID) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + + var ids []int64 + for rows.Next() { + var id int64 + if err := rows.Scan(&id); err != nil { + return nil, err + } + ids = append(ids, id) + } + return ids, nil +} + +// SetUserStatusPages replaces the status page assignments for a user. +func (s *Store) SetUserStatusPages(userID int64, pageIDs []int64) error { + tx, err := s.db.Begin() + if err != nil { + return err + } + defer func() { _ = tx.Rollback() }() + + _, err = tx.Exec(s.rebind("DELETE FROM user_status_pages WHERE user_id = ?"), userID) + if err != nil { + return err + } + + for _, pid := range pageIDs { + _, err = tx.Exec(s.rebind("INSERT INTO user_status_pages (user_id, status_page_id) VALUES (?, ?)"), userID, pid) + if err != nil { + return err + } + } + + return tx.Commit() +} + // Just to avoid unused import error for context if not used var _ = context.Background diff --git a/internal/db/store_users_rbac_test.go b/internal/db/store_users_rbac_test.go new file mode 100644 index 0000000..28989a8 --- /dev/null +++ b/internal/db/store_users_rbac_test.go @@ -0,0 +1,592 @@ +package db + +import ( + "testing" + "time" +) + +func TestCreateUserWithRole(t *testing.T) { + s := newTestStore(t) + + tests := []struct { + username string + role string + }{ + {"admin_user", "admin"}, + {"editor_user", "editor"}, + {"viewer_user", "viewer"}, + {"sv_user", "status_viewer"}, + } + + for _, tc := range tests { + t.Run(tc.role, func(t *testing.T) { + err := s.CreateUser(tc.username, "password123", "UTC", tc.role) + if err != nil { + t.Fatalf("CreateUser(%s, %s) failed: %v", tc.username, tc.role, err) + } + + user, err := s.Authenticate(tc.username, "password123") + if err != nil { + t.Fatalf("Authenticate failed: %v", err) + } + if user.Role != tc.role { + t.Errorf("Expected role %s, got %s", tc.role, user.Role) + } + }) + } +} + +func TestCreateUserEmptyRoleDefaultsToAdmin(t *testing.T) { + s := newTestStore(t) + + err := s.CreateUser("default_role_user", "password123", "UTC", "") + if err != nil { + t.Fatalf("CreateUser failed: %v", err) + } + + user, err := s.Authenticate("default_role_user", "password123") + if err != nil { + t.Fatalf("Authenticate failed: %v", err) + } + if user.Role != "admin" { + t.Errorf("Expected default role 'admin', got %s", user.Role) + } +} + +func TestGetUserRole(t *testing.T) { + s := newTestStore(t) + + err := s.CreateUser("role_test", "password123", "UTC", "editor") + if err != nil { + t.Fatalf("CreateUser failed: %v", err) + } + + user, err := s.Authenticate("role_test", "password123") + if err != nil { + t.Fatalf("Authenticate failed: %v", err) + } + + role, err := s.GetUserRole(user.ID) + if err != nil { + t.Fatalf("GetUserRole failed: %v", err) + } + if role != "editor" { + t.Errorf("Expected role 'editor', got %s", role) + } +} + +func TestGetUserRole_NotFound(t *testing.T) { + s := newTestStore(t) + + _, err := s.GetUserRole(99999) + if err != ErrUserNotFound { + t.Errorf("Expected ErrUserNotFound, got %v", err) + } +} + +func TestListUsersWithRoles(t *testing.T) { + s := newTestStore(t) + + // Create users with different roles + err := s.CreateUser("admin1", "pass", "UTC", "admin") + if err != nil { + t.Fatalf("CreateUser admin1 failed: %v", err) + } + err = s.CreateUser("editor1", "pass", "UTC", "editor") + if err != nil { + t.Fatalf("CreateUser editor1 failed: %v", err) + } + err = s.CreateUser("viewer1", "pass", "UTC", "viewer") + if err != nil { + t.Fatalf("CreateUser viewer1 failed: %v", err) + } + err = s.CreateUser("sv1", "pass", "UTC", "status_viewer") + if err != nil { + t.Fatalf("CreateUser sv1 failed: %v", err) + } + + users, err := s.ListUsers() + if err != nil { + t.Fatalf("ListUsers failed: %v", err) + } + if len(users) != 4 { + t.Fatalf("Expected 4 users, got %d", len(users)) + } + + // Build a map for easy lookup + roleMap := make(map[string]string) + for _, u := range users { + roleMap[u.Username] = u.Role + } + + expected := map[string]string{ + "admin1": "admin", + "editor1": "editor", + "viewer1": "viewer", + "sv1": "status_viewer", + } + for username, expectedRole := range expected { + if got, ok := roleMap[username]; !ok { + t.Errorf("User %s not found in list", username) + } else if got != expectedRole { + t.Errorf("User %s: expected role %s, got %s", username, expectedRole, got) + } + } +} + +func TestUpdateUserRole(t *testing.T) { + s := newTestStore(t) + + err := s.CreateUser("role_change", "pass", "UTC", "viewer") + if err != nil { + t.Fatalf("CreateUser failed: %v", err) + } + + user, err := s.Authenticate("role_change", "pass") + if err != nil { + t.Fatalf("Authenticate failed: %v", err) + } + + // Verify initial role + role, err := s.GetUserRole(user.ID) + if err != nil { + t.Fatalf("GetUserRole failed: %v", err) + } + if role != "viewer" { + t.Errorf("Expected initial role 'viewer', got %s", role) + } + + // Update to editor + err = s.UpdateUserRole(user.ID, "editor") + if err != nil { + t.Fatalf("UpdateUserRole failed: %v", err) + } + + role, err = s.GetUserRole(user.ID) + if err != nil { + t.Fatalf("GetUserRole after update failed: %v", err) + } + if role != "editor" { + t.Errorf("Expected role 'editor' after update, got %s", role) + } + + // Update to admin + err = s.UpdateUserRole(user.ID, "admin") + if err != nil { + t.Fatalf("UpdateUserRole to admin failed: %v", err) + } + + role, err = s.GetUserRole(user.ID) + if err != nil { + t.Fatalf("GetUserRole after second update failed: %v", err) + } + if role != "admin" { + t.Errorf("Expected role 'admin' after update, got %s", role) + } + + // Also verify via GetUser + u, err := s.GetUser(user.ID) + if err != nil { + t.Fatalf("GetUser failed: %v", err) + } + if u.Role != "admin" { + t.Errorf("GetUser: expected role 'admin', got %s", u.Role) + } +} + +func TestDeleteUser(t *testing.T) { + s := newTestStore(t) + + err := s.CreateUser("to_delete", "pass", "UTC", "editor") + if err != nil { + t.Fatalf("CreateUser failed: %v", err) + } + + user, err := s.Authenticate("to_delete", "pass") + if err != nil { + t.Fatalf("Authenticate failed: %v", err) + } + + // Create a session for the user + token := "delete-test-token" + expires := time.Now().Add(1 * time.Hour) + if err := s.CreateSession(user.ID, token, expires); err != nil { + t.Fatalf("CreateSession failed: %v", err) + } + + // Verify session exists + sess, err := s.GetSession(token) + if err != nil { + t.Fatalf("GetSession failed: %v", err) + } + if sess == nil { + t.Fatal("Expected session to exist before delete") + } + + // Delete the user + err = s.DeleteUser(user.ID) + if err != nil { + t.Fatalf("DeleteUser failed: %v", err) + } + + // Verify user is gone + _, err = s.GetUser(user.ID) + if err == nil { + t.Error("Expected error getting deleted user, got nil") + } + + // Verify sessions are cleaned up + sess, err = s.GetSession(token) + if err != nil { + t.Fatalf("GetSession after delete failed: %v", err) + } + if sess != nil { + t.Error("Expected session to be deleted with user") + } +} + +func TestCountAdmins(t *testing.T) { + s := newTestStore(t) + + // No users yet + count, err := s.CountAdmins() + if err != nil { + t.Fatalf("CountAdmins failed: %v", err) + } + if count != 0 { + t.Errorf("Expected 0 admins, got %d", count) + } + + // Create one admin + err = s.CreateUser("admin1", "pass", "UTC", "admin") + if err != nil { + t.Fatalf("CreateUser failed: %v", err) + } + count, err = s.CountAdmins() + if err != nil { + t.Fatalf("CountAdmins failed: %v", err) + } + if count != 1 { + t.Errorf("Expected 1 admin, got %d", count) + } + + // Create another admin + err = s.CreateUser("admin2", "pass", "UTC", "admin") + if err != nil { + t.Fatalf("CreateUser failed: %v", err) + } + count, err = s.CountAdmins() + if err != nil { + t.Fatalf("CountAdmins failed: %v", err) + } + if count != 2 { + t.Errorf("Expected 2 admins, got %d", count) + } + + // Create non-admin users - count should not change + err = s.CreateUser("editor1", "pass", "UTC", "editor") + if err != nil { + t.Fatalf("CreateUser editor failed: %v", err) + } + err = s.CreateUser("viewer1", "pass", "UTC", "viewer") + if err != nil { + t.Fatalf("CreateUser viewer failed: %v", err) + } + err = s.CreateUser("sv1", "pass", "UTC", "status_viewer") + if err != nil { + t.Fatalf("CreateUser status_viewer failed: %v", err) + } + + count, err = s.CountAdmins() + if err != nil { + t.Fatalf("CountAdmins failed: %v", err) + } + if count != 2 { + t.Errorf("Expected 2 admins (non-admin users should not be counted), got %d", count) + } +} + +func TestUserStatusPages(t *testing.T) { + s := newTestStore(t) + + // Create a user + err := s.CreateUser("sp_user", "pass", "UTC", "status_viewer") + if err != nil { + t.Fatalf("CreateUser failed: %v", err) + } + user, err := s.Authenticate("sp_user", "pass") + if err != nil { + t.Fatalf("Authenticate failed: %v", err) + } + + // Initially no status pages assigned + ids, err := s.GetUserStatusPages(user.ID) + if err != nil { + t.Fatalf("GetUserStatusPages failed: %v", err) + } + if len(ids) != 0 { + t.Errorf("Expected 0 status pages initially, got %d", len(ids)) + } + + // Create status pages for assignment + err = s.UpsertStatusPageFull(StatusPageInput{ + Slug: "test-page-1", + Title: "Test Page 1", + Enabled: true, + Theme: "system", + }) + if err != nil { + t.Fatalf("UpsertStatusPageFull 1 failed: %v", err) + } + err = s.UpsertStatusPageFull(StatusPageInput{ + Slug: "test-page-2", + Title: "Test Page 2", + Enabled: true, + Theme: "system", + }) + if err != nil { + t.Fatalf("UpsertStatusPageFull 2 failed: %v", err) + } + + sp1, err := s.GetStatusPageBySlug("test-page-1") + if err != nil { + t.Fatalf("GetStatusPageBySlug 1 failed: %v", err) + } + if sp1 == nil { + t.Fatal("Status page 1 not found") + return + } + + sp2, err := s.GetStatusPageBySlug("test-page-2") + if err != nil { + t.Fatalf("GetStatusPageBySlug 2 failed: %v", err) + } + if sp2 == nil { + t.Fatal("Status page 2 not found") + return + } + + sp1ID := sp1.ID + sp2ID := sp2.ID + + // Assign both status pages + err = s.SetUserStatusPages(user.ID, []int64{sp1ID, sp2ID}) + if err != nil { + t.Fatalf("SetUserStatusPages failed: %v", err) + } + + // Verify assignment + ids, err = s.GetUserStatusPages(user.ID) + if err != nil { + t.Fatalf("GetUserStatusPages failed: %v", err) + } + if len(ids) != 2 { + t.Fatalf("Expected 2 status pages, got %d", len(ids)) + } + + // Verify the IDs match (order may vary) + idSet := make(map[int64]bool) + for _, id := range ids { + idSet[id] = true + } + if !idSet[sp1ID] { + t.Errorf("Expected status page ID %d in assignments", sp1ID) + } + if !idSet[sp2ID] { + t.Errorf("Expected status page ID %d in assignments", sp2ID) + } + + // Replace with just one status page + err = s.SetUserStatusPages(user.ID, []int64{sp1ID}) + if err != nil { + t.Fatalf("SetUserStatusPages (replace) failed: %v", err) + } + + ids, err = s.GetUserStatusPages(user.ID) + if err != nil { + t.Fatalf("GetUserStatusPages after replace failed: %v", err) + } + if len(ids) != 1 { + t.Fatalf("Expected 1 status page after replace, got %d", len(ids)) + } + if ids[0] != sp1ID { + t.Errorf("Expected status page ID %d, got %d", sp1ID, ids[0]) + } + + // Clear all assignments + err = s.SetUserStatusPages(user.ID, []int64{}) + if err != nil { + t.Fatalf("SetUserStatusPages (clear) failed: %v", err) + } + + ids, err = s.GetUserStatusPages(user.ID) + if err != nil { + t.Fatalf("GetUserStatusPages after clear failed: %v", err) + } + if len(ids) != 0 { + t.Errorf("Expected 0 status pages after clear, got %d", len(ids)) + } +} + +func TestFindOrCreateSSOUser_DefaultViewerRole(t *testing.T) { + s := newTestStore(t) + + user, err := s.FindOrCreateSSOUser("google", "sso-id-123", "sso@example.com", "SSO User", "https://example.com/avatar.jpg", true) + if err != nil { + t.Fatalf("FindOrCreateSSOUser failed: %v", err) + } + + if user.Role != "viewer" { + t.Errorf("Expected new SSO user to have role 'viewer', got %s", user.Role) + } + + // Verify via GetUser + fetched, err := s.GetUser(user.ID) + if err != nil { + t.Fatalf("GetUser failed: %v", err) + } + if fetched.Role != "viewer" { + t.Errorf("GetUser: expected role 'viewer', got %s", fetched.Role) + } + + // Verify via GetUserRole + role, err := s.GetUserRole(user.ID) + if err != nil { + t.Fatalf("GetUserRole failed: %v", err) + } + if role != "viewer" { + t.Errorf("GetUserRole: expected 'viewer', got %s", role) + } +} + +func TestFindOrCreateSSOUser_ExistingUserKeepsRole(t *testing.T) { + s := newTestStore(t) + + // Create SSO user first time (gets viewer role) + user, err := s.FindOrCreateSSOUser("google", "sso-id-456", "keep@example.com", "Keeper", "", true) + if err != nil { + t.Fatalf("FindOrCreateSSOUser (first) failed: %v", err) + } + if user.Role != "viewer" { + t.Fatalf("Expected initial role 'viewer', got %s", user.Role) + } + + // Promote to admin + err = s.UpdateUserRole(user.ID, "admin") + if err != nil { + t.Fatalf("UpdateUserRole failed: %v", err) + } + + // Login again via SSO - should keep admin role + user2, err := s.FindOrCreateSSOUser("google", "sso-id-456", "keep@example.com", "Keeper", "", true) + if err != nil { + t.Fatalf("FindOrCreateSSOUser (second) failed: %v", err) + } + if user2.Role != "admin" { + t.Errorf("Expected existing SSO user to keep role 'admin', got %s", user2.Role) + } +} + +func TestAPIKeyWithRole(t *testing.T) { + s := newTestStore(t) + + // Create API key with editor role + key, err := s.CreateAPIKey("Editor Key", "editor") + if err != nil { + t.Fatalf("CreateAPIKey (editor) failed: %v", err) + } + + valid, role, err := s.ValidateAPIKey(key) + if err != nil { + t.Fatalf("ValidateAPIKey failed: %v", err) + } + if !valid { + t.Fatal("Expected key to be valid") + } + if role != "editor" { + t.Errorf("Expected role 'editor', got %s", role) + } + + // Create API key with viewer role + key2, err := s.CreateAPIKey("Viewer Key", "viewer") + if err != nil { + t.Fatalf("CreateAPIKey (viewer) failed: %v", err) + } + + valid, role, err = s.ValidateAPIKey(key2) + if err != nil { + t.Fatalf("ValidateAPIKey (viewer) failed: %v", err) + } + if !valid { + t.Fatal("Expected key to be valid") + } + if role != "viewer" { + t.Errorf("Expected role 'viewer', got %s", role) + } + + // Create API key with admin role + key3, err := s.CreateAPIKey("Admin Key", "admin") + if err != nil { + t.Fatalf("CreateAPIKey (admin) failed: %v", err) + } + + valid, role, err = s.ValidateAPIKey(key3) + if err != nil { + t.Fatalf("ValidateAPIKey (admin) failed: %v", err) + } + if !valid { + t.Fatal("Expected key to be valid") + } + if role != "admin" { + t.Errorf("Expected role 'admin', got %s", role) + } +} + +func TestAPIKeyDefaultRole(t *testing.T) { + s := newTestStore(t) + + // Create API key with empty role - should default to editor + key, err := s.CreateAPIKey("Default Role Key", "") + if err != nil { + t.Fatalf("CreateAPIKey (empty) failed: %v", err) + } + + valid, role, err := s.ValidateAPIKey(key) + if err != nil { + t.Fatalf("ValidateAPIKey failed: %v", err) + } + if !valid { + t.Fatal("Expected key to be valid") + } + if role != "editor" { + t.Errorf("Expected default role 'editor', got %s", role) + } +} + +func TestAPIKeyInvalidKey(t *testing.T) { + s := newTestStore(t) + + // Short key + valid, role, err := s.ValidateAPIKey("short") + if err != nil { + t.Fatalf("ValidateAPIKey (short) failed: %v", err) + } + if valid { + t.Error("Expected short key to be invalid") + } + if role != "" { + t.Errorf("Expected empty role for invalid key, got %s", role) + } + + // Non-existent key with proper prefix length + valid, role, err = s.ValidateAPIKey("sk_live_0000deadbeef0000") + if err != nil { + t.Fatalf("ValidateAPIKey (nonexistent) failed: %v", err) + } + if valid { + t.Error("Expected nonexistent key to be invalid") + } + if role != "" { + t.Errorf("Expected empty role for invalid key, got %s", role) + } +} diff --git a/internal/db/store_users_test.go b/internal/db/store_users_test.go index 0c24222..a1e6ee1 100644 --- a/internal/db/store_users_test.go +++ b/internal/db/store_users_test.go @@ -9,7 +9,7 @@ func TestUserLifecycle(t *testing.T) { s := newTestStore(t) // 1. Create User - err := s.CreateUser("admin", "secret123", "UTC") + err := s.CreateUser("admin", "secret123", "UTC", "admin") if err != nil { t.Fatalf("CreateUser failed: %v", err) } @@ -65,7 +65,7 @@ func TestUserLifecycle(t *testing.T) { func TestSessions(t *testing.T) { s := newTestStore(t) - _ = s.CreateUser("user1", "pass", "UTC") + _ = s.CreateUser("user1", "pass", "UTC", "admin") u, _ := s.Authenticate("user1", "pass") // Create Session @@ -107,7 +107,7 @@ func TestHasUsers(t *testing.T) { t.Error("Expected no users initially") } - _ = s.CreateUser("u", "p", "UTC") + _ = s.CreateUser("u", "p", "UTC", "admin") has, _ = s.HasUsers() if !has { diff --git a/internal/uptime/manager_test.go b/internal/uptime/manager_test.go index e63f70b..914d69c 100644 --- a/internal/uptime/manager_test.go +++ b/internal/uptime/manager_test.go @@ -418,7 +418,7 @@ func TestManager_UserTimezoneLoaded(t *testing.T) { m := NewManager(store) // Create a user with a specific timezone - if err := store.CreateUser("admin", "password123", "America/New_York"); err != nil { + if err := store.CreateUser("admin", "password123", "America/New_York", "admin"); err != nil { t.Fatalf("CreateUser failed: %v", err) } @@ -453,7 +453,7 @@ func TestManager_InvalidTimezoneHandling(t *testing.T) { // Create user with invalid timezone (edge case - shouldn't normally happen) // The CreateUser doesn't validate timezone, so we test the fallback - if err := store.CreateUser("admin", "password123", "Invalid/Timezone"); err != nil { + if err := store.CreateUser("admin", "password123", "Invalid/Timezone", "admin"); err != nil { t.Fatalf("CreateUser failed: %v", err) } diff --git a/web/package-lock.json b/web/package-lock.json index 00d3af3..37546ab 100644 --- a/web/package-lock.json +++ b/web/package-lock.json @@ -5159,9 +5159,9 @@ } }, "node_modules/flatted": { - "version": "3.4.1", - "resolved": "https://registry.npmjs.org/flatted/-/flatted-3.4.1.tgz", - "integrity": "sha512-IxfVbRFVlV8V/yRaGzk0UVIcsKKHMSfYw66T/u4nTwlWteQePsxe//LjudR1AMX4tZW3WFCh3Zqa/sjlqpbURQ==", + "version": "3.4.2", + "resolved": "https://registry.npmjs.org/flatted/-/flatted-3.4.2.tgz", + "integrity": "sha512-PjDse7RzhcPkIJwy5t7KPWQSZ9cAbzQXcafsetQoD7sOJRQlGikNbx7yZp2OotDnJyrDcbyRq3Ttb18iYOqkxA==", "dev": true, "license": "ISC" }, diff --git a/web/playwright.config.ts b/web/playwright.config.ts index 1d70c24..6d884ef 100644 --- a/web/playwright.config.ts +++ b/web/playwright.config.ts @@ -39,6 +39,27 @@ export default defineConfig({ use: { ...devices['Desktop Chrome'] }, dependencies: ['status-pages'], }, + // RBAC tests run isolated (they reset the database) + { + name: 'rbac', + testMatch: /rbac\.spec\.ts/, + use: { ...devices['Desktop Chrome'] }, + dependencies: ['status-pages-full'], + }, + // Status viewer tests run isolated (they reset the database) + { + name: 'status-viewer', + testMatch: /status_viewer\.spec\.ts/, + use: { ...devices['Desktop Chrome'] }, + dependencies: ['rbac'], + }, + // Comprehensive RBAC role tests (reset DB, test all roles) + { + name: 'rbac-roles', + testMatch: /rbac_roles\.spec\.ts/, + use: { ...devices['Desktop Chrome'] }, + dependencies: ['status-viewer'], + }, // All other tests can run in parallel after auth tests complete { name: 'chromium', @@ -47,6 +68,9 @@ export default defineConfig({ /custom_setup\.spec\.ts/, /status_pages\.spec\.ts/, /status_pages_full\.spec\.ts/, + /rbac\.spec\.ts/, + /status_viewer\.spec\.ts/, + /rbac_roles\.spec\.ts/, ], use: { ...devices['Desktop Chrome'] }, dependencies: ['custom-setup'], diff --git a/web/src/App.tsx b/web/src/App.tsx index a0acc26..57a7bb9 100644 --- a/web/src/App.tsx +++ b/web/src/App.tsx @@ -46,10 +46,12 @@ import { } from "@/components/ui/alert-dialog" import { useDeleteGroupMutation } from "@/hooks/useMonitors"; +import { useRole } from "@/hooks/useRole"; function MonitorGroup({ group }: { group: Group }) { const mutation = useDeleteGroupMutation(); const navigate = useNavigate(); + const { canEdit } = useRole(); const handleDelete = async () => { await mutation.mutateAsync(group.id); @@ -63,7 +65,7 @@ function MonitorGroup({ group }: { group: Group }) {