Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions message/pubsub.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,17 @@ type Subscriber interface {
Close() error
}

// Stoppable is an optional interface that a Subscriber may implement to support
// graceful shutdown. When Stop is called, the subscriber should stop delivering
// new messages but continue to allow in-flight messages to be Acked or Nacked.
//
// The Router checks for this interface during shutdown. If the subscriber
// implements Stoppable, the Router calls Stop first, waits for running handlers
// to finish, and then calls Close.
type Stoppable interface {
Stop() error
}

// SubscribeInitializer is used to initialize subscribers.
type SubscribeInitializer interface {
// SubscribeInitialize can be called to initialize subscribe before consume.
Expand Down
27 changes: 18 additions & 9 deletions message/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -295,9 +295,10 @@ func (r *Router) AddHandler(
name: handlerName,
logger: r.logger,

subscriber: subscriber,
subscribeTopic: subscribeTopic,
subscriberName: subscriberName,
subscriber: subscriber,
originalSubscriber: subscriber,
subscribeTopic: subscribeTopic,
subscriberName: subscriberName,

publisher: publisher,
publishTopic: publishTopic,
Expand Down Expand Up @@ -619,9 +620,10 @@ type handler struct {
name string
logger watermill.LoggerAdapter

subscriber Subscriber
subscribeTopic string
subscriberName string
subscriber Subscriber
originalSubscriber Subscriber // before decoration, used for Stoppable check
subscribeTopic string
subscriberName string

publisher Publisher
publishTopic string
Expand Down Expand Up @@ -790,14 +792,21 @@ func (h *handler) addHandlerContext(messages ...*Message) {
func (h *handler) handleClose(ctx context.Context) {
select {
case <-h.routersCloseCh:
// for backward compatibility we are closing subscriber
h.logger.Debug("Waiting for subscriber to close", nil)
if stoppable, ok := h.originalSubscriber.(Stoppable); ok {
h.logger.Debug("Stopping subscriber", nil)
if err := stoppable.Stop(); err != nil {
h.logger.Error("Failed to stop subscriber", err, nil)
}
h.logger.Debug("Subscriber stopped, waiting for in-flight messages", nil)
h.runningHandlersWg.Wait()
}

h.logger.Debug("Closing subscriber", nil)
if err := h.subscriber.Close(); err != nil {
h.logger.Error("Failed to close subscriber", err, nil)
}
h.logger.Debug("Subscriber closed", nil)
case <-ctx.Done():
// we are closing subscriber just when entire router is closed
}
h.stopFn()
}
Expand Down
122 changes: 122 additions & 0 deletions message/router_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1498,3 +1498,125 @@ func TestRouter_stopping_all_handlers_logs_error(t *testing.T) {
logger.Captured(),
)
}

// stoppableSubscriber wraps a message.Subscriber and records Stop/Close call order.
type stoppableSubscriber struct {
message.Subscriber
mu sync.Mutex
calls []string
stopErr error
}

func (s *stoppableSubscriber) Stop() error {
s.mu.Lock()
s.calls = append(s.calls, "stop")
s.mu.Unlock()
return s.stopErr
}

func (s *stoppableSubscriber) Close() error {
s.mu.Lock()
s.calls = append(s.calls, "close")
s.mu.Unlock()
return s.Subscriber.Close()
}

func TestRouter_stoppable_subscriber_stop_called_before_close(t *testing.T) {
t.Parallel()

pubSub := gochannel.NewGoChannel(
gochannel.Config{Persistent: true},
watermill.NewStdLogger(true, true),
)

sub := &stoppableSubscriber{Subscriber: pubSub}
logger := watermill.NewCaptureLogger()

r, err := message.NewRouter(message.RouterConfig{}, logger)
require.NoError(t, err)

handlerDone := make(chan struct{})

r.AddConsumerHandler(
"foo",
"subscribe_topic",
sub,
func(msg *message.Message) error {
close(handlerDone)
return nil
},
)

go func() {
err := r.Run(context.Background())
assert.NoError(t, err)
}()
<-r.Running()

err = pubSub.Publish("subscribe_topic", message.NewMessage(watermill.NewUUID(), nil))
require.NoError(t, err)

<-handlerDone

assert.NoError(t, r.Close())

sub.mu.Lock()
defer sub.mu.Unlock()

require.Len(t, sub.calls, 2)
assert.Equal(t, "stop", sub.calls[0])
assert.Equal(t, "close", sub.calls[1])
}

func TestRouter_stoppable_subscriber_stop_error_does_not_prevent_close(t *testing.T) {
t.Parallel()

pubSub := gochannel.NewGoChannel(
gochannel.Config{Persistent: true},
watermill.NewStdLogger(true, true),
)

sub := &stoppableSubscriber{
Subscriber: pubSub,
stopErr: fmt.Errorf("stop failed"),
}
logger := watermill.NewCaptureLogger()

r, err := message.NewRouter(message.RouterConfig{}, logger)
require.NoError(t, err)

handlerDone := make(chan struct{})

r.AddConsumerHandler(
"foo",
"subscribe_topic",
sub,
func(msg *message.Message) error {
close(handlerDone)
return nil
},
)

go func() {
err := r.Run(context.Background())
assert.NoError(t, err)
}()
<-r.Running()

err = pubSub.Publish("subscribe_topic", message.NewMessage(watermill.NewUUID(), nil))
require.NoError(t, err)

<-handlerDone

assert.NoError(t, r.Close())

sub.mu.Lock()
defer sub.mu.Unlock()

// Stop() error should not prevent Close() from being called.
require.Len(t, sub.calls, 2)
assert.Equal(t, "stop", sub.calls[0])
assert.Equal(t, "close", sub.calls[1])

assert.True(t, logger.HasError(sub.stopErr))
}