diff --git a/cmd/oauth-test-client/main.go b/cmd/oauth-test-client/main.go index 7181e7e..da8a66f 100644 --- a/cmd/oauth-test-client/main.go +++ b/cmd/oauth-test-client/main.go @@ -11,9 +11,9 @@ import ( // Configuration // In a real app, these would come from environment variables const ( - AuthServerURL = "https://auth-server-4nmm.onrender.com" // Replace with your Render URL - ClientID = "your-client-id-here" // You'll get this after registering the client - ClientSecret = "your-client-secret-here" // You'll get this after registering the client + AuthServerURL = "https://auth-server-4nmm.onrender.com" // Replace with your Render URL + ClientID = "your-client-id-here" // You'll get this after registering the client + ClientSecret = "your-client-secret-here" // You'll get this after registering the client RedirectURI = "http://localhost:3000/callback" AppPort = ":3000" ) diff --git a/internal/config/redis.go b/internal/config/redis.go index 0a4d04e..c8e15dc 100644 --- a/internal/config/redis.go +++ b/internal/config/redis.go @@ -9,7 +9,6 @@ import ( func InitRedis(cfg *Config) *redis.Client { ctx := context.Background() // Local context variable - opt, err := redis.ParseURL(cfg.Redis.URL) if err != nil { log.Fatal("Failed to parse Redis URL:", err) diff --git a/internal/handler/auth_handler.go b/internal/handler/auth_handler.go index 2ccafe6..9f20cfb 100644 --- a/internal/handler/auth_handler.go +++ b/internal/handler/auth_handler.go @@ -449,6 +449,9 @@ func (h *AuthHandler) GetSessions(c *gin.Context) { return } + currentSessionID, _ := c.Get("sessionID") + currentID, _ := currentSessionID.(string) + // Convert to response format sessionResponses := make([]dto.SessionResponse, len(sessions)) for i, session := range sessions { @@ -458,7 +461,7 @@ func (h *AuthHandler) GetSessions(c *gin.Context) { UserAgent: session.UserAgent, CreatedAt: session.CreatedAt.Format("2006-01-02 15:04:05"), ExpiresAt: session.ExpiresAt.Format("2006-01-02 15:04:05"), - IsCurrent: false, // TODO: Determine if this is the current session + IsCurrent: session.ID == currentID, } } diff --git a/internal/handler/auth_handler_protected_test.go b/internal/handler/auth_handler_protected_test.go index 78c7cb2..d7a5819 100644 --- a/internal/handler/auth_handler_protected_test.go +++ b/internal/handler/auth_handler_protected_test.go @@ -39,7 +39,11 @@ func TestAuthHandler_GetMe(t *testing.T) { } // Generate Token - token, _ := tokenService.GenerateAccessToken(user) + token, err := tokenService.GenerateAccessToken(user, "test-session-id") + assert.NoError(t, err) + if err != nil { + t.FailNow() + } req, _ := http.NewRequest(http.MethodGet, "/api/auth/me", nil) req.Header.Set("Authorization", "Bearer "+token) diff --git a/internal/handler/auth_handler_test.go b/internal/handler/auth_handler_test.go index 3510251..7d921d6 100644 --- a/internal/handler/auth_handler_test.go +++ b/internal/handler/auth_handler_test.go @@ -8,8 +8,11 @@ import ( "testing" "github.com/gin-gonic/gin" + "github.com/roshankumar0036singh/auth-server/internal/config" "github.com/roshankumar0036singh/auth-server/internal/dto" "github.com/roshankumar0036singh/auth-server/internal/handler" + "github.com/roshankumar0036singh/auth-server/internal/middleware" + "github.com/roshankumar0036singh/auth-server/internal/service" "github.com/roshankumar0036singh/auth-server/internal/testutils" "github.com/stretchr/testify/assert" ) @@ -94,3 +97,160 @@ func TestAuthHandler_Login(t *testing.T) { } // TODO: Add tests for Protected Routes using middleware + +func TestAuthHandler_GetSessions_CurrentSessionFlag(t *testing.T) { + authService, _, mr := testutils.SetupIntegrationTest(t) + defer mr.Close() + + authHandler := handler.NewAuthHandler(authService, nil) + + cfg := &config.Config{ + JWT: config.JWTConfig{ + AccessSecret: "secret", + RefreshSecret: "refresh-secret", + }, + } + tokenService := service.NewTokenService(cfg) + + gin.SetMode(gin.TestMode) + r := gin.New() + r.Use(middleware.AuthMiddleware(tokenService)) + + r.GET("/api/auth/sessions", authHandler.GetSessions) + + // Create user + regReq := &dto.RegisterRequest{ + Email: "sessions@example.com", + Password: "Password123!", + FirstName: "Session", + LastName: "Test", + } + + _, err := authService.Register(regReq) + assert.NoError(t, err) + + // Create a session via login + loginResp, err := authService.Login( + &dto.LoginRequest{ + Email: "sessions@example.com", + Password: "Password123!", + }, + "127.0.0.1", + "test-agent", + ) + + assert.NoError(t, err) + + claims, err := tokenService.ValidateAccessToken(loginResp.AccessToken) + assert.NoError(t, err) + + expectedSessionID := claims.SessionID + + // Call sessions endpoint using the access token + req, _ := http.NewRequest( + http.MethodGet, + "/api/auth/sessions", + nil, + ) + + req.Header.Set( + "Authorization", + "Bearer "+loginResp.AccessToken, + ) + + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + + var resp map[string]interface{} + err = json.Unmarshal(w.Body.Bytes(), &resp) + assert.NoError(t, err) + + data := resp["data"].([]interface{}) + + foundExpectedSession := false + + for _, item := range data { + session := item.(map[string]interface{}) + + sessionID := session["id"].(string) + isCurrent := session["isCurrent"].(bool) + + if sessionID == expectedSessionID { + assert.True(t, isCurrent, "expected session used by request token to be current") + foundExpectedSession = true + } + } + + assert.True(t, foundExpectedSession, "expected session ID not found in response") + +} + +func TestAuthHandler_GetSessions_NoSessionIDInContext(t *testing.T) { + authService, _, mr := testutils.SetupIntegrationTest(t) + defer mr.Close() + + authHandler := handler.NewAuthHandler(authService, nil) + + gin.SetMode(gin.TestMode) + r := gin.New() + + // Register user + regReq := &dto.RegisterRequest{ + Email: "nosession@example.com", + Password: "Password123!", + FirstName: "No", + LastName: "Session", + } + + user, err := authService.Register(regReq) + assert.NoError(t, err) + + userID := user.ID + + // Intentionally set only userID, not sessionID + r.GET("/api/auth/sessions", func(c *gin.Context) { + c.Set("userID", userID) + authHandler.GetSessions(c) + }) + + // Create a session + _, err = authService.Login( + &dto.LoginRequest{ + Email: regReq.Email, + Password: regReq.Password, + }, + "127.0.0.1", + "test-agent", + ) + assert.NoError(t, err) + + req, _ := http.NewRequest( + http.MethodGet, + "/api/auth/sessions", + nil, + ) + + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + + var resp map[string]interface{} + err = json.Unmarshal(w.Body.Bytes(), &resp) + assert.NoError(t, err) + + data := resp["data"].([]interface{}) + assert.NotEmpty(t, data, "expected at least one session after login") + + for _, item := range data { + session := item.(map[string]interface{}) + + assert.False( + t, + session["isCurrent"].(bool), + "expected no session to be marked current when sessionID is missing", + ) + } +} diff --git a/internal/middleware/auth.go b/internal/middleware/auth.go index 568071c..8e19e50 100644 --- a/internal/middleware/auth.go +++ b/internal/middleware/auth.go @@ -72,4 +72,5 @@ func setContextUser(c *gin.Context, claims *service.JWTClaims) { c.Set("userID", claims.UserID) c.Set("email", claims.Email) c.Set("role", claims.Role) + c.Set("sessionID", claims.SessionID) } diff --git a/internal/repository/token_repository.go b/internal/repository/token_repository.go index 91a882c..3595cb7 100644 --- a/internal/repository/token_repository.go +++ b/internal/repository/token_repository.go @@ -23,10 +23,12 @@ func (r *TokenRepository) CreateRefreshToken(token *models.RefreshToken) error { return r.db.Create(token).Error } +const tokenQuery = "token = ?" + // FindRefreshToken finds a refresh token by token string func (r *TokenRepository) FindRefreshToken(tokenString string) (*models.RefreshToken, error) { var token models.RefreshToken - if err := r.db.Where("token = ?", tokenString).First(&token).Error; err != nil { + if err := r.db.Where(tokenQuery, tokenString).First(&token).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, ErrRefreshTokenNotFound } @@ -61,7 +63,7 @@ func (r *TokenRepository) FindUserRefreshTokens(userID string) ([]models.Refresh // RevokeRefreshToken marks a refresh token as revoked func (r *TokenRepository) RevokeRefreshToken(tokenString string) error { result := r.db.Model(&models.RefreshToken{}). - Where("token = ?", tokenString). + Where(tokenQuery, tokenString). Update("is_revoked", true) if result.Error != nil { @@ -126,3 +128,25 @@ func (r *TokenRepository) CountUserActiveSessions(userID string) (int64, error) Count(&count).Error return count, err } + +func (r *TokenRepository) RotateRefreshToken(oldToken string, newToken *models.RefreshToken) error { + return r.db.Transaction(func(tx *gorm.DB) error { + result := tx.Model(&models.RefreshToken{}). + Where("token = ? AND is_revoked = ?", oldToken, false). + Update("is_revoked", true) + + if result.Error != nil { + return result.Error + } + + if result.RowsAffected == 0 { + return ErrRefreshTokenNotFound + } + + if err := tx.Create(newToken).Error; err != nil { + return err + } + + return nil + }) +} diff --git a/internal/service/auth_service.go b/internal/service/auth_service.go index 31a594d..d838c6a 100644 --- a/internal/service/auth_service.go +++ b/internal/service/auth_service.go @@ -308,36 +308,15 @@ func (s *AuthService) VerifyLoginMFA(email, code, ipAddress, userAgent string) ( return nil, errors.New("invalid TOTP code") } - // Generate tokens - accessToken, err := s.tokenService.GenerateAccessToken(user) + response, err := s.createLoginResponse(user, ipAddress, userAgent) if err != nil { - return nil, errors.New(errGenAccessToken) - } - - refreshTokenString, err := s.tokenService.GenerateRefreshToken(user) - if err != nil { - return nil, errors.New(errGenRefreshToken) - } - - refreshToken := &models.RefreshToken{ - UserID: user.ID, - Token: refreshTokenString, - ExpiresAt: time.Now().Add(7 * 24 * time.Hour), - IPAddress: ipAddress, - UserAgent: userAgent, - } - - if err := s.tokenRepo.CreateRefreshToken(refreshToken); err != nil { - return nil, errors.New(errStoreRefreshToken) + return nil, err } s.auditService.LogEvent(&user.ID, "USER_LOGIN_SUCCESS_MFA", "USER", user.ID, ipAddress, userAgent, nil) - return &dto.LoginResponse{ - AccessToken: accessToken, - RefreshToken: refreshTokenString, - User: user.ToPublic(), - }, nil + return response, nil + } // Register creates a new user account and sends verification email @@ -498,43 +477,20 @@ func (s *AuthService) Login(req *dto.LoginRequest, ipAddress, userAgent string) return nil, errors.New("mfa_required") } - // Generate tokens - accessToken, err := s.tokenService.GenerateAccessToken(user) - if err != nil { - return nil, errors.New(errGenAccessToken) - } - - refreshTokenString, err := s.tokenService.GenerateRefreshToken(user) - if err != nil { - return nil, errors.New(errGenRefreshToken) - } - - // Store refresh token - refreshToken := &models.RefreshToken{ - UserID: user.ID, - Token: refreshTokenString, - ExpiresAt: time.Now().Add(7 * 24 * time.Hour), // TODO: Align with config - IPAddress: ipAddress, - UserAgent: userAgent, - } - - if err := s.tokenRepo.CreateRefreshToken(refreshToken); err != nil { - return nil, errors.New(errStoreRefreshToken) - } - // Update last login if err := s.userRepo.Update(user.ID, map[string]interface{}{"last_login_at": time.Now()}); err != nil { log.Printf("Failed to update last login for user %s: %v", user.ID, err) } + response, err := s.createLoginResponse(user, ipAddress, userAgent) + if err != nil { + return nil, err + } + // Audit Log s.auditService.LogEvent(&user.ID, "USER_LOGIN_SUCCESS", "USER", user.ID, ipAddress, userAgent, nil) - return &dto.LoginResponse{ - AccessToken: accessToken, - RefreshToken: refreshTokenString, - User: user.ToPublic(), - }, nil + return response, nil } // LoginWithOAuth handles login or registration via OAuth provider @@ -577,36 +533,15 @@ func (s *AuthService) LoginWithOAuth(email, oauthID, firstName, lastName, provid } } - // Generate tokens - accessToken, err := s.tokenService.GenerateAccessToken(user) + response, err := s.createLoginResponse(user, ipAddress, userAgent) if err != nil { - return nil, errors.New(errGenAccessToken) - } - - refreshTokenString, err := s.tokenService.GenerateRefreshToken(user) - if err != nil { - return nil, errors.New(errGenRefreshToken) - } - - refreshToken := &models.RefreshToken{ - UserID: user.ID, - Token: refreshTokenString, - ExpiresAt: time.Now().Add(7 * 24 * time.Hour), - IPAddress: ipAddress, - UserAgent: userAgent, - } - - if err := s.tokenRepo.CreateRefreshToken(refreshToken); err != nil { - return nil, errors.New(errStoreRefreshToken) + return nil, err } s.auditService.LogEvent(&user.ID, "USER_LOGIN_SUCCESS_OAUTH", "USER", user.ID, ipAddress, userAgent, nil) - return &dto.LoginResponse{ - AccessToken: accessToken, - RefreshToken: refreshTokenString, - User: user.ToPublic(), - }, nil + return response, nil + } func (s *AuthService) handleFailedLogin(user *models.User, email string, ctx context.Context) { @@ -669,23 +604,12 @@ func (s *AuthService) RefreshAccessToken(refreshTokenString string, ipAddress, u return nil, errors.New(errUserNotFound) } - // Generate new access token - newAccessToken, err := s.tokenService.GenerateAccessToken(user) - if err != nil { - return nil, errors.New(errGenAccessToken) - } - // Token rotation: Generate new refresh token newRefreshTokenString, err := s.tokenService.GenerateRefreshToken(user) if err != nil { return nil, errors.New(errGenRefreshToken) } - // Revoke old refresh token - if err := s.tokenRepo.RevokeRefreshToken(refreshTokenString); err != nil { - log.Printf("Warning: Failed to revoke old refresh token: %v", err) - } - // Store new refresh token newRefreshToken := &models.RefreshToken{ UserID: user.ID, @@ -695,8 +619,18 @@ func (s *AuthService) RefreshAccessToken(refreshTokenString string, ipAddress, u UserAgent: userAgent, } - if err := s.tokenRepo.CreateRefreshToken(newRefreshToken); err != nil { - log.Printf("Warning: Failed to store new refresh token: %v", err) + // Generate new access token + newAccessToken, err := s.tokenService.GenerateAccessToken(user, newRefreshToken.ID) + if err != nil { + return nil, errors.New(errGenAccessToken) + } + + // transaction handling creation and rotation of refresh tokens + if err := s.tokenRepo.RotateRefreshToken( + refreshTokenString, + newRefreshToken, + ); err != nil { + return nil, errors.New("failed to rotate refresh token") } return &dto.TokenRefreshResponse{ @@ -781,3 +715,38 @@ func (s *AuthService) RevokeSession(userID, tokenID string) error { return nil } + +func (s *AuthService) createLoginResponse( + user *models.User, + ipAddress string, + userAgent string, +) (*dto.LoginResponse, error) { + + refreshTokenString, err := s.tokenService.GenerateRefreshToken(user) + if err != nil { + return nil, errors.New("failed to generate refresh token") + } + + refreshToken := &models.RefreshToken{ + UserID: user.ID, + Token: refreshTokenString, + ExpiresAt: time.Now().Add(7 * 24 * time.Hour), + IPAddress: ipAddress, + UserAgent: userAgent, + } + + if err := s.tokenRepo.CreateRefreshToken(refreshToken); err != nil { + return nil, errors.New(errStoreRefreshToken) + } + + accessToken, err := s.tokenService.GenerateAccessToken(user, refreshToken.ID) + if err != nil { + return nil, errors.New(errGenAccessToken) + } + + return &dto.LoginResponse{ + AccessToken: accessToken, + RefreshToken: refreshTokenString, + User: user.ToPublic(), + }, nil +} diff --git a/internal/service/token_service.go b/internal/service/token_service.go index df089a4..8618317 100644 --- a/internal/service/token_service.go +++ b/internal/service/token_service.go @@ -19,20 +19,22 @@ func NewTokenService(cfg *config.Config) *TokenService { // JWTClaims custom claims for JWT type JWTClaims struct { - UserID string `json:"sub"` - Email string `json:"email"` - Role string `json:"role"` + UserID string `json:"sub"` + SessionID string `json:"session_id"` + Email string `json:"email"` + Role string `json:"role"` jwt.RegisteredClaims } // GenerateAccessToken generates a new JWT access token -func (s *TokenService) GenerateAccessToken(user *models.User) (string, error) { +func (s *TokenService) GenerateAccessToken(user *models.User, sessionID string) (string, error) { expirationTime := time.Now().Add(15 * time.Minute) // 15 minutes claims := &JWTClaims{ - UserID: user.ID, - Email: user.Email, - Role: user.Role, + UserID: user.ID, + Email: user.Email, + Role: user.Role, + SessionID: sessionID, RegisteredClaims: jwt.RegisteredClaims{ ExpiresAt: jwt.NewNumericDate(expirationTime), IssuedAt: jwt.NewNumericDate(time.Now()), diff --git a/internal/service/token_service_test.go b/internal/service/token_service_test.go index 05332f3..1a085b0 100644 --- a/internal/service/token_service_test.go +++ b/internal/service/token_service_test.go @@ -24,7 +24,9 @@ func TestTokenService_GenerateAccessToken(t *testing.T) { Role: "user", } - token, err := svc.GenerateAccessToken(user) + sessionID := "session-123" + + token, err := svc.GenerateAccessToken(user, sessionID) assert.NoError(t, err) assert.NotEmpty(t, token)