Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 89 additions & 7 deletions api/internal/api/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,8 @@ func (h *Handler) CreateSession(c *gin.Context) {

var req struct {
User string `json:"user" binding:"required"`
Template string `json:"template" binding:"required"`
Template string `json:"template"`
ApplicationId string `json:"applicationId"`
Resources *struct {
Memory string `json:"memory"`
CPU string `json:"cpu"`
Expand All @@ -395,16 +396,97 @@ func (h *Handler) CreateSession(c *gin.Context) {
return
}

// Step 1: Verify Kubernetes Template CRD exists
// Step 1: Resolve template name from application ID or direct template name
// If applicationId is provided, look up the application to get the template name
// This provides better error messages and validation
templateName := req.Template

if req.ApplicationId != "" {
// Look up the installed application in the database
var appTemplateName, appDisplayName, installStatus, installMessage string
var enabled bool
err := h.db.DB().QueryRowContext(ctx, `
SELECT
COALESCE(ct.name, '') as template_name,
ia.display_name,
ia.enabled,
COALESCE(ia.install_status, 'unknown') as install_status,
COALESCE(ia.install_message, '') as install_message
FROM installed_applications ia
LEFT JOIN catalog_templates ct ON ia.catalog_template_id = ct.id
WHERE ia.id = $1
`, req.ApplicationId).Scan(&appTemplateName, &appDisplayName, &enabled, &installStatus, &installMessage)

if err != nil {
c.JSON(http.StatusNotFound, gin.H{
"error": "Application not found",
"message": fmt.Sprintf("No application found with ID: %s", req.ApplicationId),
})
return
}

// Check if the application is enabled
if !enabled {
c.JSON(http.StatusBadRequest, gin.H{
"error": "Application disabled",
"message": fmt.Sprintf("The application '%s' is currently disabled", appDisplayName),
})
return
}

// Check installation status
if installStatus == "failed" {
c.JSON(http.StatusBadRequest, gin.H{
"error": "Application installation failed",
"message": fmt.Sprintf("The application '%s' failed to install: %s", appDisplayName, installMessage),
})
return
}

if installStatus == "pending" || installStatus == "creating" {
c.JSON(http.StatusBadRequest, gin.H{
"error": "Application still installing",
"message": fmt.Sprintf("The application '%s' is still being installed. Please wait and try again.", appDisplayName),
})
return
}

// Validate template name was found
if appTemplateName == "" {
c.JSON(http.StatusBadRequest, gin.H{
"error": "Application configuration error",
"message": fmt.Sprintf("The application '%s' does not have a valid template configuration", appDisplayName),
})
return
}

templateName = appTemplateName
} else if req.Template == "" {
// Neither applicationId nor template provided
c.JSON(http.StatusBadRequest, gin.H{
"error": "Missing required field",
"message": "Either 'applicationId' or 'template' must be provided",
})
return
}

// Step 2: Verify Kubernetes Template CRD exists
// The template must be created during application installation (see handlers/applications.go)
// Without a valid template, the session cannot be created
template, err := h.k8sClient.GetTemplate(ctx, h.namespace, req.Template)
template, err := h.k8sClient.GetTemplate(ctx, h.namespace, templateName)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("Template not found: %s. Please ensure the application is properly installed.", req.Template)})
// Provide a more helpful error message
errorMsg := fmt.Sprintf("Template not found: %s.", templateName)
if req.ApplicationId != "" {
errorMsg += " The application may still be installing or the Kubernetes controller may not be running."
} else {
errorMsg += " Please ensure the application is properly installed."
}
c.JSON(http.StatusBadRequest, gin.H{"error": errorMsg})
return
}

// Step 2: Determine resource allocation (memory/CPU)
// Step 3: Determine resource allocation (memory/CPU)
// Priority: request > template defaults > system defaults
memory := "2Gi" // System default
cpu := "1000m" // System default (1 core)
Expand All @@ -426,7 +508,7 @@ func (h *Handler) CreateSession(c *gin.Context) {
}
}

// Step 3: Validate and parse resource specifications
// Step 4: Validate and parse resource specifications
// Convert human-readable formats (e.g., "2Gi", "500m") to int64 for quota checking
requestedCPU, requestedMemory, err := h.quotaEnforcer.ValidateResourceRequest(cpu, memory)
if err != nil {
Expand All @@ -437,7 +519,7 @@ func (h *Handler) CreateSession(c *gin.Context) {
return
}

// Step 4: Check user quota before creating session
// Step 5: Check user quota before creating session
// Get current resource usage by listing all pods belonging to this user
podList, err := h.k8sClient.GetPods(ctx, h.namespace)
if err != nil {
Expand Down
69 changes: 51 additions & 18 deletions api/internal/db/sessions.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,10 @@ func (s *SessionDB) CreateSession(ctx context.Context, session *Session) error {
session.Memory, session.CPU, session.PersistentHome, session.IdleTimeout, session.MaxSessionDuration,
session.CreatedAt, session.UpdatedAt, session.LastConnection, session.LastDisconnect, session.LastActivity,
)
return err
if err != nil {
return fmt.Errorf("failed to create session %s for user %s: %w", session.ID, session.UserID, err)
}
return nil
}

// GetSession retrieves a session by ID.
Expand Down Expand Up @@ -109,7 +112,7 @@ func (s *SessionDB) GetSession(ctx context.Context, sessionID string) (*Session,
if err == sql.ErrNoRows {
return nil, fmt.Errorf("session not found: %s", sessionID)
}
return nil, err
return nil, fmt.Errorf("failed to get session %s: %w", sessionID, err)
}

return session, nil
Expand Down Expand Up @@ -150,11 +153,15 @@ func (s *SessionDB) ListSessionsByUser(ctx context.Context, userID string) ([]*S

rows, err := s.db.QueryContext(ctx, query, userID)
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to list sessions for user %s: %w", userID, err)
}
defer rows.Close()

return s.scanSessions(rows)
sessions, err := s.scanSessions(rows)
if err != nil {
return nil, fmt.Errorf("failed to scan sessions for user %s: %w", userID, err)
}
return sessions, nil
}

// ListSessionsByState retrieves all sessions with a specific state.
Expand All @@ -174,11 +181,15 @@ func (s *SessionDB) ListSessionsByState(ctx context.Context, state string) ([]*S

rows, err := s.db.QueryContext(ctx, query, state)
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to list sessions with state %s: %w", state, err)
}
defer rows.Close()

return s.scanSessions(rows)
sessions, err := s.scanSessions(rows)
if err != nil {
return nil, fmt.Errorf("failed to scan sessions with state %s: %w", state, err)
}
return sessions, nil
}

// UpdateSessionState updates the state of a session.
Expand All @@ -191,7 +202,7 @@ func (s *SessionDB) UpdateSessionState(ctx context.Context, sessionID, state str

result, err := s.db.ExecContext(ctx, query, state, time.Now(), sessionID)
if err != nil {
return err
return fmt.Errorf("failed to update state to %s for session %s: %w", state, sessionID, err)
}

rows, _ := result.RowsAffected()
Expand All @@ -211,7 +222,10 @@ func (s *SessionDB) UpdateSessionURL(ctx context.Context, sessionID, url string)
`

_, err := s.db.ExecContext(ctx, query, url, time.Now(), sessionID)
return err
if err != nil {
return fmt.Errorf("failed to update URL for session %s: %w", sessionID, err)
}
return nil
}

// UpdateSessionStatus updates session state, URL, and pod name from controller status events.
Expand All @@ -224,7 +238,7 @@ func (s *SessionDB) UpdateSessionStatus(ctx context.Context, sessionID, state, u

result, err := s.db.ExecContext(ctx, query, state, url, podName, time.Now(), sessionID)
if err != nil {
return err
return fmt.Errorf("failed to update status for session %s (state=%s, url=%s, pod=%s): %w", sessionID, state, url, podName, err)
}

rows, _ := result.RowsAffected()
Expand All @@ -244,7 +258,10 @@ func (s *SessionDB) UpdateLastActivity(ctx context.Context, sessionID string) er
`

_, err := s.db.ExecContext(ctx, query, time.Now(), sessionID)
return err
if err != nil {
return fmt.Errorf("failed to update last activity for session %s: %w", sessionID, err)
}
return nil
}

// UpdateActiveConnections updates the connection count for a session.
Expand All @@ -257,7 +274,10 @@ func (s *SessionDB) UpdateActiveConnections(ctx context.Context, sessionID strin
`

_, err := s.db.ExecContext(ctx, query, count, now, sessionID)
return err
if err != nil {
return fmt.Errorf("failed to update active connections to %d for session %s: %w", count, sessionID, err)
}
return nil
}

// DeleteSession marks a session as deleted.
Expand All @@ -269,13 +289,19 @@ func (s *SessionDB) DeleteSession(ctx context.Context, sessionID string) error {
`

_, err := s.db.ExecContext(ctx, query, time.Now(), sessionID)
return err
if err != nil {
return fmt.Errorf("failed to mark session %s as deleted: %w", sessionID, err)
}
return nil
}

// HardDeleteSession permanently removes a session from the database.
func (s *SessionDB) HardDeleteSession(ctx context.Context, sessionID string) error {
_, err := s.db.ExecContext(ctx, "DELETE FROM sessions WHERE id = $1", sessionID)
return err
if err != nil {
return fmt.Errorf("failed to permanently delete session %s: %w", sessionID, err)
}
return nil
}

// CountSessionsByUser returns the number of active sessions for a user.
Expand All @@ -285,7 +311,10 @@ func (s *SessionDB) CountSessionsByUser(ctx context.Context, userID string) (int
SELECT COUNT(*) FROM sessions
WHERE user_id = $1 AND state IN ('running', 'pending', 'hibernated')
`, userID).Scan(&count)
return count, err
if err != nil {
return 0, fmt.Errorf("failed to count sessions for user %s: %w", userID, err)
}
return count, nil
}

// GetIdleSessions returns sessions that have been idle beyond their timeout.
Expand Down Expand Up @@ -313,11 +342,15 @@ func (s *SessionDB) GetIdleSessions(ctx context.Context) ([]*Session, error) {
func (s *SessionDB) querySessions(ctx context.Context, query string, args ...interface{}) ([]*Session, error) {
rows, err := s.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to execute session query: %w", err)
}
defer rows.Close()

return s.scanSessions(rows)
sessions, err := s.scanSessions(rows)
if err != nil {
return nil, fmt.Errorf("failed to scan session results: %w", err)
}
return sessions, nil
}

// scanSessions scans rows into Session structs.
Expand All @@ -333,13 +366,13 @@ func (s *SessionDB) scanSessions(rows *sql.Rows) ([]*Session, error) {
&session.CreatedAt, &session.UpdatedAt, &session.LastConnection, &session.LastDisconnect, &session.LastActivity,
)
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to scan session row: %w", err)
}
sessions = append(sessions, session)
}

if err := rows.Err(); err != nil {
return nil, err
return nil, fmt.Errorf("error iterating session rows: %w", err)
}

return sessions, nil
Expand Down
Loading
Loading