diff --git a/pkg/sql/delayed_mysql.go b/pkg/sql/delayed_mysql.go new file mode 100644 index 0000000..a83f290 --- /dev/null +++ b/pkg/sql/delayed_mysql.go @@ -0,0 +1,198 @@ +package sql + +import ( + "encoding/json" + "fmt" + "strings" + + "github.com/ThreeDotsLabs/watermill" + "github.com/ThreeDotsLabs/watermill/components/delay" + "github.com/ThreeDotsLabs/watermill/message" +) + +type DelayedMySQLPublisherConfig struct { + // DelayPublisherConfig is a configuration for the delay.Publisher. + DelayPublisherConfig delay.PublisherConfig + + // OverridePublisherConfig allows overriding the default PublisherConfig. + OverridePublisherConfig func(config *PublisherConfig) error + + Logger watermill.LoggerAdapter +} + +func (c *DelayedMySQLPublisherConfig) setDefaults() { + if c.Logger == nil { + c.Logger = watermill.NopLogger{} + } +} + +// NewDelayedMySQLPublisher creates a new Publisher that stores messages in MySQL with a delay. +// The delay can be set per message with the Watermill's components/delay metadata. +func NewDelayedMySQLPublisher(db ContextExecutor, config DelayedMySQLPublisherConfig) (message.Publisher, error) { + config.setDefaults() + + publisherConfig := PublisherConfig{ + SchemaAdapter: delayedMySQLSchemaAdapter{ + MySQLQueueSchema: MySQLQueueSchema{}, + }, + AutoInitializeSchema: true, + } + + if config.OverridePublisherConfig != nil { + err := config.OverridePublisherConfig(&publisherConfig) + if err != nil { + return nil, err + } + } + + var publisher message.Publisher + var err error + + publisher, err = NewPublisher(db, publisherConfig, config.Logger) + if err != nil { + return nil, err + } + + publisher, err = delay.NewPublisher(publisher, config.DelayPublisherConfig) + if err != nil { + return nil, err + } + + return publisher, nil +} + +type DelayedMySQLSubscriberConfig struct { + // OverrideSubscriberConfig allows overriding the default SubscriberConfig. + OverrideSubscriberConfig func(config *SubscriberConfig) error + + // DeleteOnAck deletes the message from the queue when it's acknowledged. + DeleteOnAck bool + + // AllowNoDelay allows receiving messages without the delay metadata. + // By default, such messages will be skipped. + // If set to true, messages without delay metadata will be received immediately. + AllowNoDelay bool + + Logger watermill.LoggerAdapter +} + +func (c *DelayedMySQLSubscriberConfig) setDefaults() { + if c.Logger == nil { + c.Logger = watermill.NopLogger{} + } +} + +// NewDelayedMySQLSubscriber creates a new Subscriber that reads messages from MySQL with a delay. +// The delay can be set per message with the Watermill's components/delay metadata. +func NewDelayedMySQLSubscriber(db Beginner, config DelayedMySQLSubscriberConfig) (message.Subscriber, error) { + config.setDefaults() + + where := "delayed_until <= NOW()" + + if config.AllowNoDelay { + where += " OR delayed_until IS NULL" + } + + schemaAdapter := delayedMySQLSchemaAdapter{ + MySQLQueueSchema: MySQLQueueSchema{ + GenerateWhereClause: func(params GenerateWhereClauseParams) (string, []any) { + return where, nil + }, + }, + } + + subscriberConfig := SubscriberConfig{ + SchemaAdapter: schemaAdapter, + OffsetsAdapter: MySQLQueueOffsetsAdapter{ + DeleteOnAck: config.DeleteOnAck, + }, + InitializeSchema: true, + } + + if config.OverrideSubscriberConfig != nil { + err := config.OverrideSubscriberConfig(&subscriberConfig) + if err != nil { + return nil, err + } + } + + sub, err := NewSubscriber(db, subscriberConfig, config.Logger) + if err != nil { + return nil, err + } + + return sub, nil +} + +type delayedMySQLSchemaAdapter struct { + MySQLQueueSchema +} + +func (a delayedMySQLSchemaAdapter) SchemaInitializingQueries(params SchemaInitializingQueriesParams) ([]Query, error) { + createMessagesTable := ` + CREATE TABLE IF NOT EXISTS ` + a.MessagesTable(params.Topic) + ` ( + ` + "`offset`" + ` BIGINT NOT NULL AUTO_INCREMENT PRIMARY KEY, + ` + "`uuid`" + ` VARCHAR(36) NOT NULL, + ` + "`payload`" + ` ` + a.payloadColumnType(params.Topic) + ` DEFAULT NULL, + ` + "`metadata`" + ` JSON DEFAULT NULL, + ` + "`acked`" + ` BOOLEAN NOT NULL DEFAULT FALSE, + ` + "`created_at`" + ` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + ` + "`delayed_until`" + ` TIMESTAMP NULL DEFAULT NULL, + INDEX ` + "`delayed_until_idx`" + ` (` + "`delayed_until`" + `) + ); + ` + + return []Query{{Query: createMessagesTable}}, nil +} + +func (a delayedMySQLSchemaAdapter) InsertQuery(params InsertQueryParams) (Query, error) { + insertQuery := fmt.Sprintf( + `INSERT INTO %s (uuid, payload, metadata, delayed_until) VALUES %s`, + a.MessagesTable(params.Topic), + delayedMySQLInsertMarkers(len(params.Msgs)), + ) + + args, err := delayedMySQLInsertArgs(params.Msgs) + if err != nil { + return Query{}, err + } + + return Query{insertQuery, args}, nil +} + +func delayedMySQLInsertMarkers(count int) string { + result := strings.Builder{} + + for range count { + result.WriteString("(?,?,?,?),") + } + + return strings.TrimRight(result.String(), ",") +} + +func delayedMySQLInsertArgs(msgs message.Messages) ([]any, error) { + var args []any + + for _, msg := range msgs { + metadata, err := json.Marshal(msg.Metadata) + if err != nil { + return nil, fmt.Errorf("could not marshal metadata into JSON for message %s: %w", msg.UUID, err) + } + + args = append(args, msg.UUID, msg.Payload, metadata) + + // Extract delayed_until from metadata + delayedUntilStr := msg.Metadata.Get(delay.DelayedUntilKey) + if delayedUntilStr == "" { + args = append(args, nil) + } else { + // Convert ISO 8601 to MySQL TIMESTAMP format: "2025-10-22T09:58:00Z" -> "2025-10-22 09:58:00" + delayedUntilStr = strings.Replace(delayedUntilStr, "T", " ", 1) + delayedUntilStr = strings.TrimSuffix(delayedUntilStr, "Z") + + args = append(args, delayedUntilStr) + } + } + + return args, nil +} diff --git a/pkg/sql/delayed_mysql_test.go b/pkg/sql/delayed_mysql_test.go new file mode 100644 index 0000000..ce3f6e1 --- /dev/null +++ b/pkg/sql/delayed_mysql_test.go @@ -0,0 +1,134 @@ +package sql_test + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/ThreeDotsLabs/watermill" + "github.com/ThreeDotsLabs/watermill-sql/v4/pkg/sql" + "github.com/ThreeDotsLabs/watermill/components/delay" + "github.com/ThreeDotsLabs/watermill/message" +) + +func TestDelayedMySQL(t *testing.T) { + t.Parallel() + + db := newMySQL(t) + + pub, err := sql.NewDelayedMySQLPublisher(db, sql.DelayedMySQLPublisherConfig{ + DelayPublisherConfig: delay.PublisherConfig{ + DefaultDelayGenerator: func(params delay.DefaultDelayGeneratorParams) (delay.Delay, error) { + return delay.For(time.Second), nil + }, + }, + Logger: logger, + }) + require.NoError(t, err) + + sub, err := sql.NewDelayedMySQLSubscriber(db, sql.DelayedMySQLSubscriberConfig{ + DeleteOnAck: true, + Logger: logger, + }) + require.NoError(t, err) + + topic := watermill.NewUUID() + + messages, err := sub.Subscribe(context.Background(), topic) + require.NoError(t, err) + + msg := message.NewMessage(watermill.NewUUID(), []byte("{}")) + + err = pub.Publish(topic, msg) + require.NoError(t, err) + + select { + case <-messages: + t.Errorf("message should not be received") + case <-time.After(time.Millisecond * 200): + } + + assert.EventuallyWithT(t, func(t *assert.CollectT) { + select { + case received := <-messages: + assert.Equal(t, msg.UUID, received.UUID) + received.Ack() + default: + t.Errorf("message should be received") + } + }, time.Second, time.Millisecond*10) +} + +func TestDelayedMySQL_NoDelay(t *testing.T) { + t.Parallel() + + db := newMySQL(t) + + pub, err := sql.NewDelayedMySQLPublisher(db, sql.DelayedMySQLPublisherConfig{ + DelayPublisherConfig: delay.PublisherConfig{ + AllowNoDelay: true, + }, + Logger: logger, + }) + require.NoError(t, err) + + t.Run("skip_empty", func(t *testing.T) { + t.Parallel() + + sub, err := sql.NewDelayedMySQLSubscriber(db, sql.DelayedMySQLSubscriberConfig{ + DeleteOnAck: true, + Logger: logger, + }) + require.NoError(t, err) + + topic := watermill.NewUUID() + + messages, err := sub.Subscribe(context.Background(), topic) + require.NoError(t, err) + + msg := message.NewMessage(watermill.NewUUID(), []byte("{}")) + + err = pub.Publish(topic, msg) + require.NoError(t, err) + + select { + case <-messages: + t.Errorf("message should not be received") + case <-time.After(time.Second * 2): + } + }) + + t.Run("allow_empty", func(t *testing.T) { + t.Parallel() + + sub, err := sql.NewDelayedMySQLSubscriber(db, sql.DelayedMySQLSubscriberConfig{ + DeleteOnAck: true, + AllowNoDelay: true, + Logger: logger, + }) + require.NoError(t, err) + + topic := watermill.NewUUID() + + messages, err := sub.Subscribe(context.Background(), topic) + require.NoError(t, err) + + msg := message.NewMessage(watermill.NewUUID(), []byte("{}")) + + err = pub.Publish(topic, msg) + require.NoError(t, err) + + assert.EventuallyWithT(t, func(t *assert.CollectT) { + select { + case received := <-messages: + assert.Equal(t, msg.UUID, received.UUID) + received.Ack() + default: + t.Errorf("message should be received") + } + }, time.Second*2, time.Millisecond*10) + }) +} diff --git a/pkg/sql/queue_offsets_adapter_mysql.go b/pkg/sql/queue_offsets_adapter_mysql.go new file mode 100644 index 0000000..4673351 --- /dev/null +++ b/pkg/sql/queue_offsets_adapter_mysql.go @@ -0,0 +1,68 @@ +package sql + +import ( + "fmt" + "strings" +) + +// MySQLQueueOffsetsAdapter is an OffsetsAdapter for the MySQLQueueSchema. +type MySQLQueueOffsetsAdapter struct { + // DeleteOnAck determines whether the message should be deleted from the table when it is acknowledged. + // If false, the message will be marked as acked. + DeleteOnAck bool + + // GenerateMessagesTableName may be used to override how the messages table name is generated. + GenerateMessagesTableName func(topic string) string +} + +func (a MySQLQueueOffsetsAdapter) SchemaInitializingQueries(params OffsetsSchemaInitializingQueriesParams) ([]Query, error) { + return []Query{}, nil +} + +func (a MySQLQueueOffsetsAdapter) NextOffsetQuery(params NextOffsetQueryParams) (Query, error) { + return Query{}, nil +} + +func (a MySQLQueueOffsetsAdapter) AckMessageQuery(params AckMessageQueryParams) (Query, error) { + if params.ConsumerGroup != "" { + panic("consumer groups are not supported in MySQLQueueOffsetsAdapter") + } + + var ackQuery string + + table := a.MessagesTable(params.Topic) + + if a.DeleteOnAck { + // Build the WHERE clause for multiple offsets + offsetPlaceholders := strings.Repeat("?,", len(params.Rows)) + offsetPlaceholders = strings.TrimRight(offsetPlaceholders, ",") + ackQuery = fmt.Sprintf(`DELETE FROM %s WHERE `+"`offset`"+` IN (%s)`, table, offsetPlaceholders) + } else { + // Build the WHERE clause for multiple offsets + offsetPlaceholders := strings.Repeat("?,", len(params.Rows)) + offsetPlaceholders = strings.TrimRight(offsetPlaceholders, ",") + ackQuery = fmt.Sprintf(`UPDATE %s SET acked = TRUE WHERE `+"`offset`"+` IN (%s)`, table, offsetPlaceholders) + } + + offsets := make([]any, len(params.Rows)) + for i, row := range params.Rows { + offsets[i] = row.Offset + } + + return Query{ackQuery, offsets}, nil +} + +func (a MySQLQueueOffsetsAdapter) MessagesTable(topic string) string { + if a.GenerateMessagesTableName != nil { + return a.GenerateMessagesTableName(topic) + } + return fmt.Sprintf("`watermill_%s`", topic) +} + +func (a MySQLQueueOffsetsAdapter) ConsumedMessageQuery(params ConsumedMessageQueryParams) (Query, error) { + return Query{}, nil +} + +func (a MySQLQueueOffsetsAdapter) BeforeSubscribingQueries(params BeforeSubscribingQueriesParams) ([]Query, error) { + return []Query{}, nil +} diff --git a/pkg/sql/queue_schema_adapter_mysql.go b/pkg/sql/queue_schema_adapter_mysql.go new file mode 100644 index 0000000..4f9a000 --- /dev/null +++ b/pkg/sql/queue_schema_adapter_mysql.go @@ -0,0 +1,156 @@ +package sql + +import ( + "database/sql" + "encoding/json" + "errors" + "fmt" + "strings" + + "github.com/ThreeDotsLabs/watermill/message" +) + +// MySQLQueueSchema is a schema adapter for MySQL that allows filtering messages by some condition. +// It DOES NOT support consumer groups. +// It supports deleting messages on ack. +type MySQLQueueSchema struct { + // GenerateWhereClause is a function that returns a where clause and arguments for the SELECT query. + // It may be used to filter messages by some condition. + // If empty, no where clause will be added. + GenerateWhereClause func(params GenerateWhereClauseParams) (string, []any) + + // GeneratePayloadType is the type of the payload column in the messages table. + // By default, it's JSON. If your payload is not JSON, you can use LONGBLOB. + GeneratePayloadType func(topic string) string + + // GenerateMessagesTableName may be used to override how the messages table name is generated. + GenerateMessagesTableName func(topic string) string + + // SubscribeBatchSize is the number of messages to be queried at once. + // + // Higher value, increases a chance of message re-delivery in case of crash or networking issues. + // 1 is the safest value, but it may have a negative impact on performance when consuming a lot of messages. + // + // Default value is 100. + SubscribeBatchSize int +} + +func (s MySQLQueueSchema) SchemaInitializingQueries(params SchemaInitializingQueriesParams) ([]Query, error) { + createMessagesTable := ` + CREATE TABLE IF NOT EXISTS ` + s.MessagesTable(params.Topic) + ` ( + ` + "`offset`" + ` BIGINT NOT NULL AUTO_INCREMENT PRIMARY KEY, + ` + "`uuid`" + ` VARCHAR(36) NOT NULL, + ` + "`payload`" + ` ` + s.payloadColumnType(params.Topic) + ` DEFAULT NULL, + ` + "`metadata`" + ` JSON DEFAULT NULL, + ` + "`acked`" + ` BOOLEAN NOT NULL DEFAULT FALSE, + ` + "`created_at`" + ` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP + ); + ` + + return []Query{{Query: createMessagesTable}}, nil +} + +func (s MySQLQueueSchema) InsertQuery(params InsertQueryParams) (Query, error) { + insertQuery := fmt.Sprintf( + `INSERT INTO %s (uuid, payload, metadata) VALUES %s`, + s.MessagesTable(params.Topic), + mysqlQueueInsertMarkers(len(params.Msgs)), + ) + + args, err := defaultInsertArgs(params.Msgs) + if err != nil { + return Query{}, err + } + + return Query{insertQuery, args}, nil +} + +func mysqlQueueInsertMarkers(count int) string { + result := strings.Builder{} + + for range count { + result.WriteString("(?,?,?),") + } + + return strings.TrimRight(result.String(), ",") +} + +func (s MySQLQueueSchema) batchSize() int { + if s.SubscribeBatchSize == 0 { + return 100 + } + + return s.SubscribeBatchSize +} + +func (s MySQLQueueSchema) SelectQuery(params SelectQueryParams) (Query, error) { + if params.ConsumerGroup != "" { + return Query{}, errors.New("consumer groups are not supported in MySQLQueueSchema") + } + + whereParams := GenerateWhereClauseParams{ + Topic: params.Topic, + } + + var where string + var args []any + + if s.GenerateWhereClause != nil { + where, args = s.GenerateWhereClause(whereParams) + if where != "" { + where = "AND " + where + } + } + + selectQuery := ` + SELECT ` + "`offset`" + `, uuid, payload, metadata FROM ` + s.MessagesTable(params.Topic) + ` + WHERE acked = false ` + where + ` + ORDER BY + ` + "`offset`" + ` ASC + LIMIT ` + fmt.Sprintf("%d", s.batchSize()) + ` + FOR UPDATE` + + return Query{selectQuery, args}, nil +} + +func (s MySQLQueueSchema) UnmarshalMessage(params UnmarshalMessageParams) (Row, error) { + r := Row{} + + err := params.Row.Scan(&r.Offset, &r.UUID, &r.Payload, &r.Metadata) + if err != nil { + return Row{}, fmt.Errorf("could not scan message row: %w", err) + } + + msg := message.NewMessage(string(r.UUID), r.Payload) + + if r.Metadata != nil { + err = json.Unmarshal(r.Metadata, &msg.Metadata) + if err != nil { + return Row{}, fmt.Errorf("could not unmarshal metadata as JSON: %w", err) + } + } + + r.Msg = msg + + return r, nil +} + +func (s MySQLQueueSchema) MessagesTable(topic string) string { + if s.GenerateMessagesTableName != nil { + return s.GenerateMessagesTableName(topic) + } + return fmt.Sprintf("`watermill_%s`", topic) +} + +func (s MySQLQueueSchema) SubscribeIsolationLevel() sql.IsolationLevel { + // MySQL requires serializable isolation level for not losing messages. + return sql.LevelSerializable +} + +func (s MySQLQueueSchema) payloadColumnType(topic string) string { + if s.GeneratePayloadType == nil { + return "JSON" + } + + return s.GeneratePayloadType(topic) +} diff --git a/pkg/sql/queue_schema_adapter_mysql_test.go b/pkg/sql/queue_schema_adapter_mysql_test.go new file mode 100644 index 0000000..a1b0787 --- /dev/null +++ b/pkg/sql/queue_schema_adapter_mysql_test.go @@ -0,0 +1,80 @@ +package sql_test + +import ( + "context" + "fmt" + "strconv" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/ThreeDotsLabs/watermill" + "github.com/ThreeDotsLabs/watermill-sql/v4/pkg/sql" + "github.com/ThreeDotsLabs/watermill/message" +) + +func TestMySQLQueueSchemaAdapter(t *testing.T) { + t.Parallel() + + db := newMySQL(t) + + schemaAdapter := sql.MySQLQueueSchema{ + GenerateWhereClause: func(params sql.GenerateWhereClauseParams) (string, []any) { + return "JSON_EXTRACT(metadata, '$.skip') IS NULL OR JSON_EXTRACT(metadata, '$.skip') != 'true'", nil + }, + } + + pub, err := sql.NewPublisher(db, sql.PublisherConfig{ + SchemaAdapter: schemaAdapter, + AutoInitializeSchema: true, + }, logger) + require.NoError(t, err) + + sub, err := sql.NewSubscriber(db, sql.SubscriberConfig{ + SchemaAdapter: schemaAdapter, + OffsetsAdapter: sql.MySQLQueueOffsetsAdapter{ + DeleteOnAck: true, + }, + InitializeSchema: true, + }, logger) + require.NoError(t, err) + + topic := watermill.NewUUID() + + messages, err := sub.Subscribe(context.Background(), topic) + require.NoError(t, err) + + for i := 0; i < 10; i++ { + msg := message.NewMessage(fmt.Sprint(i), []byte("{}")) + if i%2 != 0 { + msg.Metadata.Set("skip", "true") + } + err = pub.Publish(topic, msg) + require.NoError(t, err) + } + + var receivedMessages []*message.Message + for i := 0; i < 5; i++ { + select { + case msg := <-messages: + receivedMessages = append(receivedMessages, msg) + msg.Ack() + case <-time.After(5 * time.Second): + t.Errorf("expected to receive message") + break + } + } + + require.Len(t, receivedMessages, 5) + + for _, msg := range receivedMessages { + assert.NotEqual(t, "true", msg.Metadata.Get("skip")) + + id, err := strconv.Atoi(msg.UUID) + require.NoError(t, err) + + assert.Equal(t, id%2, 0) + } +}