diff --git a/.golangci.yml b/.golangci.yml new file mode 100644 index 0000000..1f7a1ba --- /dev/null +++ b/.golangci.yml @@ -0,0 +1,75 @@ +version: "2" + +linters: + # consider to add: cyclop + # find sane settings for revive + default: none + enable: + - bodyclose + - containedctx + - dupl + - err113 + - errcheck + - exhaustive + - fatcontext + - forcetypeassert + - gocheckcompilerdirectives + - gochecknoglobals + - gochecknoinits + - goconst + - gocritic + - godoclint + - godox + - gosec + - govet + - iface + - intrange + - makezero + - mirror + - musttag + - nestif + - nilerr + - nilnesserr + - nilnil + - noctx + - paralleltest + - perfsprint + - prealloc + - staticcheck + - tagalign + - tagliatelle + - testableexamples + - testpackage + - thelper + - tparallel + - unconvert + - unparam + - unused + - usestdlibvars + - usetesting + - wastedassign + - wrapcheck + + settings: + iface: + enable: + - identical # Identifies interfaces in the same package that have identical method sets. + - unused # Identifies interfaces that are not used anywhere in the same package where the interface is defined. + - opaque # Identifies functions that return interfaces, but the actual returned value is always a single concrete implementation. + + tagliatelle: + case: + rules: + json: camel + exclusions: + rules: + - path: _test\.go + linters: + - err113 + - errcheck + - gosec + - nilnil + +formatters: + enable: + - gofmt diff --git a/application/healthcheck.go b/application/healthcheck.go index 595507a..8dfabc0 100644 --- a/application/healthcheck.go +++ b/application/healthcheck.go @@ -4,6 +4,8 @@ import ( "context" "encoding/json" "net/http" + + "github.com/mishankov/platforma/log" ) type healther interface { @@ -24,5 +26,8 @@ func (h *HealthCheckHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(health) + err := json.NewEncoder(w).Encode(health) + if err != nil { + log.ErrorContext(r.Context(), "failed to decode response to json", "error", err) + } } diff --git a/auth/errors.go b/auth/errors.go index acd5f29..03d7810 100644 --- a/auth/errors.go +++ b/auth/errors.go @@ -2,7 +2,16 @@ package auth import "errors" -var ErrInvalidUsername = errors.New("invalid username") -var ErrInvalidPassword = errors.New("invalid password") -var ErrWrongUserOrPassword = errors.New("wrong user or password") -var ErrCurrentPasswordIncorrect = errors.New("current password is incorrect") +var ( + ErrUserNotFound = errors.New("user not found") + ErrWrongUserOrPassword = errors.New("wrong user or password") + + ErrInvalidUsername = errors.New("invalid username") + ErrShortUsername = errors.New("short username") + ErrLongUsername = errors.New("long username") + + ErrInvalidPassword = errors.New("invalid password") + ErrShortPassword = errors.New("short password") + ErrLongPassword = errors.New("long password") + ErrCurrentPasswordIncorrect = errors.New("current password is incorrect") +) diff --git a/auth/handler_get.go b/auth/handler_get.go index 7712171..3e21d72 100644 --- a/auth/handler_get.go +++ b/auth/handler_get.go @@ -3,6 +3,8 @@ package auth import ( "encoding/json" "net/http" + + "github.com/mishankov/platforma/log" ) type GetHandler struct { @@ -48,5 +50,8 @@ func (h *GetHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { Username: user.Username, } - json.NewEncoder(w).Encode(resp) + err = json.NewEncoder(w).Encode(resp) + if err != nil { + log.ErrorContext(ctx, "failed to decode response to json", "error", err) + } } diff --git a/auth/middleware.go b/auth/middleware.go index 5d272c9..693cd7d 100644 --- a/auth/middleware.go +++ b/auth/middleware.go @@ -2,6 +2,7 @@ package auth import ( "context" + "errors" "net/http" "github.com/mishankov/platforma/log" @@ -29,20 +30,23 @@ func (m *AuthenticationMiddleware) Wrap(next http.Handler) http.Handler { } user, err := m.userService.GetFromSession(r.Context(), cookie.Value) - if err != nil { - http.Error(w, "failed to get user", http.StatusInternalServerError) + if errors.Is(err, ErrUserNotFound) { + w.WriteHeader(http.StatusUnauthorized) return } - if user == nil { - w.WriteHeader(http.StatusUnauthorized) + if err != nil { + http.Error(w, "failed to get user", http.StatusInternalServerError) return } - ctxWithUserId := context.WithValue(r.Context(), log.UserIdKey, user.ID) - ctxWithUser := context.WithValue(ctxWithUserId, UserContextKey, user) - requestWithUser := r.WithContext(ctxWithUser) + newRequest := r + if user != nil { + ctxWithUserId := context.WithValue(r.Context(), log.UserIdKey, user.ID) + ctxWithUser := context.WithValue(ctxWithUserId, UserContextKey, user) + newRequest = r.WithContext(ctxWithUser) + } - next.ServeHTTP(w, requestWithUser) + next.ServeHTTP(w, newRequest) }) } diff --git a/auth/middleware_test.go b/auth/middleware_test.go index a512db3..fa79da5 100644 --- a/auth/middleware_test.go +++ b/auth/middleware_test.go @@ -11,6 +11,8 @@ import ( ) func TestAuthenticationMiddleware_ValidSession(t *testing.T) { + t.Parallel() + userSvc := &mockUserService{ users: map[string]*auth.User{ "valid-session-id": {ID: "user-id", Username: "testuser"}, @@ -23,7 +25,7 @@ func TestAuthenticationMiddleware_ValidSession(t *testing.T) { w.WriteHeader(http.StatusOK) })) - req := httptest.NewRequest("GET", "/", nil) + req := httptest.NewRequest(http.MethodGet, "/", nil) req.AddCookie(&http.Cookie{Name: "session", Value: "valid-session-id"}) w := httptest.NewRecorder() @@ -35,6 +37,8 @@ func TestAuthenticationMiddleware_ValidSession(t *testing.T) { } func TestAuthenticationMiddleware_NoSessionCookie(t *testing.T) { + t.Parallel() + userSvc := &mockUserService{ cookieName: "session", } @@ -44,7 +48,7 @@ func TestAuthenticationMiddleware_NoSessionCookie(t *testing.T) { t.Fatal("handler should not be called when authentication fails") })) - req := httptest.NewRequest("GET", "/", nil) + req := httptest.NewRequest(http.MethodGet, "/", nil) w := httptest.NewRecorder() handler.ServeHTTP(w, req) @@ -55,6 +59,8 @@ func TestAuthenticationMiddleware_NoSessionCookie(t *testing.T) { } func TestAuthenticationMiddleware_InvalidSession(t *testing.T) { + t.Parallel() + userSvc := &mockUserService{ users: map[string]*auth.User{}, cookieName: "session", @@ -65,7 +71,7 @@ func TestAuthenticationMiddleware_InvalidSession(t *testing.T) { t.Fatal("handler should not be called when authentication fails") })) - req := httptest.NewRequest("GET", "/", nil) + req := httptest.NewRequest(http.MethodGet, "/", nil) req.AddCookie(&http.Cookie{Name: "session", Value: "invalid-session-id"}) w := httptest.NewRecorder() @@ -77,6 +83,8 @@ func TestAuthenticationMiddleware_InvalidSession(t *testing.T) { } func TestAuthenticationMiddleware_UserServiceError(t *testing.T) { + t.Parallel() + userSvc := &mockUserService{ error: errors.New("database error"), cookieName: "session", @@ -87,7 +95,7 @@ func TestAuthenticationMiddleware_UserServiceError(t *testing.T) { t.Fatal("handler should not be called when authentication fails") })) - req := httptest.NewRequest("GET", "/", nil) + req := httptest.NewRequest(http.MethodGet, "/", nil) req.AddCookie(&http.Cookie{Name: "session", Value: "session-id"}) w := httptest.NewRecorder() @@ -99,6 +107,8 @@ func TestAuthenticationMiddleware_UserServiceError(t *testing.T) { } func TestAuthenticationMiddleware_UserNotFound(t *testing.T) { + t.Parallel() + userSvc := &mockUserService{ users: map[string]*auth.User{}, cookieName: "session", @@ -109,7 +119,7 @@ func TestAuthenticationMiddleware_UserNotFound(t *testing.T) { t.Fatal("handler should not be called when authentication fails") })) - req := httptest.NewRequest("GET", "/", nil) + req := httptest.NewRequest(http.MethodGet, "/", nil) req.AddCookie(&http.Cookie{Name: "session", Value: "session-id"}) w := httptest.NewRecorder() @@ -130,10 +140,11 @@ func (m *mockUserService) GetFromSession(ctx context.Context, sessionId string) if m.error != nil { return nil, m.error } + if user, ok := m.users[sessionId]; ok { return user, nil } - return nil, nil + return nil, auth.ErrUserNotFound } func (m *mockUserService) CookieName() string { diff --git a/auth/model.go b/auth/model.go index 60b1bfd..3dd16f1 100644 --- a/auth/model.go +++ b/auth/model.go @@ -11,11 +11,11 @@ const ( ) type User struct { - ID string `json:"id" db:"id"` - Username string `json:"username" db:"username"` - Password string `json:"password" db:"password"` - Salt string `json:"salt" db:"salt"` - Created time.Time `json:"created" db:"created"` - Updated time.Time `json:"updated" db:"updated"` - Status Status `json:"status" db:"status"` + ID string `db:"id" json:"id"` + Username string `db:"username" json:"username"` + Password string `db:"password" json:"password"` + Salt string `db:"salt" json:"salt"` + Created time.Time `db:"created" json:"created"` + Updated time.Time `db:"updated" json:"updated"` + Status Status `db:"status" json:"status"` } diff --git a/auth/repository.go b/auth/repository.go index b90607b..293ef19 100644 --- a/auth/repository.go +++ b/auth/repository.go @@ -3,6 +3,7 @@ package auth import ( "context" "database/sql" + "fmt" "github.com/mishankov/platforma/database" ) @@ -46,7 +47,7 @@ func (r *Repository) Get(ctx context.Context, id string) (*User, error) { var user User err := r.db.GetContext(ctx, &user, "SELECT * FROM users WHERE id = $1", id) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to get user by id: %w", err) } return &user, nil } @@ -55,7 +56,7 @@ func (r *Repository) GetByUsername(ctx context.Context, username string) (*User, var user User err := r.db.GetContext(ctx, &user, "SELECT * FROM users WHERE username = $1", username) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to get user by username: %w", err) } return &user, nil } @@ -66,7 +67,10 @@ func (r *Repository) Create(ctx context.Context, user *User) error { VALUES (:id, :username, :password, :salt, :created, :updated, :status) ` _, err := r.db.NamedExecContext(ctx, query, user) - return err + if err != nil { + return fmt.Errorf("failed to create user: %w", err) + } + return nil } func (r *Repository) UpdatePassword(ctx context.Context, id, password, salt string) error { @@ -76,5 +80,8 @@ func (r *Repository) UpdatePassword(ctx context.Context, id, password, salt stri WHERE id = $3 ` _, err := r.db.ExecContext(ctx, query, password, salt, id) - return err + if err != nil { + return fmt.Errorf("failed to update password: %w", err) + } + return nil } diff --git a/auth/service.go b/auth/service.go index e71082a..6f34923 100644 --- a/auth/service.go +++ b/auth/service.go @@ -52,7 +52,11 @@ func NewService(repo repository, authStorage authStorage, sessionCookieName stri } func (s *Service) Get(ctx context.Context, id string) (*User, error) { - return s.repo.Get(ctx, id) + user, err := s.repo.Get(ctx, id) + if err != nil { + return nil, fmt.Errorf("failed to get user: %w", err) + } + return user, nil } func (s *Service) GetFromSession(ctx context.Context, sessionId string) (*User, error) { @@ -62,7 +66,7 @@ func (s *Service) GetFromSession(ctx context.Context, sessionId string) (*User, } if userId == "" { - return nil, nil + return nil, ErrUserNotFound } return s.Get(ctx, userId) @@ -83,7 +87,7 @@ func (s *Service) CreateWithLoginAndPassword(ctx context.Context, username, pass hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password+":"+salt), bcrypt.DefaultCost) if err != nil { - return err + return fmt.Errorf("failed to generate password hash: %w", err) } user := &User{ @@ -96,7 +100,11 @@ func (s *Service) CreateWithLoginAndPassword(ctx context.Context, username, pass Status: StatusActive, } - return s.repo.Create(ctx, user) + err = s.repo.Create(ctx, user) + if err != nil { + return fmt.Errorf("failed to create user: %w", err) + } + return nil } func (s *Service) CreateSessionFromUsernameAndPassword(ctx context.Context, username, password string) (string, error) { @@ -115,11 +123,18 @@ func (s *Service) CreateSessionFromUsernameAndPassword(ctx context.Context, user return "", fmt.Errorf("failed to get session: %w", err) } - return session, err + if err != nil { + return "", fmt.Errorf("failed to create session: %w", err) + } + return session, nil } func (s *Service) DeleteSession(ctx context.Context, sessionId string) error { - return s.authStorage.DeleteSession(ctx, sessionId) + err := s.authStorage.DeleteSession(ctx, sessionId) + if err != nil { + return fmt.Errorf("failed to delete session: %w", err) + } + return nil } func (s *Service) CookieName() string { @@ -129,7 +144,7 @@ func (s *Service) CookieName() string { func (s *Service) ChangePassword(ctx context.Context, currentPassword, newPassword string) error { user := UserFromContext(ctx) if user == nil { - return errors.New("user not found") + return ErrUserNotFound } if s.passwordValidator != nil { @@ -149,19 +164,23 @@ func (s *Service) ChangePassword(ctx context.Context, currentPassword, newPasswo newSalt := uuid.New().String() hashedPassword, err := bcrypt.GenerateFromPassword([]byte(newPassword+":"+newSalt), bcrypt.DefaultCost) if err != nil { - return err + return fmt.Errorf("failed to generate password hash: %w", err) } - return s.repo.UpdatePassword(ctx, user.ID, string(hashedPassword), newSalt) + err = s.repo.UpdatePassword(ctx, user.ID, string(hashedPassword), newSalt) + if err != nil { + return fmt.Errorf("failed to update password: %w", err) + } + return nil } func defaultPasswordValidator(password string) error { if len(password) < 8 { - return errors.New("short password") + return ErrShortPassword } if len(password) > 100 { - return errors.New("long password") + return ErrLongPassword } return nil @@ -169,11 +188,11 @@ func defaultPasswordValidator(password string) error { func defaultUsernameValidator(username string) error { if len(username) < 5 { - return errors.New("short username") + return ErrShortUsername } if len(username) > 20 { - return errors.New("long username") + return ErrLongUsername } return nil diff --git a/database/database.go b/database/database.go index 72c82d6..17b23d5 100644 --- a/database/database.go +++ b/database/database.go @@ -2,6 +2,7 @@ package database import ( "context" + "fmt" "slices" "time" @@ -20,7 +21,7 @@ type Database struct { func New(connection string) (*Database, error) { db, err := sqlx.Connect("postgres", connection) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to connect to database: %w", err) } return &Database{DB: db, repositories: make(map[string]any), migrators: make(map[string]shemer)}, nil } @@ -35,14 +36,14 @@ func (db *Database) RegisterRepository(name string, repository any) { func (db *Database) Migrate(ctx context.Context) error { if _, err := db.ExecContext(ctx, "CREATE TABLE IF NOT EXISTS platforma_migrations (repository TEXT, id TEXT, timestamp TIMESTAMP)"); err != nil { - return err + return fmt.Errorf("failed to create migrations table: %w", err) } // Select data from platforma_migrations table var migrationsState []migrations err := db.SelectContext(ctx, &migrationsState, "SELECT * FROM platforma_migrations") if err != nil { - return err + return fmt.Errorf("failed to select migrations state: %w", err) } appliedMigrations := []Migration{} @@ -59,7 +60,7 @@ func (db *Database) Migrate(ctx context.Context) error { if !repoHasMigrations { for _, query := range repoSchema.Queries { if _, err := db.ExecContext(ctx, query); err != nil { - migrationErr = err + migrationErr = fmt.Errorf("failed to execute schema query: %w", err) break } log.InfoContext(ctx, "schema applied", "repository", repoName) @@ -67,13 +68,13 @@ func (db *Database) Migrate(ctx context.Context) error { // Log that schema applied if _, err := db.ExecContext(ctx, "INSERT INTO platforma_migrations (repository, timestamp) VALUES ($1, $2)", repoName, time.Now()); err != nil { - return err + return fmt.Errorf("failed to insert migration record: %w", err) } // If schema is applied, log that all migrations are also applied for _, migration := range repoMigrations { if _, err := db.ExecContext(ctx, "INSERT INTO platforma_migrations (repository, id, timestamp) VALUES ($1, $2, $3)", repoName, migration.ID, time.Now()); err != nil { - return err + return fmt.Errorf("failed to insert migration record: %w", err) } } @@ -93,7 +94,7 @@ func (db *Database) Migrate(ctx context.Context) error { } if _, err := db.ExecContext(ctx, migration.Up); err != nil { - migrationErr = err + migrationErr = fmt.Errorf("failed to apply migration %s for repository %s: %w", migration.ID, repoName, err) log.ErrorContext(ctx, "failed to apply migration for repository", "migration", migration.ID, "repository", repoName) break } @@ -103,7 +104,7 @@ func (db *Database) Migrate(ctx context.Context) error { // Log that migration applied if _, err := db.ExecContext(ctx, "INSERT INTO platforma_migrations (repository, id, timestamp) VALUES ($1, $2, $3)", repoName, migration.ID, time.Now()); err != nil { - return err + return fmt.Errorf("failed to insert migration record: %w", err) } } @@ -116,11 +117,11 @@ func (db *Database) Migrate(ctx context.Context) error { for _, migration := range slices.Backward(appliedMigrations) { if _, err := db.ExecContext(ctx, migration.Down); err != nil { log.ErrorContext(ctx, "failed to rollback migration %s for repository %s", migration.ID, migration.repository) - return err + return fmt.Errorf("failed to rollback migration %s for repository %s: %w", migration.ID, migration.repository, err) } if _, err := db.ExecContext(ctx, "DELETE FROM platforma_migrations WHERE repository = $1 AND id = $2", migration.repository, migration.ID); err != nil { - return err + return fmt.Errorf("failed to delete migration record: %w", err) } } } diff --git a/httpclient/httpclient.go b/httpclient/httpclient.go index 8cf08da..3a4e945 100644 --- a/httpclient/httpclient.go +++ b/httpclient/httpclient.go @@ -1,6 +1,7 @@ package httpclient import ( + "fmt" "net/http" "strings" "time" @@ -25,11 +26,11 @@ func (c *Client) Do(req *http.Request) (*http.Response, error) { resp, err := c.client.Do(req) if err != nil { log.DebugContext(req.Context(), "request failed", "error", err) - return nil, err + return nil, fmt.Errorf("failed to execute request: %w", err) } log.DebugContext(req.Context(), "request made", "status", resp.Status, "headers", maskedHeaders(resp.Header)) - return resp, err + return resp, nil } func maskedHeaders(headers http.Header) http.Header { diff --git a/httpclient/httpclient_test.go b/httpclient/httpclient_test.go index 2d2eb01..3d142de 100644 --- a/httpclient/httpclient_test.go +++ b/httpclient/httpclient_test.go @@ -13,6 +13,8 @@ import ( const timeout = 10 * time.Second func TestNew(t *testing.T) { + t.Parallel() + client := httpclient.New(timeout) if client == nil { t.Error("New() should return a non-nil client") @@ -20,6 +22,8 @@ func TestNew(t *testing.T) { } func TestDo_Success(t *testing.T) { + t.Parallel() + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) w.Write([]byte("test response")) @@ -27,7 +31,7 @@ func TestDo_Success(t *testing.T) { defer server.Close() client := httpclient.New(timeout) - req, err := http.NewRequestWithContext(context.Background(), "GET", server.URL, nil) + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, server.URL, nil) if err != nil { t.Fatalf("failed to create request: %v", err) } @@ -36,7 +40,11 @@ func TestDo_Success(t *testing.T) { if err != nil { t.Fatalf("Do() failed: %v", err) } - defer resp.Body.Close() + defer func() { + if resp.Body != nil { + resp.Body.Close() + } + }() if resp.StatusCode != http.StatusOK { t.Errorf("expected status 200, got %d", resp.StatusCode) @@ -44,14 +52,16 @@ func TestDo_Success(t *testing.T) { } func TestDo_Error(t *testing.T) { + t.Parallel() + // Create a request to a non-existent server client := httpclient.New(timeout) - req, err := http.NewRequestWithContext(context.Background(), "GET", "http://localhost:9999/nonexistent", nil) + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "http://localhost:9999/nonexistent", nil) if err != nil { t.Fatalf("failed to create request: %v", err) } - _, err = client.Do(req) + _, err = client.Do(req) //nolint:bodyclose // Response if nill here, no neen to close body if err == nil { t.Error("expected error for non-existent server, got nil") } diff --git a/httpserver/httpserver.go b/httpserver/httpserver.go index 6e5cccc..26aa6e9 100644 --- a/httpserver/httpserver.go +++ b/httpserver/httpserver.go @@ -3,6 +3,7 @@ package httpserver import ( "context" "errors" + "fmt" "net/http" "os" "os/signal" @@ -47,8 +48,9 @@ func (s *HttpServer) UseFunc(middlewareFuncs ...func(http.Handler) http.Handler) func (s *HttpServer) Run(ctx context.Context) error { server := &http.Server{ - Addr: ":" + s.port, - Handler: wrapHandlerInMiddleware(s.mux, s.middlewares), + Addr: ":" + s.port, + Handler: wrapHandlerInMiddleware(s.mux, s.middlewares), + ReadHeaderTimeout: 1 * time.Second, } go func() { @@ -69,7 +71,7 @@ func (s *HttpServer) Run(ctx context.Context) error { if err := server.Shutdown(shutdownCtx); err != nil { log.ErrorContext(ctx, "HTTP shutdown error", "error", err) - return err + return fmt.Errorf("failed to shutdown server: %w", err) } log.InfoContext(ctx, "graceful shutdown completed.") diff --git a/httpserver/httpserver_test.go b/httpserver/httpserver_test.go index f0e1607..21a4589 100644 --- a/httpserver/httpserver_test.go +++ b/httpserver/httpserver_test.go @@ -18,6 +18,8 @@ func (h *testHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } func TestHttpServer_ShutdownCompletesBeforeTimeout(t *testing.T) { + t.Parallel() + // Create a test HTTP server directly to test shutdown behavior server := &http.Server{ Addr: ":8080", @@ -53,6 +55,8 @@ func TestHttpServer_ShutdownCompletesBeforeTimeout(t *testing.T) { } func TestHttpServer_ShutdownWithNoActiveConnections(t *testing.T) { + t.Parallel() + // Create HttpServer instance to test the integration httpServer := httpserver.New("8081", 3*time.Second) httpServer.Handle("/test", &testHandler{}) @@ -90,6 +94,8 @@ func TestHttpServer_ShutdownWithNoActiveConnections(t *testing.T) { } func TestHttpServer_Healthcheck(t *testing.T) { + t.Parallel() + server := httpserver.New("8083", 5*time.Second) result := server.Healthcheck(context.Background()) diff --git a/httpserver/recover_test.go b/httpserver/recover_test.go index b8f7d0b..88ef043 100644 --- a/httpserver/recover_test.go +++ b/httpserver/recover_test.go @@ -26,6 +26,8 @@ func (h *normalHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } func TestRecoverMiddleware_NormalOperation(t *testing.T) { + t.Parallel() + // Setup middleware := httpserver.NewRecoverMiddleware() handler := &normalHandler{} @@ -52,6 +54,8 @@ func TestRecoverMiddleware_NormalOperation(t *testing.T) { } func TestRecoverMiddleware_PanicRecovery(t *testing.T) { + t.Parallel() + middleware := httpserver.NewRecoverMiddleware() handler := &panicHandler{panicMessage: "test panic"} wrappedHandler := middleware.Wrap(handler) @@ -78,6 +82,8 @@ func TestRecoverMiddleware_PanicRecovery(t *testing.T) { } func TestRecoverMiddleware_ErrorResponse(t *testing.T) { + t.Parallel() + middleware := httpserver.NewRecoverMiddleware() handler := &panicHandler{panicMessage: "specific error for testing"} wrappedHandler := middleware.Wrap(handler) @@ -111,12 +117,14 @@ func TestRecoverMiddleware_ErrorResponse(t *testing.T) { } func TestRecoverMiddleware_MultiplePanics(t *testing.T) { + t.Parallel() + middleware := httpserver.NewRecoverMiddleware() handler := &panicHandler{panicMessage: "first panic"} wrappedHandler := middleware.Wrap(handler) // Test multiple requests to ensure middleware continues to work - for i := 0; i < 3; i++ { + for i := range 3 { req := httptest.NewRequest(http.MethodGet, "/test", nil) w := httptest.NewRecorder() diff --git a/httpserver/traceid.go b/httpserver/traceid.go index 7593437..1795c52 100644 --- a/httpserver/traceid.go +++ b/httpserver/traceid.go @@ -15,7 +15,7 @@ type TraceId struct { header string } -// NewTraceId returns a new TraceId middleware. +// NewTraceIdMiddleware returns a new TraceId middleware. // If key is nil, log.TraceIdKey is used. // If header is empty, "Platforma-Trace-Id" is used. func NewTraceIdMiddleware(contextKey any, header string) *TraceId { diff --git a/internal/cli/generate.go b/internal/cli/generate.go index 9714e7b..d7eaf6d 100644 --- a/internal/cli/generate.go +++ b/internal/cli/generate.go @@ -57,14 +57,17 @@ func generateCommand(args []string) { } func writeFromTemplate(folder, file, templatePath string, data any) error { - os.MkdirAll(folder, 0755) + err := os.MkdirAll(folder, 0750) + if err != nil { + return fmt.Errorf("failed to create directory %s: %w", folder, err) + } // Get the directory of the current CLI package _, filename, _, _ := runtime.Caller(0) cliDir := filepath.Dir(filename) fullTemplatePath := filepath.Join(cliDir, templatePath) - templateContent, err := os.ReadFile(fullTemplatePath) + templateContent, err := os.ReadFile(fullTemplatePath) //nolint:gosec // Known path in compile time if err != nil { return fmt.Errorf("failed to read template %s: %w", fullTemplatePath, err) } @@ -80,9 +83,9 @@ func writeFromTemplate(folder, file, templatePath string, data any) error { return fmt.Errorf("failed to execute template %s: %w", fullTemplatePath, err) } - err = os.WriteFile(filepath.Join(folder, file), buf.Bytes(), 0644) + err = os.WriteFile(filepath.Join(folder, file), buf.Bytes(), 0600) if err != nil { - return err + return fmt.Errorf("failed to write file %s: %w", filepath.Join(folder, file), err) } return nil diff --git a/log/log.go b/log/log.go index b25f408..86d90b9 100644 --- a/log/log.go +++ b/log/log.go @@ -2,6 +2,7 @@ package log import ( "context" + "fmt" "io" "log/slog" ) @@ -18,7 +19,7 @@ type logger interface { ErrorContext(ctx context.Context, msg string, args ...any) } -var Logger logger = slog.Default() +var Logger logger = slog.Default() //nolint:gochecknoglobals // SetDefault sets the default logger used by the package-level logging functions. func SetDefault(l logger) { @@ -35,14 +36,6 @@ const ( UserIdKey contextKey = "userId" ) -var defaultKeys = []contextKey{ - DomainNameKey, - TraceIdKey, - ServiceNameKey, - StartupTaskKey, - UserIdKey, -} - type contextHandler struct { slog.Handler additionKeys map[string]any @@ -50,6 +43,14 @@ type contextHandler struct { // Handle processes the log record by adding context values before passing it to the underlying handler. func (h *contextHandler) Handle(ctx context.Context, r slog.Record) error { + var defaultKeys = []contextKey{ + DomainNameKey, + TraceIdKey, + ServiceNameKey, + StartupTaskKey, + UserIdKey, + } + for _, key := range defaultKeys { if value, ok := ctx.Value(key).(string); ok { r.AddAttrs(slog.String(string(key), value)) @@ -62,7 +63,11 @@ func (h *contextHandler) Handle(ctx context.Context, r slog.Record) error { } } - return h.Handler.Handle(ctx, r) + err := h.Handler.Handle(ctx, r) + if err != nil { + return fmt.Errorf("failed to handle log record: %w", err) + } + return nil } // New creates a new slog.Logger with the specified type (json/text), log level, and additional context keys to include. diff --git a/scheduler/scheduler.go b/scheduler/scheduler.go index c36d790..ffe1dc6 100644 --- a/scheduler/scheduler.go +++ b/scheduler/scheduler.go @@ -2,6 +2,7 @@ package scheduler import ( "context" + "fmt" "time" "github.com/mishankov/platforma/application" @@ -43,7 +44,7 @@ func (s *Scheduler) Run(ctx context.Context) error { log.InfoContext(runCtx, "scheduler task finished") case <-ctx.Done(): - return ctx.Err() + return fmt.Errorf("scheduler context canceled: %w", ctx.Err()) } } } diff --git a/scheduler/scheduler_test.go b/scheduler/scheduler_test.go index dcbce78..13e91eb 100644 --- a/scheduler/scheduler_test.go +++ b/scheduler/scheduler_test.go @@ -16,7 +16,7 @@ func TestSuccessRun(t *testing.T) { buf := bytes.Buffer{} s := scheduler.New(1*time.Second, application.RunnerFunc(func(ctx context.Context) error { - buf.Write([]byte("1")) + buf.WriteString("1") return nil })) @@ -34,7 +34,7 @@ func TestErrorRun(t *testing.T) { buf := bytes.Buffer{} s := scheduler.New(1*time.Second, application.RunnerFunc(func(ctx context.Context) error { - buf.Write([]byte("1")) + buf.WriteString("1") return errors.New("some error") })) @@ -52,7 +52,7 @@ func TestContextDecline(t *testing.T) { buf := bytes.Buffer{} s := scheduler.New(1*time.Second, application.RunnerFunc(func(ctx context.Context) error { - buf.Write([]byte("1")) + buf.WriteString("1") return nil })) diff --git a/session/model.go b/session/model.go index 0c02b9a..c8efe23 100644 --- a/session/model.go +++ b/session/model.go @@ -3,10 +3,10 @@ package session import "time" type Session struct { - ID string `json:"id" db:"id"` - User string `json:"user" db:"user"` - Created time.Time `json:"created" db:"created"` - Expires time.Time `json:"expires" db:"expires"` + ID string `db:"id" json:"id"` + User string `db:"user" json:"user"` + Created time.Time `db:"created" json:"created"` + Expires time.Time `db:"expires" json:"expires"` } func (s *Session) IsExpired() bool { diff --git a/session/repository.go b/session/repository.go index f049b52..c8bfc9d 100644 --- a/session/repository.go +++ b/session/repository.go @@ -3,6 +3,7 @@ package session import ( "context" "database/sql" + "fmt" "github.com/mishankov/platforma/database" ) @@ -43,7 +44,7 @@ func (r *Repository) Get(ctx context.Context, id string) (*Session, error) { var session Session err := r.db.GetContext(ctx, &session, "SELECT * FROM sessions WHERE id = $1", id) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to get session by id: %w", err) } return &session, nil } @@ -52,7 +53,7 @@ func (r *Repository) GetByUserId(ctx context.Context, userID string) (*Session, var session Session err := r.db.GetContext(ctx, &session, "SELECT * FROM sessions WHERE \"user\" = $1", userID) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to get session by user id: %w", err) } return &session, nil } @@ -63,7 +64,10 @@ func (r *Repository) Create(ctx context.Context, session *Session) error { VALUES (:id, :user, :created, :expires) ` _, err := r.db.NamedExecContext(ctx, query, session) - return err + if err != nil { + return fmt.Errorf("failed to create session: %w", err) + } + return nil } func (r *Repository) Delete(ctx context.Context, id string) error { @@ -71,6 +75,8 @@ func (r *Repository) Delete(ctx context.Context, id string) error { DELETE FROM sessions WHERE id = $1 ` _, err := r.db.ExecContext(ctx, query, id) - - return err + if err != nil { + return fmt.Errorf("failed to delete session: %w", err) + } + return nil }