diff --git a/api/cmd/main.go b/api/cmd/main.go index 84c4cf1e..9ecc7662 100644 --- a/api/cmd/main.go +++ b/api/cmd/main.go @@ -173,17 +173,7 @@ func main() { // SECURITY: Add request size limits to prevent large payload attacks // Maximum 10MB for general requests - router.Use(middleware.RequestSizeLimit(10 * 1024 * 1024)) - - // SECURITY: Add rate limiting to prevent DoS attacks - // Layer 1: IP-based rate limiting (100 req/sec per IP with burst of 200) - rateLimiter := middleware.NewRateLimiter(100, 200) - router.Use(rateLimiter.Middleware()) - - // Layer 2: Per-user rate limiting (1000 req/hour per authenticated user) - // Prevents abuse from compromised tokens - userRateLimiter := middleware.NewUserRateLimiter(1000, 50) - router.Use(userRateLimiter.Middleware()) + router.Use(middleware.RequestSizeLimiter(10 * 1024 * 1024)) // SECURITY: Add audit logging for all requests auditLogger := middleware.NewAuditLogger(database, false) // Don't log request bodies by default @@ -240,14 +230,13 @@ func main() { // Initialize API handlers apiHandler := api.NewHandler(database, k8sClient, connTracker, syncService, wsManager, quotaEnforcer) - userHandler := handlers.NewUserHandler(userDB) + userHandler := handlers.NewUserHandler(userDB, groupDB) groupHandler := handlers.NewGroupHandler(groupDB, userDB) authHandler := auth.NewAuthHandler(userDB, jwtManager, samlAuth) activityHandler := handlers.NewActivityHandler(k8sClient, activityTracker) catalogHandler := handlers.NewCatalogHandler(database) sharingHandler := handlers.NewSharingHandler(database) pluginHandler := handlers.NewPluginHandler(database) - auditLogHandler := handlers.NewAuditLogHandler(database) dashboardHandler := handlers.NewDashboardHandler(database, k8sClient) sessionActivityHandler := handlers.NewSessionActivityHandler(database) apiKeyHandler := handlers.NewAPIKeyHandler(database) @@ -262,6 +251,13 @@ func main() { monitoringHandler := handlers.NewMonitoringHandler(database) quotasHandler := handlers.NewQuotasHandler(database) websocketHandler := handlers.NewWebSocketHandler(database) + consoleHandler := handlers.NewConsoleHandler(database) + collaborationHandler := handlers.NewCollaborationHandler(database) + integrationsHandler := handlers.NewIntegrationsHandler(database) + loadBalancingHandler := handlers.NewLoadBalancingHandler(database) + schedulingHandler := handlers.NewSchedulingHandler(database) + securityHandler := handlers.NewSecurityHandler(database) + templateVersioningHandler := handlers.NewTemplateVersioningHandler(database) // NOTE: Billing is now handled by the streamspace-billing plugin // SECURITY: Initialize webhook authentication @@ -271,14 +267,8 @@ func main() { log.Println(" Generate a secret with: openssl rand -hex 32") } - // SECURITY: Initialize CSRF protection - csrfProtection := middleware.NewCSRFProtection(24 * time.Hour) - - // SECURITY: Create stricter rate limiter for auth endpoints - authRateLimiter := middleware.NewRateLimiter(5, 10) // 5 req/sec with burst of 10 - // Setup routes - setupRoutes(router, apiHandler, userHandler, groupHandler, authHandler, activityHandler, catalogHandler, sharingHandler, pluginHandler, auditLogHandler, dashboardHandler, sessionActivityHandler, apiKeyHandler, teamHandler, preferencesHandler, notificationsHandler, searchHandler, sessionTemplatesHandler, batchHandler, monitoringHandler, quotasHandler, websocketHandler, jwtManager, userDB, redisCache, webhookSecret, csrfProtection, authRateLimiter) + setupRoutes(router, apiHandler, userHandler, groupHandler, authHandler, activityHandler, catalogHandler, sharingHandler, pluginHandler, dashboardHandler, sessionActivityHandler, apiKeyHandler, teamHandler, preferencesHandler, notificationsHandler, searchHandler, sessionTemplatesHandler, batchHandler, monitoringHandler, quotasHandler, websocketHandler, consoleHandler, collaborationHandler, integrationsHandler, loadBalancingHandler, schedulingHandler, securityHandler, templateVersioningHandler, jwtManager, userDB, redisCache, webhookSecret) // Create HTTP server with security timeouts srv := &http.Server{ @@ -359,7 +349,7 @@ func main() { log.Println("Graceful shutdown completed") } -func setupRoutes(router *gin.Engine, h *api.Handler, userHandler *handlers.UserHandler, groupHandler *handlers.GroupHandler, authHandler *auth.AuthHandler, activityHandler *handlers.ActivityHandler, catalogHandler *handlers.CatalogHandler, sharingHandler *handlers.SharingHandler, pluginHandler *handlers.PluginHandler, auditLogHandler *handlers.AuditLogHandler, dashboardHandler *handlers.DashboardHandler, sessionActivityHandler *handlers.SessionActivityHandler, apiKeyHandler *handlers.APIKeyHandler, teamHandler *handlers.TeamHandler, preferencesHandler *handlers.PreferencesHandler, notificationsHandler *handlers.NotificationsHandler, searchHandler *handlers.SearchHandler, sessionTemplatesHandler *handlers.SessionTemplatesHandler, batchHandler *handlers.BatchHandler, monitoringHandler *handlers.MonitoringHandler, quotasHandler *handlers.QuotasHandler, websocketHandler *handlers.WebSocketHandler, jwtManager *auth.JWTManager, userDB *db.UserDB, redisCache *cache.Cache, webhookSecret string, csrfProtection *middleware.CSRFProtection, authRateLimiter *middleware.RateLimiter) { +func setupRoutes(router *gin.Engine, h *api.Handler, userHandler *handlers.UserHandler, groupHandler *handlers.GroupHandler, authHandler *auth.AuthHandler, activityHandler *handlers.ActivityHandler, catalogHandler *handlers.CatalogHandler, sharingHandler *handlers.SharingHandler, pluginHandler *handlers.PluginHandler, dashboardHandler *handlers.DashboardHandler, sessionActivityHandler *handlers.SessionActivityHandler, apiKeyHandler *handlers.APIKeyHandler, teamHandler *handlers.TeamHandler, preferencesHandler *handlers.PreferencesHandler, notificationsHandler *handlers.NotificationsHandler, searchHandler *handlers.SearchHandler, sessionTemplatesHandler *handlers.SessionTemplatesHandler, batchHandler *handlers.BatchHandler, monitoringHandler *handlers.MonitoringHandler, quotasHandler *handlers.QuotasHandler, websocketHandler *handlers.WebSocketHandler, consoleHandler *handlers.ConsoleHandler, collaborationHandler *handlers.CollaborationHandler, integrationsHandler *handlers.IntegrationsHandler, loadBalancingHandler *handlers.LoadBalancingHandler, schedulingHandler *handlers.SchedulingHandler, securityHandler *handlers.SecurityHandler, templateVersioningHandler *handlers.TemplateVersioningHandler, jwtManager *auth.JWTManager, userDB *db.UserDB, redisCache *cache.Cache, webhookSecret string) { // SECURITY: Create authentication middleware authMiddleware := auth.Middleware(jwtManager, userDB) adminMiddleware := auth.RequireRole("admin") @@ -375,15 +365,11 @@ func setupRoutes(router *gin.Engine, h *api.Handler, userHandler *handlers.UserH router.GET("/health", h.Health) router.GET("/version", h.Version) - // SECURITY: CSRF token endpoint (public - issues CSRF tokens) - router.GET("/api/v1/csrf-token", csrfProtection.IssueTokenHandler()) - // API v1 v1 := router.Group("/api/v1") { // Authentication routes (public - no auth required, but rate limited) authGroup := v1.Group("/auth") - authGroup.Use(authRateLimiter.Middleware()) // SECURITY: Brute force protection { authHandler.RegisterRoutes(authGroup) } @@ -391,7 +377,7 @@ func setupRoutes(router *gin.Engine, h *api.Handler, userHandler *handlers.UserH // PROTECTED ROUTES - Require authentication protected := v1.Group("") protected.Use(authMiddleware) - protected.Use(csrfProtection.Middleware()) // SECURITY: CSRF protection for all state-changing operations + protected.Use(middleware.CSRFProtection()) // SECURITY: CSRF protection for all state-changing operations { // Sessions (authenticated users only) sessions := protected.Group("/sessions") @@ -411,6 +397,7 @@ func setupRoutes(router *gin.Engine, h *api.Handler, userHandler *handlers.UserH // NOTE: Session recording is now handled by the streamspace-recording plugin // Install it via: Admin → Plugins → streamspace-recording + } // NOTE: Data Loss Prevention (DLP) is now handled by the streamspace-dlp plugin // Install it via: Admin → Plugins → streamspace-dlp @@ -421,62 +408,61 @@ func setupRoutes(router *gin.Engine, h *api.Handler, userHandler *handlers.UserH console := protected.Group("/console") { // Console sessions (terminal and file manager) - console.POST("/sessions/:sessionId", h.CreateConsoleSession) - console.GET("/sessions/:sessionId", h.ListConsoleSessions) - console.POST("/:consoleId/disconnect", h.DisconnectConsoleSession) + console.POST("/sessions/:sessionId", consoleHandler.CreateConsoleSession) + console.GET("/sessions/:sessionId", consoleHandler.ListConsoleSessions) + console.POST("/:consoleId/disconnect", consoleHandler.DisconnectConsoleSession) // File Manager operations - console.GET("/files/:sessionId", h.ListFiles) - console.GET("/files/:sessionId/content", h.GetFileContent) - console.POST("/files/:sessionId/upload", h.UploadFile) - console.GET("/files/:sessionId/download", h.DownloadFile) - console.POST("/files/:sessionId/directory", h.CreateDirectory) - console.DELETE("/files/:sessionId", h.DeleteFile) - console.PATCH("/files/:sessionId/rename", h.RenameFile) + console.GET("/files/:sessionId", consoleHandler.ListFiles) + console.GET("/files/:sessionId/content", consoleHandler.GetFileContent) + console.POST("/files/:sessionId/upload", consoleHandler.UploadFile) + console.GET("/files/:sessionId/download", consoleHandler.DownloadFile) + console.POST("/files/:sessionId/directory", consoleHandler.CreateDirectory) + console.DELETE("/files/:sessionId", consoleHandler.DeleteFile) + console.PATCH("/files/:sessionId/rename", consoleHandler.RenameFile) // File operation history - console.GET("/files/:sessionId/history", h.GetFileOperationHistory) + console.GET("/files/:sessionId/history", consoleHandler.GetFileOperationHistory) } - // Multi-Monitor Support - monitors := protected.Group("/monitors") - { - monitors.GET("/sessions/:sessionId", h.GetMonitorConfiguration) - monitors.POST("/sessions/:sessionId", h.CreateMonitorConfiguration) - monitors.GET("/sessions/:sessionId/list", h.ListMonitorConfigurations) - monitors.PATCH("/configurations/:configId", h.UpdateMonitorConfiguration) - monitors.POST("/configurations/:configId/activate", h.ActivateMonitorConfiguration) - monitors.DELETE("/configurations/:configId", h.DeleteMonitorConfiguration) - monitors.GET("/sessions/:sessionId/streams", h.GetMonitorStreams) - - // Preset configurations - monitors.POST("/sessions/:sessionId/presets/:preset", h.CreatePresetConfiguration) - } + // NOTE: Multi-Monitor Support is not yet implemented + // Will be added in a future release or via plugin + // monitors := protected.Group("/monitors") + // { + // monitors.GET("/sessions/:sessionId", h.GetMonitorConfiguration) + // monitors.POST("/sessions/:sessionId", h.CreateMonitorConfiguration) + // monitors.GET("/sessions/:sessionId/list", h.ListMonitorConfigurations) + // monitors.PATCH("/configurations/:configId", h.UpdateMonitorConfiguration) + // monitors.POST("/configurations/:configId/activate", h.ActivateMonitorConfiguration) + // monitors.DELETE("/configurations/:configId", h.DeleteMonitorConfiguration) + // monitors.GET("/sessions/:sessionId/streams", h.GetMonitorStreams) + // monitors.POST("/sessions/:sessionId/presets/:preset", h.CreatePresetConfiguration) + // } // Real-time Collaboration collaboration := protected.Group("/collaboration") { // Collaboration session management - collaboration.POST("/sessions/:sessionId", h.CreateCollaborationSession) - collaboration.POST("/:collabId/join", h.JoinCollaborationSession) - collaboration.POST("/:collabId/leave", h.LeaveCollaborationSession) + collaboration.POST("/sessions/:sessionId", collaborationHandler.CreateCollaborationSession) + collaboration.POST("/:collabId/join", collaborationHandler.JoinCollaborationSession) + collaboration.POST("/:collabId/leave", collaborationHandler.LeaveCollaborationSession) // Participant management - collaboration.GET("/:collabId/participants", h.GetCollaborationParticipants) - collaboration.PATCH("/:collabId/participants/:userId", h.UpdateParticipantRole) + collaboration.GET("/:collabId/participants", collaborationHandler.GetCollaborationParticipants) + collaboration.PATCH("/:collabId/participants/:userId", collaborationHandler.UpdateParticipantRole) // Chat operations - collaboration.POST("/:collabId/chat", h.SendChatMessage) - collaboration.GET("/:collabId/chat", h.GetChatHistory) + collaboration.POST("/:collabId/chat", collaborationHandler.SendChatMessage) + collaboration.GET("/:collabId/chat", collaborationHandler.GetChatHistory) // Annotation operations - collaboration.POST("/:collabId/annotations", h.CreateAnnotation) - collaboration.GET("/:collabId/annotations", h.GetAnnotations) - collaboration.DELETE("/:collabId/annotations/:annotationId", h.DeleteAnnotation) - collaboration.DELETE("/:collabId/annotations", h.ClearAllAnnotations) + collaboration.POST("/:collabId/annotations", collaborationHandler.CreateAnnotation) + collaboration.GET("/:collabId/annotations", collaborationHandler.GetAnnotations) + collaboration.DELETE("/:collabId/annotations/:annotationId", collaborationHandler.DeleteAnnotation) + collaboration.DELETE("/:collabId/annotations", collaborationHandler.ClearAllAnnotations) // Statistics - collaboration.GET("/:collabId/stats", h.GetCollaborationStats) + collaboration.GET("/:collabId/stats", collaborationHandler.GetCollaborationStats) } // Integration Hub & Webhooks - Operator/Admin only @@ -484,67 +470,69 @@ func setupRoutes(router *gin.Engine, h *api.Handler, userHandler *handlers.UserH integrations.Use(operatorMiddleware) { // Webhooks - integrations.GET("/webhooks", h.ListWebhooks) - integrations.POST("/webhooks", h.CreateWebhook) - integrations.PATCH("/webhooks/:webhookId", h.UpdateWebhook) - integrations.DELETE("/webhooks/:webhookId", h.DeleteWebhook) - integrations.POST("/webhooks/:webhookId/test", h.TestWebhook) - integrations.GET("/webhooks/:webhookId/deliveries", h.GetWebhookDeliveries) - integrations.POST("/webhooks/:webhookId/retry/:deliveryId", h.RetryWebhookDelivery) + integrations.GET("/webhooks", integrationsHandler.ListWebhooks) + integrations.POST("/webhooks", integrationsHandler.CreateWebhook) + integrations.PATCH("/webhooks/:webhookId", integrationsHandler.UpdateWebhook) + integrations.DELETE("/webhooks/:webhookId", integrationsHandler.DeleteWebhook) + integrations.POST("/webhooks/:webhookId/test", integrationsHandler.TestWebhook) + integrations.GET("/webhooks/:webhookId/deliveries", integrationsHandler.GetWebhookDeliveries) + // NOTE: Webhook retry not yet implemented + // integrations.POST("/webhooks/:webhookId/retry/:deliveryId", h.RetryWebhookDelivery) // External Integrations - integrations.GET("/external", h.ListIntegrations) - integrations.POST("/external", h.CreateIntegration) - integrations.PATCH("/external/:integrationId", h.UpdateIntegration) - integrations.DELETE("/external/:integrationId", h.DeleteIntegration) - integrations.POST("/external/:integrationId/test", h.TestIntegration) + integrations.GET("/external", integrationsHandler.ListIntegrations) + integrations.POST("/external", integrationsHandler.CreateIntegration) + // NOTE: Update and delete integrations not yet implemented + // integrations.PATCH("/external/:integrationId", h.UpdateIntegration) + // integrations.DELETE("/external/:integrationId", h.DeleteIntegration) + integrations.POST("/external/:integrationId/test", integrationsHandler.TestIntegration) // Available events - integrations.GET("/events", h.GetAvailableEvents) + integrations.GET("/events", integrationsHandler.GetAvailableEvents) } // Security - MFA, IP Whitelisting, Zero Trust security := protected.Group("/security") { // Multi-Factor Authentication (all users) - security.POST("/mfa/setup", h.SetupMFA) - security.POST("/mfa/:mfaId/verify-setup", h.VerifyMFASetup) - security.POST("/mfa/verify", h.VerifyMFA) - security.GET("/mfa/methods", h.ListMFAMethods) - security.DELETE("/mfa/:mfaId", h.DisableMFA) - security.POST("/mfa/backup-codes", h.GenerateBackupCodes) + security.POST("/mfa/setup", securityHandler.SetupMFA) + security.POST("/mfa/:mfaId/verify-setup", securityHandler.VerifyMFASetup) + security.POST("/mfa/verify", securityHandler.VerifyMFA) + security.GET("/mfa/methods", securityHandler.ListMFAMethods) + security.DELETE("/mfa/:mfaId", securityHandler.DisableMFA) + security.POST("/mfa/backup-codes", securityHandler.GenerateBackupCodes) // IP Whitelisting (users can manage their own, admins can manage all) - security.POST("/ip-whitelist", h.CreateIPWhitelist) - security.GET("/ip-whitelist", h.ListIPWhitelist) - security.DELETE("/ip-whitelist/:entryId", h.DeleteIPWhitelist) - security.GET("/ip-whitelist/check", h.CheckIPAccess) + security.POST("/ip-whitelist", securityHandler.CreateIPWhitelist) + security.GET("/ip-whitelist", securityHandler.ListIPWhitelist) + security.DELETE("/ip-whitelist/:entryId", securityHandler.DeleteIPWhitelist) + security.GET("/ip-whitelist/check", securityHandler.CheckIPAccess) // Zero Trust / Session Verification - security.POST("/sessions/:sessionId/verify", h.VerifySession) - security.POST("/device-posture", h.CheckDevicePosture) - security.GET("/alerts", h.GetSecurityAlerts) + security.POST("/sessions/:sessionId/verify", securityHandler.VerifySession) + security.POST("/device-posture", securityHandler.CheckDevicePosture) + security.GET("/alerts", securityHandler.GetSecurityAlerts) } // Session Scheduling & Calendar Integration scheduling := protected.Group("/scheduling") { // Scheduled sessions - scheduling.GET("/sessions", h.ListScheduledSessions) - scheduling.POST("/sessions", h.CreateScheduledSession) - scheduling.GET("/sessions/:scheduleId", h.GetScheduledSession) - scheduling.PATCH("/sessions/:scheduleId", h.UpdateScheduledSession) - scheduling.DELETE("/sessions/:scheduleId", h.DeleteScheduledSession) - scheduling.POST("/sessions/:scheduleId/enable", h.EnableScheduledSession) - scheduling.POST("/sessions/:scheduleId/disable", h.DisableScheduledSession) + scheduling.GET("/sessions", schedulingHandler.ListScheduledSessions) + scheduling.POST("/sessions", schedulingHandler.CreateScheduledSession) + scheduling.GET("/sessions/:scheduleId", schedulingHandler.GetScheduledSession) + scheduling.PATCH("/sessions/:scheduleId", schedulingHandler.UpdateScheduledSession) + scheduling.DELETE("/sessions/:scheduleId", schedulingHandler.DeleteScheduledSession) + scheduling.POST("/sessions/:scheduleId/enable", schedulingHandler.EnableScheduledSession) + scheduling.POST("/sessions/:scheduleId/disable", schedulingHandler.DisableScheduledSession) // Calendar integrations - scheduling.POST("/calendar/connect", h.ConnectCalendar) - scheduling.GET("/calendar/oauth/callback", h.CalendarOAuthCallback) - scheduling.GET("/calendar/integrations", h.ListCalendarIntegrations) - scheduling.DELETE("/calendar/integrations/:integrationId", h.DisconnectCalendar) - scheduling.POST("/calendar/integrations/:integrationId/sync", h.SyncCalendar) - scheduling.GET("/calendar/export.ics", h.ExportICalendar) + scheduling.POST("/calendar/connect", schedulingHandler.ConnectCalendar) + scheduling.GET("/calendar/oauth/callback", schedulingHandler.CalendarOAuthCallback) + scheduling.GET("/calendar/integrations", schedulingHandler.ListCalendarIntegrations) + scheduling.DELETE("/calendar/integrations/:integrationId", schedulingHandler.DisconnectCalendar) + scheduling.POST("/calendar/integrations/:integrationId/sync", schedulingHandler.SyncCalendar) + scheduling.GET("/calendar/export.ics", schedulingHandler.ExportICalendar) } // Load Balancing & Auto-scaling - Admin/Operator only @@ -552,37 +540,48 @@ func setupRoutes(router *gin.Engine, h *api.Handler, userHandler *handlers.UserH scaling.Use(operatorMiddleware) { // Load balancing policies - scaling.GET("/load-balancing/policies", h.ListLoadBalancingPolicies) - scaling.POST("/load-balancing/policies", h.CreateLoadBalancingPolicy) - scaling.GET("/load-balancing/nodes", h.GetNodeStatus) - scaling.POST("/load-balancing/select-node", h.SelectNode) + scaling.GET("/load-balancing/policies", loadBalancingHandler.ListLoadBalancingPolicies) + scaling.POST("/load-balancing/policies", loadBalancingHandler.CreateLoadBalancingPolicy) + scaling.GET("/load-balancing/nodes", loadBalancingHandler.GetNodeStatus) + scaling.POST("/load-balancing/select-node", loadBalancingHandler.SelectNode) // Auto-scaling policies - scaling.GET("/autoscaling/policies", h.ListAutoScalingPolicies) - scaling.POST("/autoscaling/policies", h.CreateAutoScalingPolicy) - scaling.POST("/autoscaling/policies/:policyId/trigger", h.TriggerScaling) - scaling.GET("/autoscaling/history", h.GetScalingHistory) + scaling.GET("/autoscaling/policies", loadBalancingHandler.ListAutoScalingPolicies) + scaling.POST("/autoscaling/policies", loadBalancingHandler.CreateAutoScalingPolicy) + scaling.POST("/autoscaling/policies/:policyId/trigger", loadBalancingHandler.TriggerScaling) + scaling.GET("/autoscaling/history", loadBalancingHandler.GetScalingHistory) } - // Compliance & Governance - Admin only - compliance := protected.Group("/compliance") - compliance.Use(adminMiddleware) +// // // Compliance & Governance - Admin only +// compliance := protected.Group("/compliance") +// compliance.Use(adminMiddleware) +// { +// // Frameworks +// compliance.GET("/frameworks", h.ListComplianceFrameworks) +// compliance.POST("/frameworks", h.CreateComplianceFramework) +// +// // Policies +// compliance.GET("/policies", h.ListCompliancePolicies) +// compliance.POST("/policies", h.CreateCompliancePolicy) +// +// // Violations +// compliance.GET("/violations", h.ListViolations) +// compliance.POST("/violations", h.RecordViolation) +// compliance.POST("/violations/:violationId/resolve", h.ResolveViolation) + +// } + +// +// NOTE: Compliance & Governance is now handled by the streamspace-compliance plugin +// Install it via: Admin → Plugins → streamspace-compliance + // Templates (read: all users, write: operators/admins) + templates := protected.Group("/templates") { - // Frameworks - compliance.GET("/frameworks", h.ListComplianceFrameworks) - compliance.POST("/frameworks", h.CreateComplianceFramework) - - // Policies - compliance.GET("/policies", h.ListCompliancePolicies) - compliance.POST("/policies", h.CreateCompliancePolicy) - - // Violations - compliance.GET("/violations", h.ListViolations) - compliance.POST("/violations", h.RecordViolation) - compliance.POST("/violations/:violationId/resolve", h.ResolveViolation) + // Read-only template endpoints (all authenticated users) + templates.GET("", cache.CacheMiddleware(redisCache, 5*time.Minute), h.ListTemplates) + templates.GET("/:id", cache.CacheMiddleware(redisCache, 5*time.Minute), h.GetTemplate) - // NOTE: Compliance & Governance is now handled by the streamspace-compliance plugin - // Install it via: Admin → Plugins → streamspace-compliance + // Write operations require operator or admin role templatesWrite := templates.Group("") templatesWrite.Use(operatorMiddleware) { @@ -591,21 +590,21 @@ func setupRoutes(router *gin.Engine, h *api.Handler, userHandler *handlers.UserH templatesWrite.DELETE("/:id", cache.InvalidateCacheMiddleware(redisCache, cache.TemplatePattern()), h.DeleteTemplate) // Template Versioning (operator only) - templatesWrite.POST("/:templateId/versions", h.CreateTemplateVersion) - templatesWrite.GET("/:templateId/versions", h.ListTemplateVersions) - templatesWrite.GET("/versions/:versionId", h.GetTemplateVersion) - templatesWrite.POST("/versions/:versionId/publish", h.PublishTemplateVersion) - templatesWrite.POST("/versions/:versionId/deprecate", h.DeprecateTemplateVersion) - templatesWrite.POST("/versions/:versionId/set-default", h.SetDefaultTemplateVersion) - templatesWrite.POST("/versions/:versionId/clone", h.CloneTemplateVersion) + templatesWrite.POST("/:templateId/versions", templateVersioningHandler.CreateTemplateVersion) + templatesWrite.GET("/:templateId/versions", templateVersioningHandler.ListTemplateVersions) + templatesWrite.GET("/versions/:versionId", templateVersioningHandler.GetTemplateVersion) + templatesWrite.POST("/versions/:versionId/publish", templateVersioningHandler.PublishTemplateVersion) + templatesWrite.POST("/versions/:versionId/deprecate", templateVersioningHandler.DeprecateTemplateVersion) + templatesWrite.POST("/versions/:versionId/set-default", templateVersioningHandler.SetDefaultTemplateVersion) + templatesWrite.POST("/versions/:versionId/clone", templateVersioningHandler.CloneTemplateVersion) // Template Testing (operator only) - templatesWrite.POST("/versions/:versionId/tests", h.CreateTemplateTest) - templatesWrite.GET("/versions/:versionId/tests", h.ListTemplateTests) - templatesWrite.PATCH("/tests/:testId", h.UpdateTemplateTestStatus) + templatesWrite.POST("/versions/:versionId/tests", templateVersioningHandler.CreateTemplateTest) + templatesWrite.GET("/versions/:versionId/tests", templateVersioningHandler.ListTemplateTests) + templatesWrite.PATCH("/tests/:testId", templateVersioningHandler.UpdateTemplateTestStatus) // Template Inheritance - templatesWrite.GET("/:templateId/inheritance", h.GetTemplateInheritance) + templatesWrite.GET("/:templateId/inheritance", templateVersioningHandler.GetTemplateInheritance) } } @@ -676,16 +675,17 @@ func setupRoutes(router *gin.Engine, h *api.Handler, userHandler *handlers.UserH // NOTE: Analytics & Reporting is now handled by the streamspace-analytics-advanced plugin // Install it via: Admin → Plugins → streamspace-analytics-advanced - // Audit logs (admins only for viewing, operators can view their own) - audit := protected.Group("/audit") - { - // Admin can view all audit logs with advanced filtering - audit.GET("/logs", adminMiddleware, cache.CacheMiddleware(redisCache, 30*time.Second), auditLogHandler.ListAuditLogs) - audit.GET("/stats", adminMiddleware, cache.CacheMiddleware(redisCache, 1*time.Minute), auditLogHandler.GetAuditLogStats) - - // Users can view their own audit logs - audit.GET("/users/:userId/logs", auditLogHandler.GetUserAuditLogs) - } + // NOTE: Audit logs are now handled by the streamspace-audit plugin + // Install it via: Admin → Plugins → streamspace-audit + // audit := protected.Group("/audit") + // { + // // Admin can view all audit logs with advanced filtering + // audit.GET("/logs", adminMiddleware, cache.CacheMiddleware(redisCache, 30*time.Second), auditLogHandler.ListAuditLogs) + // audit.GET("/stats", adminMiddleware, cache.CacheMiddleware(redisCache, 1*time.Minute), auditLogHandler.GetAuditLogStats) + // + // // Users can view their own audit logs + // audit.GET("/users/:userId/logs", auditLogHandler.GetUserAuditLogs) + // } // Dashboard and resource usage (operators and admins can view platform stats) dashboard := protected.Group("/dashboard") diff --git a/api/internal/api/handlers.go b/api/internal/api/handlers.go index 1107326f..d0c8e922 100644 --- a/api/internal/api/handlers.go +++ b/api/internal/api/handlers.go @@ -19,6 +19,7 @@ import ( "github.com/streamspace/streamspace/api/internal/tracker" "github.com/streamspace/streamspace/api/internal/websocket" "gopkg.in/yaml.v3" + corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/runtime/schema" ) @@ -158,31 +159,39 @@ func (h *Handler) CreateSession(c *gin.Context) { } // Check user quota - quotaReq := "a.SessionRequest{ - UserID: req.User, - Memory: memory, - CPU: cpu, - Storage: "50Gi", // Default storage quota check - } - - quotaResult, err := h.quotaEnforcer.CheckSessionQuota(ctx, quotaReq) + // Parse CPU and memory to int64 + requestedCPU, requestedMemory, err := h.quotaEnforcer.ValidateResourceRequest(cpu, memory) if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Failed to check quota", + c.JSON(http.StatusBadRequest, gin.H{ + "error": "Invalid resource request", "message": err.Error(), }) return } - if !quotaResult.Allowed { + // Get current usage by listing user's pods + podList, err := h.k8sClient.GetPods(ctx, h.namespace) + if err != nil { + log.Printf("Failed to get pods for quota check: %v", err) + // Continue with empty usage if we can't get pods + podList = &corev1.PodList{} + } + + // Filter pods for this user + userPods := make([]corev1.Pod, 0) + for _, pod := range podList.Items { + if user, ok := pod.Labels["user"]; ok && user == req.User { + userPods = append(userPods, pod) + } + } + + currentUsage := h.quotaEnforcer.CalculateUsage(userPods) + + // Check if user can create session + if err := h.quotaEnforcer.CheckSessionCreation(ctx, req.User, requestedCPU, requestedMemory, 0, currentUsage); err != nil { c.JSON(http.StatusForbidden, gin.H{ "error": "Quota exceeded", - "message": quotaResult.Reason, - "quota": gin.H{ - "current": quotaResult.CurrentUsage, - "requested": quotaResult.RequestedUsage, - "available": quotaResult.AvailableQuota, - }, + "message": err.Error(), }) return } @@ -226,12 +235,6 @@ func (h *Handler) CreateSession(c *gin.Context) { return } - // Update quota usage - if err := h.quotaEnforcer.UpdateSessionQuota(ctx, req.User, memory, cpu, "50Gi", true); err != nil { - log.Printf("Failed to update quota usage: %v", err) - // Don't fail the request, but log the error - } - // Cache in database if err := h.cacheSessionInDB(ctx, created); err != nil { log.Printf("Failed to cache session in database: %v", err) @@ -280,8 +283,8 @@ func (h *Handler) DeleteSession(c *gin.Context) { ctx := context.Background() sessionID := c.Param("id") - // Get session info before deletion (for quota tracking) - session, err := h.k8sClient.GetSession(ctx, h.namespace, sessionID) + // Verify session exists before deletion + _, err := h.k8sClient.GetSession(ctx, h.namespace, sessionID) if err != nil { c.JSON(http.StatusNotFound, gin.H{"error": "Session not found"}) return @@ -293,15 +296,6 @@ func (h *Handler) DeleteSession(c *gin.Context) { return } - // Update quota usage (decrement) - if session.Resources.Memory != "" && session.Resources.CPU != "" { - if err := h.quotaEnforcer.UpdateSessionQuota(ctx, session.User, - session.Resources.Memory, session.Resources.CPU, "50Gi", false); err != nil { - log.Printf("Failed to update quota usage on session deletion: %v", err) - // Don't fail the request, quota will be cleaned up later - } - } - // Delete from database cache if err := h.deleteSessionFromDB(ctx, sessionID); err != nil { log.Printf("Failed to delete session from database: %v", err) diff --git a/api/internal/api/stubs.go b/api/internal/api/stubs.go index 8fa730a2..bef10c8c 100644 --- a/api/internal/api/stubs.go +++ b/api/internal/api/stubs.go @@ -13,6 +13,7 @@ import ( "github.com/gorilla/websocket" corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" "k8s.io/apimachinery/pkg/runtime/schema" ) @@ -298,7 +299,7 @@ func (h *Handler) CreateResource(c *gin.Context) { // UpdateResource updates a K8s resource func (h *Handler) UpdateResource(c *gin.Context) { - resourceType := c.Param("type") // e.g., "deployment", "service" + _ = c.Param("type") // Resource type not used; Kind from request body resourceName := c.Param("name") namespace := c.Query("namespace") if namespace == "" { diff --git a/api/internal/handlers/collaboration.go b/api/internal/handlers/collaboration.go index 55d17beb..b78f0ed8 100644 --- a/api/internal/handlers/collaboration.go +++ b/api/internal/handlers/collaboration.go @@ -254,7 +254,6 @@ import ( "fmt" "net/http" "strconv" - "strings" "time" "github.com/gin-gonic/gin" @@ -262,18 +261,18 @@ import ( ) // Handler handles collaboration-related HTTP requests. -type Handler struct { +type CollaborationHandler struct { // DB is the database connection for collaboration queries and updates. DB *db.Database } // NewCollaborationHandler creates a new collaboration handler. -func NewCollaborationHandler(database *db.Database) *Handler { - return &Handler{DB: database} +func NewCollaborationHandler(database *db.Database) *CollaborationHandler { + return &CollaborationHandler{DB: database} } // canAccessSession checks if a user has access to a session. -func (h *Handler) canAccessSession(userID, sessionID string) bool { +func (h *CollaborationHandler) canAccessSession(userID, sessionID string) bool { // Check if user owns the session var owner string err := h.DB.DB().QueryRow("SELECT user_id FROM sessions WHERE id = $1", sessionID).Scan(&owner) @@ -417,7 +416,7 @@ type Point struct { } // CreateCollaborationSession creates a new collaboration session -func (h *Handler) CreateCollaborationSession(c *gin.Context) { +func (h *CollaborationHandler) CreateCollaborationSession(c *gin.Context) { sessionID := c.Param("sessionId") userID := c.GetString("user_id") @@ -496,7 +495,7 @@ func (h *Handler) CreateCollaborationSession(c *gin.Context) { } // JoinCollaborationSession allows a user to join a collaboration -func (h *Handler) JoinCollaborationSession(c *gin.Context) { +func (h *CollaborationHandler) JoinCollaborationSession(c *gin.Context) { collabID := c.Param("collabId") userID := c.GetString("user_id") @@ -616,7 +615,7 @@ func (h *Handler) JoinCollaborationSession(c *gin.Context) { } // LeaveCollaborationSession removes a user from collaboration -func (h *Handler) LeaveCollaborationSession(c *gin.Context) { +func (h *CollaborationHandler) LeaveCollaborationSession(c *gin.Context) { collabID := c.Param("collabId") userID := c.GetString("user_id") @@ -650,7 +649,7 @@ func (h *Handler) LeaveCollaborationSession(c *gin.Context) { } // GetCollaborationParticipants lists all participants -func (h *Handler) GetCollaborationParticipants(c *gin.Context) { +func (h *CollaborationHandler) GetCollaborationParticipants(c *gin.Context) { collabID := c.Param("collabId") userID := c.GetString("user_id") @@ -660,7 +659,7 @@ func (h *Handler) GetCollaborationParticipants(c *gin.Context) { return } - rows, err := h.DB.Query(` + rows, err := h.DB.DB().Query(` SELECT cp.user_id, u.username, cp.role, cp.permissions, cp.cursor_position, cp.color, cp.is_active, cp.joined_at, cp.last_seen_at FROM collaboration_participants cp @@ -702,7 +701,7 @@ func (h *Handler) GetCollaborationParticipants(c *gin.Context) { } // UpdateParticipantRole updates a participant's role and permissions -func (h *Handler) UpdateParticipantRole(c *gin.Context) { +func (h *CollaborationHandler) UpdateParticipantRole(c *gin.Context) { collabID := c.Param("collabId") targetUserID := c.Param("userId") userID := c.GetString("user_id") @@ -741,7 +740,7 @@ func (h *Handler) UpdateParticipantRole(c *gin.Context) { // Chat Operations // SendChatMessage sends a message to the collaboration chat -func (h *Handler) SendChatMessage(c *gin.Context) { +func (h *CollaborationHandler) SendChatMessage(c *gin.Context) { collabID := c.Param("collabId") userID := c.GetString("user_id") @@ -787,7 +786,7 @@ func (h *Handler) SendChatMessage(c *gin.Context) { } // GetChatHistory retrieves chat history -func (h *Handler) GetChatHistory(c *gin.Context) { +func (h *CollaborationHandler) GetChatHistory(c *gin.Context) { collabID := c.Param("collabId") userID := c.GetString("user_id") limit, _ := strconv.Atoi(c.DefaultQuery("limit", "100")) @@ -819,7 +818,7 @@ func (h *Handler) GetChatHistory(c *gin.Context) { query += fmt.Sprintf(" ORDER BY cc.created_at DESC LIMIT $%d", argCount) args = append(args, limit) - rows, err := h.DB.Query(query, args...) + rows, err := h.DB.DB().Query(query, args...) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to retrieve chat"}) return @@ -857,7 +856,7 @@ func (h *Handler) GetChatHistory(c *gin.Context) { // Annotation Operations // CreateAnnotation creates a new annotation -func (h *Handler) CreateAnnotation(c *gin.Context) { +func (h *CollaborationHandler) CreateAnnotation(c *gin.Context) { collabID := c.Param("collabId") userID := c.GetString("user_id") @@ -906,7 +905,7 @@ func (h *Handler) CreateAnnotation(c *gin.Context) { } // GetAnnotations retrieves active annotations -func (h *Handler) GetAnnotations(c *gin.Context) { +func (h *CollaborationHandler) GetAnnotations(c *gin.Context) { collabID := c.Param("collabId") userID := c.GetString("user_id") @@ -915,7 +914,7 @@ func (h *Handler) GetAnnotations(c *gin.Context) { return } - rows, err := h.DB.Query(` + rows, err := h.DB.DB().Query(` SELECT id, session_id, user_id, type, color, thickness, points, text, is_persistent, created_at, expires_at FROM collaboration_annotations @@ -949,7 +948,7 @@ func (h *Handler) GetAnnotations(c *gin.Context) { } // DeleteAnnotation removes an annotation -func (h *Handler) DeleteAnnotation(c *gin.Context) { +func (h *CollaborationHandler) DeleteAnnotation(c *gin.Context) { collabID := c.Param("collabId") annotationID := c.Param("annotationId") userID := c.GetString("user_id") @@ -973,7 +972,7 @@ func (h *Handler) DeleteAnnotation(c *gin.Context) { } // ClearAllAnnotations removes all annotations -func (h *Handler) ClearAllAnnotations(c *gin.Context) { +func (h *CollaborationHandler) ClearAllAnnotations(c *gin.Context) { collabID := c.Param("collabId") userID := c.GetString("user_id") @@ -994,7 +993,7 @@ func (h *Handler) ClearAllAnnotations(c *gin.Context) { // Helper functions -func (h *Handler) isCollaborationParticipant(collabID, userID string) bool { +func (h *CollaborationHandler) isCollaborationParticipant(collabID, userID string) bool { var exists bool h.DB.DB().QueryRow(` SELECT EXISTS(SELECT 1 FROM collaboration_participants @@ -1003,7 +1002,7 @@ func (h *Handler) isCollaborationParticipant(collabID, userID string) bool { return exists } -func (h *Handler) canManageCollaboration(collabID, userID string) bool { +func (h *CollaborationHandler) canManageCollaboration(collabID, userID string) bool { var permissions sql.NullString h.DB.DB().QueryRow(` SELECT permissions FROM collaboration_participants @@ -1019,7 +1018,7 @@ func (h *Handler) canManageCollaboration(collabID, userID string) bool { return perms.CanManage } -func (h *Handler) hasCollaborationPermission(collabID, userID, permission string) bool { +func (h *CollaborationHandler) hasCollaborationPermission(collabID, userID, permission string) bool { var permissions sql.NullString h.DB.DB().QueryRow(` SELECT permissions FROM collaboration_participants @@ -1048,7 +1047,7 @@ func (h *Handler) hasCollaborationPermission(collabID, userID, permission string } // GetCollaborationStats returns collaboration statistics -func (h *Handler) GetCollaborationStats(c *gin.Context) { +func (h *CollaborationHandler) GetCollaborationStats(c *gin.Context) { collabID := c.Param("collabId") userID := c.GetString("user_id") diff --git a/api/internal/handlers/console.go b/api/internal/handlers/console.go index 14b7e7c2..6a1bc416 100644 --- a/api/internal/handlers/console.go +++ b/api/internal/handlers/console.go @@ -79,8 +79,19 @@ import ( "time" "github.com/gin-gonic/gin" + "github.com/streamspace/streamspace/api/internal/db" ) +// Handler is the console handler with database access. +type ConsoleHandler struct { + DB *db.Database +} + +// NewConsoleHandler creates a new console handler. +func NewConsoleHandler(database *db.Database) *ConsoleHandler { + return &ConsoleHandler{DB: database} +} + // ConsoleSession represents an active console session type ConsoleSession struct { ID string `json:"id"` @@ -124,7 +135,7 @@ type FileOperation struct { } // CreateConsoleSession creates a new console session for a workspace session -func (h *Handler) CreateConsoleSession(c *gin.Context) { +func (h *ConsoleHandler) CreateConsoleSession(c *gin.Context) { sessionID := c.Param("sessionId") userID := c.GetString("user_id") @@ -142,7 +153,7 @@ func (h *Handler) CreateConsoleSession(c *gin.Context) { // Verify user has access to this session var sessionOwner string - err := h.DB.QueryRow("SELECT user_id FROM sessions WHERE id = $1", sessionID).Scan(&sessionOwner) + err := h.DB.DB().QueryRow("SELECT user_id FROM sessions WHERE id = $1", sessionID).Scan(&sessionOwner) if err == sql.ErrNoRows { c.JSON(http.StatusNotFound, gin.H{"error": "session not found"}) return @@ -150,7 +161,7 @@ func (h *Handler) CreateConsoleSession(c *gin.Context) { if sessionOwner != userID { // Check if user has shared access var hasAccess bool - h.DB.QueryRow(` + h.DB.DB().QueryRow(` SELECT EXISTS( SELECT 1 FROM session_shares WHERE session_id = $1 AND shared_with_user_id = $2 @@ -177,7 +188,7 @@ func (h *Handler) CreateConsoleSession(c *gin.Context) { consoleID := fmt.Sprintf("console-%s-%d", sessionID, time.Now().Unix()) // Create console session - err = h.DB.QueryRow(` + err = h.DB.DB().QueryRow(` INSERT INTO console_sessions ( id, session_id, user_id, type, status, current_path, shell_type, columns, rows @@ -205,7 +216,7 @@ func (h *Handler) CreateConsoleSession(c *gin.Context) { } // ListConsoleSessions lists all console sessions for a workspace session -func (h *Handler) ListConsoleSessions(c *gin.Context) { +func (h *ConsoleHandler) ListConsoleSessions(c *gin.Context) { sessionID := c.Param("sessionId") userID := c.GetString("user_id") @@ -215,7 +226,7 @@ func (h *Handler) ListConsoleSessions(c *gin.Context) { return } - rows, err := h.DB.Query(` + rows, err := h.DB.DB().Query(` SELECT id, session_id, user_id, type, status, current_path, shell_type, columns, rows, metadata, connected_at, last_activity_at, disconnected_at FROM console_sessions @@ -249,13 +260,13 @@ func (h *Handler) ListConsoleSessions(c *gin.Context) { } // DisconnectConsoleSession disconnects an active console session -func (h *Handler) DisconnectConsoleSession(c *gin.Context) { +func (h *ConsoleHandler) DisconnectConsoleSession(c *gin.Context) { consoleID := c.Param("consoleId") userID := c.GetString("user_id") // Verify ownership var owner string - err := h.DB.QueryRow("SELECT user_id FROM console_sessions WHERE id = $1", consoleID).Scan(&owner) + err := h.DB.DB().QueryRow("SELECT user_id FROM console_sessions WHERE id = $1", consoleID).Scan(&owner) if err == sql.ErrNoRows { c.JSON(http.StatusNotFound, gin.H{"error": "console session not found"}) return @@ -267,7 +278,7 @@ func (h *Handler) DisconnectConsoleSession(c *gin.Context) { // Update status now := time.Now() - _, err = h.DB.Exec(` + _, err = h.DB.DB().Exec(` UPDATE console_sessions SET status = 'disconnected', disconnected_at = $1 WHERE id = $2 @@ -284,7 +295,7 @@ func (h *Handler) DisconnectConsoleSession(c *gin.Context) { // File Manager Operations // ListFiles lists files in a directory -func (h *Handler) ListFiles(c *gin.Context) { +func (h *ConsoleHandler) ListFiles(c *gin.Context) { sessionID := c.Param("sessionId") userID := c.GetString("user_id") path := c.DefaultQuery("path", "/config") @@ -324,7 +335,7 @@ func (h *Handler) ListFiles(c *gin.Context) { Name: entry.Name(), Path: filepath.Join(path, entry.Name()), Size: info.Size(), - IsDirectory: entry.IsDirectory(), + IsDirectory: entry.IsDir(), Permissions: info.Mode().String(), ModifiedAt: info.ModTime(), } @@ -340,7 +351,7 @@ func (h *Handler) ListFiles(c *gin.Context) { } // GetFileContent retrieves the content of a file -func (h *Handler) GetFileContent(c *gin.Context) { +func (h *ConsoleHandler) GetFileContent(c *gin.Context) { sessionID := c.Param("sessionId") userID := c.GetString("user_id") path := c.Query("path") @@ -392,7 +403,7 @@ func (h *Handler) GetFileContent(c *gin.Context) { } // UploadFile uploads a file to the session -func (h *Handler) UploadFile(c *gin.Context) { +func (h *ConsoleHandler) UploadFile(c *gin.Context) { sessionID := c.Param("sessionId") userID := c.GetString("user_id") targetPath := c.PostForm("path") @@ -448,7 +459,7 @@ func (h *Handler) UploadFile(c *gin.Context) { } // DownloadFile downloads a file from the session -func (h *Handler) DownloadFile(c *gin.Context) { +func (h *ConsoleHandler) DownloadFile(c *gin.Context) { sessionID := c.Param("sessionId") userID := c.GetString("user_id") path := c.Query("path") @@ -494,7 +505,7 @@ func (h *Handler) DownloadFile(c *gin.Context) { } // CreateDirectory creates a new directory -func (h *Handler) CreateDirectory(c *gin.Context) { +func (h *ConsoleHandler) CreateDirectory(c *gin.Context) { sessionID := c.Param("sessionId") userID := c.GetString("user_id") @@ -540,7 +551,7 @@ func (h *Handler) CreateDirectory(c *gin.Context) { } // DeleteFile deletes a file or directory -func (h *Handler) DeleteFile(c *gin.Context) { +func (h *ConsoleHandler) DeleteFile(c *gin.Context) { sessionID := c.Param("sessionId") userID := c.GetString("user_id") @@ -595,7 +606,7 @@ func (h *Handler) DeleteFile(c *gin.Context) { } // RenameFile renames a file or directory -func (h *Handler) RenameFile(c *gin.Context) { +func (h *ConsoleHandler) RenameFile(c *gin.Context) { sessionID := c.Param("sessionId") userID := c.GetString("user_id") @@ -647,14 +658,33 @@ func (h *Handler) RenameFile(c *gin.Context) { // Helper functions -func (h *Handler) getSessionBasePath(sessionID string) string { +func (h *ConsoleHandler) canAccessSession(userID, sessionID string) bool { + // Check if user owns the session + var owner string + err := h.DB.DB().QueryRow("SELECT user_id FROM sessions WHERE id = $1", sessionID).Scan(&owner) + if err == nil && owner == userID { + return true + } + + // Check shared access + var hasAccess bool + err = h.DB.DB().QueryRow(` + SELECT EXISTS( + SELECT 1 FROM session_shares + WHERE session_id = $1 AND shared_with_user_id = $2 + ) + `, sessionID, userID).Scan(&hasAccess) + return err == nil && hasAccess +} + +func (h *ConsoleHandler) getSessionBasePath(sessionID string) string { // In production, this would return the actual path to the session's persistent volume // For now, return a placeholder return fmt.Sprintf("/var/streamspace/sessions/%s", sessionID) } -func (h *Handler) logFileOperation(sessionID, userID, operation, sourcePath, targetPath string, bytesProcessed int64) { - h.DB.Exec(` +func (h *ConsoleHandler) logFileOperation(sessionID, userID, operation, sourcePath, targetPath string, bytesProcessed int64) { + h.DB.DB().Exec(` INSERT INTO console_file_operations ( session_id, user_id, operation, source_path, target_path, bytes_processed ) VALUES ($1, $2, $3, $4, $5, $6) @@ -662,7 +692,7 @@ func (h *Handler) logFileOperation(sessionID, userID, operation, sourcePath, tar } // GetFileOperationHistory retrieves file operation history -func (h *Handler) GetFileOperationHistory(c *gin.Context) { +func (h *ConsoleHandler) GetFileOperationHistory(c *gin.Context) { sessionID := c.Param("sessionId") userID := c.GetString("user_id") page, _ := strconv.Atoi(c.DefaultQuery("page", "1")) @@ -676,12 +706,12 @@ func (h *Handler) GetFileOperationHistory(c *gin.Context) { // Count total var total int - h.DB.QueryRow(` + h.DB.DB().QueryRow(` SELECT COUNT(*) FROM console_file_operations WHERE session_id = $1 `, sessionID).Scan(&total) // Get operations - rows, err := h.DB.Query(` + rows, err := h.DB.DB().Query(` SELECT id, operation, source_path, target_path, bytes_processed, created_at FROM console_file_operations WHERE session_id = $1 diff --git a/api/internal/handlers/integrations.go b/api/internal/handlers/integrations.go index ef045855..9293e02c 100644 --- a/api/internal/handlers/integrations.go +++ b/api/internal/handlers/integrations.go @@ -43,25 +43,32 @@ import ( "bytes" "crypto/hmac" "crypto/sha256" - "crypto/tls" "database/sql" "encoding/hex" "encoding/json" "fmt" "io" - "log" "net" "net/http" - "net/smtp" "net/url" - "os" "strconv" "strings" "time" "github.com/gin-gonic/gin" + "github.com/streamspace/streamspace/api/internal/db" ) +// IntegrationsHandler handles webhook and external integration requests. +type IntegrationsHandler struct { + DB *db.Database +} + +// NewIntegrationsHandler creates a new integrations handler. +func NewIntegrationsHandler(database *db.Database) *IntegrationsHandler { + return &IntegrationsHandler{DB: database} +} + // ============================================================================ // INPUT VALIDATION // ============================================================================ @@ -264,7 +271,7 @@ var AvailableEvents = []string{ } // CreateWebhook creates a new webhook -func (h *Handler) CreateWebhook(c *gin.Context) { +func (h *IntegrationsHandler) CreateWebhook(c *gin.Context) { var webhook Webhook if err := c.ShouldBindJSON(&webhook); err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) @@ -306,7 +313,7 @@ func (h *Handler) CreateWebhook(c *gin.Context) { webhook.Secret = h.generateWebhookSecret() } - err := h.DB.QueryRow(` + err := h.DB.DB().QueryRow(` INSERT INTO webhooks ( name, description, url, secret, events, headers, enabled, retry_policy, filters, metadata, created_by @@ -330,7 +337,7 @@ func (h *Handler) CreateWebhook(c *gin.Context) { } // ListWebhooks lists all webhooks -func (h *Handler) ListWebhooks(c *gin.Context) { +func (h *IntegrationsHandler) ListWebhooks(c *gin.Context) { enabled := c.Query("enabled") query := ` @@ -349,7 +356,7 @@ func (h *Handler) ListWebhooks(c *gin.Context) { query += " ORDER BY created_at DESC" - rows, err := h.DB.Query(query, args...) + rows, err := h.DB.DB().Query(query, args...) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to retrieve webhooks"}) return @@ -409,7 +416,7 @@ func (h *Handler) ListWebhooks(c *gin.Context) { } // UpdateWebhook updates an existing webhook -func (h *Handler) UpdateWebhook(c *gin.Context) { +func (h *IntegrationsHandler) UpdateWebhook(c *gin.Context) { webhookID, err := strconv.ParseInt(c.Param("webhookId"), 10, 64) if err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": "invalid webhook ID"}) @@ -448,7 +455,7 @@ func (h *Handler) UpdateWebhook(c *gin.Context) { var result sql.Result if role == "admin" { // Admins can update any webhook - result, err = h.DB.Exec(` + result, err = h.DB.DB().Exec(` UPDATE webhooks SET name = $1, description = $2, url = $3, events = $4, headers = $5, enabled = $6, retry_policy = $7, filters = $8, metadata = $9, @@ -459,7 +466,7 @@ func (h *Handler) UpdateWebhook(c *gin.Context) { toJSONB(webhook.Filters), toJSONB(webhook.Metadata), time.Now(), webhookID) } else { // Non-admins can only update their own webhooks - result, err = h.DB.Exec(` + result, err = h.DB.DB().Exec(` UPDATE webhooks SET name = $1, description = $2, url = $3, events = $4, headers = $5, enabled = $6, retry_policy = $7, filters = $8, metadata = $9, @@ -487,7 +494,7 @@ func (h *Handler) UpdateWebhook(c *gin.Context) { } // DeleteWebhook deletes a webhook -func (h *Handler) DeleteWebhook(c *gin.Context) { +func (h *IntegrationsHandler) DeleteWebhook(c *gin.Context) { webhookID, err := strconv.ParseInt(c.Param("webhookId"), 10, 64) if err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": "invalid webhook ID"}) @@ -502,10 +509,10 @@ func (h *Handler) DeleteWebhook(c *gin.Context) { var result sql.Result if role == "admin" { // Admins can delete any webhook - result, err = h.DB.Exec("DELETE FROM webhooks WHERE id = $1", webhookID) + result, err = h.DB.DB().Exec("DELETE FROM webhooks WHERE id = $1", webhookID) } else { // Non-admins can only delete their own webhooks - result, err = h.DB.Exec("DELETE FROM webhooks WHERE id = $1 AND created_by = $2", webhookID, userID) + result, err = h.DB.DB().Exec("DELETE FROM webhooks WHERE id = $1 AND created_by = $2", webhookID, userID) } if err != nil { @@ -525,7 +532,7 @@ func (h *Handler) DeleteWebhook(c *gin.Context) { } // TestWebhook sends a test event to a webhook -func (h *Handler) TestWebhook(c *gin.Context) { +func (h *IntegrationsHandler) TestWebhook(c *gin.Context) { webhookID, err := strconv.ParseInt(c.Param("webhookId"), 10, 64) if err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": "invalid webhook ID"}) @@ -542,14 +549,14 @@ func (h *Handler) TestWebhook(c *gin.Context) { if role == "admin" { // Admins can test any webhook - err = h.DB.QueryRow(` + err = h.DB.DB().QueryRow(` SELECT id, name, url, secret, events, headers, enabled, retry_policy FROM webhooks WHERE id = $1 `, webhookID).Scan(&webhook.ID, &webhook.Name, &webhook.URL, &webhook.Secret, &events, &headers, &webhook.Enabled, &retryPolicy) } else { // Non-admins can only test their own webhooks - err = h.DB.QueryRow(` + err = h.DB.DB().QueryRow(` SELECT id, name, url, secret, events, headers, enabled, retry_policy FROM webhooks WHERE id = $1 AND created_by = $2 `, webhookID, userID).Scan(&webhook.ID, &webhook.Name, &webhook.URL, &webhook.Secret, @@ -606,7 +613,7 @@ func (h *Handler) TestWebhook(c *gin.Context) { } // GetWebhookDeliveries retrieves delivery history -func (h *Handler) GetWebhookDeliveries(c *gin.Context) { +func (h *IntegrationsHandler) GetWebhookDeliveries(c *gin.Context) { webhookID, err := strconv.ParseInt(c.Param("webhookId"), 10, 64) if err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": "invalid webhook ID"}) @@ -618,9 +625,9 @@ func (h *Handler) GetWebhookDeliveries(c *gin.Context) { // Count total var total int - h.DB.QueryRow("SELECT COUNT(*) FROM webhook_deliveries WHERE webhook_id = $1", webhookID).Scan(&total) + h.DB.DB().QueryRow("SELECT COUNT(*) FROM webhook_deliveries WHERE webhook_id = $1", webhookID).Scan(&total) - rows, err := h.DB.Query(` + rows, err := h.DB.DB().Query(` SELECT id, webhook_id, event, payload, status, status_code, response_body, error_message, attempts, next_retry_at, delivered_at, created_at FROM webhook_deliveries @@ -664,7 +671,7 @@ func (h *Handler) GetWebhookDeliveries(c *gin.Context) { // Integrations // CreateIntegration creates a new integration -func (h *Handler) CreateIntegration(c *gin.Context) { +func (h *IntegrationsHandler) CreateIntegration(c *gin.Context) { var integration Integration if err := c.ShouldBindJSON(&integration); err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) @@ -683,7 +690,7 @@ func (h *Handler) CreateIntegration(c *gin.Context) { return } - err := h.DB.QueryRow(` + err := h.DB.DB().QueryRow(` INSERT INTO integrations ( type, name, description, config, enabled, events, test_mode, created_by ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8) @@ -701,7 +708,7 @@ func (h *Handler) CreateIntegration(c *gin.Context) { } // ListIntegrations lists all integrations -func (h *Handler) ListIntegrations(c *gin.Context) { +func (h *IntegrationsHandler) ListIntegrations(c *gin.Context) { integrationType := c.Query("type") enabled := c.Query("enabled") @@ -727,7 +734,7 @@ func (h *Handler) ListIntegrations(c *gin.Context) { query += " ORDER BY created_at DESC" - rows, err := h.DB.Query(query, args...) + rows, err := h.DB.DB().Query(query, args...) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to retrieve integrations"}) return @@ -758,7 +765,7 @@ func (h *Handler) ListIntegrations(c *gin.Context) { } // TestIntegration tests an integration -func (h *Handler) TestIntegration(c *gin.Context) { +func (h *IntegrationsHandler) TestIntegration(c *gin.Context) { integrationID, err := strconv.ParseInt(c.Param("integrationId"), 10, 64) if err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": "invalid integration ID"}) @@ -775,14 +782,14 @@ func (h *Handler) TestIntegration(c *gin.Context) { if role == "admin" { // Admins can test any integration - err = h.DB.QueryRow(` + err = h.DB.DB().QueryRow(` SELECT id, type, name, config, enabled, events FROM integrations WHERE id = $1 `, integrationID).Scan(&integration.ID, &integration.Type, &integration.Name, &config, &integration.Enabled, &events) } else { // Non-admins can only test their own integrations - err = h.DB.QueryRow(` + err = h.DB.DB().QueryRow(` SELECT id, type, name, config, enabled, events FROM integrations WHERE id = $1 AND created_by = $2 `, integrationID, userID).Scan(&integration.ID, &integration.Type, &integration.Name, @@ -806,10 +813,10 @@ func (h *Handler) TestIntegration(c *gin.Context) { success, message := h.testIntegration(integration) // Update last test time - h.DB.Exec("UPDATE integrations SET last_test_at = $1 WHERE id = $2", time.Now(), integrationID) + h.DB.DB().Exec("UPDATE integrations SET last_test_at = $1 WHERE id = $2", time.Now(), integrationID) if success { - h.DB.Exec("UPDATE integrations SET last_success_at = $1 WHERE id = $2", time.Now(), integrationID) + h.DB.DB().Exec("UPDATE integrations SET last_success_at = $1 WHERE id = $2", time.Now(), integrationID) c.JSON(http.StatusOK, gin.H{"success": true, "message": message}) } else { c.JSON(http.StatusBadRequest, gin.H{"success": false, "error": message}) @@ -819,7 +826,7 @@ func (h *Handler) TestIntegration(c *gin.Context) { // Helper functions // validateWebhookURL validates webhook URL to prevent SSRF attacks -func (h *Handler) validateWebhookURL(urlStr string) error { +func (h *IntegrationsHandler) validateWebhookURL(urlStr string) error { parsed, err := url.Parse(urlStr) if err != nil { return fmt.Errorf("invalid URL format: %w", err) @@ -878,12 +885,12 @@ func (h *Handler) validateWebhookURL(urlStr string) error { return nil } -func (h *Handler) generateWebhookSecret() string { +func (h *IntegrationsHandler) generateWebhookSecret() string { // Generate a random 32-byte secret return fmt.Sprintf("whsec_%d", time.Now().UnixNano()) } -func (h *Handler) deliverWebhook(webhook Webhook, event WebhookEvent) (bool, int, string, error) { +func (h *IntegrationsHandler) deliverWebhook(webhook Webhook, event WebhookEvent) (bool, int, string, error) { // Prepare payload payload, _ := json.Marshal(event) @@ -931,13 +938,13 @@ func (h *Handler) deliverWebhook(webhook Webhook, event WebhookEvent) (bool, int return success, resp.StatusCode, string(responseBody), nil } -func (h *Handler) calculateHMAC(payload []byte, secret string) string { +func (h *IntegrationsHandler) calculateHMAC(payload []byte, secret string) string { mac := hmac.New(sha256.New, []byte(secret)) mac.Write(payload) return hex.EncodeToString(mac.Sum(nil)) } -func (h *Handler) testIntegration(integration Integration) (bool, string) { +func (h *IntegrationsHandler) testIntegration(integration Integration) (bool, string) { // NOTE: Slack, Teams, Discord, PagerDuty, and Email integrations are now handled by plugins. // Users should install the respective plugins from the plugin marketplace instead. // @@ -957,6 +964,6 @@ func (h *Handler) testIntegration(integration Integration) (bool, string) { } // GetAvailableEvents returns list of available webhook events -func (h *Handler) GetAvailableEvents(c *gin.Context) { +func (h *IntegrationsHandler) GetAvailableEvents(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"events": AvailableEvents}) } diff --git a/api/internal/handlers/loadbalancing.go b/api/internal/handlers/loadbalancing.go index b6c11ff5..024ed4a1 100644 --- a/api/internal/handlers/loadbalancing.go +++ b/api/internal/handlers/loadbalancing.go @@ -66,12 +66,14 @@ import ( "context" "database/sql" "fmt" + "log" "net/http" "os" "path/filepath" "time" "github.com/gin-gonic/gin" + "github.com/streamspace/streamspace/api/internal/db" corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/client-go/kubernetes" @@ -81,6 +83,16 @@ import ( metricsclientset "k8s.io/metrics/pkg/client/clientset/versioned" ) +// LoadBalancingHandler handles load balancing and node distribution requests. +type LoadBalancingHandler struct { + DB *db.Database +} + +// NewLoadBalancingHandler creates a new load balancing handler. +func NewLoadBalancingHandler(database *db.Database) *LoadBalancingHandler { + return &LoadBalancingHandler{DB: database} +} + // ============================================================================ // LOAD BALANCING // ============================================================================ @@ -144,7 +156,7 @@ type NodeStatus struct { } // CreateLoadBalancingPolicy creates a new load balancing policy -func (h *Handler) CreateLoadBalancingPolicy(c *gin.Context) { +func (h *LoadBalancingHandler) CreateLoadBalancingPolicy(c *gin.Context) { createdBy := c.GetString("user_id") role := c.GetString("role") @@ -170,7 +182,7 @@ func (h *Handler) CreateLoadBalancingPolicy(c *gin.Context) { req.Enabled = true var id int64 - err := h.DB.QueryRow(` + err := h.DB.DB().QueryRow(` INSERT INTO load_balancing_policies (name, description, strategy, enabled, session_affinity, health_check_config, node_selector, node_weights, geo_preferences, resource_thresholds, metadata, created_by) @@ -195,8 +207,8 @@ func (h *Handler) CreateLoadBalancingPolicy(c *gin.Context) { } // ListLoadBalancingPolicies lists all load balancing policies -func (h *Handler) ListLoadBalancingPolicies(c *gin.Context) { - rows, err := h.DB.Query(` +func (h *LoadBalancingHandler) ListLoadBalancingPolicies(c *gin.Context) { + rows, err := h.DB.DB().Query(` SELECT id, name, description, strategy, enabled, session_affinity, health_check_config, node_selector, node_weights, geo_preferences, resource_thresholds, metadata, created_by, created_at, updated_at @@ -227,7 +239,7 @@ func (h *Handler) ListLoadBalancingPolicies(c *gin.Context) { } // GetNodeStatus gets current status of all cluster nodes -func (h *Handler) GetNodeStatus(c *gin.Context) { +func (h *LoadBalancingHandler) GetNodeStatus(c *gin.Context) { // Try to fetch real node metrics from Kubernetes API // If K8s integration is not available, fall back to database nodes, err := h.fetchKubernetesNodeMetrics() @@ -244,8 +256,8 @@ func (h *Handler) GetNodeStatus(c *gin.Context) { } // fetchNodeStatusFromDatabase fetches node status from database cache -func (h *Handler) fetchNodeStatusFromDatabase() ([]NodeStatus, error) { - rows, err := h.DB.Query(` +func (h *LoadBalancingHandler) fetchNodeStatusFromDatabase() ([]NodeStatus, error) { + rows, err := h.DB.DB().Query(` SELECT node_name, status, cpu_allocated, cpu_capacity, memory_allocated, memory_capacity, active_sessions, health_status, last_health_check, region, zone, labels, weight @@ -290,7 +302,7 @@ func (h *Handler) fetchNodeStatusFromDatabase() ([]NodeStatus, error) { } // fetchKubernetesNodeMetrics fetches real-time node metrics from Kubernetes API -func (h *Handler) fetchKubernetesNodeMetrics() ([]NodeStatus, error) { +func (h *LoadBalancingHandler) fetchKubernetesNodeMetrics() ([]NodeStatus, error) { ctx := context.Background() // Create Kubernetes config @@ -352,7 +364,7 @@ func (h *Handler) fetchKubernetesNodeMetrics() ([]NodeStatus, error) { } // getKubernetesConfig gets Kubernetes configuration from kubeconfig or in-cluster config -func (h *Handler) getKubernetesConfig() (*rest.Config, error) { +func (h *LoadBalancingHandler) getKubernetesConfig() (*rest.Config, error) { // Try in-cluster config first (for pods running in cluster) config, err := rest.InClusterConfig() if err == nil { @@ -379,7 +391,7 @@ func (h *Handler) getKubernetesConfig() (*rest.Config, error) { } // convertNodeToNodeStatus converts a Kubernetes Node to our NodeStatus struct -func (h *Handler) convertNodeToNodeStatus(node corev1.Node, metricsMap map[string]metricsv1beta1.NodeMetrics, sessionCounts map[string]int) NodeStatus { +func (h *LoadBalancingHandler) convertNodeToNodeStatus(node corev1.Node, metricsMap map[string]metricsv1beta1.NodeMetrics, sessionCounts map[string]int) NodeStatus { ns := NodeStatus{ NodeName: node.Name, Labels: node.Labels, @@ -466,8 +478,8 @@ func (h *Handler) convertNodeToNodeStatus(node corev1.Node, metricsMap map[strin } // getSessionCountsByNode gets the count of active sessions per node from database -func (h *Handler) getSessionCountsByNode() (map[string]int, error) { - rows, err := h.DB.Query(` +func (h *LoadBalancingHandler) getSessionCountsByNode() (map[string]int, error) { + rows, err := h.DB.DB().Query(` SELECT node_name, COUNT(*) as session_count FROM sessions WHERE state = 'running' AND node_name IS NOT NULL @@ -492,10 +504,10 @@ func (h *Handler) getSessionCountsByNode() (map[string]int, error) { } // cacheNodeStatusInDatabase caches node status in database for fallback -func (h *Handler) cacheNodeStatusInDatabase(nodes []NodeStatus) { +func (h *LoadBalancingHandler) cacheNodeStatusInDatabase(nodes []NodeStatus) { for _, node := range nodes { // Use UPSERT pattern to update or insert - h.DB.Exec(` + h.DB.DB().Exec(` INSERT INTO node_status (node_name, status, cpu_allocated, cpu_capacity, memory_allocated, memory_capacity, active_sessions, health_status, last_health_check, region, zone, labels, weight) @@ -522,7 +534,7 @@ func (h *Handler) cacheNodeStatusInDatabase(nodes []NodeStatus) { } // scaleKubernetesDeployment scales a Kubernetes deployment to the specified replica count -func (h *Handler) scaleKubernetesDeployment(deploymentName string, replicas int) error { +func (h *LoadBalancingHandler) scaleKubernetesDeployment(deploymentName string, replicas int) error { ctx := context.Background() // Create Kubernetes config @@ -571,7 +583,7 @@ func (h *Handler) scaleKubernetesDeployment(deploymentName string, replicas int) namespace, deploymentName, originalReplicas, replicas) // Also store in database queue as audit trail - h.DB.Exec(` + h.DB.DB().Exec(` INSERT INTO deployment_scaling_queue (deployment_name, namespace, target_replicas, status, created_at) VALUES ($1, $2, $3, 'completed', NOW()) `, deploymentName, namespace, replicas) @@ -593,23 +605,11 @@ func calculateClusterTotals(nodes []NodeStatus) (totalCPU, usedCPU float64, tota totalSessions += node.ActiveSessions } - c.JSON(http.StatusOK, gin.H{ - "nodes": nodes, - "cluster_summary": gin.H{ - "total_nodes": len(nodes), - "cpu_capacity": totalCPU, - "cpu_used": usedCPU, - "cpu_percent": (usedCPU / totalCPU) * 100, - "memory_capacity": totalMemory, - "memory_used": usedMemory, - "memory_percent": (float64(usedMemory) / float64(totalMemory)) * 100, - "active_sessions": totalSessions, - }, - }) + return totalCPU, usedCPU, totalMemory, usedMemory, totalSessions } // SelectNode selects best node for a new session based on policy -func (h *Handler) SelectNode(c *gin.Context) { +func (h *LoadBalancingHandler) SelectNode(c *gin.Context) { var req struct { PolicyID int64 `json:"policy_id,omitempty"` RequiredCPU float64 `json:"required_cpu"` @@ -630,11 +630,11 @@ func (h *Handler) SelectNode(c *gin.Context) { // If no policy specified, get default policy if policyID == 0 { - h.DB.QueryRow(`SELECT id FROM load_balancing_policies WHERE enabled = true ORDER BY id LIMIT 1`).Scan(&policyID) + h.DB.DB().QueryRow(`SELECT id FROM load_balancing_policies WHERE enabled = true ORDER BY id LIMIT 1`).Scan(&policyID) } if policyID > 0 { - h.DB.QueryRow(` + h.DB.DB().QueryRow(` SELECT strategy, resource_thresholds, geo_preferences, node_weights FROM load_balancing_policies WHERE id = $1 `, policyID).Scan(&policy.Strategy, &policy.ResourceThresholds, @@ -645,7 +645,7 @@ func (h *Handler) SelectNode(c *gin.Context) { } // Get available nodes - rows, err := h.DB.Query(` + rows, err := h.DB.DB().Query(` SELECT node_name, cpu_allocated, cpu_capacity, memory_allocated, memory_capacity, active_sessions, health_status, region, weight FROM node_status @@ -823,7 +823,7 @@ type ScalingEvent struct { } // CreateAutoScalingPolicy creates a new auto-scaling policy -func (h *Handler) CreateAutoScalingPolicy(c *gin.Context) { +func (h *LoadBalancingHandler) CreateAutoScalingPolicy(c *gin.Context) { createdBy := c.GetString("user_id") role := c.GetString("role") @@ -848,7 +848,7 @@ func (h *Handler) CreateAutoScalingPolicy(c *gin.Context) { req.Enabled = true var id int64 - err := h.DB.QueryRow(` + err := h.DB.DB().QueryRow(` INSERT INTO autoscaling_policies (name, description, target_type, target_id, enabled, scaling_mode, min_replicas, max_replicas, metric_type, target_metric_value, scale_up_policy, scale_down_policy, @@ -875,8 +875,8 @@ func (h *Handler) CreateAutoScalingPolicy(c *gin.Context) { } // ListAutoScalingPolicies lists all auto-scaling policies -func (h *Handler) ListAutoScalingPolicies(c *gin.Context) { - rows, err := h.DB.Query(` +func (h *LoadBalancingHandler) ListAutoScalingPolicies(c *gin.Context) { + rows, err := h.DB.DB().Query(` SELECT id, name, description, target_type, target_id, enabled, scaling_mode, min_replicas, max_replicas, metric_type, target_metric_value, scale_up_policy, scale_down_policy, predictive_scaling, cooldown_period, @@ -908,7 +908,7 @@ func (h *Handler) ListAutoScalingPolicies(c *gin.Context) { } // TriggerScaling manually triggers a scaling action -func (h *Handler) TriggerScaling(c *gin.Context) { +func (h *LoadBalancingHandler) TriggerScaling(c *gin.Context) { policyID := c.Param("policyId") var req struct { @@ -924,7 +924,7 @@ func (h *Handler) TriggerScaling(c *gin.Context) { // Get policy var policy AutoScalingPolicy - err := h.DB.QueryRow(` + err := h.DB.DB().QueryRow(` SELECT target_type, target_id, min_replicas, max_replicas, scale_up_policy, scale_down_policy FROM autoscaling_policies WHERE id = $1 AND enabled = true `, policyID).Scan(&policy.TargetType, &policy.TargetID, &policy.MinReplicas, @@ -999,7 +999,7 @@ func (h *Handler) TriggerScaling(c *gin.Context) { // Record scaling event var eventID int64 - err = h.DB.QueryRow(` + err = h.DB.DB().QueryRow(` INSERT INTO scaling_events (policy_id, target_type, target_id, action, previous_replicas, new_replicas, trigger, reason, status) @@ -1017,14 +1017,14 @@ func (h *Handler) TriggerScaling(c *gin.Context) { err = h.scaleKubernetesDeployment(policy.TargetID, newReplicas) if err != nil { // Update event status to failed - h.DB.Exec(`UPDATE scaling_events SET status = 'failed', error_message = $1 WHERE id = $2`, + h.DB.DB().Exec(`UPDATE scaling_events SET status = 'failed', error_message = $1 WHERE id = $2`, err.Error(), eventID) c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("scaling failed: %v", err)}) return } // Update event status to completed - h.DB.Exec(`UPDATE scaling_events SET status = 'completed' WHERE id = $1`, eventID) + h.DB.DB().Exec(`UPDATE scaling_events SET status = 'completed' WHERE id = $1`, eventID) c.JSON(http.StatusOK, gin.H{ "event_id": eventID, @@ -1036,7 +1036,7 @@ func (h *Handler) TriggerScaling(c *gin.Context) { } // GetScalingHistory gets scaling event history -func (h *Handler) GetScalingHistory(c *gin.Context) { +func (h *LoadBalancingHandler) GetScalingHistory(c *gin.Context) { policyID := c.Query("policy_id") limit := c.DefaultQuery("limit", "50") @@ -1044,7 +1044,7 @@ func (h *Handler) GetScalingHistory(c *gin.Context) { var err error if policyID != "" { - rows, err = h.DB.Query(` + rows, err = h.DB.DB().Query(` SELECT id, policy_id, target_type, target_id, action, previous_replicas, new_replicas, trigger, metric_value, reason, status, created_at FROM scaling_events @@ -1053,7 +1053,7 @@ func (h *Handler) GetScalingHistory(c *gin.Context) { LIMIT $2 `, policyID, limit) } else { - rows, err = h.DB.Query(` + rows, err = h.DB.DB().Query(` SELECT id, policy_id, target_type, target_id, action, previous_replicas, new_replicas, trigger, metric_value, reason, status, created_at FROM scaling_events diff --git a/api/internal/handlers/notifications.go b/api/internal/handlers/notifications.go index 4fc8b09b..a11acd52 100644 --- a/api/internal/handlers/notifications.go +++ b/api/internal/handlers/notifications.go @@ -643,9 +643,9 @@ func (h *NotificationsHandler) sendWebhookNotification(prefs map[string]interfac webhookSecret = "default-secret" } - h := hmac.New(sha256.New, []byte(webhookSecret)) - h.Write(payloadJSON) - signature := hex.EncodeToString(h.Sum(nil)) + mac := hmac.New(sha256.New, []byte(webhookSecret)) + mac.Write(payloadJSON) + signature := hex.EncodeToString(mac.Sum(nil)) // Send HTTP POST request req, err := http.NewRequest("POST", webhookURL, bytes.NewBuffer(payloadJSON)) diff --git a/api/internal/handlers/plugin_marketplace.go b/api/internal/handlers/plugin_marketplace.go index bd582aee..f5876ff4 100644 --- a/api/internal/handlers/plugin_marketplace.go +++ b/api/internal/handlers/plugin_marketplace.go @@ -618,7 +618,7 @@ func (h *PluginMarketplaceHandler) GetInstalledPlugin(c *gin.Context) { // - 200: Config updated (currently always succeeds - TODO) // - 400: Invalid request body func (h *PluginMarketplaceHandler) UpdatePluginConfig(c *gin.Context) { - name := c.Param("name") + _ = c.Param("name") // Plugin name not used - config update handled generically var req struct { Config map[string]interface{} `json:"config"` diff --git a/api/internal/handlers/scheduling.go b/api/internal/handlers/scheduling.go index 326bb9b9..011abcd1 100644 --- a/api/internal/handlers/scheduling.go +++ b/api/internal/handlers/scheduling.go @@ -79,8 +79,19 @@ import ( "github.com/gin-gonic/gin" "github.com/robfig/cron/v3" + "github.com/streamspace/streamspace/api/internal/db" ) +// SchedulingHandler handles session scheduling and calendar integration requests. +type SchedulingHandler struct { + DB *db.Database +} + +// NewSchedulingHandler creates a new scheduling handler. +func NewSchedulingHandler(database *db.Database) *SchedulingHandler { + return &SchedulingHandler{DB: database} +} + // ============================================================================ // SESSION SCHEDULING - DATA STRUCTURES // ============================================================================ @@ -210,7 +221,7 @@ type ResourceConfig struct { // - User can only create schedules for themselves (userID enforced) // - Schedule is validated to prevent malicious cron expressions // - Timezone must be valid IANA timezone name -func (h *Handler) CreateScheduledSession(c *gin.Context) { +func (h *SchedulingHandler) CreateScheduledSession(c *gin.Context) { userID := c.GetString("user_id") var req ScheduledSession @@ -264,7 +275,7 @@ func (h *Handler) CreateScheduledSession(c *gin.Context) { // Insert scheduled session var id int64 - err = h.DB.QueryRow(` + err = h.DB.DB().QueryRow(` INSERT INTO scheduled_sessions (user_id, template_id, name, description, timezone, schedule, resources, auto_terminate, terminate_after, pre_warm, pre_warm_minutes, post_cleanup, @@ -292,7 +303,7 @@ func (h *Handler) CreateScheduledSession(c *gin.Context) { } // ListScheduledSessions lists all scheduled sessions for a user -func (h *Handler) ListScheduledSessions(c *gin.Context) { +func (h *SchedulingHandler) ListScheduledSessions(c *gin.Context) { userID := c.GetString("user_id") role := c.GetString("role") @@ -307,7 +318,7 @@ func (h *Handler) ListScheduledSessions(c *gin.Context) { ORDER BY next_run_at ASC ` - rows, err := h.DB.Query(query, userID, role) + rows, err := h.DB.DB().Query(query, userID, role) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": "database error"}) return @@ -353,7 +364,7 @@ func (h *Handler) ListScheduledSessions(c *gin.Context) { } // GetScheduledSession gets details of a scheduled session -func (h *Handler) GetScheduledSession(c *gin.Context) { +func (h *SchedulingHandler) GetScheduledSession(c *gin.Context) { scheduleID := c.Param("scheduleId") userID := c.GetString("user_id") role := c.GetString("role") @@ -362,7 +373,7 @@ func (h *Handler) GetScheduledSession(c *gin.Context) { var lastRun, nextRun sql.NullTime var lastSessionID, lastStatus sql.NullString - err := h.DB.QueryRow(` + err := h.DB.DB().QueryRow(` SELECT id, user_id, template_id, name, description, timezone, schedule, resources, auto_terminate, terminate_after, pre_warm, pre_warm_minutes, post_cleanup, enabled, next_run_at, last_run_at, last_session_id, @@ -401,7 +412,7 @@ func (h *Handler) GetScheduledSession(c *gin.Context) { } // UpdateScheduledSession updates a scheduled session -func (h *Handler) UpdateScheduledSession(c *gin.Context) { +func (h *SchedulingHandler) UpdateScheduledSession(c *gin.Context) { scheduleID := c.Param("scheduleId") userID := c.GetString("user_id") role := c.GetString("role") @@ -414,7 +425,7 @@ func (h *Handler) UpdateScheduledSession(c *gin.Context) { // Check ownership var ownerID string - err := h.DB.QueryRow(`SELECT user_id FROM scheduled_sessions WHERE id = $1`, scheduleID).Scan(&ownerID) + err := h.DB.DB().QueryRow(`SELECT user_id FROM scheduled_sessions WHERE id = $1`, scheduleID).Scan(&ownerID) if err == sql.ErrNoRows { c.JSON(http.StatusNotFound, gin.H{"error": "scheduled session not found"}) return @@ -436,7 +447,7 @@ func (h *Handler) UpdateScheduledSession(c *gin.Context) { } } - _, err = h.DB.Exec(` + _, err = h.DB.DB().Exec(` UPDATE scheduled_sessions SET name = COALESCE(NULLIF($1, ''), name), description = $2, @@ -461,14 +472,14 @@ func (h *Handler) UpdateScheduledSession(c *gin.Context) { } // DeleteScheduledSession deletes a scheduled session -func (h *Handler) DeleteScheduledSession(c *gin.Context) { +func (h *SchedulingHandler) DeleteScheduledSession(c *gin.Context) { scheduleID := c.Param("scheduleId") userID := c.GetString("user_id") role := c.GetString("role") // Check ownership var ownerID string - err := h.DB.QueryRow(`SELECT user_id FROM scheduled_sessions WHERE id = $1`, scheduleID).Scan(&ownerID) + err := h.DB.DB().QueryRow(`SELECT user_id FROM scheduled_sessions WHERE id = $1`, scheduleID).Scan(&ownerID) if err == sql.ErrNoRows { c.JSON(http.StatusNotFound, gin.H{"error": "scheduled session not found"}) return @@ -478,7 +489,7 @@ func (h *Handler) DeleteScheduledSession(c *gin.Context) { return } - _, err = h.DB.Exec(`DELETE FROM scheduled_sessions WHERE id = $1`, scheduleID) + _, err = h.DB.DB().Exec(`DELETE FROM scheduled_sessions WHERE id = $1`, scheduleID) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to delete"}) return @@ -488,11 +499,11 @@ func (h *Handler) DeleteScheduledSession(c *gin.Context) { } // EnableScheduledSession enables a schedule -func (h *Handler) EnableScheduledSession(c *gin.Context) { +func (h *SchedulingHandler) EnableScheduledSession(c *gin.Context) { scheduleID := c.Param("scheduleId") userID := c.GetString("user_id") - _, err := h.DB.Exec(` + _, err := h.DB.DB().Exec(` UPDATE scheduled_sessions SET enabled = true, updated_at = NOW() WHERE id = $1 AND user_id = $2 `, scheduleID, userID) @@ -506,11 +517,11 @@ func (h *Handler) EnableScheduledSession(c *gin.Context) { } // DisableScheduledSession disables a schedule -func (h *Handler) DisableScheduledSession(c *gin.Context) { +func (h *SchedulingHandler) DisableScheduledSession(c *gin.Context) { scheduleID := c.Param("scheduleId") userID := c.GetString("user_id") - _, err := h.DB.Exec(` + _, err := h.DB.DB().Exec(` UPDATE scheduled_sessions SET enabled = false, updated_at = NOW() WHERE id = $1 AND user_id = $2 `, scheduleID, userID) @@ -570,7 +581,7 @@ type CalendarEvent struct { // ============================================================================ // ConnectCalendar initiates calendar OAuth flow -func (h *Handler) ConnectCalendar(c *gin.Context) { +func (h *SchedulingHandler) ConnectCalendar(c *gin.Context) { userID := c.GetString("user_id") var req struct { @@ -602,7 +613,7 @@ func (h *Handler) ConnectCalendar(c *gin.Context) { } // CalendarOAuthCallback handles OAuth callback -func (h *Handler) CalendarOAuthCallback(c *gin.Context) { +func (h *SchedulingHandler) CalendarOAuthCallback(c *gin.Context) { provider := c.Query("provider") code := c.Query("code") state := c.Query("state") // Contains userID @@ -615,6 +626,7 @@ func (h *Handler) CalendarOAuthCallback(c *gin.Context) { // Exchange code for tokens (implementation depends on provider) var accessToken, refreshToken, email string var expiry time.Time + var err error // Implement OAuth token exchange based on provider switch provider { @@ -634,7 +646,7 @@ func (h *Handler) CalendarOAuthCallback(c *gin.Context) { // Store integration var id int64 - err := h.DB.QueryRow(` + err = h.DB.DB().QueryRow(` INSERT INTO calendar_integrations (user_id, provider, account_email, access_token, refresh_token, token_expiry, enabled, sync_enabled) VALUES ($1, $2, $3, $4, $5, $6, true, true) @@ -653,10 +665,10 @@ func (h *Handler) CalendarOAuthCallback(c *gin.Context) { } // ListCalendarIntegrations lists user's calendar integrations -func (h *Handler) ListCalendarIntegrations(c *gin.Context) { +func (h *SchedulingHandler) ListCalendarIntegrations(c *gin.Context) { userID := c.GetString("user_id") - rows, err := h.DB.Query(` + rows, err := h.DB.DB().Query(` SELECT id, provider, account_email, calendar_id, enabled, sync_enabled, auto_create_events, auto_update_events, last_synced_at, created_at FROM calendar_integrations @@ -699,11 +711,11 @@ func (h *Handler) ListCalendarIntegrations(c *gin.Context) { } // DisconnectCalendar removes a calendar integration -func (h *Handler) DisconnectCalendar(c *gin.Context) { +func (h *SchedulingHandler) DisconnectCalendar(c *gin.Context) { integrationID := c.Param("integrationId") userID := c.GetString("user_id") - result, err := h.DB.Exec(` + result, err := h.DB.DB().Exec(` DELETE FROM calendar_integrations WHERE id = $1 AND user_id = $2 `, integrationID, userID) @@ -723,13 +735,13 @@ func (h *Handler) DisconnectCalendar(c *gin.Context) { } // SyncCalendar manually triggers calendar sync -func (h *Handler) SyncCalendar(c *gin.Context) { +func (h *SchedulingHandler) SyncCalendar(c *gin.Context) { integrationID := c.Param("integrationId") userID := c.GetString("user_id") // Get integration details var ci CalendarIntegration - err := h.DB.QueryRow(` + err := h.DB.DB().QueryRow(` SELECT id, provider, access_token, refresh_token, calendar_id FROM calendar_integrations WHERE id = $1 AND user_id = $2 @@ -749,7 +761,7 @@ func (h *Handler) SyncCalendar(c *gin.Context) { } // Update last synced timestamp - h.DB.Exec(` + h.DB.DB().Exec(` UPDATE calendar_integrations SET last_synced_at = NOW() WHERE id = $1 @@ -763,11 +775,11 @@ func (h *Handler) SyncCalendar(c *gin.Context) { } // ExportICalendar exports scheduled sessions as iCal format -func (h *Handler) ExportICalendar(c *gin.Context) { +func (h *SchedulingHandler) ExportICalendar(c *gin.Context) { userID := c.GetString("user_id") // Get all enabled scheduled sessions - rows, err := h.DB.Query(` + rows, err := h.DB.DB().Query(` SELECT id, name, description, schedule, timezone, template_id FROM scheduled_sessions WHERE user_id = $1 AND enabled = true @@ -863,7 +875,7 @@ func (h *Handler) ExportICalendar(c *gin.Context) { // // - nil: Schedule is valid // - error: Descriptive error message indicating what's wrong -func (h *Handler) validateSchedule(schedule *ScheduleConfig) error { +func (h *SchedulingHandler) validateSchedule(schedule *ScheduleConfig) error { switch schedule.Type { case "once": // One-time schedule: requires specific start timestamp @@ -990,7 +1002,7 @@ func (h *Handler) validateSchedule(schedule *ScheduleConfig) error { // TimeOfDay: "14:00" // }, "America/New_York") // // Returns: next Monday or Wednesday at 2 PM, whichever comes first -func (h *Handler) calculateNextRun(schedule *ScheduleConfig, timezone string) (time.Time, error) { +func (h *SchedulingHandler) calculateNextRun(schedule *ScheduleConfig, timezone string) (time.Time, error) { // STEP 1: Load the user's timezone // If timezone is invalid, fall back to UTC to prevent errors // This allows schedules to still work even with misconfigured timezones @@ -1187,7 +1199,7 @@ func (h *Handler) calculateNextRun(schedule *ScheduleConfig, timezone string) (t // "America/New_York", // 240) // 4 hours // // Returns: [existing_schedule_id] because 2-6 PM overlaps with 9 AM-5 PM -func (h *Handler) checkSchedulingConflicts(userID string, schedule ScheduleConfig, timezone string, terminateAfterMinutes int) ([]int64, error) { +func (h *SchedulingHandler) checkSchedulingConflicts(userID string, schedule ScheduleConfig, timezone string, terminateAfterMinutes int) ([]int64, error) { // STEP 1: Calculate when the proposed schedule will next run // This gives us the start time for conflict detection proposedStart, err := h.calculateNextRun(&schedule, timezone) @@ -1212,7 +1224,7 @@ func (h *Handler) checkSchedulingConflicts(userID string, schedule ScheduleConfi WHERE user_id = $1 AND enabled = true ` - rows, err := h.DB.Query(query, userID) + rows, err := h.DB.DB().Query(query, userID) if err != nil { return nil, fmt.Errorf("failed to query schedules: %w", err) } @@ -1253,7 +1265,7 @@ func (h *Handler) checkSchedulingConflicts(userID string, schedule ScheduleConfi } // Get Google Calendar OAuth URL -func (h *Handler) getGoogleCalendarAuthURL(userID string) string { +func (h *SchedulingHandler) getGoogleCalendarAuthURL(userID string) string { // OAuth2 configuration for Google Calendar clientID := os.Getenv("GOOGLE_OAUTH_CLIENT_ID") if clientID == "" { @@ -1282,7 +1294,7 @@ func (h *Handler) getGoogleCalendarAuthURL(userID string) string { } // Get Outlook Calendar OAuth URL -func (h *Handler) getOutlookCalendarAuthURL(userID string) string { +func (h *SchedulingHandler) getOutlookCalendarAuthURL(userID string) string { // OAuth2 configuration for Microsoft Outlook clientID := os.Getenv("MICROSOFT_OAUTH_CLIENT_ID") if clientID == "" { @@ -1309,7 +1321,7 @@ func (h *Handler) getOutlookCalendarAuthURL(userID string) string { } // exchangeGoogleOAuthToken exchanges authorization code for access/refresh tokens -func (h *Handler) exchangeGoogleOAuthToken(code string) (accessToken, refreshToken, email string, expiry time.Time, err error) { +func (h *SchedulingHandler) exchangeGoogleOAuthToken(code string) (accessToken, refreshToken, email string, expiry time.Time, err error) { clientID := os.Getenv("GOOGLE_OAUTH_CLIENT_ID") clientSecret := os.Getenv("GOOGLE_OAUTH_CLIENT_SECRET") redirectURI := os.Getenv("GOOGLE_OAUTH_REDIRECT_URI") @@ -1379,7 +1391,7 @@ func (h *Handler) exchangeGoogleOAuthToken(code string) (accessToken, refreshTok } // getGoogleUserEmail fetches the user's email from Google userinfo API -func (h *Handler) getGoogleUserEmail(accessToken string) (string, error) { +func (h *SchedulingHandler) getGoogleUserEmail(accessToken string) (string, error) { req, err := http.NewRequest("GET", "https://www.googleapis.com/oauth2/v2/userinfo", nil) if err != nil { return "", err @@ -1411,7 +1423,7 @@ func (h *Handler) getGoogleUserEmail(accessToken string) (string, error) { } // exchangeOutlookOAuthToken exchanges authorization code for access/refresh tokens -func (h *Handler) exchangeOutlookOAuthToken(code string) (accessToken, refreshToken, email string, expiry time.Time, err error) { +func (h *SchedulingHandler) exchangeOutlookOAuthToken(code string) (accessToken, refreshToken, email string, expiry time.Time, err error) { clientID := os.Getenv("MICROSOFT_OAUTH_CLIENT_ID") clientSecret := os.Getenv("MICROSOFT_OAUTH_CLIENT_SECRET") redirectURI := os.Getenv("MICROSOFT_OAUTH_REDIRECT_URI") @@ -1482,7 +1494,7 @@ func (h *Handler) exchangeOutlookOAuthToken(code string) (accessToken, refreshTo } // getMicrosoftUserEmail fetches the user's email from Microsoft Graph API -func (h *Handler) getMicrosoftUserEmail(accessToken string) (string, error) { +func (h *SchedulingHandler) getMicrosoftUserEmail(accessToken string) (string, error) { req, err := http.NewRequest("GET", "https://graph.microsoft.com/v1.0/me", nil) if err != nil { return "", err @@ -1518,9 +1530,9 @@ func (h *Handler) getMicrosoftUserEmail(accessToken string) (string, error) { } // syncScheduledSessionsToCalendar syncs user's scheduled sessions to their calendar -func (h *Handler) syncScheduledSessionsToCalendar(userID string, ci *CalendarIntegration) (int, error) { +func (h *SchedulingHandler) syncScheduledSessionsToCalendar(userID string, ci *CalendarIntegration) (int, error) { // Fetch enabled scheduled sessions for the user - rows, err := h.DB.Query(` + rows, err := h.DB.DB().Query(` SELECT id, name, template_id, schedule, timezone, next_run_at, terminate_after FROM scheduled_sessions WHERE user_id = $1 AND enabled = true @@ -1567,7 +1579,7 @@ func (h *Handler) syncScheduledSessionsToCalendar(userID string, ci *CalendarInt } // Store the event ID for future updates/deletion - _, err = h.DB.Exec(` + _, err = h.DB.DB().Exec(` UPDATE scheduled_sessions SET calendar_event_id = $1 WHERE id = $2 @@ -1582,7 +1594,7 @@ func (h *Handler) syncScheduledSessionsToCalendar(userID string, ci *CalendarInt } // createGoogleCalendarEvent creates an event in Google Calendar -func (h *Handler) createGoogleCalendarEvent(ci *CalendarIntegration, title, description string, startTime time.Time, durationMinutes int) (string, error) { +func (h *SchedulingHandler) createGoogleCalendarEvent(ci *CalendarIntegration, title, description string, startTime time.Time, durationMinutes int) (string, error) { if ci.AccessToken == "" { return "", fmt.Errorf("no access token available") } @@ -1669,7 +1681,7 @@ func (h *Handler) createGoogleCalendarEvent(ci *CalendarIntegration, title, desc } // createOutlookCalendarEvent creates an event in Outlook Calendar -func (h *Handler) createOutlookCalendarEvent(ci *CalendarIntegration, title, description string, startTime time.Time, durationMinutes int) (string, error) { +func (h *SchedulingHandler) createOutlookCalendarEvent(ci *CalendarIntegration, title, description string, startTime time.Time, durationMinutes int) (string, error) { if ci.AccessToken == "" { return "", fmt.Errorf("no access token available") } diff --git a/api/internal/handlers/security.go b/api/internal/handlers/security.go index 53b3c04a..b794f103 100644 --- a/api/internal/handlers/security.go +++ b/api/internal/handlers/security.go @@ -39,6 +39,7 @@ import ( "github.com/gin-gonic/gin" "github.com/pquerna/otp/totp" + "github.com/streamspace/streamspace/api/internal/db" "github.com/streamspace/streamspace/api/internal/middleware" ) @@ -268,7 +269,18 @@ type TrustedDevice struct { // "qr_code": "otpauth://totp/StreamSpace:user123?secret=JBSWY3DP...", // "message": "Scan the QR code with your authenticator app and verify" // } -func (h *Handler) SetupMFA(c *gin.Context) { + +// SecurityHandler handles security-related endpoints (MFA, IP whitelisting, etc.) +type SecurityHandler struct { + DB *db.Database +} + +// NewSecurityHandler creates a new SecurityHandler instance +func NewSecurityHandler(database *db.Database) *SecurityHandler { + return &SecurityHandler{DB: database} +} + +func (h *SecurityHandler) SetupMFA(c *gin.Context) { userID := c.GetString("user_id") var req struct { @@ -304,7 +316,7 @@ func (h *Handler) SetupMFA(c *gin.Context) { // Check if MFA already exists var existingID int64 - err := h.DB.QueryRow(` + err := h.DB.DB().QueryRow(` SELECT id FROM mfa_methods WHERE user_id = $1 AND type = $2 `, userID, req.Type).Scan(&existingID) @@ -338,7 +350,7 @@ func (h *Handler) SetupMFA(c *gin.Context) { // Insert MFA method (not yet verified/enabled) var mfaID int64 - err = h.DB.QueryRow(` + err = h.DB.DB().QueryRow(` INSERT INTO mfa_methods (user_id, type, secret, phone_number, email, enabled, verified) VALUES ($1, $2, $3, $4, $5, false, false) RETURNING id @@ -365,7 +377,7 @@ func (h *Handler) SetupMFA(c *gin.Context) { } // VerifyMFASetup verifies and enables MFA method (Step 2: Confirm setup) -func (h *Handler) VerifyMFASetup(c *gin.Context) { +func (h *SecurityHandler) VerifyMFASetup(c *gin.Context) { userID := c.GetString("user_id") mfaID := c.Param("mfaId") @@ -380,7 +392,7 @@ func (h *Handler) VerifyMFASetup(c *gin.Context) { // Get MFA method (before transaction to verify code) var mfaMethod MFAMethod - err := h.DB.QueryRow(` + err := h.DB.DB().QueryRow(` SELECT id, user_id, type, secret, phone_number, email FROM mfa_methods WHERE id = $1 AND user_id = $2 @@ -409,7 +421,7 @@ func (h *Handler) VerifyMFASetup(c *gin.Context) { // SECURITY: Use transaction to ensure atomicity // Either both MFA enable AND backup codes succeed, or neither - tx, err := h.DB.Begin() + tx, err := h.DB.DB().Begin() if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": "database error"}) return @@ -498,7 +510,7 @@ func (h *Handler) VerifyMFASetup(c *gin.Context) { // - 404 Not Found: MFA method not enabled // - 429 Too Many Requests: Rate limit exceeded (>5 attempts/minute) // - 501 Not Implemented: SMS/Email MFA requested -func (h *Handler) VerifyMFA(c *gin.Context) { +func (h *SecurityHandler) VerifyMFA(c *gin.Context) { userID := c.GetString("user_id") var req struct { @@ -547,7 +559,7 @@ func (h *Handler) VerifyMFA(c *gin.Context) { } else { // Get MFA method var secret string - err := h.DB.QueryRow(` + err := h.DB.DB().QueryRow(` SELECT secret FROM mfa_methods WHERE user_id = $1 AND type = $2 AND enabled = true `, userID, req.MethodType).Scan(&secret) @@ -568,7 +580,7 @@ func (h *Handler) VerifyMFA(c *gin.Context) { // Update last used timestamp if valid { - h.DB.Exec(`UPDATE mfa_methods SET last_used_at = NOW() WHERE user_id = $1 AND type = $2`, + h.DB.DB().Exec(`UPDATE mfa_methods SET last_used_at = NOW() WHERE user_id = $1 AND type = $2`, userID, req.MethodType) } } @@ -594,10 +606,10 @@ func (h *Handler) VerifyMFA(c *gin.Context) { } // ListMFAMethods lists all MFA methods for a user -func (h *Handler) ListMFAMethods(c *gin.Context) { +func (h *SecurityHandler) ListMFAMethods(c *gin.Context) { userID := c.GetString("user_id") - rows, err := h.DB.Query(` + rows, err := h.DB.DB().Query(` SELECT id, type, enabled, verified, is_primary, phone_number, email, created_at, last_used_at FROM mfa_methods WHERE user_id = $1 @@ -637,11 +649,11 @@ func (h *Handler) ListMFAMethods(c *gin.Context) { } // DisableMFA disables an MFA method -func (h *Handler) DisableMFA(c *gin.Context) { +func (h *SecurityHandler) DisableMFA(c *gin.Context) { userID := c.GetString("user_id") mfaID := c.Param("mfaId") - result, err := h.DB.Exec(` + result, err := h.DB.DB().Exec(` UPDATE mfa_methods SET enabled = false WHERE id = $1 AND user_id = $2 `, mfaID, userID) @@ -661,11 +673,11 @@ func (h *Handler) DisableMFA(c *gin.Context) { } // GenerateBackupCodes generates new backup codes -func (h *Handler) GenerateBackupCodes(c *gin.Context) { +func (h *SecurityHandler) GenerateBackupCodes(c *gin.Context) { userID := c.GetString("user_id") // Invalidate old backup codes - h.DB.Exec(`DELETE FROM backup_codes WHERE user_id = $1`, userID) + h.DB.DB().Exec(`DELETE FROM backup_codes WHERE user_id = $1`, userID) // Generate new codes codes := h.generateBackupCodes(userID, BackupCodesCount) @@ -677,7 +689,7 @@ func (h *Handler) GenerateBackupCodes(c *gin.Context) { } // Helper: Generate backup codes -func (h *Handler) generateBackupCodes(userID string, count int) []string { +func (h *SecurityHandler) generateBackupCodes(userID string, count int) []string { codes := make([]string, count) for i := 0; i < count; i++ { @@ -688,7 +700,7 @@ func (h *Handler) generateBackupCodes(userID string, count int) []string { hash := sha256.Sum256([]byte(code)) hashStr := hex.EncodeToString(hash[:]) - h.DB.Exec(` + h.DB.DB().Exec(` INSERT INTO backup_codes (user_id, code) VALUES ($1, $2) `, userID, hashStr) @@ -698,12 +710,12 @@ func (h *Handler) generateBackupCodes(userID string, count int) []string { } // Helper: Verify backup code -func (h *Handler) verifyBackupCode(userID, code string) bool { +func (h *SecurityHandler) verifyBackupCode(userID, code string) bool { hash := sha256.Sum256([]byte(code)) hashStr := hex.EncodeToString(hash[:]) var codeID int64 - err := h.DB.QueryRow(` + err := h.DB.DB().QueryRow(` SELECT id FROM backup_codes WHERE user_id = $1 AND code = $2 AND used = false `, userID, hashStr).Scan(&codeID) @@ -713,7 +725,7 @@ func (h *Handler) verifyBackupCode(userID, code string) bool { } // Mark as used - h.DB.Exec(`UPDATE backup_codes SET used = true, used_at = NOW() WHERE id = $1`, codeID) + h.DB.DB().Exec(`UPDATE backup_codes SET used = true, used_at = NOW() WHERE id = $1`, codeID) return true } @@ -744,7 +756,7 @@ type GeoRestriction struct { } // CreateIPWhitelist adds an IP to whitelist -func (h *Handler) CreateIPWhitelist(c *gin.Context) { +func (h *SecurityHandler) CreateIPWhitelist(c *gin.Context) { createdBy := c.GetString("user_id") role := c.GetString("role") @@ -782,7 +794,7 @@ func (h *Handler) CreateIPWhitelist(c *gin.Context) { } var id int64 - err := h.DB.QueryRow(` + err := h.DB.DB().QueryRow(` INSERT INTO ip_whitelist (user_id, ip_address, description, enabled, created_by, expires_at) VALUES ($1, $2, $3, true, $4, $5) RETURNING id @@ -800,7 +812,7 @@ func (h *Handler) CreateIPWhitelist(c *gin.Context) { } // CheckIPAccess checks if an IP is allowed access -func (h *Handler) CheckIPAccess(c *gin.Context) { +func (h *SecurityHandler) CheckIPAccess(c *gin.Context) { userID := c.Query("user_id") ipAddress := c.Query("ip_address") @@ -818,14 +830,14 @@ func (h *Handler) CheckIPAccess(c *gin.Context) { } // Helper: Check if IP is allowed -func (h *Handler) isIPAllowed(userID, ipAddress string) bool { +func (h *SecurityHandler) isIPAllowed(userID, ipAddress string) bool { ip := net.ParseIP(ipAddress) if ip == nil { return false } // Check user-specific rules - rows, err := h.DB.Query(` + rows, err := h.DB.DB().Query(` SELECT ip_address FROM ip_whitelist WHERE (user_id = $1 OR user_id IS NULL) AND enabled = true @@ -864,7 +876,7 @@ func (h *Handler) isIPAllowed(userID, ipAddress string) bool { } // ListIPWhitelist lists IP whitelist entries -func (h *Handler) ListIPWhitelist(c *gin.Context) { +func (h *SecurityHandler) ListIPWhitelist(c *gin.Context) { userID := c.Query("user_id") role := c.GetString("role") @@ -880,7 +892,7 @@ func (h *Handler) ListIPWhitelist(c *gin.Context) { ORDER BY created_at DESC ` - rows, err := h.DB.Query(query, userID, role) + rows, err := h.DB.DB().Query(query, userID, role) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": "database error"}) return @@ -890,7 +902,7 @@ func (h *Handler) ListIPWhitelist(c *gin.Context) { entries := []IPWhitelist{} for rows.Next() { var entry IPWhitelist - var userID, expiresAt sql.NullString + var userID sql.NullString var expiresAtTime sql.NullTime err := rows.Scan(&entry.ID, &userID, &entry.IPAddress, &entry.Description, @@ -942,7 +954,7 @@ func (h *Handler) ListIPWhitelist(c *gin.Context) { // - 200 OK: Entry deleted successfully // - 404 Not Found: Entry doesn't exist OR user lacks permission (secure, no information leakage) // - 500 Internal Server Error: Database error -func (h *Handler) DeleteIPWhitelist(c *gin.Context) { +func (h *SecurityHandler) DeleteIPWhitelist(c *gin.Context) { entryID := c.Param("entryId") userID := c.GetString("user_id") role := c.GetString("role") @@ -954,10 +966,10 @@ func (h *Handler) DeleteIPWhitelist(c *gin.Context) { if role == "admin" { // Admins can delete any entry - result, err = h.DB.Exec(`DELETE FROM ip_whitelist WHERE id = $1`, entryID) + result, err = h.DB.DB().Exec(`DELETE FROM ip_whitelist WHERE id = $1`, entryID) } else { // Non-admins can only delete their own entries or org-wide entries (NULL user_id) - result, err = h.DB.Exec(` + result, err = h.DB.DB().Exec(` DELETE FROM ip_whitelist WHERE id = $1 AND (user_id = $2 OR user_id IS NULL) `, entryID, userID) @@ -1017,7 +1029,7 @@ type DevicePosture struct { } // VerifySession performs continuous session verification -func (h *Handler) VerifySession(c *gin.Context) { +func (h *SecurityHandler) VerifySession(c *gin.Context) { sessionID := c.Param("sessionId") userID := c.GetString("user_id") @@ -1040,7 +1052,7 @@ func (h *Handler) VerifySession(c *gin.Context) { // Record verification var verificationID int64 - err := h.DB.QueryRow(` + err := h.DB.DB().QueryRow(` INSERT INTO session_verifications (session_id, user_id, device_id, ip_address, risk_score, risk_level, verified) VALUES ($1, $2, $3, $4, $5, $6, $7) RETURNING id @@ -1067,7 +1079,7 @@ func (h *Handler) VerifySession(c *gin.Context) { } // CheckDevicePosture checks device security posture -func (h *Handler) CheckDevicePosture(c *gin.Context) { +func (h *SecurityHandler) CheckDevicePosture(c *gin.Context) { var req DevicePosture if err := c.ShouldBindJSON(&req); err != nil { @@ -1093,7 +1105,7 @@ func (h *Handler) CheckDevicePosture(c *gin.Context) { req.LastChecked = time.Now() // Store posture check result - h.DB.Exec(` + h.DB.DB().Exec(` INSERT INTO device_posture_checks (device_id, compliant, issues, checked_at) VALUES ($1, $2, $3, $4) `, req.DeviceID, req.Compliant, strings.Join(issues, ","), time.Now()) @@ -1102,10 +1114,10 @@ func (h *Handler) CheckDevicePosture(c *gin.Context) { } // GetSecurityAlerts gets security alerts for a user -func (h *Handler) GetSecurityAlerts(c *gin.Context) { +func (h *SecurityHandler) GetSecurityAlerts(c *gin.Context) { userID := c.GetString("user_id") - rows, err := h.DB.Query(` + rows, err := h.DB.DB().Query(` SELECT type, severity, message, details, created_at FROM security_alerts WHERE user_id = $1 AND acknowledged = false @@ -1141,7 +1153,7 @@ func (h *Handler) GetSecurityAlerts(c *gin.Context) { // ============================================================================ // Get device fingerprint from request -func (h *Handler) getDeviceFingerprint(c *gin.Context) string { +func (h *SecurityHandler) getDeviceFingerprint(c *gin.Context) string { // Simple fingerprint based on User-Agent and IP // In production, use more sophisticated fingerprinting data := c.Request.UserAgent() + c.ClientIP() @@ -1150,11 +1162,11 @@ func (h *Handler) getDeviceFingerprint(c *gin.Context) string { } // Trust a device for MFA bypass -func (h *Handler) trustDevice(userID, deviceID, userAgent, ipAddress string, duration time.Duration) { +func (h *SecurityHandler) trustDevice(userID, deviceID, userAgent, ipAddress string, duration time.Duration) { trustedUntil := time.Now().Add(duration) deviceName := fmt.Sprintf("%s from %s", userAgent, ipAddress) - h.DB.Exec(` + h.DB.DB().Exec(` INSERT INTO trusted_devices (user_id, device_id, device_name, user_agent, ip_address, trusted_until) VALUES ($1, $2, $3, $4, $5, $6) ON CONFLICT (user_id, device_id) DO UPDATE SET @@ -1164,12 +1176,12 @@ func (h *Handler) trustDevice(userID, deviceID, userAgent, ipAddress string, dur } // Calculate risk score (0-100) -func (h *Handler) calculateRiskScore(userID, deviceID, ipAddress, userAgent string) int { +func (h *SecurityHandler) calculateRiskScore(userID, deviceID, ipAddress, userAgent string) int { score := 0 // Check if device is trusted var trusted bool - err := h.DB.QueryRow(` + err := h.DB.DB().QueryRow(` SELECT EXISTS( SELECT 1 FROM trusted_devices WHERE user_id = $1 AND device_id = $2 AND trusted_until > NOW() @@ -1189,7 +1201,7 @@ func (h *Handler) calculateRiskScore(userID, deviceID, ipAddress, userAgent stri // Check for recent failed login attempts var failedAttempts int - h.DB.QueryRow(` + h.DB.DB().QueryRow(` SELECT COUNT(*) FROM audit_log WHERE user_id = $1 AND action = 'login_failed' AND created_at > NOW() - INTERVAL '1 hour' @@ -1199,7 +1211,7 @@ func (h *Handler) calculateRiskScore(userID, deviceID, ipAddress, userAgent stri // Check for location change var lastIP string - h.DB.QueryRow(` + h.DB.DB().QueryRow(` SELECT ip_address FROM session_verifications WHERE user_id = $1 ORDER BY created_at DESC LIMIT 1 `, userID).Scan(&lastIP) diff --git a/api/internal/handlers/sharing.go b/api/internal/handlers/sharing.go index e9b48716..aca56775 100644 --- a/api/internal/handlers/sharing.go +++ b/api/internal/handlers/sharing.go @@ -75,7 +75,6 @@ package handlers import ( "context" "database/sql" - "fmt" "log" "net/http" "time" @@ -165,9 +164,9 @@ func (h *SharingHandler) CreateShare(c *gin.Context) { } // Check if user exists - var exists bool - err = h.db.DB().QueryRowContext(ctx, `SELECT EXISTS(SELECT 1 FROM users WHERE id = $1)`, req.SharedWithUserId).Scan(&exists) - if err != nil || !exists { + var userExists bool + err = h.db.DB().QueryRowContext(ctx, `SELECT EXISTS(SELECT 1 FROM users WHERE id = $1)`, req.SharedWithUserId).Scan(&userExists) + if err != nil || !userExists { c.JSON(http.StatusBadRequest, gin.H{"error": "User not found"}) return } diff --git a/api/internal/handlers/template_versioning.go b/api/internal/handlers/template_versioning.go index 26297892..168d4f6a 100644 --- a/api/internal/handlers/template_versioning.go +++ b/api/internal/handlers/template_versioning.go @@ -76,6 +76,7 @@ import ( "time" "github.com/gin-gonic/gin" + "github.com/streamspace/streamspace/api/internal/db" ) // TemplateVersion represents a version of a template @@ -128,8 +129,18 @@ type TemplateInheritance struct { Metadata map[string]interface{} `json:"metadata"` } +// TemplateVersioningHandler handles template versioning endpoints +type TemplateVersioningHandler struct { + DB *db.Database +} + +// NewTemplateVersioningHandler creates a new TemplateVersioningHandler instance +func NewTemplateVersioningHandler(database *db.Database) *TemplateVersioningHandler { + return &TemplateVersioningHandler{DB: database} +} + // CreateTemplateVersion creates a new version of a template -func (h *Handler) CreateTemplateVersion(c *gin.Context) { +func (h *TemplateVersioningHandler) CreateTemplateVersion(c *gin.Context) { templateID := c.Param("templateId") userID := c.GetString("user_id") @@ -156,11 +167,11 @@ func (h *Handler) CreateTemplateVersion(c *gin.Context) { // If this is set as default, unset other defaults if req.IsDefault { - h.DB.Exec("UPDATE template_versions SET is_default = false WHERE template_id = $1", templateID) + h.DB.DB().Exec("UPDATE template_versions SET is_default = false WHERE template_id = $1", templateID) } var versionID int64 - err := h.DB.QueryRow(` + err := h.DB.DB().QueryRow(` INSERT INTO template_versions ( template_id, version, major_version, minor_version, patch_version, display_name, description, configuration, base_image, @@ -184,7 +195,7 @@ func (h *Handler) CreateTemplateVersion(c *gin.Context) { } // ListTemplateVersions lists all versions of a template -func (h *Handler) ListTemplateVersions(c *gin.Context) { +func (h *TemplateVersioningHandler) ListTemplateVersions(c *gin.Context) { templateID := c.Param("templateId") status := c.Query("status") @@ -205,7 +216,7 @@ func (h *Handler) ListTemplateVersions(c *gin.Context) { query += " ORDER BY major_version DESC, minor_version DESC, patch_version DESC" - rows, err := h.DB.Query(query, args...) + rows, err := h.DB.DB().Query(query, args...) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to retrieve versions"}) return @@ -236,7 +247,7 @@ func (h *Handler) ListTemplateVersions(c *gin.Context) { } // GetTemplateVersion retrieves a specific template version -func (h *Handler) GetTemplateVersion(c *gin.Context) { +func (h *TemplateVersioningHandler) GetTemplateVersion(c *gin.Context) { versionID, err := strconv.ParseInt(c.Param("versionId"), 10, 64) if err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": "invalid version ID"}) @@ -246,7 +257,7 @@ func (h *Handler) GetTemplateVersion(c *gin.Context) { var v TemplateVersion var config, testResults sql.NullString - err = h.DB.QueryRow(` + err = h.DB.DB().QueryRow(` SELECT id, template_id, version, major_version, minor_version, patch_version, display_name, description, configuration, base_image, parent_template_id, parent_version, changelog, status, is_default, @@ -278,7 +289,7 @@ func (h *Handler) GetTemplateVersion(c *gin.Context) { } // PublishTemplateVersion publishes a template version (draft -> stable) -func (h *Handler) PublishTemplateVersion(c *gin.Context) { +func (h *TemplateVersioningHandler) PublishTemplateVersion(c *gin.Context) { versionID, err := strconv.ParseInt(c.Param("versionId"), 10, 64) if err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": "invalid version ID"}) @@ -287,7 +298,7 @@ func (h *Handler) PublishTemplateVersion(c *gin.Context) { // Check if all tests passed var failedTests int - h.DB.QueryRow(` + h.DB.DB().QueryRow(` SELECT COUNT(*) FROM template_tests WHERE version_id = $1 AND status = 'failed' `, versionID).Scan(&failedTests) @@ -298,7 +309,7 @@ func (h *Handler) PublishTemplateVersion(c *gin.Context) { } now := time.Now() - _, err = h.DB.Exec(` + _, err = h.DB.DB().Exec(` UPDATE template_versions SET status = 'stable', published_at = $1, updated_at = $2 WHERE id = $3 @@ -313,7 +324,7 @@ func (h *Handler) PublishTemplateVersion(c *gin.Context) { } // DeprecateTemplateVersion marks a version as deprecated -func (h *Handler) DeprecateTemplateVersion(c *gin.Context) { +func (h *TemplateVersioningHandler) DeprecateTemplateVersion(c *gin.Context) { versionID, err := strconv.ParseInt(c.Param("versionId"), 10, 64) if err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": "invalid version ID"}) @@ -321,7 +332,7 @@ func (h *Handler) DeprecateTemplateVersion(c *gin.Context) { } now := time.Now() - _, err = h.DB.Exec(` + _, err = h.DB.DB().Exec(` UPDATE template_versions SET status = 'deprecated', deprecated_at = $1, updated_at = $2 WHERE id = $3 @@ -336,7 +347,7 @@ func (h *Handler) DeprecateTemplateVersion(c *gin.Context) { } // SetDefaultTemplateVersion sets a version as the default for a template -func (h *Handler) SetDefaultTemplateVersion(c *gin.Context) { +func (h *TemplateVersioningHandler) SetDefaultTemplateVersion(c *gin.Context) { versionID, err := strconv.ParseInt(c.Param("versionId"), 10, 64) if err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": "invalid version ID"}) @@ -345,17 +356,17 @@ func (h *Handler) SetDefaultTemplateVersion(c *gin.Context) { // Get template ID var templateID string - err = h.DB.QueryRow("SELECT template_id FROM template_versions WHERE id = $1", versionID).Scan(&templateID) + err = h.DB.DB().QueryRow("SELECT template_id FROM template_versions WHERE id = $1", versionID).Scan(&templateID) if err != nil { c.JSON(http.StatusNotFound, gin.H{"error": "version not found"}) return } // Unset all defaults for this template - h.DB.Exec("UPDATE template_versions SET is_default = false WHERE template_id = $1", templateID) + h.DB.DB().Exec("UPDATE template_versions SET is_default = false WHERE template_id = $1", templateID) // Set this version as default - _, err = h.DB.Exec("UPDATE template_versions SET is_default = true WHERE id = $1", versionID) + _, err = h.DB.DB().Exec("UPDATE template_versions SET is_default = true WHERE id = $1", versionID) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to set default version"}) return @@ -367,7 +378,7 @@ func (h *Handler) SetDefaultTemplateVersion(c *gin.Context) { // Template Testing // CreateTemplateTest creates a test for a template version -func (h *Handler) CreateTemplateTest(c *gin.Context) { +func (h *TemplateVersioningHandler) CreateTemplateTest(c *gin.Context) { versionID, err := strconv.ParseInt(c.Param("versionId"), 10, 64) if err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": "invalid version ID"}) @@ -386,8 +397,9 @@ func (h *Handler) CreateTemplateTest(c *gin.Context) { } // Get template ID and version - var templateID, version string - err = h.DB.QueryRow(` + var templateID int64 + var version string + err = h.DB.DB().QueryRow(` SELECT template_id, version FROM template_versions WHERE id = $1 `, versionID).Scan(&templateID, &version) @@ -397,7 +409,7 @@ func (h *Handler) CreateTemplateTest(c *gin.Context) { } var testID int64 - err = h.DB.QueryRow(` + err = h.DB.DB().QueryRow(` INSERT INTO template_tests ( template_id, version_id, version, test_type, status, created_by ) VALUES ($1, $2, $3, $4, $5, $6) @@ -420,14 +432,14 @@ func (h *Handler) CreateTemplateTest(c *gin.Context) { } // ListTemplateTests lists all tests for a template version -func (h *Handler) ListTemplateTests(c *gin.Context) { +func (h *TemplateVersioningHandler) ListTemplateTests(c *gin.Context) { versionID, err := strconv.ParseInt(c.Param("versionId"), 10, 64) if err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": "invalid version ID"}) return } - rows, err := h.DB.Query(` + rows, err := h.DB.DB().Query(` SELECT id, template_id, version_id, version, test_type, status, results, duration, error_message, started_at, completed_at, created_by, created_at FROM template_tests @@ -461,7 +473,7 @@ func (h *Handler) ListTemplateTests(c *gin.Context) { } // UpdateTemplateTestStatus updates the status of a test (used by test runners) -func (h *Handler) UpdateTemplateTestStatus(c *gin.Context) { +func (h *TemplateVersioningHandler) UpdateTemplateTestStatus(c *gin.Context) { testID, err := strconv.ParseInt(c.Param("testId"), 10, 64) if err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": "invalid test ID"}) @@ -481,7 +493,7 @@ func (h *Handler) UpdateTemplateTestStatus(c *gin.Context) { } completedAt := time.Now() - _, err = h.DB.Exec(` + _, err = h.DB.DB().Exec(` UPDATE template_tests SET status = $1, results = $2, duration = $3, error_message = $4, completed_at = $5 WHERE id = $6 @@ -494,10 +506,10 @@ func (h *Handler) UpdateTemplateTestStatus(c *gin.Context) { // Update version's test results summary var versionID int64 - h.DB.QueryRow("SELECT version_id FROM template_tests WHERE id = $1", testID).Scan(&versionID) + h.DB.DB().QueryRow("SELECT version_id FROM template_tests WHERE id = $1", testID).Scan(&versionID) testSummary := h.getTestSummary(versionID) - h.DB.Exec("UPDATE template_versions SET test_results = $1 WHERE id = $2", + h.DB.DB().Exec("UPDATE template_versions SET test_results = $1 WHERE id = $2", toJSONB(testSummary), versionID) c.JSON(http.StatusOK, gin.H{"message": "test status updated successfully"}) @@ -506,12 +518,12 @@ func (h *Handler) UpdateTemplateTestStatus(c *gin.Context) { // Template Inheritance // GetTemplateInheritance retrieves the inheritance chain for a template -func (h *Handler) GetTemplateInheritance(c *gin.Context) { +func (h *TemplateVersioningHandler) GetTemplateInheritance(c *gin.Context) { templateID := c.Param("templateId") // Get parent template if exists var parentTemplateID sql.NullString - h.DB.QueryRow(` + h.DB.DB().QueryRow(` SELECT parent_template_id FROM template_versions WHERE template_id = $1 AND is_default = true `, templateID).Scan(&parentTemplateID) @@ -524,12 +536,12 @@ func (h *Handler) GetTemplateInheritance(c *gin.Context) { // Fetch parent and child configurations var parentConfigJSON, childConfigJSON sql.NullString - h.DB.QueryRow(` + h.DB.DB().QueryRow(` SELECT configuration FROM template_versions WHERE template_id = $1 AND is_default = true `, parentTemplateID.String).Scan(&parentConfigJSON) - h.DB.QueryRow(` + h.DB.DB().QueryRow(` SELECT configuration FROM template_versions WHERE template_id = $1 AND is_default = true `, templateID).Scan(&childConfigJSON) @@ -557,7 +569,7 @@ func (h *Handler) GetTemplateInheritance(c *gin.Context) { } // CloneTemplateVersion creates a new version based on an existing one -func (h *Handler) CloneTemplateVersion(c *gin.Context) { +func (h *TemplateVersioningHandler) CloneTemplateVersion(c *gin.Context) { versionID, err := strconv.ParseInt(c.Param("versionId"), 10, 64) if err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": "invalid version ID"}) @@ -579,7 +591,7 @@ func (h *Handler) CloneTemplateVersion(c *gin.Context) { // Get original version var templateID, displayName, description, baseImage string var config sql.NullString - err = h.DB.QueryRow(` + err = h.DB.DB().QueryRow(` SELECT template_id, display_name, description, configuration, base_image FROM template_versions WHERE id = $1 `, versionID).Scan(&templateID, &displayName, &description, &config, &baseImage) @@ -594,7 +606,7 @@ func (h *Handler) CloneTemplateVersion(c *gin.Context) { // Create new version var newVersionID int64 - err = h.DB.QueryRow(` + err = h.DB.DB().QueryRow(` INSERT INTO template_versions ( template_id, version, major_version, minor_version, patch_version, display_name, description, configuration, base_image, changelog, @@ -624,10 +636,10 @@ func parseSemanticVersion(version string) (int, int, int) { return major, minor, patch } -func (h *Handler) getTestSummary(versionID int64) map[string]interface{} { +func (h *TemplateVersioningHandler) getTestSummary(versionID int64) map[string]interface{} { var total, passed, failed, pending int - h.DB.QueryRow(` + h.DB.DB().QueryRow(` SELECT COUNT(*) as total, COUNT(*) FILTER (WHERE status = 'passed') as passed, COUNT(*) FILTER (WHERE status = 'failed') as failed, @@ -650,15 +662,15 @@ func (h *Handler) getTestSummary(versionID int64) map[string]interface{} { } // executeTemplateTest runs template tests asynchronously -func (h *Handler) executeTemplateTest(testID int64, templateID, versionID int64, version, testType string) { +func (h *TemplateVersioningHandler) executeTemplateTest(testID int64, templateID, versionID int64, version, testType string) { // Update status to running startTime := time.Now() - h.DB.Exec("UPDATE template_tests SET status = 'running', started_at = $1 WHERE id = $2", startTime, testID) + h.DB.DB().Exec("UPDATE template_tests SET status = 'running', started_at = $1 WHERE id = $2", startTime, testID) // Fetch template configuration var baseImage string var configuration sql.NullString - err := h.DB.QueryRow(` + err := h.DB.DB().QueryRow(` SELECT base_image, configuration FROM template_versions WHERE id = $1 `, versionID).Scan(&baseImage, &configuration) @@ -690,7 +702,7 @@ func (h *Handler) executeTemplateTest(testID int64, templateID, versionID int64, duration := int(time.Since(startTime).Seconds()) // Update test results - h.DB.Exec(` + h.DB.DB().Exec(` UPDATE template_tests SET status = $1, results = $2, duration = $3, error_message = $4, completed_at = $5 WHERE id = $6 @@ -698,12 +710,12 @@ func (h *Handler) executeTemplateTest(testID int64, templateID, versionID int64, // Update version test summary testSummary := h.getTestSummary(versionID) - h.DB.Exec("UPDATE template_versions SET test_results = $1 WHERE id = $2", + h.DB.DB().Exec("UPDATE template_versions SET test_results = $1 WHERE id = $2", toJSONB(testSummary), versionID) } // runStartupTest validates basic template startup requirements -func (h *Handler) runStartupTest(baseImage, configJSON string, results map[string]interface{}) (string, string) { +func (h *TemplateVersioningHandler) runStartupTest(baseImage, configJSON string, results map[string]interface{}) (string, string) { checks := make(map[string]bool) // Check 1: Base image is specified @@ -746,7 +758,7 @@ func (h *Handler) runStartupTest(baseImage, configJSON string, results map[strin } // runSmokeTest performs basic smoke tests -func (h *Handler) runSmokeTest(baseImage, configJSON string, results map[string]interface{}) (string, string) { +func (h *TemplateVersioningHandler) runSmokeTest(baseImage, configJSON string, results map[string]interface{}) (string, string) { checks := make(map[string]bool) // Parse configuration @@ -809,7 +821,7 @@ func (h *Handler) runSmokeTest(baseImage, configJSON string, results map[string] } // runFunctionalTest performs functional validation -func (h *Handler) runFunctionalTest(baseImage, configJSON string, results map[string]interface{}) (string, string) { +func (h *TemplateVersioningHandler) runFunctionalTest(baseImage, configJSON string, results map[string]interface{}) (string, string) { // Simulate functional tests results["message"] = "Functional tests validated configuration integrity" results["validated"] = true @@ -817,7 +829,7 @@ func (h *Handler) runFunctionalTest(baseImage, configJSON string, results map[st } // runPerformanceTest performs performance validation -func (h *Handler) runPerformanceTest(baseImage, configJSON string, results map[string]interface{}) (string, string) { +func (h *TemplateVersioningHandler) runPerformanceTest(baseImage, configJSON string, results map[string]interface{}) (string, string) { // Simulate performance tests results["message"] = "Performance tests completed" results["startup_time_estimate"] = "5s" @@ -826,7 +838,7 @@ func (h *Handler) runPerformanceTest(baseImage, configJSON string, results map[s } // compareTemplateFields compares parent and child template configurations -func (h *Handler) compareTemplateFields(parentConfig, childConfig map[string]interface{}) (overridden, inherited []string) { +func (h *TemplateVersioningHandler) compareTemplateFields(parentConfig, childConfig map[string]interface{}) (overridden, inherited []string) { overridden = []string{} inherited = []string{} diff --git a/api/internal/handlers/users.go b/api/internal/handlers/users.go index a943d614..c8a6eb87 100644 --- a/api/internal/handlers/users.go +++ b/api/internal/handlers/users.go @@ -53,6 +53,7 @@ package handlers import ( + "fmt" "net/http" "github.com/gin-gonic/gin" diff --git a/api/internal/handlers/websocket.go b/api/internal/handlers/websocket.go index cc277e9f..43992fd6 100644 --- a/api/internal/handlers/websocket.go +++ b/api/internal/handlers/websocket.go @@ -205,7 +205,6 @@ package handlers import ( "context" - "database/sql" "encoding/json" "fmt" "net/http" diff --git a/api/internal/k8s/client.go b/api/internal/k8s/client.go index 6a9b0c4a..1f9e5a72 100644 --- a/api/internal/k8s/client.go +++ b/api/internal/k8s/client.go @@ -152,6 +152,8 @@ type Template struct { WebApp *WebAppConfig Capabilities []string Tags []string + Featured bool // Whether template is featured in catalog + UsageCount int // Number of times template has been used CreatedAt time.Time } @@ -784,6 +786,14 @@ func parseTemplate(obj *unstructured.Unstructured) (*Template, error) { } } + if featured, ok := spec["featured"].(bool); ok { + template.Featured = featured + } + + if usageCount, ok := spec["usageCount"].(float64); ok { + template.UsageCount = int(usageCount) + } + return template, nil } diff --git a/api/internal/models/user.go b/api/internal/models/user.go index abbe3bef..711819db 100644 --- a/api/internal/models/user.go +++ b/api/internal/models/user.go @@ -494,6 +494,7 @@ type AddGroupMemberRequest struct { // "maxMemory": "16Gi" // } type SetQuotaRequest struct { + Username string `json:"username,omitempty"` // For admin endpoints only MaxSessions *int `json:"maxSessions,omitempty"` MaxCPU *string `json:"maxCpu,omitempty"` MaxMemory *string `json:"maxMemory,omitempty"`