diff --git a/build/native-image/build.sh b/build/native-image/build.sh new file mode 100644 index 0000000000..f7d7f30c9b --- /dev/null +++ b/build/native-image/build.sh @@ -0,0 +1,11 @@ +mvn clean package -P assembly -DskipTests +cd target/pgadapter +native-image \ + --initialize-at-build-time=com.google.protobuf,com.google.gson,com.google.cloud.spanner.pgadapter.Server \ + -J-Xmx14g \ + -H:IncludeResources=".*" \ + -H:ReflectionConfigurationFiles=../../build/native-image/reflectconfig.json \ + -jar pgadapter.jar \ + --no-fallback + +./pgadapter -p appdev-soda-spanner-staging -i knut-test-ycsb -s 5433 diff --git a/build/native-image/reflectconfig.json b/build/native-image/reflectconfig.json new file mode 100644 index 0000000000..8d116ca9f1 --- /dev/null +++ b/build/native-image/reflectconfig.json @@ -0,0 +1,79 @@ +[ + { + "name" : "java.lang.Class", + "queryAllDeclaredConstructors" : true, + "queryAllPublicConstructors" : true, + "queryAllDeclaredMethods" : true, + "queryAllPublicMethods" : true, + "allDeclaredClasses" : true, + "allPublicClasses" : true + }, + { + "name" : "sun.misc.Signal", + "queryAllDeclaredConstructors" : true, + "queryAllPublicConstructors" : true, + "queryAllDeclaredMethods" : true, + "queryAllPublicMethods" : true, + "allDeclaredClasses" : true, + "allPublicClasses" : true, + "allDeclaredConstructors": true, + "allPublicConstructors": true, + "allDeclaredMethods": true, + "allPublicMethods": true + }, + { + "name" : "sun.misc.SignalHandler", + "queryAllDeclaredConstructors" : true, + "queryAllPublicConstructors" : true, + "queryAllDeclaredMethods" : true, + "queryAllPublicMethods" : true, + "allDeclaredClasses" : true, + "allPublicClasses" : true, + "allDeclaredConstructors": true, + "allPublicConstructors": true, + "allDeclaredMethods": true, + "allPublicMethods": true + }, + { + "name" : "com.google.cloud.spanner.pgadapter.Server", + "methods" : [ + { "name" : "handleTerm" }, + { "name" : "handleInt" }, + { "name" : "handleQuit" } + ] + }, + { + "name" : "com.google.cloud.spanner.pgadapter.logging.StdoutHandler", + "queryAllDeclaredConstructors" : true, + "queryAllPublicConstructors" : true, + "queryAllDeclaredMethods" : true, + "queryAllPublicMethods" : true, + "allDeclaredClasses" : true, + "allPublicClasses" : true, + "allDeclaredConstructors": true, + "allPublicConstructors": true, + "allDeclaredMethods": true, + "allPublicMethods": true + }, + { + "name" : "com.google.cloud.spanner.pgadapter.logging.StderrHandler", + "queryAllDeclaredConstructors" : true, + "queryAllPublicConstructors" : true, + "queryAllDeclaredMethods" : true, + "queryAllPublicMethods" : true, + "allDeclaredClasses" : true, + "allPublicClasses" : true, + "allDeclaredConstructors": true, + "allPublicConstructors": true, + "allDeclaredMethods": true, + "allPublicMethods": true + }, + { + "name" : "java.time.Instant", + "methods" : [ + { "name" : "now" }, + { "name" : "getNano" }, + { "name" : "getEpochSecond" } + ] + } +] diff --git a/pom.xml b/pom.xml index acaa283407..882faa3abc 100644 --- a/pom.xml +++ b/pom.xml @@ -100,6 +100,10 @@ + + org.graalvm.sdk + graal-sdk + com.google.cloud google-cloud-spanner diff --git a/samples/nodejs/pgadapter-childprocess/package.json b/samples/nodejs/pgadapter-childprocess/package.json new file mode 100644 index 0000000000..a47f960b38 --- /dev/null +++ b/samples/nodejs/pgadapter-childprocess/package.json @@ -0,0 +1,21 @@ +{ + "name": "pgadapter-childprocess-sample", + "version": "0.0.1", + "description": "PGAdapter Child Process Sample", + "type": "commonjs", + "devDependencies": { + "@types/node": "^22.1.0", + "@types/pg": "^8.11.4", + "ts-node": "10.9.2", + "typescript": "5.8.2" + }, + "dependencies": { + "pg": "^8.11.3", + "test-wrapped-binary": "/Users/loite/IdeaProjects/pgadapter/wrappers/nodejs", + "umzug": "^3.6.1", + "yargs": "^17.5.1" + }, + "scripts": { + "start": "ts-node src/index.ts" + } +} diff --git a/samples/nodejs/pgadapter-childprocess/src/index.ts b/samples/nodejs/pgadapter-childprocess/src/index.ts new file mode 100644 index 0000000000..a323fbb904 --- /dev/null +++ b/samples/nodejs/pgadapter-childprocess/src/index.ts @@ -0,0 +1,62 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import {startPGAdapter} from 'test-wrapped-binary' +import { Client } from 'pg'; + +function sleep(ms) { + return new Promise((resolve) => { + setTimeout(resolve, ms); + }); +} + +async function main() { + const pgAdapter = await startPGAdapter({ + project: "appdev-soda-spanner-staging", + instance: "knut-test-ycsb", + port: 5433, + }); + try { + //await sleep(500); + + console.log('Started PGAdapter'); + + // Execute a simple query. + const connection = new Client({ + host: "localhost", + port: 5433, + database: "knut-test-db", + }); + await connection.connect(); + + const result = await connection.query("SELECT * " + + "FROM all_types " + + "LIMIT 10"); + for (const row of result.rows) { + console.log(JSON.stringify(row)); + } + + // Close the connection. + await connection.end(); + } finally { + pgAdapter.kill(); + } +} + +(async () => { + await main(); +})().catch(e => { + console.error(e); + process.exit(1); +}); diff --git a/samples/nodejs/pgadapter-childprocess/src/init.ts b/samples/nodejs/pgadapter-childprocess/src/init.ts new file mode 100644 index 0000000000..c8dab81912 --- /dev/null +++ b/samples/nodejs/pgadapter-childprocess/src/init.ts @@ -0,0 +1,117 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import {QueryTypes, Sequelize} from 'sequelize'; +import {GenericContainer, PullPolicy, StartedTestContainer, TestContainer} from "testcontainers"; + +/** + * Creates the data model that is needed for this sample application. + * + * The Cloud Spanner PostgreSQL dialect does not support all system tables (pg_catalog tables) that + * are present in open-source PostgreSQL databases. Those tables are used by Sequelize migrations. + * Migrations are therefore not supported. + */ +export async function createDataModel(sequelize: Sequelize) { + console.log("Checking whether tables already exists"); + const result: any = await sequelize.query( + `SELECT COUNT(1) AS c + FROM information_schema.tables + WHERE table_schema='public' + AND table_name IN ('Singers', 'Albums', 'Tracks', 'Venues', 'Concerts', 'TicketSales')`, + { type: QueryTypes.SELECT, raw: true, plain: true }); + if (result.c == '6') { + return; + } + console.log("Creating tables"); + // Create the data model. + await sequelize.query( + ` + alter database db set spanner.default_sequence_kind='bit_reversed_positive'; + create table "Singers" ( + id serial primary key, + "firstName" varchar, + "lastName" varchar, + "fullName" varchar generated always as ( + CASE WHEN "firstName" IS NULL THEN "lastName" + WHEN "lastName" IS NULL THEN "firstName" + ELSE "firstName" || ' ' || "lastName" + END) stored, + "active" boolean, + "createdAt" timestamptz, + "updatedAt" timestamptz + ); + + create table "Albums" ( + id serial primary key, + title varchar, + "marketingBudget" numeric, + "SingerId" bigint, + "createdAt" timestamptz default current_timestamp, + "updatedAt" timestamptz, + constraint fk_albums_singers foreign key ("SingerId") references "Singers" (id) + ); + + create table if not exists "Tracks" ( + id bigint not null, + "trackNumber" bigint not null, + title varchar not null, + "sampleRate" float8 not null, + "createdAt" timestamptz default current_timestamp, + "updatedAt" timestamptz, + primary key (id, "trackNumber") + ) interleave in parent "Albums" on delete cascade; + + create table if not exists "Venues" ( + id serial primary key, + name varchar not null, + description varchar not null, + "createdAt" timestamptz default current_timestamp, + "updatedAt" timestamptz + ); + + create table if not exists "Concerts" ( + id serial primary key, + "VenueId" bigint not null, + "SingerId" bigint not null, + name varchar not null, + "startTime" timestamptz not null, + "endTime" timestamptz not null, + "createdAt" timestamptz default current_timestamp, + "updatedAt" timestamptz, + constraint fk_concerts_venues foreign key ("VenueId") references "Venues" (id), + constraint fk_concerts_singers foreign key ("SingerId") references "Singers" (id), + constraint chk_end_time_after_start_time check ("endTime" > "startTime") + ); + + create table if not exists "TicketSales" ( + id serial primary key, + "ConcertId" bigint not null, + "customerName" varchar not null, + price decimal not null, + seats text[], + "createdAt" timestamptz default current_timestamp, + "updatedAt" timestamptz, + constraint fk_ticket_sales_concerts foreign key ("ConcertId") references "Concerts" (id) + );`, + {type: QueryTypes.RAW}) +} + +export async function startPGAdapter(): Promise { + console.log("Pulling PGAdapter and Spanner emulator"); + const container: TestContainer = new GenericContainer("gcr.io/cloud-spanner-pg-adapter/pgadapter-emulator") + .withPullPolicy(PullPolicy.alwaysPull()) + .withExposedPorts(5432); + console.log("Starting PGAdapter and Spanner emulator"); + return await container.start(); +} diff --git a/samples/nodejs/pgadapter-childprocess/src/random.ts b/samples/nodejs/pgadapter-childprocess/src/random.ts new file mode 100644 index 0000000000..c806053c95 --- /dev/null +++ b/samples/nodejs/pgadapter-childprocess/src/random.ts @@ -0,0 +1,134 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +export function randomInt(min: number, max: number): number { + return Math.floor(Math.random() * (max - min + 1) + min); +} + +export function randomFirstName(): string { + return randomArrayElement(first_names); +} + +export function randomLastName(): string { + return randomArrayElement(last_names); +} + +export function randomAlbumTitle(): string { + return `${randomArrayElement(adjectives)} ${randomArrayElement(nouns)}`; +} + +export function randomTrackTitle(): string { + return `${randomArrayElement(adverbs)} ${randomArrayElement(verbs)}`; +} + +function randomArrayElement(array: Array): string { + return array[Math.floor(Math.random() * array.length)]; +} + +const first_names: string[] = [ + "Saffron", "Eleanor", "Ann", "Salma", "Kiera", "Mariam", "Georgie", "Eden", "Carmen", "Darcie", + "Antony", "Benjamin", "Donald", "Keaton", "Jared", "Simon", "Tanya", "Julian", "Eugene", "Laurence"]; + +const last_names: string[] = [ + "Terry", "Ford", "Mills", "Connolly", "Newton", "Rodgers", "Austin", "Floyd", "Doherty", "Nguyen", + "Chavez", "Crossley", "Silva", "George", "Baldwin", "Burns", "Russell", "Ramirez", "Hunter", "Fuller"]; + +export const adjectives: string[] = [ + "ultra", + "happy", + "emotional", + "lame", + "charming", + "alleged", + "talented", + "exotic", + "lamentable", + "splendid", + "old-fashioned", + "savory", + "delicate", + "willing", + "habitual", + "upset", + "gainful", + "nonchalant", + "kind", + "unruly"]; + +export const nouns: string[] = [ + "improvement", + "control", + "tennis", + "gene", + "department", + "person", + "awareness", + "health", + "development", + "platform", + "garbage", + "suggestion", + "agreement", + "knowledge", + "introduction", + "recommendation", + "driver", + "elevator", + "industry", + "extent"]; + +export const verbs: string[] = [ + "instruct", + "rescue", + "disappear", + "import", + "inhibit", + "accommodate", + "dress", + "describe", + "mind", + "strip", + "crawl", + "lower", + "influence", + "alter", + "prove", + "race", + "label", + "exhaust", + "reach", + "remove"]; + +export const adverbs: string[] = [ + "cautiously", + "offensively", + "immediately", + "soon", + "judgementally", + "actually", + "honestly", + "slightly", + "limply", + "rigidly", + "fast", + "normally", + "unnecessarily", + "wildly", + "unimpressively", + "helplessly", + "rightfully", + "kiddingly", + "early", + "queasily"]; diff --git a/samples/nodejs/pgadapter-childprocess/src/tsconfig.json b/samples/nodejs/pgadapter-childprocess/src/tsconfig.json new file mode 100644 index 0000000000..d73f85e265 --- /dev/null +++ b/samples/nodejs/pgadapter-childprocess/src/tsconfig.json @@ -0,0 +1,7 @@ +{ + "noImplicitAny": false, + "compilerOptions": { + "target": "es6", + "module": "commonjs" + }, +} diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/ProxyServer.java b/src/main/java/com/google/cloud/spanner/pgadapter/ProxyServer.java index 5fd96fd7e4..e895711ae7 100644 --- a/src/main/java/com/google/cloud/spanner/pgadapter/ProxyServer.java +++ b/src/main/java/com/google/cloud/spanner/pgadapter/ProxyServer.java @@ -26,6 +26,7 @@ import com.google.cloud.spanner.pgadapter.metadata.OptionsMetadata.TextFormat; import com.google.cloud.spanner.pgadapter.statements.IntermediateStatement; import com.google.cloud.spanner.pgadapter.utils.Metrics; +import com.google.cloud.spanner.pgadapter.wireprotocol.MessageReader; import com.google.cloud.spanner.pgadapter.wireprotocol.WireMessage; import com.google.common.collect.ImmutableList; import io.opentelemetry.api.OpenTelemetry; @@ -69,6 +70,7 @@ public class ProxyServer extends AbstractApiService { private final Metrics metrics; private final Properties properties; private final List handlers = new LinkedList<>(); + private final MessageReader messageReader; /** * Latch that is closed when the TCP server has started. We need this to know the exact port that @@ -162,6 +164,7 @@ public ProxyServer(OptionsMetadata optionsMetadata, OpenTelemetry openTelemetry) public ProxyServer( OptionsMetadata optionsMetadata, OpenTelemetry openTelemetry, Properties properties) { this.options = optionsMetadata; + this.messageReader = new MessageReader(optionsMetadata); this.openTelemetry = openTelemetry; this.metrics = optionsMetadata.isEnableOpenTelemetryMetrics() @@ -176,6 +179,10 @@ public ProxyServer( addConnectionProperties(); } + public MessageReader getMessageReader() { + return this.messageReader; + } + private void addConnectionProperties() { for (Map.Entry entry : options.getPropertyMap().entrySet()) { properties.setProperty(entry.getKey(), entry.getValue()); diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/Server.java b/src/main/java/com/google/cloud/spanner/pgadapter/Server.java index b0d4e6925a..066cf68172 100644 --- a/src/main/java/com/google/cloud/spanner/pgadapter/Server.java +++ b/src/main/java/com/google/cloud/spanner/pgadapter/Server.java @@ -63,6 +63,10 @@ public class Server { private static volatile ShutdownHandler shutdownHandler; + static { + registerSignalHandlers(); + } + /** * Main method for running a Spanner PostgreSQL Adapter {@link Server} as a stand-alone * application. Here we call for parameter parsing and start the Proxy Server. @@ -78,7 +82,6 @@ public static void main(String[] args) { // Create a shutdown handler and register signal handlers for the signals that should // terminate the server. Server.shutdownHandler = proxyServer.getOrCreateShutdownHandler(); - registerSignalHandlers(); } catch (Exception e) { printError(e, System.err, System.out); } diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/statements/IntermediateStatement.java b/src/main/java/com/google/cloud/spanner/pgadapter/statements/IntermediateStatement.java index a6ce2561d5..428e74b29e 100644 --- a/src/main/java/com/google/cloud/spanner/pgadapter/statements/IntermediateStatement.java +++ b/src/main/java/com/google/cloud/spanner/pgadapter/statements/IntermediateStatement.java @@ -110,7 +110,7 @@ protected IntermediateStatement( } this.parsedStatement = potentiallyReplacedStatement; this.connection = connectionHandler.getSpannerConnection(); - this.command = parseCommand(this.parsedStatement.getSqlWithoutComments()); + this.command = parseCommand(originalStatement.getSql()); this.commandTag = this.command; this.outputStream = connectionHandler.getConnectionMetadata().getOutputStream(); } diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/wireprotocol/BindMessage.java b/src/main/java/com/google/cloud/spanner/pgadapter/wireprotocol/BindMessage.java index a12916710b..35d28d8899 100644 --- a/src/main/java/com/google/cloud/spanner/pgadapter/wireprotocol/BindMessage.java +++ b/src/main/java/com/google/cloud/spanner/pgadapter/wireprotocol/BindMessage.java @@ -47,9 +47,9 @@ public BindMessage(ConnectionHandler connection) throws Exception { super(connection); this.portalName = this.readString(); this.statementName = this.readString(); - this.formatCodes = getFormatCodes(this.inputStream); + this.formatCodes = MessageReader.getFormatCodes(this.inputStream); this.parameters = getParameters(this.inputStream); - this.resultFormatCodes = getFormatCodes(this.inputStream); + this.resultFormatCodes = MessageReader.getFormatCodes(this.inputStream); IntermediatePreparedStatement statement = connection.getStatement(statementName); this.statement = statement.createPortal( diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/wireprotocol/ControlMessage.java b/src/main/java/com/google/cloud/spanner/pgadapter/wireprotocol/ControlMessage.java index bd4dff0934..91430f16d4 100644 --- a/src/main/java/com/google/cloud/spanner/pgadapter/wireprotocol/ControlMessage.java +++ b/src/main/java/com/google/cloud/spanner/pgadapter/wireprotocol/ControlMessage.java @@ -15,7 +15,6 @@ package com.google.cloud.spanner.pgadapter.wireprotocol; import static com.google.cloud.spanner.pgadapter.statements.BackendConnection.DB_STATEMENT; -import static com.google.cloud.spanner.pgadapter.statements.IntermediatePortalStatement.NO_FORMAT_CODES; import com.google.api.core.InternalApi; import com.google.api.gax.grpc.GrpcCallContext; @@ -32,12 +31,8 @@ import com.google.cloud.spanner.connection.StatementResult; import com.google.cloud.spanner.connection.StatementResult.ResultType; import com.google.cloud.spanner.pgadapter.ConnectionHandler; -import com.google.cloud.spanner.pgadapter.ConnectionHandler.ConnectionStatus; import com.google.cloud.spanner.pgadapter.ConnectionHandler.QueryMode; -import com.google.cloud.spanner.pgadapter.error.PGException; import com.google.cloud.spanner.pgadapter.error.PGExceptionFactory; -import com.google.cloud.spanner.pgadapter.error.SQLState; -import com.google.cloud.spanner.pgadapter.error.Severity; import com.google.cloud.spanner.pgadapter.metadata.SendResultSetState; import com.google.cloud.spanner.pgadapter.statements.BackendConnection.PartitionQueryResult; import com.google.cloud.spanner.pgadapter.statements.CopyToStatement; @@ -61,7 +56,6 @@ import io.opentelemetry.api.trace.Span; import io.opentelemetry.api.trace.Tracer; import io.opentelemetry.context.Scope; -import java.io.DataInputStream; import java.io.IOException; import java.util.ArrayList; import java.util.List; @@ -75,15 +69,12 @@ /** * Generic representation for a control wire message: that is, a message which does not handle any - * form of start-up, but reather general communications. + * form of start-up, but rather general communications. */ @InternalApi public abstract class ControlMessage extends WireMessage { private static final Logger logger = Logger.getLogger(ControlMessage.class.getName()); - /** Maximum number of invalid messages in a row allowed before we terminate the connection. */ - static final int MAX_INVALID_MESSAGE_COUNT = 50; - /** * Token that is used to mark {@link ControlMessage}s that are manually created to execute a * {@link QueryMessage}. @@ -109,117 +100,6 @@ public boolean isExtendedProtocol() { return manuallyCreatedToken == null; } - /** - * Factory method to create the message from the specific command type char. - * - * @param connection The connection handler object setup with the ability to send/receive. - * @return The constructed wire message given the input message. - * @throws Exception If construction or reading fails. - */ - public static ControlMessage create(ConnectionHandler connection) throws Exception { - boolean validMessage = true; - char nextMsg = (char) connection.getConnectionMetadata().getInputStream().readUnsignedByte(); - try { - if (connection.getStatus() == ConnectionStatus.COPY_IN) { - switch (nextMsg) { - case CopyDoneMessage.IDENTIFIER: - return new CopyDoneMessage(connection); - case CopyDataMessage.IDENTIFIER: - return new CopyDataMessage(connection); - case CopyFailMessage.IDENTIFIER: - return new CopyFailMessage(connection); - case SyncMessage.IDENTIFIER: - case FlushMessage.IDENTIFIER: - // Skip sync/flush in COPY_IN. This is consistent with real PostgreSQL which also does - // this to accommodate clients that do not check what type of statement they sent in an - // ExecuteMessage, and instead always blindly send a flush/sync after each execute. - return SkipMessage.createForValidStream(connection); - default: - // Skip other unexpected messages and throw an exception to fail the copy operation. - validMessage = false; - SkipMessage.createForInvalidStream(connection); - throw new IllegalStateException( - String.format( - "Expected CopyData ('d'), CopyDone ('c') or CopyFail ('f') messages, got: '%c'", - nextMsg)); - } - } else { - switch (nextMsg) { - case QueryMessage.IDENTIFIER: - return new QueryMessage(connection); - case ParseMessage.IDENTIFIER: - return new ParseMessage(connection); - case BindMessage.IDENTIFIER: - return new BindMessage(connection); - case DescribeMessage.IDENTIFIER: - return new DescribeMessage(connection); - case ExecuteMessage.IDENTIFIER: - return new ExecuteMessage(connection); - case CloseMessage.IDENTIFIER: - return new CloseMessage(connection); - case TerminateMessage.IDENTIFIER: - return new TerminateMessage(connection); - case FunctionCallMessage.IDENTIFIER: - return new FunctionCallMessage(connection); - case FlushMessage.IDENTIFIER: - return new FlushMessage(connection); - case SyncMessage.IDENTIFIER: - return new SyncMessage(connection); - case CopyDoneMessage.IDENTIFIER: - case CopyDataMessage.IDENTIFIER: - case CopyFailMessage.IDENTIFIER: - // Silently skip COPY messages in non-COPY mode. This is consistent with the PG wire - // protocol. If we continue to receive COPY messages while in non-COPY mode, we'll - // terminate the connection to prevent the server from being flooded with invalid - // messages. - validMessage = false; - // Note: The stream itself is still valid as we received a message that we recognized. - return SkipMessage.createForValidStream(connection); - default: - throw new IllegalStateException(String.format("Unknown message: %c", nextMsg)); - } - } - } finally { - if (validMessage) { - connection.clearInvalidMessageCount(); - } else { - connection.increaseInvalidMessageCount(); - if (connection.getInvalidMessageCount() > MAX_INVALID_MESSAGE_COUNT) { - new ErrorResponse( - connection, - PGException.newBuilder( - String.format( - "Received %d invalid/unexpected messages. Last received message: '%c'", - connection.getInvalidMessageCount(), nextMsg)) - .setSQLState(SQLState.ProtocolViolation) - .setSeverity(Severity.FATAL) - .build()) - .send(); - connection.setStatus(ConnectionStatus.TERMINATED); - } - } - } - } - - /** - * Extract format codes from message (useful for both input and output format codes). - * - * @param input The data stream containing the user request. - * @return A list of format codes. - * @throws Exception If reading fails in any way. - */ - protected static short[] getFormatCodes(DataInputStream input) throws Exception { - short numberOfFormatCodes = input.readShort(); - if (numberOfFormatCodes == 0) { - return NO_FORMAT_CODES; - } - short[] formatCodes = new short[numberOfFormatCodes]; - for (int i = 0; i < numberOfFormatCodes; i++) { - formatCodes[i] = input.readShort(); - } - return formatCodes; - } - public enum PreparedType { Portal, Statement; diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/wireprotocol/FunctionCallMessage.java b/src/main/java/com/google/cloud/spanner/pgadapter/wireprotocol/FunctionCallMessage.java index bbdee28930..49af27141a 100644 --- a/src/main/java/com/google/cloud/spanner/pgadapter/wireprotocol/FunctionCallMessage.java +++ b/src/main/java/com/google/cloud/spanner/pgadapter/wireprotocol/FunctionCallMessage.java @@ -37,7 +37,7 @@ public class FunctionCallMessage extends ControlMessage { public FunctionCallMessage(ConnectionHandler connection) throws Exception { super(connection); this.functionID = this.inputStream.readInt(); - this.argumentFormatCodes = getFormatCodes(this.inputStream); + this.argumentFormatCodes = MessageReader.getFormatCodes(this.inputStream); this.arguments = getParameters(this.inputStream); this.resultFormatCode = this.inputStream.readShort(); } diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/wireprotocol/MessageReader.java b/src/main/java/com/google/cloud/spanner/pgadapter/wireprotocol/MessageReader.java new file mode 100644 index 0000000000..7f0ade3a00 --- /dev/null +++ b/src/main/java/com/google/cloud/spanner/pgadapter/wireprotocol/MessageReader.java @@ -0,0 +1,178 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.cloud.spanner.pgadapter.wireprotocol; + +import static com.google.cloud.spanner.pgadapter.statements.IntermediatePortalStatement.NO_FORMAT_CODES; + +import com.google.cloud.spanner.pgadapter.ConnectionHandler; +import com.google.cloud.spanner.pgadapter.ConnectionHandler.ConnectionStatus; +import com.google.cloud.spanner.pgadapter.error.PGException; +import com.google.cloud.spanner.pgadapter.error.SQLState; +import com.google.cloud.spanner.pgadapter.error.Severity; +import com.google.cloud.spanner.pgadapter.metadata.OptionsMetadata; +import com.google.cloud.spanner.pgadapter.wireoutput.ErrorResponse; +import java.io.DataInputStream; +import java.io.IOException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.locks.LockSupport; + +public class MessageReader { + /** Maximum number of invalid messages in a row allowed before we terminate the connection. */ + static final int MAX_INVALID_MESSAGE_COUNT = 50; + + private final AtomicInteger busyWaiters = new AtomicInteger(); + private final int maxBusyWaiters = + Integer.getInteger( + "pgadapter.max_busy_waiters", + Math.max(Runtime.getRuntime().availableProcessors() / 2, 1)); + private final int maxBusyWaitTime = Integer.getInteger("pgadapter.max_busy_wait_millis", 0); + private final long busyWaitParkTime = + Long.getLong("pgadapter.busy_wait_park_time_nanos", TimeUnit.MICROSECONDS.toNanos(1)); + + public MessageReader(OptionsMetadata options) { + // TODO: Read busy-wait options from options + } + + /** + * Factory method to create the message from the specific command type char. + * + * @param connection The connection handler object setup with the ability to send/receive. + * @return The constructed wire message given the input message. + * @throws Exception If construction or reading fails. + */ + public ControlMessage create(ConnectionHandler connection) throws Exception { + boolean validMessage = true; + char nextMsg = readNextMsgIdentifier(connection); + try { + if (connection.getStatus() == ConnectionStatus.COPY_IN) { + switch (nextMsg) { + case CopyDoneMessage.IDENTIFIER: + return new CopyDoneMessage(connection); + case CopyDataMessage.IDENTIFIER: + return new CopyDataMessage(connection); + case CopyFailMessage.IDENTIFIER: + return new CopyFailMessage(connection); + case SyncMessage.IDENTIFIER: + case FlushMessage.IDENTIFIER: + // Skip sync/flush in COPY_IN. This is consistent with real PostgreSQL which also does + // this to accommodate clients that do not check what type of statement they sent in an + // ExecuteMessage, and instead always blindly send a flush/sync after each execute. + return SkipMessage.createForValidStream(connection); + default: + // Skip other unexpected messages and throw an exception to fail the copy operation. + validMessage = false; + SkipMessage.createForInvalidStream(connection); + throw new IllegalStateException( + String.format( + "Expected CopyData ('d'), CopyDone ('c') or CopyFail ('f') messages, got: '%c'", + nextMsg)); + } + } else { + switch (nextMsg) { + case QueryMessage.IDENTIFIER: + return new QueryMessage(connection); + case ParseMessage.IDENTIFIER: + return new ParseMessage(connection); + case BindMessage.IDENTIFIER: + return new BindMessage(connection); + case DescribeMessage.IDENTIFIER: + return new DescribeMessage(connection); + case ExecuteMessage.IDENTIFIER: + return new ExecuteMessage(connection); + case CloseMessage.IDENTIFIER: + return new CloseMessage(connection); + case TerminateMessage.IDENTIFIER: + return new TerminateMessage(connection); + case FunctionCallMessage.IDENTIFIER: + return new FunctionCallMessage(connection); + case FlushMessage.IDENTIFIER: + return new FlushMessage(connection); + case SyncMessage.IDENTIFIER: + return new SyncMessage(connection); + case CopyDoneMessage.IDENTIFIER: + case CopyDataMessage.IDENTIFIER: + case CopyFailMessage.IDENTIFIER: + // Silently skip COPY messages in non-COPY mode. This is consistent with the PG wire + // protocol. If we continue to receive COPY messages while in non-COPY mode, we'll + // terminate the connection to prevent the server from being flooded with invalid + // messages. + validMessage = false; + // Note: The stream itself is still valid as we received a message that we recognized. + return SkipMessage.createForValidStream(connection); + default: + throw new IllegalStateException(String.format("Unknown message: %c", nextMsg)); + } + } + } finally { + if (validMessage) { + connection.clearInvalidMessageCount(); + } else { + connection.increaseInvalidMessageCount(); + if (connection.getInvalidMessageCount() > MAX_INVALID_MESSAGE_COUNT) { + new ErrorResponse( + connection, + PGException.newBuilder( + String.format( + "Received %d invalid/unexpected messages. Last received message: '%c'", + connection.getInvalidMessageCount(), nextMsg)) + .setSQLState(SQLState.ProtocolViolation) + .setSeverity(Severity.FATAL) + .build()) + .send(); + connection.setStatus(ConnectionStatus.TERMINATED); + } + } + } + } + + private char readNextMsgIdentifier(ConnectionHandler connection) throws IOException { + DataInputStream inputStream = connection.getConnectionMetadata().getInputStream(); + if (maxBusyWaitTime > 0 && busyWaiters.get() < maxBusyWaiters) { + try { + busyWaiters.incrementAndGet(); + long wait = busyWaitParkTime; + long startTime = System.currentTimeMillis(); + while (inputStream.available() == 0 + && System.currentTimeMillis() - startTime < maxBusyWaitTime) { + LockSupport.parkNanos(wait); + wait = Math.min(wait * 2, maxBusyWaitTime); + } + } finally { + busyWaiters.decrementAndGet(); + } + } + return (char) inputStream.readUnsignedByte(); + } + + /** + * Extract format codes from message (useful for both input and output format codes). + * + * @param input The data stream containing the user request. + * @return A list of format codes. + * @throws Exception If reading fails in any way. + */ + static short[] getFormatCodes(DataInputStream input) throws Exception { + short numberOfFormatCodes = input.readShort(); + if (numberOfFormatCodes == 0) { + return NO_FORMAT_CODES; + } + short[] formatCodes = new short[numberOfFormatCodes]; + for (int i = 0; i < numberOfFormatCodes; i++) { + formatCodes[i] = input.readShort(); + } + return formatCodes; + } +} diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/wireprotocol/WireMessage.java b/src/main/java/com/google/cloud/spanner/pgadapter/wireprotocol/WireMessage.java index 7fb6060e51..1b6abfda11 100644 --- a/src/main/java/com/google/cloud/spanner/pgadapter/wireprotocol/WireMessage.java +++ b/src/main/java/com/google/cloud/spanner/pgadapter/wireprotocol/WireMessage.java @@ -33,15 +33,17 @@ public abstract class WireMessage { private static final Logger logger = Logger.getLogger(WireMessage.class.getName()); protected int length; - protected DataInputStream inputStream; - protected DataOutputStream outputStream; - protected ConnectionHandler connection; + protected final DataInputStream inputStream; + protected final DataOutputStream outputStream; + protected final ConnectionHandler connection; + protected final MessageReader messageReader; public WireMessage(ConnectionHandler connection, int length) { Preconditions.checkArgument(length >= 4); this.connection = connection; this.inputStream = connection.getConnectionMetadata().getInputStream(); this.outputStream = connection.getConnectionMetadata().getOutputStream(); + this.messageReader = connection.getServer().getMessageReader(); this.length = length; } @@ -163,32 +165,35 @@ public String read(int length) throws IOException { */ public String readString() throws IOException { this.inputStream.mark(MARK_READ_LIMIT); - int index = 0; - while (index < MARK_READ_LIMIT) { - byte b = this.inputStream.readByte(); - if (b == 0) { - break; + try { + int index = 0; + while (index < MARK_READ_LIMIT) { + byte b = this.inputStream.readByte(); + if (b == 0) { + break; + } + index++; + if (index == MARK_READ_LIMIT) { + throw new IOException("No null terminator found"); + } } - index++; - if (index == MARK_READ_LIMIT) { - throw new IOException("No null terminator found"); + if (index == 0) { + // Empty string, we don't have to ready anything. + return ""; } - } - // Reset the stream to the mark and read the name (if any). - this.inputStream.reset(); - if (index == 0) { - // No name, but we still need to skip the null-terminator. + + // Reset the stream to the mark and read the string. + this.inputStream.reset(); + byte[] result = new byte[index]; + this.inputStream.readFully(result); + // Skip the null-terminator. //noinspection StatementWithEmptyBody while (this.inputStream.skip(1) < 1) {} - return ""; + return new String(result, StandardCharsets.UTF_8); + } finally { + // Drop the mark. + this.inputStream.mark(0); } - - byte[] result = new byte[index]; - this.inputStream.readFully(result); - // Skip the null-terminator. - //noinspection StatementWithEmptyBody - while (this.inputStream.skip(1) < 1) {} - return new String(result, StandardCharsets.UTF_8); } /** @@ -207,6 +212,6 @@ protected int getHeaderLength() { * setting for {@link ConnectionHandler}. */ public void nextHandler() throws Exception { - this.connection.setMessageState(ControlMessage.create(this.connection)); + this.connection.setMessageState(messageReader.create(this.connection)); } } diff --git a/src/test/java/com/google/cloud/spanner/pgadapter/statements/StatementTest.java b/src/test/java/com/google/cloud/spanner/pgadapter/statements/StatementTest.java index bae11d0d61..ab9ebbffd6 100644 --- a/src/test/java/com/google/cloud/spanner/pgadapter/statements/StatementTest.java +++ b/src/test/java/com/google/cloud/spanner/pgadapter/statements/StatementTest.java @@ -57,7 +57,7 @@ import com.google.cloud.spanner.pgadapter.utils.ClientAutoDetector.WellKnownClient; import com.google.cloud.spanner.pgadapter.utils.Metrics; import com.google.cloud.spanner.pgadapter.utils.MutationWriter; -import com.google.cloud.spanner.pgadapter.wireprotocol.ControlMessage; +import com.google.cloud.spanner.pgadapter.wireprotocol.MessageReader; import com.google.cloud.spanner.pgadapter.wireprotocol.QueryMessage; import com.google.cloud.spanner.pgadapter.wireprotocol.WireMessage; import com.google.common.collect.ImmutableList; @@ -477,15 +477,16 @@ public void testBatchStatementsWithComments() throws Exception { byte[] value = Bytes.concat(messageMetadata, payload.getBytes()); DataInputStream inputStream = new DataInputStream(new ByteArrayInputStream(value)); - when(connectionHandler.getConnectionMetadata()).thenReturn(connectionMetadata); + when(server.getMessageReader()).thenReturn(new MessageReader(mock(OptionsMetadata.class))); when(connectionHandler.getServer()).thenReturn(server); + when(connectionHandler.getConnectionMetadata()).thenReturn(connectionMetadata); when(connectionHandler.getExtendedQueryProtocolHandler()) .thenReturn(extendedQueryProtocolHandler); when(server.getOptions()).thenReturn(options); when(connectionMetadata.getInputStream()).thenReturn(inputStream); when(connectionMetadata.getOutputStream()).thenReturn(outputStream); - WireMessage message = ControlMessage.create(connectionHandler); + WireMessage message = server.getMessageReader().create(connectionHandler); assertEquals(QueryMessage.class, message.getClass()); SimpleQueryStatement simpleQueryStatement = ((QueryMessage) message).getSimpleQueryStatement(); diff --git a/src/test/java/com/google/cloud/spanner/pgadapter/wireprotocol/ControlMessageTest.java b/src/test/java/com/google/cloud/spanner/pgadapter/wireprotocol/ControlMessageTest.java index b429c7f0a2..d5c7c56775 100644 --- a/src/test/java/com/google/cloud/spanner/pgadapter/wireprotocol/ControlMessageTest.java +++ b/src/test/java/com/google/cloud/spanner/pgadapter/wireprotocol/ControlMessageTest.java @@ -20,7 +20,6 @@ import static org.mockito.Mockito.when; import com.google.cloud.spanner.connection.AbstractStatementParser.StatementType; -import com.google.cloud.spanner.connection.Connection; import com.google.cloud.spanner.connection.StatementResult; import com.google.cloud.spanner.connection.StatementResult.ResultType; import com.google.cloud.spanner.pgadapter.ConnectionHandler; @@ -65,7 +64,6 @@ public final class ControlMessageTest { @Mock private ExtendedQueryProtocolHandler extendedQueryProtocolHandler; @Mock private IntermediateStatement intermediateStatement; @Mock private ConnectionMetadata connectionMetadata; - @Mock private Connection connection; @Test public void testInsertResult() throws Exception { @@ -102,7 +100,7 @@ public void testInsertResult() throws Exception { when(connectionHandler.getExtendedQueryProtocolHandler()) .thenReturn(extendedQueryProtocolHandler); - ControlMessage controlMessage = ControlMessage.create(connectionHandler); + ControlMessage controlMessage = server.getMessageReader().create(connectionHandler); controlMessage.sendSpannerResult(intermediateStatement, QueryMode.SIMPLE, 0L); DataInputStream outputReader = @@ -126,6 +124,7 @@ public void testUnknownStatementTypeDoesNotThrowError() throws Exception { new DataInputStream( new ByteArrayInputStream(new byte[] {(byte) QUERY_IDENTIFIER, 0, 0, 0, 5, 0})); + when(connectionHandler.getServer()).thenReturn(mock(ProxyServer.class)); when(connectionMetadata.getInputStream()).thenReturn(inputStream); when(connectionMetadata.getOutputStream()).thenReturn(outputStream); when(connectionHandler.getExtendedQueryProtocolHandler()) @@ -157,6 +156,7 @@ public void testUnknownStatementTypeDoesNotThrowError() throws Exception { public void testSendNoRowsAsResultSetFails() { OpenTelemetry otel = OpenTelemetry.noop(); Tracer tracer = otel.getTracer("test"); + when(connectionHandler.getServer()).thenReturn(mock(ProxyServer.class)); when(connectionHandler.getConnectionMetadata()).thenReturn(connectionMetadata); when(connectionHandler.getExtendedQueryProtocolHandler()) .thenReturn(extendedQueryProtocolHandler); diff --git a/src/test/java/com/google/cloud/spanner/pgadapter/wireprotocol/ProtocolTest.java b/src/test/java/com/google/cloud/spanner/pgadapter/wireprotocol/ProtocolTest.java index 273145ad51..fc83f50dd6 100644 --- a/src/test/java/com/google/cloud/spanner/pgadapter/wireprotocol/ProtocolTest.java +++ b/src/test/java/com/google/cloud/spanner/pgadapter/wireprotocol/ProtocolTest.java @@ -15,7 +15,7 @@ package com.google.cloud.spanner.pgadapter.wireprotocol; import static com.google.cloud.spanner.pgadapter.statements.IntermediatePortalStatement.NO_FORMAT_CODES; -import static com.google.cloud.spanner.pgadapter.wireprotocol.ControlMessage.MAX_INVALID_MESSAGE_COUNT; +import static com.google.cloud.spanner.pgadapter.wireprotocol.MessageReader.MAX_INVALID_MESSAGE_COUNT; import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; @@ -145,6 +145,7 @@ public void testQueryMessage() throws Exception { when(connectionHandler.getServer()).thenReturn(server); when(server.getOptions()).thenReturn(options); + when(server.getMessageReader()).thenReturn(new MessageReader(options)); when(options.requiresMatcher()).thenReturn(false); when(connectionHandler.getSpannerConnection()).thenReturn(connection); when(connectionHandler.getConnectionMetadata()).thenReturn(connectionMetadata); @@ -156,7 +157,7 @@ public void testQueryMessage() throws Exception { when(connectionMetadata.getInputStream()).thenReturn(inputStream); when(connectionMetadata.getOutputStream()).thenReturn(outputStream); - WireMessage message = ControlMessage.create(connectionHandler); + WireMessage message = server.getMessageReader().create(connectionHandler); assertEquals(QueryMessage.class, message.getClass()); assertEquals(expectedSQL, ((QueryMessage) message).getStatement().getSql()); @@ -179,11 +180,12 @@ public void testQueryUsesPSQLStatementWhenPSQLModeSelectedMessage() throws Excep .thenReturn(extendedQueryProtocolHandler); when(connectionHandler.getServer()).thenReturn(server); when(server.getOptions()).thenReturn(options); + when(server.getMessageReader()).thenReturn(new MessageReader(options)); when(connectionHandler.getConnectionMetadata()).thenReturn(connectionMetadata); when(connectionMetadata.getInputStream()).thenReturn(inputStream); when(connectionMetadata.getOutputStream()).thenReturn(outputStream); - WireMessage message = ControlMessage.create(connectionHandler); + WireMessage message = server.getMessageReader().create(connectionHandler); assertEquals(QueryMessage.class, message.getClass()); assertNotNull(((QueryMessage) message).getSimpleQueryStatement()); assertEquals(expectedSQL, ((QueryMessage) message).getStatement().getSql()); @@ -197,13 +199,15 @@ public void testQueryMessageFailsWhenNotNullTerminated() { DataInputStream inputStream = new DataInputStream(new ByteArrayInputStream(value)); + when(server.getMessageReader()).thenReturn(new MessageReader(options)); + when(connectionHandler.getServer()).thenReturn(server); when(connectionHandler.getExtendedQueryProtocolHandler()) .thenReturn(extendedQueryProtocolHandler); when(connectionHandler.getConnectionMetadata()).thenReturn(connectionMetadata); when(connectionMetadata.getInputStream()).thenReturn(inputStream); when(connectionMetadata.getOutputStream()).thenReturn(outputStream); - assertThrows(IOException.class, () -> ControlMessage.create(connectionHandler)); + assertThrows(IOException.class, () -> server.getMessageReader().create(connectionHandler)); } @Test @@ -242,6 +246,7 @@ public void testParseMessageException() throws Exception { when(connectionHandler.getServer()).thenReturn(server); when(server.getOptions()).thenReturn(options); + when(server.getMessageReader()).thenReturn(new MessageReader(options)); when(connectionHandler.getSpannerConnection()).thenReturn(connection); when(connectionHandler.getConnectionMetadata()).thenReturn(connectionMetadata); when(connectionMetadata.getInputStream()).thenReturn(inputStream); @@ -249,7 +254,7 @@ public void testParseMessageException() throws Exception { when(connectionHandler.getExtendedQueryProtocolHandler()) .thenReturn(extendedQueryProtocolHandler); - WireMessage message = ControlMessage.create(connectionHandler); + WireMessage message = server.getMessageReader().create(connectionHandler); assertEquals(ParseMessage.class, message.getClass()); assertEquals(expectedMessageName, ((ParseMessage) message).getName()); assertEquals(expectedSQL, ((ParseMessage) message).getStatement().getSql()); @@ -306,6 +311,7 @@ public void testParseMessage() throws Exception { when(connectionHandler.getServer()).thenReturn(server); when(server.getOptions()).thenReturn(options); + when(server.getMessageReader()).thenReturn(new MessageReader(options)); when(connectionHandler.getSpannerConnection()).thenReturn(connection); when(connectionHandler.getConnectionMetadata()).thenReturn(connectionMetadata); when(connectionMetadata.getInputStream()).thenReturn(inputStream); @@ -313,7 +319,7 @@ public void testParseMessage() throws Exception { when(connectionHandler.getExtendedQueryProtocolHandler()) .thenReturn(extendedQueryProtocolHandler); - WireMessage message = ControlMessage.create(connectionHandler); + WireMessage message = server.getMessageReader().create(connectionHandler); assertEquals(ParseMessage.class, message.getClass()); assertEquals(expectedMessageName, ((ParseMessage) message).getName()); assertEquals(expectedSQL, ((ParseMessage) message).getStatement().getSql()); @@ -371,6 +377,7 @@ public void testParseMessageAcceptsUntypedParameter() throws Exception { when(connectionHandler.getServer()).thenReturn(server); when(server.getOptions()).thenReturn(options); + when(server.getMessageReader()).thenReturn(new MessageReader(options)); when(connectionHandler.getSpannerConnection()).thenReturn(connection); when(connectionHandler.getConnectionMetadata()).thenReturn(connectionMetadata); when(connectionMetadata.getInputStream()).thenReturn(inputStream); @@ -378,7 +385,7 @@ public void testParseMessageAcceptsUntypedParameter() throws Exception { when(connectionHandler.getExtendedQueryProtocolHandler()) .thenReturn(extendedQueryProtocolHandler); - WireMessage message = ControlMessage.create(connectionHandler); + WireMessage message = server.getMessageReader().create(connectionHandler); assertEquals(ParseMessage.class, message.getClass()); assertEquals(expectedMessageName, ((ParseMessage) message).getName()); assertEquals(expectedSQL, ((ParseMessage) message).getStatement().getSql()); @@ -421,6 +428,7 @@ public void testParseMessageWithNonMatchingParameterTypeCount() throws Exception when(connectionHandler.getServer()).thenReturn(server); when(server.getOptions()).thenReturn(options); + when(server.getMessageReader()).thenReturn(new MessageReader(options)); when(connectionHandler.getSpannerConnection()).thenReturn(connection); when(connectionHandler.getConnectionMetadata()).thenReturn(connectionMetadata); when(connectionMetadata.getInputStream()).thenReturn(inputStream); @@ -428,7 +436,7 @@ public void testParseMessageWithNonMatchingParameterTypeCount() throws Exception when(connectionHandler.getExtendedQueryProtocolHandler()) .thenReturn(extendedQueryProtocolHandler); - WireMessage message = ControlMessage.create(connectionHandler); + WireMessage message = server.getMessageReader().create(connectionHandler); assertEquals(ParseMessage.class, message.getClass()); assertEquals(expectedMessageName, ((ParseMessage) message).getName()); assertEquals(expectedSQL, ((ParseMessage) message).getStatement().getSql()); @@ -479,6 +487,7 @@ public void testParseMessageExceptsIfNameIsInUse() throws Exception { when(connectionHandler.getServer()).thenReturn(server); when(server.getOptions()).thenReturn(options); + when(server.getMessageReader()).thenReturn(new MessageReader(options)); when(connectionHandler.getSpannerConnection()).thenReturn(connection); when(connectionHandler.getConnectionMetadata()).thenReturn(connectionMetadata); when(connectionMetadata.getInputStream()).thenReturn(inputStream); @@ -486,7 +495,7 @@ public void testParseMessageExceptsIfNameIsInUse() throws Exception { when(connectionHandler.getExtendedQueryProtocolHandler()) .thenReturn(extendedQueryProtocolHandler); - WireMessage message = ControlMessage.create(connectionHandler); + WireMessage message = server.getMessageReader().create(connectionHandler); when(connectionHandler.hasStatement(anyString())).thenReturn(true); assertThrows(IllegalStateException.class, message::send); @@ -525,6 +534,7 @@ public void testParseMessageExceptsIfNameIsNull() throws Exception { when(connectionHandler.getServer()).thenReturn(server); when(server.getOptions()).thenReturn(options); + when(server.getMessageReader()).thenReturn(new MessageReader(options)); when(connectionHandler.getSpannerConnection()).thenReturn(connection); when(connectionHandler.getConnectionMetadata()).thenReturn(connectionMetadata); when(connectionMetadata.getInputStream()).thenReturn(inputStream); @@ -532,7 +542,7 @@ public void testParseMessageExceptsIfNameIsNull() throws Exception { when(connectionHandler.getExtendedQueryProtocolHandler()) .thenReturn(extendedQueryProtocolHandler); - WireMessage message = ControlMessage.create(connectionHandler); + WireMessage message = server.getMessageReader().create(connectionHandler); when(connectionHandler.hasStatement(anyString())).thenReturn(true); assertThrows(IllegalStateException.class, message::send); @@ -572,6 +582,7 @@ public void testParseMessageWorksIfNameIsEmpty() throws Exception { when(connectionHandler.getServer()).thenReturn(server); when(server.getOptions()).thenReturn(options); + when(server.getMessageReader()).thenReturn(new MessageReader(options)); when(connectionHandler.getSpannerConnection()).thenReturn(connection); when(connectionHandler.getConnectionMetadata()).thenReturn(connectionMetadata); when(connectionMetadata.getInputStream()).thenReturn(inputStream); @@ -579,7 +590,7 @@ public void testParseMessageWorksIfNameIsEmpty() throws Exception { when(connectionHandler.getExtendedQueryProtocolHandler()) .thenReturn(extendedQueryProtocolHandler); - WireMessage message = ControlMessage.create(connectionHandler); + WireMessage message = server.getMessageReader().create(connectionHandler); message.send(); } @@ -621,6 +632,7 @@ public void testBindMessage() throws Exception { parameter, resultCodesCount); + when(server.getMessageReader()).thenReturn(new MessageReader(options)); when(connectionHandler.getStatement(anyString())).thenReturn(intermediatePreparedStatement); when(intermediatePreparedStatement.createPortal(anyString(), any(), any(), any())) .thenReturn(intermediatePortalStatement); @@ -642,8 +654,9 @@ public void testBindMessage() throws Exception { when(connectionMetadata.getOutputStream()).thenReturn(outputStream); when(connectionHandler.getExtendedQueryProtocolHandler()) .thenReturn(extendedQueryProtocolHandler); + when(connectionHandler.getServer()).thenReturn(server); - WireMessage message = ControlMessage.create(connectionHandler); + WireMessage message = server.getMessageReader().create(connectionHandler); assertEquals(BindMessage.class, message.getClass()); assertEquals(expectedPortalName, ((BindMessage) message).getPortalName()); assertEquals(expectedStatementName, ((BindMessage) message).getStatementName()); @@ -720,6 +733,7 @@ public void testBindMessageOneNonTextParam() throws Exception { DataInputStream inputStream = new DataInputStream(new ByteArrayInputStream(value)); + when(server.getMessageReader()).thenReturn(new MessageReader(options)); when(connectionHandler.getExtendedQueryProtocolHandler()) .thenReturn(extendedQueryProtocolHandler); when(connectionHandler.getConnectionMetadata()).thenReturn(connectionMetadata); @@ -728,8 +742,9 @@ public void testBindMessageOneNonTextParam() throws Exception { when(connectionHandler.getStatement(anyString())).thenReturn(intermediatePreparedStatement); when(intermediatePreparedStatement.createPortal(anyString(), any(), any(), any())) .thenReturn(intermediatePortalStatement); + when(connectionHandler.getServer()).thenReturn(server); - WireMessage message = ControlMessage.create(connectionHandler); + WireMessage message = server.getMessageReader().create(connectionHandler); assertEquals(BindMessage.class, message.getClass()); assertEquals(expectedPortalName, ((BindMessage) message).getPortalName()); assertEquals(expectedStatementName, ((BindMessage) message).getStatementName()); @@ -795,6 +810,7 @@ public void testBindMessageAllNonTextParam() throws Exception { DataInputStream inputStream = new DataInputStream(new ByteArrayInputStream(value)); + when(server.getMessageReader()).thenReturn(new MessageReader(options)); when(connectionHandler.getExtendedQueryProtocolHandler()) .thenReturn(extendedQueryProtocolHandler); when(connectionHandler.getConnectionMetadata()).thenReturn(connectionMetadata); @@ -803,8 +819,9 @@ public void testBindMessageAllNonTextParam() throws Exception { when(connectionHandler.getStatement(anyString())).thenReturn(intermediatePreparedStatement); when(intermediatePreparedStatement.createPortal(anyString(), any(), any(), any())) .thenReturn(intermediatePortalStatement); + when(connectionHandler.getServer()).thenReturn(server); - WireMessage message = ControlMessage.create(connectionHandler); + WireMessage message = server.getMessageReader().create(connectionHandler); assertEquals(BindMessage.class, message.getClass()); assertEquals(expectedPortalName, ((BindMessage) message).getPortalName()); assertEquals(expectedStatementName, ((BindMessage) message).getStatementName()); @@ -828,6 +845,7 @@ public void testDescribePortalMessage() throws Exception { ByteArrayOutputStream result = new ByteArrayOutputStream(); DataOutputStream outputStream = new DataOutputStream(result); + when(server.getMessageReader()).thenReturn(new MessageReader(options)); when(connectionHandler.getPortal(anyString())).thenReturn(intermediatePortalStatement); when(intermediatePortalStatement.getSql()).thenReturn("select * from foo"); when(connectionHandler.getConnectionMetadata()).thenReturn(connectionMetadata); @@ -835,8 +853,9 @@ public void testDescribePortalMessage() throws Exception { when(connectionMetadata.getOutputStream()).thenReturn(outputStream); when(connectionHandler.getExtendedQueryProtocolHandler()) .thenReturn(extendedQueryProtocolHandler); + when(connectionHandler.getServer()).thenReturn(server); - WireMessage message = ControlMessage.create(connectionHandler); + WireMessage message = server.getMessageReader().create(connectionHandler); assertEquals(DescribeMessage.class, message.getClass()); assertEquals(expectedStatementName, ((DescribeMessage) message).getName()); assertEquals("select * from foo", ((DescribeMessage) message).getSql()); @@ -867,14 +886,16 @@ public void testDescribeStatementMessage() throws Exception { ByteArrayOutputStream result = new ByteArrayOutputStream(); DataOutputStream outputStream = new DataOutputStream(result); + when(server.getMessageReader()).thenReturn(new MessageReader(options)); when(connectionHandler.getStatement(anyString())).thenReturn(intermediatePreparedStatement); when(connectionHandler.getConnectionMetadata()).thenReturn(connectionMetadata); when(connectionMetadata.getInputStream()).thenReturn(inputStream); when(connectionMetadata.getOutputStream()).thenReturn(outputStream); when(connectionHandler.getExtendedQueryProtocolHandler()) .thenReturn(extendedQueryProtocolHandler); + when(connectionHandler.getServer()).thenReturn(server); - WireMessage message = ControlMessage.create(connectionHandler); + WireMessage message = server.getMessageReader().create(connectionHandler); assertEquals(DescribeMessage.class, message.getClass()); assertEquals(expectedStatementName, ((DescribeMessage) message).getName()); @@ -903,6 +924,7 @@ public void testDescribeMessageWithException() throws Exception { ByteArrayOutputStream result = new ByteArrayOutputStream(); DataOutputStream outputStream = new DataOutputStream(result); + when(server.getMessageReader()).thenReturn(new MessageReader(options)); when(connectionHandler.getStatement(anyString())).thenReturn(intermediatePreparedStatement); when(connectionHandler.getConnectionMetadata()).thenReturn(connectionMetadata); when(connectionMetadata.getInputStream()).thenReturn(inputStream); @@ -912,8 +934,9 @@ public void testDescribeMessageWithException() throws Exception { when(intermediatePreparedStatement.hasException()).thenReturn(true); when(intermediatePreparedStatement.getException()) .thenReturn(PGExceptionFactory.newPGException("test error", SQLState.InternalError)); + when(connectionHandler.getServer()).thenReturn(server); - WireMessage message = ControlMessage.create(connectionHandler); + WireMessage message = server.getMessageReader().create(connectionHandler); assertEquals(DescribeMessage.class, message.getClass()); DescribeMessage describeMessage = (DescribeMessage) message; @@ -938,6 +961,8 @@ public void testExecuteMessage() throws Exception { ByteArrayOutputStream result = new ByteArrayOutputStream(); DataOutputStream outputStream = new DataOutputStream(result); + when(server.getMessageReader()).thenReturn(new MessageReader(options)); + when(connectionHandler.getServer()).thenReturn(server); when(connectionHandler.getPortal(anyString())).thenReturn(intermediatePortalStatement); when(connectionHandler.getConnectionMetadata()).thenReturn(connectionMetadata); when(connectionMetadata.getInputStream()).thenReturn(inputStream); @@ -946,7 +971,7 @@ public void testExecuteMessage() throws Exception { .thenReturn(extendedQueryProtocolHandler); when(extendedQueryProtocolHandler.getBackendConnection()).thenReturn(backendConnection); - WireMessage message = ControlMessage.create(connectionHandler); + WireMessage message = server.getMessageReader().create(connectionHandler); assertEquals(ExecuteMessage.class, message.getClass()); assertEquals(expectedStatementName, ((ExecuteMessage) message).getName()); assertEquals(totalRows, ((ExecuteMessage) message).getMaxRows()); @@ -985,6 +1010,7 @@ public void testExecuteMessageWithException() throws Exception { PGException testException = PGExceptionFactory.newPGException("test error", SQLState.SyntaxError); + when(server.getMessageReader()).thenReturn(new MessageReader(options)); when(intermediatePortalStatement.hasException()).thenReturn(true); when(intermediatePortalStatement.getException()).thenReturn(testException); when(connectionHandler.getWellKnownClient()).thenReturn(WellKnownClient.UNSPECIFIED); @@ -995,8 +1021,9 @@ public void testExecuteMessageWithException() throws Exception { when(connectionHandler.getExtendedQueryProtocolHandler()) .thenReturn(extendedQueryProtocolHandler); when(extendedQueryProtocolHandler.getBackendConnection()).thenReturn(backendConnection); + when(connectionHandler.getServer()).thenReturn(server); - WireMessage message = ControlMessage.create(connectionHandler); + WireMessage message = server.getMessageReader().create(connectionHandler); assertEquals(ExecuteMessage.class, message.getClass()); assertEquals(expectedStatementName, ((ExecuteMessage) message).getName()); assertEquals(totalRows, ((ExecuteMessage) message).getMaxRows()); @@ -1028,12 +1055,14 @@ public void testClosePortalMessage() throws Exception { ByteArrayOutputStream result = new ByteArrayOutputStream(); DataOutputStream outputStream = new DataOutputStream(result); + when(server.getMessageReader()).thenReturn(new MessageReader(options)); when(connectionHandler.getPortal(anyString())).thenReturn(intermediatePortalStatement); when(connectionHandler.getConnectionMetadata()).thenReturn(connectionMetadata); when(connectionMetadata.getInputStream()).thenReturn(inputStream); when(connectionMetadata.getOutputStream()).thenReturn(outputStream); + when(connectionHandler.getServer()).thenReturn(server); - WireMessage message = ControlMessage.create(connectionHandler); + WireMessage message = server.getMessageReader().create(connectionHandler); assertEquals(CloseMessage.class, message.getClass()); assertEquals(expectedStatementName, ((CloseMessage) message).getName()); assertEquals(expectedType, ((CloseMessage) message).getType()); @@ -1066,12 +1095,14 @@ public void testCloseStatementMessage() throws Exception { ByteArrayOutputStream result = new ByteArrayOutputStream(); DataOutputStream outputStream = new DataOutputStream(result); + when(server.getMessageReader()).thenReturn(new MessageReader(options)); + when(connectionHandler.getServer()).thenReturn(server); when(connectionHandler.getStatement(anyString())).thenReturn(intermediatePortalStatement); when(connectionHandler.getConnectionMetadata()).thenReturn(connectionMetadata); when(connectionMetadata.getInputStream()).thenReturn(inputStream); when(connectionMetadata.getOutputStream()).thenReturn(outputStream); - WireMessage message = ControlMessage.create(connectionHandler); + WireMessage message = server.getMessageReader().create(connectionHandler); assertEquals(CloseMessage.class, message.getClass()); assertEquals(expectedStatementName, ((CloseMessage) message).getName()); assertEquals(expectedType, ((CloseMessage) message).getType()); @@ -1099,6 +1130,8 @@ public void testSyncMessage() throws Exception { ByteArrayOutputStream result = new ByteArrayOutputStream(); DataOutputStream outputStream = new DataOutputStream(result); + when(server.getMessageReader()).thenReturn(new MessageReader(options)); + when(connectionHandler.getServer()).thenReturn(server); when(connectionHandler.getTraceConnectionId()).thenReturn(UUID.randomUUID()); ExtendedQueryProtocolHandler extendedQueryProtocolHandler = new ExtendedQueryProtocolHandler(connectionHandler, backendConnection); @@ -1111,7 +1144,7 @@ public void testSyncMessage() throws Exception { .thenReturn(extendedQueryProtocolHandler); when(backendConnection.getConnectionState()).thenReturn(ConnectionState.IDLE); - WireMessage message = ControlMessage.create(connectionHandler); + WireMessage message = server.getMessageReader().create(connectionHandler); assertEquals(message.getClass(), SyncMessage.class); message.send(); @@ -1135,6 +1168,7 @@ public void testSyncMessageInTransaction() throws Exception { ByteArrayOutputStream result = new ByteArrayOutputStream(); DataOutputStream outputStream = new DataOutputStream(result); + when(server.getMessageReader()).thenReturn(new MessageReader(options)); when(connectionHandler.getTraceConnectionId()).thenReturn(UUID.randomUUID()); ExtendedQueryProtocolHandler extendedQueryProtocolHandler = new ExtendedQueryProtocolHandler(connectionHandler, backendConnection); @@ -1142,12 +1176,13 @@ public void testSyncMessageInTransaction() throws Exception { when(connectionHandler.getConnectionMetadata()).thenReturn(connectionMetadata); when(connectionMetadata.getInputStream()).thenReturn(inputStream); when(connectionMetadata.getOutputStream()).thenReturn(outputStream); + when(connectionHandler.getServer()).thenReturn(server); when(connectionHandler.getExtendedQueryProtocolHandler()) .thenReturn(extendedQueryProtocolHandler); when(backendConnection.getConnectionState()).thenReturn(ConnectionState.TRANSACTION); - WireMessage message = ControlMessage.create(connectionHandler); + WireMessage message = server.getMessageReader().create(connectionHandler); assertEquals(SyncMessage.class, message.getClass()); message.send(); @@ -1170,13 +1205,15 @@ public void testFlushMessage() throws Exception { DataInputStream inputStream = new DataInputStream(new ByteArrayInputStream(value)); DataOutputStream outputStream = mock(DataOutputStream.class); + when(server.getMessageReader()).thenReturn(new MessageReader(options)); when(connectionHandler.getConnectionMetadata()).thenReturn(connectionMetadata); when(connectionMetadata.getInputStream()).thenReturn(inputStream); when(connectionMetadata.getOutputStream()).thenReturn(outputStream); when(connectionHandler.getExtendedQueryProtocolHandler()) .thenReturn(extendedQueryProtocolHandler); + when(connectionHandler.getServer()).thenReturn(server); - WireMessage message = ControlMessage.create(connectionHandler); + WireMessage message = server.getMessageReader().create(connectionHandler); assertEquals(message.getClass(), FlushMessage.class); message.send(); @@ -1196,13 +1233,15 @@ public void testFlushMessageInTransaction() throws Exception { DataInputStream inputStream = new DataInputStream(new ByteArrayInputStream(value)); DataOutputStream outputStream = mock(DataOutputStream.class); + when(server.getMessageReader()).thenReturn(new MessageReader(options)); + when(connectionHandler.getServer()).thenReturn(server); when(connectionHandler.getExtendedQueryProtocolHandler()) .thenReturn(extendedQueryProtocolHandler); when(connectionHandler.getConnectionMetadata()).thenReturn(connectionMetadata); when(connectionMetadata.getInputStream()).thenReturn(inputStream); when(connectionMetadata.getOutputStream()).thenReturn(outputStream); - WireMessage message = ControlMessage.create(connectionHandler); + WireMessage message = server.getMessageReader().create(connectionHandler); assertEquals(FlushMessage.class, message.getClass()); message.send(); @@ -1239,6 +1278,7 @@ public void testQueryMessageInTransaction() throws Exception { when(backendConnection.getMetrics()).thenReturn(new Metrics(OpenTelemetry.noop())); OptionsMetadata options = mock(OptionsMetadata.class); when(server.getOptions()).thenReturn(options); + when(server.getMessageReader()).thenReturn(new MessageReader(options)); when(connectionMetadata.getInputStream()).thenReturn(inputStream); when(connectionMetadata.getOutputStream()).thenReturn(outputStream); @@ -1247,7 +1287,7 @@ public void testQueryMessageInTransaction() throws Exception { when(connectionHandler.getExtendedQueryProtocolHandler()) .thenReturn(extendedQueryProtocolHandler); - WireMessage message = ControlMessage.create(connectionHandler); + WireMessage message = server.getMessageReader().create(connectionHandler); assertEquals(QueryMessage.class, message.getClass()); assertEquals(expectedSQL, ((QueryMessage) message).getStatement().getSql()); @@ -1286,8 +1326,10 @@ public void testTerminateMessage() throws Exception { when(connectionHandler.getConnectionMetadata()).thenReturn(connectionMetadata); when(connectionMetadata.getInputStream()).thenReturn(inputStream); when(connectionMetadata.getOutputStream()).thenReturn(outputStream); + when(server.getMessageReader()).thenReturn(new MessageReader(options)); + when(connectionHandler.getServer()).thenReturn(server); - WireMessage message = ControlMessage.create(connectionHandler); + WireMessage message = server.getMessageReader().create(connectionHandler); assertEquals(TerminateMessage.class, message.getClass()); } @@ -1299,8 +1341,10 @@ public void testUnknownMessageTypeCausesException() { when(connectionHandler.getConnectionMetadata()).thenReturn(connectionMetadata); when(connectionMetadata.getInputStream()).thenReturn(inputStream); + when(server.getMessageReader()).thenReturn(new MessageReader(options)); - assertThrows(IllegalStateException.class, () -> ControlMessage.create(connectionHandler)); + assertThrows( + IllegalStateException.class, () -> server.getMessageReader().create(connectionHandler)); } @Test @@ -1318,11 +1362,13 @@ public void testCopyDataMessage() throws Exception { when(connectionHandler.getStatus()).thenReturn(ConnectionStatus.COPY_IN); when(connectionMetadata.getInputStream()).thenReturn(inputStream); when(connectionMetadata.getOutputStream()).thenReturn(outputStream); + when(server.getMessageReader()).thenReturn(new MessageReader(options)); + when(connectionHandler.getServer()).thenReturn(server); MutationWriter mw = mock(MutationWriter.class); when(copyStatement.getMutationWriter()).thenReturn(mw); - WireMessage message = ControlMessage.create(connectionHandler); + WireMessage message = server.getMessageReader().create(connectionHandler); assertEquals(CopyDataMessage.class, message.getClass()); assertArrayEquals(payload, ((CopyDataMessage) message).getPayload()); @@ -1346,8 +1392,10 @@ public void testCopyDataMessageWithNoCopyStatement() throws Exception { when(connectionHandler.getStatus()).thenReturn(ConnectionStatus.COPY_IN); when(connectionMetadata.getInputStream()).thenReturn(inputStream); when(connectionMetadata.getOutputStream()).thenReturn(outputStream); + when(server.getMessageReader()).thenReturn(new MessageReader(options)); + when(connectionHandler.getServer()).thenReturn(server); - WireMessage message = ControlMessage.create(connectionHandler); + WireMessage message = server.getMessageReader().create(connectionHandler); // This should be a no-op. message.sendPayload(); } @@ -1357,6 +1405,7 @@ public void testMultipleCopyDataMessages() throws Exception { when(connectionHandler.getSpannerConnection()).thenReturn(connection); when(connectionHandler.getConnectionMetadata()).thenReturn(connectionMetadata); when(connectionHandler.getStatus()).thenReturn(ConnectionStatus.COPY_IN); + when(server.getMessageReader()).thenReturn(new MessageReader(options)); byte[] messageMetadata = {'d'}; byte[] payload1 = "1\t'one'\n2\t".getBytes(); @@ -1382,10 +1431,11 @@ public void testMultipleCopyDataMessages() throws Exception { when(connectionHandler.getActiveCopyStatement()).thenReturn(copyStatement); when(connectionHandler.getConnectionMetadata()).thenReturn(connectionMetadata); when(connectionMetadata.getOutputStream()).thenReturn(outputStream); + when(connectionHandler.getServer()).thenReturn(server); { when(connectionMetadata.getInputStream()).thenReturn(inputStream1); - WireMessage message = ControlMessage.create(connectionHandler); + WireMessage message = server.getMessageReader().create(connectionHandler); assertEquals(CopyDataMessage.class, message.getClass()); assertArrayEquals(payload1, ((CopyDataMessage) message).getPayload()); CopyDataMessage copyDataMessage = (CopyDataMessage) message; @@ -1393,7 +1443,7 @@ public void testMultipleCopyDataMessages() throws Exception { } { when(connectionMetadata.getInputStream()).thenReturn(inputStream2); - WireMessage message = ControlMessage.create(connectionHandler); + WireMessage message = server.getMessageReader().create(connectionHandler); assertEquals(CopyDataMessage.class, message.getClass()); assertArrayEquals(payload2, ((CopyDataMessage) message).getPayload()); CopyDataMessage copyDataMessage = (CopyDataMessage) message; @@ -1401,7 +1451,7 @@ public void testMultipleCopyDataMessages() throws Exception { } { when(connectionMetadata.getInputStream()).thenReturn(inputStream3); - WireMessage message = ControlMessage.create(connectionHandler); + WireMessage message = server.getMessageReader().create(connectionHandler); assertEquals(CopyDataMessage.class, message.getClass()); assertArrayEquals(payload3, ((CopyDataMessage) message).getPayload()); CopyDataMessage copyDataMessage = (CopyDataMessage) message; @@ -1427,8 +1477,10 @@ public void testCopyDoneMessage() throws Exception { when(connectionHandler.getStatus()).thenReturn(ConnectionStatus.COPY_IN); when(connectionMetadata.getInputStream()).thenReturn(inputStream); when(connectionMetadata.getOutputStream()).thenReturn(outputStream); + when(server.getMessageReader()).thenReturn(new MessageReader(options)); + when(connectionHandler.getServer()).thenReturn(server); - WireMessage message = ControlMessage.create(connectionHandler); + WireMessage message = server.getMessageReader().create(connectionHandler); assertEquals(CopyDoneMessage.class, message.getClass()); CopyDoneMessage messageSpy = (CopyDoneMessage) spy(message); @@ -1459,8 +1511,10 @@ public void testCopyFailMessage() throws Exception { when(connectionHandler.getStatus()).thenReturn(ConnectionStatus.COPY_IN); when(connectionMetadata.getInputStream()).thenReturn(inputStream); when(connectionMetadata.getOutputStream()).thenReturn(outputStream); + when(server.getMessageReader()).thenReturn(new MessageReader(options)); + when(connectionHandler.getServer()).thenReturn(server); - WireMessage message = ControlMessage.create(connectionHandler); + WireMessage message = server.getMessageReader().create(connectionHandler); assertEquals(CopyFailMessage.class, message.getClass()); assertEquals(expectedErrorMessage, ((CopyFailMessage) message).getErrorMessage()); @@ -1514,8 +1568,10 @@ public void testFunctionCallMessageThrowsException() throws Exception { when(connectionHandler.getConnectionMetadata()).thenReturn(connectionMetadata); when(connectionMetadata.getInputStream()).thenReturn(inputStream); when(connectionMetadata.getOutputStream()).thenReturn(outputStream); + when(server.getMessageReader()).thenReturn(new MessageReader(options)); + when(connectionHandler.getServer()).thenReturn(server); - WireMessage message = ControlMessage.create(connectionHandler); + WireMessage message = server.getMessageReader().create(connectionHandler); assertEquals(FunctionCallMessage.class, message.getClass()); assertThrows(IllegalStateException.class, message::send); @@ -1564,6 +1620,7 @@ public void testStartUpMessage() throws Exception { when(sessionState.get(null, "server_version")).thenReturn(serverVersionSetting); when(backendConnection.getSessionState()).thenReturn(sessionState); when(server.getOptions()).thenReturn(options); + when(server.getMessageReader()).thenReturn(new MessageReader(options)); when(options.shouldAuthenticate()).thenReturn(false); when(connectionMetadata.getInputStream()).thenReturn(inputStream); when(connectionMetadata.getOutputStream()).thenReturn(outputStream); @@ -1664,6 +1721,8 @@ public void testCancelMessage() throws Exception { when(connectionHandler.getConnectionMetadata()).thenReturn(connectionMetadata); when(connectionMetadata.getInputStream()).thenReturn(inputStream); when(connectionMetadata.getOutputStream()).thenReturn(outputStream); + when(server.getMessageReader()).thenReturn(new MessageReader(options)); + when(connectionHandler.getServer()).thenReturn(server); WireMessage message = BootstrapMessage.create(connectionHandler); assertEquals(CancelMessage.class, message.getClass()); @@ -1694,6 +1753,7 @@ public void testSSLMessage() throws Exception { when(options.getSslMode()).thenReturn(SslMode.Enable); when(connectionMetadata.getInputStream()).thenReturn(inputStream); when(connectionMetadata.getOutputStream()).thenReturn(outputStream); + when(server.getMessageReader()).thenReturn(new MessageReader(options)); WireMessage message = BootstrapMessage.create(connectionHandler); assertEquals(SSLMessage.class, message.getClass()); @@ -1724,6 +1784,7 @@ public void testSSLMessageFailsWhenCalledTwice() throws Exception { when(options.getSslMode()).thenReturn(SslMode.Disable); when(connectionMetadata.getInputStream()).thenReturn(inputStream); when(connectionMetadata.getOutputStream()).thenReturn(outputStream); + when(server.getMessageReader()).thenReturn(new MessageReader(options)); WireMessage message = BootstrapMessage.create(connectionHandler); assertEquals(SSLMessage.class, message.getClass()); @@ -1747,6 +1808,8 @@ public void testGetPortalMetadataBeforeFlushFails() { when(intermediatePortalStatement.containsResultSet()).thenReturn(true); when(intermediatePortalStatement.describeAsync(backendConnection)) .thenReturn(SettableFuture.create()); + when(server.getMessageReader()).thenReturn(new MessageReader(options)); + when(connectionHandler.getServer()).thenReturn(server); DescribeMessage describeMessage = new DescribeMessage(connectionHandler, ManuallyCreatedToken.MANUALLY_CREATED_TOKEN); @@ -1768,6 +1831,8 @@ public void testSkipMessage() throws Exception { when(connectionMetadata.getInputStream()).thenReturn(inputStream); when(connectionMetadata.getOutputStream()).thenReturn(outputStream); when(connectionHandler.getConnectionMetadata()).thenReturn(connectionMetadata); + when(server.getMessageReader()).thenReturn(new MessageReader(options)); + when(connectionHandler.getServer()).thenReturn(server); SkipMessage message = SkipMessage.createForValidStream(connectionHandler); message.send(); @@ -1791,8 +1856,10 @@ public void testFlushSkippedInCopyMode() throws Exception { when(connectionMetadata.getOutputStream()).thenReturn(outputStream); when(connectionHandler.getConnectionMetadata()).thenReturn(connectionMetadata); when(connectionHandler.getStatus()).thenReturn(ConnectionStatus.COPY_IN); + when(server.getMessageReader()).thenReturn(new MessageReader(options)); + when(connectionHandler.getServer()).thenReturn(server); - ControlMessage message = ControlMessage.create(connectionHandler); + ControlMessage message = server.getMessageReader().create(connectionHandler); assertEquals(SkipMessage.class, message.getClass()); // Verify that nothing was written to the output. @@ -1811,8 +1878,10 @@ public void testSyncSkippedInCopyMode() throws Exception { when(connectionMetadata.getOutputStream()).thenReturn(outputStream); when(connectionHandler.getConnectionMetadata()).thenReturn(connectionMetadata); when(connectionHandler.getStatus()).thenReturn(ConnectionStatus.COPY_IN); + when(server.getMessageReader()).thenReturn(new MessageReader(options)); + when(connectionHandler.getServer()).thenReturn(server); - ControlMessage message = ControlMessage.create(connectionHandler); + ControlMessage message = server.getMessageReader().create(connectionHandler); assertEquals(SkipMessage.class, message.getClass()); // Verify that nothing was written to the output. @@ -1833,8 +1902,10 @@ public void testCopyDataSkippedInNormalMode() throws Exception { when(connectionMetadata.getOutputStream()).thenReturn(outputStream); when(connectionHandler.getConnectionMetadata()).thenReturn(connectionMetadata); when(connectionHandler.getStatus()).thenReturn(ConnectionStatus.AUTHENTICATED); + when(server.getMessageReader()).thenReturn(new MessageReader(options)); + when(connectionHandler.getServer()).thenReturn(server); - ControlMessage message = ControlMessage.create(connectionHandler); + ControlMessage message = server.getMessageReader().create(connectionHandler); assertEquals(SkipMessage.class, message.getClass()); // Verify that nothing was written to the output. @@ -1853,8 +1924,10 @@ public void testCopyDoneSkippedInNormalMode() throws Exception { when(connectionMetadata.getOutputStream()).thenReturn(outputStream); when(connectionHandler.getConnectionMetadata()).thenReturn(connectionMetadata); when(connectionHandler.getStatus()).thenReturn(ConnectionStatus.AUTHENTICATED); + when(connectionHandler.getServer()).thenReturn(server); + when(server.getMessageReader()).thenReturn(new MessageReader(options)); - ControlMessage message = ControlMessage.create(connectionHandler); + ControlMessage message = server.getMessageReader().create(connectionHandler); assertEquals(SkipMessage.class, message.getClass()); // Verify that nothing was written to the output. @@ -1873,8 +1946,10 @@ public void testCopyFailSkippedInNormalMode() throws Exception { when(connectionMetadata.getOutputStream()).thenReturn(outputStream); when(connectionHandler.getConnectionMetadata()).thenReturn(connectionMetadata); when(connectionHandler.getStatus()).thenReturn(ConnectionStatus.AUTHENTICATED); + when(server.getMessageReader()).thenReturn(new MessageReader(options)); + when(connectionHandler.getServer()).thenReturn(server); - ControlMessage message = ControlMessage.create(connectionHandler); + ControlMessage message = server.getMessageReader().create(connectionHandler); assertEquals(SkipMessage.class, message.getClass()); // Verify that nothing was written to the output. @@ -1908,16 +1983,18 @@ public void testRepeatedCopyDataInNormalMode_TerminatesConnectionAndReturnsError when(connectionHandler.getStatus()).thenReturn(ConnectionStatus.AUTHENTICATED); doCallRealMethod().when(connectionHandler).increaseInvalidMessageCount(); when(connectionHandler.getInvalidMessageCount()).thenCallRealMethod(); + when(server.getMessageReader()).thenReturn(new MessageReader(options)); + when(connectionHandler.getServer()).thenReturn(server); for (int i = 0; i < MAX_INVALID_MESSAGE_COUNT; i++) { - ControlMessage message = ControlMessage.create(connectionHandler); + ControlMessage message = server.getMessageReader().create(connectionHandler); assertEquals(SkipMessage.class, message.getClass()); // Verify that nothing was written to the output. assertEquals(0, result.size()); verify(connectionHandler, never()).setStatus(ConnectionStatus.TERMINATED); } - ControlMessage.create(connectionHandler); + server.getMessageReader().create(connectionHandler); verify(connectionHandler).setStatus(ConnectionStatus.TERMINATED); byte[] resultBytes = result.toByteArray(); assertEquals('E', resultBytes[0]); diff --git a/wrappers/nodejs/.gitignore b/wrappers/nodejs/.gitignore new file mode 100644 index 0000000000..1521c8b765 --- /dev/null +++ b/wrappers/nodejs/.gitignore @@ -0,0 +1 @@ +dist diff --git a/wrappers/nodejs/package.json b/wrappers/nodejs/package.json new file mode 100644 index 0000000000..837d53b66a --- /dev/null +++ b/wrappers/nodejs/package.json @@ -0,0 +1,26 @@ +{ + "name": "test-wrapped-binary", + "version": "0.0.4", + "description": "Test running a wrapped binary", + "main": "dist/index.js", + "types": "dist/index.d.ts", + "files": [ + "/dist" + ], + "devDependencies": { + "@types/gunzip-maybe": "^1.4.2", + "@types/node": "^22.1.0", + "@types/tar-fs": "^2.0.4", + "ts-node": "10.9.2", + "typescript": "^5.8.2" + }, + "dependencies": { + "gunzip-maybe": "^1.4.2", + "tar-fs": "^3.0.8", + "tcp-port-used": "^1.0.2", + "typed-rest-client": "^2.1.0" + }, + "scripts": { + "install": "node dist/install/install.js" + } +} diff --git a/wrappers/nodejs/src/index.ts b/wrappers/nodejs/src/index.ts new file mode 100644 index 0000000000..4c3ceb1658 --- /dev/null +++ b/wrappers/nodejs/src/index.ts @@ -0,0 +1 @@ +export {startPGAdapter} from './pgadapter' diff --git a/wrappers/nodejs/src/install/binary.ts b/wrappers/nodejs/src/install/binary.ts new file mode 100644 index 0000000000..dc97a40f92 --- /dev/null +++ b/wrappers/nodejs/src/install/binary.ts @@ -0,0 +1,11 @@ +const supportedPlatforms: NodeJS.Platform[] = ["linux", "darwin"]; +const supportedArchitectures: NodeJS.Architecture[] = ["x64", "arm64"]; + +export function checkPlatform(platform: NodeJS.Platform, arch: NodeJS.Architecture) { + if (!supportedPlatforms.includes(platform)) { + throw new Error(`Unsupported platform: ${platform}`); + } + if (!supportedArchitectures.includes(arch)) { + throw new Error(`Unsupported architecture: ${arch}`); + } +} diff --git a/wrappers/nodejs/src/install/install.ts b/wrappers/nodejs/src/install/install.ts new file mode 100644 index 0000000000..2e5bca66ef --- /dev/null +++ b/wrappers/nodejs/src/install/install.ts @@ -0,0 +1,49 @@ +import path from "path"; +import {HttpClient} from "typed-rest-client/HttpClient"; +import * as fs from "fs"; +import gunzip from "gunzip-maybe"; +import tar from "tar-fs"; +import {checkPlatform} from "./binary"; + +async function installBinary() { + const client = new HttpClient("pgadapter-nodejs"); + const url = determineUrl(); + const response = await client.get(url); + const folder = path.join(__dirname, "..", "..", "bin"); + + if (!fs.existsSync(folder)) { + fs.mkdirSync(folder); + } + + if (response.message.statusCode !== 200) { + const err: Error = new Error(`Unexpected HTTP response: ${response.message.statusCode}`); + err["httpStatusCode"] = response.message.statusCode; + throw err; + } + + console.log(`Downloading pgadapter from ${url}`); + return new Promise((resolve, reject) => { + const stream = response.message.pipe(gunzip()).pipe(tar.extract(folder)); + stream.on("error", (err) => reject(err)); + stream.on("close", () => { + try { resolve(folder); } catch (err) { + reject(err); + } + }); + }); +} + +function determineUrl(): string { + checkPlatform(process.platform, process.arch); + + const host = "https://storage.googleapis.com/test-pgadapter-native-image"; + const version = require(path.join(__dirname, "..", "..", "package.json")).version; + return `${host}/v${version}/pgadapter-${process.platform}-${process.arch}.tar.gz`; +} + +(async () => { + await installBinary(); +})().catch(e => { + console.error(e); + process.exit(1); +}); diff --git a/wrappers/nodejs/src/pgadapter.ts b/wrappers/nodejs/src/pgadapter.ts new file mode 100644 index 0000000000..e15f416c2b --- /dev/null +++ b/wrappers/nodejs/src/pgadapter.ts @@ -0,0 +1,60 @@ +import {checkPlatform} from "./install/binary"; +import {ChildProcessWithoutNullStreams, spawn} from 'child_process'; +import * as path from "path"; +const tcpPortUsed = require('tcp-port-used'); + +export interface PGAdapterOptions { + project?: string + instance?: string + database?: string + + port?: number + credentials?: string +} + +export interface StartupOptions { + skipStartupProbe?: boolean, + timeoutMs?: number, + probeRetryMs?: number, + + platform?: NodeJS.Platform, + arch?: NodeJS.Architecture, +} + +export async function startPGAdapter(options?: PGAdapterOptions, startupOptions?: StartupOptions): + Promise { + const platform = startupOptions?.platform || process.platform; + const arch = startupOptions?.arch || process.arch; + checkPlatform(platform, arch); + + const binary = path.join(__dirname, "..", "bin", `pgadapter-${platform}-${arch}`); + const args: string[] = []; + if (options?.project) { + args.push("-p", options.project); + } + if (options?.instance) { + args.push("-i", options.instance); + } + if (options?.database) { + args.push("-d", options.database); + } + if (options?.port) { + args.push("-s", `${options.port}`); + } + if (options?.credentials) { + args.push("-c", options.credentials); + } + + const pgAdapter = spawn(binary, args, {stdio: 'inherit'}); + await new Promise((resolve, reject) => { + pgAdapter.on("spawn", resolve); + pgAdapter.on("error", reject); + }); + if (!startupOptions?.skipStartupProbe) { + await tcpPortUsed.waitUntilUsed( + options.port || 5432, + startupOptions?.probeRetryMs || 100, + startupOptions?.timeoutMs || 10000); + } + return pgAdapter; +} diff --git a/wrappers/nodejs/tsconfig.json b/wrappers/nodejs/tsconfig.json new file mode 100644 index 0000000000..bf42db879b --- /dev/null +++ b/wrappers/nodejs/tsconfig.json @@ -0,0 +1,12 @@ +{ + "compilerOptions": { + "module": "commonjs", + "target": "es2019", + "declaration": true, + "outDir": "./dist", + "esModuleInterop": true + }, + "include": [ + "src/**/*" + ] +}