diff --git a/.github/workflows/client_go.yml b/.github/workflows/client_go.yml new file mode 100644 index 0000000..98aea48 --- /dev/null +++ b/.github/workflows/client_go.yml @@ -0,0 +1,40 @@ +name: Client (Go) + +on: + push: + branches: [main] + paths: + - "client_go/**" + - ".github/workflows/client_go.yml" + pull_request: + branches: [main] + paths: + - "client_go/**" + - ".github/workflows/client_go.yml" + +jobs: + test: + name: Test + runs-on: ubuntu-latest + + defaults: + run: + working-directory: client_go + + steps: + - uses: actions/checkout@v4 + + - name: Set up Go + uses: actions/setup-go@v5 + with: + go-version-file: client_go/go.mod + cache-dependency-path: client_go/go.sum + + - name: Build + run: go build ./... + + - name: Vet + run: go vet ./... + + - name: Test + run: go test -race ./... diff --git a/README.md b/README.md index 9eff558..b048939 100644 --- a/README.md +++ b/README.md @@ -10,13 +10,14 @@ +


Topical is an Elixir library for synchronising server-maintained state (_topics_) to connected clients. Topic lifecycle is managed by the server: topics are initialised as needed, shared between subscribing clients, and automatically shut down when not in use. -The accompanying JavaScript library (and React hooks) allow clients to easily connect to topics, and efficiently receive real-time updates. Clients can also send requests (or notifications) upstream to the server. +The accompanying JavaScript library (and React hooks) and Go client allow clients to easily connect to topics, and efficiently receive real-time updates. Clients can also send requests (or notifications) upstream to the server.

Architecture diagram @@ -137,6 +138,7 @@ This repository is separated into: - [`server_ex`](server_ex/) - the Elixir library for implementing topic servers, including adapters. - [`client_js`](client_js/) - the vanilla JavaScript WebSocket client. - [`client_react`](client_react/) - React hooks built on top of the JavaScript client. +- [`client_go`](client_go/) - Go WebSocket client. ## License diff --git a/client_go/README.md b/client_go/README.md new file mode 100644 index 0000000..a721498 --- /dev/null +++ b/client_go/README.md @@ -0,0 +1,115 @@ +# Topical → Client (Go) + +A Go client for [Topical](https://github.com/joefreeman/topical), a real-time state synchronization library. Connects to a Topical server over WebSocket and keeps local state in sync. + +## Install + +``` +go get github.com/joefreeman/topical/client_go +``` + +## Usage + +### Connecting + +```go +ctx := context.Background() +client, err := topical.Connect(ctx, "ws://localhost:4000/socket") +if err != nil { + log.Fatal(err) +} +defer client.Close() +``` + +By default the client reconnects automatically with exponential backoff. This can be configured: + +```go +client, err := topical.Connect(ctx, url, + topical.WithReconnect(false), + topical.WithBackoff(1*time.Second, 60*time.Second), +) +``` + +### Subscribing to topics + +Subscribe returns a `*Subscription` with channels for receiving values and errors. Multiple subscriptions to the same topic share a single server-side subscription. + +```go +sub := client.Subscribe("lists/my-list", nil) +defer sub.Unsubscribe() + +for val := range sub.Values() { + fmt.Println("new value:", val) +} +``` + +Topics can take parameters: + +```go +sub := client.Subscribe("lists/my-list", topical.Params{"user_id": "123"}) +``` + +### Typed subscriptions + +Use the generic `Subscribe` function to automatically unmarshal values into a struct: + +```go +type TodoList struct { + Items map[string]Item `json:"items"` + Order []string `json:"order"` +} + +sub := topical.Subscribe[TodoList](client, "lists/my-list", nil) +defer sub.Unsubscribe() + +for list := range sub.Values() { + fmt.Printf("got %d items\n", len(list.Items)) +} +``` + +### Execute (RPC) + +Send a request and wait for a response. The context controls the timeout: + +```go +ctx, cancel := context.WithTimeout(ctx, 5*time.Second) +defer cancel() + +result, err := client.Execute(ctx, "lists/my-list", "add_item", []any{"buy milk"}, nil) +``` + +### Notify (fire-and-forget) + +Send a one-way message with no response: + +```go +err := client.Notify("lists/my-list", "mark_done", []any{"item-id"}, nil) +``` + +### Connection state + +```go +fmt.Println(client.State()) // "connected", "connecting", or "disconnected" + +stateSub := client.StateChanges() +defer stateSub.Close() + +for s := range stateSub.C() { + fmt.Println("state changed:", s) +} +``` + +### Error handling + +Check for subscription errors on the `Err()` channel: + +```go +select { +case val := <-sub.Values(): + handleValue(val) +case err := <-sub.Err(): + handleError(err) +} +``` + +Operations return `topical.ErrNotConnected` when the client is disconnected. diff --git a/client_go/client.go b/client_go/client.go new file mode 100644 index 0000000..568dab3 --- /dev/null +++ b/client_go/client.go @@ -0,0 +1,642 @@ +package topical + +import ( + "context" + "encoding/json" + "errors" + "math/rand/v2" + "net/url" + "sort" + "strings" + "sync" + "time" + + "github.com/coder/websocket" +) + +// State represents the connection state. +type State int + +const ( + Connecting State = iota + Connected + Disconnected +) + +func (s State) String() string { + switch s { + case Connecting: + return "connecting" + case Connected: + return "connected" + case Disconnected: + return "disconnected" + default: + return "unknown" + } +} + +// Params holds topic parameters. +type Params map[string]string + +var ( + // ErrNotConnected is returned when an operation requires a connection but the client is not connected. + ErrNotConnected = errors.New("topical: not connected") + // ErrInvalidMessage is returned when a received message cannot be decoded. + ErrInvalidMessage = errors.New("topical: invalid message") +) + +type clientConfig struct { + reconnect bool + backoffBase time.Duration + backoffMax time.Duration + dialOptions *websocket.DialOptions +} + +// Option configures a Client. +type Option func(*clientConfig) + +// WithReconnect enables or disables automatic reconnection. +func WithReconnect(enabled bool) Option { + return func(c *clientConfig) { c.reconnect = enabled } +} + +// WithBackoff configures reconnection backoff timing. +func WithBackoff(base, max time.Duration) Option { + return func(c *clientConfig) { + c.backoffBase = base + c.backoffMax = max + } +} + +// WithDialOptions sets WebSocket dial options. +func WithDialOptions(opts *websocket.DialOptions) Option { + return func(c *clientConfig) { c.dialOptions = opts } +} + +type request struct { + result chan any + err chan error +} + +type topicEntry struct { + listeners []*listener + topic string + params Params + channelID int + value any + hasValue bool +} + +type listener struct { + values chan any + errors chan error + closed bool +} + +// Client manages a WebSocket connection to a Topical server. +type Client struct { + mu sync.Mutex + url string + config clientConfig + conn *websocket.Conn + state State + closed bool + lastChannelID int + topics map[string]*topicEntry + requests map[int]*request + subscriptions map[int]string // channelID -> topic key + aliases map[int]int // aliased channelID -> target channelID + stateListeners map[int]chan State + nextListenerID int + ctx context.Context + cancel context.CancelFunc + wg sync.WaitGroup +} + +// Connect establishes a WebSocket connection to the given URL. +func Connect(ctx context.Context, rawURL string, opts ...Option) (*Client, error) { + cfg := clientConfig{ + reconnect: true, + backoffBase: 500 * time.Millisecond, + backoffMax: 30 * time.Second, + } + for _, o := range opts { + o(&cfg) + } + + clientCtx, cancel := context.WithCancel(context.Background()) + + c := &Client{ + url: rawURL, + config: cfg, + state: Connecting, + topics: make(map[string]*topicEntry), + requests: make(map[int]*request), + subscriptions: make(map[int]string), + aliases: make(map[int]int), + stateListeners: make(map[int]chan State), + ctx: clientCtx, + cancel: cancel, + } + + conn, _, err := websocket.Dial(ctx, rawURL, cfg.dialOptions) + if err != nil { + cancel() + return nil, err + } + conn.SetReadLimit(-1) + c.conn = conn + c.state = Connected + + // Resubscribe existing topics (none on first connect, but used after reconnect) + c.mu.Lock() + for key := range c.topics { + c.setupSubscriptionLocked(key) + } + c.mu.Unlock() + + c.wg.Add(1) + go c.readLoop() + + return c, nil +} + +// Close shuts down the client and its WebSocket connection. +func (c *Client) Close() error { + c.mu.Lock() + if c.closed { + c.mu.Unlock() + return nil + } + c.closed = true + c.mu.Unlock() + + c.cancel() + if c.conn != nil { + c.conn.Close(websocket.StatusNormalClosure, "") + } + c.wg.Wait() + return nil +} + +// State returns the current connection state. +func (c *Client) State() State { + c.mu.Lock() + defer c.mu.Unlock() + return c.state +} + +// StateSubscription receives connection state changes. +type StateSubscription struct { + client *Client + id int + ch chan State +} + +// C returns a channel that receives state changes. +func (ss *StateSubscription) C() <-chan State { + return ss.ch +} + +// Close removes this state subscription. +func (ss *StateSubscription) Close() { + ss.client.mu.Lock() + defer ss.client.mu.Unlock() + delete(ss.client.stateListeners, ss.id) +} + +// StateChanges returns a subscription that receives connection state changes. +func (c *Client) StateChanges() *StateSubscription { + c.mu.Lock() + id := c.nextListenerID + c.nextListenerID++ + ch := make(chan State, 1) + c.stateListeners[id] = ch + c.mu.Unlock() + return &StateSubscription{client: c, id: id, ch: ch} +} + +// notifyStateListeners sends the new state to all state listener channels. +func (c *Client) notifyStateListeners(s State) { + c.mu.Lock() + listeners := make([]chan State, 0, len(c.stateListeners)) + for _, ch := range c.stateListeners { + listeners = append(listeners, ch) + } + c.mu.Unlock() + for _, ch := range listeners { + sendReplace(ch, s) + } +} + +func (c *Client) nextChannelID() int { + c.lastChannelID++ + return c.lastChannelID +} + +func (c *Client) readLoop() { + defer c.wg.Done() + for { + _, data, err := c.conn.Read(c.ctx) + if err != nil { + c.handleDisconnect() + return + } + c.handleMessage(data) + } +} + +func (c *Client) handleMessage(data []byte) { + opcode, fields, err := decodeResponse(data) + if err != nil { + return + } + + switch opcode { + case respError: + c.handleError(fields) + case respResult: + c.handleResult(fields) + case respTopicReset: + c.handleTopicReset(fields) + case respTopicUpdates: + c.handleTopicUpdates(fields) + case respTopicAlias: + c.handleTopicAlias(fields) + } +} + +func (c *Client) handleError(fields []json.RawMessage) { + if len(fields) < 2 { + return + } + var channelID int + if err := json.Unmarshal(fields[0], &channelID); err != nil { + return + } + var errorVal any + json.Unmarshal(fields[1], &errorVal) + + c.mu.Lock() + defer c.mu.Unlock() + + if key, ok := c.subscriptions[channelID]; ok { + if t, ok := c.topics[key]; ok { + for _, l := range t.listeners { + sendNonBlocking(l.errors, errors.New(errorString(errorVal))) + } + delete(c.topics, key) + } + delete(c.subscriptions, channelID) + } else if req, ok := c.requests[channelID]; ok { + sendNonBlocking(req.err, errors.New(errorString(errorVal))) + delete(c.requests, channelID) + } +} + +func (c *Client) handleResult(fields []json.RawMessage) { + if len(fields) < 2 { + return + } + var channelID int + if err := json.Unmarshal(fields[0], &channelID); err != nil { + return + } + var result any + json.Unmarshal(fields[1], &result) + + c.mu.Lock() + req, ok := c.requests[channelID] + if ok { + delete(c.requests, channelID) + } + c.mu.Unlock() + + if ok { + sendNonBlocking(req.result, result) + } +} + +func (c *Client) handleTopicReset(fields []json.RawMessage) { + if len(fields) < 2 { + return + } + var channelID int + if err := json.Unmarshal(fields[0], &channelID); err != nil { + return + } + var value any + json.Unmarshal(fields[1], &value) + + c.mu.Lock() + key, ok := c.subscriptions[channelID] + if !ok { + c.mu.Unlock() + return + } + t, ok := c.topics[key] + if !ok { + c.mu.Unlock() + return + } + t.value = value + t.hasValue = true + listeners := make([]*listener, len(t.listeners)) + copy(listeners, t.listeners) + c.mu.Unlock() + + for _, l := range listeners { + sendReplace(l.values, value) + } +} + +func (c *Client) handleTopicUpdates(fields []json.RawMessage) { + if len(fields) < 2 { + return + } + var channelID int + if err := json.Unmarshal(fields[0], &channelID); err != nil { + return + } + var updates [][]any + if err := json.Unmarshal(fields[1], &updates); err != nil { + return + } + + c.mu.Lock() + key, ok := c.subscriptions[channelID] + if !ok { + c.mu.Unlock() + return + } + t, ok := c.topics[key] + if !ok { + c.mu.Unlock() + return + } + + value := t.value + for _, u := range updates { + var err error + value, err = applyUpdate(value, u) + if err != nil { + c.mu.Unlock() + return + } + } + t.value = value + t.hasValue = true + listeners := make([]*listener, len(t.listeners)) + copy(listeners, t.listeners) + c.mu.Unlock() + + for _, l := range listeners { + sendReplace(l.values, value) + } +} + +func (c *Client) handleTopicAlias(fields []json.RawMessage) { + if len(fields) < 2 { + return + } + var aliasedChannelID, targetChannelID int + if err := json.Unmarshal(fields[0], &aliasedChannelID); err != nil { + return + } + if err := json.Unmarshal(fields[1], &targetChannelID); err != nil { + return + } + + c.mu.Lock() + aliasedKey, ok1 := c.subscriptions[aliasedChannelID] + targetKey, ok2 := c.subscriptions[targetChannelID] + if !ok1 || !ok2 { + c.mu.Unlock() + return + } + aliasedTopic := c.topics[aliasedKey] + targetTopic := c.topics[targetKey] + if aliasedTopic == nil || targetTopic == nil { + c.mu.Unlock() + return + } + + // Move listeners from aliased to target + movedListeners := make([]*listener, len(aliasedTopic.listeners)) + copy(movedListeners, aliasedTopic.listeners) + targetTopic.listeners = append(targetTopic.listeners, movedListeners...) + + hasValue := targetTopic.hasValue + value := targetTopic.value + + // Clean up aliased topic + delete(c.topics, aliasedKey) + delete(c.subscriptions, aliasedChannelID) + c.aliases[aliasedChannelID] = targetChannelID + c.mu.Unlock() + + // Notify moved listeners of current value + if hasValue { + for _, l := range movedListeners { + sendReplace(l.values, value) + } + } +} + +func (c *Client) handleDisconnect() { + c.mu.Lock() + c.state = Disconnected + + // Clear channel IDs from topics (they'll be reassigned on reconnect) + for _, t := range c.topics { + t.channelID = 0 + } + c.subscriptions = make(map[int]string) + c.aliases = make(map[int]int) + + // Reject pending requests + for _, req := range c.requests { + sendNonBlocking(req.err, ErrNotConnected) + } + c.requests = make(map[int]*request) + + shouldReconnect := c.config.reconnect && !c.closed + + // If closed and not reconnecting, close all listener channels so consumers unblock + if c.closed { + for _, t := range c.topics { + for _, l := range t.listeners { + if !l.closed { + l.closed = true + close(l.values) + close(l.errors) + } + } + } + for _, ch := range c.stateListeners { + close(ch) + } + } + + c.mu.Unlock() + + c.notifyStateListeners(Disconnected) + + if shouldReconnect { + c.reconnect() + } +} + +func (c *Client) reconnect() { + delay := c.config.backoffBase + for { + select { + case <-c.ctx.Done(): + return + case <-time.After(jitter(delay)): + } + + c.mu.Lock() + if c.closed { + c.mu.Unlock() + return + } + c.state = Connecting + c.mu.Unlock() + + c.notifyStateListeners(Connecting) + + conn, _, err := websocket.Dial(c.ctx, c.url, c.config.dialOptions) + if err != nil { + delay = min(delay*2, c.config.backoffMax) + c.mu.Lock() + c.state = Disconnected + c.mu.Unlock() + c.notifyStateListeners(Disconnected) + continue + } + conn.SetReadLimit(-1) + + c.mu.Lock() + c.conn = conn + c.state = Connected + + // Resubscribe all active topics + for key := range c.topics { + c.setupSubscriptionLocked(key) + } + c.mu.Unlock() + + c.notifyStateListeners(Connected) + + c.wg.Add(1) + go c.readLoop() + return + } +} + +func (c *Client) setupSubscriptionLocked(key string) { + t := c.topics[key] + if t == nil { + return + } + channelID := c.nextChannelID() + t.channelID = channelID + c.subscriptions[channelID] = key + + data, err := encodeSubscribe(channelID, t.topic, t.params) + if err != nil { + return + } + c.conn.Write(c.ctx, websocket.MessageText, data) +} + +func (c *Client) send(data []byte) error { + c.mu.Lock() + conn := c.conn + ctx := c.ctx + c.mu.Unlock() + if conn == nil { + return ErrNotConnected + } + return conn.Write(ctx, websocket.MessageText, data) +} + +// sendLocked writes data on the current connection. Must be called with c.mu held. +func (c *Client) sendLocked(data []byte) error { + if c.conn == nil { + return ErrNotConnected + } + return c.conn.Write(c.ctx, websocket.MessageText, data) +} + +// topicKey generates a deterministic key for a topic + params combination. +func topicKey(topic string, params Params) string { + var b strings.Builder + b.WriteString(topic) + b.WriteByte('?') + keys := make([]string, 0, len(params)) + for k := range params { + keys = append(keys, k) + } + sort.Strings(keys) + for i, k := range keys { + if i > 0 { + b.WriteByte('&') + } + b.WriteString(url.QueryEscape(k)) + b.WriteByte('=') + b.WriteString(url.QueryEscape(params[k])) + } + return b.String() +} + +// sendReplace sends a value on a buffered channel (size 1), draining the old value if needed. +// This function assumes a single goroutine sends to ch at a time (the readLoop goroutine). +// Multiple concurrent senders would race on the drain-then-send sequence. +func sendReplace[T any](ch chan T, val T) { + select { + case ch <- val: + default: + // Drain stale value and send new one + select { + case <-ch: + default: + } + select { + case ch <- val: + default: + } + } +} + +// sendNonBlocking tries to send on a buffered channel without blocking. +func sendNonBlocking[T any](ch chan T, val T) { + select { + case ch <- val: + default: + } +} + +func errorString(v any) string { + switch e := v.(type) { + case string: + return e + case map[string]any: + if msg, ok := e["message"]; ok { + return errorString(msg) + } + b, _ := json.Marshal(e) + return string(b) + default: + b, _ := json.Marshal(v) + return string(b) + } +} + +func jitter(d time.Duration) time.Duration { + // +/-25% jitter + factor := 0.75 + rand.Float64()*0.5 + return time.Duration(float64(d) * factor) +} diff --git a/client_go/client_test.go b/client_go/client_test.go new file mode 100644 index 0000000..70fc8fe --- /dev/null +++ b/client_go/client_test.go @@ -0,0 +1,473 @@ +package topical + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "sync/atomic" + "testing" + "time" + + "github.com/coder/websocket" +) + +// mockServer creates a test WebSocket server that echoes back protocol messages +// according to the Topical protocol. +type mockServer struct { + server *httptest.Server + // handler is called for each incoming message; it returns messages to send back. + handler func(msg []any) [][]any +} + +func newMockServer(handler func(msg []any) [][]any) *mockServer { + ms := &mockServer{handler: handler} + ms.server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := websocket.Accept(w, r, nil) + if err != nil { + return + } + defer conn.CloseNow() + + ctx := r.Context() + for { + _, data, err := conn.Read(ctx) + if err != nil { + return + } + var msg []any + if err := json.Unmarshal(data, &msg); err != nil { + continue + } + responses := ms.handler(msg) + for _, resp := range responses { + respData, _ := json.Marshal(resp) + conn.Write(ctx, websocket.MessageText, respData) + } + } + })) + return ms +} + +func (ms *mockServer) wsURL() string { + return "ws" + strings.TrimPrefix(ms.server.URL, "http") +} + +func (ms *mockServer) close() { + ms.server.Close() +} + +func TestConnectAndClose(t *testing.T) { + t.Parallel() + ms := newMockServer(func(msg []any) [][]any { return nil }) + defer ms.close() + + ctx := context.Background() + client, err := Connect(ctx, ms.wsURL(), WithReconnect(false)) + if err != nil { + t.Fatal(err) + } + if client.State() != Connected { + t.Errorf("expected Connected, got %v", client.State()) + } + client.Close() +} + +func TestSubscribeReceivesReset(t *testing.T) { + t.Parallel() + ms := newMockServer(func(msg []any) [][]any { + opcode := int(msg[0].(float64)) + if opcode == reqSubscribe { + channelID := msg[1].(float64) + // Send a topic reset with initial value + return [][]any{ + {float64(respTopicReset), channelID, map[string]any{"items": map[string]any{}, "order": []any{}}}, + } + } + return nil + }) + defer ms.close() + + ctx := context.Background() + client, err := Connect(ctx, ms.wsURL(), WithReconnect(false)) + if err != nil { + t.Fatal(err) + } + defer client.Close() + + sub := client.Subscribe("lists/test", nil) + defer sub.Unsubscribe() + + select { + case val := <-sub.Values(): + m, ok := val.(map[string]any) + if !ok { + t.Fatalf("expected map, got %T", val) + } + if _, ok := m["items"]; !ok { + t.Error("expected 'items' key in value") + } + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for value") + } +} + +func TestSubscribeReceivesUpdates(t *testing.T) { + t.Parallel() + ms := newMockServer(func(msg []any) [][]any { + opcode := int(msg[0].(float64)) + if opcode == reqSubscribe { + channelID := msg[1].(float64) + return [][]any{ + // Initial reset + {float64(respTopicReset), channelID, map[string]any{"count": float64(0)}}, + // Then an update + {float64(respTopicUpdates), channelID, []any{ + []any{float64(0), []any{"count"}, float64(1)}, + }}, + } + } + return nil + }) + defer ms.close() + + ctx := context.Background() + client, err := Connect(ctx, ms.wsURL(), WithReconnect(false)) + if err != nil { + t.Fatal(err) + } + defer client.Close() + + sub := client.Subscribe("counter", nil) + defer sub.Unsubscribe() + + // Should eventually get the updated value (count=1) + deadline := time.After(2 * time.Second) + for { + select { + case val := <-sub.Values(): + m := val.(map[string]any) + if m["count"] == float64(1) { + return // success + } + case <-deadline: + t.Fatal("timeout waiting for updated value") + } + } +} + +func TestExecute(t *testing.T) { + t.Parallel() + ms := newMockServer(func(msg []any) [][]any { + opcode := int(msg[0].(float64)) + if opcode == reqExecute { + channelID := msg[1].(float64) + return [][]any{ + {float64(respResult), channelID, "hello"}, + } + } + return nil + }) + defer ms.close() + + ctx := context.Background() + client, err := Connect(ctx, ms.wsURL(), WithReconnect(false)) + if err != nil { + t.Fatal(err) + } + defer client.Close() + + result, err := client.Execute(ctx, "lists/test", "greet", []any{"world"}, nil) + if err != nil { + t.Fatal(err) + } + if result != "hello" { + t.Errorf("expected 'hello', got %v", result) + } +} + +func TestExecuteError(t *testing.T) { + t.Parallel() + ms := newMockServer(func(msg []any) [][]any { + opcode := int(msg[0].(float64)) + if opcode == reqExecute { + channelID := msg[1].(float64) + return [][]any{ + {float64(respError), channelID, "not_found"}, + } + } + return nil + }) + defer ms.close() + + ctx := context.Background() + client, err := Connect(ctx, ms.wsURL(), WithReconnect(false)) + if err != nil { + t.Fatal(err) + } + defer client.Close() + + _, err = client.Execute(ctx, "lists/test", "missing", nil, nil) + if err == nil { + t.Fatal("expected error") + } + if err.Error() != "not_found" { + t.Errorf("expected 'not_found', got %v", err) + } +} + +func TestNotify(t *testing.T) { + t.Parallel() + received := make(chan []any, 1) + ms := newMockServer(func(msg []any) [][]any { + opcode := int(msg[0].(float64)) + if opcode == reqNotify { + received <- msg + } + return nil + }) + defer ms.close() + + ctx := context.Background() + client, err := Connect(ctx, ms.wsURL(), WithReconnect(false)) + if err != nil { + t.Fatal(err) + } + defer client.Close() + + err = client.Notify("lists/test", "ping", []any{"data"}, nil) + if err != nil { + t.Fatal(err) + } + + select { + case msg := <-received: + action := msg[2].(string) + if action != "ping" { + t.Errorf("expected action 'ping', got %v", action) + } + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for notify") + } +} + +func TestTypedSubscription(t *testing.T) { + t.Parallel() + ms := newMockServer(func(msg []any) [][]any { + opcode := int(msg[0].(float64)) + if opcode == reqSubscribe { + channelID := msg[1].(float64) + return [][]any{ + {float64(respTopicReset), channelID, map[string]any{ + "items": map[string]any{"a": map[string]any{"text": "hello"}}, + "order": []any{"a"}, + }}, + } + } + return nil + }) + defer ms.close() + + type Item struct { + Text string `json:"text"` + } + type List struct { + Items map[string]Item `json:"items"` + Order []string `json:"order"` + } + + ctx := context.Background() + client, err := Connect(ctx, ms.wsURL(), WithReconnect(false)) + if err != nil { + t.Fatal(err) + } + defer client.Close() + + sub := Subscribe[List](client, "lists/test", nil) + defer sub.Unsubscribe() + + select { + case list := <-sub.Values(): + if len(list.Items) != 1 { + t.Errorf("expected 1 item, got %d", len(list.Items)) + } + if list.Items["a"].Text != "hello" { + t.Errorf("expected 'hello', got %s", list.Items["a"].Text) + } + if len(list.Order) != 1 || list.Order[0] != "a" { + t.Errorf("unexpected order: %v", list.Order) + } + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for typed value") + } +} + +func TestSubscriptionDedup(t *testing.T) { + t.Parallel() + var subscribeCount atomic.Int32 + ms := newMockServer(func(msg []any) [][]any { + opcode := int(msg[0].(float64)) + if opcode == reqSubscribe { + subscribeCount.Add(1) + channelID := msg[1].(float64) + return [][]any{ + {float64(respTopicReset), channelID, map[string]any{"count": float64(0)}}, + } + } + return nil + }) + defer ms.close() + + ctx := context.Background() + client, err := Connect(ctx, ms.wsURL(), WithReconnect(false)) + if err != nil { + t.Fatal(err) + } + defer client.Close() + + sub1 := client.Subscribe("counter", nil) + // Wait for first value + select { + case <-sub1.Values(): + case <-time.After(2 * time.Second): + t.Fatal("timeout") + } + + sub2 := client.Subscribe("counter", nil) + // Second subscriber should immediately get the cached value + select { + case <-sub2.Values(): + case <-time.After(2 * time.Second): + t.Fatal("timeout") + } + + if count := subscribeCount.Load(); count != 1 { + t.Errorf("expected 1 server subscribe, got %d", count) + } + + sub1.Unsubscribe() + sub2.Unsubscribe() +} + +func TestTopicAlias(t *testing.T) { + t.Parallel() + ms := newMockServer(func(msg []any) [][]any { + opcode := int(msg[0].(float64)) + if opcode == reqSubscribe { + channelID := msg[1].(float64) + topicName := msg[2].(string) + if topicName == "first" { + // First subscription gets a reset + return [][]any{ + {float64(respTopicReset), channelID, map[string]any{"data": "hello"}}, + } + } + if topicName == "second" { + // Second subscription is an alias to the first (channelID 1) + return [][]any{ + {float64(respTopicAlias), channelID, float64(1)}, + } + } + } + return nil + }) + defer ms.close() + + ctx := context.Background() + client, err := Connect(ctx, ms.wsURL(), WithReconnect(false)) + if err != nil { + t.Fatal(err) + } + defer client.Close() + + sub1 := client.Subscribe("first", nil) + select { + case <-sub1.Values(): + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for first value") + } + + sub2 := client.Subscribe("second", nil) + // Should receive the aliased value from the first topic + select { + case val := <-sub2.Values(): + m := val.(map[string]any) + if m["data"] != "hello" { + t.Errorf("expected 'hello', got %v", m["data"]) + } + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for aliased value") + } + + sub1.Unsubscribe() + sub2.Unsubscribe() +} + +func TestCloseUnblocksSubscribers(t *testing.T) { + t.Parallel() + ms := newMockServer(func(msg []any) [][]any { + opcode := int(msg[0].(float64)) + if opcode == reqSubscribe { + channelID := msg[1].(float64) + return [][]any{ + {float64(respTopicReset), channelID, map[string]any{"data": "initial"}}, + } + } + return nil + }) + defer ms.close() + + ctx := context.Background() + client, err := Connect(ctx, ms.wsURL(), WithReconnect(false)) + if err != nil { + t.Fatal(err) + } + + sub := client.Subscribe("test", nil) + // Drain initial value + select { + case <-sub.Values(): + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for initial value") + } + + // Close should cause Values() channel to be closed + client.Close() + + select { + case _, ok := <-sub.Values(): + if ok { + t.Error("expected channel to be closed") + } + case <-time.After(2 * time.Second): + t.Fatal("timeout: Values() channel was not closed after Client.Close()") + } +} + +func TestExecuteTimeout(t *testing.T) { + t.Parallel() + ms := newMockServer(func(msg []any) [][]any { + // Never respond to execute requests + return nil + }) + defer ms.close() + + ctx := context.Background() + client, err := Connect(ctx, ms.wsURL(), WithReconnect(false)) + if err != nil { + t.Fatal(err) + } + defer client.Close() + + execCtx, cancel := context.WithTimeout(ctx, 100*time.Millisecond) + defer cancel() + + _, err = client.Execute(execCtx, "lists/test", "slow", nil, nil) + if err == nil { + t.Fatal("expected error from timeout") + } + if err != context.DeadlineExceeded { + t.Errorf("expected DeadlineExceeded, got %v", err) + } +} diff --git a/client_go/doc.go b/client_go/doc.go new file mode 100644 index 0000000..6beacc0 --- /dev/null +++ b/client_go/doc.go @@ -0,0 +1,4 @@ +// Package topical provides a Go client for Topical, a real-time state +// synchronization library. It connects to a Topical server over WebSocket +// and keeps local state in sync with efficient diff-based updates. +package topical diff --git a/client_go/execute.go b/client_go/execute.go new file mode 100644 index 0000000..25310e8 --- /dev/null +++ b/client_go/execute.go @@ -0,0 +1,69 @@ +package topical + +import "context" + +// Execute sends an RPC-style request and blocks until the server responds. +// The context controls the timeout. +func (c *Client) Execute(ctx context.Context, topic string, action string, args []any, params Params) (any, error) { + c.mu.Lock() + if c.state != Connected { + c.mu.Unlock() + return nil, ErrNotConnected + } + channelID := c.nextChannelID() + req := &request{ + result: make(chan any, 1), + err: make(chan error, 1), + } + c.requests[channelID] = req + c.mu.Unlock() + + if args == nil { + args = []any{} + } + data, err := encodeExecute(channelID, topic, action, args, params) + if err != nil { + c.mu.Lock() + delete(c.requests, channelID) + c.mu.Unlock() + return nil, err + } + + if err := c.send(data); err != nil { + c.mu.Lock() + delete(c.requests, channelID) + c.mu.Unlock() + return nil, err + } + + select { + case result := <-req.result: + return result, nil + case err := <-req.err: + return nil, err + case <-ctx.Done(): + c.mu.Lock() + delete(c.requests, channelID) + c.mu.Unlock() + return nil, ctx.Err() + } +} + +// Notify sends a fire-and-forget notification. +func (c *Client) Notify(topic string, action string, args []any, params Params) error { + c.mu.Lock() + if c.state != Connected { + c.mu.Unlock() + return ErrNotConnected + } + c.mu.Unlock() + + if args == nil { + args = []any{} + } + data, err := encodeNotify(topic, action, args, params) + if err != nil { + return err + } + return c.send(data) +} diff --git a/client_go/go.mod b/client_go/go.mod new file mode 100644 index 0000000..a5e3622 --- /dev/null +++ b/client_go/go.mod @@ -0,0 +1,5 @@ +module github.com/joefreeman/topical/client_go + +go 1.23 + +require github.com/coder/websocket v1.8.12 diff --git a/client_go/go.sum b/client_go/go.sum new file mode 100644 index 0000000..029cf47 --- /dev/null +++ b/client_go/go.sum @@ -0,0 +1,2 @@ +github.com/coder/websocket v1.8.12 h1:5bUXkEPPIbewrnkU8LTCLVaxi4N4J8ahufH2vlo4NAo= +github.com/coder/websocket v1.8.12/go.mod h1:LNVeNrXQZfe5qhS9ALED3uA+l5pPqvwXg3CKoDBB2gs= diff --git a/client_go/protocol.go b/client_go/protocol.go new file mode 100644 index 0000000..e581991 --- /dev/null +++ b/client_go/protocol.go @@ -0,0 +1,71 @@ +package topical + +import "encoding/json" + +// Request opcodes (client -> server) +const ( + reqNotify = 0 + reqExecute = 1 + reqSubscribe = 2 + reqUnsubscribe = 3 +) + +// Response opcodes (server -> client) +const ( + respError = 0 + respResult = 1 + respTopicReset = 2 + respTopicUpdates = 3 + respTopicAlias = 4 +) + +func encodeNotify(topic string, action string, args []any, params Params) ([]byte, error) { + var msg []any + if len(params) > 0 { + msg = []any{reqNotify, topic, action, args, params} + } else { + msg = []any{reqNotify, topic, action, args} + } + return json.Marshal(msg) +} + +func encodeExecute(channelID int, topic string, action string, args []any, params Params) ([]byte, error) { + var msg []any + if len(params) > 0 { + msg = []any{reqExecute, channelID, topic, action, args, params} + } else { + msg = []any{reqExecute, channelID, topic, action, args} + } + return json.Marshal(msg) +} + +func encodeSubscribe(channelID int, topic string, params Params) ([]byte, error) { + var msg []any + if len(params) > 0 { + msg = []any{reqSubscribe, channelID, topic, params} + } else { + msg = []any{reqSubscribe, channelID, topic} + } + return json.Marshal(msg) +} + +func encodeUnsubscribe(channelID int) ([]byte, error) { + return json.Marshal([]any{reqUnsubscribe, channelID}) +} + +// decodeResponse parses a server message and returns the opcode and raw fields. +// The caller is responsible for interpreting fields based on the opcode. +func decodeResponse(data []byte) (int, []json.RawMessage, error) { + var raw []json.RawMessage + if err := json.Unmarshal(data, &raw); err != nil { + return 0, nil, err + } + if len(raw) < 2 { + return 0, nil, ErrInvalidMessage + } + var opcode int + if err := json.Unmarshal(raw[0], &opcode); err != nil { + return 0, nil, err + } + return opcode, raw[1:], nil +} diff --git a/client_go/protocol_test.go b/client_go/protocol_test.go new file mode 100644 index 0000000..76e015f --- /dev/null +++ b/client_go/protocol_test.go @@ -0,0 +1,142 @@ +package topical + +import ( + "encoding/json" + "testing" +) + +func TestEncodeNotifyWithoutParams(t *testing.T) { + t.Parallel() + data, err := encodeNotify("lists/abc", "add_item", []any{"hello"}, nil) + if err != nil { + t.Fatal(err) + } + var msg []any + json.Unmarshal(data, &msg) + if int(msg[0].(float64)) != reqNotify { + t.Errorf("expected opcode %d, got %v", reqNotify, msg[0]) + } + if len(msg) != 4 { + t.Errorf("expected 4 fields without params, got %d", len(msg)) + } +} + +func TestEncodeNotifyWithParams(t *testing.T) { + t.Parallel() + data, err := encodeNotify("lists/abc", "add_item", []any{"hello"}, Params{"user": "joe"}) + if err != nil { + t.Fatal(err) + } + var msg []any + json.Unmarshal(data, &msg) + if len(msg) != 5 { + t.Errorf("expected 5 fields with params, got %d", len(msg)) + } +} + +func TestEncodeExecuteWithoutParams(t *testing.T) { + t.Parallel() + data, err := encodeExecute(42, "lists/abc", "get_item", []any{1}, nil) + if err != nil { + t.Fatal(err) + } + var msg []any + json.Unmarshal(data, &msg) + if int(msg[0].(float64)) != reqExecute { + t.Errorf("expected opcode %d, got %v", reqExecute, msg[0]) + } + if int(msg[1].(float64)) != 42 { + t.Errorf("expected channelID 42, got %v", msg[1]) + } + if len(msg) != 5 { + t.Errorf("expected 5 fields without params, got %d", len(msg)) + } +} + +func TestEncodeSubscribeWithoutParams(t *testing.T) { + t.Parallel() + data, err := encodeSubscribe(1, "lists/abc", nil) + if err != nil { + t.Fatal(err) + } + var msg []any + json.Unmarshal(data, &msg) + if int(msg[0].(float64)) != reqSubscribe { + t.Errorf("expected opcode %d, got %v", reqSubscribe, msg[0]) + } + if len(msg) != 3 { + t.Errorf("expected 3 fields without params, got %d", len(msg)) + } +} + +func TestEncodeSubscribeWithParams(t *testing.T) { + t.Parallel() + data, err := encodeSubscribe(1, "lists/abc", Params{"key": "val"}) + if err != nil { + t.Fatal(err) + } + var msg []any + json.Unmarshal(data, &msg) + if len(msg) != 4 { + t.Errorf("expected 4 fields with params, got %d", len(msg)) + } +} + +func TestEncodeUnsubscribe(t *testing.T) { + t.Parallel() + data, err := encodeUnsubscribe(7) + if err != nil { + t.Fatal(err) + } + var msg []any + json.Unmarshal(data, &msg) + if int(msg[0].(float64)) != reqUnsubscribe { + t.Errorf("expected opcode %d, got %v", reqUnsubscribe, msg[0]) + } + if int(msg[1].(float64)) != 7 { + t.Errorf("expected channelID 7, got %v", msg[1]) + } +} + +func TestDecodeResponse(t *testing.T) { + t.Parallel() + data := []byte(`[2, 1, {"items": {}}]`) + opcode, fields, err := decodeResponse(data) + if err != nil { + t.Fatal(err) + } + if opcode != respTopicReset { + t.Errorf("expected opcode %d, got %d", respTopicReset, opcode) + } + if len(fields) != 2 { + t.Errorf("expected 2 fields, got %d", len(fields)) + } + var channelID int + json.Unmarshal(fields[0], &channelID) + if channelID != 1 { + t.Errorf("expected channelID 1, got %d", channelID) + } +} + +func TestDecodeResponseTooShort(t *testing.T) { + t.Parallel() + data := []byte(`[2]`) + _, _, err := decodeResponse(data) + if err == nil { + t.Error("expected error for short message") + } +} + +func TestTopicKey(t *testing.T) { + t.Parallel() + key := topicKey("lists/abc", nil) + if key != "lists/abc?" { + t.Errorf("unexpected key: %s", key) + } + + key2 := topicKey("lists/abc", Params{"b": "2", "a": "1"}) + expected := "lists/abc?a=1&b=2" + if key2 != expected { + t.Errorf("expected %s, got %s", expected, key2) + } +} diff --git a/client_go/subscribe.go b/client_go/subscribe.go new file mode 100644 index 0000000..7fd182b --- /dev/null +++ b/client_go/subscribe.go @@ -0,0 +1,100 @@ +package topical + +// Subscription delivers untyped topic values. +type Subscription struct { + client *Client + key string + listener *listener +} + +// Subscribe creates a subscription to the given topic. Multiple calls with the +// same topic and params share a single server subscription (reference counted). +// The topic is a slash-separated path (e.g. "lists/my-list"). +func (c *Client) Subscribe(topic string, params Params) *Subscription { + key := topicKey(topic, params) + + l := &listener{ + values: make(chan any, 1), + errors: make(chan error, 1), + } + + c.mu.Lock() + defer c.mu.Unlock() + + if t, ok := c.topics[key]; ok { + // Existing topic - add listener + t.listeners = append(t.listeners, l) + if t.hasValue { + sendReplace(l.values, t.value) + } + } else { + // New topic + t := &topicEntry{ + listeners: []*listener{l}, + topic: topic, + params: params, + } + c.topics[key] = t + if c.state == Connected { + c.setupSubscriptionLocked(key) + } + } + + return &Subscription{ + client: c, + key: key, + listener: l, + } +} + +// Values returns a channel that receives the latest topic value on each change. +func (s *Subscription) Values() <-chan any { + return s.listener.values +} + +// Err returns a channel that receives server-side topic errors. +func (s *Subscription) Err() <-chan error { + return s.listener.errors +} + +// Unsubscribe removes this subscription. When the last subscriber for a topic +// leaves, the server subscription is also cancelled. +func (s *Subscription) Unsubscribe() { + c := s.client + c.mu.Lock() + defer c.mu.Unlock() + + // Already closed (e.g., by Client.Close) + if s.listener.closed { + return + } + + t, ok := c.topics[s.key] + if !ok { + return + } + + // Remove this listener + for i, l := range t.listeners { + if l == s.listener { + t.listeners = append(t.listeners[:i], t.listeners[i+1:]...) + break + } + } + + // If no more listeners, unsubscribe from server + if len(t.listeners) == 0 { + if t.channelID != 0 && c.state == Connected { + data, err := encodeUnsubscribe(t.channelID) + if err == nil { + c.sendLocked(data) + } + delete(c.subscriptions, t.channelID) + } + delete(c.topics, s.key) + } + + s.listener.closed = true + close(s.listener.values) + close(s.listener.errors) +} \ No newline at end of file diff --git a/client_go/typed.go b/client_go/typed.go new file mode 100644 index 0000000..0322cfd --- /dev/null +++ b/client_go/typed.go @@ -0,0 +1,71 @@ +package topical + +import ( + "encoding/json" + "fmt" +) + +// TypedSubscription delivers typed topic values, converting from the internal +// untyped representation via a JSON round-trip. +type TypedSubscription[T any] struct { + sub *Subscription + values chan T + errors chan error +} + +// Subscribe is a generic wrapper that converts untyped topic values to T +// via JSON marshaling/unmarshaling on each update. +func Subscribe[T any](c *Client, topic string, params Params) *TypedSubscription[T] { + sub := c.Subscribe(topic, params) + ts := &TypedSubscription[T]{ + sub: sub, + values: make(chan T, 1), + errors: make(chan error, 1), + } + go ts.convert() + return ts +} + +func (ts *TypedSubscription[T]) convert() { + defer close(ts.values) + defer close(ts.errors) + for { + select { + case v, ok := <-ts.sub.Values(): + if !ok { + return + } + data, err := json.Marshal(v) + if err != nil { + sendNonBlocking(ts.errors, fmt.Errorf("topical: marshal: %w", err)) + continue + } + var typed T + if err := json.Unmarshal(data, &typed); err != nil { + sendNonBlocking(ts.errors, fmt.Errorf("topical: unmarshal: %w", err)) + continue + } + sendReplace(ts.values, typed) + case err, ok := <-ts.sub.Err(): + if !ok { + return + } + sendNonBlocking(ts.errors, err) + } + } +} + +// Values returns a channel that receives the latest typed topic value on each change. +func (ts *TypedSubscription[T]) Values() <-chan T { + return ts.values +} + +// Err returns a channel that receives subscription and conversion errors. +func (ts *TypedSubscription[T]) Err() <-chan error { + return ts.errors +} + +// Unsubscribe removes this subscription. +func (ts *TypedSubscription[T]) Unsubscribe() { + ts.sub.Unsubscribe() +} diff --git a/client_go/updates.go b/client_go/updates.go new file mode 100644 index 0000000..18aa5c7 --- /dev/null +++ b/client_go/updates.go @@ -0,0 +1,201 @@ +package topical + +import "fmt" + +// updateIn traverses a nested structure along path, applying callback at the leaf. +// Path elements are strings (map keys) or float64 (slice indices, from JSON). +func updateIn(value any, path []any, callback func(any) (any, error)) (any, error) { + if len(path) == 0 { + return callback(value) + } + key := path[0] + rest := path[1:] + + switch k := key.(type) { + case float64: + idx := int(k) + slice, ok := value.([]any) + if !ok { + return nil, fmt.Errorf("expected array, got %T", value) + } + if idx < 0 || idx >= len(slice) { + return nil, fmt.Errorf("index %d out of range (len %d)", idx, len(slice)) + } + updated, err := updateIn(slice[idx], rest, callback) + if err != nil { + return nil, err + } + result := make([]any, len(slice)) + copy(result, slice) + result[idx] = updated + return result, nil + + case string: + m, ok := value.(map[string]any) + if !ok { + if value == nil { + // Handle nil by creating a new map + m = map[string]any{} + } else { + return nil, fmt.Errorf("expected map, got %T", value) + } + } + updated, err := updateIn(m[k], rest, callback) + if err != nil { + return nil, err + } + result := make(map[string]any, len(m)+1) + for mk, mv := range m { + result[mk] = mv + } + result[k] = updated + return result, nil + + default: + return nil, fmt.Errorf("invalid path element type: %T", key) + } +} + +// applyUpdate applies a single update operation to a value. +// Update formats: +// +// [0, path, value] - set +// [1, path, key] - unset (delete key from map) +// [2, path, index, vals] - insert into slice (null index = append) +// [3, path, index, count]- delete from slice +// [4, path, value] - merge (shallow) into map +func applyUpdate(current any, update []any) (any, error) { + if len(update) < 3 { + return nil, fmt.Errorf("update too short: %v", update) + } + opcodeF, ok := update[0].(float64) + if !ok { + return nil, fmt.Errorf("invalid update opcode type: %T", update[0]) + } + opcode := int(opcodeF) + path, ok := toPath(update[1]) + if !ok { + return nil, fmt.Errorf("invalid update path: %v", update[1]) + } + + switch opcode { + case 0: // set + val := update[2] + return updateIn(current, path, func(_ any) (any, error) { + return val, nil + }) + + case 1: // unset + key, ok := update[2].(string) + if !ok { + return nil, fmt.Errorf("unset key must be string, got %T", update[2]) + } + return updateIn(current, path, func(value any) (any, error) { + m, ok := value.(map[string]any) + if !ok { + return nil, fmt.Errorf("expected map for unset, got %T", value) + } + result := make(map[string]any, len(m)) + for k, v := range m { + if k != key { + result[k] = v + } + } + return result, nil + }) + + case 2: // insert + if len(update) < 4 { + return nil, fmt.Errorf("insert update too short") + } + values, ok := update[3].([]any) + if !ok { + return nil, fmt.Errorf("insert values must be array, got %T", update[3]) + } + return updateIn(current, path, func(value any) (any, error) { + list, ok := value.([]any) + if !ok { + return nil, fmt.Errorf("expected array for insert, got %T", value) + } + var idx int + if update[2] == nil { + idx = len(list) + } else { + f, ok := update[2].(float64) + if !ok { + return nil, fmt.Errorf("insert index must be number or null, got %T", update[2]) + } + idx = int(f) + } + if idx < 0 || idx > len(list) { + return nil, fmt.Errorf("insert index %d out of range (len %d)", idx, len(list)) + } + result := make([]any, 0, len(list)+len(values)) + result = append(result, list[:idx]...) + result = append(result, values...) + result = append(result, list[idx:]...) + return result, nil + }) + + case 3: // delete + if len(update) < 4 { + return nil, fmt.Errorf("delete update too short") + } + idxF, ok := update[2].(float64) + if !ok { + return nil, fmt.Errorf("delete index must be number, got %T", update[2]) + } + countF, ok := update[3].(float64) + if !ok { + return nil, fmt.Errorf("delete count must be number, got %T", update[3]) + } + idx := int(idxF) + count := int(countF) + return updateIn(current, path, func(value any) (any, error) { + list, ok := value.([]any) + if !ok { + return nil, fmt.Errorf("expected array for delete, got %T", value) + } + if idx < 0 || idx+count > len(list) { + return nil, fmt.Errorf("delete range [%d:%d] out of range (len %d)", idx, idx+count, len(list)) + } + result := make([]any, 0, len(list)-count) + result = append(result, list[:idx]...) + result = append(result, list[idx+count:]...) + return result, nil + }) + + case 4: // merge + mergeVal, ok := update[2].(map[string]any) + if !ok { + return nil, fmt.Errorf("merge value must be map, got %T", update[2]) + } + return updateIn(current, path, func(value any) (any, error) { + existing, ok := value.(map[string]any) + if !ok { + // If existing is nil/non-map, start with empty map + existing = map[string]any{} + } + result := make(map[string]any, len(existing)+len(mergeVal)) + for k, v := range existing { + result[k] = v + } + for k, v := range mergeVal { + result[k] = v + } + return result, nil + }) + + default: + return nil, fmt.Errorf("unhandled update opcode: %d", opcode) + } +} + +// toPath converts a JSON-decoded path ([]any of strings and float64s) into []any. +func toPath(v any) ([]any, bool) { + arr, ok := v.([]any) + if !ok { + return nil, false + } + return arr, true +} diff --git a/client_go/updates_test.go b/client_go/updates_test.go new file mode 100644 index 0000000..b5d6a89 --- /dev/null +++ b/client_go/updates_test.go @@ -0,0 +1,171 @@ +package topical + +import ( + "encoding/json" + "testing" +) + +// helper: parse JSON string into any (mimics what we get from the wire) +func j(s string) any { + var v any + if err := json.Unmarshal([]byte(s), &v); err != nil { + panic(err) + } + return v +} + +func TestSetRootValue(t *testing.T) { + t.Parallel() + current := j(`{"foo": 1}`) + update := j(`[0, [], 2]`).([]any) + result, err := applyUpdate(current, update) + if err != nil { + t.Fatal(err) + } + assertJSON(t, result, `2`) +} + +func TestSetNewValue(t *testing.T) { + t.Parallel() + current := j(`{}`) + update := j(`[0, ["foo", "bar"], 2]`).([]any) + result, err := applyUpdate(current, update) + if err != nil { + t.Fatal(err) + } + assertJSON(t, result, `{"foo": {"bar": 2}}`) +} + +func TestSetValueWithinList(t *testing.T) { + t.Parallel() + current := j(`{"foo": [0, {"bar": 1}, 2]}`) + update := j(`[0, ["foo", 1, "bar"], 3]`).([]any) + result, err := applyUpdate(current, update) + if err != nil { + t.Fatal(err) + } + assertJSON(t, result, `{"foo": [0, {"bar": 3}, 2]}`) +} + +func TestReplaceExistingValue(t *testing.T) { + t.Parallel() + current := j(`{"foo": {"bar": 1, "baz": 2}}`) + update := j(`[0, ["foo", "bar"], 3]`).([]any) + result, err := applyUpdate(current, update) + if err != nil { + t.Fatal(err) + } + assertJSON(t, result, `{"foo": {"bar": 3, "baz": 2}}`) +} + +func TestUnsetValue(t *testing.T) { + t.Parallel() + current := j(`{"foo": {"bar": 2}}`) + update := j(`[1, ["foo"], "bar"]`).([]any) + result, err := applyUpdate(current, update) + if err != nil { + t.Fatal(err) + } + assertJSON(t, result, `{"foo": {}}`) +} + +func TestUnsetValueWithinList(t *testing.T) { + t.Parallel() + current := j(`{"foo": [0, {"bar": 1}, 2]}`) + update := j(`[1, ["foo", 1], "bar"]`).([]any) + result, err := applyUpdate(current, update) + if err != nil { + t.Fatal(err) + } + assertJSON(t, result, `{"foo": [0, {}, 2]}`) +} + +func TestResetValue(t *testing.T) { + t.Parallel() + current := j(`{"foo": {"bar": 2}}`) + update := j(`[0, [], null]`).([]any) + result, err := applyUpdate(current, update) + if err != nil { + t.Fatal(err) + } + if result != nil { + t.Fatalf("expected nil, got %v", result) + } +} + +func TestInsertIntoList(t *testing.T) { + t.Parallel() + current := j(`{"foo": [0, 1, 2]}`) + update := j(`[2, ["foo"], 1, [3, 4]]`).([]any) + result, err := applyUpdate(current, update) + if err != nil { + t.Fatal(err) + } + assertJSON(t, result, `{"foo": [0, 3, 4, 1, 2]}`) +} + +func TestDeleteFromList(t *testing.T) { + t.Parallel() + current := j(`{"foo": [0, 1, 2, 3]}`) + update := j(`[3, ["foo"], 1, 2]`).([]any) + result, err := applyUpdate(current, update) + if err != nil { + t.Fatal(err) + } + assertJSON(t, result, `{"foo": [0, 3]}`) +} + +func TestMergeValue(t *testing.T) { + t.Parallel() + current := j(`{"foo": {"bar": {"a": 1, "b": 2}}}`) + update := j(`[4, ["foo", "bar"], {"b": 3, "c": 4}]`).([]any) + result, err := applyUpdate(current, update) + if err != nil { + t.Fatal(err) + } + assertJSON(t, result, `{"foo": {"bar": {"a": 1, "b": 3, "c": 4}}}`) +} + +func TestMergeNonExistingValue(t *testing.T) { + t.Parallel() + current := j(`{"foo": {}}`) + update := j(`[4, ["foo", "bar"], {"a": 1}]`).([]any) + result, err := applyUpdate(current, update) + if err != nil { + t.Fatal(err) + } + assertJSON(t, result, `{"foo": {"bar": {"a": 1}}}`) +} + +func TestInsertAppend(t *testing.T) { + t.Parallel() + current := j(`{"foo": [0, 1]}`) + update := j(`[2, ["foo"], null, [2, 3]]`).([]any) + result, err := applyUpdate(current, update) + if err != nil { + t.Fatal(err) + } + assertJSON(t, result, `{"foo": [0, 1, 2, 3]}`) +} + +// assertJSON checks that the JSON representation of got matches the expected JSON string. +func assertJSON(t *testing.T, got any, expectedJSON string) { + t.Helper() + gotBytes, err := json.Marshal(got) + if err != nil { + t.Fatalf("failed to marshal result: %v", err) + } + // Normalize both by unmarshaling and remarshaling + var gotNorm, expNorm any + if err := json.Unmarshal(gotBytes, &gotNorm); err != nil { + t.Fatalf("failed to unmarshal result: %v", err) + } + if err := json.Unmarshal([]byte(expectedJSON), &expNorm); err != nil { + t.Fatalf("failed to unmarshal expected: %v", err) + } + gotNormBytes, _ := json.Marshal(gotNorm) + expNormBytes, _ := json.Marshal(expNorm) + if string(gotNormBytes) != string(expNormBytes) { + t.Errorf("mismatch:\n got: %s\n expected: %s", gotNormBytes, expNormBytes) + } +}