diff --git a/pgxlisten.go b/pgxlisten.go index 5b8d1f0..16b9f00 100644 --- a/pgxlisten.go +++ b/pgxlisten.go @@ -5,6 +5,7 @@ import ( "context" "errors" "fmt" + "sync" "time" "github.com/jackc/pgx/v5" @@ -27,16 +28,48 @@ type Listener struct { // is lost. If set to 0, the default of 1 minute is used. A negative value disables the timeout entirely. ReconnectDelay time.Duration - handlers map[string]Handler + handlers map[string]Handler + mux sync.RWMutex + hasStarted bool + conn *pgx.Conn } // Handle sets the handler for notifications sent to channel. -func (l *Listener) Handle(channel string, handler Handler) { +func (l *Listener) Handle(ctx context.Context, channel string, handler Handler) error { + l.mux.Lock() + defer l.mux.Unlock() if l.handlers == nil { l.handlers = make(map[string]Handler) } + _, ok := l.handlers[channel] l.handlers[channel] = handler + if l.hasStarted && !ok { + err := l.listenandbacklog(ctx, channel, handler) + return err + } + return nil +} + +// Unhandle removes the handler for notifications sent to channel. +func (l *Listener) Unhandle(ctx context.Context, channel string) (bool, error) { + l.mux.Lock() + defer l.mux.Unlock() + if l.handlers == nil { + return false, nil + } + _, ok := l.handlers[channel] + if !ok { + return false, nil + } + delete(l.handlers, channel) + if l.hasStarted { + _, err := l.conn.Exec(ctx, "unlisten "+pgx.Identifier{channel}.Sanitize()) + if err != nil { + return true, err + } + } + return true, nil } // Listen listens for and handles notifications. It will only return when ctx is cancelled or a fatal error occurs. @@ -80,24 +113,34 @@ func (l *Listener) Listen(ctx context.Context) error { } } +func (l *Listener) listenandbacklog(ctx context.Context, channel string, handler Handler) error { + _, err := l.conn.Exec(ctx, "listen "+pgx.Identifier{channel}.Sanitize()) + if err != nil { + return fmt.Errorf("listen %q: %w", channel, err) + } + + if backlogHandler, ok := handler.(BacklogHandler); ok { + err := backlogHandler.HandleBacklog(ctx, channel, l.conn) + if err != nil { + l.logError(ctx, fmt.Errorf("handle backlog %q: %w", channel, err)) + } + } + return nil +} + func (l *Listener) listen(ctx context.Context) error { conn, err := l.Connect(ctx) if err != nil { return fmt.Errorf("connect: %w", err) } defer conn.Close(ctx) + l.hasStarted = true + l.conn = conn for channel, handler := range l.handlers { - _, err := conn.Exec(ctx, "listen "+pgx.Identifier{channel}.Sanitize()) + err = l.listenandbacklog(ctx, channel, handler) if err != nil { - return fmt.Errorf("listen %q: %w", channel, err) - } - - if backlogHandler, ok := handler.(BacklogHandler); ok { - err := backlogHandler.HandleBacklog(ctx, channel, conn) - if err != nil { - l.logError(ctx, fmt.Errorf("handle backlog %q: %w", channel, err)) - } + return err } } @@ -107,14 +150,18 @@ func (l *Listener) listen(ctx context.Context) error { return fmt.Errorf("waiting for notification: %w", err) } - if handler, ok := l.handlers[notification.Channel]; ok { - err := handler.HandleNotification(ctx, notification, conn) - if err != nil { - l.logError(ctx, fmt.Errorf("handle %s notification: %w", notification.Channel, err)) + func() { + l.mux.RLock() + defer l.mux.RUnlock() + if handler, ok := l.handlers[notification.Channel]; ok { + err := handler.HandleNotification(ctx, notification, conn) + if err != nil { + l.logError(ctx, fmt.Errorf("handle %s notification: %w", notification.Channel, err)) + } + } else { + l.logError(ctx, fmt.Errorf("missing handler: %s", notification.Channel)) } - } else { - l.logError(ctx, fmt.Errorf("missing handler: %s", notification.Channel)) - } + }() } } diff --git a/pgxlisten_test.go b/pgxlisten_test.go index 78bd8e3..e418a73 100644 --- a/pgxlisten_test.go +++ b/pgxlisten_test.go @@ -29,7 +29,7 @@ func TestListenerListenDispatchesNotifications(t *testing.T) { fooChan := make(chan *pgconn.Notification) barChan := make(chan *pgconn.Notification) - listener.Handle("foo", pgxlisten.HandlerFunc(func(ctx context.Context, notification *pgconn.Notification, conn *pgx.Conn) error { + listener.Handle(ctx, "foo", pgxlisten.HandlerFunc(func(ctx context.Context, notification *pgconn.Notification, conn *pgx.Conn) error { select { case fooChan <- notification: case <-ctx.Done(): @@ -37,7 +37,7 @@ func TestListenerListenDispatchesNotifications(t *testing.T) { return nil })) - listener.Handle("bar", pgxlisten.HandlerFunc(func(ctx context.Context, notification *pgconn.Notification, conn *pgx.Conn) error { + listener.Handle(ctx, "bar", pgxlisten.HandlerFunc(func(ctx context.Context, notification *pgconn.Notification, conn *pgx.Conn) error { select { case barChan <- notification: case <-ctx.Done(): @@ -164,7 +164,7 @@ create table pgxlisten_test (id int primary key generated by default as identity ch: fooChan, } - listener.Handle("foo", handler) + listener.Handle(ctx, "foo", handler) listenerCtx, listenerCtxCancel := context.WithCancel(ctx) defer listenerCtxCancel() @@ -178,10 +178,6 @@ create table pgxlisten_test (id int primary key generated by default as identity // No way to know when Listener is ready so wait a little. time.Sleep(2 * time.Second) - type notificationTest struct { - payload string - } - notificationMsgs := []string{"d", "e"} // Send all notifications.