diff --git a/cmd/entries/delete.go b/cmd/entries/delete.go index eba687f..34acd00 100644 --- a/cmd/entries/delete.go +++ b/cmd/entries/delete.go @@ -161,6 +161,8 @@ func DeleteCmd() *cobra.Command { os.Exit(1) } + db.SaveLastAction(storage.UndoAction{Type: storage.ActionDelete, ProjectName: selectedEntry.ProjectName, Entry: selectedEntry}) + fmt.Println() ui.PrintSuccess(ui.EmojiSuccess, "Entry deleted successfully") ui.NewlineBelow() diff --git a/cmd/entries/manual.go b/cmd/entries/manual.go index 99428af..babdb45 100644 --- a/cmd/entries/manual.go +++ b/cmd/entries/manual.go @@ -360,6 +360,8 @@ func ManualCmd() *cobra.Command { os.Exit(1) } + db.SaveLastAction(storage.UndoAction{Type: storage.ActionManual, EntryID: entry.ID, ProjectName: entry.ProjectName}) + duration := entry.Duration() fmt.Println() ui.PrintSuccess(ui.EmojiSuccess, fmt.Sprintf("Created manual entry for %s", ui.Bold(entry.ProjectName))) diff --git a/cmd/root.go b/cmd/root.go index e1ae07d..999fbf2 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -39,6 +39,7 @@ Track time effortlessly with automatic project detection and simple commands.`, cmd.Flags().BoolP("version", "v", false, "version for tmpo") // Utilities + cmd.AddCommand(utilities.UndoCmd()) cmd.AddCommand(utilities.VersionCmd()) // Tracking diff --git a/cmd/tracking/pause.go b/cmd/tracking/pause.go index 001a652..f84bb60 100644 --- a/cmd/tracking/pause.go +++ b/cmd/tracking/pause.go @@ -44,6 +44,8 @@ func PauseCmd() *cobra.Command { os.Exit(1) } + db.SaveLastAction(storage.UndoAction{Type: storage.ActionPause, EntryID: running.ID, ProjectName: running.ProjectName}) + duration := time.Since(running.StartTime) ui.PrintSuccess(ui.EmojiStop, fmt.Sprintf("Paused tracking %s", ui.Bold(running.ProjectName))) diff --git a/cmd/tracking/resume.go b/cmd/tracking/resume.go index ea2fe9c..c7d5fd9 100644 --- a/cmd/tracking/resume.go +++ b/cmd/tracking/resume.go @@ -68,6 +68,8 @@ func ResumeCmd() *cobra.Command { os.Exit(1) } + db.SaveLastAction(storage.UndoAction{Type: storage.ActionResume, EntryID: entry.ID, ProjectName: entry.ProjectName}) + ui.PrintSuccess(ui.EmojiStart, fmt.Sprintf("Resumed tracking time for %s", ui.Bold(entry.ProjectName))) if entry.Description != "" { diff --git a/cmd/tracking/start.go b/cmd/tracking/start.go index 5efee1b..34fbb3d 100644 --- a/cmd/tracking/start.go +++ b/cmd/tracking/start.go @@ -74,6 +74,8 @@ func StartCmd() *cobra.Command { os.Exit(1) } + db.SaveLastAction(storage.UndoAction{Type: storage.ActionStart, EntryID: entry.ID, ProjectName: entry.ProjectName}) + ui.PrintSuccess(ui.EmojiStart, fmt.Sprintf("Started tracking time for %s", ui.Bold(entry.ProjectName))) // communicate config source to user diff --git a/cmd/tracking/stop.go b/cmd/tracking/stop.go index d186fcd..9c72f25 100644 --- a/cmd/tracking/stop.go +++ b/cmd/tracking/stop.go @@ -43,6 +43,8 @@ func StopCmd() *cobra.Command { os.Exit(1) } + db.SaveLastAction(storage.UndoAction{Type: storage.ActionStop, EntryID: running.ID, ProjectName: running.ProjectName}) + duration := time.Since(running.StartTime) ui.PrintSuccess(ui.EmojiStop, fmt.Sprintf("Stopped tracking %s", ui.Bold(running.ProjectName))) diff --git a/cmd/utilities/undo.go b/cmd/utilities/undo.go new file mode 100644 index 0000000..7894220 --- /dev/null +++ b/cmd/utilities/undo.go @@ -0,0 +1,110 @@ +package utilities + +import ( + "fmt" + "os" + + "github.com/DylanDevelops/tmpo/internal/storage" + "github.com/DylanDevelops/tmpo/internal/ui" + "github.com/manifoldco/promptui" + "github.com/spf13/cobra" +) + +var actionDescriptions = map[storage.ActionType]string{ + storage.ActionStop: "Stopped tracking", + storage.ActionPause: "Paused tracking", + storage.ActionStart: "Started tracking", + storage.ActionResume: "Resumed tracking", + storage.ActionManual: "Created manual entry for", + storage.ActionDelete: "Deleted entry for", +} + +func UndoCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "undo", + Short: "Undo the previous action", + Long: `Undo the previous action in case of a mistake or in need of a rollback.`, + Run: func(cmd *cobra.Command, args []string) { + ui.NewlineAbove() + + db, err := storage.Initialize() + if err != nil { + ui.PrintError(ui.EmojiError, fmt.Sprintf("%v", err)) + os.Exit(1) + } + defer db.Close() + + action, err := db.GetLastAction() + if err != nil { + ui.PrintError(ui.EmojiError, fmt.Sprintf("%v", err)) + os.Exit(1) + } + + if action == nil { + ui.PrintWarning(ui.EmojiWarning, "Nothing to undo.") + ui.NewlineBelow() + return + } + + ui.PrintInfo(0, ui.EmojiUndo+" Last action", undoActionDescription(action)) + fmt.Println() + + confirmPrompt := promptui.Prompt{ + Label: "Undo this action? [y/N]", + IsConfirm: true, + } + if _, err := confirmPrompt.Run(); err != nil { + ui.PrintWarning(ui.EmojiWarning, "Undo cancelled.") + ui.NewlineBelow() + return + } + + if err := applyUndo(db, action); err != nil { + ui.PrintError(ui.EmojiError, fmt.Sprintf("undo failed: %v", err)) + ui.NewlineBelow() + os.Exit(1) + } + + // not fatal if fails + db.ClearLastAction() + + ui.PrintSuccess(ui.EmojiUndo, "Undo successful.") + ui.NewlineBelow() + }, + } + + return cmd +} + +func undoActionDescription(action *storage.UndoAction) string { + if prefix, ok := actionDescriptions[action.Type]; ok { + return fmt.Sprintf("%s %s", prefix, ui.Bold(action.ProjectName)) + } + return fmt.Sprintf("Unknown action: %s", action.Type) +} + +func applyUndo(db *storage.Database, action *storage.UndoAction) error { + switch action.Type { + case storage.ActionStop, storage.ActionPause: + running, err := db.GetRunningEntry() + if err != nil { + return fmt.Errorf("checking for running entry: %w", err) + } + if running != nil { + return fmt.Errorf("a timer is already running for %s — stop it first with 'tmpo stop'", running.ProjectName) + } + return db.UncompleteEntry(action.EntryID) + + case storage.ActionStart, storage.ActionResume, storage.ActionManual: + return db.DeleteTimeEntry(action.EntryID) + + case storage.ActionDelete: + if action.Entry == nil { + return fmt.Errorf("no entry snapshot available to restore") + } + return db.RestoreDeletedEntry(action.Entry) + + default: + return fmt.Errorf("unknown action type: %s", action.Type) + } +} diff --git a/cmd/utilities/undo_test.go b/cmd/utilities/undo_test.go new file mode 100644 index 0000000..3993f4d --- /dev/null +++ b/cmd/utilities/undo_test.go @@ -0,0 +1,101 @@ +package utilities + +import ( + "testing" + + "github.com/DylanDevelops/tmpo/internal/storage" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func setupUndoTestDB(t *testing.T) *storage.Database { + t.Helper() + tmpHome := t.TempDir() + t.Setenv("HOME", tmpHome) + t.Setenv("USERPROFILE", tmpHome) + t.Setenv("TMPO_DEV", "1") + db, err := storage.Initialize() + require.NoError(t, err) + t.Cleanup(func() { db.Close() }) + return db +} + +func TestUndoActionDescription(t *testing.T) { + tests := []struct { + actionType storage.ActionType + contains string + }{ + {storage.ActionStop, "Stopped tracking"}, + {storage.ActionPause, "Paused tracking"}, + {storage.ActionStart, "Started tracking"}, + {storage.ActionResume, "Resumed tracking"}, + {storage.ActionManual, "Created manual entry for"}, + {storage.ActionDelete, "Deleted entry for"}, + } + + for _, tt := range tests { + t.Run(string(tt.actionType), func(t *testing.T) { + action := &storage.UndoAction{Type: tt.actionType, ProjectName: "proj"} + desc := undoActionDescription(action) + assert.Contains(t, desc, tt.contains) + assert.Contains(t, desc, "proj") + }) + } +} + +func TestUndoActionDescription_Unknown(t *testing.T) { + action := &storage.UndoAction{Type: "something_new", ProjectName: "proj"} + desc := undoActionDescription(action) + assert.Contains(t, desc, "Unknown action") + assert.Contains(t, desc, "something_new") +} + +func TestApplyUndo_Stop_ErrorWhenTimerAlreadyRunning(t *testing.T) { + db := setupUndoTestDB(t) + + stopped, err := db.CreateEntry("proj", "", nil, nil) + require.NoError(t, err) + require.NoError(t, db.StopEntry(stopped.ID)) + + _, err = db.CreateEntry("other", "", nil, nil) + require.NoError(t, err) + + action := &storage.UndoAction{Type: storage.ActionStop, EntryID: stopped.ID, ProjectName: "proj"} + err = applyUndo(db, action) + assert.Error(t, err) + assert.Contains(t, err.Error(), "timer is already running") +} + +func TestApplyUndo_Pause_ErrorWhenTimerAlreadyRunning(t *testing.T) { + db := setupUndoTestDB(t) + + stopped, err := db.CreateEntry("proj", "", nil, nil) + require.NoError(t, err) + require.NoError(t, db.StopEntry(stopped.ID)) + + _, err = db.CreateEntry("other", "", nil, nil) + require.NoError(t, err) + + action := &storage.UndoAction{Type: storage.ActionPause, EntryID: stopped.ID, ProjectName: "proj"} + err = applyUndo(db, action) + assert.Error(t, err) + assert.Contains(t, err.Error(), "timer is already running") +} + +func TestApplyUndo_Delete_ErrorWhenNoSnapshot(t *testing.T) { + db := setupUndoTestDB(t) + + action := &storage.UndoAction{Type: storage.ActionDelete, ProjectName: "proj", Entry: nil} + err := applyUndo(db, action) + assert.Error(t, err) + assert.Contains(t, err.Error(), "no entry snapshot") +} + +func TestApplyUndo_UnknownType_ReturnsError(t *testing.T) { + db := setupUndoTestDB(t) + + action := &storage.UndoAction{Type: "bogus", ProjectName: "proj"} + err := applyUndo(db, action) + assert.Error(t, err) + assert.Contains(t, err.Error(), "unknown action type") +} diff --git a/internal/storage/undo.go b/internal/storage/undo.go new file mode 100644 index 0000000..8c0af86 --- /dev/null +++ b/internal/storage/undo.go @@ -0,0 +1,115 @@ +package storage + +import ( + "database/sql" + "encoding/json" + "fmt" + "time" +) + +type ActionType string + +const ( + ActionStop ActionType = "stop" + ActionStart ActionType = "start" + ActionPause ActionType = "pause" + ActionResume ActionType = "resume" + ActionDelete ActionType = "delete" + ActionManual ActionType = "manual" +) + +const lastActionKey = "last_action" + +type UndoAction struct { + Type ActionType `json:"type"` + EntryID int64 `json:"entry_id,omitempty"` + ProjectName string `json:"project_name,omitempty"` + Entry *TimeEntry `json:"entry,omitempty"` +} + +func (d *Database) SaveLastAction(action UndoAction) error { + data, err := json.Marshal(action) + if err != nil { + return fmt.Errorf("failed to serialize action: %w", err) + } + _, err = d.db.Exec( + "INSERT OR REPLACE INTO settings (key, value, updated_at) VALUES (?, ?, ?)", + lastActionKey, + string(data), + time.Now().UTC(), + ) + if err != nil { + return fmt.Errorf("failed to save last action: %w", err) + } + return nil +} + +func (d *Database) GetLastAction() (*UndoAction, error) { + var value string + err := d.db.QueryRow("SELECT value FROM settings WHERE key = ?", lastActionKey).Scan(&value) + if err == sql.ErrNoRows { + return nil, nil + } + if err != nil { + return nil, fmt.Errorf("failed to get last action: %w", err) + } + var action UndoAction + if err := json.Unmarshal([]byte(value), &action); err != nil { + return nil, fmt.Errorf("failed to parse last action: %w", err) + } + return &action, nil +} + +func (d *Database) ClearLastAction() error { + _, err := d.db.Exec("DELETE FROM settings WHERE key = ?", lastActionKey) + if err != nil { + return fmt.Errorf("failed to clear last action: %w", err) + } + return nil +} + +// UncompleteEntry clears the end_time of an entry, resuming it as a running timer. +func (d *Database) UncompleteEntry(id int64) error { + result, err := d.db.Exec("UPDATE time_entries SET end_time = NULL WHERE id = ?", id) + if err != nil { + return fmt.Errorf("failed to uncomplete entry: %w", err) + } + rows, err := result.RowsAffected() + if err != nil { + return fmt.Errorf("failed to uncomplete entry: %w", err) + } + if rows == 0 { + return fmt.Errorf("entry %d not found", id) + } + return nil +} + +// RestoreDeletedEntry re-inserts a previously deleted entry preserving its original ID. +func (d *Database) RestoreDeletedEntry(entry *TimeEntry) error { + var endTime sql.NullTime + if entry.EndTime != nil { + endTime = sql.NullTime{Time: entry.EndTime.UTC(), Valid: true} + } + var rate sql.NullFloat64 + if entry.HourlyRate != nil { + rate = sql.NullFloat64{Float64: *entry.HourlyRate, Valid: true} + } + var milestone sql.NullString + if entry.MilestoneName != nil { + milestone = sql.NullString{String: *entry.MilestoneName, Valid: true} + } + _, err := d.db.Exec( + "INSERT INTO time_entries (id, project_name, start_time, end_time, description, hourly_rate, milestone_name) VALUES (?, ?, ?, ?, ?, ?, ?)", + entry.ID, + entry.ProjectName, + entry.StartTime.UTC(), + endTime, + entry.Description, + rate, + milestone, + ) + if err != nil { + return fmt.Errorf("failed to restore entry: %w", err) + } + return nil +} diff --git a/internal/storage/undo_test.go b/internal/storage/undo_test.go new file mode 100644 index 0000000..22963d0 --- /dev/null +++ b/internal/storage/undo_test.go @@ -0,0 +1,162 @@ +package storage + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func setupUndoTestDB(t *testing.T) *Database { + t.Helper() + db := setupTestDB(t) + _, err := db.db.Exec(` + CREATE TABLE IF NOT EXISTS settings ( + key TEXT PRIMARY KEY, + value TEXT NOT NULL, + updated_at DATETIME NOT NULL + ) + `) + require.NoError(t, err) + return db +} + +func TestSaveAndGetLastAction(t *testing.T) { + db := setupUndoTestDB(t) + defer db.Close() + + action := UndoAction{Type: ActionStop, EntryID: 42, ProjectName: "myproject"} + require.NoError(t, db.SaveLastAction(action)) + + got, err := db.GetLastAction() + require.NoError(t, err) + require.NotNil(t, got) + assert.Equal(t, ActionStop, got.Type) + assert.Equal(t, int64(42), got.EntryID) + assert.Equal(t, "myproject", got.ProjectName) +} + +func TestGetLastAction_WhenNone(t *testing.T) { + db := setupUndoTestDB(t) + defer db.Close() + + got, err := db.GetLastAction() + require.NoError(t, err) + assert.Nil(t, got) +} + +func TestSaveLastAction_Overwrites(t *testing.T) { + db := setupUndoTestDB(t) + defer db.Close() + + require.NoError(t, db.SaveLastAction(UndoAction{Type: ActionStop, EntryID: 1, ProjectName: "first"})) + require.NoError(t, db.SaveLastAction(UndoAction{Type: ActionStart, EntryID: 2, ProjectName: "second"})) + + got, err := db.GetLastAction() + require.NoError(t, err) + require.NotNil(t, got) + assert.Equal(t, ActionStart, got.Type) + assert.Equal(t, int64(2), got.EntryID) +} + +func TestClearLastAction(t *testing.T) { + db := setupUndoTestDB(t) + defer db.Close() + + require.NoError(t, db.SaveLastAction(UndoAction{Type: ActionStop, EntryID: 1, ProjectName: "proj"})) + require.NoError(t, db.ClearLastAction()) + + got, err := db.GetLastAction() + require.NoError(t, err) + assert.Nil(t, got) +} + +func TestSaveLastAction_PreservesEntrySnapshot(t *testing.T) { + db := setupUndoTestDB(t) + defer db.Close() + + rate := 75.0 + milestone := "v1" + end := time.Now() + entry := &TimeEntry{ + ID: 99, + ProjectName: "proj", + StartTime: time.Now().Add(-time.Hour), + EndTime: &end, + Description: "some work", + HourlyRate: &rate, + MilestoneName: &milestone, + } + + require.NoError(t, db.SaveLastAction(UndoAction{Type: ActionDelete, ProjectName: "proj", Entry: entry})) + + got, err := db.GetLastAction() + require.NoError(t, err) + require.NotNil(t, got) + require.NotNil(t, got.Entry) + assert.Equal(t, int64(99), got.Entry.ID) + assert.Equal(t, "some work", got.Entry.Description) + assert.Equal(t, 75.0, *got.Entry.HourlyRate) + assert.Equal(t, "v1", *got.Entry.MilestoneName) +} + +func TestUncompleteEntry(t *testing.T) { + db := setupTestDB(t) + defer db.Close() + + entry, err := db.CreateEntry("proj", "", nil, nil) + require.NoError(t, err) + require.NoError(t, db.StopEntry(entry.ID)) + + stopped, err := db.GetEntry(entry.ID) + require.NoError(t, err) + assert.NotNil(t, stopped.EndTime) + + require.NoError(t, db.UncompleteEntry(entry.ID)) + + running, err := db.GetEntry(entry.ID) + require.NoError(t, err) + assert.Nil(t, running.EndTime) +} + +func TestRestoreDeletedEntry(t *testing.T) { + db := setupTestDB(t) + defer db.Close() + + start := time.Now().Add(-2 * time.Hour) + end := time.Now().Add(-time.Hour) + rate := 100.0 + milestone := "m1" + original, err := db.CreateManualEntry("proj", "work", start, end, &rate, &milestone) + require.NoError(t, err) + + require.NoError(t, db.DeleteTimeEntry(original.ID)) + + require.NoError(t, db.RestoreDeletedEntry(original)) + + restored, err := db.GetEntry(original.ID) + require.NoError(t, err) + assert.Equal(t, original.ID, restored.ID) + assert.Equal(t, original.ProjectName, restored.ProjectName) + assert.Equal(t, original.Description, restored.Description) + assert.Equal(t, *original.HourlyRate, *restored.HourlyRate) + assert.Equal(t, *original.MilestoneName, *restored.MilestoneName) + assert.NotNil(t, restored.EndTime) +} + +func TestRestoreDeletedEntry_RunningEntry(t *testing.T) { + db := setupTestDB(t) + defer db.Close() + + original, err := db.CreateEntry("proj", "active", nil, nil) + require.NoError(t, err) + require.NoError(t, db.DeleteTimeEntry(original.ID)) + + require.NoError(t, db.RestoreDeletedEntry(original)) + + restored, err := db.GetEntry(original.ID) + require.NoError(t, err) + assert.Equal(t, original.ID, restored.ID) + assert.Nil(t, restored.EndTime) +} diff --git a/internal/ui/ui.go b/internal/ui/ui.go index 5e9cd14..f3f89fe 100644 --- a/internal/ui/ui.go +++ b/internal/ui/ui.go @@ -48,6 +48,7 @@ const ( EmojiError = "❌" EmojiWarning = "⚠️" EmojiInfo = "ℹ️" + EmojiUndo = "↩️" ) func Success(message string) string {