Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 148 additions & 0 deletions pgutils/connector.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
package pgutils

import (
"context"
"errors"
"fmt"
"log"
"net/url"

"database/sql"
"database/sql/driver"

"github.com/aws/aws-sdk-go-v2/aws"
awsconfig "github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/feature/rds/auth"
"github.com/jmoiron/sqlx"
"github.com/lib/pq"
)

type baseConnectionStringProvider interface {
getBaseConnectionString(ctx context.Context) (string, error)
}

type PostgresqlConnector struct {
baseConnectionStringProvider
searchPath string
}

func (conn *PostgresqlConnector) WithSearchPath(searchPath string) *PostgresqlConnector {
return &PostgresqlConnector{
baseConnectionStringProvider: conn.baseConnectionStringProvider,
searchPath: searchPath,
}
}

func (conn *PostgresqlConnector) Connect(ctx context.Context) (driver.Conn, error) {
dsn, err := conn.GetConnectionString(ctx)
if err != nil {
return nil, fmt.Errorf("get connection string: %w", err)
}
pqConnector, err := pq.NewConnector(dsn)
if err != nil {
return nil, fmt.Errorf("create pq connector: %w", err)
}

return pqConnector.Connect(ctx)
}

func (conn *PostgresqlConnector) GetConnectionString(ctx context.Context) (string, error) {
dsn, err := conn.getBaseConnectionString(ctx)
if err != nil {
return "", fmt.Errorf("get base connection string: %w", err)
}
if conn.searchPath == "" {
return dsn, nil
}

// Add search path
u, err := url.Parse(dsn)
if err != nil {
return "", fmt.Errorf("parse DSN URL: %w", err)
}
q := u.Query()
if v := q.Get("search_path"); v != "" {
return "", fmt.Errorf("search_path already set to %q", v)
}
q.Set("search_path", conn.searchPath) // url.Values will percent-encode commas as needed
u.RawQuery = q.Encode()
return u.String(), nil
}

func (c *PostgresqlConnector) Driver() driver.Driver {
return &pq.Driver{}
}

type staticConnectionStringProvider struct {
connectionString string
}

func (p *staticConnectionStringProvider) getBaseConnectionString(ctx context.Context) (string, error) {
return p.connectionString, nil
}

func NewPostgresqlConnectorFromConnectionString(connectionString string) *PostgresqlConnector {
return &PostgresqlConnector{
baseConnectionStringProvider: &staticConnectionStringProvider{connectionString},
}
}

type IAMAuthConfig struct {
RDSEndpoint string
User string
Database string
}

type iamAuthConnectionStringProvider struct {
IAMAuthConfig

region string
creds aws.CredentialsProvider
}

func (p *iamAuthConnectionStringProvider) getBaseConnectionString(ctx context.Context) (string, error) {
authToken, err := auth.BuildAuthToken(ctx, p.RDSEndpoint, p.region, p.User, p.creds)
if err != nil {
return "", fmt.Errorf("building auth token: %w", err)
}
log.Printf("Signing RDS IAM token for user: %s", p.User)

dsnURL := &url.URL{
Scheme: "postgresql",
User: url.UserPassword(p.User, authToken),
Host: p.RDSEndpoint,
Path: "/" + p.Database,
}

return dsnURL.String(), nil
}

func NewPostgresqlConnectorWithIAMAuth(ctx context.Context, cfg *IAMAuthConfig) (*PostgresqlConnector, error) {
if cfg.RDSEndpoint == "" || cfg.User == "" || cfg.Database == "" {
return nil, errors.New("RDS endpoint, user, and database are required")
}

awsCfg, err := awsconfig.LoadDefaultConfig(ctx)
if err != nil {
return nil, fmt.Errorf("load AWS config: %w", err)
}

if awsCfg.Region == "" {
return nil, errors.New("AWS region is not configured")
}

return &PostgresqlConnector{
baseConnectionStringProvider: &iamAuthConnectionStringProvider{
IAMAuthConfig: *cfg,
region: awsCfg.Region,
creds: awsCfg.Credentials,
},
}, nil
}

// Provides missing sqlx.OpenDB
func OpenDB(conn *PostgresqlConnector) *sqlx.DB {
sqlDB := sql.OpenDB(conn)
return sqlx.NewDb(sqlDB, "postgres")
}

177 changes: 177 additions & 0 deletions pgutils/listener.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
package pgutils

import (
"context"
"fmt"
"log"
"time"

"github.com/lib/pq"
)

// This listener exists because we cannot rely on lib/pq's built-in reconnect
// logic when using RDS IAM authentication. IAM auth uses short-lived tokens in
// the connection string; lib/pq will happily reconnect forever with the *same*
// DSN, but once the token expires those reconnect attempts can never succeed.
//
// To handle this, we treat pq.Listener as a disposable object and rebuild it
// whenever the connection is lost:
//
// - PostgresqlConnector.GetConnectionString(ctx) returns a fresh DSN, which
// includes a new IAM token.
// - makeListener() uses that DSN to construct a new pq.Listener and issues
// LISTEN pgChannelName.
//
// A goroutine runs the event loop: it drains listener.Notify,
// invokes the caller's callback for each notification, periodically calls
// Ping() to surface dead sockets, and handles reconnects with backoff when
// needed.
//
// Listen() orchestrates the lifecycle: it watches pq.Listener events,
// triggers reconnects via a small buffered channel, applies exponential
// backoff when creating new listeners fails, and ensures notifications
// that arrive while we are rebuilding the listener are intentionally
// dropped.
//
// Callers only need to pass a context, channel name, and callback; Listen
// hides the IAM token refresh, reconnection, and backoff mechanics.
//
// This is the path a postgres notification takes:
// Postgres NOTIFY
// -> pq.Listener internals
// -> listener.Notify channel
// -> Listen loop
// -> callback(notification)
// -> your business logic

func listenerEventToString(t pq.ListenerEventType) string {
switch t {
case pq.ListenerEventConnected:
return "connected"
case pq.ListenerEventDisconnected:
return "disconnected"
case pq.ListenerEventReconnected:
return "reconnected"
case pq.ListenerEventConnectionAttemptFailed:
return "connection failed"
default:
return fmt.Sprintf("Unknown: (%d)", t)
}
}

// Listen subscribes to a Postgres LISTEN channel (pgChannelName) and invokes callback for
// each notification. It automatically reconnects with backoff and pings periodically to
// surface dead sockets. If an onClose callback is provided, it is called once when the
// listener goroutine exits.
//
// Notifications that arrive while the listener is being rebuilt are intentionally dropped.
//
// The callback is invoked from the listener goroutine; it MUST NOT block
// for long periods. If you need to do heavy work, offload it to another
// goroutine.
func Listen(ctx context.Context, conn *PostgresqlConnector, pgChannelName string, callback func(*pq.Notification), onClose func()) error {
if callback == nil {
return fmt.Errorf("listener callback cannot be nil")
}

reconnectEventCh := make(chan struct{}, 1) // We just need a single reconnect event to trigger, so buffer size of 1

makeListener := func() (*pq.Listener, error) {
url, err := conn.GetConnectionString(ctx)
if err != nil {
return nil, fmt.Errorf("get url: %w", err)
}

cb := func(t pq.ListenerEventType, e error) {
eventType := listenerEventToString(t)
log.Printf("Postgres listener (%s): %s (err=%v)", pgChannelName, eventType, e)
if t == pq.ListenerEventDisconnected || t == pq.ListenerEventConnectionAttemptFailed {
select {
case reconnectEventCh <- struct{}{}:
default:
}
}
}

listener := pq.NewListener(url, time.Second, 30*time.Second, cb)
if err := listener.Listen(pgChannelName); err != nil {
_ = listener.Close()
return nil, fmt.Errorf("listen %q: %w", pgChannelName, err)
}
return listener, nil
}

// Build the first listener eagerly so callers learn about init failures immediately.
listener, err := makeListener()
if err != nil {
return err
}

go func() {
defer func() {
log.Printf("Postgres listener (%s): shutting down (ctx err=%v)", pgChannelName, ctx.Err())
if onClose != nil {
onClose()
}
if listener != nil {
_ = listener.Close()
}
}()

backoff := time.Second
const maxBackoff = 30 * time.Second

ping := time.NewTicker(60 * time.Second)
defer ping.Stop()

for {
// Rebuild listener with a fresh URL when needed.
for listener == nil {
listener, err = makeListener()
if err != nil {
log.Printf("listener create failed: %v (retry in %s)", err, backoff)
select {
case <-time.After(backoff):
if backoff < maxBackoff {
backoff *= 2
if backoff > maxBackoff {
backoff = maxBackoff
}
}
continue
case <-ctx.Done():
return
}
}
backoff = time.Second // reset on success
}

select {
case n, ok := <-listener.Notify:
if !ok {
return
}
if n == nil {
// Seen right after reconnects sometimes.
continue
}
callback(n)

case <-ping.C:
// Nudge connection to surface dead sockets.
go listener.Ping()

case <-reconnectEventCh:
// Lib/pq listener has entered a state where we want to reconnect with a new pq.Listener.
_ = listener.Close()
listener = nil

case <-ctx.Done():
return
}
}
}()

return nil
}