From 69d15aa5cc348d4c9a4573f7776e5468a214b073 Mon Sep 17 00:00:00 2001 From: Jared Jakacky Date: Tue, 19 May 2026 12:31:13 -0500 Subject: [PATCH] fix(runtime): tighten shutdown and failure events --- runtime.go | 13 +++++- runtime_test.go | 76 ++++++++++++++++++++++++++++++++- servekitservice/service_test.go | 60 ++++++++++++++++++++++++++ 3 files changed, 147 insertions(+), 2 deletions(-) diff --git a/runtime.go b/runtime.go index dcfaf35..82833cc 100644 --- a/runtime.go +++ b/runtime.go @@ -471,9 +471,17 @@ func (r *Runtime) Shutdown(ctx context.Context) error { if err := r.DrainAllBestEffort(ctx); err != nil { errs = append(errs, err) } + if err := ctx.Err(); err != nil { + errs = append(errs, err) + return errors.Join(errs...) + } if err := r.WaitAllIdle(ctx); err != nil { errs = append(errs, err) } + if err := ctx.Err(); err != nil { + errs = append(errs, err) + return errors.Join(errs...) + } if err := r.StopAll(ctx); err != nil { errs = append(errs, err) } @@ -1022,6 +1030,7 @@ func (r *Runtime) failWorkerWithCommandAttempt(ctx context.Context, name, comman state := r.workerStates[name] before := state + suppressFailureEvent := command == "" && before.lifecycle == StateFailed && before.lastFailure != nil now := time.Now() state.lastTransition = &LifecycleTransition{ From: state.lifecycle, @@ -1044,7 +1053,9 @@ func (r *Runtime) failWorkerWithCommandAttempt(ctx context.Context, name, comman r.observeTransition(ctx, obs.worker, obs.from, obs.to, obs.at) } r.observeReadinessChange(ctx, obs.worker, obs.beforeReady, obs.afterReady, obs.at) - r.observeFailure(ctx, name, command, err, panicked, now, dispatchID, attempt) + if !suppressFailureEvent { + r.observeFailure(ctx, name, command, err, panicked, now, dispatchID, attempt) + } r.observeRuntimeChanges(ctx, obs.before, obs.after) } diff --git a/runtime_test.go b/runtime_test.go index 3dcd7d7..d13c459 100644 --- a/runtime_test.go +++ b/runtime_test.go @@ -235,7 +235,7 @@ func TestLifecycleAttemptFailureDoesNotEmitDuplicateFailureEvent(t *testing.T) { } } -func TestFailedWorkerFailureDoesNotEmitNoopTransition(t *testing.T) { +func TestFailedWorkerFailureDoesNotEmitDuplicateLifecycleFailureEvent(t *testing.T) { observer := &recordingObserver{} reportedFailure := errors.New("reported during start") returnedFailure := errors.New("returned after report") @@ -263,6 +263,14 @@ func TestFailedWorkerFailureDoesNotEmitNoopTransition(t *testing.T) { if !errors.Is(err, returnedFailure) { t.Fatalf("Start error = %v, want %v", err, returnedFailure) } + snapshot := requireWorker(t, rt, "worker") + status := snapshot.Status + if status.State != StateFailed { + t.Fatalf("worker state = %s, want %s", status.State, StateFailed) + } + if status.LastFailure == nil || status.LastFailure.Message != returnedFailure.Error() { + t.Fatalf("LastFailure = %#v, want %q", status.LastFailure, returnedFailure.Error()) + } var failedToFailed int for _, event := range observer.transitions { @@ -273,6 +281,15 @@ func TestFailedWorkerFailureDoesNotEmitNoopTransition(t *testing.T) { if failedToFailed != 0 { t.Fatalf("failed -> failed worker transitions = %d, want 0", failedToFailed) } + if got := len(observer.failures); got != 1 { + t.Fatalf("failure events = %d, want 1", got) + } + if !errors.Is(observer.failures[0].Err, reportedFailure) { + t.Fatalf("failure event error = %v, want %v", observer.failures[0].Err, reportedFailure) + } + if observer.failures[0].Command != "" { + t.Fatalf("failure event command = %q, want empty", observer.failures[0].Command) + } } func TestCommandRetryEmitsFailurePerFailedAttemptAndSuccessfulCommandEnd(t *testing.T) { @@ -663,6 +680,63 @@ func TestShutdownDrainsWaitsForInFlightCommandsAndStopsInReverseOrder(t *testing } } +func TestShutdownTimeoutBeforeStopAllDoesNotFailWorker(t *testing.T) { + entered := make(chan struct{}) + release := make(chan struct{}) + var stops int + rt := newTestRuntime(t) + if err := rt.Register( + WorkerSpec{ + Name: "worker", + Worker: testWorker{ + stop: func(ctx context.Context) error { + stops++ + return ctx.Err() + }, + }, + }, + WithCommand("block", CommandHandlerFunc(func(context.Context, CommandRequest) (CommandResult, error) { + close(entered) + <-release + return CommandResult{}, nil + })), + ); err != nil { + t.Fatalf("Register returned error: %v", err) + } + if err := rt.Start(context.Background(), "worker"); err != nil { + t.Fatalf("Start returned error: %v", err) + } + + dispatchDone := make(chan error, 1) + go func() { + _, err := rt.Dispatch(context.Background(), CommandRequest{Worker: "worker", Name: "block"}) + dispatchDone <- err + }() + <-entered + + shutdownCtx, cancel := context.WithTimeout(context.Background(), time.Millisecond) + defer cancel() + err := rt.Shutdown(shutdownCtx) + close(release) + if dispatchErr := <-dispatchDone; dispatchErr != nil { + t.Fatalf("Dispatch returned error: %v", dispatchErr) + } + if !errors.Is(err, context.DeadlineExceeded) { + t.Fatalf("Shutdown error = %v, want DeadlineExceeded", err) + } + if stops != 0 { + t.Fatalf("stop calls = %d, want 0", stops) + } + snapshot := requireWorker(t, rt, "worker") + status := snapshot.Status + if status.State != StateDraining { + t.Fatalf("worker state = %s, want %s", status.State, StateDraining) + } + if status.LastFailure != nil { + t.Fatalf("LastFailure = %#v, want nil", status.LastFailure) + } +} + func TestDrainMarksWorkerUnreadyAndRejectsDispatch(t *testing.T) { rt := newTestRuntime(t) if err := rt.Register( diff --git a/servekitservice/service_test.go b/servekitservice/service_test.go index f2b5573..b9e7596 100644 --- a/servekitservice/service_test.go +++ b/servekitservice/service_test.go @@ -361,6 +361,66 @@ func TestShutdownWorkersUsesStopFallbackAfterIdleWaitTimeout(t *testing.T) { } } +func TestShutdownWorkersFallbackDoesNotLeaveDeadlineFailure(t *testing.T) { + t.Parallel() + + entered := make(chan struct{}) + release := make(chan struct{}) + var stops atomic.Int32 + rt := newTestRuntime(t) + if err := rt.Register( + workerkit.WorkerSpec{ + Name: "worker", + Worker: testWorker{ + stop: func(ctx context.Context) error { + stops.Add(1) + return ctx.Err() + }, + }, + }, + workerkit.WithCommand("block", workerkit.CommandHandlerFunc(func(context.Context, workerkit.CommandRequest) (workerkit.CommandResult, error) { + close(entered) + <-release + return workerkit.CommandResult{}, nil + })), + ); err != nil { + t.Fatalf("Register returned error: %v", err) + } + if err := rt.Start(context.Background(), "worker"); err != nil { + t.Fatalf("Start returned error: %v", err) + } + dispatchDone := make(chan error, 1) + go func() { + _, err := rt.Dispatch(context.Background(), workerkit.CommandRequest{Worker: "worker", Name: "block"}) + dispatchDone <- err + }() + <-entered + + service := newTestService(t, rt, WithShutdownTimeout(time.Millisecond)) + err := service.shutdownWorkers(context.Background()) + close(release) + if dispatchErr := <-dispatchDone; dispatchErr != nil { + t.Fatalf("Dispatch returned error: %v", dispatchErr) + } + if !errors.Is(err, context.DeadlineExceeded) { + t.Fatalf("shutdownWorkers error = %v, want DeadlineExceeded", err) + } + if got := stops.Load(); got != 1 { + t.Fatalf("stop calls = %d, want 1", got) + } + snapshot, ok := rt.Worker("worker") + if !ok { + t.Fatal("worker missing") + } + status := snapshot.Status + if status.State != workerkit.StateStopped { + t.Fatalf("worker state = %s, want %s", status.State, workerkit.StateStopped) + } + if status.LastFailure != nil { + t.Fatalf("LastFailure = %#v, want nil", status.LastFailure) + } +} + func TestRunRejectsInvalidService(t *testing.T) { t.Parallel()