diff --git a/client/client.go b/client/client.go index 92971f4..ea7f747 100644 --- a/client/client.go +++ b/client/client.go @@ -4,9 +4,11 @@ package client import ( + "crypto/tls" "errors" "fmt" "log/slog" + "net/http" "sync" "sync/atomic" "time" @@ -113,6 +115,8 @@ type Client struct { state int32 + tlsConfig *tls.Config + mut sync.RWMutex } @@ -125,6 +129,21 @@ func WithLogger(log *slog.Logger) Option { } } +// WithTLSConfig lets the caller set an optional TLS configuration for connections +// to Mattermost. This is needed when the server uses a self-signed or private CA +// certificate. The config is applied to both the HTTP API client and the WebSocket +// connection. +func WithTLSConfig(tlsConfig *tls.Config) Option { + return func(c *Client) error { + transport := http.DefaultTransport.(*http.Transport).Clone() + transport.TLSClientConfig = tlsConfig + transport.ForceAttemptHTTP2 = true + c.apiClient.HTTPClient = &http.Client{Transport: transport} + c.tlsConfig = tlsConfig + return nil + } +} + // New initializes and returns a new Calls client. func New(cfg Config, opts ...Option) (*Client, error) { if err := cfg.Parse(); err != nil { diff --git a/client/websocket.go b/client/websocket.go index 301dcbb..3ab00e9 100644 --- a/client/websocket.go +++ b/client/websocket.go @@ -330,11 +330,15 @@ func (c *Client) handleWSMsg(msg ws.Message) error { } func (c *Client) wsOpen() error { + wsOpts := []ws.ClientOption{ws.WithLogger(c.log)} + if c.tlsConfig != nil { + wsOpts = append(wsOpts, ws.WithTLSConfig(c.tlsConfig)) + } ws, err := ws.NewClient(ws.ClientConfig{ URL: c.cfg.wsURL, AuthToken: c.cfg.AuthToken, AuthType: ws.BearerClientAuthType, - }, ws.WithLogger(c.log)) + }, wsOpts...) if err != nil { return fmt.Errorf("failed to create websocket client: %w", err) } diff --git a/client/websocket_test.go b/client/websocket_test.go index 77cc5c2..c99e65d 100644 --- a/client/websocket_test.go +++ b/client/websocket_test.go @@ -4,6 +4,8 @@ package client import ( + "fmt" + "net" "testing" "time" @@ -115,12 +117,21 @@ func TestClientWSReconnectTimeout(t *testing.T) { require.Fail(t, "timed out waiting for connect event") } - th.userClient.cfg.wsURL = "ws://localhost:8080" + // Bind a listener to get an unused port, then close it so the port gives + // immediate ECONNREFUSED on reconnect (avoids slow TCP timeouts from + // non-routable IPs, and avoids accidentally hitting a real server). + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + unusedAddr := ln.Addr().String() + ln.Close() + th.userClient.cfg.wsURL = fmt.Sprintf("ws://%s", unusedAddr) - errorCh := make(chan struct{}) + errorCh := make(chan error, 1) err = th.userClient.On(ErrorEvent, func(ctx any) error { - close(errorCh) - require.EqualError(t, ctx.(error), "ws reconnection timeout reached") + select { + case errorCh <- ctx.(error): + default: + } return nil }) require.NoError(t, err) @@ -136,7 +147,8 @@ func TestClientWSReconnectTimeout(t *testing.T) { require.NoError(t, err) select { - case <-errorCh: + case err := <-errorCh: + require.EqualError(t, err, "ws reconnection timeout reached") case <-time.After(wsReconnectionTimeout * 2): require.Fail(t, "timed out waiting for error event") } diff --git a/service/ws/client.go b/service/ws/client.go index 3dc8dd3..ebfd96c 100644 --- a/service/ws/client.go +++ b/service/ws/client.go @@ -4,6 +4,7 @@ package ws import ( + "crypto/tls" "fmt" "log/slog" "net/http" @@ -33,6 +34,7 @@ type Client struct { wg sync.WaitGroup connState int32 dialFn DialContextFn + tlsConfig *tls.Config pingHandlerFn func(msg string) error log *slog.Logger } @@ -73,6 +75,9 @@ func NewClient(cfg ClientConfig, opts ...ClientOption) (*Client, error) { if c.dialFn != nil { dialer.NetDialContext = c.dialFn } + if c.tlsConfig != nil { + dialer.TLSClientConfig = c.tlsConfig + } ws, _, err := dialer.Dial(cfg.URL, header) if err != nil { return nil, fmt.Errorf("failed to dial: %w", err) diff --git a/service/ws/option.go b/service/ws/option.go index 91c90a8..132e2d1 100644 --- a/service/ws/option.go +++ b/service/ws/option.go @@ -5,6 +5,7 @@ package ws import ( "context" + "crypto/tls" "log/slog" "net" ) @@ -38,3 +39,13 @@ func WithLogger(log *slog.Logger) ClientOption { return nil } } + +// WithTLSConfig lets the caller set an optional TLS configuration for the +// WebSocket connection. This is needed when connecting to a server using a +// self-signed or private CA certificate. +func WithTLSConfig(tlsConfig *tls.Config) ClientOption { + return func(c *Client) error { + c.tlsConfig = tlsConfig + return nil + } +}