From 674d67a023b53f483aa3524b882ef86bc543322a Mon Sep 17 00:00:00 2001 From: Haakam Aujla Date: Sun, 1 Mar 2026 21:26:10 -0800 Subject: [PATCH] WS wait for open on connect --- src/wrapper/WebsocketsClient.ts | 16 +++++++++------- tests/unit/wrapper/WebsocketsClient.test.ts | 4 +++- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/src/wrapper/WebsocketsClient.ts b/src/wrapper/WebsocketsClient.ts index 15ba15b..5e130b4 100644 --- a/src/wrapper/WebsocketsClient.ts +++ b/src/wrapper/WebsocketsClient.ts @@ -14,6 +14,8 @@ export class WebsocketsClient extends FernWebsocketsClient { } public override async connect(args: FernWebsocketsClient.ConnectArgs = {}): Promise { + let connectArgs = args; + if (this._getPaymentCredentials) { const wsUrl = core.url.join( (await core.Supplier.get(this._options.baseUrl)) ?? @@ -22,17 +24,17 @@ export class WebsocketsClient extends FernWebsocketsClient { "/v0", ); const credentials = await this._getPaymentCredentials(wsUrl); - return super.connect({ + connectArgs = { ...args, queryParams: { ...credentials, ...args.queryParams }, - }); - } - - if (!args.apiKey) { + }; + } else if (!args.apiKey) { const apiKey = (await core.Supplier.get(this._options.apiKey)) ?? process.env.AGENTMAIL_API_KEY; - return super.connect({ ...args, apiKey }); + connectArgs = { ...args, apiKey }; } - return super.connect(args); + const socket = await super.connect(connectArgs); + await socket.waitForOpen(); + return socket; } } diff --git a/tests/unit/wrapper/WebsocketsClient.test.ts b/tests/unit/wrapper/WebsocketsClient.test.ts index b78d204..3ef0879 100644 --- a/tests/unit/wrapper/WebsocketsClient.test.ts +++ b/tests/unit/wrapper/WebsocketsClient.test.ts @@ -5,7 +5,9 @@ import * as x402Helpers from "../../../src/wrapper/x402"; import * as mppHelpers from "../../../src/wrapper/mppx"; function mockConnect() { - return vi.spyOn(FernWebsocketsClient.prototype, "connect").mockResolvedValue({} as WebsocketsSocket); + return vi + .spyOn(FernWebsocketsClient.prototype, "connect") + .mockResolvedValue({ waitForOpen: vi.fn().mockResolvedValue(undefined) } as unknown as WebsocketsSocket); } describe("WebsocketsClient wrapper", () => {