diff --git a/.github/workflows/integration_test.yml b/.github/workflows/integration_test.yml index 34834ce..e8db6b7 100644 --- a/.github/workflows/integration_test.yml +++ b/.github/workflows/integration_test.yml @@ -40,6 +40,8 @@ jobs: docker compose logs - name: Run integration tests + env: + TEST_FN_DISABLED: true run: | cd test go test -v -p=1 ./... -cover -coverpkg=../... -coverprofile cover.out && go tool cover -func cover.out diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 2a263d0..f2a5955 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -23,6 +23,8 @@ jobs: go-version: 1.25.1 - name: Test + env: + TEST_FN_DISABLED: true run: go test -v -p=1 ./... -cover -coverprofile cover.out && go tool cover -func cover.out - name: Build diff --git a/Readme.md b/Readme.md index 412a91b..cc43fd0 100644 --- a/Readme.md +++ b/Readme.md @@ -435,70 +435,69 @@ Each step contains: - **Compensation**: A rollback operation that undoes the action if later steps fail Steps execute sequentially. If any step fails, all previous steps are automatically -compensated in reverse order, ensuring system consistency +compensated, ensuring system consistency. # drawing Example: ```go -// Use StepBuilder for more complex configuration -// This approach provides access to all library features: -// - Panic recovery -// - Retry policies -// - Custom backoff strategies -// - Jitter for load distribution steps := []saga.Step{ - saga.NewStep("first_step"). - WithAction( - // Add action with decorators - saga.NewAction(func(ctx context.Context) error { - // Simulate error to demonstrate retry - return fmt.Errorf("first_step_Error") - }). - // Protection against panics — important for production! - // If the action panics, the panic will be caught - // and returned as an error with ErrPanicRecovered - WithPanicRecovery(). - // Add retry for action - WithRetry( - // 2 attempts, 1s between attempts - saga.NewBaseRetryOpt(2, 1*time.Second). - // Return all errors which arise during retries - WithReturnAllAroseErr(), ), ). - // Add compensation - WithCompensation( - saga.NewCompensation(func(ctx context.Context, aroseErr error) error { - // Compensation logic. - // aroseErr — error from action that triggered compensation - // This can be useful for logging or strategy selection - return nil - }). - // Compensation can also have retry logic - WithRetry( - saga.NewAdvanceRetryPolicy( - 2, // max attempts - 1*time.Second, // initial delay - saga.NewExponentialBackoff(), // exponential backoff - ). - // Jitter prevents "thundering herd" - WithJitter( - // random delay - saga.NewFullJitter(), - ). - // maximum delay - WithMaxDelay(10 * time.Second), ), ),} + saga.NewStep("first_step"). + WithAction( + // Add action with decorators + saga.NewAction(func(ctx context.Context, track saga.Track) error { + err := fmt.Errorf("first_step_Error") + return err + }). + // Protection against panics — important for production! + // If the action panics, the panic will be caught + // and returned as an error with ErrPanicRecovered + WithPanicRecovery(). + // Add retry for action + WithRetry( + // 2 attempts, 1s between attempts + saga.NewBaseRetryOpt(2, 1*time.Second), + ), + ). + // Add compensation + WithCompensation( + saga.NewCompensation(func(ctx context.Context, track saga.Track) error { + // Compensation logic. + // Use track.GetData() to inspect what failed + data := track.GetData() + if len(data.Action.Errors) > 0 { + log.Printf("Compensating for error: %v", data.Action.Errors[0]) + } + return performCompensation(ctx) + }). + // Compensation can also have retry logic + WithRetry( + saga.NewAdvanceRetryPolicy( + 2, // max attempts + 1*time.Second, // initial delay + saga.NewExponentialBackoff(), // exponential backoff + ). + // Jitter prevents "thundering herd" during mass failures + WithJitter( + saga.NewFullJitter(), // random delay + ). + // maximum delay cap + WithMaxDelay(10*time.Second), + ), + ). + WithCompensationRequired(), +} // Execute the saga // // With this approach: -// 1. If action fails, there will be 2 attempts with exponential backoff +// 1. If action fails, it will be retried according to the retry policy // 2. If all attempts fail, compensations will run -// 3. Compensations will also retry on failure -// 4. Jitter distributes load during mass failures -err := saga.NewSaga(steps).Execute(context.Background()) - +// 3. Compensations will also retry on failure with exponential backoff +// 4. Jitter distributes load during mass failure scenarios +result, err := saga.NewSaga(steps).Execute(context.Background()) if err != nil { - // Handle error + // Handle the `Result` and errors } ``` diff --git a/examples/sage_test.go b/examples/sage_test.go index beebbec..617f44c 100644 --- a/examples/sage_test.go +++ b/examples/sage_test.go @@ -2,6 +2,7 @@ package examples import ( "context" + "errors" "fmt" "testing" "time" @@ -17,12 +18,23 @@ import ( func Test_Saga_example(t *testing.T) { t.Skip() + var ( + // ErrPaymentFailed is an example error type for demonstration + ErrPaymentFailed = fmt.Errorf("payment failed") + + // refundPayment is an example compensation function + refundPayment = func(ctx context.Context) error { + // Implementation would refund a payment + return nil + } + ) + t.Run("first_example: simple declarative approach", func(t *testing.T) { // Create saga steps as simple structs // This approach is ideal when: - // - You have simple actions, compensation and retries logic + // - You have simple actions and compensation logic // - You want maximum readability - // - You don't need additional decorators + // - You don't need additional decorators (retry, panic recovery) steps := []saga.Step{ { // Name is used for logging and debugging @@ -31,42 +43,64 @@ func Test_Saga_example(t *testing.T) { // Action — the main function of the step // Executes business logic and returns an error on failure - Action: func(ctx context.Context) error { + // The track parameter provides access to execution context: + // - track.GetData() — retrieve step execution data + // - track.SetFailedOnError(err) — record errors + // - track.AddError(err) — append errors without changing status + Action: func(ctx context.Context, track saga.Track) error { // This could be: // - Database query via mtx.Transactor // - External API call // - Any other operation + // + // Use track to record intermediate errors: + // if err := someOperation(ctx); err != nil { + // track.SetFailedOnError(err) + // return err + // } return nil }, // Compensation — rollback function // Called if subsequent steps fail // Important: compensation must be idempotent! - Compensation: func(ctx context.Context, aroseErr error) error { - // aroseErr — the error that triggered compensation - // This allows making different decisions based on error type + // The track contains information about the failed action: + // - track.GetData().Action.Errors — errors from the action + // - track.GetData().Action.Status — status of the action + Compensation: func(ctx context.Context, track saga.Track) error { + // Get execution data to make decisions based on error type + data := track.GetData() - // Example: if errors.Is(aroseErr, ErrPaymentFailed) { - // return refundPayment(ctx) - // } + // Example: conditional compensation based on error + if len(data.Action.Errors) > 0 { + if errors.Is(data.Action.Errors[0], ErrPaymentFailed) { + // Handle specific error type + return refundPayment(ctx) + } + } + + // Default compensation logic return nil }, - // CompensationOnFail determines whether this step needs compensation + // CompensationRequired determines whether this step needs compensation // true: if step changes state and requires rollback // false: for read-only operations or non-compensatable actions (email, notifications) - CompensationOnFail: true, + CompensationRequired: true, }, } - // Create and execute the saga + // Create and execute the Saga. // Saga automatically manages the order: actions execute sequentially, // on error, compensations run in reverse order - err := saga.NewSaga(steps).Execute(context.Background()) + result, err := saga.NewSaga(steps).Execute(context.Background()) - // Handle the error + // Handle the result if err != nil { - // Important: err may contain both execution errors and compensation errors + // err contains detailed information about failures + // Use result to get detailed step-by-step execution data + t.Logf("Saga failed: %v\n", err) + fmt.Printf("Result status: %s\n", result.Status) } }) @@ -81,9 +115,12 @@ func Test_Saga_example(t *testing.T) { saga.NewStep("first_step"). WithAction( // Add action with decorators - saga.NewAction(func(ctx context.Context) error { + saga.NewAction(func(ctx context.Context, track saga.Track) error { // Simulate error to demonstrate retry - return fmt.Errorf("first_step_Error") + // Record the error in track + err := fmt.Errorf("first_step_Error") + track.SetFailedOnError(err) + return err }). // Protection against panics — important for production! // If the action panics, the panic will be caught @@ -92,17 +129,22 @@ func Test_Saga_example(t *testing.T) { // Add retry for action WithRetry( // 2 attempts, 1s between attempts - saga.NewBaseRetryOpt(2, 1*time.Second). - // Return all errors which arise during retries - WithReturnAllAroseErr(), + saga.NewBaseRetryOpt(2, 1*time.Second), ), ). // Add compensation WithCompensation( - saga.NewCompensation(func(ctx context.Context, aroseErr error) error { + saga.NewCompensation(func(ctx context.Context, track saga.Track) error { // Compensation logic. - // aroseErr — error from action that triggered compensation - // This can be useful for logging or strategy selection + // Get data to understand what failed + data := track.GetData() + + // Log the error that triggered compensation + if len(data.Action.Errors) > 0 { + fmt.Printf("Compensating for error: %v\n", data.Action.Errors[0]) + } + + // Perform compensation return nil }). // Compensation can also have retry logic @@ -120,20 +162,24 @@ func Test_Saga_example(t *testing.T) { // maximum delay WithMaxDelay(10 * time.Second), ), - ), + ). + // Mark that this step requires compensation + WithCompensationRequired(), } // Execute the saga // // With this approach: - // 1. If action fails, there will be 2 attempts with exponential backoff + // 1. If action fails, there will be 2 attempts with fixed delay // 2. If all attempts fail, compensations will run - // 3. Compensations will also retry on failure + // 3. Compensations will also retry on failure with exponential backoff // 4. Jitter distributes load during mass failures - err := saga.NewSaga(steps).Execute(context.Background()) + result, err := saga.NewSaga(steps).Execute(context.Background()) if err != nil { - // Handle error + // Handle error with full context + fmt.Printf("Saga execution failed: %v\n", err) + fmt.Printf("Result status: %s\n", result.Status) } }) } diff --git a/internal/testtool/assert.go b/internal/testtool/assert.go deleted file mode 100644 index 73d80fe..0000000 --- a/internal/testtool/assert.go +++ /dev/null @@ -1,37 +0,0 @@ -package testtool - -import ( - "testing" -) - -// AssertTrue was added to avoid to use external dependencies for mocking -func AssertTrue(t *testing.T, val bool) { - t.Helper() - if !val { - t.Fatalf("expected true [current value: %v]", val) - } -} - -// AssertFalse was added to avoid to use external dependencies for mocking -func AssertFalse(t *testing.T, val bool) { - t.Helper() - if val { - t.Fatalf("expected false [current value: %v]", val) - } -} - -// AssertNoError was added to avoid to use external dependencies for mocking -func AssertNoError(t *testing.T, err error) { - t.Helper() - if err != nil { - t.Fatalf("error arose: %v", err) - } -} - -// AssertError was added to avoid to use external dependencies for mocking -func AssertError(t *testing.T, err error) { - t.Helper() - if err == nil { - t.Fatalf("error expected") - } -} diff --git a/internal/testtool/assert/assert.go b/internal/testtool/assert/assert.go new file mode 100644 index 0000000..ee9a64b --- /dev/null +++ b/internal/testtool/assert/assert.go @@ -0,0 +1,61 @@ +package assert + +import ( + "errors" + "testing" +) + +// True was added to avoid to use external dependencies for mocking +func True(t *testing.T, val bool) { + t.Helper() + if !val { + t.Fatalf("expected true [current value: %v]", val) + } +} + +// False was added to avoid to use external dependencies for mocking +func False(t *testing.T, val bool) { + t.Helper() + if val { + t.Fatalf("expected false [current value: %v]", val) + } +} + +// Error was added to avoid to use external dependencies for mocking +func Error(t *testing.T, err error) { + t.Helper() + if err == nil { + t.Fatalf("error expected") + } +} + +// NoError was added to avoid to use external dependencies for mocking +func NoError(t *testing.T, err error) { + t.Helper() + if err != nil { + t.Fatalf("error arose: %v", err) + } +} + +func Equal[T comparable](t *testing.T, expected, target T) { + t.Helper() + if expected != target { + t.Fatalf("%v != %v", expected, target) + } +} + +// ErrorIs asserts that at least one of the errors in err's chain matches target. +// This is a wrapper for errors.Is. +func ErrorIs(t *testing.T, err, target error) { + t.Helper() + if !errors.Is(err, target) { + t.Fatalf("[%v] is not [%v]", err, target) + } +} + +func ErrorIsNot(t *testing.T, err, target error) { + t.Helper() + if errors.Is(err, target) { + t.Fatalf("[%v] is [%v]", err, target) + } +} diff --git a/internal/testtool/error.go b/internal/testtool/error.go index a43aede..ff8be89 100644 --- a/internal/testtool/error.go +++ b/internal/testtool/error.go @@ -3,6 +3,10 @@ package testtool import "fmt" var ( - // ErrExpTest - errors for tests. - ErrExpTest = fmt.Errorf("exp_test_error") + // ErrExpTestA is an error for tests. + ErrExpTestA = fmt.Errorf("exp_test_error_A") + // ErrExpTestB is an errors for tests. + ErrExpTestB = fmt.Errorf("exp_test_error_B") + // ErrExpTestС is an errors for tests. + ErrExpTestC = fmt.Errorf("exp_test_error_С") ) diff --git a/internal/testtool/fn.go b/internal/testtool/fn.go new file mode 100644 index 0000000..2f03178 --- /dev/null +++ b/internal/testtool/fn.go @@ -0,0 +1,30 @@ +package testtool + +import ( + "os" + "strings" + "testing" +) + +const ( + envTestLoggerDisabled = "TEST_FN_DISABLED" +) + +var ( + disableTestLogger = false +) + +func init() { + dtl := os.Getenv(envTestLoggerDisabled) + if strings.TrimSpace(strings.ToLower(dtl)) == "true" { + disableTestLogger = true + } +} + +func TestFn(t *testing.T, fn func()) { + t.Helper() + if disableTestLogger { + return + } + fn() +} diff --git a/internal/testtool/log.go b/internal/testtool/log.go deleted file mode 100644 index 46014a3..0000000 --- a/internal/testtool/log.go +++ /dev/null @@ -1,8 +0,0 @@ -package testtool - -import "testing" - -func LogError(t testing.TB, err error) { - t.Helper() - t.Logf("test error output: \n{\n%v\n}", err) -} diff --git a/mtx/transactor_test.go b/mtx/transactor_test.go index 5eccd11..2299e08 100644 --- a/mtx/transactor_test.go +++ b/mtx/transactor_test.go @@ -7,7 +7,7 @@ import ( "strings" "testing" - "github.com/kozmod/oniontx/internal/testtool" + "github.com/kozmod/oniontx/internal/testtool/assert" ) func Test_CtxOperator(t *testing.T) { @@ -21,8 +21,8 @@ func Test_CtxOperator(t *testing.T) { ) ctx = o.Inject(ctx, &c) extracted, ok := o.Extract(ctx) - testtool.AssertTrue(t, ok) - testtool.AssertTrue(t, extracted == &c) + assert.True(t, ok) + assert.True(t, extracted == &c) }) t.Run("extract_value", func(t *testing.T) { var ( @@ -37,8 +37,8 @@ func Test_CtxOperator(t *testing.T) { ) ctx = o.Inject(ctx, c) extracted, ok := o.Extract(ctx) - testtool.AssertTrue(t, ok) - testtool.AssertTrue(t, extracted == c) + assert.True(t, ok) + assert.True(t, extracted == c) }) t.Run("extract_nil_value", func(t *testing.T) { var ( @@ -53,8 +53,8 @@ func Test_CtxOperator(t *testing.T) { ) ctx = o.Inject(ctx, c) extracted, ok := o.Extract(ctx) - testtool.AssertTrue(t, ok) - testtool.AssertTrue(t, extracted == c) + assert.True(t, ok) + assert.True(t, extracted == c) }) }) @@ -84,13 +84,13 @@ func Test_Transactor(t *testing.T) { //nolint: dupl ) err := tr.WithinTx(ctx, func(ctx context.Context) error { tx, ok := tr.TryGetTx(ctx) - testtool.AssertTrue(t, ok) - testtool.AssertTrue(t, &c == tx) + assert.True(t, ok) + assert.True(t, &c == tx) return nil }) - testtool.AssertNoError(t, err) - testtool.AssertTrue(t, beginnerCalled) - testtool.AssertTrue(t, commitCalled) + assert.NoError(t, err) + assert.True(t, beginnerCalled) + assert.True(t, commitCalled) }) t.Run("TxBeginner", func(t *testing.T) { var ( @@ -114,13 +114,13 @@ func Test_Transactor(t *testing.T) { //nolint: dupl ) err := tr.WithinTx(ctx, func(ctx context.Context) error { beginner := tr.TxBeginner() - testtool.AssertTrue(t, beginner != nil) - testtool.AssertTrue(t, &b == beginner) + assert.True(t, beginner != nil) + assert.True(t, &b == beginner) return nil }) - testtool.AssertNoError(t, err) - testtool.AssertTrue(t, beginnerCalled) - testtool.AssertTrue(t, commitCalled) + assert.NoError(t, err) + assert.True(t, beginnerCalled) + assert.True(t, commitCalled) }) t.Run("WithinTx", func(t *testing.T) { t.Run("success_commit", func(t *testing.T) { @@ -145,13 +145,13 @@ func Test_Transactor(t *testing.T) { //nolint: dupl ) err := tr.WithinTx(ctx, func(ctx context.Context) error { tx, ok := o.Extract(ctx) - testtool.AssertTrue(t, ok) - testtool.AssertTrue(t, &c == tx) + assert.True(t, ok) + assert.True(t, &c == tx) return nil }) - testtool.AssertNoError(t, err) - testtool.AssertTrue(t, beginnerCalled) - testtool.AssertTrue(t, commitCalled) + assert.NoError(t, err) + assert.True(t, beginnerCalled) + assert.True(t, commitCalled) }) t.Run("success_and_not_commit_with_exists_tx", func(t *testing.T) { var ( @@ -170,12 +170,12 @@ func Test_Transactor(t *testing.T) { //nolint: dupl ctx = o.Inject(ctx, &c) err := tr.WithinTx(ctx, func(ctx context.Context) error { tx, ok := o.Extract(ctx) - testtool.AssertTrue(t, ok) - testtool.AssertTrue(t, &c == tx) + assert.True(t, ok) + assert.True(t, &c == tx) return nil }) - testtool.AssertNoError(t, err) - testtool.AssertTrue(t, !commitCalled) + assert.NoError(t, err) + assert.True(t, !commitCalled) }) t.Run("failed_commit", func(t *testing.T) { var ( @@ -199,13 +199,13 @@ func Test_Transactor(t *testing.T) { //nolint: dupl ) err := tr.WithinTx(ctx, func(ctx context.Context) error { tx, ok := o.Extract(ctx) - testtool.AssertTrue(t, ok) - testtool.AssertTrue(t, &c == tx) + assert.True(t, ok) + assert.True(t, &c == tx) return nil }) - testtool.AssertTrue(t, errors.Is(err, ErrCommitFailed)) - testtool.AssertTrue(t, errors.Is(err, expError)) - testtool.AssertTrue(t, commitCalled) + assert.True(t, errors.Is(err, ErrCommitFailed)) + assert.True(t, errors.Is(err, expError)) + assert.True(t, commitCalled) }) t.Run("success_rollback", func(t *testing.T) { var ( @@ -231,14 +231,14 @@ func Test_Transactor(t *testing.T) { //nolint: dupl ) err := tr.WithinTx(ctx, func(ctx context.Context) error { tx, ok := o.Extract(ctx) - testtool.AssertTrue(t, ok) - testtool.AssertTrue(t, &c == tx) + assert.True(t, ok) + assert.True(t, &c == tx) return expError }) - testtool.AssertTrue(t, errors.Is(err, ErrRollbackSuccess)) - testtool.AssertTrue(t, errors.Is(err, expError)) - testtool.AssertTrue(t, rollbackCalled) - testtool.AssertTrue(t, beginCalled) + assert.True(t, errors.Is(err, ErrRollbackSuccess)) + assert.True(t, errors.Is(err, expError)) + assert.True(t, rollbackCalled) + assert.True(t, beginCalled) }) t.Run("failed_rollback", func(t *testing.T) { var ( @@ -265,15 +265,15 @@ func Test_Transactor(t *testing.T) { //nolint: dupl ) err := tr.WithinTx(ctx, func(ctx context.Context) error { tx, ok := o.Extract(ctx) - testtool.AssertTrue(t, ok) - testtool.AssertTrue(t, &c == tx) + assert.True(t, ok) + assert.True(t, &c == tx) return transactorError }) - testtool.AssertTrue(t, errors.Is(err, ErrRollbackFailed)) - testtool.AssertTrue(t, errors.Is(err, transactorError)) - testtool.AssertTrue(t, errors.Is(err, rollbackErr)) - testtool.AssertTrue(t, rollbackCalled) - testtool.AssertTrue(t, beginCalled) + assert.True(t, errors.Is(err, ErrRollbackFailed)) + assert.True(t, errors.Is(err, transactorError)) + assert.True(t, errors.Is(err, rollbackErr)) + assert.True(t, rollbackCalled) + assert.True(t, beginCalled) }) t.Run("success_panic_rollback", func(t *testing.T) { var ( @@ -298,15 +298,15 @@ func Test_Transactor(t *testing.T) { //nolint: dupl ) err := tr.WithinTx(ctx, func(ctx context.Context) error { tx, ok := o.Extract(ctx) - testtool.AssertTrue(t, ok) - testtool.AssertTrue(t, &c == tx) + assert.True(t, ok) + assert.True(t, &c == tx) panic(expPanic) }) - testtool.AssertTrue(t, errors.Is(err, ErrRollbackSuccess)) - testtool.AssertTrue(t, errors.Is(err, ErrPanicRecovered)) - testtool.AssertTrue(t, strings.Contains(err.Error(), expPanic)) - testtool.AssertTrue(t, rollbackCalled) - testtool.AssertTrue(t, beginCalled) + assert.True(t, errors.Is(err, ErrRollbackSuccess)) + assert.True(t, errors.Is(err, ErrPanicRecovered)) + assert.True(t, strings.Contains(err.Error(), expPanic)) + assert.True(t, rollbackCalled) + assert.True(t, beginCalled) }) t.Run("failed_panic_rollback", func(t *testing.T) { const ( @@ -335,16 +335,16 @@ func Test_Transactor(t *testing.T) { //nolint: dupl ) err := tr.WithinTx(ctx, func(ctx context.Context) error { tx, ok := o.Extract(ctx) - testtool.AssertTrue(t, ok) - testtool.AssertTrue(t, &c == tx) + assert.True(t, ok) + assert.True(t, &c == tx) panic(expPanicMsg) }) - testtool.AssertTrue(t, errors.Is(err, ErrRollbackFailed)) - testtool.AssertTrue(t, errors.Is(err, ErrPanicRecovered)) - testtool.AssertTrue(t, errors.Is(err, rollbackErr)) - testtool.AssertTrue(t, strings.Contains(err.Error(), expPanicMsg)) - testtool.AssertTrue(t, rollbackCalled) - testtool.AssertTrue(t, beginCalled) + assert.True(t, errors.Is(err, ErrRollbackFailed)) + assert.True(t, errors.Is(err, ErrPanicRecovered)) + assert.True(t, errors.Is(err, rollbackErr)) + assert.True(t, strings.Contains(err.Error(), expPanicMsg)) + assert.True(t, rollbackCalled) + assert.True(t, beginCalled) }) t.Run("failed_begin_tx", func(t *testing.T) { var ( @@ -363,12 +363,12 @@ func Test_Transactor(t *testing.T) { //nolint: dupl ) err := tr.WithinTx(ctx, func(ctx context.Context) error { _, ok := o.Extract(ctx) - testtool.AssertFalse(t, ok) + assert.False(t, ok) return nil }) - testtool.AssertTrue(t, errors.Is(err, ErrBeginTx)) - testtool.AssertTrue(t, errors.Is(err, expError)) - testtool.AssertTrue(t, beginCalled) + assert.True(t, errors.Is(err, ErrBeginTx)) + assert.True(t, errors.Is(err, expError)) + assert.True(t, beginCalled) }) t.Run("error_when_beginner_is_nil", func(t *testing.T) { var ( @@ -379,7 +379,7 @@ func Test_Transactor(t *testing.T) { //nolint: dupl err := tr.WithinTx(ctx, func(ctx context.Context) error { return nil }) - testtool.AssertTrue(t, errors.Is(err, ErrNilTxBeginner)) + assert.True(t, errors.Is(err, ErrNilTxBeginner)) }) t.Run("error_when_operator_is_nil", func(t *testing.T) { var ( @@ -394,7 +394,7 @@ func Test_Transactor(t *testing.T) { //nolint: dupl err := tr.WithinTx(ctx, func(ctx context.Context) error { return nil }) - testtool.AssertTrue(t, errors.Is(err, ErrNilTxOperator)) + assert.True(t, errors.Is(err, ErrNilTxOperator)) }) }) } @@ -431,11 +431,11 @@ func Test_Transactor_recursive_call(t *testing.T) { //nolint: dupl assertTopLvl = func(ctx context.Context) { // tool.Assert that rollback was called on the recursion "top" level. - testtool.AssertTrue(t, isLvlEqual(ctx, ctxValTopLvl)) + assert.True(t, isLvlEqual(ctx, ctxValTopLvl)) // tool.Assert that rollback call wasn't called on the "second" recursion level. - testtool.AssertFalse(t, isLvlEqual(ctx, ctxValSecondLvl)) + assert.False(t, isLvlEqual(ctx, ctxValSecondLvl)) // tool.Assert that rollback call wasn't called on the "third" recursion level. - testtool.AssertFalse(t, isLvlEqual(ctx, ctxValThirdLvl)) + assert.False(t, isLvlEqual(ctx, ctxValThirdLvl)) } ) @@ -493,22 +493,22 @@ func Test_Transactor_recursive_call(t *testing.T) { //nolint: dupl err := tr.WithinTx(ctx, func(ctx context.Context) error { tx, ok := o.Extract(ctx) - testtool.AssertTrue(t, ok) - testtool.AssertTrue(t, c == tx) + assert.True(t, ok) + assert.True(t, c == tx) // inject "second" level variable in context.Context. ctx = injectLvl(ctx, ctxValSecondLvl) return tr.WithinTx(ctx, func(ctx context.Context) error { tx, ok := o.Extract(ctx) - testtool.AssertTrue(t, ok) - testtool.AssertTrue(t, c == tx) + assert.True(t, ok) + assert.True(t, c == tx) return expError }) }) - testtool.AssertTrue(t, errors.Is(err, ErrRollbackSuccess)) - testtool.AssertTrue(t, errors.Is(err, expError)) - testtool.AssertTrue(t, rollbackCalled == 1) - testtool.AssertTrue(t, beginCalled == 1) + assert.True(t, errors.Is(err, ErrRollbackSuccess)) + assert.True(t, errors.Is(err, expError)) + assert.True(t, rollbackCalled == 1) + assert.True(t, beginCalled == 1) }) t.Run("success_and_commit_on_top_lvl_func", func(t *testing.T) { @@ -525,32 +525,32 @@ func Test_Transactor_recursive_call(t *testing.T) { //nolint: dupl err := tr.WithinTx(ctx, func(ctx context.Context) error { tx, ok := o.Extract(ctx) - testtool.AssertTrue(t, ok) - testtool.AssertTrue(t, c == tx) + assert.True(t, ok) + assert.True(t, c == tx) // inject "second" level variable in context.Context. ctx = injectLvl(ctx, ctxValSecondLvl) err := tr.WithinTx(ctx, func(ctx context.Context) error { tx, ok := o.Extract(ctx) - testtool.AssertTrue(t, ok) - testtool.AssertTrue(t, c == tx) + assert.True(t, ok) + assert.True(t, c == tx) return nil }) - testtool.AssertNoError(t, err) + assert.NoError(t, err) err = tr.WithinTx(ctx, func(ctx context.Context) error { tx, ok := o.Extract(ctx) - testtool.AssertTrue(t, ok) - testtool.AssertTrue(t, c == tx) + assert.True(t, ok) + assert.True(t, c == tx) return nil }) - testtool.AssertTrue(t, err == nil) + assert.True(t, err == nil) return err }) - testtool.AssertTrue(t, beginCalled == 1) - testtool.AssertNoError(t, err) - testtool.AssertTrue(t, commitCalled == 1) + assert.True(t, beginCalled == 1) + assert.NoError(t, err) + assert.True(t, commitCalled == 1) }) t.Run("error_and_rollback_on_high_lvl_when_error_on_low_lvl_func", func(t *testing.T) { defer t.Cleanup(cleanup) @@ -568,35 +568,35 @@ func Test_Transactor_recursive_call(t *testing.T) { //nolint: dupl err := tr.WithinTx(ctx, func(ctx context.Context) error { tx, ok := o.Extract(ctx) - testtool.AssertTrue(t, ok) - testtool.AssertTrue(t, c == tx) + assert.True(t, ok) + assert.True(t, c == tx) // inject "second" level variable in context.Context. ctx = injectLvl(ctx, ctxValSecondLvl) err := tr.WithinTx(ctx, func(ctx context.Context) error { tx, ok := o.Extract(ctx) - testtool.AssertTrue(t, ok) - testtool.AssertTrue(t, c == tx) + assert.True(t, ok) + assert.True(t, c == tx) // inject "third" level variable in context.Context. ctx = injectLvl(ctx, ctxValThirdLvl) err := tr.WithinTx(ctx, func(ctx context.Context) error { tx, ok := o.Extract(ctx) - testtool.AssertTrue(t, ok) - testtool.AssertTrue(t, c == tx) + assert.True(t, ok) + assert.True(t, c == tx) return expError }) return err }) - testtool.AssertError(t, err) + assert.Error(t, err) return err }) - testtool.AssertTrue(t, errors.Is(err, ErrRollbackSuccess)) - testtool.AssertTrue(t, errors.Is(err, expError)) - testtool.AssertTrue(t, beginCalled == 1) - testtool.AssertTrue(t, commitCalled == 0) - testtool.AssertTrue(t, rollbackCalled == 1) + assert.True(t, errors.Is(err, ErrRollbackSuccess)) + assert.True(t, errors.Is(err, expError)) + assert.True(t, beginCalled == 1) + assert.True(t, commitCalled == 0) + assert.True(t, rollbackCalled == 1) }) t.Run("panic", func(t *testing.T) { const ( @@ -618,37 +618,37 @@ func Test_Transactor_recursive_call(t *testing.T) { //nolint: dupl err := tr.WithinTx(ctx, func(ctx context.Context) error { tx, ok := o.Extract(ctx) - testtool.AssertTrue(t, ok) - testtool.AssertTrue(t, c == tx) + assert.True(t, ok) + assert.True(t, c == tx) // inject "second" level variable in context.Context. ctx = injectLvl(ctx, ctxValSecondLvl) err := tr.WithinTx(ctx, func(ctx context.Context) error { tx, ok := o.Extract(ctx) - testtool.AssertTrue(t, ok) - testtool.AssertTrue(t, c == tx) + assert.True(t, ok) + assert.True(t, c == tx) // inject "second" level variable in context.Context. ctx = injectLvl(ctx, ctxValThirdLvl) err := tr.WithinTx(ctx, func(ctx context.Context) error { tx, ok := o.Extract(ctx) - testtool.AssertTrue(t, ok) - testtool.AssertTrue(t, c == tx) + assert.True(t, ok) + assert.True(t, c == tx) panic(lowLvlPanicMsg) }) - testtool.AssertError(t, err) + assert.Error(t, err) return err }) - testtool.AssertError(t, err) + assert.Error(t, err) return err }) - testtool.AssertTrue(t, errors.Is(err, ErrRollbackSuccess)) - testtool.AssertTrue(t, errors.Is(err, ErrPanicRecovered)) - testtool.AssertTrue(t, strings.Contains(err.Error(), lowLvlPanicMsg)) - testtool.AssertTrue(t, beginCalled == 1) - testtool.AssertTrue(t, commitCalled == 0) - testtool.AssertTrue(t, rollbackCalled == 1) + assert.True(t, errors.Is(err, ErrRollbackSuccess)) + assert.True(t, errors.Is(err, ErrPanicRecovered)) + assert.True(t, strings.Contains(err.Error(), lowLvlPanicMsg)) + assert.True(t, beginCalled == 1) + assert.True(t, commitCalled == 0) + assert.True(t, rollbackCalled == 1) }) t.Run("error_and_rollback_on_high_lvl_when_panic_on_middle_lvl_override_low_lvl", func(t *testing.T) { @@ -665,38 +665,38 @@ func Test_Transactor_recursive_call(t *testing.T) { //nolint: dupl err := tr.WithinTx(ctx, func(ctx context.Context) error { tx, ok := o.Extract(ctx) - testtool.AssertTrue(t, ok) - testtool.AssertTrue(t, c == tx) + assert.True(t, ok) + assert.True(t, c == tx) // inject "second" level variable in context.Context. ctx = injectLvl(ctx, ctxValSecondLvl) err := tr.WithinTx(ctx, func(ctx context.Context) error { tx, ok := o.Extract(ctx) - testtool.AssertTrue(t, ok) - testtool.AssertTrue(t, c == tx) + assert.True(t, ok) + assert.True(t, c == tx) // inject "second" level variable in context.Context. ctx = injectLvl(ctx, ctxValThirdLvl) err := tr.WithinTx(ctx, func(ctx context.Context) error { tx, ok := o.Extract(ctx) - testtool.AssertTrue(t, ok) - testtool.AssertTrue(t, c == tx) + assert.True(t, ok) + assert.True(t, c == tx) panic(lowLvlPanicMsg) }) - testtool.AssertError(t, err) + assert.Error(t, err) panic(middleLvlPanicMsg) }) - testtool.AssertTrue(t, err != nil) + assert.True(t, err != nil) return err }) - testtool.AssertTrue(t, errors.Is(err, ErrRollbackSuccess)) - testtool.AssertTrue(t, errors.Is(err, ErrPanicRecovered)) - testtool.AssertFalse(t, strings.Contains(err.Error(), lowLvlPanicMsg)) - testtool.AssertTrue(t, strings.Contains(err.Error(), middleLvlPanicMsg)) - testtool.AssertTrue(t, beginCalled == 1) - testtool.AssertTrue(t, commitCalled == 0) - testtool.AssertTrue(t, rollbackCalled == 1) + assert.True(t, errors.Is(err, ErrRollbackSuccess)) + assert.True(t, errors.Is(err, ErrPanicRecovered)) + assert.False(t, strings.Contains(err.Error(), lowLvlPanicMsg)) + assert.True(t, strings.Contains(err.Error(), middleLvlPanicMsg)) + assert.True(t, beginCalled == 1) + assert.True(t, commitCalled == 0) + assert.True(t, rollbackCalled == 1) }) }) } diff --git a/saga/actions.go b/saga/actions.go index b5db8d6..e5bf70c 100644 --- a/saga/actions.go +++ b/saga/actions.go @@ -7,7 +7,7 @@ import ( // ActionFunc represents a function that performs an action and can return an error. // It is designed to be used in workflows, sagas, or any operation that might need // additional behavior like panic recovery or retries. -type ActionFunc func(ctx context.Context) error +type ActionFunc func(ctx context.Context, track Track) error // NewAction creates a new ActionFunc. func NewAction(afn ActionFunc) ActionFunc { @@ -43,55 +43,59 @@ func (a ActionFunc) WithRetry(opt RetryPolicy) ActionFunc { // // Example: // -// action := someAction.WithBeforeHook(func(ctx context.Context) error { +// action := someAction.WithBeforeHook(func(ctx context.Context, _ Track) error { // log.Println("starting action") // return validateInput(ctx) // }) -func (a ActionFunc) WithBeforeHook(before func(ctx context.Context) error) ActionFunc { - return func(ctx context.Context) error { - err := before(ctx) +func (a ActionFunc) WithBeforeHook(before func(ctx context.Context, track Track) error) ActionFunc { + return func(ctx context.Context, track Track) error { + err := before(ctx, track) if err != nil { return err } - return a(ctx) + return a(ctx, track) } } // WithAfterHook adds an after-hook to the ActionFunc. -// The hook executes after the main action and receives both the context -// and the error returned by the action (nil if successful). -// The hook can inspect, log, or transform the error. +// The hook executes after the main action and receives the execution track, +// which contains information about the action's outcome (success/failure and errors). +// The hook can inspect, log, or modify the track state. // Returns a new ActionFunc with the after-hook applied. // // Use cases: // - Logging success/failure // - Metrics collection // - Resource cleanup -// - Error wrapping/enrichment +// - Error enrichment via track.AddError() or track.SetFailedOnError() // // Example: // -// action := someAction.WithAfterHook(func(ctx context.Context, err error) error { -// if err != nil { -// log.Printf("action failed: %v", err) -// return fmt.Errorf("operation failed: %w", err) +// action := someAction.WithAfterHook(func(ctx context.Context, track Track) error { +// data := track.GetData() +// if data.Action.Status == ExecutionStatusFail { +// log.Printf("action failed with errors: %v", data.Action.Errors) +// return nil // } // log.Println("action completed successfully") // return nil // }) -func (a ActionFunc) WithAfterHook(after func(ctx context.Context, aroseError error) error) ActionFunc { - return func(ctx context.Context) error { - err := a(ctx) - err = after(ctx, err) +func (a ActionFunc) WithAfterHook(after func(ctx context.Context, track Track) error) ActionFunc { + return func(ctx context.Context, track Track) error { + err := a(ctx, track) + if err != nil { + track.SetFailedOnError(err) + } + err = after(ctx, track) if err != nil { return err } - return err + return nil } } // WithWrapper adds a custom wrapper to the ActionFunc that can modify its behavior. -// The wrapper receives the context and the original ActionFunc, and returns an error. +// The wrapper receives the context, the track, and the original ActionFunc, and returns an error. // This provides maximum flexibility for cross-cutting concerns that don't fit into // the standard before/after hook pattern. // @@ -107,29 +111,18 @@ func (a ActionFunc) WithAfterHook(after func(ctx context.Context, aroseError err // The wrapper must call the provided ActionFunc to execute the original logic, // but can add behavior before, after, or around it. // -// Example with timing: +// Example: // -// action := someAction.WithWrapper(func(ctx context.Context, action ActionFunc) error { +// action := someAction.WithWrapper(func(ctx context.Context, track Track, action ActionFunc) error { // start := time.Now() -// err := action(ctx) +// err := action(ctx, track) // duration := time.Since(start) // metrics.RecordActionDuration(duration, err) // return err // }) -// -// Example with circuit breaker: -// -// action := someAction.WithWrapper(func(ctx context.Context, action ActionFunc) error { -// if circuit.IsOpen() { -// return ErrCircuitOpen -// } -// err := action(ctx) -// circuit.RecordResult(err) -// return err -// }) -func (a ActionFunc) WithWrapper(wrapper func(ctx context.Context, action ActionFunc) error) ActionFunc { - return func(ctx context.Context) error { - return wrapper(ctx, a) +func (a ActionFunc) WithWrapper(wrapper func(ctx context.Context, track Track, action ActionFunc) error) ActionFunc { + return func(ctx context.Context, track Track) error { + return wrapper(ctx, track, a) } } @@ -137,7 +130,7 @@ func (a ActionFunc) WithWrapper(wrapper func(ctx context.Context, action ActionF // when an error occurs in the main action. // It receives both the context and the error that triggered the compensation, // allowing it to make decisions based on the specific error that occurred. -type CompensationFunc func(ctx context.Context, actionErr error) error +type CompensationFunc func(ctx context.Context, track Track) error // NewCompensation creates a new CompensationFunc. func NewCompensation(afn CompensationFunc) CompensationFunc { @@ -150,11 +143,11 @@ func NewCompensation(afn CompensationFunc) CompensationFunc { // The original aroseErr is preserved and passed through to the compensation function. // Returns a new CompensationFunc with panic recovery enabled. func (c CompensationFunc) WithPanicRecovery() CompensationFunc { - return func(ctx context.Context, aroseErr error) error { - fn := func(ctx context.Context) error { - return c(ctx, aroseErr) + return func(ctx context.Context, track Track) error { + fn := func(ctx context.Context, track Track) error { + return c(ctx, track) } - return WithPanicRecovery(fn)(ctx) + return WithPanicRecovery(fn)(ctx, track) } } @@ -164,16 +157,17 @@ func (c CompensationFunc) WithPanicRecovery() CompensationFunc { // If all retry attempts fail, the behavior depends on RetryPolicy.ReturnAllAroseErr. // Returns a new CompensationFunc with retry logic enabled. func (c CompensationFunc) WithRetry(opt RetryPolicy) CompensationFunc { - return func(ctx context.Context, actionError error) error { - fn := func(ctx context.Context) error { - return c(ctx, actionError) + return func(ctx context.Context, track Track) error { + fn := func(ctx context.Context, track Track) error { + return c(ctx, track) } - return WithRetry(opt, fn)(ctx) + return WithRetry(opt, fn)(ctx, track) } } // WithBeforeHook adds a before-hook to the CompensationFunc. -// The hook executes before the compensation and receives the original action error. +// The hook executes before the compensation and receives the track, +// which contains information about the original action's failure. // If the hook returns an error, the compensation is skipped and the error is returned. // Returns a new CompensationFunc with the before-hook applied. // @@ -181,29 +175,32 @@ func (c CompensationFunc) WithRetry(opt RetryPolicy) CompensationFunc { // - Check if compensation is needed based on error type // - Logging compensation attempts // - Pre-compensation validation +// - Conditional compensation based on step state // // Example: // -// compensation := someCompensation.WithBeforeHook(func(ctx context.Context, actionErr error) error { -// if errors.Is(actionErr, ErrTemporaryFailure) { -// return nil // Compensate for temporary failures +// compensation := someCompensation.WithBeforeHook(func(ctx context.Context, track Track) error { +// data := track.GetData() +// if data.Action.Status == ExecutionStatusFail { +// // Only compensate if the action actually failed +// return nil // } -// return ErrSkipCompensation // Skip compensation for permanent failures +// return ErrSkipCompensation // }) -func (c CompensationFunc) WithBeforeHook(before func(ctx context.Context, actionErr error) error) CompensationFunc { - return func(ctx context.Context, actionErr error) error { - err := before(ctx, actionErr) +func (c CompensationFunc) WithBeforeHook(before func(ctx context.Context, track Track) error) CompensationFunc { + return func(ctx context.Context, track Track) error { + err := before(ctx, track) if err != nil { return err } - return c(ctx, actionErr) + return c(ctx, track) } } // WithAfterHook adds an after-hook to the CompensationFunc. -// The hook executes after the compensation and receives both the original action error -// and the error returned by the compensation (nil if successful). -// The hook can inspect, log, or transform the compensation error. +// The hook executes after the compensation and receives the track, +// which contains information about both the original action and the compensation outcome. +// The hook can inspect, log, or modify the compensation error via track. // Returns a new CompensationFunc with the after-hook applied. // // Use cases: @@ -214,28 +211,33 @@ func (c CompensationFunc) WithBeforeHook(before func(ctx context.Context, action // // Example: // -// compensation := someCompensation.WithAfterHook(func(ctx context.Context, actionErr, compensationErr error) error { -// if compensationErr != nil { -// log.Printf("CRITICAL: Compensation failed for action error %v: %v", actionErr, compensationErr) -// monitoring.Alert(ctx, "compensation_failed", compensationErr) -// return fmt.Errorf("compensation error (original: %w): %w", actionErr, compensationErr) +// compensation := someCompensation.WithAfterHook(func(ctx context.Context, track Track) error { +// data := track.GetData() +// if data.Compensation.Status == ExecutionStatusFail { +// log.Printf("CRITICAL: Compensation failed: %v", data.Compensation.Errors) +// monitoring.Alert(ctx, "compensation_failed", data.Compensation.Errors) +// return fmt.Errorf("compensation failed: %w", data.Compensation.Errors[0]) // } -// log.Printf("Compensation successful for action error: %v", actionErr) +// log.Printf("Compensation successful for action: %s", data.StepName) // return nil // }) -func (c CompensationFunc) WithAfterHook(after func(ctx context.Context, actionErr, aroseErr error) error) CompensationFunc { - return func(ctx context.Context, actionErr error) error { - aroseErr := c(ctx, actionErr) - err := after(ctx, actionErr, aroseErr) +func (c CompensationFunc) WithAfterHook(after func(ctx context.Context, track Track) error) CompensationFunc { + return func(ctx context.Context, track Track) error { + err := c(ctx, track) + if err != nil { + track.SetFailedOnError(err) + } + err = after(ctx, track) if err != nil { return err } - return err + + return nil } } // WithWrapper adds a custom wrapper to the CompensationFunc that can modify its behavior. -// The wrapper receives the context, the original action error, and the CompensationFunc, +// The wrapper receives the context, the track, and the CompensationFunc, // and returns an error. This provides maximum flexibility for cross-cutting concerns // specific to compensation logic. // @@ -244,24 +246,24 @@ func (c CompensationFunc) WithAfterHook(after func(ctx context.Context, actionEr // - Circuit breaking specifically for compensations // - Rate limiting compensation calls // - Distributed tracing with error context -// - Conditional compensation based on error type +// - Conditional compensation based on step state // - Compensation attempt counting and alerting // - Dead letter queue integration for failed compensations // -// The wrapper must call the provided CompensationFunc with the appropriate parameters -// to execute the original compensation logic. +// The wrapper must call the provided CompensationFunc to execute the original compensation logic. // // Example with error classification: // -// compensation := someCompensation.WithWrapper(func(ctx context.Context, actionErr error, comp CompensationFunc) error { -// if errors.Is(actionErr, ErrNonCritical) { -// log.Printf("Skipping compensation for non-critical error: %v", actionErr) +// compensation := someCompensation.WithWrapper(func(ctx context.Context, track Track, comp CompensationFunc) error { +// data := track.GetData() +// if errors.Is(data.Action.Errors[0], ErrNonCritical) { +// log.Printf("Skipping compensation for non-critical error") // return nil // } -// return comp(ctx, actionErr) +// return comp(ctx, track) // }) -func (c CompensationFunc) WithWrapper(wrapper func(ctx context.Context, actionErr error, comp CompensationFunc) error) CompensationFunc { - return func(ctx context.Context, actionErr error) error { - return wrapper(ctx, actionErr, c) +func (c CompensationFunc) WithWrapper(wrapper func(ctx context.Context, track Track, comp CompensationFunc) error) CompensationFunc { + return func(ctx context.Context, track Track) error { + return wrapper(ctx, track, c) } } diff --git a/saga/jitter_test.go b/saga/jitter_test.go index cafef6c..2820e9d 100644 --- a/saga/jitter_test.go +++ b/saga/jitter_test.go @@ -4,7 +4,7 @@ import ( "testing" "time" - "github.com/kozmod/oniontx/internal/testtool" + "github.com/kozmod/oniontx/internal/testtool/assert" ) func Test_retry_jitter(t *testing.T) { @@ -15,7 +15,7 @@ func Test_retry_jitter(t *testing.T) { ) delay := jitter.Jitter(base) - testtool.AssertTrue(t, delay < base) + assert.True(t, delay < base) }) t.Run("full_jitter", func(t *testing.T) { var ( @@ -24,6 +24,22 @@ func Test_retry_jitter(t *testing.T) { ) delay := jitter.Jitter(base) - testtool.AssertTrue(t, delay < base) + assert.True(t, delay < base) + }) +} + +func Test_jitter(t *testing.T) { + t.Run("full_jitter_v1", func(t *testing.T) { + var ( + baseTime = 10 * time.Nanosecond + fullJitter = NewFullJitter() + ) + jitter := fullJitter.Jitter(baseTime) + if jitter > baseTime { + t.Fatalf("jitter is greater than base time[jitter: %v, base_time: %v]", jitter, baseTime) + } + if jitter < 0 { + t.Fatalf("jitter is less than zero[jitter: %v]", jitter) + } }) } diff --git a/saga/recovery.go b/saga/recovery.go index 057c74a..e68f563 100644 --- a/saga/recovery.go +++ b/saga/recovery.go @@ -22,14 +22,14 @@ import ( // err := safeFn(ctx) // // err will wrap "panic [something went wrong]" and ErrPanicRecovered -func WithPanicRecovery(fn func(ctx context.Context) error) func(ctx context.Context) error { - return func(ctx context.Context) (err error) { +func WithPanicRecovery(fn func(ctx context.Context, track Track) error) func(context.Context, Track) error { + return func(ctx context.Context, track Track) (err error) { defer func() { if p := recover(); p != nil { err = errors.Join(ErrPanicRecovered, tool.WrapPanic(p)) } }() - err = fn(ctx) + err = fn(ctx, track) return err } } diff --git a/saga/result.go b/saga/result.go new file mode 100644 index 0000000..331393b --- /dev/null +++ b/saga/result.go @@ -0,0 +1,145 @@ +package saga + +import ( + "errors" + "fmt" + "strings" +) + +// StageStatus represents the overall outcome of a saga execution. +type StageStatus string + +const ( + // StageResultUnknown indicates the saga result cannot be determined. + StageResultUnknown StageStatus = "Unknown" + // StageResultFail indicates the saga failed and no compensation was applied + // (or compensation also failed). + StageResultFail StageStatus = "Fail" + // StageResultSuccess indicates all actions completed successfully + StageResultSuccess StageStatus = "Success" + // StageResultCompensated indicates some actions failed and successful + // compensations were applied. + StageResultCompensated StageStatus = "Compensated" +) + +// Result contains the complete execution report of a saga. +type Result struct { + Steps []StepData + Status StageStatus +} + +// String returns a human-readable representation of the Result. +func (r Result) String() string { + var builder strings.Builder + + builder.WriteString(fmt.Sprintf("Status: %s\n", r.Status)) + builder.WriteString(fmt.Sprintf("Steps(%d):\n", len(r.Steps))) + + for i, track := range r.Steps { + builder.WriteString(fmt.Sprintf(" [%d] %s\n", i+1, track.String())) + } + + return builder.String() +} + +// prepareResult analyzes execution tracks and produces a final Result. +// It evaluates the state of all steps and determines the overall saga outcome +// based on action failures, compensation requirements, and compensation outcomes. +// +// The function implements the following logic: +// - If no actions failed -> StageResultSuccess +// - If any compensation that was required to run failed -> StageResultFail +// - If there were failed actions requiring compensation but all compensations +// succeeded -> StageResultCompensated +// - Special case: when no compensations were required, no successful steps, +// and no successful compensations -> StageResultFail +// +// Returns: +// - Result: aggregated execution data for all steps +// - error: descriptive error with categorized lists of failed/compensated steps +func prepareResult(tracks []*executionTrack) (Result, error) { + var ( + result = Result{ + Steps: make([]StepData, 0, len(tracks)), + Status: StageResultUnknown, + } + failed = make([]string, 0, len(tracks)) + compensated = make([]string, 0, len(tracks)) + compensationNotRequired = make([]string, 0, len(tracks)) + failedWithCompensationReq = make([]string, 0, len(tracks)) + failedWithCompensationReqFailed = make([]string, 0, len(tracks)) + hasSuccessfulStep = false + + prepareStateStrFn = func(position uint32, name string) string { + return fmt.Sprintf("%d#%s", position, name) + } + + resultErrorFn = func(err error) error { + const comma = ", " + return fmt.Errorf( + "state failed - failed [%s], compensated [%s], compensation not required [%s], failed requiring compensation [%s]: %w", + strings.Join(failed, comma), + strings.Join(compensated, comma), + strings.Join(compensationNotRequired, comma), + strings.Join(failedWithCompensationReq, comma), + err, + ) + } + ) + + for _, tr := range tracks { + data := tr.GetData() + result.Steps = append(result.Steps, data) + + stepID := prepareStateStrFn(data.StepPosition, data.StepName) + + if data.Action.Status == ExecutionStatusSuccess { + hasSuccessfulStep = true + } + + if data.Action.Status == ExecutionStatusFail { + failed = append(failed, stepID) + + if data.CompensationRequired { + failedWithCompensationReq = append(failedWithCompensationReq, stepID) + if data.Compensation.Status == ExecutionStatusSuccess { + compensated = append(compensated, stepID) + } else { + failedWithCompensationReqFailed = append(failedWithCompensationReqFailed, stepID) + } + } + continue + } + + if data.Action.Status == ExecutionStatusSuccess { + switch data.Compensation.Status { + case ExecutionStatusSuccess: + compensated = append(compensated, stepID) + case ExecutionStatusUnset: + if !data.CompensationRequired { + compensationNotRequired = append(compensationNotRequired, stepID) + } + } + } + } + + switch { + case len(failed) == 0: + result.Status = StageResultSuccess + return result, nil + + case len(failedWithCompensationReqFailed) > 0: + result.Status = StageResultFail + return result, resultErrorFn(errors.Join(ErrActionFailed, ErrCompensationFailed)) + + case len(failedWithCompensationReq) == 0 && !hasSuccessfulStep && len(compensated) == 0: + // Edge case: no required compensations, no successful steps, no successful compensations + // This indicates a failure scenario where no meaningful recovery occurred + result.Status = StageResultFail + return result, resultErrorFn(errors.Join(ErrActionFailed, ErrCompensationFailed)) + + default: + result.Status = StageResultCompensated + return result, resultErrorFn(ErrActionFailed) + } +} diff --git a/saga/retry.go b/saga/retry.go index de7b4fa..58fc45f 100644 --- a/saga/retry.go +++ b/saga/retry.go @@ -19,10 +19,6 @@ type ( RetryPolicy interface { Attempts() uint32 Delay(attempt uint32) time.Duration - - // ReturnAllAroseErr indicates whether to return all collected errors - // from failed attempts (true) or just the last error (false). - ReturnAllAroseErr() bool } // Backoff defines the interface for backoff strategy calculation. @@ -38,10 +34,9 @@ type ( // baseRetryPolicy provides common fields and basic implementation for retry options. type baseRetryPolicy struct { - attempts uint32 - delay time.Duration - maxDelay time.Duration - returnAllAroseErr bool + attempts uint32 + delay time.Duration + maxDelay time.Duration } // Attempts returns the configured maximum number of retry attempts. @@ -49,11 +44,6 @@ func (o baseRetryPolicy) Attempts() uint32 { return o.attempts } -// ReturnAllAroseErr returns the configured error aggregation behavior. -func (o baseRetryPolicy) ReturnAllAroseErr() bool { - return o.returnAllAroseErr -} - // Delay returns a constant delay duration regardless of attempt number. func (o baseRetryPolicy) Delay(_ uint32) time.Duration { return o.delay @@ -76,12 +66,6 @@ func NewBaseRetryOpt(attempts uint32, delay time.Duration) *BaseRetryPolicy { } } -// WithReturnAllAroseErr enables returning all errors from failed attempts. -func (o BaseRetryPolicy) WithReturnAllAroseErr() BaseRetryPolicy { - o.returnAllAroseErr = true - return o -} - // AdvanceRetryPolicy provides configurable retry behavior with pluggable // backoff and jitter strategies. This allows for flexible composition of // different retry algorithms. @@ -104,12 +88,6 @@ func NewAdvanceRetryPolicy(attempts uint32, delay time.Duration, backoff Backoff } } -// WithReturnAllAroseErr enables returning all errors from failed attempts. -func (o AdvanceRetryPolicy) WithReturnAllAroseErr() AdvanceRetryPolicy { - o.baseRetryPolicy.returnAllAroseErr = true - return o -} - // WithJitter adds jitter to the retry policy. func (o AdvanceRetryPolicy) WithJitter(jitter Jitter) AdvanceRetryPolicy { o.jitter = jitter @@ -127,11 +105,6 @@ func (o AdvanceRetryPolicy) Attempts() uint32 { return o.attempts } -// ReturnAllAroseErr returns the configured error aggregation behavior. -func (o AdvanceRetryPolicy) ReturnAllAroseErr() bool { - return o.returnAllAroseErr -} - // Delay returns a constant delay duration regardless of attempt number. func (o AdvanceRetryPolicy) Delay(i uint32) time.Duration { var ( @@ -159,51 +132,55 @@ func (o AdvanceRetryPolicy) Delay(i uint32) time.Duration { // - Before each attempt, it waits for the delay provided by opt.Delay(attempt) // - The function stops retrying on first successful execution // - Context cancellation is respected between attempts (checked before each attempt) -// - All errors from failed attempts are collected +// - All Errors from failed attempts are collected // // If all attempts fail, behavior depends on ReturnAllAroseErr(): -// - if true - returns all errors via errors.Join(...) +// - if true - returns all Errors via Errors.Join(...) // - if false - returns only the last error that occurred -func WithRetry(opt RetryPolicy, fn func(ctx context.Context) error) func(ctx context.Context) error { - return func(ctx context.Context) error { +func WithRetry(opt RetryPolicy, fn func(ctx context.Context, track Track) error) func(context.Context, Track) error { + return func(ctx context.Context, track Track) error { // first call var ( - attempts = opt.Attempts() - retryErrs []error + attempts = opt.Attempts() ) - err := fn(ctx) + err := fn(ctx, track) switch { case err == nil: + track.SetStatus(ExecutionStatusSuccess) return nil case attempts == 0: return err case err != nil: - err = fmt.Errorf("action error: %w", err) - retryErrs = append(retryErrs, err) + track.SetFailedOnError(err) + } // retries stop: for i := uint32(0); i < attempts; i++ { + track.call() + track.setParentError(fmt.Errorf("retry [%d]", i)) select { case <-ctx.Done(): - err = errors.Join(ErrRetryContextDone, ctx.Err()) - retryErrs = append(retryErrs, fmt.Errorf("retry [%d]: %w", i, err)) + track.SetFailedOnError( + errors.Join(ErrRetryContextDone, ctx.Err()), + ) break stop default: - err = fn(ctx) + err = fn(ctx, track) if err == nil { + track.SetStatus(ExecutionStatusSuccess) break stop } - retryErrs = append(retryErrs, fmt.Errorf("retry [%d]: %w", i, err)) - time.Sleep(opt.Delay(i)) + track.SetFailedOnError(err) + if i < attempts-1 { + time.Sleep(opt.Delay(i)) + } } } - if opt.ReturnAllAroseErr() { - return errors.Join(retryErrs...) - } - return err + track.setParentError(nil) + return nil } } diff --git a/saga/retry_test.go b/saga/retry_test.go index 6c7f80f..8bceb87 100644 --- a/saga/retry_test.go +++ b/saga/retry_test.go @@ -2,12 +2,11 @@ package saga import ( "context" - "errors" - "fmt" "testing" "time" "github.com/kozmod/oniontx/internal/testtool" + "github.com/kozmod/oniontx/internal/testtool/assert" ) func Test_backoff(t *testing.T) { @@ -23,22 +22,6 @@ func Test_backoff(t *testing.T) { }) } -func Test_jitter(t *testing.T) { - t.Run("full_jitter_v1", func(t *testing.T) { - var ( - baseTime = 10 * time.Nanosecond - fullJitter = NewFullJitter() - ) - jitter := fullJitter.Jitter(baseTime) - if jitter > baseTime { - t.Fatalf("jitter is greater than base time[jitter: %v, base_time: %v]", jitter, baseTime) - } - if jitter < 0 { - t.Fatalf("jitter is less than zero[jitter: %v]", jitter) - } - }) -} - func Test_Saga_retry(t *testing.T) { var ( ctx = context.Background() @@ -54,121 +37,94 @@ func Test_Saga_retry(t *testing.T) { Name: "step0", Action: WithRetry( NewBaseRetryOpt(3, time.Nanosecond), - func(ctx context.Context) error { + func(ctx context.Context, _ Track) error { actionCalls++ errCounter++ if errCounter < 3 { - return testtool.ErrExpTest + return testtool.ErrExpTestA } return nil }), }, } - err := NewSaga(steps).Execute(ctx) - testtool.AssertNoError(t, err) - testtool.AssertTrue(t, actionCalls == 3) - }) - t.Run("success_with_return_all_errors", func(t *testing.T) { - var ( - errCounter = 0 - actionCalls = 0 - - secondExpErr = fmt.Errorf("some_err_2") - ) - steps := []Step{ - NewStep("step0"). - WithAction( - WithRetry( - NewBaseRetryOpt(5, time.Nanosecond). - WithReturnAllAroseErr(), - func(ctx context.Context) error { - actionCalls++ - errCounter++ - switch actionCalls { - case 1: - return testtool.ErrExpTest - case 2: - return secondExpErr - case 3: - return nil - default: - t.Fatalf("should not happen") - return nil - } - }), - ), + resp, err := NewSaga(steps).Execute(ctx) + assert.NoError(t, err) + assert.Equal(t, StageResultSuccess, resp.Status) + assert.Equal(t, 3, actionCalls) + assert.Equal(t, ExecutionStatusSuccess, resp.Steps[0].Action.Status) + assert.Equal(t, 3, resp.Steps[0].Action.Calls) + assert.Equal(t, 2, len(resp.Steps[0].Action.Errors)) + for _, e := range resp.Steps[0].Action.Errors { + assert.ErrorIs(t, e, testtool.ErrExpTestA) } - - err := NewSaga(steps).Execute(ctx) - testtool.AssertError(t, err) - testtool.AssertTrue(t, actionCalls == 3) - testtool.AssertTrue(t, errors.Is(err, testtool.ErrExpTest)) - testtool.AssertTrue(t, errors.Is(err, secondExpErr)) - testtool.AssertTrue(t, errors.Is(err, ErrActionFailed)) - - testtool.LogError(t, err) }) - }) - t.Run("builders", func(t *testing.T) { - t.Run("success_ActionFunc", func(t *testing.T) { - var ( - errCounter = 0 - actionCalls = 0 - ) - steps := []Step{ - { - Name: "step0", - Action: ActionFunc(func(ctx context.Context) error { - actionCalls++ - errCounter++ - if errCounter < 3 { - return testtool.ErrExpTest - } - return nil - }).WithRetry( - NewBaseRetryOpt(3, time.Nanosecond), - ), - }, - } + t.Run("builders", func(t *testing.T) { + t.Run("success_ActionFunc", func(t *testing.T) { + var ( + errCounter = 0 + actionCalls = 0 + ) + steps := []Step{ + { + Name: "step0", + Action: ActionFunc(func(ctx context.Context, _ Track) error { + actionCalls++ + errCounter++ + if errCounter < 3 { + return testtool.ErrExpTestA + } + return nil + }).WithRetry( + NewBaseRetryOpt(3, time.Nanosecond), + ), + }, + } - err := NewSaga(steps).Execute(ctx) - testtool.AssertNoError(t, err) - testtool.AssertTrue(t, actionCalls == 3) - }) - t.Run("success_CompensationFunc", func(t *testing.T) { - var ( - errCounter = 0 - actionCalls = 0 - compensationCalls = 0 - ) - steps := []Step{ - { - Name: "step0", - Action: ActionFunc(func(ctx context.Context) error { - actionCalls++ - return testtool.ErrExpTest - }), - Compensation: CompensationFunc(func(ctx context.Context, aroseErr error) error { - compensationCalls++ - errCounter++ - if errCounter < 3 { - return testtool.ErrExpTest - } - return nil - }).WithRetry( - NewBaseRetryOpt(3, time.Nanosecond), - ), - CompensationOnFail: true, - }, - } + resp, err := NewSaga(steps).Execute(ctx) + assert.NoError(t, err) + assert.Equal(t, StageResultSuccess, resp.Status) + assert.Equal(t, 3, actionCalls) + assert.Equal(t, ExecutionStatusSuccess, resp.Steps[0].Action.Status) + assert.Equal(t, 3, resp.Steps[0].Action.Calls) + assert.Equal(t, 2, len(resp.Steps[0].Action.Errors)) + for _, e := range resp.Steps[0].Action.Errors { + assert.ErrorIs(t, e, testtool.ErrExpTestA) + } + }) + t.Run("success_CompensationFunc", func(t *testing.T) { + var ( + errCounter = 0 + actionCalls = 0 + compensationCalls = 0 + ) + steps := []Step{ + { + Name: "step0", + Action: ActionFunc(func(ctx context.Context, track Track) error { + actionCalls++ + return testtool.ErrExpTestA + }), + Compensation: CompensationFunc(func(ctx context.Context, track Track) error { + compensationCalls++ + errCounter++ + if errCounter < 3 { + return testtool.ErrExpTestA + } + return nil + }).WithRetry( + NewBaseRetryOpt(3, time.Nanosecond), + ), + CompensationRequired: true, + }, + } - err := NewSaga(steps).Execute(ctx) - testtool.AssertError(t, err) - testtool.AssertTrue(t, errors.Is(err, ErrActionFailed)) - testtool.AssertTrue(t, errors.Is(err, ErrCompensationSuccess)) - testtool.AssertTrue(t, actionCalls == 1) - testtool.AssertTrue(t, compensationCalls == 3) + resp, err := NewSaga(steps).Execute(ctx) + assert.Error(t, err) + assert.Equal(t, StageResultCompensated, resp.Status) + assert.Equal(t, 1, actionCalls) + assert.Equal(t, 3, compensationCalls) + }) }) }) } diff --git a/saga/saga.go b/saga/saga.go index 06f22a3..d5abb15 100644 --- a/saga/saga.go +++ b/saga/saga.go @@ -17,13 +17,6 @@ var ( // completed action, and the compensation logic itself encounters an error. ErrCompensationFailed = fmt.Errorf("compensation failed") - // ErrCompensationSuccess indicates that a compensation was executed successfully. - // This error can be used to signal that compensation logic has been applied, - // which might be useful for logging or monitoring purposes. - // Note: Despite being an error type, it represents a successful compensation - // execution, not a failure. - ErrCompensationSuccess = fmt.Errorf("compensation executed") - // ErrPanicRecovered is returned when a panic is recovered and converted to an error. // It wraps the original panic value to provide more context about what caused // the panic. This allows panics to be handled gracefully without crashing @@ -49,6 +42,15 @@ var ( ErrRetryContextDone = fmt.Errorf("retry context done") ) +type Track interface { + call() + setParentError(error) + + SetStatus(ExecutionStatus) + SetFailedOnError(error) + GetData() StepData +} + // Saga coordinates a distributed transaction using the `Saga` pattern. type Saga struct { steps []Step @@ -64,75 +66,100 @@ func NewSaga(steps []Step) *Saga { // Execute runs all Saga steps. // // If any step fails, compensating actions are triggered for all successfully completed steps. -func (s *Saga) Execute(ctx context.Context) error { - var completedSteps []Step +func (s *Saga) Execute(ctx context.Context) (Result, error) { + var ( + tracks []*executionTrack + completedTrack []*executionTrack + err error + ) +stop: for i, step := range s.steps { + var ( + tr = newExecutionTrack( + uint32(i), + step, + ) + ) + + tr.actionTrack() + tracks = append(tracks, tr) select { case <-ctx.Done(): - return errors.Join(ErrExecuteActionsContextDone, ctx.Err()) + tr.SetFailedOnError( + fmt.Errorf("action failed [%d#%s]: %w", i, tr.StepName, + errors.Join(ctx.Err(), ErrExecuteActionsContextDone), + ), + ) + break stop default: if step.Action == nil { + tr.action.Status = ExecutionStatusUnset continue } - if step.CompensationOnFail { - completedSteps = append(completedSteps, step) + if step.CompensationRequired { + completedTrack = append(completedTrack, tr) } - err := step.Action(ctx) - if err != nil { - err = fmt.Errorf("action failed [%d#%s]: %w", i, step.Name, errors.Join(ErrActionFailed, err)) + tr.call() + err = step.Action(ctx, tr) + + switch status := tr.getStatus(); { + case err == nil && status != ExecutionStatusFail: + tr.SetStatus(ExecutionStatusSuccess) + case err != nil || status == ExecutionStatusFail: + if err != nil { + err = errors.Join(err, ErrActionFailed) + tr.SetFailedOnError( + fmt.Errorf("action failed [%d#%s]: %w", i, tr.StepName, err), + ) + } // Run compensation when error arise. - return s.compensate(ctx, completedSteps, err) + s.compensate(ctx, completedTrack) + break stop } - if !step.CompensationOnFail { - completedSteps = append(completedSteps, step) + if !step.CompensationRequired { + completedTrack = append(completedTrack, tr) } } } - return nil + result, err := prepareResult(tracks) + return result, err } // compensate triggers compensating actions for all steps in reverse order. -func (s *Saga) compensate(ctx context.Context, completedSteps []Step, originalErr error) error { - var ( - compensationErrors []error - compensationsExecuted int32 - ) - +func (s *Saga) compensate(ctx context.Context, tracks []*executionTrack) { stop: - for i, step := range completedSteps { + for i, tr := range tracks { + if tr.compensationFn == nil { + continue + } + tr.compensationTrack() select { case <-ctx.Done(): - compensationErrors = append(compensationErrors, errors.Join(ErrExecuteCompensationContextDone, ctx.Err())) + tr.SetFailedOnError( + fmt.Errorf("compensation failed [%d#%s]: %w", i, tr.StepName, + errors.Join(ctx.Err(), ErrExecuteCompensationContextDone), + ), + ) break stop default: - if step.Compensation == nil { - continue + tr.call() + err := tr.compensationFn(ctx, tr) + switch { + case err == nil: + // Determine final status based on error count vs calls + if uint32(len(tr.current.Errors)) == tr.current.Calls { + tr.SetStatus(ExecutionStatusFail) + } else { + tr.SetStatus(ExecutionStatusSuccess) + } + case err != nil: + tr.SetFailedOnError(fmt.Errorf("compensation failed [%d#%s]: %w", i, tr.StepName, err)) } - - if err := step.Compensation(ctx, originalErr); err != nil { - compensationErrors = append( - compensationErrors, - fmt.Errorf("compensation failed [%d#%s]: %w", i, step.Name, err), - ) - } - compensationsExecuted++ } } - - var err error - if len(compensationErrors) > 0 { - err = errors.Join(errors.Join(compensationErrors...), originalErr) - } - - if err != nil { - return errors.Join(ErrCompensationFailed, err) - - } - - return errors.Join(ErrCompensationSuccess, originalErr) } diff --git a/saga/saga_test.go b/saga/saga_test.go index 60c2f2b..8f3ced3 100644 --- a/saga/saga_test.go +++ b/saga/saga_test.go @@ -5,13 +5,16 @@ import ( "errors" "fmt" "slices" + "strings" "testing" "time" "github.com/kozmod/oniontx/internal/testtool" + "github.com/kozmod/oniontx/internal/testtool/assert" ) // nolint: dupl + func TestSaga_Execute(t *testing.T) { var ( ctx = context.Background() @@ -26,11 +29,11 @@ func TestSaga_Execute(t *testing.T) { steps := []Step{ { Name: "step0", - Action: func(ctx context.Context) error { + Action: func(ctx context.Context, _ Track) error { executedActions = append(executedActions, "action1") return nil }, - Compensation: func(ctx context.Context, _ error) error { + Compensation: func(ctx context.Context, _ Track) error { executedCompensation = append(executedCompensation, "comp1") t.Fatalf("should not have been called") return nil @@ -38,11 +41,11 @@ func TestSaga_Execute(t *testing.T) { }, { Name: "step1", - Action: func(ctx context.Context) error { + Action: func(ctx context.Context, _ Track) error { executedActions = append(executedActions, "action2") return nil }, - Compensation: func(ctx context.Context, _ error) error { + Compensation: func(ctx context.Context, _ Track) error { executedCompensation = append(executedCompensation, "comp2") t.Fatalf("should not have been called") return nil @@ -50,10 +53,25 @@ func TestSaga_Execute(t *testing.T) { }, } - err := NewSaga(steps).Execute(ctx) - testtool.AssertNoError(t, err) - testtool.AssertTrue(t, slices.Equal([]string{"action1", "action2"}, executedActions)) - testtool.AssertTrue(t, len(executedCompensation) == 0) + res, err := NewSaga(steps).Execute(ctx) + assert.NoError(t, err) + assert.Equal(t, StageResultSuccess, res.Status) + assert.Equal(t, 2, len(res.Steps)) + + assert.Equal(t, "step0", res.Steps[0].StepName) + assert.Equal(t, 0, res.Steps[0].StepPosition) + assert.Equal(t, 1, res.Steps[0].Action.Calls) + assert.Equal(t, ExecutionStatusSuccess, res.Steps[0].Action.Status) + assert.Equal(t, 0, res.Steps[0].Compensation.Calls) + + assert.Equal(t, "step1", res.Steps[1].StepName) + assert.Equal(t, 1, res.Steps[1].StepPosition) + assert.Equal(t, 1, res.Steps[1].Action.Calls) + assert.Equal(t, ExecutionStatusSuccess, res.Steps[1].Action.Status) + assert.Equal(t, 0, res.Steps[1].Compensation.Calls) + + assert.True(t, slices.Equal([]string{"action1", "action2"}, executedActions)) + assert.True(t, len(executedCompensation) == 0) }) t.Run("success_compensation_on_step1", func(t *testing.T) { @@ -65,22 +83,22 @@ func TestSaga_Execute(t *testing.T) { steps := []Step{ { Name: "step0", - Action: func(ctx context.Context) error { + Action: func(ctx context.Context, _ Track) error { executedActions = append(executedActions, "action1") return nil }, - Compensation: func(ctx context.Context, _ error) error { + Compensation: func(ctx context.Context, _ Track) error { executedCompensation = append(executedCompensation, "comp1") return nil }, }, { Name: "step1", - Action: NewAction(func(ctx context.Context) error { + Action: NewAction(func(ctx context.Context, _ Track) error { executedActions = append(executedActions, "action2") - return testtool.ErrExpTest + return testtool.ErrExpTestA }), - Compensation: NewCompensation(func(ctx context.Context, aroseErr error) error { + Compensation: NewCompensation(func(ctx context.Context, _ Track) error { executedCompensation = append(executedCompensation, "comp2") t.Fatalf("should not have been called") return nil @@ -88,11 +106,24 @@ func TestSaga_Execute(t *testing.T) { }, } - err := NewSaga(steps).Execute(ctx) - testtool.AssertError(t, err) - testtool.AssertTrue(t, errors.Is(err, testtool.ErrExpTest)) - testtool.AssertTrue(t, slices.Equal([]string{"action1", "action2"}, executedActions)) - testtool.AssertTrue(t, slices.Equal([]string{"comp1"}, executedCompensation)) + res, err := NewSaga(steps).Execute(ctx) + assert.Error(t, err) + assert.Equal(t, StageResultCompensated, res.Status) + assert.Equal(t, 2, len(res.Steps)) + + assert.Equal(t, "step0", res.Steps[0].StepName) + assert.Equal(t, 0, res.Steps[0].StepPosition) + assert.Equal(t, 1, res.Steps[0].Action.Calls) + assert.Equal(t, 1, res.Steps[0].Compensation.Calls) + assert.Equal(t, ExecutionStatusSuccess, res.Steps[0].Compensation.Status) + + assert.Equal(t, "step1", res.Steps[1].StepName) + assert.Equal(t, 1, res.Steps[1].StepPosition) + assert.Equal(t, 1, res.Steps[1].Action.Calls) + assert.Equal(t, 0, res.Steps[1].Compensation.Calls) + + assert.True(t, slices.Equal([]string{"action1", "action2"}, executedActions)) + assert.True(t, slices.Equal([]string{"comp1"}, executedCompensation)) }) t.Run("compensation_on_fail", func(t *testing.T) { @@ -105,11 +136,11 @@ func TestSaga_Execute(t *testing.T) { steps := []Step{ { Name: "step0", - Action: func(ctx context.Context) error { + Action: func(ctx context.Context, _ Track) error { executedActions = append(executedActions, "action1") - return testtool.ErrExpTest + return testtool.ErrExpTestA }, - Compensation: func(ctx context.Context, aroseErr error) error { + Compensation: func(ctx context.Context, _ Track) error { executedCompensation = append(executedCompensation, "comp1") t.Fatalf("should not have been called") return nil @@ -117,11 +148,22 @@ func TestSaga_Execute(t *testing.T) { }, } - err := NewSaga(steps).Execute(ctx) - testtool.AssertError(t, err) - testtool.AssertTrue(t, errors.Is(err, testtool.ErrExpTest)) - testtool.AssertTrue(t, slices.Equal([]string{"action1"}, executedActions)) - testtool.AssertTrue(t, len(executedCompensation) == 0) + res, err := NewSaga(steps).Execute(ctx) + assert.Error(t, err) + assert.Equal(t, StageResultFail, res.Status) + assert.Equal(t, 1, len(res.Steps)) + + assert.Equal(t, "step0", res.Steps[0].StepName) + assert.Equal(t, 0, res.Steps[0].StepPosition) + assert.Equal(t, 1, res.Steps[0].Action.Calls) + assert.Equal(t, 1, len(res.Steps[0].Action.Errors)) + assert.ErrorIs(t, res.Steps[0].Action.Errors[0], testtool.ErrExpTestA) + assert.Equal(t, ExecutionStatusFail, res.Steps[0].Action.Status) + assert.Equal(t, 0, res.Steps[0].Compensation.Calls) + assert.Equal(t, ExecutionStatusUncalled, res.Steps[0].Compensation.Status) + + assert.True(t, slices.Equal([]string{"action1"}, executedActions)) + assert.True(t, len(executedCompensation) == 0) }) t.Run("added", func(t *testing.T) { var ( @@ -132,23 +174,34 @@ func TestSaga_Execute(t *testing.T) { steps := []Step{ { Name: "step0", - Action: func(ctx context.Context) error { + Action: func(ctx context.Context, _ Track) error { executedActions = append(executedActions, "action1") - return testtool.ErrExpTest + return testtool.ErrExpTestA }, - Compensation: func(ctx context.Context, aroseErr error) error { + Compensation: func(ctx context.Context, _ Track) error { executedCompensation = append(executedCompensation, "comp1") return nil }, - CompensationOnFail: true, + CompensationRequired: true, }, } - err := NewSaga(steps).Execute(ctx) - testtool.AssertError(t, err) - testtool.AssertTrue(t, errors.Is(err, testtool.ErrExpTest)) - testtool.AssertTrue(t, slices.Equal([]string{"action1"}, executedActions)) - testtool.AssertTrue(t, slices.Equal([]string{"comp1"}, executedCompensation)) + res, err := NewSaga(steps).Execute(ctx) + assert.Error(t, err) + assert.Equal(t, StageResultCompensated, res.Status) + assert.Equal(t, 1, len(res.Steps)) + + assert.Equal(t, "step0", res.Steps[0].StepName) + assert.Equal(t, 0, res.Steps[0].StepPosition) + assert.Equal(t, 1, res.Steps[0].Action.Calls) + assert.Equal(t, 1, len(res.Steps[0].Action.Errors)) + assert.ErrorIs(t, res.Steps[0].Action.Errors[0], testtool.ErrExpTestA) + assert.Equal(t, ExecutionStatusFail, res.Steps[0].Action.Status) + assert.Equal(t, 1, res.Steps[0].Compensation.Calls) + assert.Equal(t, ExecutionStatusSuccess, res.Steps[0].Compensation.Status) + + assert.True(t, slices.Equal([]string{"action1"}, executedActions)) + assert.True(t, slices.Equal([]string{"comp1"}, executedCompensation)) }) }) } @@ -162,59 +215,80 @@ func Test_Saga_panic_recovery(t *testing.T) { steps := []Step{ { Name: "step0", - Action: WithPanicRecovery(func(ctx context.Context) error { + Action: WithPanicRecovery(func(ctx context.Context, _ Track) error { panic("panic_v1!") }), }, } - err := NewSaga(steps).Execute(ctx) - testtool.AssertError(t, err) - testtool.AssertTrue(t, errors.Is(err, ErrPanicRecovered)) - testtool.AssertTrue(t, errors.Is(err, ErrActionFailed)) + res, err := NewSaga(steps).Execute(ctx) + assert.Error(t, err) + assert.Equal(t, StageResultFail, res.Status) + assert.Equal(t, 1, len(res.Steps)) + assert.Equal(t, ExecutionStatusFail, res.Steps[0].Action.Status) + assert.Equal(t, 1, len(res.Steps[0].Action.Errors)) + assert.ErrorIs(t, res.Steps[0].Action.Errors[0], ErrPanicRecovered) + assert.Equal(t, ExecutionStatusUnset, res.Steps[0].Compensation.Status) + assert.Equal(t, 0, res.Steps[0].Compensation.Calls) + assert.Equal(t, 0, len(res.Steps[0].Compensation.Errors)) - testtool.LogError(t, err) }) }) - t.Run("builders", func(t *testing.T) { + t.Run("builder_stile", func(t *testing.T) { t.Run("success_ActionFunc", func(t *testing.T) { steps := []Step{ - { - Name: "step0", - Action: ActionFunc(func(ctx context.Context) error { - panic("panic_v2!") - }).WithPanicRecovery(), - }, + NewStep("step0"). + WithAction( + NewAction(func(ctx context.Context, _ Track) error { + panic("panic_v2!") + }).WithPanicRecovery(), + ), } - err := NewSaga(steps).Execute(ctx) - testtool.AssertError(t, err) - testtool.AssertTrue(t, errors.Is(err, ErrPanicRecovered)) - testtool.AssertTrue(t, errors.Is(err, ErrActionFailed)) - - testtool.LogError(t, err) + res, err := NewSaga(steps).Execute(ctx) + assert.Error(t, err) + assert.Equal(t, StageResultFail, res.Status) + assert.Equal(t, 1, len(res.Steps)) + assert.Equal(t, ExecutionStatusFail, res.Steps[0].Action.Status) + assert.Equal(t, 1, len(res.Steps[0].Action.Errors)) + assert.ErrorIs(t, res.Steps[0].Action.Errors[0], ErrPanicRecovered) + assert.Equal(t, ExecutionStatusUnset, res.Steps[0].Compensation.Status) + assert.Equal(t, 0, res.Steps[0].Compensation.Calls) + assert.Equal(t, 0, len(res.Steps[0].Compensation.Errors)) }) t.Run("success_CompensationFunc", func(t *testing.T) { steps := []Step{ { Name: "step0", - Action: ActionFunc(func(ctx context.Context) error { - return testtool.ErrExpTest + Action: ActionFunc(func(ctx context.Context, _ Track) error { + return testtool.ErrExpTestA }), - Compensation: CompensationFunc(func(ctx context.Context, aroseErr error) error { + Compensation: CompensationFunc(func(ctx context.Context, track Track) error { + str := track.GetData() + assert.Equal(t, 1, len(str.Action.Errors)) + assert.Equal(t, 1, str.Action.Calls) + assert.Equal(t, ExecutionStatusFail, str.Action.Status) + assert.Error(t, str.Action.Errors[0]) + assert.ErrorIs(t, str.Action.Errors[0], testtool.ErrExpTestA) + panic("panic_v3!") }).WithPanicRecovery(), - CompensationOnFail: true, + CompensationRequired: true, }, } - err := NewSaga(steps).Execute(ctx) - testtool.AssertError(t, err) - testtool.AssertTrue(t, errors.Is(err, ErrPanicRecovered)) - testtool.AssertTrue(t, errors.Is(err, ErrCompensationFailed)) - - testtool.LogError(t, err) + res, err := NewSaga(steps).Execute(ctx) + assert.Error(t, err) + assert.Equal(t, StageResultFail, res.Status) + assert.Equal(t, 1, len(res.Steps)) + assert.Equal(t, ExecutionStatusFail, res.Steps[0].Action.Status) + assert.Equal(t, 1, len(res.Steps[0].Action.Errors)) + assert.ErrorIs(t, res.Steps[0].Action.Errors[0], testtool.ErrExpTestA) + assert.Equal(t, ExecutionStatusFail, res.Steps[0].Compensation.Status) + assert.Equal(t, 1, res.Steps[0].Compensation.Calls) + assert.Equal(t, 1, len(res.Steps[0].Compensation.Errors)) + assert.ErrorIs(t, res.Steps[0].Compensation.Errors[0], ErrPanicRecovered) }) }) } @@ -225,141 +299,172 @@ func Test_actions_v2(t *testing.T) { ) t.Run("success_actions", func(t *testing.T) { - var ( - executedActions []string - executedCompensation []string - ) - steps := []Step{ NewStep("step0"). - WithAction(func(ctx context.Context) error { - executedActions = append(executedActions, "action1") + WithAction(func(ctx context.Context, _ Track) error { return nil }). - WithCompensation(func(ctx context.Context, aroseErr error) error { - executedCompensation = append(executedCompensation, "comp1") + WithCompensation(func(ctx context.Context, _ Track) error { t.Fatalf("should not have been called") return nil }), NewStep("step1"). - WithAction(func(ctx context.Context) error { - executedActions = append(executedActions, "action2") + WithAction(func(ctx context.Context, _ Track) error { return nil }). - WithCompensation(func(ctx context.Context, aroseErr error) error { - executedCompensation = append(executedCompensation, "comp2") + WithCompensation(func(ctx context.Context, _ Track) error { t.Fatalf("should not have been called") return nil }), } - err := NewSaga(steps).Execute(ctx) - testtool.AssertNoError(t, err) - testtool.AssertTrue(t, slices.Equal([]string{"action1", "action2"}, executedActions)) - testtool.AssertTrue(t, len(executedCompensation) == 0) + res, err := NewSaga(steps).Execute(ctx) + assert.NoError(t, err) + assert.Equal(t, StageResultSuccess, res.Status) + assert.Equal(t, 2, len(res.Steps)) + assert.Equal(t, ExecutionStatusSuccess, res.Steps[0].Action.Status) + assert.Equal(t, 1, res.Steps[0].Action.Calls) + assert.Equal(t, 0, len(res.Steps[0].Action.Errors)) + assert.Equal(t, ExecutionStatusUncalled, res.Steps[0].Compensation.Status) + assert.Equal(t, 0, res.Steps[0].Compensation.Calls) + assert.Equal(t, 0, len(res.Steps[0].Compensation.Errors)) + + assert.Equal(t, ExecutionStatusSuccess, res.Steps[1].Action.Status) + assert.Equal(t, 1, res.Steps[1].Action.Calls) + assert.Equal(t, 0, len(res.Steps[1].Action.Errors)) + assert.Equal(t, ExecutionStatusUncalled, res.Steps[1].Compensation.Status) + assert.Equal(t, 0, res.Steps[1].Compensation.Calls) + assert.Equal(t, 0, len(res.Steps[1].Compensation.Errors)) }) } - func Test_execute_context(t *testing.T) { t.Run("action_ctx_cancel", func(t *testing.T) { var ( - ctx, cancel = context.WithCancel(context.Background()) - executedActions []string + ctx, cancel = context.WithCancel(context.Background()) + calls = make([]string, 0, 2) ) steps := []Step{ NewStep("step0"). WithAction(nil), NewStep("step1"). - WithAction(func(ctx context.Context) error { - executedActions = append(executedActions, "action1") + WithAction(func(ctx context.Context, _ Track) error { + calls = append(calls, "action1") return nil }), NewStep("step2"). - WithAction(func(ctx context.Context) error { - executedActions = append(executedActions, "action2") + WithAction(func(ctx context.Context, _ Track) error { + calls = append(calls, "action2") cancel() // cancel context for test return nil }), NewStep("step3"). - WithAction(func(ctx context.Context) error { - executedActions = append(executedActions, "action3") + WithAction(func(ctx context.Context, _ Track) error { + calls = append(calls, "action3") t.Fatalf("should not have been called") return nil }), } - err := NewSaga(steps).Execute(ctx) - testtool.AssertError(t, err) - testtool.AssertTrue(t, errors.Is(err, ErrExecuteActionsContextDone)) - testtool.AssertTrue(t, slices.Equal([]string{"action1", "action2"}, executedActions)) + res, err := NewSaga(steps).Execute(ctx) + assert.Error(t, err) + + assert.Equal(t, StageResultCompensated, res.Status) + assert.Equal(t, 4, len(res.Steps)) + assert.Equal(t, "step0", res.Steps[0].StepName) + assert.Equal(t, 0, res.Steps[0].StepPosition) + assert.Equal(t, ExecutionStatusUnset, res.Steps[0].Action.Status) + assert.Equal(t, ExecutionStatusUnset, res.Steps[0].Compensation.Status) + assert.Equal(t, false, res.Steps[0].CompensationRequired) + + assert.Equal(t, "step1", res.Steps[1].StepName) + assert.Equal(t, 1, res.Steps[1].StepPosition) + assert.Equal(t, ExecutionStatusSuccess, res.Steps[1].Action.Status) + assert.Equal(t, ExecutionStatusUnset, res.Steps[1].Compensation.Status) + assert.Equal(t, false, res.Steps[1].CompensationRequired) + + assert.Equal(t, "step2", res.Steps[2].StepName) + assert.Equal(t, 2, res.Steps[2].StepPosition) + assert.Equal(t, ExecutionStatusSuccess, res.Steps[2].Action.Status) + assert.Equal(t, ExecutionStatusUnset, res.Steps[2].Compensation.Status) + assert.Equal(t, false, res.Steps[2].CompensationRequired) + + assert.Equal(t, "step3", res.Steps[3].StepName) + assert.Equal(t, 3, res.Steps[3].StepPosition) + assert.Equal(t, ExecutionStatusFail, res.Steps[3].Action.Status) + assert.Equal(t, 1, len(res.Steps[3].Action.Errors)) + assert.ErrorIs(t, res.Steps[3].Action.Errors[0], ErrExecuteActionsContextDone) + assert.Equal(t, ExecutionStatusUnset, res.Steps[3].Compensation.Status) + assert.Equal(t, 0, len(res.Steps[3].Compensation.Errors)) + assert.Equal(t, false, res.Steps[3].CompensationRequired) + + assert.True(t, slices.Equal([]string{"action1", "action2"}, calls)) }) t.Run("retry_ctx_cancel", func(t *testing.T) { var ( ctx, cancel = context.WithCancel(context.Background()) - executed []string - actionCalls = 1 ) steps := []Step{ NewStep("step0"). - WithAction(nil), - NewStep("step1"). WithAction( - NewAction(func(ctx context.Context) error { - executed = append(executed, "action1") - switch { - case actionCalls == 1: - actionCalls++ - return testtool.ErrExpTest - case actionCalls == 2: - actionCalls++ - return testtool.ErrExpTest - case actionCalls >= 3: - actionCalls++ + NewAction(func(ctx context.Context, track Track) error { + data := track.GetData() + if data.Action.Calls >= 2 { cancel() // cancel context for test - return testtool.ErrExpTest } - return nil - }).WithRetry(NewBaseRetryOpt(4, 1*time.Nanosecond)), + return testtool.ErrExpTestA + }).WithRetry(NewBaseRetryOpt(10, 1*time.Nanosecond)), ), } - err := NewSaga(steps).Execute(ctx) - testtool.AssertError(t, err) - testtool.AssertTrue(t, errors.Is(err, ErrRetryContextDone)) - testtool.AssertTrue(t, 4 == actionCalls) // 3 + first execution - testtool.AssertTrue(t, slices.Equal([]string{"action1", "action1", "action1"}, executed)) + res, err := NewSaga(steps).Execute(ctx) + assert.Error(t, err) - testtool.LogError(t, err) + assert.Equal(t, StageResultFail, res.Status) + assert.Equal(t, 1, len(res.Steps)) + assert.Equal(t, "step0", res.Steps[0].StepName) + assert.Equal(t, 0, res.Steps[0].StepPosition) + assert.Equal(t, ExecutionStatusFail, res.Steps[0].Action.Status) + assert.Equal(t, ExecutionStatusUnset, res.Steps[0].Compensation.Status) + assert.Equal(t, 3, len(res.Steps[0].Action.Errors)) + assert.ErrorIs(t, res.Steps[0].Action.Errors[0], testtool.ErrExpTestA) + assert.ErrorIs(t, res.Steps[0].Action.Errors[1], testtool.ErrExpTestA) + assert.ErrorIs(t, res.Steps[0].Action.Errors[2], ErrRetryContextDone) }) - + // t.Run("compensation_ctx_cancel", func(t *testing.T) { var ( ctx, cancel = context.WithCancel(context.Background()) - executed []string ) steps := []Step{ - NewStep("step1"). + NewStep("step0"). WithAction( - NewAction(func(ctx context.Context) error { - executed = append(executed, "action1") + NewAction(func(ctx context.Context, _ Track) error { cancel() // cancel context for test - return testtool.ErrExpTest + return testtool.ErrExpTestA }), ).WithCompensation( - NewCompensation(func(ctx context.Context, aroseErr error) error { + NewCompensation(func(ctx context.Context, _ Track) error { t.Fatalf("should not have been called") return nil }), - ).WithCompensationOnFail(), + ).WithCompensationRequired(), } - err := NewSaga(steps).Execute(ctx) - testtool.AssertError(t, err) - testtool.AssertTrue(t, errors.Is(err, ErrExecuteCompensationContextDone)) - testtool.AssertTrue(t, slices.Equal([]string{"action1"}, executed)) - - testtool.LogError(t, err) + res, err := NewSaga(steps).Execute(ctx) + assert.Error(t, err) + assert.Equal(t, StageResultFail, res.Status) + assert.Equal(t, 1, len(res.Steps)) + assert.Equal(t, "step0", res.Steps[0].StepName) + assert.Equal(t, 0, res.Steps[0].StepPosition) + assert.Equal(t, ExecutionStatusFail, res.Steps[0].Action.Status) + assert.Equal(t, 1, res.Steps[0].Action.Calls) + assert.Equal(t, 1, len(res.Steps[0].Action.Errors)) + assert.ErrorIs(t, res.Steps[0].Action.Errors[0], testtool.ErrExpTestA) + assert.Equal(t, ExecutionStatusFail, res.Steps[0].Compensation.Status) + assert.Equal(t, 0, res.Steps[0].Compensation.Calls) + assert.Equal(t, 1, len(res.Steps[0].Compensation.Errors)) + assert.ErrorIs(t, res.Steps[0].Compensation.Errors[0], ErrExecuteCompensationContextDone) }) } @@ -369,63 +474,77 @@ func Test_hooks(t *testing.T) { t.Run("before", func(t *testing.T) { var ( ctx = context.Background() - executed []string + executed = make([]string, 0, 3) ) steps := []Step{ NewStep("step1"). WithAction( - NewAction(func(ctx context.Context) error { + NewAction(func(ctx context.Context, _ Track) error { executed = append(executed, "action1") - return testtool.ErrExpTest - }).WithBeforeHook(func(ctx context.Context) error { + return testtool.ErrExpTestA + }).WithBeforeHook(func(ctx context.Context, _ Track) error { executed = append(executed, "hook1") return nil - }).WithBeforeHook(func(ctx context.Context) error { + }).WithBeforeHook(func(ctx context.Context, _ Track) error { executed = append(executed, "hook2") return nil }), ), } - err := NewSaga(steps).Execute(ctx) - testtool.AssertError(t, err) - testtool.AssertTrue(t, slices.Equal([]string{"hook2", "hook1", "action1"}, executed)) - testtool.LogError(t, err) + res, err := NewSaga(steps).Execute(ctx) + assert.Error(t, err) + + assert.Equal(t, StageResultFail, res.Status) + assert.Equal(t, 1, len(res.Steps)) + assert.Equal(t, 1, res.Steps[0].Action.Calls) + assert.Equal(t, ExecutionStatusFail, res.Steps[0].Action.Status) + assert.Equal(t, 1, len(res.Steps[0].Action.Errors)) + assert.ErrorIs(t, res.Steps[0].Action.Errors[0], testtool.ErrExpTestA) + + assert.True(t, slices.Equal([]string{"hook2", "hook1", "action1"}, executed)) }) t.Run("before_with_retry", func(t *testing.T) { var ( ctx = context.Background() - executed []string + executed = make([]string, 0, 8) ) steps := []Step{ NewStep("step1"). WithAction( - NewAction(func(ctx context.Context) error { + NewAction(func(ctx context.Context, _ Track) error { executed = append(executed, "action1") - return testtool.ErrExpTest - }).WithBeforeHook(func(ctx context.Context) error { + return testtool.ErrExpTestA + }).WithBeforeHook(func(ctx context.Context, _ Track) error { executed = append(executed, "hook1") return nil - }).WithBeforeHook(func(ctx context.Context) error { + }).WithBeforeHook(func(ctx context.Context, _ Track) error { executed = append(executed, "hook2") return nil }).WithRetry(NewBaseRetryOpt(1, 1*time.Nanosecond)). - WithBeforeHook(func(ctx context.Context) error { + WithBeforeHook(func(ctx context.Context, _ Track) error { executed = append(executed, "retry_hook1") return nil }).WithBeforeHook( - func(ctx context.Context) error { + func(ctx context.Context, _ Track) error { executed = append(executed, "retry_hook2") return nil }), ), } - err := NewSaga(steps).Execute(ctx) - testtool.AssertError(t, err) - testtool.AssertTrue(t, errors.Is(err, testtool.ErrExpTest)) - testtool.AssertTrue(t, + + res, err := NewSaga(steps).Execute(ctx) + assert.Error(t, err) + assert.Equal(t, StageResultFail, res.Status) + assert.Equal(t, 1, len(res.Steps)) + assert.Equal(t, 2, res.Steps[0].Action.Calls) + assert.Equal(t, ExecutionStatusFail, res.Steps[0].Action.Status) + assert.Equal(t, 2, len(res.Steps[0].Action.Errors)) + assert.ErrorIs(t, res.Steps[0].Action.Errors[0], testtool.ErrExpTestA) + assert.ErrorIs(t, res.Steps[0].Action.Errors[1], testtool.ErrExpTestA) + assert.True(t, slices.Equal( []string{ "retry_hook2", "retry_hook1", // retry hooks @@ -434,79 +553,198 @@ func Test_hooks(t *testing.T) { }, executed), ) - - testtool.LogError(t, err) }) + + var ( + errHook1 = fmt.Errorf("error_hook1") + errHook2 = fmt.Errorf("error_hook2") + errHook3 = fmt.Errorf("error_hook_3") + errHook4 = fmt.Errorf("error_hook_4") + ) + t.Run("after", func(t *testing.T) { var ( ctx = context.Background() - executed []string + executed = make([]string, 0, 3) ) steps := []Step{ NewStep("step1"). WithAction( - NewAction(func(ctx context.Context) error { + NewAction(func(ctx context.Context, track Track) error { executed = append(executed, "action1") - return testtool.ErrExpTest - }).WithAfterHook(func(ctx context.Context, aroseError error) error { - testtool.AssertError(t, aroseError) - testtool.AssertTrue(t, errors.Is(aroseError, testtool.ErrExpTest)) + return testtool.ErrExpTestA + }).WithAfterHook(func(ctx context.Context, track Track) error { executed = append(executed, "hook1") - return aroseError - }).WithAfterHook(func(ctx context.Context, aroseError error) error { - testtool.AssertError(t, aroseError) - testtool.AssertTrue(t, errors.Is(aroseError, testtool.ErrExpTest)) + + data := track.GetData() + assert.Equal(t, 1, data.Action.Calls) + assert.Equal(t, 1, len(data.Action.Errors)) + assert.ErrorIs(t, data.Action.Errors[0], testtool.ErrExpTestA) + + data.Action.Errors = append(data.Action.Errors, errHook1) + return errHook1 + }).WithAfterHook(func(ctx context.Context, track Track) error { executed = append(executed, "hook2") - return aroseError + + data := track.GetData() + assert.Equal(t, 1, data.Action.Calls) + assert.Equal(t, 2, len(data.Action.Errors)) + assert.ErrorIs(t, data.Action.Errors[0], testtool.ErrExpTestA) + assert.ErrorIs(t, data.Action.Errors[1], errHook1) + return errHook2 }), ), } - err := NewSaga(steps).Execute(ctx) - testtool.AssertError(t, err) - testtool.AssertTrue(t, errors.Is(err, testtool.ErrExpTest)) - testtool.AssertTrue(t, slices.Equal([]string{"action1", "hook1", "hook2"}, executed)) + res, err := NewSaga(steps).Execute(ctx) + assert.Error(t, err) + assert.Equal(t, StageResultFail, res.Status) + assert.Equal(t, 1, len(res.Steps)) + assert.Equal(t, 1, res.Steps[0].Action.Calls) + assert.Equal(t, ExecutionStatusFail, res.Steps[0].Action.Status) + assert.Equal(t, 3, len(res.Steps[0].Action.Errors)) + assert.ErrorIs(t, res.Steps[0].Action.Errors[0], testtool.ErrExpTestA) + assert.ErrorIs(t, res.Steps[0].Action.Errors[1], errHook1) + assert.ErrorIs(t, res.Steps[0].Action.Errors[2], errHook2) + + assert.True(t, slices.Equal([]string{"action1", "hook1", "hook2"}, executed)) - testtool.LogError(t, err) }) - t.Run("after_with_retry", func(t *testing.T) { + t.Run("after_with_retry___complicated_v1", func(t *testing.T) { var ( ctx = context.Background() - executed []string + executed = make([]string, 0, 11) + + checkRetryStr = func(i uint8, err error) bool { + return strings.Contains(err.Error(), fmt.Sprintf("retry [%d]", i)) + } ) steps := []Step{ NewStep("step1"). WithAction( - NewAction(func(ctx context.Context) error { + NewAction(func(ctx context.Context, track Track) error { executed = append(executed, "action1") - return testtool.ErrExpTest - }).WithAfterHook(func(ctx context.Context, aroseError error) error { - testtool.AssertError(t, aroseError) - testtool.AssertTrue(t, errors.Is(aroseError, testtool.ErrExpTest)) + + data := track.GetData() + switch data.Action.Calls { + case 1: + return testtool.ErrExpTestA + case 2: + return testtool.ErrExpTestB + case 3: + return testtool.ErrExpTestC + } + t.Fatalf("should not have been called") + return nil + + }).WithAfterHook(func(ctx context.Context, track Track) error { executed = append(executed, "hook1") - return aroseError - }).WithAfterHook(func(ctx context.Context, aroseError error) error { - testtool.AssertError(t, aroseError) - testtool.AssertTrue(t, errors.Is(aroseError, testtool.ErrExpTest)) + + data := track.GetData() + switch data.Action.Calls { + case 1: + assert.Equal(t, 1, len(data.Action.Errors)) + assert.ErrorIs(t, data.Action.Errors[0], testtool.ErrExpTestA) + case 2: + assert.Equal(t, 4, len(data.Action.Errors)) + assert.ErrorIs(t, data.Action.Errors[0], testtool.ErrExpTestA) + assert.ErrorIs(t, data.Action.Errors[1], errHook1) + assert.ErrorIs(t, data.Action.Errors[2], errHook2) + assert.ErrorIs(t, data.Action.Errors[3], testtool.ErrExpTestB) + assert.True(t, checkRetryStr(0, data.Action.Errors[3])) + case 3: + assert.Equal(t, 7, len(data.Action.Errors)) + assert.ErrorIs(t, data.Action.Errors[0], testtool.ErrExpTestA) + assert.ErrorIs(t, data.Action.Errors[1], errHook1) + assert.ErrorIs(t, data.Action.Errors[2], errHook2) + assert.ErrorIs(t, data.Action.Errors[3], testtool.ErrExpTestB) + assert.True(t, checkRetryStr(0, data.Action.Errors[3])) + assert.ErrorIs(t, data.Action.Errors[4], errHook1) + assert.True(t, checkRetryStr(0, data.Action.Errors[4])) + assert.ErrorIs(t, data.Action.Errors[5], errHook2) + assert.True(t, checkRetryStr(0, data.Action.Errors[5])) + assert.ErrorIs(t, data.Action.Errors[6], testtool.ErrExpTestC) + assert.True(t, checkRetryStr(1, data.Action.Errors[6])) + case 4: + t.Fatalf("should not have been called") + } + return errHook1 + }).WithAfterHook(func(ctx context.Context, track Track) error { executed = append(executed, "hook2") - return aroseError + + data := track.GetData() + switch data.Action.Calls { + case 1: + assert.Equal(t, 2, len(data.Action.Errors)) + assert.ErrorIs(t, data.Action.Errors[0], testtool.ErrExpTestA) + assert.ErrorIs(t, data.Action.Errors[1], errHook1) + case 2: + assert.Equal(t, 5, len(data.Action.Errors)) + assert.ErrorIs(t, data.Action.Errors[0], testtool.ErrExpTestA) + assert.ErrorIs(t, data.Action.Errors[1], errHook1) + assert.ErrorIs(t, data.Action.Errors[2], errHook2) + assert.ErrorIs(t, data.Action.Errors[3], testtool.ErrExpTestB) + assert.ErrorIs(t, data.Action.Errors[4], errHook1) + case 3: + assert.Equal(t, 8, len(data.Action.Errors)) + assert.ErrorIs(t, data.Action.Errors[0], testtool.ErrExpTestA) + assert.ErrorIs(t, data.Action.Errors[1], errHook1) + assert.ErrorIs(t, data.Action.Errors[2], errHook2) + assert.ErrorIs(t, data.Action.Errors[3], testtool.ErrExpTestB) + assert.True(t, checkRetryStr(0, data.Action.Errors[3])) + assert.ErrorIs(t, data.Action.Errors[4], errHook1) + assert.True(t, checkRetryStr(0, data.Action.Errors[4])) + assert.ErrorIs(t, data.Action.Errors[5], errHook2) + assert.True(t, checkRetryStr(0, data.Action.Errors[5])) + assert.ErrorIs(t, data.Action.Errors[6], testtool.ErrExpTestC) + assert.True(t, checkRetryStr(1, data.Action.Errors[6])) + assert.ErrorIs(t, data.Action.Errors[7], errHook1) + assert.True(t, checkRetryStr(1, data.Action.Errors[7])) + case 4: + t.Fatalf("should not have been called") + } + + return errHook2 }).WithRetry(NewBaseRetryOpt(2, 1*time.Nanosecond)). - WithAfterHook(func(ctx context.Context, aroseError error) error { - testtool.AssertTrue(t, errors.Is(aroseError, testtool.ErrExpTest)) + WithAfterHook(func(ctx context.Context, track Track) error { executed = append(executed, "retry_hook1") - return aroseError + + data := track.GetData() + assert.Equal(t, 3, data.Action.Calls) + assert.Equal(t, 9, len(data.Action.Errors)) + assert.ErrorIs(t, data.Action.Errors[8], errHook2) + assert.True(t, checkRetryStr(1, data.Action.Errors[8])) + return errHook3 }). - WithAfterHook(func(ctx context.Context, aroseError error) error { - testtool.AssertTrue(t, errors.Is(aroseError, testtool.ErrExpTest)) + WithAfterHook(func(ctx context.Context, track Track) error { executed = append(executed, "retry_hook2") - return aroseError + + data := track.GetData() + assert.Equal(t, 3, data.Action.Calls) + assert.Equal(t, 10, len(data.Action.Errors)) + assert.ErrorIs(t, data.Action.Errors[8], errHook2) + assert.True(t, checkRetryStr(1, data.Action.Errors[8])) + assert.ErrorIs(t, data.Action.Errors[9], errHook3) + return errHook4 }), ), } - err := NewSaga(steps).Execute(ctx) - testtool.AssertError(t, err) - testtool.AssertTrue(t, + res, err := NewSaga(steps).Execute(ctx) + assert.Error(t, err) + assert.Equal(t, StageResultFail, res.Status) + + testtool.TestFn(t, func() { + t.Logf("\nresult:\n%v", res) + t.Logf("\nexecution error: %v", err) + step := res.Steps[0] + t.Logf("\nstep [%d#%s] action errors:", step.StepPosition, step.StepName) + for i, e := range res.Steps[0].Action.Errors { + fmt.Printf("%d: %v\n", i, e) + } + }) + + assert.True(t, slices.Equal( []string{ "action1", "hook1", "hook2", // call @@ -517,7 +755,25 @@ func Test_hooks(t *testing.T) { executed), ) - testtool.LogError(t, err) + assert.Equal(t, 1, len(res.Steps)) + + action := res.Steps[0].Action + assert.Equal(t, 3, res.Steps[0].Action.Calls) + assert.Equal(t, ExecutionStatusFail, action.Status) + assert.Equal(t, 11, len(action.Errors)) + assert.ErrorIs(t, action.Errors[0], testtool.ErrExpTestA) + assert.ErrorIs(t, action.Errors[1], errHook1) + assert.ErrorIs(t, action.Errors[2], errHook2) + assert.ErrorIs(t, action.Errors[3], testtool.ErrExpTestB) + assert.True(t, checkRetryStr(0, action.Errors[3])) + assert.ErrorIs(t, action.Errors[4], errHook1) + assert.True(t, checkRetryStr(0, action.Errors[4])) + assert.ErrorIs(t, action.Errors[5], errHook2) + assert.True(t, checkRetryStr(0, action.Errors[5])) + assert.ErrorIs(t, action.Errors[6], testtool.ErrExpTestC) + assert.True(t, checkRetryStr(1, action.Errors[6])) + assert.ErrorIs(t, action.Errors[7], errHook1) + assert.True(t, checkRetryStr(1, action.Errors[7])) }) }) t.Run("compensation_hooks", func(t *testing.T) { @@ -530,34 +786,61 @@ func Test_hooks(t *testing.T) { steps := []Step{ NewStep("step1"). WithAction( - NewAction(func(ctx context.Context) error { + NewAction(func(ctx context.Context, _ Track) error { executed = append(executed, "action1") - return testtool.ErrExpTest + return testtool.ErrExpTestA }), ).WithCompensation( - NewCompensation(func(ctx context.Context, actionErr error) error { + NewCompensation(func(ctx context.Context, track Track) error { executed = append(executed, "comp1") + + data := track.GetData() + assert.Equal(t, 1, data.Action.Calls) + assert.Equal(t, 1, data.Compensation.Calls) + assert.Equal(t, 1, len(data.Action.Errors)) + assert.ErrorIs(t, data.Action.Errors[0], testtool.ErrExpTestA) + return nil - }).WithBeforeHook(func(ctx context.Context, actionErr error) error { - testtool.AssertError(t, actionErr) - testtool.AssertTrue(t, errors.Is(actionErr, testtool.ErrExpTest)) + }).WithBeforeHook(func(ctx context.Context, track Track) error { executed = append(executed, "comp_hook1") + + data := track.GetData() + assert.Equal(t, 1, data.Action.Calls) + assert.Equal(t, 1, data.Compensation.Calls) + assert.Equal(t, 1, len(data.Action.Errors)) + assert.ErrorIs(t, data.Action.Errors[0], testtool.ErrExpTestA) return nil - }).WithBeforeHook(func(ctx context.Context, actionErr error) error { - testtool.AssertError(t, actionErr) - testtool.AssertTrue(t, errors.Is(actionErr, testtool.ErrExpTest)) + }).WithBeforeHook(func(ctx context.Context, track Track) error { executed = append(executed, "comp_hook2") + + data := track.GetData() + assert.Equal(t, 1, data.Action.Calls) + assert.Equal(t, 1, data.Compensation.Calls) + assert.Equal(t, 1, len(data.Action.Errors)) + assert.ErrorIs(t, data.Action.Errors[0], testtool.ErrExpTestA) return nil }), - ).WithCompensationOnFail(), + ).WithCompensationRequired(), } - err := NewSaga(steps).Execute(ctx) - testtool.AssertError(t, err) - testtool.AssertTrue(t, errors.Is(err, testtool.ErrExpTest)) - testtool.AssertTrue(t, errors.Is(err, ErrCompensationSuccess)) - testtool.AssertTrue(t, slices.Equal([]string{"action1", "comp_hook2", "comp_hook1", "comp1"}, executed)) + res, err := NewSaga(steps).Execute(ctx) + assert.Error(t, err) + assert.Equal(t, StageResultCompensated, res.Status) + + assert.True(t, slices.Equal([]string{"action1", "comp_hook2", "comp_hook1", "comp1"}, executed)) + + assert.Equal(t, 1, len(res.Steps)) + + action := res.Steps[0].Action + assert.Equal(t, ExecutionStatusFail, action.Status) + assert.Equal(t, 1, len(action.Errors)) + assert.ErrorIs(t, action.Errors[0], testtool.ErrExpTestA) + + compensation := res.Steps[0].Compensation + assert.Equal(t, 1, compensation.Calls) + assert.Equal(t, ExecutionStatusSuccess, compensation.Status) + assert.Equal(t, ExecutionStatusSuccess, compensation.Status) + assert.Equal(t, 0, len(compensation.Errors)) - testtool.LogError(t, err) }) t.Run("after", func(t *testing.T) { var ( @@ -570,115 +853,562 @@ func Test_hooks(t *testing.T) { steps := []Step{ NewStep("step1"). WithAction( - NewAction(func(ctx context.Context) error { + NewAction(func(ctx context.Context, track Track) error { executed = append(executed, "action1") - return testtool.ErrExpTest + return testtool.ErrExpTestA }), ).WithCompensation( - NewCompensation(func(ctx context.Context, actionErr error) error { + NewCompensation(func(ctx context.Context, track Track) error { executed = append(executed, "comp1") return compErr - }).WithAfterHook(func(ctx context.Context, actionErr, previousErr error) error { - testtool.AssertError(t, actionErr) - testtool.AssertTrue(t, errors.Is(actionErr, testtool.ErrExpTest)) - testtool.AssertTrue(t, errors.Is(previousErr, compErr)) + }).WithAfterHook(func(ctx context.Context, track Track) error { executed = append(executed, "comp_hook1") + + data := track.GetData() + assert.Equal(t, 1, data.Action.Calls) + assert.Equal(t, 1, data.Compensation.Calls) + assert.Equal(t, 1, len(data.Action.Errors)) + assert.ErrorIs(t, data.Action.Errors[0], testtool.ErrExpTestA) + assert.Equal(t, 1, len(data.Compensation.Errors)) + assert.ErrorIs(t, data.Compensation.Errors[0], compErr) + return previousHookErr - }).WithAfterHook(func(ctx context.Context, actionErr, previousErr error) error { - testtool.AssertError(t, actionErr) - testtool.AssertTrue(t, errors.Is(actionErr, testtool.ErrExpTest)) - testtool.AssertTrue(t, errors.Is(previousErr, previousHookErr)) + }).WithAfterHook(func(ctx context.Context, track Track) error { executed = append(executed, "comp_hook2") + + data := track.GetData() + assert.Equal(t, 1, data.Action.Calls) + assert.Equal(t, 1, data.Compensation.Calls) + assert.Equal(t, 1, len(data.Action.Errors)) + assert.ErrorIs(t, data.Action.Errors[0], testtool.ErrExpTestA) + assert.Equal(t, 2, len(data.Compensation.Errors)) + assert.ErrorIs(t, data.Compensation.Errors[0], compErr) + assert.Equal(t, 2, len(data.Compensation.Errors)) + assert.ErrorIs(t, data.Compensation.Errors[1], previousHookErr) return nil }), - ).WithCompensationOnFail(), + ).WithCompensationRequired(), } - err := NewSaga(steps).Execute(ctx) - testtool.AssertError(t, err) - testtool.AssertTrue(t, errors.Is(err, testtool.ErrExpTest)) - testtool.AssertTrue(t, errors.Is(err, ErrCompensationSuccess)) - testtool.AssertTrue(t, slices.Equal([]string{"action1", "comp1", "comp_hook1", "comp_hook2"}, executed)) + res, err := NewSaga(steps).Execute(ctx) + assert.Error(t, err) + assert.Equal(t, StageResultCompensated, res.Status) + + assert.True(t, slices.Equal([]string{"action1", "comp1", "comp_hook1", "comp_hook2"}, executed)) - testtool.LogError(t, err) }) }) } func Test_wrapper(t *testing.T) { + var ( + ctx = context.Background() + ) t.Run("action", func(t *testing.T) { var ( - ctx = context.Background() - executed []string + calls = make([]string, 0, 3) ) steps := []Step{ NewStep("step1"). WithAction( - NewAction(func(ctx context.Context) error { - executed = append(executed, "action1") + NewAction(func(ctx context.Context, _ Track) error { + calls = append(calls, "action1") return nil - }).WithWrapper(func(ctx context.Context, action ActionFunc) error { - executed = append(executed, "before1") - err := action(ctx) - testtool.AssertNoError(t, err) - executed = append(executed, "after1") + }).WithWrapper(func(ctx context.Context, track Track, action ActionFunc) error { + calls = append(calls, "before1") + err := action(ctx, track) + assert.NoError(t, err) + calls = append(calls, "after1") return nil }), ), } - err := NewSaga(steps).Execute(ctx) - testtool.AssertNoError(t, err) - testtool.AssertTrue(t, slices.Equal([]string{"before1", "action1", "after1"}, executed)) + res, err := NewSaga(steps).Execute(ctx) + assert.NoError(t, err) + + assert.Equal(t, StageResultSuccess, res.Status) + assert.Equal(t, 1, res.Steps[0].Action.Calls) + assert.Equal(t, ExecutionStatusSuccess, res.Steps[0].Action.Status) + assert.Equal(t, 0, res.Steps[0].Compensation.Calls) + assert.Equal(t, ExecutionStatusUnset, res.Steps[0].Compensation.Status) + + assert.True(t, slices.Equal([]string{"before1", "action1", "after1"}, calls)) }) t.Run("compensation", func(t *testing.T) { var ( - ctx = context.Background() - expErr = testtool.ErrExpTest - executed []string + expErr = testtool.ErrExpTestA + calls = make([]string, 0, 6) ) steps := []Step{ NewStep("step1"). WithAction( - NewAction(func(ctx context.Context) error { - executed = append(executed, "action1") + NewAction(func(ctx context.Context, _ Track) error { + calls = append(calls, "action1") return expErr - }).WithWrapper(func(ctx context.Context, action ActionFunc) error { - executed = append(executed, "before_action1") - err := action(ctx) // call action - testtool.AssertError(t, err) - testtool.AssertTrue(t, errors.Is(err, expErr)) - executed = append(executed, "after_action1") + }).WithWrapper(func(ctx context.Context, track Track, action ActionFunc) error { + calls = append(calls, "before_action1") + err := action(ctx, track) // call action + assert.Error(t, err) + assert.ErrorIs(t, err, expErr) + calls = append(calls, "after_action1") return err }), ).WithCompensation( - NewCompensation(func(ctx context.Context, actionErr error) error { - executed = append(executed, "com1") - testtool.AssertError(t, actionErr) - testtool.AssertTrue(t, errors.Is(actionErr, expErr)) + NewCompensation(func(ctx context.Context, track Track) error { + calls = append(calls, "com1") + actionErr := track.GetData().Action.Errors[0] + assert.Error(t, actionErr) + assert.ErrorIs(t, actionErr, expErr) return nil - }).WithWrapper(func(ctx context.Context, actionErr error, comp CompensationFunc) error { - executed = append(executed, "before_comp1") - testtool.AssertError(t, actionErr) - testtool.AssertTrue(t, errors.Is(actionErr, expErr)) - err := comp(ctx, actionErr) // call compensation - testtool.AssertNoError(t, err) - executed = append(executed, "after_comp1") + }).WithWrapper(func(ctx context.Context, track Track, comp CompensationFunc) error { + calls = append(calls, "before_comp1") + actionErr := track.GetData().Action.Errors[0] + assert.Error(t, actionErr) + assert.ErrorIs(t, actionErr, expErr) + err := comp(ctx, track) // call compensation + assert.NoError(t, err) + calls = append(calls, "after_comp1") return nil }), - ).WithCompensationOnFail(), + ).WithCompensationRequired(), } - err := NewSaga(steps).Execute(ctx) - testtool.AssertError(t, err) - testtool.AssertTrue(t, errors.Is(err, ErrActionFailed)) - testtool.AssertTrue(t, + res, err := NewSaga(steps).Execute(ctx) + assert.Error(t, err) + assert.Equal(t, StageResultCompensated, res.Status) + assert.Equal(t, 1, res.Steps[0].Action.Calls) + assert.Equal(t, 1, len(res.Steps[0].Action.Errors)) + assert.Equal(t, ExecutionStatusFail, res.Steps[0].Action.Status) + assert.Equal(t, 1, res.Steps[0].Compensation.Calls) + assert.Equal(t, ExecutionStatusSuccess, res.Steps[0].Compensation.Status) + assert.Equal(t, 0, len(res.Steps[0].Compensation.Errors)) + + assert.True(t, slices.Equal( []string{ "before_action1", "action1", "after_action1", "before_comp1", "com1", "after_comp1", - }, executed)) + }, calls)) + }) +} + +func Test_steps(t *testing.T) { + t.Run("action", func(t *testing.T) { + var ( + ctx = context.Background() + ) + t.Run("success_v1", func(t *testing.T) { + steps := []Step{ + NewStep("step1"). + WithAction( + NewAction(func(ctx context.Context, _ Track) error { + return nil + }), + ), + } + + res, err := NewSaga(steps).Execute(ctx) + assert.NoError(t, err) + assert.Equal(t, StageResultSuccess, res.Status) + + assert.Equal(t, 1, len(res.Steps)) + + assert.Equal(t, "step1", res.Steps[0].StepName) + assert.Equal(t, 0, res.Steps[0].StepPosition) + assert.Equal(t, ExecutionStatusSuccess, res.Steps[0].Action.Status) + assert.Equal(t, 1, res.Steps[0].Action.Calls) + assert.Equal(t, 0, len(res.Steps[0].Action.Errors)) + + testtool.TestFn(t, func() { + t.Log(res) + }) + }) + t.Run("fail_v1", func(t *testing.T) { + steps := []Step{ + NewStep("step1"). + WithAction( + NewAction(func(ctx context.Context, _ Track) error { + return testtool.ErrExpTestA + }), + ), + } + + res, err := NewSaga(steps).Execute(ctx) + assert.Error(t, err) + assert.Equal(t, StageResultFail, res.Status) + assert.True(t, errors.Is(err, ErrActionFailed)) + assert.Equal(t, 1, len(res.Steps)) + assert.Equal(t, "step1", res.Steps[0].StepName) + assert.Equal(t, 0, res.Steps[0].StepPosition) + assert.Equal(t, ExecutionStatusFail, res.Steps[0].Action.Status) + assert.Equal(t, ExecutionStatusUnset, res.Steps[0].Compensation.Status) + assert.Equal(t, 1, res.Steps[0].Action.Calls) + assert.Equal(t, 1, len(res.Steps[0].Action.Errors)) + assert.ErrorIs(t, res.Steps[0].Action.Errors[0], testtool.ErrExpTestA) + + testtool.TestFn(t, func() { + t.Log(res) + }) + }) + t.Run("fail_v2", func(t *testing.T) { + steps := []Step{ + NewStep("step1"). + WithAction( + NewAction(func(ctx context.Context, _ Track) error { + return nil + }), + ), + NewStep("step2"). + WithAction( + NewAction(func(ctx context.Context, _ Track) error { + return testtool.ErrExpTestA + }), + ), + } + + res, err := NewSaga(steps).Execute(ctx) + assert.Error(t, err) + assert.Equal(t, StageResultCompensated, res.Status) + assert.True(t, errors.Is(err, ErrActionFailed)) + + assert.Equal(t, 2, len(res.Steps)) + + assert.Equal(t, "step1", res.Steps[0].StepName) + assert.Equal(t, 0, res.Steps[0].StepPosition) + assert.Equal(t, ExecutionStatusSuccess, res.Steps[0].Action.Status) + assert.Equal(t, ExecutionStatusUnset, res.Steps[0].Compensation.Status) + assert.Equal(t, false, res.Steps[0].CompensationRequired) + assert.Equal(t, 1, res.Steps[0].Action.Calls) + assert.Equal(t, 0, len(res.Steps[0].Action.Errors)) + + assert.Equal(t, "step2", res.Steps[1].StepName) + assert.Equal(t, 1, res.Steps[1].StepPosition) + assert.Equal(t, ExecutionStatusFail, res.Steps[1].Action.Status) + assert.Equal(t, ExecutionStatusUnset, res.Steps[1].Compensation.Status) + assert.Equal(t, false, res.Steps[1].CompensationRequired) + assert.Equal(t, 1, res.Steps[1].Action.Calls) + assert.Equal(t, 1, len(res.Steps[1].Action.Errors)) + assert.ErrorIs(t, res.Steps[1].Action.Errors[0], testtool.ErrExpTestA) + + testtool.TestFn(t, func() { + t.Log(res) + }) + }) + }) + + t.Run("compensation", func(t *testing.T) { + var ( + ctx = context.Background() + ) + t.Run("success_v1", func(t *testing.T) { + steps := []Step{ + NewStep("step1"). + WithAction( + NewAction(func(ctx context.Context, _ Track) error { + return testtool.ErrExpTestA + }), + ). + WithCompensation( + NewCompensation(func(ctx context.Context, track Track) error { + str := track.GetData() + assert.Equal(t, "step1", str.StepName) + assert.Equal(t, 0, str.StepPosition) + assert.Equal(t, ExecutionStatusFail, str.Action.Status) + assert.Equal(t, 1, str.Action.Calls) + assert.Equal(t, 1, len(str.Action.Errors)) + + assert.Equal(t, ExecutionStatusUncalled, str.Compensation.Status) + assert.Equal(t, 1, str.Compensation.Calls) + assert.Equal(t, 0, len(str.Compensation.Errors)) + return nil + }), + ).WithCompensationRequired(), + } + + res, err := NewSaga(steps).Execute(ctx) + assert.Error(t, err) + assert.Equal(t, StageResultCompensated, res.Status) + assert.True(t, errors.Is(err, ErrActionFailed)) + + assert.Equal(t, 1, len(res.Steps)) + + assert.Equal(t, "step1", res.Steps[0].StepName) + assert.Equal(t, 0, res.Steps[0].StepPosition) + assert.Equal(t, ExecutionStatusFail, res.Steps[0].Action.Status) + assert.Equal(t, 1, res.Steps[0].Action.Calls) + assert.Equal(t, 1, len(res.Steps[0].Action.Errors)) + + assert.Equal(t, ExecutionStatusSuccess, res.Steps[0].Compensation.Status) + assert.Equal(t, 1, res.Steps[0].Compensation.Calls) + assert.Equal(t, 0, len(res.Steps[0].Compensation.Errors)) + + testtool.TestFn(t, func() { + t.Log(res) + }) + }) + t.Run("compensate_v1", func(t *testing.T) { + steps := []Step{ + NewStep("step1"). + WithAction( + NewAction(func(ctx context.Context, _ Track) error { + return testtool.ErrExpTestA + }), + ). + WithCompensation( + NewCompensation(func(ctx context.Context, track Track) error { + str := track.GetData() + assert.Equal(t, "step1", str.StepName) + assert.Equal(t, 0, str.StepPosition) + assert.Equal(t, ExecutionStatusFail, str.Action.Status) + assert.Equal(t, 1, str.Action.Calls) + assert.Equal(t, 1, len(str.Action.Errors)) + assert.ErrorIs(t, str.Action.Errors[0], testtool.ErrExpTestA) + + assert.Equal(t, ExecutionStatusUncalled, str.Compensation.Status) + assert.Equal(t, 1, str.Compensation.Calls) + assert.Equal(t, 0, len(str.Compensation.Errors)) + return nil + }), + ).WithCompensationRequired(), + } + + res, err := NewSaga(steps).Execute(ctx) + assert.Error(t, err) + assert.Equal(t, StageResultCompensated, res.Status) + assert.True(t, errors.Is(err, ErrActionFailed)) + + assert.Equal(t, 1, len(res.Steps)) + + assert.Equal(t, "step1", res.Steps[0].StepName) + assert.Equal(t, 0, res.Steps[0].StepPosition) + assert.Equal(t, ExecutionStatusFail, res.Steps[0].Action.Status) + assert.Equal(t, 1, res.Steps[0].Action.Calls) + assert.Equal(t, 1, len(res.Steps[0].Action.Errors)) + + assert.Equal(t, ExecutionStatusSuccess, res.Steps[0].Compensation.Status) + assert.Equal(t, 1, res.Steps[0].Compensation.Calls) + assert.Equal(t, 0, len(res.Steps[0].Compensation.Errors)) + + testtool.TestFn(t, func() { + t.Log(res) + }) + }) + t.Run("compensate_v2", func(t *testing.T) { + steps := []Step{ + NewStep("step1"). + WithAction( + NewAction(func(ctx context.Context, _ Track) error { + return nil + }), + ), + NewStep("step2"). + WithAction( + NewAction(func(ctx context.Context, _ Track) error { + return testtool.ErrExpTestA + }), + ). + WithCompensation( + NewCompensation(func(ctx context.Context, track Track) error { + return nil + }), + ), + } + + res, err := NewSaga(steps).Execute(ctx) + assert.Error(t, err) + assert.Equal(t, StageResultCompensated, res.Status) + assert.True(t, errors.Is(err, ErrActionFailed)) + + assert.Equal(t, 2, len(res.Steps)) + + assert.Equal(t, "step1", res.Steps[0].StepName) + assert.Equal(t, 0, res.Steps[0].StepPosition) + assert.Equal(t, ExecutionStatusSuccess, res.Steps[0].Action.Status) + assert.Equal(t, ExecutionStatusUnset, res.Steps[0].Compensation.Status) + assert.Equal(t, false, res.Steps[0].CompensationRequired) + assert.Equal(t, 1, res.Steps[0].Action.Calls) + assert.Equal(t, 0, len(res.Steps[0].Action.Errors)) + + assert.Equal(t, "step2", res.Steps[1].StepName) + assert.Equal(t, 1, res.Steps[1].StepPosition) + assert.Equal(t, ExecutionStatusFail, res.Steps[1].Action.Status) + assert.Equal(t, ExecutionStatusUncalled, res.Steps[1].Compensation.Status) + assert.Equal(t, false, res.Steps[1].CompensationRequired) + assert.Equal(t, 1, res.Steps[1].Action.Calls) + assert.Equal(t, 1, len(res.Steps[1].Action.Errors)) + assert.ErrorIs(t, res.Steps[1].Action.Errors[0], testtool.ErrExpTestA) + + testtool.TestFn(t, func() { + t.Log(res) + }) + }) + t.Run("fail_v1", func(t *testing.T) { + steps := []Step{ + NewStep("step1"). + WithAction( + NewAction(func(ctx context.Context, _ Track) error { + return testtool.ErrExpTestA + }), + ). + WithCompensation( + NewCompensation(func(ctx context.Context, track Track) error { + str := track.GetData() + assert.Equal(t, "step1", str.StepName) + assert.Equal(t, 0, str.StepPosition) + assert.Equal(t, ExecutionStatusFail, str.Action.Status) + assert.Equal(t, 1, str.Action.Calls) + assert.Equal(t, 1, len(str.Action.Errors)) + assert.ErrorIs(t, str.Action.Errors[0], testtool.ErrExpTestA) + + assert.Equal(t, ExecutionStatusUncalled, str.Compensation.Status) + assert.Equal(t, 1, str.Compensation.Calls) + assert.Equal(t, 0, len(str.Compensation.Errors)) + return testtool.ErrExpTestB + }), + ).WithCompensationRequired(), + } + + res, err := NewSaga(steps).Execute(ctx) + assert.Error(t, err) + assert.Equal(t, StageResultFail, res.Status) + assert.ErrorIs(t, err, ErrActionFailed) + assert.ErrorIs(t, err, ErrCompensationFailed) + + assert.Equal(t, 1, len(res.Steps)) + + assert.Equal(t, "step1", res.Steps[0].StepName) + assert.Equal(t, 0, res.Steps[0].StepPosition) + assert.Equal(t, ExecutionStatusFail, res.Steps[0].Action.Status) + assert.Equal(t, 1, res.Steps[0].Action.Calls) + assert.Equal(t, 1, len(res.Steps[0].Action.Errors)) + + assert.Equal(t, ExecutionStatusFail, res.Steps[0].Compensation.Status) + assert.Equal(t, 1, res.Steps[0].Compensation.Calls) + assert.Equal(t, 1, len(res.Steps[0].Compensation.Errors)) + assert.ErrorIs(t, res.Steps[0].Compensation.Errors[0], testtool.ErrExpTestB) + + testtool.TestFn(t, func() { + t.Log(res) + }) + }) + }) +} + +func Test_retry(t *testing.T) { + var ( + ctx = context.Background() + ) + t.Run("compensation", func(t *testing.T) { + var ( + retries = uint32(4) + ) + steps := []Step{ + NewStep("step1"). + WithAction( + NewAction(func(ctx context.Context, _ Track) error { + return testtool.ErrExpTestA + }). + WithRetry(NewBaseRetryOpt(retries, 5*time.Nanosecond)), + ). + WithCompensation( + NewCompensation(func(ctx context.Context, track Track) error { + str := track.GetData() + if str.Compensation.Calls < retries+1 { + return fmt.Errorf("comp err [%d]: %w", len(str.Compensation.Errors), testtool.ErrExpTestA) + } + return nil + }).WithRetry(NewBaseRetryOpt(retries, 5*time.Nanosecond)), + ).WithCompensationRequired(), + } + + res, err := NewSaga(steps).Execute(ctx) + assert.Error(t, err) + assert.Equal(t, StageResultCompensated, res.Status) + assert.ErrorIs(t, err, ErrActionFailed) + assert.ErrorIsNot(t, err, ErrCompensationFailed) + + assert.Equal(t, 1, len(res.Steps)) + + assert.Equal(t, "step1", res.Steps[0].StepName) + assert.Equal(t, 0, res.Steps[0].StepPosition) + assert.Equal(t, ExecutionStatusFail, res.Steps[0].Action.Status) + assert.Equal(t, 5, res.Steps[0].Action.Calls) + assert.Equal(t, 5, len(res.Steps[0].Action.Errors)) + + for _, e := range res.Steps[0].Action.Errors { + assert.ErrorIs(t, e, testtool.ErrExpTestA) + } + + assert.Equal(t, ExecutionStatusSuccess, res.Steps[0].Compensation.Status) + assert.Equal(t, 5, res.Steps[0].Compensation.Calls) + assert.Equal(t, 4, len(res.Steps[0].Compensation.Errors)) + for _, e := range res.Steps[0].Compensation.Errors { + assert.ErrorIs(t, e, testtool.ErrExpTestA) + } + + testtool.TestFn(t, func() { + t.Log( + res, + "+ error:", err, + "\n + Action errors: ", res.Steps[0].Action.Errors, + "\n + Compensation errors: ", res.Steps[0].Compensation.Errors, + ) + }) + + }) + t.Run("compensation", func(t *testing.T) { + steps := []Step{ + NewStep("step1"). + WithAction( + NewAction(func(ctx context.Context, _ Track) error { + return testtool.ErrExpTestA + }). + WithRetry(NewBaseRetryOpt(4, 5*time.Nanosecond)), + ). + WithCompensation( + NewCompensation(func(ctx context.Context, track Track) error { + str := track.GetData() + return fmt.Errorf("comp err [%d]: %w", len(str.Compensation.Errors), testtool.ErrExpTestB) + }).WithRetry(NewBaseRetryOpt(4, 5*time.Nanosecond)), + ).WithCompensationRequired(), + } + + res, err := NewSaga(steps).Execute(ctx) + assert.Error(t, err) + assert.Equal(t, StageResultFail, res.Status) + assert.ErrorIs(t, err, ErrActionFailed) + assert.ErrorIs(t, err, ErrCompensationFailed) + + assert.Equal(t, 1, len(res.Steps)) + + assert.Equal(t, "step1", res.Steps[0].StepName) + assert.Equal(t, 0, res.Steps[0].StepPosition) + assert.Equal(t, ExecutionStatusFail, res.Steps[0].Action.Status) + assert.Equal(t, 5, res.Steps[0].Action.Calls) + assert.Equal(t, 5, len(res.Steps[0].Action.Errors)) + + for _, e := range res.Steps[0].Action.Errors { + assert.ErrorIs(t, e, testtool.ErrExpTestA) + } + + assert.Equal(t, ExecutionStatusFail, res.Steps[0].Compensation.Status) + assert.Equal(t, 5, res.Steps[0].Compensation.Calls) + assert.Equal(t, 5, len(res.Steps[0].Compensation.Errors)) + for _, e := range res.Steps[0].Compensation.Errors { + assert.ErrorIs(t, e, testtool.ErrExpTestB) + } + + testtool.TestFn(t, func() { + t.Log( + res, + "+ error:", err, + "\n + Action errors: ", res.Steps[0].Action.Errors, + "\n + Compensation errors: ", res.Steps[0].Compensation.Errors, + ) + }) + }) } diff --git a/saga/step.go b/saga/step.go index 9d5f8d8..3073f1e 100644 --- a/saga/step.go +++ b/saga/step.go @@ -24,8 +24,8 @@ type Step struct { // multiple times in failure scenarios. Compensation CompensationFunc - // CompensationOnFail needs to add the current compensation to the list of compensations. - CompensationOnFail bool + // CompensationRequired needs to add the current compensation to the list of compensations. + CompensationRequired bool } // NewStep creates a new Step with the given name. @@ -45,8 +45,8 @@ func (s Step) WithCompensation(fn CompensationFunc) Step { return s } -// WithCompensationOnFail enables compensation for this step on failure. -func (s Step) WithCompensationOnFail() Step { - s.CompensationOnFail = true +// WithCompensationRequired enables compensation for this step on failure. +func (s Step) WithCompensationRequired() Step { + s.CompensationRequired = true return s } diff --git a/saga/track.go b/saga/track.go new file mode 100644 index 0000000..8baaa34 --- /dev/null +++ b/saga/track.go @@ -0,0 +1,196 @@ +package saga + +import ( + "fmt" + "strings" + "sync" +) + +// ExecutionStatus represents the current state of an action or compensation execution. +type ExecutionStatus string + +const ( + // ExecutionStatusSuccess indicates the operation completed successfully. + ExecutionStatusSuccess ExecutionStatus = "Success" + // ExecutionStatusFail indicates the operation failed. + ExecutionStatusFail ExecutionStatus = "Fail" + // ExecutionStatusUncalled indicates the operation has not been invoked. + ExecutionStatusUncalled ExecutionStatus = "Uncalled" + // ExecutionStatusUnset indicates the operation is not configured (e.g., nil function). + ExecutionStatusUnset ExecutionStatus = "Unset" +) + +// StepData contains the complete execution history for a single saga step. +// It includes information about both the main action and its compensation. +type StepData struct { + StepPosition uint32 + StepName string + + Action ExecutionData + Compensation ExecutionData + CompensationRequired bool +} + +// String returns a human-readable representation of the StepData. +func (s StepData) String() string { + return fmt.Sprintf("Step %d: %s | Action: %s | Compensation: %s", + s.StepPosition, + s.StepName, + s.Action.String(), + s.Compensation.String()) +} + +// ExecutionData holds execution details for a single operation (action or compensation). +type ExecutionData struct { + Calls uint32 + Errors []error + Status ExecutionStatus +} + +// String returns a compact representation of ExecutionData. +func (t ExecutionData) String() string { + var builder strings.Builder + + builder.WriteString(fmt.Sprintf("{Status: %s, Calls: %d", t.Status, t.Calls)) + if len(t.Errors) > 0 { + builder.WriteString(fmt.Sprintf(", Errors: %d", len(t.Errors))) + // @TODO: add errors output + //if len(t.Errors) == 1 { + // builder.WriteString(fmt.Sprintf(" [%v]", t.Errors[0])) + //} + } + + builder.WriteString("}") + return builder.String() +} + +// Clone creates a deep copy of ExecutionData. +func (t ExecutionData) Clone() ExecutionData { + errors := make([]error, len(t.Errors)) + copy(errors, t.Errors) + return ExecutionData{ + Calls: t.Calls, + Errors: errors, + Status: t.Status, + } +} + +// executionTrack manages the execution state for a single saga step. +// It implements the Track interface and provides thread-safe state management. +type executionTrack struct { + StepName string + StepPosition uint32 + + mx *sync.RWMutex + + action *ExecutionData + compensation *ExecutionData + current *ExecutionData + + compensationOnFail bool + compensationFn CompensationFunc + parentErr error +} + +// newExecutionTrack creates a new executionTrack for a given step. +func newExecutionTrack(position uint32, step Step) *executionTrack { + track := executionTrack{ + StepName: step.Name, + StepPosition: position, + mx: new(sync.RWMutex), + action: &ExecutionData{ + Status: ExecutionStatusUncalled, + }, + compensation: &ExecutionData{ + Status: ExecutionStatusUncalled, + }, + compensationOnFail: step.CompensationRequired, + compensationFn: step.Compensation, + } + + if step.Compensation == nil { + track.compensation.Status = ExecutionStatusUnset + } + + if step.Action == nil { + track.action.Status = ExecutionStatusUnset + } + + return &track +} + +// actionTrack switches the current execution context to the action. +// Returns the track for method chaining. +func (t *executionTrack) actionTrack() *executionTrack { + t.mx.Lock() + defer t.mx.Unlock() + t.current = t.action + return t +} + +// compensationTrack switches the current execution context to the compensation. +// Returns the track for method chaining. +func (t *executionTrack) compensationTrack() *executionTrack { + t.mx.Lock() + defer t.mx.Unlock() + t.current = t.compensation + return t +} + +// call increments the call counter for the current execution context. +// Should be called before each invocation of the operation. +func (t *executionTrack) call() { + t.mx.Lock() + defer t.mx.Unlock() + t.current.Calls = t.current.Calls + 1 +} + +// setParentError sets a parent error that will be wrapped with subsequent errors. +// Used to provide context about which retry attempt or operation triggered an error. +func (t *executionTrack) setParentError(err error) { + t.mx.Lock() + defer t.mx.Unlock() + t.parentErr = err +} + +// SetStatus sets the status of the current execution context. +func (t *executionTrack) SetStatus(status ExecutionStatus) { + t.mx.Lock() + defer t.mx.Unlock() + t.current.Status = status +} + +// getStatus returns the status of the current execution context. +func (t *executionTrack) getStatus() ExecutionStatus { + t.mx.RLock() + defer t.mx.RUnlock() + return t.current.Status +} + +// SetFailedOnError marks the current execution as failed and records the error. +// If a parent error exists, it will be wrapped with the new error. +func (t *executionTrack) SetFailedOnError(err error) { + t.mx.Lock() + defer t.mx.Unlock() + if err == nil { + return + } + if t.parentErr != nil { + err = fmt.Errorf("%w: %w", t.parentErr, err) + } + t.current.Status = ExecutionStatusFail + t.current.Errors = append(t.current.Errors, err) +} + +// GetData returns a snapshot of the current execution state for this step. +func (t *executionTrack) GetData() StepData { + t.mx.RLock() + defer t.mx.RUnlock() + return StepData{ + StepName: t.StepName, + StepPosition: t.StepPosition, + Action: t.action.Clone(), + Compensation: t.compensation.Clone(), + CompensationRequired: t.compensationOnFail, + } +} diff --git a/test/integration/internal/saga/concurrent_test.go b/test/integration/internal/saga/concurrent_test.go index 52ae97e..a516062 100644 --- a/test/integration/internal/saga/concurrent_test.go +++ b/test/integration/internal/saga/concurrent_test.go @@ -30,7 +30,7 @@ func Test_Concurrent(t *testing.T) { steps := []saga.Step{ { Name: "step0", - Action: func(ctx context.Context) error { + Action: func(ctx context.Context, _ saga.Track) error { wg.Go(func() { mx.Lock() defer mx.Unlock() @@ -38,14 +38,14 @@ func Test_Concurrent(t *testing.T) { }) return nil }, - Compensation: func(ctx context.Context, aroseErr error) error { + Compensation: func(ctx context.Context, _ saga.Track) error { executedCompensation = append(executedCompensation, "comp0") return nil }, }, { Name: "step1", - Action: func(ctx context.Context) error { + Action: func(ctx context.Context, _ saga.Track) error { wg.Go(func() { mx.Lock() defer mx.Unlock() @@ -53,14 +53,14 @@ func Test_Concurrent(t *testing.T) { }) return nil }, - Compensation: func(ctx context.Context, aroseErr error) error { + Compensation: func(ctx context.Context, _ saga.Track) error { executedCompensation = append(executedCompensation, "comp1") return nil }, }, { Name: "step2", - Action: func(ctx context.Context) error { + Action: func(ctx context.Context, _ saga.Track) error { wg.Go(func() { mx.Lock() defer mx.Unlock() @@ -69,14 +69,14 @@ func Test_Concurrent(t *testing.T) { }) return nil }, - Compensation: func(ctx context.Context, aroseErr error) error { + Compensation: func(ctx context.Context, _ saga.Track) error { executedCompensation = append(executedCompensation, "comp2") return nil }, }, { Name: "check_async_sttep", - Action: func(ctx context.Context) error { + Action: func(ctx context.Context, _ saga.Track) error { wg.Wait() close(errChan) @@ -90,9 +90,9 @@ func Test_Concurrent(t *testing.T) { }, } - err := saga.NewSaga(steps).Execute(ctx) + res, err := saga.NewSaga(steps).Execute(ctx) assert.Error(t, err) - assert.ErrorIs(t, err, entity.ErrExpected) + assert.Equal(t, saga.StageResultCompensated, res.Status) assert.ElementsMatch(t, []string{"action0", "action1", "action2"}, executedActions) assert.ElementsMatch(t, []string{"comp0", "comp1", "comp2"}, executedCompensation) }) diff --git a/test/integration/internal/saga/facade_multi_test.go b/test/integration/internal/saga/facade_multi_test.go index d3462b7..2cea387 100644 --- a/test/integration/internal/saga/facade_multi_test.go +++ b/test/integration/internal/saga/facade_multi_test.go @@ -69,38 +69,38 @@ func Test_Saga_multi_Facade(t *testing.T) { mongoTransactor = mongo.NewTransactor(mongo.NewMongo(mongoClient)) mongoRepo = mongo.NewRepository(mongoCollectionA, mongoTransactor, false) ) - err := saga.NewSaga([]saga.Step{ + res, err := saga.NewSaga([]saga.Step{ { Name: "step_sql_0", - Action: func(ctx context.Context) error { + Action: func(ctx context.Context, _ saga.Track) error { err := sqlTransactor.WithinTx(ctx, func(ctx context.Context) error { return sqlRepo.Insert(ctx, sqlTextRecord) }) assert.NoError(t, err) return nil }, - Compensation: func(ctx context.Context, _ error) error { + Compensation: func(ctx context.Context, _ saga.Track) error { assert.Fail(t, "should not call (sql)") return nil }, }, { Name: "step_mongo_0", - Action: func(ctx context.Context) error { + Action: func(ctx context.Context, _ saga.Track) error { err := mongoTransactor.WithinTx(ctx, func(ctx context.Context) error { return mongoRepo.Save(ctx, mongoTestDataValA) }) assert.NoError(t, err) return nil }, - Compensation: func(ctx context.Context, _ error) error { + Compensation: func(ctx context.Context, _ saga.Track) error { assert.Fail(t, "should not call (mongo)") return nil }, }, { Name: "step_check_all", - Action: func(ctx context.Context) error { + Action: func(ctx context.Context, _ saga.Track) error { records, err := stdlib.GetTextRecords(sqlDB) assert.NoError(t, err) assert.Len(t, records, 1) @@ -116,6 +116,11 @@ func Test_Saga_multi_Facade(t *testing.T) { }).Execute(ctx) assert.NoError(t, err) + assert.Equal(t, saga.StageResultSuccess, res.Status) + + testtool.TestFn(t, func() { + printResult(t, res, err) + }) }) t.Run("success_compensation", func(t *testing.T) { @@ -128,40 +133,41 @@ func Test_Saga_multi_Facade(t *testing.T) { mongoTransactor = mongo.NewTransactor(mongo.NewMongo(mongoClient)) mongoRepo = mongo.NewRepository(mongoCollectionA, mongoTransactor, false) ) - err := saga.NewSaga([]saga.Step{ + res, err := saga.NewSaga([]saga.Step{ { Name: "step_sql_0", - Action: func(ctx context.Context) error { + Action: func(ctx context.Context, _ saga.Track) error { err := sqlTransactor.WithinTx(ctx, func(ctx context.Context) error { return sqlRepo.Insert(ctx, sqlTextRecord) }) assert.NoError(t, err) - return nil + return err }, - Compensation: func(ctx context.Context, aroseErr error) error { - assert.Error(t, aroseErr) - assert.ErrorIs(t, aroseErr, entity.ErrExpected) + Compensation: func(ctx context.Context, track saga.Track) error { + data := track.GetData() + assert.Len(t, data.Action.Errors, 0) err := sqlTransactor.WithinTx(ctx, func(ctx context.Context) error { return sqlRepo.Delete(ctx, sqlTextRecord) }) assert.NoError(t, err) - return nil + return err }, }, { Name: "step_mongo_0", - Action: func(ctx context.Context) error { + Action: func(ctx context.Context, _ saga.Track) error { err := mongoTransactor.WithinTx(ctx, func(ctx context.Context) error { return mongoRepo.Save(ctx, mongoTestDataValA) }) assert.NoError(t, err) - return nil + return err }, - Compensation: func(ctx context.Context, aroseErr error) error { - assert.Error(t, aroseErr) - assert.ErrorIs(t, aroseErr, entity.ErrExpected) + Compensation: func(ctx context.Context, track saga.Track) error { + data := track.GetData() + assert.Len(t, data.Action.Errors, 0) + t.Log(data) err := mongoRepo.Delete(ctx, mongoTestDataValA) assert.NoError(t, err) return err @@ -169,7 +175,7 @@ func Test_Saga_multi_Facade(t *testing.T) { }, { Name: "step_check_all", - Action: func(ctx context.Context) error { + Action: func(ctx context.Context, _ saga.Track) error { records, err := stdlib.GetTextRecords(sqlDB) assert.NoError(t, err) assert.Len(t, records, 1) @@ -184,16 +190,19 @@ func Test_Saga_multi_Facade(t *testing.T) { }, { Name: "step_error", - Action: func(ctx context.Context) error { + Action: func(ctx context.Context, _ saga.Track) error { return entity.ErrExpected }, }, }).Execute(ctx) assert.Error(t, err) - assert.ErrorIs(t, err, entity.ErrExpected) + assert.ErrorIs(t, err, saga.ErrActionFailed) + assert.Equal(t, saga.StageResultCompensated, res.Status) - testtool.LogError(t, err) + testtool.TestFn(t, func() { + printResult(t, res, err) + }) { records, err := stdlib.GetTextRecords(sqlDB) @@ -209,7 +218,7 @@ func Test_Saga_multi_Facade(t *testing.T) { t.Run("success_compensation_in_single_action", func(t *testing.T) { t.Cleanup(cleanupFn) - t.Log("using `CompensationOnFail` flag") + t.Log("using `CompensationRequired` flag") var ( sqlTransactor = stdlib.NewTransactor(sqlDB) @@ -218,10 +227,10 @@ func Test_Saga_multi_Facade(t *testing.T) { mongoTransactor = mongo.NewTransactor(mongo.NewMongo(mongoClient)) mongoRepo = mongo.NewRepository(mongoCollectionA, mongoTransactor, false) ) - err := saga.NewSaga([]saga.Step{ + res, err := saga.NewSaga([]saga.Step{ { Name: "step_sql_0", - Action: func(ctx context.Context) error { + Action: func(ctx context.Context, _ saga.Track) error { // The parent [Transactor] which maintain SQL transactions. err := sqlTransactor.WithinTx(ctx, func(ctx context.Context) error { err := sqlRepo.Insert(ctx, sqlTextRecord) @@ -249,8 +258,8 @@ func Test_Saga_multi_Facade(t *testing.T) { return err }, // Need to add current compensation to list of compensations. - CompensationOnFail: true, - Compensation: func(ctx context.Context, aroseErr error) error { + CompensationRequired: true, + Compensation: func(ctx context.Context, track saga.Track) error { // check Mongo entities (commit). data, err := mongo.GetDataByID(ctx, t, mongoCollectionA, mongoTestID) assert.NoError(t, err) @@ -264,7 +273,8 @@ func Test_Saga_multi_Facade(t *testing.T) { // Compensation logic. // // Check an error type and call compensation only for Mongo. - if aroseErr != nil && errors.Is(aroseErr, entity.ErrExpected) { + trackData := track.GetData() + if len(trackData.Action.Errors) > 0 && errors.Is(trackData.Action.Errors[0], entity.ErrExpected) { err = mongoRepo.Delete(ctx, mongoTestDataValA) assert.NoError(t, err) return err @@ -275,9 +285,12 @@ func Test_Saga_multi_Facade(t *testing.T) { }, }).Execute(ctx) - assert.ErrorIs(t, err, entity.ErrExpected) + assert.ErrorIs(t, err, saga.ErrActionFailed) + assert.Equal(t, saga.StageResultCompensated, res.Status) - testtool.LogError(t, err) + testtool.TestFn(t, func() { + printResult(t, res, err) + }) { records, err := stdlib.GetTextRecords(sqlDB) diff --git a/test/integration/internal/saga/facade_stdlib_test.go b/test/integration/internal/saga/facade_stdlib_test.go index 8b5cdf6..502daef 100644 --- a/test/integration/internal/saga/facade_stdlib_test.go +++ b/test/integration/internal/saga/facade_stdlib_test.go @@ -6,7 +6,6 @@ import ( "testing" "github.com/kozmod/oniontx/internal/testtool" - "github.com/kozmod/oniontx/mtx" "github.com/kozmod/oniontx/saga" "github.com/kozmod/oniontx/test/integration/internal/entity" "github.com/kozmod/oniontx/test/integration/internal/stdlib" @@ -41,19 +40,20 @@ func Test_Saga_stdlib_Facade(t *testing.T) { repoA = stdlib.NewTextRepository(transactor, false) repoB = stdlib.NewTextRepository(transactor, true) ) - err := saga.NewSaga([]saga.Step{ + res, err := saga.NewSaga([]saga.Step{ { Name: "step_0", - Action: func(ctx context.Context) error { + Action: func(ctx context.Context, _ saga.Track) error { err := transactor.WithinTx(ctx, func(ctx context.Context) error { return repoA.Insert(ctx, textRecord) }) assert.NoError(t, err) return nil }, - Compensation: func(ctx context.Context, aroseErr error) error { - assert.Error(t, aroseErr) - assert.ErrorIs(t, aroseErr, entity.ErrExpected) + Compensation: func(ctx context.Context, track saga.Track) error { + //data := track.GetData() + //assert.Len(t, data.Action.Errors, 1) + //assert.ErrorIs(t, data.Action.Errors[0], entity.ErrExpected) err := repoA.Delete(ctx, textRecord) assert.NoError(t, err) @@ -62,7 +62,7 @@ func Test_Saga_stdlib_Facade(t *testing.T) { }, { Name: "step_1", - Action: func(ctx context.Context) error { + Action: func(ctx context.Context, _ saga.Track) error { records, err := stdlib.GetTextRecords(db) assert.NoError(t, err) assert.Len(t, records, 1) @@ -73,7 +73,7 @@ func Test_Saga_stdlib_Facade(t *testing.T) { }, { Name: "step_2", - Action: func(ctx context.Context) error { + Action: func(ctx context.Context, _ saga.Track) error { err := transactor.WithinTx(ctx, func(ctx context.Context) error { err := repoA.Insert(ctx, textRecord) if err != nil { @@ -96,13 +96,26 @@ func Test_Saga_stdlib_Facade(t *testing.T) { }).Execute(ctx) assert.Error(t, err) + assert.ErrorIs(t, err, saga.ErrActionFailed) - testtool.LogError(t, err) + assert.Equal(t, saga.StageResultCompensated, res.Status) + assert.Len(t, res.Steps, 3) - assert.ErrorIs(t, err, entity.ErrExpected) - assert.ErrorIs(t, err, saga.ErrActionFailed) - assert.ErrorIs(t, err, mtx.ErrRollbackSuccess) - assert.ErrorIs(t, err, saga.ErrCompensationSuccess) + assert.Equal(t, saga.ExecutionStatusSuccess, res.Steps[0].Action.Status) + assert.Equal(t, saga.ExecutionStatusSuccess, res.Steps[0].Compensation.Status) + + assert.Equal(t, saga.ExecutionStatusSuccess, res.Steps[1].Action.Status) + assert.Equal(t, saga.ExecutionStatusUnset, res.Steps[1].Compensation.Status) + + assert.Equal(t, saga.ExecutionStatusFail, res.Steps[2].Action.Status) + assert.Equal(t, saga.ExecutionStatusUnset, res.Steps[2].Compensation.Status) + + assert.Len(t, res.Steps[2].Action.Errors, 1) + assert.ErrorIs(t, res.Steps[2].Action.Errors[0], saga.ErrActionFailed) + + testtool.TestFn(t, func() { + printResult(t, res, err) + }) { records, err := stdlib.GetTextRecords(db) diff --git a/test/integration/internal/saga/helper.go b/test/integration/internal/saga/helper.go new file mode 100644 index 0000000..2d7d11f --- /dev/null +++ b/test/integration/internal/saga/helper.go @@ -0,0 +1,47 @@ +package saga + +import ( + "fmt" + "testing" + + "github.com/kozmod/oniontx/saga" +) + +func printResult(t *testing.T, res saga.Result, err error) { + t.Helper() + t.Logf("\nresult:\n%v", res) + fmt.Printf("\nexecution error: %v\n", err) + + for _, step := range res.Steps { + fmt.Printf("-----") + fmt.Printf("\nstep [%d#%s]:\n", step.StepPosition, step.StepName) + + switch { + case len(step.Action.Errors) > 0: + fmt.Printf(" action errors (%d):\n", len(step.Action.Errors)) + for i, e := range step.Action.Errors { + fmt.Printf(" %d: %v\n", i, e) + } + default: + fmt.Printf(" action errors: none\n") + } + + switch { + case len(step.Compensation.Errors) > 0: + fmt.Printf(" compensation errors (%d):\n", len(step.Compensation.Errors)) + for i, e := range step.Compensation.Errors { + fmt.Printf(" %d: %v\n", i, e) + } + default: + fmt.Printf(" compensation errors: none\n") + } + + fmt.Printf(" -----\n") + fmt.Printf(" action status: %v\n", step.Action.Status) + fmt.Printf(" compensation status: %v\n", step.Compensation.Status) + fmt.Printf(" compensation required: %v\n", step.CompensationRequired) + fmt.Printf(" action calls: %d\n", step.Action.Calls) + fmt.Printf(" compensation calls: %d\n", step.Compensation.Calls) + fmt.Printf("-----\n") + } +}