Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 65 additions & 18 deletions pgxlisten.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"context"
"errors"
"fmt"
"sync"
"time"

"github.com/jackc/pgx/v5"
Expand All @@ -27,16 +28,48 @@ type Listener struct {
// is lost. If set to 0, the default of 1 minute is used. A negative value disables the timeout entirely.
ReconnectDelay time.Duration

handlers map[string]Handler
handlers map[string]Handler
mux sync.RWMutex
hasStarted bool
conn *pgx.Conn
}

// Handle sets the handler for notifications sent to channel.
func (l *Listener) Handle(channel string, handler Handler) {
func (l *Listener) Handle(ctx context.Context, channel string, handler Handler) error {
l.mux.Lock()
defer l.mux.Unlock()
if l.handlers == nil {
l.handlers = make(map[string]Handler)
}

_, ok := l.handlers[channel]
l.handlers[channel] = handler
if l.hasStarted && !ok {
err := l.listenandbacklog(ctx, channel, handler)
return err
}
return nil
}

// Unhandle removes the handler for notifications sent to channel.
func (l *Listener) Unhandle(ctx context.Context, channel string) (bool, error) {
l.mux.Lock()
defer l.mux.Unlock()
if l.handlers == nil {
return false, nil
}
_, ok := l.handlers[channel]
if !ok {
return false, nil
}
delete(l.handlers, channel)
if l.hasStarted {
_, err := l.conn.Exec(ctx, "unlisten "+pgx.Identifier{channel}.Sanitize())
if err != nil {
return true, err
}
}
return true, nil
}

// Listen listens for and handles notifications. It will only return when ctx is cancelled or a fatal error occurs.
Expand Down Expand Up @@ -80,24 +113,34 @@ func (l *Listener) Listen(ctx context.Context) error {
}
}

func (l *Listener) listenandbacklog(ctx context.Context, channel string, handler Handler) error {
_, err := l.conn.Exec(ctx, "listen "+pgx.Identifier{channel}.Sanitize())
if err != nil {
return fmt.Errorf("listen %q: %w", channel, err)
}

if backlogHandler, ok := handler.(BacklogHandler); ok {
err := backlogHandler.HandleBacklog(ctx, channel, l.conn)
if err != nil {
l.logError(ctx, fmt.Errorf("handle backlog %q: %w", channel, err))
}
}
return nil
}

func (l *Listener) listen(ctx context.Context) error {
conn, err := l.Connect(ctx)
if err != nil {
return fmt.Errorf("connect: %w", err)
}
defer conn.Close(ctx)
l.hasStarted = true
l.conn = conn

for channel, handler := range l.handlers {
_, err := conn.Exec(ctx, "listen "+pgx.Identifier{channel}.Sanitize())
err = l.listenandbacklog(ctx, channel, handler)
if err != nil {
return fmt.Errorf("listen %q: %w", channel, err)
}

if backlogHandler, ok := handler.(BacklogHandler); ok {
err := backlogHandler.HandleBacklog(ctx, channel, conn)
if err != nil {
l.logError(ctx, fmt.Errorf("handle backlog %q: %w", channel, err))
}
return err
}
}

Expand All @@ -107,14 +150,18 @@ func (l *Listener) listen(ctx context.Context) error {
return fmt.Errorf("waiting for notification: %w", err)
}

if handler, ok := l.handlers[notification.Channel]; ok {
err := handler.HandleNotification(ctx, notification, conn)
if err != nil {
l.logError(ctx, fmt.Errorf("handle %s notification: %w", notification.Channel, err))
func() {
l.mux.RLock()
defer l.mux.RUnlock()
if handler, ok := l.handlers[notification.Channel]; ok {
err := handler.HandleNotification(ctx, notification, conn)
if err != nil {
l.logError(ctx, fmt.Errorf("handle %s notification: %w", notification.Channel, err))
}
} else {
l.logError(ctx, fmt.Errorf("missing handler: %s", notification.Channel))
}
} else {
l.logError(ctx, fmt.Errorf("missing handler: %s", notification.Channel))
}
}()
}

}
Expand Down
10 changes: 3 additions & 7 deletions pgxlisten_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,15 @@ func TestListenerListenDispatchesNotifications(t *testing.T) {
fooChan := make(chan *pgconn.Notification)
barChan := make(chan *pgconn.Notification)

listener.Handle("foo", pgxlisten.HandlerFunc(func(ctx context.Context, notification *pgconn.Notification, conn *pgx.Conn) error {
listener.Handle(ctx, "foo", pgxlisten.HandlerFunc(func(ctx context.Context, notification *pgconn.Notification, conn *pgx.Conn) error {
select {
case fooChan <- notification:
case <-ctx.Done():
}
return nil
}))

listener.Handle("bar", pgxlisten.HandlerFunc(func(ctx context.Context, notification *pgconn.Notification, conn *pgx.Conn) error {
listener.Handle(ctx, "bar", pgxlisten.HandlerFunc(func(ctx context.Context, notification *pgconn.Notification, conn *pgx.Conn) error {
select {
case barChan <- notification:
case <-ctx.Done():
Expand Down Expand Up @@ -164,7 +164,7 @@ create table pgxlisten_test (id int primary key generated by default as identity
ch: fooChan,
}

listener.Handle("foo", handler)
listener.Handle(ctx, "foo", handler)

listenerCtx, listenerCtxCancel := context.WithCancel(ctx)
defer listenerCtxCancel()
Expand All @@ -178,10 +178,6 @@ create table pgxlisten_test (id int primary key generated by default as identity
// No way to know when Listener is ready so wait a little.
time.Sleep(2 * time.Second)

type notificationTest struct {
payload string
}

notificationMsgs := []string{"d", "e"}

// Send all notifications.
Expand Down