diff --git a/pgutils/connector.go b/pgutils/connector.go new file mode 100644 index 0000000..8a0951e --- /dev/null +++ b/pgutils/connector.go @@ -0,0 +1,148 @@ +package pgutils + +import ( + "context" + "errors" + "fmt" + "log" + "net/url" + + "database/sql" + "database/sql/driver" + + "github.com/aws/aws-sdk-go-v2/aws" + awsconfig "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/feature/rds/auth" + "github.com/jmoiron/sqlx" + "github.com/lib/pq" +) + +type baseConnectionStringProvider interface { + getBaseConnectionString(ctx context.Context) (string, error) +} + +type PostgresqlConnector struct { + baseConnectionStringProvider + searchPath string +} + +func (conn *PostgresqlConnector) WithSearchPath(searchPath string) *PostgresqlConnector { + return &PostgresqlConnector{ + baseConnectionStringProvider: conn.baseConnectionStringProvider, + searchPath: searchPath, + } +} + +func (conn *PostgresqlConnector) Connect(ctx context.Context) (driver.Conn, error) { + dsn, err := conn.GetConnectionString(ctx) + if err != nil { + return nil, fmt.Errorf("get connection string: %w", err) + } + pqConnector, err := pq.NewConnector(dsn) + if err != nil { + return nil, fmt.Errorf("create pq connector: %w", err) + } + + return pqConnector.Connect(ctx) +} + +func (conn *PostgresqlConnector) GetConnectionString(ctx context.Context) (string, error) { + dsn, err := conn.getBaseConnectionString(ctx) + if err != nil { + return "", fmt.Errorf("get base connection string: %w", err) + } + if conn.searchPath == "" { + return dsn, nil + } + + // Add search path + u, err := url.Parse(dsn) + if err != nil { + return "", fmt.Errorf("parse DSN URL: %w", err) + } + q := u.Query() + if v := q.Get("search_path"); v != "" { + return "", fmt.Errorf("search_path already set to %q", v) + } + q.Set("search_path", conn.searchPath) // url.Values will percent-encode commas as needed + u.RawQuery = q.Encode() + return u.String(), nil +} + +func (c *PostgresqlConnector) Driver() driver.Driver { + return &pq.Driver{} +} + +type staticConnectionStringProvider struct { + connectionString string +} + +func (p *staticConnectionStringProvider) getBaseConnectionString(ctx context.Context) (string, error) { + return p.connectionString, nil +} + +func NewPostgresqlConnectorFromConnectionString(connectionString string) *PostgresqlConnector { + return &PostgresqlConnector{ + baseConnectionStringProvider: &staticConnectionStringProvider{connectionString}, + } +} + +type IAMAuthConfig struct { + RDSEndpoint string + User string + Database string +} + +type iamAuthConnectionStringProvider struct { + IAMAuthConfig + + region string + creds aws.CredentialsProvider +} + +func (p *iamAuthConnectionStringProvider) getBaseConnectionString(ctx context.Context) (string, error) { + authToken, err := auth.BuildAuthToken(ctx, p.RDSEndpoint, p.region, p.User, p.creds) + if err != nil { + return "", fmt.Errorf("building auth token: %w", err) + } + log.Printf("Signing RDS IAM token for user: %s", p.User) + + dsnURL := &url.URL{ + Scheme: "postgresql", + User: url.UserPassword(p.User, authToken), + Host: p.RDSEndpoint, + Path: "/" + p.Database, + } + + return dsnURL.String(), nil +} + +func NewPostgresqlConnectorWithIAMAuth(ctx context.Context, cfg *IAMAuthConfig) (*PostgresqlConnector, error) { + if cfg.RDSEndpoint == "" || cfg.User == "" || cfg.Database == "" { + return nil, errors.New("RDS endpoint, user, and database are required") + } + + awsCfg, err := awsconfig.LoadDefaultConfig(ctx) + if err != nil { + return nil, fmt.Errorf("load AWS config: %w", err) + } + + if awsCfg.Region == "" { + return nil, errors.New("AWS region is not configured") + } + + return &PostgresqlConnector{ + baseConnectionStringProvider: &iamAuthConnectionStringProvider{ + IAMAuthConfig: *cfg, + region: awsCfg.Region, + creds: awsCfg.Credentials, + }, + }, nil +} + +// Provides missing sqlx.OpenDB +func OpenDB(conn *PostgresqlConnector) *sqlx.DB { + sqlDB := sql.OpenDB(conn) + return sqlx.NewDb(sqlDB, "postgres") +} + diff --git a/pgutils/listener.go b/pgutils/listener.go new file mode 100644 index 0000000..958462c --- /dev/null +++ b/pgutils/listener.go @@ -0,0 +1,177 @@ +package pgutils + +import ( + "context" + "fmt" + "log" + "time" + + "github.com/lib/pq" +) + +// This listener exists because we cannot rely on lib/pq's built-in reconnect +// logic when using RDS IAM authentication. IAM auth uses short-lived tokens in +// the connection string; lib/pq will happily reconnect forever with the *same* +// DSN, but once the token expires those reconnect attempts can never succeed. +// +// To handle this, we treat pq.Listener as a disposable object and rebuild it +// whenever the connection is lost: +// +// - PostgresqlConnector.GetConnectionString(ctx) returns a fresh DSN, which +// includes a new IAM token. +// - makeListener() uses that DSN to construct a new pq.Listener and issues +// LISTEN pgChannelName. +// +// A goroutine runs the event loop: it drains listener.Notify, +// invokes the caller's callback for each notification, periodically calls +// Ping() to surface dead sockets, and handles reconnects with backoff when +// needed. +// +// Listen() orchestrates the lifecycle: it watches pq.Listener events, +// triggers reconnects via a small buffered channel, applies exponential +// backoff when creating new listeners fails, and ensures notifications +// that arrive while we are rebuilding the listener are intentionally +// dropped. +// +// Callers only need to pass a context, channel name, and callback; Listen +// hides the IAM token refresh, reconnection, and backoff mechanics. +// +// This is the path a postgres notification takes: +// Postgres NOTIFY +// -> pq.Listener internals +// -> listener.Notify channel +// -> Listen loop +// -> callback(notification) +// -> your business logic + +func listenerEventToString(t pq.ListenerEventType) string { + switch t { + case pq.ListenerEventConnected: + return "connected" + case pq.ListenerEventDisconnected: + return "disconnected" + case pq.ListenerEventReconnected: + return "reconnected" + case pq.ListenerEventConnectionAttemptFailed: + return "connection failed" + default: + return fmt.Sprintf("Unknown: (%d)", t) + } +} + +// Listen subscribes to a Postgres LISTEN channel (pgChannelName) and invokes callback for +// each notification. It automatically reconnects with backoff and pings periodically to +// surface dead sockets. If an onClose callback is provided, it is called once when the +// listener goroutine exits. +// +// Notifications that arrive while the listener is being rebuilt are intentionally dropped. +// +// The callback is invoked from the listener goroutine; it MUST NOT block +// for long periods. If you need to do heavy work, offload it to another +// goroutine. +func Listen(ctx context.Context, conn *PostgresqlConnector, pgChannelName string, callback func(*pq.Notification), onClose func()) error { + if callback == nil { + return fmt.Errorf("listener callback cannot be nil") + } + + reconnectEventCh := make(chan struct{}, 1) // We just need a single reconnect event to trigger, so buffer size of 1 + + makeListener := func() (*pq.Listener, error) { + url, err := conn.GetConnectionString(ctx) + if err != nil { + return nil, fmt.Errorf("get url: %w", err) + } + + cb := func(t pq.ListenerEventType, e error) { + eventType := listenerEventToString(t) + log.Printf("Postgres listener (%s): %s (err=%v)", pgChannelName, eventType, e) + if t == pq.ListenerEventDisconnected || t == pq.ListenerEventConnectionAttemptFailed { + select { + case reconnectEventCh <- struct{}{}: + default: + } + } + } + + listener := pq.NewListener(url, time.Second, 30*time.Second, cb) + if err := listener.Listen(pgChannelName); err != nil { + _ = listener.Close() + return nil, fmt.Errorf("listen %q: %w", pgChannelName, err) + } + return listener, nil + } + + // Build the first listener eagerly so callers learn about init failures immediately. + listener, err := makeListener() + if err != nil { + return err + } + + go func() { + defer func() { + log.Printf("Postgres listener (%s): shutting down (ctx err=%v)", pgChannelName, ctx.Err()) + if onClose != nil { + onClose() + } + if listener != nil { + _ = listener.Close() + } + }() + + backoff := time.Second + const maxBackoff = 30 * time.Second + + ping := time.NewTicker(60 * time.Second) + defer ping.Stop() + + for { + // Rebuild listener with a fresh URL when needed. + for listener == nil { + listener, err = makeListener() + if err != nil { + log.Printf("listener create failed: %v (retry in %s)", err, backoff) + select { + case <-time.After(backoff): + if backoff < maxBackoff { + backoff *= 2 + if backoff > maxBackoff { + backoff = maxBackoff + } + } + continue + case <-ctx.Done(): + return + } + } + backoff = time.Second // reset on success + } + + select { + case n, ok := <-listener.Notify: + if !ok { + return + } + if n == nil { + // Seen right after reconnects sometimes. + continue + } + callback(n) + + case <-ping.C: + // Nudge connection to surface dead sockets. + go listener.Ping() + + case <-reconnectEventCh: + // Lib/pq listener has entered a state where we want to reconnect with a new pq.Listener. + _ = listener.Close() + listener = nil + + case <-ctx.Done(): + return + } + } + }() + + return nil +} +