diff --git a/message/pubsub.go b/message/pubsub.go index 258e1767e..75f874075 100644 --- a/message/pubsub.go +++ b/message/pubsub.go @@ -38,6 +38,17 @@ type Subscriber interface { Close() error } +// Stoppable is an optional interface that a Subscriber may implement to support +// graceful shutdown. When Stop is called, the subscriber should stop delivering +// new messages but continue to allow in-flight messages to be Acked or Nacked. +// +// The Router checks for this interface during shutdown. If the subscriber +// implements Stoppable, the Router calls Stop first, waits for running handlers +// to finish, and then calls Close. +type Stoppable interface { + Stop() error +} + // SubscribeInitializer is used to initialize subscribers. type SubscribeInitializer interface { // SubscribeInitialize can be called to initialize subscribe before consume. diff --git a/message/router.go b/message/router.go index b5c1a0b7d..e6d5118bf 100644 --- a/message/router.go +++ b/message/router.go @@ -295,9 +295,10 @@ func (r *Router) AddHandler( name: handlerName, logger: r.logger, - subscriber: subscriber, - subscribeTopic: subscribeTopic, - subscriberName: subscriberName, + subscriber: subscriber, + originalSubscriber: subscriber, + subscribeTopic: subscribeTopic, + subscriberName: subscriberName, publisher: publisher, publishTopic: publishTopic, @@ -619,9 +620,10 @@ type handler struct { name string logger watermill.LoggerAdapter - subscriber Subscriber - subscribeTopic string - subscriberName string + subscriber Subscriber + originalSubscriber Subscriber // before decoration, used for Stoppable check + subscribeTopic string + subscriberName string publisher Publisher publishTopic string @@ -790,14 +792,21 @@ func (h *handler) addHandlerContext(messages ...*Message) { func (h *handler) handleClose(ctx context.Context) { select { case <-h.routersCloseCh: - // for backward compatibility we are closing subscriber - h.logger.Debug("Waiting for subscriber to close", nil) + if stoppable, ok := h.originalSubscriber.(Stoppable); ok { + h.logger.Debug("Stopping subscriber", nil) + if err := stoppable.Stop(); err != nil { + h.logger.Error("Failed to stop subscriber", err, nil) + } + h.logger.Debug("Subscriber stopped, waiting for in-flight messages", nil) + h.runningHandlersWg.Wait() + } + + h.logger.Debug("Closing subscriber", nil) if err := h.subscriber.Close(); err != nil { h.logger.Error("Failed to close subscriber", err, nil) } h.logger.Debug("Subscriber closed", nil) case <-ctx.Done(): - // we are closing subscriber just when entire router is closed } h.stopFn() } diff --git a/message/router_test.go b/message/router_test.go index 4754f8618..9c2500ea1 100644 --- a/message/router_test.go +++ b/message/router_test.go @@ -1498,3 +1498,125 @@ func TestRouter_stopping_all_handlers_logs_error(t *testing.T) { logger.Captured(), ) } + +// stoppableSubscriber wraps a message.Subscriber and records Stop/Close call order. +type stoppableSubscriber struct { + message.Subscriber + mu sync.Mutex + calls []string + stopErr error +} + +func (s *stoppableSubscriber) Stop() error { + s.mu.Lock() + s.calls = append(s.calls, "stop") + s.mu.Unlock() + return s.stopErr +} + +func (s *stoppableSubscriber) Close() error { + s.mu.Lock() + s.calls = append(s.calls, "close") + s.mu.Unlock() + return s.Subscriber.Close() +} + +func TestRouter_stoppable_subscriber_stop_called_before_close(t *testing.T) { + t.Parallel() + + pubSub := gochannel.NewGoChannel( + gochannel.Config{Persistent: true}, + watermill.NewStdLogger(true, true), + ) + + sub := &stoppableSubscriber{Subscriber: pubSub} + logger := watermill.NewCaptureLogger() + + r, err := message.NewRouter(message.RouterConfig{}, logger) + require.NoError(t, err) + + handlerDone := make(chan struct{}) + + r.AddConsumerHandler( + "foo", + "subscribe_topic", + sub, + func(msg *message.Message) error { + close(handlerDone) + return nil + }, + ) + + go func() { + err := r.Run(context.Background()) + assert.NoError(t, err) + }() + <-r.Running() + + err = pubSub.Publish("subscribe_topic", message.NewMessage(watermill.NewUUID(), nil)) + require.NoError(t, err) + + <-handlerDone + + assert.NoError(t, r.Close()) + + sub.mu.Lock() + defer sub.mu.Unlock() + + require.Len(t, sub.calls, 2) + assert.Equal(t, "stop", sub.calls[0]) + assert.Equal(t, "close", sub.calls[1]) +} + +func TestRouter_stoppable_subscriber_stop_error_does_not_prevent_close(t *testing.T) { + t.Parallel() + + pubSub := gochannel.NewGoChannel( + gochannel.Config{Persistent: true}, + watermill.NewStdLogger(true, true), + ) + + sub := &stoppableSubscriber{ + Subscriber: pubSub, + stopErr: fmt.Errorf("stop failed"), + } + logger := watermill.NewCaptureLogger() + + r, err := message.NewRouter(message.RouterConfig{}, logger) + require.NoError(t, err) + + handlerDone := make(chan struct{}) + + r.AddConsumerHandler( + "foo", + "subscribe_topic", + sub, + func(msg *message.Message) error { + close(handlerDone) + return nil + }, + ) + + go func() { + err := r.Run(context.Background()) + assert.NoError(t, err) + }() + <-r.Running() + + err = pubSub.Publish("subscribe_topic", message.NewMessage(watermill.NewUUID(), nil)) + require.NoError(t, err) + + <-handlerDone + + assert.NoError(t, r.Close()) + + sub.mu.Lock() + defer sub.mu.Unlock() + + // Stop() error should not prevent Close() from being called. + require.Len(t, sub.calls, 2) + assert.Equal(t, "stop", sub.calls[0]) + assert.Equal(t, "close", sub.calls[1]) + + assert.True(t, logger.HasError(sub.stopErr)) +}