diff --git a/ably/example_message_updates_test.go b/ably/example_message_updates_test.go new file mode 100644 index 00000000..13fa45dd --- /dev/null +++ b/ably/example_message_updates_test.go @@ -0,0 +1,96 @@ +package ably_test + +import ( + "context" + "fmt" + + "github.com/ably/ably-go/ably" +) + +// Example demonstrating how to publish a message and get its serial +func ExampleRESTChannel_PublishWithResult() { + client, err := ably.NewREST(ably.WithKey("xxx:xxx")) + if err != nil { + panic(err) + } + + channel := client.Channels.Get("example-channel") + + // Publish a message and get its serial + result, err := channel.PublishWithResult(context.Background(), "event-name", "message data") + if err != nil { + panic(err) + } + + fmt.Printf("Message published with serial: %s\n", result.Serial) +} + +// Example demonstrating how to update a message +func ExampleRESTChannel_UpdateMessage() { + client, err := ably.NewREST(ably.WithKey("xxx:xxx")) + if err != nil { + panic(err) + } + + channel := client.Channels.Get("example-channel") + + // First publish a message to get its serial + result, err := channel.PublishWithResult(context.Background(), "event", "initial data") + if err != nil { + panic(err) + } + + // Update the message + msg := &ably.Message{ + Serial: result.Serial, + Data: "updated data", + } + + updateResult, err := channel.UpdateMessage( + context.Background(), + msg, + ably.UpdateWithDescription("Fixed typo"), + ably.UpdateWithMetadata(map[string]string{"editor": "alice"}), + ) + if err != nil { + panic(err) + } + + fmt.Printf("Message updated with version serial: %s\n", updateResult.VersionSerial) +} + +// Example demonstrating async message append for AI streaming +func ExampleRealtimeChannel_AppendMessageAsync() { + client, err := ably.NewRealtime(ably.WithKey("xxx:xxx")) + if err != nil { + panic(err) + } + + channel := client.Channels.Get("chat-channel") + + // Publish initial message + result, err := channel.PublishWithResult(context.Background(), "ai-response", "The answer is") + if err != nil { + panic(err) + } + + // Stream tokens asynchronously without blocking + tokens := []string{" 42", ".", " This", " is", " the", " answer."} + for _, token := range tokens { + msg := &ably.Message{ + Serial: result.Serial, + Data: token, + } + // Non-blocking append - critical for AI streaming + err := channel.AppendMessageAsync(msg, func(r *ably.UpdateResult, err error) { + if err != nil { + fmt.Printf("Append failed: %v\n", err) + } + }) + if err != nil { + panic(err) + } + } + + fmt.Println("All tokens queued for append") +} diff --git a/ably/export_test.go b/ably/export_test.go index 9ac12bbb..da3b94d9 100644 --- a/ably/export_test.go +++ b/ably/export_test.go @@ -222,7 +222,7 @@ func (c *Connection) AckAll() { c.mtx.Unlock() c.log().Infof("Ack all %d messages waiting for ACK/NACK", len(cx)) for _, v := range cx { - v.onAck(nil) + v.callback.call(nil, nil) } } diff --git a/ably/message_updates_integration_test.go b/ably/message_updates_integration_test.go new file mode 100644 index 00000000..8327ee00 --- /dev/null +++ b/ably/message_updates_integration_test.go @@ -0,0 +1,347 @@ +//go:build !unit +// +build !unit + +package ably_test + +import ( + "context" + "testing" + "time" + + "github.com/ably/ably-go/ably" + "github.com/ably/ably-go/ablytest" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestRESTChannel_MessageUpdates(t *testing.T) { + app, err := ablytest.NewSandbox(nil) + require.NoError(t, err) + defer app.Close() + + client, err := ably.NewREST(app.Options()...) + require.NoError(t, err) + + ctx := context.Background() + + t.Run("PublishWithResult", func(t *testing.T) { + // Use mutable: namespace to enable message operations feature + channel := client.Channels.Get("mutable:test_publish_with_result") + + t.Run("returns serial for published message", func(t *testing.T) { + result, err := channel.PublishWithResult(ctx, "event1", "test data") + require.NoError(t, err) + assert.NotEmpty(t, result.Serial, "Expected non-empty serial") + }) + + t.Run("PublishMultipleWithResult returns serials for all messages", func(t *testing.T) { + messages := []*ably.Message{ + {Name: "event1", Data: "data1"}, + {Name: "event2", Data: "data2"}, + {Name: "event3", Data: "data3"}, + } + + results, err := channel.PublishMultipleWithResult(ctx, messages) + require.NoError(t, err) + assert.Len(t, results, 3) + + for i, result := range results { + assert.NotEmpty(t, result.Serial, "Expected non-empty serial for message %d", i) + } + }) + }) + + t.Run("UpdateMessage", func(t *testing.T) { + channel := client.Channels.Get("mutable:test_update_message") + + t.Run("updates a message with new data", func(t *testing.T) { + // Publish initial message + publishResult, err := channel.PublishWithResult(ctx, "test-event", "initial data") + require.NoError(t, err) + + // Update the message + msg := &ably.Message{ + Serial: publishResult.Serial, + Data: "updated data", + } + updateResult, err := channel.UpdateMessage(ctx, msg, + ably.UpdateWithDescription("Fixed typo"), + ably.UpdateWithMetadata(map[string]string{"editor": "test"}), + ) + require.NoError(t, err) + assert.NotEmpty(t, updateResult.VersionSerial, "Expected version serial") + assert.NotEqual(t, publishResult.Serial, updateResult.VersionSerial, "VersionSerial should differ from original Serial") + + // Verify the update by fetching the message (eventually consistent) + require.Eventually(t, func() bool { + retrieved, err := channel.GetMessage(ctx, publishResult.Serial) + if err != nil { + return false + } + return retrieved.Data == "updated data" + }, 5*time.Second, 100*time.Millisecond, "Updated message should be retrievable") + }) + + t.Run("returns error when message has no serial", func(t *testing.T) { + msg := &ably.Message{Data: "test"} + _, err := channel.UpdateMessage(ctx, msg) + require.Error(t, err) + + errorInfo, ok := err.(*ably.ErrorInfo) + require.True(t, ok, "Expected ErrorInfo") + assert.Equal(t, ably.ErrorCode(40003), errorInfo.Code) + assert.Contains(t, errorInfo.Message(), "lacks a serial") + }) + }) + + t.Run("DeleteMessage", func(t *testing.T) { + channel := client.Channels.Get("mutable:test_delete_message") + + t.Run("deletes a message", func(t *testing.T) { + // Publish initial message + publishResult, err := channel.PublishWithResult(ctx, "test-event", "data to delete") + require.NoError(t, err) + + // Delete the message + msg := &ably.Message{ + Serial: publishResult.Serial, + } + deleteResult, err := channel.DeleteMessage(ctx, msg, + ably.UpdateWithDescription("Deleted by test"), + ) + require.NoError(t, err) + assert.NotEmpty(t, deleteResult.VersionSerial, "Expected version serial") + }) + }) + + t.Run("AppendMessage", func(t *testing.T) { + channel := client.Channels.Get("mutable:test_append_message") + + t.Run("appends to a message", func(t *testing.T) { + // Publish initial message + publishResult, err := channel.PublishWithResult(ctx, "test-event", "Hello") + require.NoError(t, err) + + // Append to the message + msg := &ably.Message{ + Serial: publishResult.Serial, + Data: " World", + } + appendResult, err := channel.AppendMessage(ctx, msg) + require.NoError(t, err) + assert.NotEmpty(t, appendResult.VersionSerial, "Expected version serial") + + // Verify by fetching the message - data should be appended (eventually consistent) + require.Eventually(t, func() bool { + retrieved, err := channel.GetMessage(ctx, publishResult.Serial) + if err != nil { + return false + } + // Verify the data was appended: "Hello" + " World" = "Hello World" + return retrieved.Data == "Hello World" + }, 5*time.Second, 100*time.Millisecond, "Message data should be appended") + }) + }) + + t.Run("GetMessage", func(t *testing.T) { + channel := client.Channels.Get("mutable:test_get_message") + + t.Run("retrieves a message by serial", func(t *testing.T) { + // Publish a message + publishResult, err := channel.PublishWithResult(ctx, "test-event", "test data") + require.NoError(t, err) + + // GetMessage is eventually consistent - retry until message is available + require.Eventually(t, func() bool { + msg, err := channel.GetMessage(ctx, publishResult.Serial) + if err != nil { + return false + } + return msg.Data == "test data" && msg.Serial == publishResult.Serial + }, 5*time.Second, 100*time.Millisecond, "Message should be retrievable") + }) + }) + + t.Run("GetMessageVersions", func(t *testing.T) { + channel := client.Channels.Get("mutable:test_get_message_versions") + + t.Run("retrieves all versions after updates", func(t *testing.T) { + // Publish initial message + publishResult, err := channel.PublishWithResult(ctx, "test-event", "version 1") + require.NoError(t, err) + + // Update the message twice + msg := &ably.Message{ + Serial: publishResult.Serial, + Data: "version 2", + } + _, err = channel.UpdateMessage(ctx, msg, ably.UpdateWithDescription("First update")) + require.NoError(t, err) + + msg.Data = "version 3" + _, err = channel.UpdateMessage(ctx, msg, ably.UpdateWithDescription("Second update")) + require.NoError(t, err) + + // GetMessageVersions is eventually consistent - retry until all versions are available + var versions []*ably.Message + require.Eventually(t, func() bool { + page, err := channel.GetMessageVersions(publishResult.Serial, nil).Pages(ctx) + if err != nil { + return false + } + + // Must call Next() to decode the response body into items + if !page.Next(ctx) { + return false + } + + versions = page.Items() + + // Should have exactly 3 versions: original publish + 2 updates + return len(versions) == 3 + }, 10*time.Second, 200*time.Millisecond, "All three message versions should be retrievable") + + // Verify we have exactly 3 versions in the correct order + require.Equal(t, 3, len(versions)) + assert.Equal(t, ably.MessageActionCreate, versions[0].Action, "First version should be message.create") + assert.Equal(t, ably.MessageActionUpdate, versions[1].Action, "Second version should be message.update") + assert.Equal(t, ably.MessageActionUpdate, versions[2].Action, "Third version should be message.update") + }) + }) +} + +func TestRealtimeChannel_MessageUpdates(t *testing.T) { + app, err := ablytest.NewSandbox(nil) + require.NoError(t, err) + defer app.Close() + + client, err := ably.NewRealtime(app.Options()...) + require.NoError(t, err) + defer client.Close() + + ctx := context.Background() + + // Wait for connection + err = ablytest.Wait(ablytest.ConnWaiter(client, nil, ably.ConnectionEventConnected), nil) + require.NoError(t, err) + + t.Run("PublishWithResult", func(t *testing.T) { + channel := client.Channels.Get("mutable:test_realtime_publish_with_result") + + // Attach channel + err := channel.Attach(ctx) + require.NoError(t, err) + + t.Run("returns serial for published message", func(t *testing.T) { + result, err := channel.PublishWithResult(ctx, "event1", "realtime data") + require.NoError(t, err) + assert.NotEmpty(t, result.Serial, "Expected non-empty serial") + }) + + t.Run("PublishMultipleWithResult returns serials", func(t *testing.T) { + messages := []*ably.Message{ + {Name: "evt1", Data: "data1"}, + {Name: "evt2", Data: "data2"}, + } + + results, err := channel.PublishMultipleWithResult(ctx, messages) + require.NoError(t, err) + assert.Len(t, results, 2) + + for i, result := range results { + assert.NotEmpty(t, result.Serial, "Expected serial for message %d", i) + } + }) + }) + + t.Run("UpdateMessageAsync", func(t *testing.T) { + channel := client.Channels.Get("mutable:test_realtime_update_async") + + // Attach channel + err := channel.Attach(ctx) + require.NoError(t, err) + + t.Run("updates message asynchronously", func(t *testing.T) { + // Publish initial message + publishResult, err := channel.PublishWithResult(ctx, "event", "initial") + require.NoError(t, err) + + // Update asynchronously + done := make(chan *ably.UpdateResult, 1) + errChan := make(chan error, 1) + + msg := &ably.Message{ + Serial: publishResult.Serial, + Data: "updated async", + } + err = channel.UpdateMessageAsync(msg, func(result *ably.UpdateResult, err error) { + if err != nil { + errChan <- err + } else { + done <- result + } + }) + require.NoError(t, err) + + // Wait for callback + select { + case result := <-done: + assert.NotEmpty(t, result.VersionSerial) + case err := <-errChan: + t.Fatalf("Update failed: %v", err) + case <-time.After(5 * time.Second): + t.Fatal("Timeout waiting for update callback") + } + }) + }) + + t.Run("AppendMessageAsync", func(t *testing.T) { + channel := client.Channels.Get("mutable:test_ai_streaming") + + // Attach channel + err := channel.Attach(ctx) + require.NoError(t, err) + + t.Run("rapid async appends for AI token streaming", func(t *testing.T) { + // Publish initial message + publishResult, err := channel.PublishWithResult(ctx, "ai-response", "The answer is") + require.NoError(t, err) + + // Simulate rapid token streaming + tokens := []string{" 42", ".", " This", " is", " correct", "."} + type ack struct { + result *ably.UpdateResult + err error + } + acks := make(chan ack, len(tokens)) + + for _, token := range tokens { + msg := &ably.Message{ + Serial: publishResult.Serial, + Data: token, + } + + err := channel.AppendMessageAsync(msg, func(result *ably.UpdateResult, err error) { + acks <- ack{result, err} + }) + require.NoError(t, err, "Failed to queue append %q", token) + } + + // Wait for all appends to complete (with timeout) + timeout := time.After(10 * time.Second) + ackCount := 0 + for ackCount < len(tokens) { + select { + case ack := <-acks: + require.NoError(t, ack.err) + ackCount++ + case <-timeout: + t.Fatalf("Timeout: Only %d/%d appends completed before timeout", ackCount, len(tokens)) + } + } + + assert.Equal(t, len(tokens), ackCount, "All appends should complete") + }) + }) +} diff --git a/ably/paginated_result.go b/ably/paginated_result.go index d4957f9a..d3296352 100644 --- a/ably/paginated_result.go +++ b/ably/paginated_result.go @@ -20,6 +20,7 @@ type paginatedRequest struct { path string rawPath string params url.Values + header http.Header // Optional custom headers query queryFunc } diff --git a/ably/proto_http.go b/ably/proto_http.go index 28dc553b..787c6f91 100644 --- a/ably/proto_http.go +++ b/ably/proto_http.go @@ -13,11 +13,18 @@ const ( ablyErrorMessageHeader = "X-Ably-Errormessage" clientLibraryVersion = "1.3.0" clientRuntimeName = "go" - ablyProtocolVersion = "2" // CSV2 - ablyClientIDHeader = "X-Ably-ClientId" - hostHeader = "Host" - ablyAgentHeader = "Ably-Agent" // RSC7d - ablySDKIdentifier = "ably-go/" + clientLibraryVersion // RSC7d1 + // ablyProtocolVersion is the default Ably protocol version used for all requests. + // Protocol v5 is required for message operations (publish/update/delete/append) to return + // message serials and version information. + // + // Note: Stats requests explicitly override this to use protocol v2 to maintain compatibility + // with the existing nested Stats structure. Migrating stats to the flattened v3+ format + // requires breaking API changes and is planned for ably-go v2.0. + ablyProtocolVersion = "5" // CSV2 + ablyClientIDHeader = "X-Ably-ClientId" + hostHeader = "Host" + ablyAgentHeader = "Ably-Agent" // RSC7d + ablySDKIdentifier = "ably-go/" + clientLibraryVersion // RSC7d1 ) var goRuntimeIdentifier = func() string { diff --git a/ably/proto_message.go b/ably/proto_message.go index 3dd2e4a4..8cba7117 100644 --- a/ably/proto_message.go +++ b/ably/proto_message.go @@ -7,6 +7,8 @@ import ( "fmt" "strings" "unicode/utf8" + + "github.com/ugorji/go/codec" ) // encodings @@ -18,6 +20,140 @@ const ( encVCDiff = "vcdiff" ) +// MessageAction represents the type of operation performed on a message. +type MessageAction string + +const ( + MessageActionCreate MessageAction = "message.create" + MessageActionUpdate MessageAction = "message.update" + MessageActionDelete MessageAction = "message.delete" + MessageActionAppend MessageAction = "message.append" +) + +// MarshalJSON implements json.Marshaler to encode MessageAction as numeric for wire compatibility. +func (a MessageAction) MarshalJSON() ([]byte, error) { + var num int + switch a { + case MessageActionCreate: + num = 0 + case MessageActionUpdate: + num = 1 + case MessageActionDelete: + num = 2 + case MessageActionAppend: + num = 5 + default: + num = 0 + } + return json.Marshal(num) +} + +// UnmarshalJSON implements json.Unmarshaler to decode numeric wire format to MessageAction. +func (a *MessageAction) UnmarshalJSON(data []byte) error { + var num int + if err := json.Unmarshal(data, &num); err != nil { + return err + } + switch num { + case 0: + *a = MessageActionCreate + case 1: + *a = MessageActionUpdate + case 2: + *a = MessageActionDelete + case 5: + *a = MessageActionAppend + default: + *a = MessageActionCreate + } + return nil +} + +// CodecEncodeSelf implements codec.Selfer for MessagePack encoding. +func (a MessageAction) CodecEncodeSelf(encoder *codec.Encoder) { + var num int + switch a { + case MessageActionCreate: + num = 0 + case MessageActionUpdate: + num = 1 + case MessageActionDelete: + num = 2 + case MessageActionAppend: + num = 5 + default: + num = 0 + } + encoder.MustEncode(num) +} + +// CodecDecodeSelf implements codec.Selfer for MessagePack decoding. +func (a *MessageAction) CodecDecodeSelf(decoder *codec.Decoder) { + var num int + decoder.MustDecode(&num) + switch num { + case 0: + *a = MessageActionCreate + case 1: + *a = MessageActionUpdate + case 2: + *a = MessageActionDelete + case 5: + *a = MessageActionAppend + default: + *a = MessageActionCreate + } +} + +// MessageVersion contains version information for a message operation. +type MessageVersion struct { + Serial string `json:"serial,omitempty" codec:"serial,omitempty"` + Timestamp int64 `json:"timestamp,omitempty" codec:"timestamp,omitempty"` + ClientID string `json:"clientId,omitempty" codec:"clientId,omitempty"` + Description string `json:"description,omitempty" codec:"description,omitempty"` + Metadata map[string]string `json:"metadata,omitempty" codec:"metadata,omitempty"` +} + +// PublishResult contains the result of a publish operation with serial tracking. +type PublishResult struct { + Serial string // May be empty if message discarded by conflation +} + +// UpdateResult contains the result of an update, delete, or append operation. +type UpdateResult struct { + VersionSerial string // Serial of new version, may be empty if superseded +} + +// UpdateOption is a functional option for message update operations. +type UpdateOption func(*updateOptions) + +type updateOptions struct { + description string + clientID string + metadata map[string]string +} + +// UpdateWithDescription sets a description for the update operation. +func UpdateWithDescription(description string) UpdateOption { + return func(o *updateOptions) { + o.description = description + } +} + +// UpdateWithClientID sets the client ID for the update operation. +func UpdateWithClientID(clientID string) UpdateOption { + return func(o *updateOptions) { + o.clientID = clientID + } +} + +// UpdateWithMetadata sets metadata for the update operation. +func UpdateWithMetadata(metadata map[string]string) UpdateOption { + return func(o *updateOptions) { + o.metadata = metadata + } +} + // Message contains an individual message that is sent to, or received from, Ably. type Message struct { // ID is a unique identifier assigned by Ably to this message (TM2a). @@ -42,6 +178,12 @@ type Message struct { // Extras is a JSON object of arbitrary key-value pairs that may contain metadata, and/or ancillary payloads. // Valid payloads include push, deltaExtras, ReferenceExtras and headers (TM2i). Extras map[string]interface{} `json:"extras,omitempty" codec:"extras,omitempty"` + // Serial is a permanent identifier for this message assigned by the server. + Serial string `json:"serial,omitempty" codec:"serial,omitempty"` + // Action indicates the type of operation (create, update, delete, append) performed on this message. + Action MessageAction `json:"action,omitempty" codec:"action,omitempty"` + // Version contains version information for message update/delete/append operations. + Version *MessageVersion `json:"version,omitempty" codec:"version,omitempty"` } // DeltaExtras describes a message whose payload is a "vcdiff"-encoded delta generated with respect to a base message (DE1, DE2). diff --git a/ably/proto_message_operations_test.go b/ably/proto_message_operations_test.go new file mode 100644 index 00000000..418a2fa6 --- /dev/null +++ b/ably/proto_message_operations_test.go @@ -0,0 +1,232 @@ +package ably + +import ( + "encoding/json" + "testing" + + "github.com/ugorji/go/codec" +) + +func TestMessageAction_JSON_Encoding(t *testing.T) { + tests := []struct { + action MessageAction + expected string + }{ + {MessageActionCreate, "0"}, + {MessageActionUpdate, "1"}, + {MessageActionDelete, "2"}, + {MessageActionAppend, "5"}, + } + + for _, tt := range tests { + t.Run(string(tt.action), func(t *testing.T) { + data, err := json.Marshal(tt.action) + if err != nil { + t.Fatalf("Failed to marshal: %v", err) + } + if string(data) != tt.expected { + t.Errorf("Expected %s, got %s", tt.expected, string(data)) + } + }) + } +} + +func TestMessageAction_JSON_Decoding(t *testing.T) { + tests := []struct { + input string + expected MessageAction + }{ + {"0", MessageActionCreate}, + {"1", MessageActionUpdate}, + {"2", MessageActionDelete}, + {"5", MessageActionAppend}, + {"999", MessageActionCreate}, // Unknown values default to Create + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + var action MessageAction + err := json.Unmarshal([]byte(tt.input), &action) + if err != nil { + t.Fatalf("Failed to unmarshal: %v", err) + } + if action != tt.expected { + t.Errorf("Expected %s, got %s", tt.expected, action) + } + }) + } +} + +func TestMessageAction_Codec_Encoding(t *testing.T) { + tests := []struct { + action MessageAction + num int + }{ + {MessageActionCreate, 0}, + {MessageActionUpdate, 1}, + {MessageActionDelete, 2}, + {MessageActionAppend, 5}, + } + + for _, tt := range tests { + t.Run(string(tt.action), func(t *testing.T) { + var buf []byte + enc := codec.NewEncoderBytes(&buf, &codec.MsgpackHandle{}) + tt.action.CodecEncodeSelf(enc) + + // Decode to verify + var result int + dec := codec.NewDecoderBytes(buf, &codec.MsgpackHandle{}) + dec.MustDecode(&result) + + if result != tt.num { + t.Errorf("Expected %d, got %d", tt.num, result) + } + }) + } +} + +func TestMessageAction_Codec_Decoding(t *testing.T) { + tests := []struct { + num int + expected MessageAction + }{ + {0, MessageActionCreate}, + {1, MessageActionUpdate}, + {2, MessageActionDelete}, + {5, MessageActionAppend}, + {999, MessageActionCreate}, // Unknown values default to Create + } + + for _, tt := range tests { + t.Run(string(tt.expected), func(t *testing.T) { + var buf []byte + enc := codec.NewEncoderBytes(&buf, &codec.MsgpackHandle{}) + enc.MustEncode(tt.num) + + var action MessageAction + dec := codec.NewDecoderBytes(buf, &codec.MsgpackHandle{}) + action.CodecDecodeSelf(dec) + + if action != tt.expected { + t.Errorf("Expected %s, got %s", tt.expected, action) + } + }) + } +} + +func TestMessageVersion_Serialization(t *testing.T) { + version := &MessageVersion{ + Serial: "abc123", + Timestamp: 1234567890, + ClientID: "client1", + Description: "Test update", + Metadata: map[string]string{ + "key1": "value1", + "key2": "value2", + }, + } + + // Test JSON serialization + data, err := json.Marshal(version) + if err != nil { + t.Fatalf("Failed to marshal MessageVersion: %v", err) + } + + var decoded MessageVersion + err = json.Unmarshal(data, &decoded) + if err != nil { + t.Fatalf("Failed to unmarshal MessageVersion: %v", err) + } + + if decoded.Serial != version.Serial { + t.Errorf("Serial: expected %s, got %s", version.Serial, decoded.Serial) + } + if decoded.Timestamp != version.Timestamp { + t.Errorf("Timestamp: expected %d, got %d", version.Timestamp, decoded.Timestamp) + } + if decoded.ClientID != version.ClientID { + t.Errorf("ClientID: expected %s, got %s", version.ClientID, decoded.ClientID) + } + if decoded.Description != version.Description { + t.Errorf("Description: expected %s, got %s", version.Description, decoded.Description) + } + if len(decoded.Metadata) != len(version.Metadata) { + t.Errorf("Metadata length: expected %d, got %d", len(version.Metadata), len(decoded.Metadata)) + } +} + +func TestMessage_NewFields_Serialization(t *testing.T) { + msg := &Message{ + ID: "msg123", + Data: "test data", + Serial: "serial123", + Action: MessageActionUpdate, + Version: &MessageVersion{ + Serial: "version_serial", + ClientID: "client1", + Description: "Updated message", + }, + } + + // Test JSON serialization + data, err := json.Marshal(msg) + if err != nil { + t.Fatalf("Failed to marshal Message: %v", err) + } + + var decoded Message + err = json.Unmarshal(data, &decoded) + if err != nil { + t.Fatalf("Failed to unmarshal Message: %v", err) + } + + if decoded.Serial != msg.Serial { + t.Errorf("Serial: expected %s, got %s", msg.Serial, decoded.Serial) + } + if decoded.Action != msg.Action { + t.Errorf("Action: expected %s, got %s", msg.Action, decoded.Action) + } + if decoded.Version == nil { + t.Fatal("Version should not be nil") + } + if decoded.Version.Serial != msg.Version.Serial { + t.Errorf("Version.Serial: expected %s, got %s", msg.Version.Serial, decoded.Version.Serial) + } +} + +func TestValidateMessageSerial(t *testing.T) { + t.Run("nil message", func(t *testing.T) { + err := validateMessageSerial(nil) + if err == nil { + t.Fatal("Expected error for nil message") + } + if code(err) != 40003 { + t.Errorf("Expected error code 40003, got %d", code(err)) + } + }) + + t.Run("empty serial", func(t *testing.T) { + msg := &Message{Data: "test"} + err := validateMessageSerial(msg) + if err == nil { + t.Fatal("Expected error for message without serial") + } + if code(err) != 40003 { + t.Errorf("Expected error code 40003, got %d", code(err)) + } + // Verify exact error message matches TypeScript + expectedMsg := "this message lacks a serial and cannot be updated. Make sure you have enabled \"Message annotations, updates, and deletes\" in channel settings on your dashboard" + if err.(*ErrorInfo).Message() != expectedMsg { + t.Errorf("Error message mismatch.\nExpected: %s\nGot: %s", expectedMsg, err.(*ErrorInfo).Message()) + } + }) + + t.Run("valid serial", func(t *testing.T) { + msg := &Message{Data: "test", Serial: "abc123"} + err := validateMessageSerial(msg) + if err != nil { + t.Errorf("Expected no error for valid message, got %v", err) + } + }) +} diff --git a/ably/proto_protocol_message.go b/ably/proto_protocol_message.go index 809d2119..b9979dd4 100644 --- a/ably/proto_protocol_message.go +++ b/ably/proto_protocol_message.go @@ -117,25 +117,31 @@ func coerceInt64(v interface{}) int64 { } } +// protocolPublishResult matches the wire format for publish results in ACK messages. +type protocolPublishResult struct { + Serials []string `json:"serials,omitempty" codec:"serials,omitempty"` +} + type protocolMessage struct { - Messages []*Message `json:"messages,omitempty" codec:"messages,omitempty"` - Presence []*PresenceMessage `json:"presence,omitempty" codec:"presence,omitempty"` - State []*objects.Message `json:"state,omitempty" codec:"state,omitempty"` - ID string `json:"id,omitempty" codec:"id,omitempty"` - ApplicationID string `json:"applicationId,omitempty" codec:"applicationId,omitempty"` - ConnectionID string `json:"connectionId,omitempty" codec:"connectionId,omitempty"` - ConnectionKey string `json:"connectionKey,omitempty" codec:"connectionKey,omitempty"` - Channel string `json:"channel,omitempty" codec:"channel,omitempty"` - ChannelSerial string `json:"channelSerial,omitempty" codec:"channelSerial,omitempty"` - ConnectionDetails *connectionDetails `json:"connectionDetails,omitempty" codec:"connectionDetails,omitempty"` - Error *errorInfo `json:"error,omitempty" codec:"error,omitempty"` - MsgSerial int64 `json:"msgSerial" codec:"msgSerial"` - Timestamp int64 `json:"timestamp,omitempty" codec:"timestamp,omitempty"` - Count int `json:"count,omitempty" codec:"count,omitempty"` - Action protoAction `json:"action,omitempty" codec:"action,omitempty"` - Flags protoFlag `json:"flags,omitempty" codec:"flags,omitempty"` - Params channelParams `json:"params,omitempty" codec:"params,omitempty"` - Auth *authDetails `json:"auth,omitempty" codec:"auth,omitempty"` + Messages []*Message `json:"messages,omitempty" codec:"messages,omitempty"` + Presence []*PresenceMessage `json:"presence,omitempty" codec:"presence,omitempty"` + State []*objects.Message `json:"state,omitempty" codec:"state,omitempty"` + ID string `json:"id,omitempty" codec:"id,omitempty"` + ApplicationID string `json:"applicationId,omitempty" codec:"applicationId,omitempty"` + ConnectionID string `json:"connectionId,omitempty" codec:"connectionId,omitempty"` + ConnectionKey string `json:"connectionKey,omitempty" codec:"connectionKey,omitempty"` + Channel string `json:"channel,omitempty" codec:"channel,omitempty"` + ChannelSerial string `json:"channelSerial,omitempty" codec:"channelSerial,omitempty"` + ConnectionDetails *connectionDetails `json:"connectionDetails,omitempty" codec:"connectionDetails,omitempty"` + Error *errorInfo `json:"error,omitempty" codec:"error,omitempty"` + MsgSerial int64 `json:"msgSerial" codec:"msgSerial"` + Timestamp int64 `json:"timestamp,omitempty" codec:"timestamp,omitempty"` + Count int `json:"count,omitempty" codec:"count,omitempty"` + Action protoAction `json:"action,omitempty" codec:"action,omitempty"` + Flags protoFlag `json:"flags,omitempty" codec:"flags,omitempty"` + Params channelParams `json:"params,omitempty" codec:"params,omitempty"` + Auth *authDetails `json:"auth,omitempty" codec:"auth,omitempty"` + Res []*protocolPublishResult `json:"res,omitempty" codec:"res,omitempty"` } // authDetails contains the token string used to authenticate a client with Ably. diff --git a/ably/realtime_channel.go b/ably/realtime_channel.go index da96dd8e..7418b73f 100644 --- a/ably/realtime_channel.go +++ b/ably/realtime_channel.go @@ -5,6 +5,7 @@ import ( "encoding/base64" "errors" "fmt" + "net/url" "sort" "sync" @@ -728,7 +729,87 @@ func (c *RealtimeChannel) PublishMultipleAsync(messages []*Message, onAck func(e Channel: c.Name, Messages: messages, } - return c.send(msg, onAck) + return c.send(msg, &ackCallback{onAck: onAck}) +} + +// PublishWithResult publishes a single message to the channel and returns the serial assigned by the server. +// This will block until either the publish is acknowledged or fails to deliver. +func (c *RealtimeChannel) PublishWithResult(ctx context.Context, name string, data interface{}) (*PublishResult, error) { + results, err := c.PublishMultipleWithResult(ctx, []*Message{{Name: name, Data: data}}) + if err != nil { + return nil, err + } + if len(results) == 0 { + return &PublishResult{}, nil + } + return &results[0], nil +} + +// PublishWithResultAsync is the same as PublishWithResult except instead of blocking it calls onAck +// with the result or error. Note onAck must not block as it would block the internal client. +func (c *RealtimeChannel) PublishWithResultAsync(name string, data interface{}, onAck func(*PublishResult, error)) error { + return c.PublishMultipleWithResultAsync([]*Message{{Name: name, Data: data}}, func(results []PublishResult, err error) { + if err != nil { + onAck(nil, err) + return + } + if len(results) == 0 { + onAck(&PublishResult{}, nil) + return + } + onAck(&results[0], nil) + }) +} + +// PublishMultipleWithResult publishes all given messages on the channel and returns the serials assigned by the server. +func (c *RealtimeChannel) PublishMultipleWithResult(ctx context.Context, messages []*Message) ([]PublishResult, error) { + type resultOrError struct { + results []PublishResult + err error + } + listen := make(chan resultOrError, 1) + onAck := func(results []PublishResult, err error) { + listen <- resultOrError{results, err} + } + if err := c.PublishMultipleWithResultAsync(messages, onAck); err != nil { + return nil, err + } + + select { + case <-ctx.Done(): + return nil, ctx.Err() + case result := <-listen: + return result.results, result.err + } +} + +// PublishMultipleWithResultAsync is the same as PublishMultipleWithResult except it calls onAck instead of blocking. +func (c *RealtimeChannel) PublishMultipleWithResultAsync(messages []*Message, onAck func([]PublishResult, error)) error { + id := c.client.Auth.clientIDForCheck() + for _, v := range messages { + if v.ClientID != "" && id != wildcardClientID && v.ClientID != id { + // Spec RTL6g3,RTL6g4 + return fmt.Errorf("Unable to publish message containing a clientId (%s) that is incompatible with the library clientId (%s)", v.ClientID, id) + } + } + msg := &protocolMessage{ + Action: actionMessage, + Channel: c.Name, + Messages: messages, + } + return c.sendWithSerialCallback(msg, func(serials []string, err error) { + if err != nil { + onAck(nil, err) + return + } + results := make([]PublishResult, len(messages)) + for i := range results { + if i < len(serials) { + results[i].Serial = serials[i] + } + } + onAck(results, nil) + }) } // History retrieves a [ably.HistoryRequest] object, containing an array of historical @@ -768,8 +849,116 @@ func (c *RealtimeChannel) HistoryUntilAttach(o ...HistoryOption) (*HistoryReques return &historyRequest, nil } -func (c *RealtimeChannel) send(msg *protocolMessage, onAck func(err error)) error { - if enqueued := c.maybeEnqueue(msg, onAck); enqueued { +// performMessageOperationAsync is a shared helper for UpdateMessageAsync, DeleteMessageAsync, and AppendMessageAsync. +// It validates the message serial, applies update options, sets the action, and sends the protocol message. +func (c *RealtimeChannel) performMessageOperationAsync(msg *Message, action MessageAction, onAck func(*UpdateResult, error), options ...UpdateOption) error { + if err := validateMessageSerial(msg); err != nil { + return err + } + + // Apply options + var opts updateOptions + for _, o := range options { + o(&opts) + } + + // Build version from options + version := &MessageVersion{ + Description: opts.description, + ClientID: opts.clientID, + Metadata: opts.metadata, + } + + // Create message for the operation + opMsg := *msg + opMsg.Action = action + opMsg.Version = version + + protoMsg := &protocolMessage{ + Action: actionMessage, + Channel: c.Name, + Messages: []*Message{&opMsg}, + } + + return c.sendWithSerialCallback(protoMsg, func(serials []string, err error) { + if err != nil { + onAck(nil, err) + return + } + result := &UpdateResult{} + if len(serials) > 0 { + result.VersionSerial = serials[0] + } + onAck(result, nil) + }) +} + +// performMessageOperation is a shared blocking helper for UpdateMessage, DeleteMessage, and AppendMessage. +// It wraps performMessageOperationAsync with a channel-based blocking pattern. +func (c *RealtimeChannel) performMessageOperation(ctx context.Context, msg *Message, action MessageAction, options ...UpdateOption) (*UpdateResult, error) { + type resultOrError struct { + result *UpdateResult + err error + } + listen := make(chan resultOrError, 1) + onAck := func(result *UpdateResult, err error) { + listen <- resultOrError{result, err} + } + if err := c.performMessageOperationAsync(msg, action, onAck, options...); err != nil { + return nil, err + } + + select { + case <-ctx.Done(): + return nil, ctx.Err() + case result := <-listen: + return result.result, result.err + } +} + +// UpdateMessage updates a previously published message. +func (c *RealtimeChannel) UpdateMessage(ctx context.Context, msg *Message, options ...UpdateOption) (*UpdateResult, error) { + return c.performMessageOperation(ctx, msg, MessageActionUpdate, options...) +} + +// UpdateMessageAsync is the same as UpdateMessage except instead of blocking it calls onAck. +func (c *RealtimeChannel) UpdateMessageAsync(msg *Message, onAck func(*UpdateResult, error), options ...UpdateOption) error { + return c.performMessageOperationAsync(msg, MessageActionUpdate, onAck, options...) +} + +// DeleteMessage deletes a previously published message. +func (c *RealtimeChannel) DeleteMessage(ctx context.Context, msg *Message, options ...UpdateOption) (*UpdateResult, error) { + return c.performMessageOperation(ctx, msg, MessageActionDelete, options...) +} + +// DeleteMessageAsync is the same as DeleteMessage except instead of blocking it calls onAck. +func (c *RealtimeChannel) DeleteMessageAsync(msg *Message, onAck func(*UpdateResult, error), options ...UpdateOption) error { + return c.performMessageOperationAsync(msg, MessageActionDelete, onAck, options...) +} + +// AppendMessage appends to a previously published message. +func (c *RealtimeChannel) AppendMessage(ctx context.Context, msg *Message, options ...UpdateOption) (*UpdateResult, error) { + return c.performMessageOperation(ctx, msg, MessageActionAppend, options...) +} + +// AppendMessageAsync is the same as AppendMessage except instead of blocking it calls onAck. +// This is critical for AI token streaming use cases where rapid appends should not block. +func (c *RealtimeChannel) AppendMessageAsync(msg *Message, onAck func(*UpdateResult, error), options ...UpdateOption) error { + return c.performMessageOperationAsync(msg, MessageActionAppend, onAck, options...) +} + +// GetMessage retrieves a message by its serial (delegates to REST). +func (c *RealtimeChannel) GetMessage(ctx context.Context, serial string) (*Message, error) { + return c.client.rest.Channels.Get(c.Name).GetMessage(ctx, serial) +} + +// GetMessageVersions retrieves the version history of a message by its serial (delegates to REST). +func (c *RealtimeChannel) GetMessageVersions(serial string, params url.Values) HistoryRequest { + return c.client.rest.Channels.Get(c.Name).GetMessageVersions(serial, params) +} + +func (c *RealtimeChannel) send(msg *protocolMessage, callback *ackCallback) error { + if enqueued := c.maybeEnqueue(msg, callback); enqueued { return nil } @@ -777,11 +966,16 @@ func (c *RealtimeChannel) send(msg *protocolMessage, onAck func(err error)) erro return newError(ErrChannelOperationFailedInvalidChannelState, nil) } - c.client.Connection.send(msg, onAck) + c.client.Connection.send(msg, callback) return nil } -func (c *RealtimeChannel) maybeEnqueue(msg *protocolMessage, onAck func(err error)) bool { +// sendWithSerialCallback sends a message and calls onAck with serials extracted from ACK. +func (c *RealtimeChannel) sendWithSerialCallback(msg *protocolMessage, onAck func(serials []string, err error)) error { + return c.send(msg, &ackCallback{onAckWithSerials: onAck}) +} + +func (c *RealtimeChannel) maybeEnqueue(msg *protocolMessage, callback *ackCallback) bool { // RTL6c2 if c.opts().NoQueueing { return false @@ -803,7 +997,7 @@ func (c *RealtimeChannel) maybeEnqueue(msg *protocolMessage, onAck func(err erro ChannelStateDetaching: } - c.queue.Enqueue(msg, onAck) + c.queue.Enqueue(msg, callback) return true } diff --git a/ably/realtime_conn.go b/ably/realtime_conn.go index b913ab2e..d99c99fb 100644 --- a/ably/realtime_conn.go +++ b/ably/realtime_conn.go @@ -614,35 +614,36 @@ func (c *Connection) advanceSerial() { c.msgSerial = (c.msgSerial + 1) % maxint64 } -func (c *Connection) send(msg *protocolMessage, onAck func(err error)) { +// send sends a message with a callback. +func (c *Connection) send(msg *protocolMessage, callback *ackCallback) { hasMsgSerial := msg.Action == actionMessage || msg.Action == actionPresence || msg.Action == actionObject c.mtx.Lock() // RTP16a - in case of presence msg send, check for connection status and send accordingly switch state := c.state; state { default: c.mtx.Unlock() - if onAck != nil { + if callback != nil { if c.state == ConnectionStateClosed { - onAck(errClosed) + callback.call(nil, errClosed) } else { - onAck(connStateError(state, nil)) + callback.call(nil, connStateError(state, nil)) } } case ConnectionStateInitialized, ConnectionStateConnecting, ConnectionStateDisconnected: c.mtx.Unlock() if c.opts.NoQueueing { - if onAck != nil { - onAck(connStateError(state, errQueueing)) + if callback != nil { + callback.call(nil, connStateError(state, errQueueing)) } } else { - c.queue.Enqueue(msg, onAck) // RTL4i + c.queue.Enqueue(msg, callback) // RTL4i } case ConnectionStateConnected: if err := c.verifyAndUpdateMessages(msg); err != nil { c.mtx.Unlock() - if onAck != nil { - onAck(err) + if callback != nil { + callback.call(nil, err) } return } @@ -660,13 +661,13 @@ func (c *Connection) send(msg *protocolMessage, onAck func(err error)) { c.log().Warnf("transport level failure while sending message, %v", err) c.conn.Close() c.mtx.Unlock() - c.queue.Enqueue(msg, onAck) + c.queue.Enqueue(msg, callback) } else { if hasMsgSerial { c.advanceSerial() } - if onAck != nil { - c.pending.Enqueue(msg, onAck) + if callback != nil { + c.pending.Enqueue(msg, callback) } c.mtx.Unlock() } @@ -760,7 +761,7 @@ func (c *Connection) resendPending() { c.mtx.Unlock() c.log().Debugf("resending %d messages waiting for ACK/NACK", len(cx)) for _, v := range cx { - c.send(v.msg, v.onAck) + c.send(v.msg, v.callback) } } diff --git a/ably/realtime_experimental_objects.go b/ably/realtime_experimental_objects.go index b40ff5a1..e824935d 100644 --- a/ably/realtime_experimental_objects.go +++ b/ably/realtime_experimental_objects.go @@ -50,7 +50,7 @@ func (o *RealtimeExperimentalObjects) PublishObjects(ctx context.Context, msgs . Channel: o.channel.getName(), State: msgs, } - if err := o.channel.send(msg, onAck); err != nil { + if err := o.channel.send(msg, &ackCallback{onAck: onAck}); err != nil { return err } @@ -63,7 +63,7 @@ func (o *RealtimeExperimentalObjects) PublishObjects(ctx context.Context, msgs . } type channel interface { - send(msg *protocolMessage, onAck func(error)) error + send(msg *protocolMessage, callback *ackCallback) error getClientOptions() *clientOptions getName() string } diff --git a/ably/realtime_experimental_objects_test.go b/ably/realtime_experimental_objects_test.go index 21d1eaab..16ad3d32 100644 --- a/ably/realtime_experimental_objects_test.go +++ b/ably/realtime_experimental_objects_test.go @@ -167,14 +167,14 @@ func TestRealtimeExperimentalObjects_PublishObjects(t *testing.T) { // Create channel mock channelMock := &channelMock{ - SendFunc: func(msg *protocolMessage, onAck func(error)) error { + SendFunc: func(msg *protocolMessage, callback *ackCallback) error { if tt.sendError != nil { return tt.sendError } capturedProtocolMsg = msg // Simulate async ack go func() { - onAck(tt.ackError) + callback.call(nil, tt.ackError) }() return nil }, @@ -234,8 +234,8 @@ func TestRealtimeExperimentalObjects_PublishObjects(t *testing.T) { func TestRealtimeExperimentalObjects_PublishObjectsContextCancellation(t *testing.T) { // Test context cancellation during publish channelMock := &channelMock{ - SendFunc: func(msg *protocolMessage, onAck func(error)) error { - // Don't call onAck to simulate hanging + SendFunc: func(msg *protocolMessage, callback *ackCallback) error { + // Don't call callback to simulate hanging return nil }, GetClientOptionsFunc: func() *clientOptions { @@ -275,13 +275,13 @@ func TestRealtimeExperimentalObjects_PublishObjectsContextCancellation(t *testin // channelMock implements the channel interface type channelMock struct { - SendFunc func(msg *protocolMessage, onAck func(error)) error + SendFunc func(msg *protocolMessage, callback *ackCallback) error GetClientOptionsFunc func() *clientOptions GetNameFunc func() string } -func (c channelMock) send(msg *protocolMessage, onAck func(error)) error { - return c.SendFunc(msg, onAck) +func (c channelMock) send(msg *protocolMessage, callback *ackCallback) error { + return c.SendFunc(msg, callback) } func (c channelMock) getClientOptions() *clientOptions { diff --git a/ably/realtime_presence.go b/ably/realtime_presence.go index e93192bd..8ffc5f39 100644 --- a/ably/realtime_presence.go +++ b/ably/realtime_presence.go @@ -73,14 +73,14 @@ func (pres *RealtimePresence) onChannelSuspended(err error) { pres.queue.Fail(err) } -func (pres *RealtimePresence) maybeEnqueue(msg *protocolMessage, onAck func(err error)) bool { +func (pres *RealtimePresence) maybeEnqueue(msg *protocolMessage, callback *ackCallback) bool { if pres.channel.opts().NoQueueing { - if onAck != nil { - onAck(errors.New("unable enqueue message because Options.QueueMessages is set to false")) + if callback != nil { + callback.call(nil, errors.New("unable enqueue message because Options.QueueMessages is set to false")) } return false } - pres.queue.Enqueue(msg, onAck) + pres.queue.Enqueue(msg, callback) return true } @@ -98,15 +98,16 @@ func (pres *RealtimePresence) send(msg *PresenceMessage) (result, error) { onAck := func(err error) { listen <- err } + callback := &ackCallback{onAck: onAck} switch pres.channel.State() { case ChannelStateInitialized: // RTP16b - if pres.maybeEnqueue(protomsg, onAck) { + if pres.maybeEnqueue(protomsg, callback) { pres.channel.attach() } case ChannelStateAttaching: // RTP16b - pres.maybeEnqueue(protomsg, onAck) + pres.maybeEnqueue(protomsg, callback) case ChannelStateAttached: // RTP16a - pres.channel.client.Connection.send(protomsg, onAck) // RTP16a, RTL6c + pres.channel.client.Connection.send(protomsg, callback) // RTP16a, RTL6c } return resultFunc(func(ctx context.Context) error { diff --git a/ably/rest_channel.go b/ably/rest_channel.go index ffd55804..11a9c415 100644 --- a/ably/rest_channel.go +++ b/ably/rest_channel.go @@ -64,6 +64,22 @@ type publishMultipleOptions struct { params map[string]string } +// publishResponse represents the response from the server after publishing messages. +type publishResponse struct { + Serials []string `json:"serials,omitempty" codec:"serials,omitempty"` +} + +// validateMessageSerial validates that a message has a serial for update operations. +func validateMessageSerial(msg *Message) error { + if msg == nil { + return newError(40003, fmt.Errorf("message cannot be nil")) + } + if msg.Serial == "" { + return newError(40003, fmt.Errorf("this message lacks a serial and cannot be updated. Make sure you have enabled \"Message annotations, updates, and deletes\" in channel settings on your dashboard")) + } + return nil +} + // PublishWithConnectionKey allows a message to be published for a specified connectionKey. func PublishWithConnectionKey(connectionKey string) PublishMultipleOption { return func(options *publishMultipleOptions) { @@ -86,8 +102,9 @@ func PublishMultipleWithParams(params map[string]string) PublishMultipleOption { return PublishWithParams(params) } -// PublishMultiple publishes multiple messages in a batch. Returns error if there is a problem publishing message (RSL1). -func (c *RESTChannel) PublishMultiple(ctx context.Context, messages []*Message, options ...PublishMultipleOption) error { +// publishMultiple is the internal implementation for publishing multiple messages. +// If out is non-nil, the response body will be decoded into it. +func (c *RESTChannel) publishMultiple(ctx context.Context, messages []*Message, out interface{}, options ...PublishMultipleOption) error { var publishOpts publishMultipleOptions for _, o := range options { o(&publishOpts) @@ -146,13 +163,18 @@ func (c *RESTChannel) PublishMultiple(ctx context.Context, messages []*Message, } } - res, err := c.client.post(ctx, c.baseURL+"/messages"+query, messages, nil) + res, err := c.client.post(ctx, c.baseURL+"/messages"+query, messages, out) if err != nil { return err } return res.Body.Close() } +// PublishMultiple publishes multiple messages in a batch. Returns error if there is a problem publishing message (RSL1). +func (c *RESTChannel) PublishMultiple(ctx context.Context, messages []*Message, options ...PublishMultipleOption) error { + return c.publishMultiple(ctx, messages, nil, options...) +} + // PublishMultipleWithOptions is the same as PublishMultiple. // // Deprecated: Use PublishMultiple instead. @@ -162,6 +184,104 @@ func (c *RESTChannel) PublishMultipleWithOptions(ctx context.Context, messages [ return c.PublishMultiple(ctx, messages, options...) } +// PublishWithResult publishes a single message to the channel with the given event name and payload, +// and returns the serial assigned by the server. Returns error if there is a problem performing message publish. +func (c *RESTChannel) PublishWithResult(ctx context.Context, name string, data interface{}, options ...PublishMultipleOption) (*PublishResult, error) { + results, err := c.PublishMultipleWithResult(ctx, []*Message{{Name: name, Data: data}}, options...) + if err != nil { + return nil, err + } + if len(results) == 0 { + return &PublishResult{}, nil + } + return &results[0], nil +} + +// PublishMultipleWithResult publishes multiple messages in a batch and returns the serials assigned by the server. +// Returns error if there is a problem publishing messages. +func (c *RESTChannel) PublishMultipleWithResult(ctx context.Context, messages []*Message, options ...PublishMultipleOption) ([]PublishResult, error) { + var response publishResponse + if err := c.publishMultiple(ctx, messages, &response, options...); err != nil { + return nil, err + } + + // Debug: log response + c.log().Debugf("PublishMultipleWithResult response: serials=%v, count=%d", response.Serials, len(response.Serials)) + + // Build results from serials + results := make([]PublishResult, len(messages)) + for i := range results { + if i < len(response.Serials) { + results[i].Serial = response.Serials[i] + } + } + return results, nil +} + +// performMessageOperation is a shared helper for UpdateMessage, DeleteMessage, and AppendMessage. +// It validates the message serial, applies update options, sets the action, encodes data, and sends the request. +func (c *RESTChannel) performMessageOperation(ctx context.Context, msg *Message, action MessageAction, options ...UpdateOption) (*UpdateResult, error) { + if err := validateMessageSerial(msg); err != nil { + return nil, err + } + + // Apply options + var opts updateOptions + for _, o := range options { + o(&opts) + } + + // Build version from options + version := &MessageVersion{ + Description: opts.description, + ClientID: opts.clientID, + Metadata: opts.metadata, + } + + // Create message for the operation + opMsg := *msg + opMsg.Action = action + opMsg.Version = version + + // Encode data + cipher, _ := c.options.GetCipher() + var err error + opMsg, err = opMsg.withEncodedData(cipher) + if err != nil { + return nil, fmt.Errorf("encoding data for message: %w", err) + } + + // POST to API + var response publishResponse + res, err := c.client.post(ctx, c.baseURL+"/messages", []*Message{&opMsg}, &response) + if err != nil { + return nil, err + } + defer res.Body.Close() + + // Extract version serial + result := &UpdateResult{} + if len(response.Serials) > 0 { + result.VersionSerial = response.Serials[0] + } + return result, nil +} + +// UpdateMessage updates a previously published message. +func (c *RESTChannel) UpdateMessage(ctx context.Context, msg *Message, options ...UpdateOption) (*UpdateResult, error) { + return c.performMessageOperation(ctx, msg, MessageActionUpdate, options...) +} + +// DeleteMessage deletes a previously published message. +func (c *RESTChannel) DeleteMessage(ctx context.Context, msg *Message, options ...UpdateOption) (*UpdateResult, error) { + return c.performMessageOperation(ctx, msg, MessageActionDelete, options...) +} + +// AppendMessage appends to a previously published message. +func (c *RESTChannel) AppendMessage(ctx context.Context, msg *Message, options ...UpdateOption) (*UpdateResult, error) { + return c.performMessageOperation(ctx, msg, MessageActionAppend, options...) +} + // ChannelDetails contains the details of a [ably.RESTChannel] or [ably.RealtimeChannel] object // such as its ID and [ably.ChannelStatus]. type ChannelDetails struct { @@ -370,11 +490,20 @@ func (c *RESTChannel) fullMessagesDecoder(dst *[]*Message) interface{} { return &fullMessagesDecoder{dst: dst, c: c} } +func (c *RESTChannel) fullMessageDecoder(dst *Message) interface{} { + return &fullMessageDecoder{dst: dst, c: c} +} + type fullMessagesDecoder struct { dst *[]*Message c *RESTChannel } +type fullMessageDecoder struct { + dst *Message + c *RESTChannel +} + func (t *fullMessagesDecoder) UnmarshalJSON(b []byte) error { err := json.Unmarshal(b, &t.dst) if err != nil { @@ -410,6 +539,38 @@ func (t *fullMessagesDecoder) decodeMessagesData() { } } +func (t *fullMessageDecoder) UnmarshalJSON(b []byte) error { + err := json.Unmarshal(b, t.dst) + if err != nil { + return err + } + t.decodeMessageData() + return nil +} + +func (t *fullMessageDecoder) CodecEncodeSelf(*codec.Encoder) { + panic("fullMessageDecoder cannot be used as encoder") +} + +func (t *fullMessageDecoder) CodecDecodeSelf(decoder *codec.Decoder) { + decoder.MustDecode(t.dst) + t.decodeMessageData() +} + +var _ interface { + json.Unmarshaler + codec.Selfer +} = (*fullMessageDecoder)(nil) + +func (t *fullMessageDecoder) decodeMessageData() { + cipher, _ := t.c.options.GetCipher() + var err error + *t.dst, err = t.dst.withDecodedData(cipher) + if err != nil { + t.c.log().Errorf("Couldn't fully decode message data from channel %q: %w", t.c.Name, err) + } +} + type MessagesPaginatedItems struct { PaginatedResult items []*Message @@ -436,6 +597,32 @@ func (p *MessagesPaginatedItems) Item() *Message { return p.item } +// GetMessage retrieves a message by its serial. +func (c *RESTChannel) GetMessage(ctx context.Context, serial string) (*Message, error) { + var message Message + req := &request{ + Method: "GET", + Path: c.baseURL + "/messages/" + url.PathEscape(serial), + Out: c.fullMessageDecoder(&message), + } + _, err := c.client.do(ctx, req) + if err != nil { + return nil, err + } + return &message, nil +} + +// GetMessageVersions retrieves the version history of a message by its serial. +// Returns a HistoryRequest that can be used to paginate through message versions. +func (c *RESTChannel) GetMessageVersions(serial string, params url.Values) HistoryRequest { + path := c.baseURL + "/messages/" + url.PathEscape(serial) + "/versions" + rawPath := "/channels/" + c.pathName() + "/messages/" + url.PathEscape(serial) + "/versions" + return HistoryRequest{ + r: c.client.newPaginatedRequest(path, rawPath, params), + channel: c, + } +} + func (c *RESTChannel) log() logger { return c.client.log } diff --git a/ably/rest_client.go b/ably/rest_client.go index 233e79a4..3db36cf0 100644 --- a/ably/rest_client.go +++ b/ably/rest_client.go @@ -188,9 +188,26 @@ func (c *REST) Time(ctx context.Context) (time.Time, error) { // [ably.PaginatedResult] object, containing an array of [Stats]{@link Stats} objects (RSC6a). // // See package-level documentation => [ably] Pagination for handling stats pagination. +// +// Note: Stats requests use protocol version 2 to maintain compatibility with the existing +// nested Stats structure. Migrating to the flattened protocol v3+ stats format is planned +// for ably-go v2 as it requires breaking API changes. func (c *REST) Stats(o ...StatsOption) StatsRequest { params := (&statsOptions{}).apply(o...) - return StatsRequest{r: c.newPaginatedRequest("/stats", "", params)} + + // Use protocol v2 for stats to maintain compatibility with existing Stats structure. + // Protocol v3+ uses a flattened format that would require breaking API changes. + statsHeader := make(http.Header) + statsHeader.Set(ablyProtocolVersionHeader, "2") + + req := c.newPaginatedRequest("/stats", "", params) + // Override the query function to use getWithHeader with protocol v2 + req.query = func(ctx context.Context, path string) (*http.Response, error) { + // Pass nil for out because pagination decodes the response separately + return c.getWithHeader(ctx, path, nil, statsHeader) + } + + return StatsRequest{r: req} } func (c *REST) setActiveRealtimeHost(realtimeHost string) { @@ -622,6 +639,17 @@ func (c *REST) get(ctx context.Context, path string, out interface{}) (*http.Res return c.do(ctx, r) } +// getWithHeader is like get but allows specifying custom HTTP headers. +func (c *REST) getWithHeader(ctx context.Context, path string, out interface{}, header http.Header) (*http.Response, error) { + r := &request{ + Method: "GET", + Path: path, + Out: out, + header: header, + } + return c.do(ctx, r) +} + func (c *REST) post(ctx context.Context, path string, in, out interface{}) (*http.Response, error) { r := &request{ Method: "POST", @@ -784,8 +812,11 @@ func (c *REST) newHTTPRequest(ctx context.Context, r *request) (*http.Request, e if r.header != nil { copyHeader(req.Header, r.header) } - req.Header.Set("Accept", protocol) // RSC19c - req.Header.Set(ablyProtocolVersionHeader, ablyProtocolVersion) // RSC7a + req.Header.Set("Accept", protocol) // RSC19c + // RSC7a - Only set protocol version if not already set by custom header + if req.Header.Get(ablyProtocolVersionHeader) == "" { + req.Header.Set(ablyProtocolVersionHeader, ablyProtocolVersion) + } req.Header.Set(ablyAgentHeader, ablyAgentIdentifier(c.opts.Agents)) // RSC7d if c.opts.ClientID != "" && c.Auth.method == authBasic { // References RSA7e2 diff --git a/ably/state.go b/ably/state.go index a0e92227..8416ce60 100644 --- a/ably/state.go +++ b/ably/state.go @@ -125,14 +125,14 @@ func (q *pendingEmitter) Dismiss() []msgWithAckCallback { return cx } -func (q *pendingEmitter) Enqueue(msg *protocolMessage, onAck func(err error)) { +func (q *pendingEmitter) Enqueue(msg *protocolMessage, callback *ackCallback) { if len(q.queue) > 0 { expected := q.queue[len(q.queue)-1].msg.MsgSerial + 1 if got := msg.MsgSerial; expected != got { panic(fmt.Sprintf("protocol violation: expected next enqueued message to have msgSerial %d; got %d", expected, got)) } } - q.queue = append(q.queue, msgWithAckCallback{msg, onAck}) + q.queue = append(q.queue, msgWithAckCallback{msg, callback}) } func (q *pendingEmitter) Ack(msg *protocolMessage, errInfo *ErrorInfo) { @@ -180,15 +180,48 @@ func (q *pendingEmitter) Ack(msg *protocolMessage, errInfo *ErrorInfo) { err = errImplictNACK } q.log.Verbosef("received %v for message serial %d", msg.Action, sch.msg.MsgSerial) - if sch.onAck != nil { - sch.onAck(err) + + // Extract the corresponding result for this message from the res array. + // The res array contains results only for messages that were actually ACKed, + // not for implicitly NACKed messages (those with i < serialShift). + var result []*protocolPublishResult + resIndex := i - serialShift + if msg.Res != nil && resIndex >= 0 && resIndex < len(msg.Res) { + result = []*protocolPublishResult{msg.Res[resIndex]} } + sch.callback.call(result, err) + } +} + +type ackCallback struct { + onAck func(err error) + onAckWithSerials func(serials []string, err error) +} + +// call invokes the appropriate callback based on which is set. +// If onAckWithSerials is set, extracts serials from res before calling. +// If onAck is set, calls it with just the error. +func (cb *ackCallback) call(res []*protocolPublishResult, err error) { + if cb == nil { + return + } + if cb.onAckWithSerials != nil { + // Extract serials from results + var serials []string + for _, result := range res { + if result != nil && len(result.Serials) > 0 { + serials = append(serials, result.Serials...) + } + } + cb.onAckWithSerials(serials, err) + } else if cb.onAck != nil { + cb.onAck(err) } } type msgWithAckCallback struct { - msg *protocolMessage - onAck func(err error) + msg *protocolMessage + callback *ackCallback } type msgQueue struct { @@ -203,17 +236,17 @@ func newMsgQueue(conn *Connection) *msgQueue { } } -func (q *msgQueue) Enqueue(msg *protocolMessage, onAck func(err error)) { +func (q *msgQueue) Enqueue(msg *protocolMessage, callback *ackCallback) { q.mtx.Lock() // TODO(rjeczalik): reorder the queue so Presence / Messages can be merged - q.queue = append(q.queue, msgWithAckCallback{msg, onAck}) + q.queue = append(q.queue, msgWithAckCallback{msg, callback}) q.mtx.Unlock() } func (q *msgQueue) Flush() { q.mtx.Lock() for _, queueMsg := range q.queue { - q.conn.send(queueMsg.msg, queueMsg.onAck) + q.conn.send(queueMsg.msg, queueMsg.callback) } q.queue = nil q.mtx.Unlock() @@ -223,9 +256,7 @@ func (q *msgQueue) Fail(err error) { q.mtx.Lock() for _, queueMsg := range q.queue { q.log().Errorf("failure sending message (serial=%d): %v", queueMsg.msg.MsgSerial, err) - if queueMsg.onAck != nil { - queueMsg.onAck(newError(90000, err)) - } + queueMsg.callback.call(nil, newError(90000, err)) } q.queue = nil q.mtx.Unlock() diff --git a/ably/state_test.go b/ably/state_test.go new file mode 100644 index 00000000..b61505e8 --- /dev/null +++ b/ably/state_test.go @@ -0,0 +1,162 @@ +//go:build !integration +// +build !integration + +package ably + +import ( + "io" + "log" + "testing" + + "github.com/stretchr/testify/assert" +) + +// This test verifies that when an ACK message contains a "res" array with multiple results, +// each result is correctly associated with its corresponding message based on msgSerial. +// +// Scenario: SDK sends multiple messages with consecutive msgSerials +// Server responds with ACK: msgSerial=X, count=N, res=[result1, result2, ..., resultN] +// Expected behavior: Each message should receive serials from only its corresponding res element +func TestPendingEmitter_AckResult(t *testing.T) { + t.Run("two messages with single serial each", func(t *testing.T) { + testLogger := logger{l: &stdLogger{log.New(io.Discard, "", 0)}} + emitter := newPendingEmitter(testLogger) + + // Track what serials each message receives + var msg1Serials, msg2Serials []string + + // Create two protocol messages with consecutive msgSerials + protoMsg1 := &protocolMessage{ + MsgSerial: 5, + Action: actionMessage, + Channel: "test-channel", + } + callback1 := &ackCallback{ + onAckWithSerials: func(serials []string, err error) { + msg1Serials = serials + }, + } + + protoMsg2 := &protocolMessage{ + MsgSerial: 6, + Action: actionMessage, + Channel: "test-channel", + } + callback2 := &ackCallback{ + onAckWithSerials: func(serials []string, err error) { + msg2Serials = serials + }, + } + + // Enqueue both messages + emitter.Enqueue(protoMsg1, callback1) + emitter.Enqueue(protoMsg2, callback2) + + // Simulate receiving an ACK with msgSerial=5, count=2, and res array with two distinct results + ackMsg := &protocolMessage{ + Action: actionAck, + MsgSerial: 5, + Count: 2, + Res: []*protocolPublishResult{ + {Serials: []string{"serial-for-msg-5"}}, // Should only go to message 1 + {Serials: []string{"serial-for-msg-6"}}, // Should only go to message 2 + }, + } + + // Process the ACK + emitter.Ack(ackMsg, nil) + + // Verify each message received only its corresponding serials + assert.Equal(t, []string{"serial-for-msg-5"}, msg1Serials, + "Message 1 (msgSerial=5) should only receive serials from res[0]") + assert.Equal(t, []string{"serial-for-msg-6"}, msg2Serials, + "Message 2 (msgSerial=6) should only receive serials from res[1]") + }) + + t.Run("two messages with multiple serials each", func(t *testing.T) { + testLogger := logger{l: &stdLogger{log.New(io.Discard, "", 0)}} + emitter := newPendingEmitter(testLogger) + + var msg1Serials, msg2Serials []string + + protoMsg1 := &protocolMessage{ + MsgSerial: 10, + Action: actionMessage, + Channel: "test-channel", + } + callback1 := &ackCallback{ + onAckWithSerials: func(serials []string, err error) { + msg1Serials = serials + }, + } + + protoMsg2 := &protocolMessage{ + MsgSerial: 11, + Action: actionMessage, + Channel: "test-channel", + } + callback2 := &ackCallback{ + onAckWithSerials: func(serials []string, err error) { + msg2Serials = serials + }, + } + + emitter.Enqueue(protoMsg1, callback1) + emitter.Enqueue(protoMsg2, callback2) + + // ACK with multiple serials per result + ackMsg := &protocolMessage{ + Action: actionAck, + MsgSerial: 10, + Count: 2, + Res: []*protocolPublishResult{ + {Serials: []string{"serial-10-a", "serial-10-b"}}, // Both should go to message 1 + {Serials: []string{"serial-11-a", "serial-11-b"}}, // Both should go to message 2 + }, + } + + emitter.Ack(ackMsg, nil) + + assert.Equal(t, []string{"serial-10-a", "serial-10-b"}, msg1Serials, + "Message 1 (msgSerial=10) should receive both serials from res[0]") + assert.Equal(t, []string{"serial-11-a", "serial-11-b"}, msg2Serials, + "Message 2 (msgSerial=11) should receive both serials from res[1]") + }) + + t.Run("three messages", func(t *testing.T) { + testLogger := logger{l: &stdLogger{log.New(io.Discard, "", 0)}} + emitter := newPendingEmitter(testLogger) + + var msg1Serials, msg2Serials, msg3Serials []string + + protoMsg1 := &protocolMessage{MsgSerial: 1, Action: actionMessage, Channel: "test"} + callback1 := &ackCallback{onAckWithSerials: func(serials []string, err error) { msg1Serials = serials }} + + protoMsg2 := &protocolMessage{MsgSerial: 2, Action: actionMessage, Channel: "test"} + callback2 := &ackCallback{onAckWithSerials: func(serials []string, err error) { msg2Serials = serials }} + + protoMsg3 := &protocolMessage{MsgSerial: 3, Action: actionMessage, Channel: "test"} + callback3 := &ackCallback{onAckWithSerials: func(serials []string, err error) { msg3Serials = serials }} + + emitter.Enqueue(protoMsg1, callback1) + emitter.Enqueue(protoMsg2, callback2) + emitter.Enqueue(protoMsg3, callback3) + + ackMsg := &protocolMessage{ + Action: actionAck, + MsgSerial: 1, + Count: 3, + Res: []*protocolPublishResult{ + {Serials: []string{"serial-1"}}, + {Serials: []string{"serial-2"}}, + {Serials: []string{"serial-3"}}, + }, + } + + emitter.Ack(ackMsg, nil) + + assert.Equal(t, []string{"serial-1"}, msg1Serials, "Message 1 should receive serial-1") + assert.Equal(t, []string{"serial-2"}, msg2Serials, "Message 2 should receive serial-2") + assert.Equal(t, []string{"serial-3"}, msg3Serials, "Message 3 should receive serial-3") + }) +} diff --git a/ablytest/sandbox.go b/ablytest/sandbox.go index a2294bd3..e8898bb2 100644 --- a/ablytest/sandbox.go +++ b/ablytest/sandbox.go @@ -34,10 +34,11 @@ type Key struct { } type Namespace struct { - ID string `json:"id"` - Created int `json:"created,omitempty"` - Modified int `json:"modified,omitempty"` - Persisted bool `json:"persisted,omitempty"` + ID string `json:"id"` + Created int `json:"created,omitempty"` + Modified int `json:"modified,omitempty"` + Persisted bool `json:"persisted,omitempty"` + MutableMessages bool `json:"mutableMessages,omitempty"` } type Presence struct { @@ -80,6 +81,7 @@ func DefaultConfig() *Config { }, Namespaces: []Namespace{ {ID: "persisted", Persisted: true}, + {ID: "mutable", MutableMessages: true}, }, Channels: []Channel{ {