diff --git a/.changeset/harden-rfq-invalid-frames.md b/.changeset/harden-rfq-invalid-frames.md new file mode 100644 index 0000000..59170e9 --- /dev/null +++ b/.changeset/harden-rfq-invalid-frames.md @@ -0,0 +1,5 @@ +--- +"@polymarket/client": patch +--- + +Harden RFQ quoter WebSocket handling for unknown and malformed inbound frames. diff --git a/packages/bindings/src/rfq.ts b/packages/bindings/src/rfq.ts index 1713bc6..51abf60 100644 --- a/packages/bindings/src/rfq.ts +++ b/packages/bindings/src/rfq.ts @@ -165,45 +165,60 @@ export type RfqAuthMessage = { }; }; -export const RfqAuthResponseMessageSchema = z.object({ - type: z.literal('auth'), - success: z.boolean(), - address: EvmAddressSchema.optional(), - role: z.string().optional(), - error: z.string().optional(), +export enum RfqKnownInboundType { + Auth = 'auth', + QuoteRequest = 'RFQ_REQUEST', + QuoteAck = 'ACK_RFQ_QUOTE', + QuoteCancelAck = 'ACK_RFQ_QUOTE_CANCEL', + ConfirmationRequest = 'RFQ_CONFIRMATION_REQUEST', + ConfirmationAck = 'ACK_RFQ_CONFIRMATION_RESPONSE', + ExecutionUpdate = 'RFQ_EXECUTION_UPDATE', + Error = 'RFQ_ERROR', +} + +export const RfqKnownInboundMessageSchema = z.object({ + type: z.enum(RfqKnownInboundType), }); +export const RfqAuthResponseMessageSchema = RfqKnownInboundMessageSchema.extend( + { + type: z.literal(RfqKnownInboundType.Auth), + success: z.boolean(), + address: EvmAddressSchema.optional(), + role: z.string().optional(), + error: z.string().optional(), + }, +); + export type RfqAuthResponseMessage = z.infer< typeof RfqAuthResponseMessageSchema >; -export const RfqQuoteRequestSchema = z - .object({ - type: z.literal('RFQ_REQUEST'), - rfq_id: RfqIdSchema, - requestor_public_id: RfqRequestorPublicIdSchema, - leg_position_ids: z.array(PositionIdSchema), - condition_id: ConditionIdSchema, - yes_position_id: PositionIdSchema, - no_position_id: PositionIdSchema, - direction: RfqDirectionSchema, - side: RfqSideSchema, - requested_size: RfqRequestedSizeSchema, - submission_deadline: EpochMillisecondsSchema, - }) - .transform((message) => ({ - conditionId: message.condition_id, - direction: message.direction, - legPositionIds: message.leg_position_ids, - noPositionId: message.no_position_id, - requestorPublicId: message.requestor_public_id, - requestedSize: message.requested_size, - rfqId: message.rfq_id, - side: message.side, - submissionDeadline: message.submission_deadline, - type: 'quote_request' as const, - yesPositionId: message.yes_position_id, - })); +export const RfqQuoteRequestSchema = RfqKnownInboundMessageSchema.extend({ + type: z.literal(RfqKnownInboundType.QuoteRequest), + rfq_id: RfqIdSchema, + requestor_public_id: RfqRequestorPublicIdSchema, + leg_position_ids: z.array(PositionIdSchema), + condition_id: ConditionIdSchema, + yes_position_id: PositionIdSchema, + no_position_id: PositionIdSchema, + direction: RfqDirectionSchema, + side: RfqSideSchema, + requested_size: RfqRequestedSizeSchema, + submission_deadline: EpochMillisecondsSchema, +}).transform((message) => ({ + conditionId: message.condition_id, + direction: message.direction, + legPositionIds: message.leg_position_ids, + noPositionId: message.no_position_id, + requestorPublicId: message.requestor_public_id, + requestedSize: message.requested_size, + rfqId: message.rfq_id, + side: message.side, + submissionDeadline: message.submission_deadline, + type: 'quote_request' as const, + yesPositionId: message.yes_position_id, +})); export type RfqQuoteRequest = z.infer; @@ -223,37 +238,33 @@ export type RfqQuoteCancelMessage = { maker_address: EvmAddress; }; -export const RfqQuoteAckSchema = z - .object({ - type: z.literal('ACK_RFQ_QUOTE'), - rfq_id: RfqIdSchema, - quote_id: RfqQuoteIdSchema, - }) - .transform((message) => ({ - quoteId: message.quote_id, - rfqId: message.rfq_id, - type: 'quote_ack' as const, - })); +export const RfqQuoteAckSchema = RfqKnownInboundMessageSchema.extend({ + type: z.literal(RfqKnownInboundType.QuoteAck), + rfq_id: RfqIdSchema, + quote_id: RfqQuoteIdSchema, +}).transform((message) => ({ + quoteId: message.quote_id, + rfqId: message.rfq_id, + type: 'quote_ack' as const, +})); export type RfqQuoteAck = z.infer; -export const RfqQuoteCancelAckSchema = z - .object({ - type: z.literal('ACK_RFQ_QUOTE_CANCEL'), - rfq_id: RfqIdSchema, - quote_id: RfqQuoteIdSchema, - }) - .transform((message) => ({ - quoteId: message.quote_id, - rfqId: message.rfq_id, - type: 'quote_cancel_ack' as const, - })); +export const RfqQuoteCancelAckSchema = RfqKnownInboundMessageSchema.extend({ + type: z.literal(RfqKnownInboundType.QuoteCancelAck), + rfq_id: RfqIdSchema, + quote_id: RfqQuoteIdSchema, +}).transform((message) => ({ + quoteId: message.quote_id, + rfqId: message.rfq_id, + type: 'quote_cancel_ack' as const, +})); export type RfqQuoteCancelAck = z.infer; -export const RfqConfirmationRequestSchema = z - .object({ - type: z.literal('RFQ_CONFIRMATION_REQUEST'), +export const RfqConfirmationRequestSchema = RfqKnownInboundMessageSchema.extend( + { + type: z.literal(RfqKnownInboundType.ConfirmationRequest), rfq_id: RfqIdSchema, quote_id: RfqQuoteIdSchema, signer_address: EvmAddressSchema, @@ -268,24 +279,24 @@ export const RfqConfirmationRequestSchema = z fill_size_e6: BigIntStringToDecimalStringSchema, price_e6: BigIntStringToDecimalStringSchema, confirm_by: EpochMillisecondsSchema, - }) - .transform((message) => ({ - conditionId: message.condition_id, - confirmBy: message.confirm_by, - direction: message.direction, - fillSize: message.fill_size_e6, - legPositionIds: message.leg_position_ids, - makerAddress: message.maker_address, - noPositionId: message.no_position_id, - price: message.price_e6, - quoteId: message.quote_id, - rfqId: message.rfq_id, - side: message.side, - signatureType: message.signature_type, - signerAddress: message.signer_address, - type: 'confirmation_request' as const, - yesPositionId: message.yes_position_id, - })); + }, +).transform((message) => ({ + conditionId: message.condition_id, + confirmBy: message.confirm_by, + direction: message.direction, + fillSize: message.fill_size_e6, + legPositionIds: message.leg_position_ids, + makerAddress: message.maker_address, + noPositionId: message.no_position_id, + price: message.price_e6, + quoteId: message.quote_id, + rfqId: message.rfq_id, + side: message.side, + signatureType: message.signature_type, + signerAddress: message.signer_address, + type: 'confirmation_request' as const, + yesPositionId: message.yes_position_id, +})); export type RfqConfirmationRequest = z.infer< typeof RfqConfirmationRequestSchema @@ -298,56 +309,50 @@ export type RfqConfirmationResponseMessage = { decision: RfqConfirmationDecision; }; -export const RfqConfirmationAckSchema = z - .object({ - type: z.literal('ACK_RFQ_CONFIRMATION_RESPONSE'), - rfq_id: RfqIdSchema, - quote_id: RfqQuoteIdSchema, - decision: RfqConfirmationDecisionSchema, - }) - .transform((message) => ({ - decision: message.decision, - quoteId: message.quote_id, - rfqId: message.rfq_id, - type: 'confirmation_ack' as const, - })); +export const RfqConfirmationAckSchema = RfqKnownInboundMessageSchema.extend({ + type: z.literal(RfqKnownInboundType.ConfirmationAck), + rfq_id: RfqIdSchema, + quote_id: RfqQuoteIdSchema, + decision: RfqConfirmationDecisionSchema, +}).transform((message) => ({ + decision: message.decision, + quoteId: message.quote_id, + rfqId: message.rfq_id, + type: 'confirmation_ack' as const, +})); export type RfqConfirmationAck = z.infer; -export const RfqExecutionUpdateSchema = z - .object({ - type: z.literal('RFQ_EXECUTION_UPDATE'), - rfq_id: RfqIdSchema, - status: RfqExecutionStatusSchema, - tx_hash: TxHashSchema.optional(), - }) - .transform((message) => ({ - rfqId: message.rfq_id, - status: message.status, - ...(message.tx_hash === undefined ? {} : { txHash: message.tx_hash }), - type: 'execution_update' as const, - })); +export const RfqExecutionUpdateSchema = RfqKnownInboundMessageSchema.extend({ + type: z.literal(RfqKnownInboundType.ExecutionUpdate), + rfq_id: RfqIdSchema, + status: RfqExecutionStatusSchema, + tx_hash: TxHashSchema.optional(), +}).transform((message) => ({ + rfqId: message.rfq_id, + status: message.status, + ...(message.tx_hash === undefined ? {} : { txHash: message.tx_hash }), + type: 'execution_update' as const, +})); export type RfqExecutionUpdate = z.infer; -export const RfqErrorMessageSchema = z - .object({ - type: z.literal('RFQ_ERROR'), - request_type: z.string().optional(), - rfq_id: RfqIdSchema.optional(), - quote_id: RfqQuoteIdSchema.optional(), - code: RfqErrorCodeSchema, - error: z.string(), - request: z.unknown().optional(), - }) - .transform((message) => ({ - code: message.code, - message: message.error, - quoteId: message.quote_id, - requestType: message.request_type, - rfqId: message.rfq_id, - type: 'rfq_error' as const, - })); +export const RfqErrorMessageSchema = RfqKnownInboundMessageSchema.extend({ + type: z.literal(RfqKnownInboundType.Error), + request_type: z.string().optional(), + rfq_id: RfqIdSchema.optional(), + quote_id: RfqQuoteIdSchema.optional(), + code: RfqErrorCodeSchema, + error: z.string(), + request: z.unknown().optional(), +}).transform((message) => ({ + code: message.code, + message: message.error, + quoteId: message.quote_id, + requestType: message.request_type, + rfqId: message.rfq_id, + type: 'rfq_error' as const, +})); export type RfqErrorMessage = z.infer; diff --git a/packages/client/src/websockets/rfq/quoter.ts b/packages/client/src/websockets/rfq/quoter.ts index 20b9fad..868b6f3 100644 --- a/packages/client/src/websockets/rfq/quoter.ts +++ b/packages/client/src/websockets/rfq/quoter.ts @@ -3,6 +3,7 @@ import { type RfqConfirmationDecision, type RfqErrorMessage, type RfqId, + RfqKnownInboundMessageSchema, type RfqQuoteId, type RfqQuoteRequest, RfqQuoterInboundMessageSchema, @@ -47,6 +48,7 @@ import { createRfqQuote, parseRfqQuoteResponse } from './quote'; const AUTH_TIMEOUT_MS = 30_000; const ACK_TIMEOUT_MS = 30_000; +const RFQ_WEBSOCKET_CLOSED_ERROR = 'RFQ quoter websocket closed.'; export type RfqQuoterWebSocketManagerOptions = { account: AccountIdentity; @@ -223,23 +225,42 @@ class RfqWebSocketSession implements RfqSession, RfqEventController { } async #shutdown(): Promise { - this.#pending.rejectAll(new TransportError('RFQ quoter websocket closed.')); + const error = new TransportError(RFQ_WEBSOCKET_CLOSED_ERROR); + await this.#shutdownWithError(error); + } + + async #fail(error: Error): Promise { + if (this.#closing === undefined) { + this.#closing = this.#shutdownWithError(error); + } + await this.#closing; + } + + async #shutdownWithError(error: Error): Promise { + this.#failPending(error); this.#queue.end(); await this.#connection.close(); this.#onClose(); } + #failPending(error: Error): void { + this.#auth?.reject(error); + this.#pending.rejectAll(error); + } + #sendAuthMessage(): void { this.#connection.send(createAuthMessage(this.#account, this.#credentials)); } #handleMessage(rawMessage: unknown): void { + if (!RfqKnownInboundMessageSchema.safeParse(rawMessage).success) return; + const parsed = RfqQuoterInboundMessageSchema.safeParse(rawMessage); if (!parsed.success) { const error = new TransportError('Invalid RFQ quoter message.', { cause: parsed.error, }); - this.#pending.rejectAll(error); + void this.#fail(error); return; } diff --git a/packages/client/tests/integration/rfq-frames.ts b/packages/client/tests/integration/rfq-frames.ts index f0e5ec6..ccf1bc0 100644 --- a/packages/client/tests/integration/rfq-frames.ts +++ b/packages/client/tests/integration/rfq-frames.ts @@ -156,6 +156,20 @@ export function confirmationAckMessage(decision: string) { return JSON.stringify(confirmationAckFrame(decision)); } +export function malformedQuoteAckMessage() { + return JSON.stringify({ + rfq_id: RFQ_ID, + type: 'ACK_RFQ_QUOTE', + }); +} + +export function unknownRfqMessage() { + return JSON.stringify({ + payload: 'ignored', + type: 'RFQ_FUTURE_MESSAGE', + }); +} + function executionUpdateFrame() { return { rfq_id: RFQ_ID, diff --git a/packages/client/tests/integration/rfq.test.ts b/packages/client/tests/integration/rfq.test.ts index 4020467..5a48456 100644 --- a/packages/client/tests/integration/rfq.test.ts +++ b/packages/client/tests/integration/rfq.test.ts @@ -16,6 +16,7 @@ import { confirmationDecision, confirmationRequestMessage, executionUpdateMessage, + malformedQuoteAckMessage, type OutboundFrame, QUOTE_ID, QUOTE_SIZE_E6, @@ -27,6 +28,7 @@ import { recordOutboundFrame, rfqErrorMessage, TX_HASH, + unknownRfqMessage, } from './rfq-frames'; const rfq = ws.link(production.rfqQuoterWs); @@ -805,6 +807,130 @@ describe('RFQ sessions', () => { }); }); + describe('when the server sends unknown RFQ frames', () => { + beforeEach(() => { + server.resetHandlers(); + server.use( + rfq.addEventListener('connection', ({ client: socket }) => { + socket.addEventListener('message', (event) => { + const frame = recordOutboundFrame(event.data, outboundFrames); + + if (frame.type === 'auth') { + socket.send(authAckMessage()); + socket.send(quoteRequestMessage()); + return; + } + + if (frame.type === 'RFQ_QUOTE') { + const quote = quoteAmounts(frame); + socket.send(unknownRfqMessage()); + socket.send(quoteAckMessage()); + socket.send( + confirmationRequestMessage(quote.priceE6, quote.sizeE6), + ); + return; + } + + if (frame.type === 'RFQ_QUOTE_CANCEL') { + socket.send(unknownRfqMessage()); + socket.send(quoteCancelAckMessage()); + return; + } + + if (frame.type === 'RFQ_CONFIRMATION_RESPONSE') { + const decision = confirmationDecision(frame); + socket.send(unknownRfqMessage()); + socket.send(confirmationAckMessage(decision)); + } + }); + }), + ); + }); + + it('keeps quote, cancellation, and confirmation waits pending until their acknowledgements arrive', async ({ + secureClientWithDepositWallet, + }) => { + const session = await secureClientWithDepositWallet.openRfqSession(); + + try { + for await (const event of session) { + if (event.type === 'quote_request') { + const quote = await event.quote({ price: 0.45 }); + const cancellation = await session.cancelQuote(quote); + + expect(quote).toEqual({ + quoteId: QUOTE_ID, + rfqId: event.rfqId, + }); + expect(cancellation).toEqual(quote); + continue; + } + + if (event.type === 'confirmation_request') { + const confirmation = await event.confirm(); + + expect(confirmation).toEqual({ + quoteId: event.quoteId, + rfqId: event.rfqId, + }); + + await session.close(); + break; + } + } + } finally { + await secureClientWithDepositWallet.closeSubscriptions(); + } + }); + }); + + describe('when the server sends a malformed known RFQ frame', () => { + beforeEach(() => { + server.resetHandlers(); + server.use( + rfq.addEventListener('connection', ({ client: socket }) => { + socket.addEventListener('message', (event) => { + const frame = recordOutboundFrame(event.data, outboundFrames); + + if (frame.type === 'auth') { + socket.send(authAckMessage()); + socket.send(quoteRequestMessage()); + return; + } + + if (frame.type === 'RFQ_QUOTE') { + quoteAmounts(frame); + socket.send(malformedQuoteAckMessage()); + } + }); + }), + ); + }); + + it('fails the session and ends the event stream', async ({ + secureClientWithDepositWallet, + }) => { + const session = await secureClientWithDepositWallet.openRfqSession(); + + try { + const iterator = session[Symbol.asyncIterator](); + const next = await iterator.next(); + + if (next.done === true || next.value.type !== 'quote_request') { + throw new Error('Expected RFQ quote request.'); + } + + await expect(next.value.quote({ price: 0.45 })).rejects.toMatchObject({ + message: 'Invalid RFQ quoter message.', + name: 'TransportError', + }); + await expect(iterator.next()).resolves.toMatchObject({ done: true }); + } finally { + await secureClientWithDepositWallet.closeSubscriptions(); + } + }); + }); + describe('when the connection closes before quote acknowledgement', () => { beforeEach(() => { server.resetHandlers();